Skip to main content
Version: Next

Composite Bayesian Optimization with Multi-Task Gaussian Processes

Composite Bayesian Optimization with Multi-Task Gaussian Processes

In this tutorial, we'll be describing how to perform multi-task Bayesian optimization over composite functions. In these types of problems, there are several related outputs, and an overall easy to evaluate objective function that we wish to maximize.

Multi-task Bayesian Optimization was first proposed by Swersky et al, NeurIPS, '13 in the context of fast hyper-parameter tuning for neural network models; however, we demonstrate a more advanced use-case of composite Bayesian optimization where the overall function that we wish to optimize is a cheap-to-evaluate (and known) function of the outputs. In general, we expect that using more information about the function should yield improved performance when attempting to optimize it, particularly if the metric function itself is quickly varying.

See the composite BO tutorial w/ HOGP for a more technical introduction. In general, we suggest using MTGPs for unstructured task outputs and the HOGP for matrix / tensor structured outputs.

We will use a Multi-Task Gaussian process (MTGP) with an ICM kernel to model all of the outputs in this problem. MTGPs can be easily accessed in Botorch via the botorch.models.KroneckerMultiTaskGP model class (for the "block design" case of fully observed outputs at all inputs). Given TT tasks (outputs) and nn data points, they assume that the responses, YRn×T,Y \sim \mathbb{R}^{n \times T}, are distributed as vec(Y)N(f,D)\text{vec}(Y) \sim \mathcal{N}(f, D) and fGP(μθ,KXXKT),f \sim \mathcal{GP}(\mu_{\theta}, K_{XX} \otimes K_{T}), where DD is a (diagonal) noise term.

# Install dependencies if we are running in colab
import sys
if 'google.colab' in sys.modules:
%pip install botorch
import os
import time

import torch
from botorch.acquisition.logei import qLogExpectedImprovement
from botorch.acquisition.objective import GenericMCObjective
from botorch.models import KroneckerMultiTaskGP
from botorch.optim import optimize_acqf
from botorch.sampling.normal import IIDNormalSampler

from botorch.test_functions import Hartmann
from gpytorch.mlls import ExactMarginalLogLikelihood

import warnings
warnings.filterwarnings("ignore")

SMOKE_TEST = os.environ.get("SMOKE_TEST")

Set device, dtype and random seed

torch.random.manual_seed(10)

tkwargs = {
"dtype": torch.double,
"device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
}

Problem Definition

The function that we wish to optimize is based off of a contextual version of the Hartmann-6 test function, where following Feng et al, NeurIPS, '20 we convert the sixth task dimension into a task indicator. Here we assume that we evaluate all contexts at once.

from torch import Tensor


class ContextualHartmann6(Hartmann):
def __init__(self, num_tasks: int = 20, noise_std=None, negate=False):
super().__init__(dim=6, noise_std=noise_std, negate=negate)
self.task_range = torch.linspace(0, 1, num_tasks).unsqueeze(-1)
self._bounds = [(0.0, 1.0) for _ in range(self.dim - 1)]
self.bounds = torch.tensor(self._bounds).t()

def evaluate_true(self, X: Tensor) -> Tensor:
batch_X = X.unsqueeze(-2)
batch_dims = X.ndim - 1

expanded_task_range = self.task_range
for _ in range(batch_dims):
expanded_task_range = expanded_task_range.unsqueeze(0)
task_range = expanded_task_range.repeat(*X.shape[:-1], 1, 1).to(X)
concatenated_X = torch.cat(
(
batch_X.repeat(*[1] * batch_dims, self.task_range.shape[0], 1),
task_range,
),
dim=-1,
)
return super().evaluate_true(concatenated_X)

We use GenericMCObjective to define the differentiable function that we are optimizing. Here, it is defined as

g(f)=i=1Tcos(fi2+fiwi)g(f) = \sum_{i=1}^T \cos(f_i^2 + f_i w_i)

where ww is a weight vector (drawn randomly once at the start of the optimization). As this function is a non-linear function of the outputs f,f, we cannot compute acquisition functions via computation of the posterior mean and variance, but rather have to compute posterior samples and evaluate acquisitions with Monte Carlo sampling.

For greater than 1010 or so tasks, it is computationally challenging to sample the posterior over all tasks jointly using conventional approaches, except that Maddox et al, '21 have devised an efficient method for exploiting the structure in the posterior distribution of the MTGP, enabling efficient MC based optimization of objectives using MTGPs. In this tutorial, we choose 6 contexts/tasks for demostration.

num_tasks = 6
problem = ContextualHartmann6(num_tasks=num_tasks, noise_std=0.001, negate=True).to(**tkwargs)

# we choose num_tasks random weights
weights = torch.randn(num_tasks, **tkwargs)


def callable_func(samples, X=None):
res = -torch.cos((samples**2) + samples * weights)
return res.sum(dim=-1)


objective = GenericMCObjective(callable_func)
bounds = problem.bounds

BO Loop

Set environmental parameters, we use 20 initial data points and optimize for 20 steps with a batch size of 3 candidate points at each evaluation.

if SMOKE_TEST:
n_init = 5
n_steps = 1
batch_size = 2
num_samples = 4
# For L-BFGS inner optimization loop
MAXITER = 10
else:
n_init = 10
n_steps = 10
batch_size = 3
num_samples = 64
MAXITER = 200
from botorch.fit import fit_gpytorch_mll

Finally, run the optimization loop.

Warning... this optimization loop can take a while, especially on the CPU. We compare to random sampling.

# New version
torch.manual_seed(0)

init_x = (bounds[1] - bounds[0]) * torch.rand(
n_init, bounds.shape[1], **tkwargs
) + bounds[0]

init_y = problem(init_x)

mtgp_train_x, mtgp_train_y = init_x, init_y
rand_x, rand_y = init_x, init_y

best_value_mtgp = objective(init_y).max()
best_random = best_value_mtgp

for iteration in range(n_steps):
# we empty the cache to clear memory out
torch.cuda.empty_cache()

# MTGP
mtgp_t0 = time.monotonic()
mtgp = KroneckerMultiTaskGP(mtgp_train_x, mtgp_train_y)
mtgp_mll = ExactMarginalLogLikelihood(mtgp.likelihood, mtgp)
fit_gpytorch_mll(mll=mtgp_mll, optimizer_kwargs={"options": {"maxiter": 50}})

sampler = IIDNormalSampler(sample_shape=torch.Size([num_samples]))
mtgp_acqf = qLogExpectedImprovement(
model=mtgp,
best_f=best_value_mtgp,
sampler=sampler,
objective=objective,
)
new_mtgp_x, _ = optimize_acqf(
acq_function=mtgp_acqf,
bounds=bounds,
q=batch_size,
num_restarts=10,
raw_samples=512, # used for intialization heuristic
options={"batch_limit": 5, "maxiter": MAXITER, "init_batch_limit": 5},
)
mtgp_train_x = torch.cat((mtgp_train_x, new_mtgp_x), dim=0)
mtgp_train_y = torch.cat((mtgp_train_y, problem(new_mtgp_x)), dim=0)
best_value_mtgp = objective(mtgp_train_y).max()
mtgp_t1 = time.monotonic()

# rand
new_rand_x = (bounds[1] - bounds[0]) * torch.rand(
batch_size, bounds.shape[1], **tkwargs
) + bounds[0]
rand_x = torch.cat((rand_x, new_rand_x))
rand_y = torch.cat((rand_y, problem(new_rand_x)))
best_random = objective(rand_y).max()

print(
f"\nBatch {iteration:>2}: best_value (random, mtgp) = "
f"({best_random:>4.2f}, {best_value_mtgp:>4.2f}, "
f"mtgp time = {mtgp_t1-mtgp_t0:>4.2f}",
end="",
)

objectives = {
"MGTP": objective(mtgp_train_y).detach().cpu(),
"Random": objective(rand_y).detach().cpu(),
}
Out:

Batch 0: best_value (random, mtgp) = (-4.76, -4.76, mtgp time = 15.64

Batch 1: best_value (random, mtgp) = (-4.76, -4.76, mtgp time = 36.96

Batch 2: best_value (random, mtgp) = (-4.39, -4.76, mtgp time = 23.20

Batch 3: best_value (random, mtgp) = (-4.39, -4.76, mtgp time = 23.07

Batch 4: best_value (random, mtgp) = (-4.39, -4.76, mtgp time = 35.79

Batch 5: best_value (random, mtgp) = (-4.22, -4.76, mtgp time = 50.71

Batch 6: best_value (random, mtgp) = (-4.22, -2.88, mtgp time = 61.87

Batch 7: best_value (random, mtgp) = (-4.22, -2.88, mtgp time = 115.77

Batch 8: best_value (random, mtgp) = (-4.22, -1.91, mtgp time = 67.89

Batch 9: best_value (random, mtgp) = (-4.22, -1.91, mtgp time = 48.66

Plot Results

import matplotlib.pyplot as plt

Finally, we plot the results. MTGP will outperform the random baseline.

results = {
k: t[n_init:].cummax(0).values for k, t in objectives.items()
}
for name, vals in results.items():
plt.plot(vals, label=name)
plt.legend()
Out:

<matplotlib.legend.Legend at 0x7f26b3d3b010>