In this tutorial, we'll be describing how to perform multi-task Bayesian optimization over composite functions. In these types of problems, there are several related outputs, and an overall easy to evaluate objective function that we wish to maximize.
Multi-task Bayesian Optimization was first proposed by Swersky et al, NeurIPS, '13 in the context of fast hyper-parameter tuning for neural network models; however, we demonstrate a more advanced use-case of composite Bayesian optimization where the overall function that we wish to optimize is a cheap-to-evaluate (and known) function of the outputs. In general, we expect that using more information about the function should yield improved performance when attempting to optimize it, particularly if the metric function itself is quickly varying.
See the composite BO tutorial w/ HOGP for a more technical introduction. In general, we suggest using MTGPs for unstructured task outputs and the HOGP for matrix / tensor structured outputs.
We will use a Multi-Task Gaussian process (MTGP) with an ICM kernel to model all of the outputs in this problem. MTGPs can be easily accessed in Botorch via the botorch.models.KroneckerMultiTaskGP
model class (for the "block design" case of fully observed outputs at all inputs). Given $T$ tasks (outputs) and $n$ data points, they assume that the responses, $Y \sim \mathbb{R}^{n \times T},$ are distributed as $\text{vec}(Y) \sim \mathcal{N}(f, D)$ and $f \sim \mathcal{GP}(\mu_{\theta}, K_{XX} \otimes K_{T}),$ where $D$ is a (diagonal) noise term.
import torch
import os
import time
from botorch.test_functions import Hartmann
from botorch.models import SingleTaskGP, KroneckerMultiTaskGP
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.sampling import IIDNormalSampler
from botorch.optim import optimize_acqf
from botorch.optim.fit import fit_gpytorch_torch
from botorch.acquisition.objective import GenericMCObjective
from botorch.acquisition.monte_carlo import qExpectedImprovement
from botorch.models.transforms.outcome import Standardize
from botorch.models.transforms.input import Normalize
SMOKE_TEST = os.environ.get("SMOKE_TEST")
torch.random.manual_seed(2010)
if torch.cuda.is_available():
torch.cuda.set_device("cuda:0")
tkwargs = {
"dtype": torch.double,
"device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
}
The function that we wish to optimize is based off of a contextual version of the Hartmann-6 test function, where following Feng et al, NeurIPS, '20 we convert the sixth task dimension into a task indicator. We assume that we evaluate all $20$ contexts at once.
from botorch.test_functions import Hartmann
from torch import Tensor
class ContextualHartmann6(Hartmann):
def __init__(self, num_tasks: int = 20, noise_std = None, negate = False):
super().__init__(dim=6, noise_std = noise_std, negate = negate)
self.task_range = torch.linspace(0, 1, num_tasks).unsqueeze(-1)
self._bounds = [(0.0, 1.0) for _ in range(self.dim - 1)]
self.bounds = torch.tensor(self._bounds).t()
def evaluate_true(self, X: Tensor) -> Tensor:
batch_X = X.unsqueeze(-2)
batch_dims = X.ndim - 1
expanded_task_range = self.task_range
for _ in range(batch_dims):
expanded_task_range = expanded_task_range.unsqueeze(0)
task_range = expanded_task_range.repeat(*X.shape[:-1], 1, 1).to(X)
concatenated_X = torch.cat(
(batch_X.repeat(*[1]*batch_dims, self.task_range.shape[0], 1), task_range), dim=-1
)
return super().evaluate_true(concatenated_X)
We use GenericMCObjective
to define the differentiable function that we are optimizing. Here, it is defined as
$$g(f) = \sum_{i=1}^T \cos(f_i^2 + f_i w_i)$$
where $w$ is a weight vector (drawn randomly once at the start of the optimization). As this function is a non-linear function of the outputs $f,$ we cannot compute acquisition functions via computation of the posterior mean and variance, but rather have to compute posterior samples and evaluate acquisitions with Monte Carlo sampling.
For greater than $10$ or so tasks, it is computationally challenging to sample the posterior over all tasks jointly using conventional approaches, except that Maddox et al, '21 have devised an efficient method for exploiting the structure in the posterior distribution of the MTGP, enabling efficient MC based optimization of objectives using MTGPs.
problem = ContextualHartmann6(noise_std = 0.001, negate=True).to(**tkwargs)
# we choose 20 random weights
weights = torch.randn(20, **tkwargs)
def callable_func(samples, X=None):
res = -torch.cos((samples**2) + samples * weights)
return res.sum(dim=-1)
objective = GenericMCObjective(callable_func)
bounds = problem.bounds
Define helper functions used for optimizing the acquisition function and for constructing the batch expected improvement acquisition, which we optimize for both the batch GP and MTGP.
def optimize_acqf_and_get_candidate(acq_func, bounds, batch_size):
"""Optimizes the acquisition function, and returns a new candidate and a noisy observation."""
# optimize
candidates, _ = optimize_acqf(
acq_function=acq_func,
bounds=bounds,
q=batch_size,
num_restarts=10,
raw_samples=512, # used for intialization heuristic
options={"batch_limit": 5, "maxiter": 200, "init_batch_limit": 5},
)
# observe new values
new_x = candidates.detach()
return new_x
def construct_acqf(model, objective, num_samples, best_f):
sampler = IIDNormalSampler(num_samples=num_samples)
qEI = qExpectedImprovement(
model=model,
best_f=best_f,
sampler=sampler,
objective=objective,
)
return qEI
Set environmental parameters, we use 20 initial data points and optimize for 20 steps with a batch size of 3 candidate points at each evaluation.
if SMOKE_TEST:
n_init = 5
n_steps = 1
batch_size = 2
num_samples = 4
n_trials = 4
verbose = False
else:
n_init = 20
n_steps = 20
batch_size = 3
num_samples = 128
n_trials = 3
verbose = True
Warning... this optimization loop can take a while, especially on the CPU. We compare to both random sampling and a batch GP fit in a composite manner on every output. The batch GP does not take into account any correlations between the different tasks.
mtgp_trial_objectives = []
batch_trial_objectives = []
rand_trial_objectives = []
for trial in range(n_trials):
init_x = (bounds[1] - bounds[0]) * torch.rand(n_init, bounds.shape[1], **tkwargs) + bounds[0]
init_y = problem(init_x)
mtgp_train_x, mtgp_train_y = init_x, init_y
batch_train_x, batch_train_y = init_x, init_y
rand_x, rand_y = init_x, init_y
best_value_mtgp = objective(init_y).max()
best_value_batch = best_value_mtgp
best_random = best_value_mtgp
for iteration in range(n_steps):
# we empty the cache to clear memory out
torch.cuda.empty_cache()
mtgp_t0 = time.time()
mtgp = KroneckerMultiTaskGP(
mtgp_train_x,
mtgp_train_y,
)
mtgp_mll = ExactMarginalLogLikelihood(mtgp.likelihood, mtgp)
fit_gpytorch_torch(mtgp_mll, options={"maxiter": 3000, "lr": 0.01, "disp": False})
mtgp_acqf = construct_acqf(mtgp, objective, num_samples, best_value_mtgp)
new_mtgp_x = optimize_acqf_and_get_candidate(mtgp_acqf, bounds, batch_size)
mtgp_t1 = time.time()
batch_t0 = time.time()
batchgp = SingleTaskGP(
batch_train_x,
batch_train_y,
)
batch_mll = ExactMarginalLogLikelihood(batchgp.likelihood, batchgp)
fit_gpytorch_torch(batch_mll, options={"maxiter": 3000, "lr": 0.01, "disp": False})
batch_acqf = construct_acqf(batchgp, objective, num_samples, best_value_batch)
new_batch_x = optimize_acqf_and_get_candidate(batch_acqf, bounds, batch_size)
batch_t1 = time.time()
mtgp_train_x = torch.cat((mtgp_train_x, new_mtgp_x), dim=0)
batch_train_x = torch.cat((batch_train_x, new_batch_x), dim=0)
mtgp_train_y = torch.cat((mtgp_train_y, problem(new_mtgp_x)), dim=0)
batch_train_y = torch.cat((batch_train_y, problem(new_batch_x)), dim=0)
best_value_mtgp = objective(mtgp_train_y).max()
best_value_batch = objective(batch_train_y).max()
new_rand_x = (bounds[1] - bounds[0]) * torch.rand(batch_size, bounds.shape[1], **tkwargs) + bounds[0]
rand_x = torch.cat((rand_x, new_rand_x))
rand_y = torch.cat((rand_y, problem(new_rand_x)))
best_random = objective(rand_y).max()
if verbose:
print(
f"\nBatch {iteration:>2}: best_value (random, mtgp, batch) = "
f"({best_random:>4.2f}, {best_value_mtgp:>4.2f}, {best_value_batch:>4.2f}), "
f"batch time = {batch_t1-batch_t0:>4.2f}, mtgp time = {mtgp_t1-mtgp_t0:>4.2f}", end=""
)
else:
print(".", end="")
mtgp_trial_objectives.append(objective(mtgp_train_y).detach().cpu())
batch_trial_objectives.append(objective(batch_train_y).detach().cpu())
rand_trial_objectives.append(objective(rand_y).detach().cpu())
Batch 0: best_value (random, mtgp, batch) = (-9.52, -9.52, -9.52), batch time = 11.39, mtgp time = 50.09 Batch 1: best_value (random, mtgp, batch) = (-9.52, -9.52, -9.52), batch time = 11.67, mtgp time = 22.61 Batch 2: best_value (random, mtgp, batch) = (-9.52, -9.52, -9.52), batch time = 12.64, mtgp time = 28.46
/mnt/xarfuse/uid-228567/ef01a148-seed-nspid4026531836_cgpid45489687-ns-4026531840/gpytorch/utils/cholesky.py:44: NumericalWarning: A not p.d., added jitter of 1.0e-06 to the diagonal
Batch 3: best_value (random, mtgp, batch) = (-9.52, -9.52, -9.52), batch time = 9.27, mtgp time = 43.40
/mnt/xarfuse/uid-228567/ef01a148-seed-nspid4026531836_cgpid45489687-ns-4026531840/gpytorch/utils/cholesky.py:44: NumericalWarning: A not p.d., added jitter of 1.0e-06 to the diagonal
Batch 4: best_value (random, mtgp, batch) = (-9.52, -5.68, -9.52), batch time = 4.93, mtgp time = 30.97 Batch 5: best_value (random, mtgp, batch) = (-9.52, -5.68, -9.52), batch time = 6.38, mtgp time = 13.53
/mnt/xarfuse/uid-228567/ef01a148-seed-nspid4026531836_cgpid45489687-ns-4026531840/gpytorch/utils/cholesky.py:44: NumericalWarning: A not p.d., added jitter of 1.0e-06 to the diagonal
Batch 6: best_value (random, mtgp, batch) = (-9.52, -5.68, -9.52), batch time = 6.23, mtgp time = 42.83 Batch 7: best_value (random, mtgp, batch) = (-9.24, -4.61, -9.52), batch time = 6.24, mtgp time = 14.94 Batch 8: best_value (random, mtgp, batch) = (-9.24, -4.47, -9.52), batch time = 9.92, mtgp time = 21.80 Batch 9: best_value (random, mtgp, batch) = (-9.24, -4.47, -6.24), batch time = 14.26, mtgp time = 30.95
/mnt/xarfuse/uid-228567/ef01a148-seed-nspid4026531836_cgpid45489687-ns-4026531840/gpytorch/utils/cholesky.py:44: NumericalWarning: A not p.d., added jitter of 1.0e-06 to the diagonal
Batch 10: best_value (random, mtgp, batch) = (-9.24, -3.26, -4.54), batch time = 12.88, mtgp time = 26.41 Batch 11: best_value (random, mtgp, batch) = (-9.24, -3.26, -3.57), batch time = 15.02, mtgp time = 27.82 Batch 12: best_value (random, mtgp, batch) = (-9.24, -3.26, -2.50), batch time = 13.59, mtgp time = 27.51 Batch 13: best_value (random, mtgp, batch) = (-9.24, -2.75, -2.50), batch time = 24.31, mtgp time = 33.26 Batch 14: best_value (random, mtgp, batch) = (-9.24, -2.75, -2.50), batch time = 11.34, mtgp time = 28.75 Batch 15: best_value (random, mtgp, batch) = (-9.24, -2.75, -2.50), batch time = 11.85, mtgp time = 39.74 Batch 16: best_value (random, mtgp, batch) = (-9.24, -2.35, -2.31), batch time = 13.90, mtgp time = 29.78 Batch 17: best_value (random, mtgp, batch) = (-9.24, -2.35, -2.31), batch time = 11.93, mtgp time = 29.11 Batch 18: best_value (random, mtgp, batch) = (-9.24, -2.35, -2.31), batch time = 11.93, mtgp time = 30.58 Batch 19: best_value (random, mtgp, batch) = (-8.54, -2.35, -2.31), batch time = 6.24, mtgp time = 27.86 Batch 0: best_value (random, mtgp, batch) = (-6.34, -6.34, -6.34), batch time = 12.88, mtgp time = 36.25 Batch 1: best_value (random, mtgp, batch) = (-4.86, -6.34, -6.34), batch time = 12.11, mtgp time = 33.74 Batch 2: best_value (random, mtgp, batch) = (-4.86, -6.34, -6.34), batch time = 12.00, mtgp time = 46.60 Batch 3: best_value (random, mtgp, batch) = (-4.86, -6.34, -6.34), batch time = 12.92, mtgp time = 36.88 Batch 4: best_value (random, mtgp, batch) = (-4.86, -6.34, -6.34), batch time = 6.23, mtgp time = 30.78 Batch 5: best_value (random, mtgp, batch) = (-4.86, -3.68, -6.34), batch time = 8.35, mtgp time = 27.63 Batch 6: best_value (random, mtgp, batch) = (-4.86, -2.66, -4.93), batch time = 7.50, mtgp time = 49.98 Batch 7: best_value (random, mtgp, batch) = (-4.86, -2.66, -4.47), batch time = 11.87, mtgp time = 37.50 Batch 8: best_value (random, mtgp, batch) = (-4.86, -2.66, -2.81), batch time = 11.41, mtgp time = 41.93 Batch 9: best_value (random, mtgp, batch) = (-4.86, -2.66, -2.81), batch time = 10.96, mtgp time = 44.75 Batch 10: best_value (random, mtgp, batch) = (-4.86, -2.66, -2.54), batch time = 12.96, mtgp time = 34.05 Batch 11: best_value (random, mtgp, batch) = (-4.86, -2.66, -2.49), batch time = 13.49, mtgp time = 40.42 Batch 12: best_value (random, mtgp, batch) = (-4.86, -2.66, -2.26), batch time = 13.78, mtgp time = 37.96 Batch 13: best_value (random, mtgp, batch) = (-4.86, -2.35, -2.26), batch time = 11.03, mtgp time = 49.89 Batch 14: best_value (random, mtgp, batch) = (-4.86, -2.35, -2.26), batch time = 10.55, mtgp time = 33.78 Batch 15: best_value (random, mtgp, batch) = (-4.86, -2.32, -2.26), batch time = 11.93, mtgp time = 31.45 Batch 16: best_value (random, mtgp, batch) = (-4.86, -2.32, -2.26), batch time = 11.73, mtgp time = 25.93
/mnt/xarfuse/uid-228567/ef01a148-seed-nspid4026531836_cgpid45489687-ns-4026531840/gpytorch/utils/cholesky.py:44: NumericalWarning: A not p.d., added jitter of 1.0e-06 to the diagonal
Batch 17: best_value (random, mtgp, batch) = (-4.86, -2.32, -2.26), batch time = 11.65, mtgp time = 41.50 Batch 18: best_value (random, mtgp, batch) = (-4.86, -2.32, -2.26), batch time = 12.54, mtgp time = 36.79
/mnt/xarfuse/uid-228567/ef01a148-seed-nspid4026531836_cgpid45489687-ns-4026531840/gpytorch/utils/cholesky.py:44: NumericalWarning: A not p.d., added jitter of 1.0e-06 to the diagonal
Batch 19: best_value (random, mtgp, batch) = (-4.86, -2.32, -2.26), batch time = 10.62, mtgp time = 52.57 Batch 0: best_value (random, mtgp, batch) = (-5.52, -5.52, -5.52), batch time = 11.71, mtgp time = 39.39
/mnt/xarfuse/uid-228567/ef01a148-seed-nspid4026531836_cgpid45489687-ns-4026531840/gpytorch/utils/cholesky.py:44: NumericalWarning: A not p.d., added jitter of 1.0e-06 to the diagonal
Batch 1: best_value (random, mtgp, batch) = (-5.52, -5.52, -5.52), batch time = 14.40, mtgp time = 34.61
/mnt/xarfuse/uid-228567/ef01a148-seed-nspid4026531836_cgpid45489687-ns-4026531840/gpytorch/utils/cholesky.py:44: NumericalWarning: A not p.d., added jitter of 1.0e-06 to the diagonal
Batch 2: best_value (random, mtgp, batch) = (-5.52, -5.52, -5.52), batch time = 14.71, mtgp time = 26.74
/mnt/xarfuse/uid-228567/ef01a148-seed-nspid4026531836_cgpid45489687-ns-4026531840/gpytorch/utils/cholesky.py:44: NumericalWarning: A not p.d., added jitter of 1.0e-06 to the diagonal
Batch 3: best_value (random, mtgp, batch) = (-5.52, -5.52, -5.52), batch time = 12.25, mtgp time = 31.19 Batch 4: best_value (random, mtgp, batch) = (-5.52, -2.74, -3.40), batch time = 11.82, mtgp time = 31.47
/mnt/xarfuse/uid-228567/ef01a148-seed-nspid4026531836_cgpid45489687-ns-4026531840/gpytorch/utils/cholesky.py:44: NumericalWarning: A not p.d., added jitter of 1.0e-06 to the diagonal
Batch 5: best_value (random, mtgp, batch) = (-5.52, -2.74, -3.29), batch time = 10.88, mtgp time = 43.98 Batch 6: best_value (random, mtgp, batch) = (-5.52, -2.74, -3.29), batch time = 11.60, mtgp time = 32.99 Batch 7: best_value (random, mtgp, batch) = (-5.52, -2.74, -2.97), batch time = 11.25, mtgp time = 37.71 Batch 8: best_value (random, mtgp, batch) = (-5.52, -2.33, -2.48), batch time = 11.48, mtgp time = 31.68 Batch 9: best_value (random, mtgp, batch) = (-5.52, -2.33, -2.48), batch time = 11.50, mtgp time = 29.20 Batch 10: best_value (random, mtgp, batch) = (-5.52, -2.33, -2.48), batch time = 12.28, mtgp time = 35.44 Batch 11: best_value (random, mtgp, batch) = (-5.52, -2.33, -2.48), batch time = 11.04, mtgp time = 31.78 Batch 12: best_value (random, mtgp, batch) = (-5.52, -2.33, -2.48), batch time = 9.10, mtgp time = 46.31 Batch 13: best_value (random, mtgp, batch) = (-5.52, -2.33, -2.48), batch time = 10.81, mtgp time = 55.77 Batch 14: best_value (random, mtgp, batch) = (-5.52, -2.33, -2.48), batch time = 12.39, mtgp time = 50.37 Batch 15: best_value (random, mtgp, batch) = (-5.52, -2.33, -2.40), batch time = 10.58, mtgp time = 26.95 Batch 16: best_value (random, mtgp, batch) = (-5.52, -2.33, -2.35), batch time = 11.29, mtgp time = 39.82 Batch 17: best_value (random, mtgp, batch) = (-5.52, -2.33, -2.31), batch time = 10.38, mtgp time = 42.04 Batch 18: best_value (random, mtgp, batch) = (-5.52, -2.33, -2.31), batch time = 10.90, mtgp time = 18.43 Batch 19: best_value (random, mtgp, batch) = (-5.52, -2.33, -2.31), batch time = 10.75, mtgp time = 47.58
import matplotlib.pyplot as plt
Finally, we plot the results, where we see that the MTGP tends to outperform both the batch GP and the random baseline. The optimization procedure seems to have a good deal less noise than the batch GP.
However, as demonstrated above, optimizing the acquisition function and fitting the MTGP tend to take a bit longer.
mtgp_results = torch.stack(mtgp_trial_objectives)[:, n_init:].cummax(1).values
batch_results = torch.stack(batch_trial_objectives)[:, n_init:].cummax(1).values
random_results = torch.stack(rand_trial_objectives)[:, n_init:].cummax(1).values
plt.plot(mtgp_results.mean(0))
plt.fill_between(
torch.arange(n_steps * batch_size),
mtgp_results.mean(0) - 2. * mtgp_results.std(0) / (n_trials ** 0.5),
mtgp_results.mean(0) + 2. * mtgp_results.std(0) / (n_trials ** 0.5),
alpha = 0.3, label = "MTGP",
)
plt.plot(batch_results.mean(0))
plt.fill_between(
torch.arange(n_steps * batch_size),
batch_results.mean(0) - 2. * batch_results.std(0) / (n_trials ** 0.5),
batch_results.mean(0) + 2. * batch_results.std(0) / (n_trials ** 0.5),
alpha = 0.3, label = "Batch"
)
plt.plot(random_results.mean(0))
plt.fill_between(
torch.arange(n_steps * batch_size),
random_results.mean(0) - 2. * random_results.std(0) / (n_trials ** 0.5),
random_results.mean(0) + 2. * random_results.std(0) / (n_trials ** 0.5),
alpha = 0.3, label = "Random"
)
plt.legend(loc = "lower right", fontsize = 15)
plt.xlabel("Number of Function Queries")
plt.ylabel("Best Objective Achieved")
Text(0, 0.5, 'Best Objective Achieved')
W0720 165607.890 font_manager.py:1280] findfont: Font family ['STIXGeneral'] not found. Falling back to DejaVu Sans. W0720 165607.892 font_manager.py:1280] findfont: Font family ['STIXGeneral'] not found. Falling back to DejaVu Sans. W0720 165607.895 font_manager.py:1280] findfont: Font family ['STIXGeneral'] not found. Falling back to DejaVu Sans. W0720 165607.897 font_manager.py:1280] findfont: Font family ['STIXNonUnicode'] not found. Falling back to DejaVu Sans. W0720 165607.898 font_manager.py:1280] findfont: Font family ['STIXNonUnicode'] not found. Falling back to DejaVu Sans. W0720 165607.900 font_manager.py:1280] findfont: Font family ['STIXNonUnicode'] not found. Falling back to DejaVu Sans. W0720 165607.901 font_manager.py:1280] findfont: Font family ['STIXSizeOneSym'] not found. Falling back to DejaVu Sans. W0720 165607.902 font_manager.py:1280] findfont: Font family ['STIXSizeTwoSym'] not found. Falling back to DejaVu Sans. W0720 165607.904 font_manager.py:1280] findfont: Font family ['STIXSizeThreeSym'] not found. Falling back to DejaVu Sans. W0720 165607.905 font_manager.py:1280] findfont: Font family ['STIXSizeFourSym'] not found. Falling back to DejaVu Sans. W0720 165607.906 font_manager.py:1280] findfont: Font family ['STIXSizeFiveSym'] not found. Falling back to DejaVu Sans. W0720 165607.908 font_manager.py:1280] findfont: Font family ['cmsy10'] not found. Falling back to DejaVu Sans. W0720 165607.909 font_manager.py:1280] findfont: Font family ['cmr10'] not found. Falling back to DejaVu Sans. W0720 165607.910 font_manager.py:1280] findfont: Font family ['cmtt10'] not found. Falling back to DejaVu Sans. W0720 165607.911 font_manager.py:1280] findfont: Font family ['cmmi10'] not found. Falling back to DejaVu Sans. W0720 165607.913 font_manager.py:1280] findfont: Font family ['cmb10'] not found. Falling back to DejaVu Sans. W0720 165607.914 font_manager.py:1280] findfont: Font family ['cmss10'] not found. Falling back to DejaVu Sans. W0720 165607.915 font_manager.py:1280] findfont: Font family ['cmex10'] not found. Falling back to DejaVu Sans. W0720 165607.920 font_manager.py:1280] findfont: Font family ['DejaVu Sans Display'] not found. Falling back to DejaVu Sans.