In this tutorial, we're going to explore composite Bayesian optimization Astudillo & Frazier, ICML, '19 with the High Order Gaussian Process (HOGP) model of Zhe et al, AISTATS, '19. The setup for composite Bayesian optimization is that we have an unknown (black box) function mapping input parameters to several outputs, and a second, known function describing the quality of the functional output. We wish to find input parameters that maximize the output metric function. We wish to find input parameters that maximize the output metric function in a black-box manner.
Specifically, this can be described as $\max_{x \in \mathcal{X}} g(f(x)),$ where $f$ is unknown and $g$ is known. As in traditional Bayesian optimization, we are going to construct a Gaussian process surrogate model over the expensive to evaluate function $f(.),$ and will use a HOGP to model this function.
The High Order Gaussian Process (HOGP) model is a Gaussian process model designed specifically to operate over tensors or multi-dimensional arrays and exploits structure in the tensor to be able to operate efficiently. Specifically, the HOGP takes as inputs $y \in \mathbb{R}^{N \times d_2 \times \cdots \times d_M}$ and assumes that $\text{vec}(y) \sim \mathcal{N}(0, \otimes_{i=1}^M K_i + \sigma^2 I),$ where $K_1 = K_{XX}.$ Each dimension of the tensor has its own kernel function, $K_i,$ as well as a set of $d_i$ latent parameters that can be optimized over.
Recently, Maddox et al, '21 proposed a method for computing posterior samples from the HOGP by exploiting structure in the posterior distribution, thereby enabling its usage in BO settings. While they show that this approach allows to use composite BO on problems with tens or thousands of outputs, for scalability we consider a much smaller example here (that does not require GPU acceleration).
import torch
import os
import logging
import numpy as np
import math
import matplotlib.pyplot as plt
import time
import gpytorch.settings as gpt_settings
from botorch.acquisition import qSimpleRegret, qExpectedImprovement
from botorch.acquisition.objective import GenericMCObjective
from botorch.models import HigherOrderGP, SingleTaskGP
from botorch.models.higher_order_gp import FlattenedStandardize
from botorch.models.transforms import Normalize, Standardize
from botorch import fit_gpytorch_model
from botorch.optim import optimize_acqf
from botorch.sampling import IIDNormalSampler
from botorch.utils.sampling import draw_sobol_samples
from gpytorch.mlls import ExactMarginalLogLikelihood
SMOKE_TEST = os.environ.get("SMOKE_TEST")
from botorch.optim.fit import fit_gpytorch_torch
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0")
dtype = torch.float
print("Using ", device)
Using cuda:0
models_used = (
"rnd",
"ei",
"ei_hogp_cf",
)
We use a simple test problem describing the concentration of pollutants after a chemical spill from Astudillo & Frazier, ICML, '19 defined over a $3 \times 4$ grid of values $s,t$ and we wish to optimize the parameters w.r.t. their true values, to estimate the true value of parameters, $x = [M, D, L, \tau].$ The function is given by $$ f(s,t | M, D, L, \tau) := \frac{M}{\sqrt{4 \pi D t}} \exp\{-\frac{s^2}{4Dt}\} + \frac{1_{t > \tau} M}{\sqrt{4 \pi D(t - \tau)}} \exp\{- \frac{(s - L)^2}{4 D (t - \tau)}\}, $$ with the cheap to evaluate, differentiable function given by $g(y):= \sum_{(s,t) \in S \times T} \left(c(s, t|x_{\text{true}}) - y\right)^2.$ As the objective function itself is going to be implemented in Pytorch, we will be able to differentiate through it, enabling the usage of gradient-based optimization to optimize the objectives with respect to the inputs.
def env_cfun(s, t, M, D, L, tau):
c1 = M / torch.sqrt(4 * math.pi * D * t)
exp1 = torch.exp(-(s ** 2) / 4 / D / t)
term1 = c1 * exp1
c2 = M / torch.sqrt(4 * math.pi * D * (t - tau))
exp2 = torch.exp(-((s - L) ** 2) / 4 / D / (t - tau))
term2 = c2 * exp2
term2[torch.isnan(term2)] = 0.0
return term1 + term2
These are helper functions for us to maximize the acquisition function and to get random points.
def gen_rand_points(bounds, num_samples):
points_nlzd = torch.rand(num_samples, bounds.shape[-1]).to(bounds)
return bounds[0] + (bounds[1] - bounds[0]) * points_nlzd
def optimize_ei(qEI, bounds, **options):
with gpt_settings.fast_computations(covar_root_decomposition=False):
cands_nlzd, _ = optimize_acqf(
qEI, bounds, **options,
)
return cands_nlzd
Below is a wrapped function to help us define bounds on the parameter space, we can also vary the size of the grid if we'd like to.
def prepare_data(s_size=3, t_size=4, device=device, dtype=dtype):
print("---- Running the environmental problem with ", s_size, t_size, " ----")
# X = [M, D, L, tau]
bounds = torch.tensor(
[[7.0, 0.02, 0.01, 30.010], [13.0, 0.12, 3.00, 30.295]],
device=device,
dtype=dtype,
)
M0 = torch.tensor(10.0, device=device, dtype=dtype)
D0 = torch.tensor(0.07, device=device, dtype=dtype)
L0 = torch.tensor(1.505, device=device, dtype=dtype)
tau0 = torch.tensor(30.1525, device=device, dtype=dtype)
# we can vectorize everything, no need for loops
if s_size == 3:
S = torch.tensor([0.0, 1.0, 2.5], device=device, dtype=dtype)
else:
S = torch.linspace(0.0, 2.5, s_size, device=device, dtype=dtype)
if t_size == 4:
T = torch.tensor([15.0, 30.0, 45.0, 60.0], device=device, dtype=dtype)
else:
T = torch.linspace(15.0, 60.0, t_size, device=device, dtype=dtype)
Sgrid, Tgrid = torch.meshgrid(S, T)
# X = [M, D, L, tau]
def c_batched(X, k=None):
return torch.stack([env_cfun(Sgrid, Tgrid, *x) for x in X])
c_true = env_cfun(Sgrid, Tgrid, M0, D0, L0, tau0)
def neq_sum_quared_diff(samples):
# unsqueeze
if samples.shape[-1] == (s_size * t_size):
samples = samples.unsqueeze(-1).reshape(
*samples.shape[:-1], s_size, t_size
)
sq_diffs = (samples - c_true).pow(2)
return sq_diffs.sum(dim=(-1, -2)).mul(-1.0)
objective = GenericMCObjective(neq_sum_quared_diff)
num_samples = 32
return c_batched, objective, bounds, num_samples
In the above, we construct a GenericMCObjective
instance to codify the objective function (which is minimizing the MSE of the output tensors and the outputs corresponding to the "true" parameter values). Note that the objective function is encoded in PyTorch and is differentiable (although it technically doesn't have to be). Ultimately, we backpropagate through the objective with respect to the input parameters (and through the HOGP as well).
Finally, we run the BO loop over three trials for $15$ batches each. This loop might take a while.
We will be comparing to both random selection and batch expected improvment on the aggregated metric.
n_init = 30
if SMOKE_TEST:
n_batches = 1
batch_size = 2
n_trials = 1
else:
n_batches = 15
batch_size = 4
n_trials = 3
As a word of caution, we've found that when fitting the HOGP model, using first-order optimizers (e.g. Adam) as is used in fit_gpytorch_torch
tends to outperform second-order optimizers such as L-BFGS-B due to the large number of free parameters in the HOGP. L-BFGS-B tends to overfit in practice here.
all_objective_vals = []
with gpt_settings.cholesky_jitter(1e-4):
for trial in range(n_trials):
print("Beginning with trial: ", trial+1)
c_batched, objective, bounds, num_samples = prepare_data(device=device, dtype=dtype)
train_X_init = gen_rand_points(bounds, n_init)
train_Y_init = c_batched(train_X_init)
# these will keep track of the points explored
train_X = {k: train_X_init.clone() for k in models_used}
train_Y = {k: train_Y_init.clone() for k in train_X}
# run the BO loop
for i in range(n_batches):
tic = time.time()
# get best observations, log status
best_f = {k: objective(v).max().detach() for k, v in train_Y.items()}
logging.info(
f"It {i+1:>2}/{n_batches}, best obs.: "
", ".join([f"{k}: {v:.3f}" for k, v in best_f.items()])
)
# generate random candidates
cands = {}
cands["rnd"] = gen_rand_points(bounds, batch_size)
optimize_acqf_kwargs = {
"q": batch_size,
"num_restarts": 10,
"raw_samples": 512,
}
sampler = IIDNormalSampler(num_samples=128)
train_Y_ei = objective(train_Y["ei"]).unsqueeze(-1)
model_ei = SingleTaskGP(
train_X["ei"],
train_Y_ei,
input_transform=Normalize(train_X["ei"].shape[-1]),
outcome_transform=Standardize(train_Y_ei.shape[-1]),
)
mll = ExactMarginalLogLikelihood(model_ei.likelihood, model_ei)
fit_gpytorch_torch(mll, options={"lr": 0.01, "maxiter": 3000, "disp": False})
# generate qEI candidate (single output modeling)
qEI = qExpectedImprovement(model_ei, best_f=best_f["ei"], sampler=sampler)
cands["ei"] = optimize_ei(qEI, bounds, **optimize_acqf_kwargs)
model_ei_hogp_cf = HigherOrderGP(
train_X["ei_hogp_cf"],
train_Y["ei_hogp_cf"],
outcome_transform=FlattenedStandardize(train_Y["ei_hogp_cf"].shape[1:]),
input_transform=Normalize(train_X["ei_hogp_cf"].shape[-1]),
latent_init="gp",
)
mll = ExactMarginalLogLikelihood(model_ei_hogp_cf.likelihood, model_ei_hogp_cf)
fit_gpytorch_torch(mll, options={"lr": 0.01, "maxiter": 3000, "disp": False})
# generate qEI candidate (multi-output modeling)
qEI_hogp_cf = qExpectedImprovement(
model_ei_hogp_cf,
best_f=best_f["ei_hogp_cf"],
sampler=sampler,
objective=objective,
)
cands["ei_hogp_cf"] = optimize_ei(qEI_hogp_cf, bounds, **optimize_acqf_kwargs)
# make observations and update data
for k, Xold in train_X.items():
Xnew = cands[k]
if Xnew.shape[0] > 0:
train_X[k] = torch.cat([Xold, Xnew])
train_Y[k] = torch.cat([train_Y[k], c_batched(Xnew)])
logging.info(f"Wall time: {time.time() - tic:1f}")
objective_dict = {k: objective(train_Y[k]) for k in train_Y}
all_objective_vals.append(objective_dict)
print("Finished with trial: ", trial+1)
Beginning with trial: 1 ---- Running the environmental problem with 3 4 ----
I0712 100259.916 <ipython-input-9-289fa74c7767>:22] rnd: -0.277It 1/15, best obs.: , ei: -0.277It 1/15, best obs.: , ei_hogp_cf: -0.277 I0712 100313.875 <ipython-input-9-289fa74c7767>:80] Wall time: 13.959815 I0712 100313.876 <ipython-input-9-289fa74c7767>:22] rnd: -0.277It 2/15, best obs.: , ei: -0.277It 2/15, best obs.: , ei_hogp_cf: -0.277 I0712 100331.374 <ipython-input-9-289fa74c7767>:80] Wall time: 17.497971 I0712 100331.375 <ipython-input-9-289fa74c7767>:22] rnd: -0.277It 3/15, best obs.: , ei: -0.277It 3/15, best obs.: , ei_hogp_cf: -0.277 I0712 100351.131 <ipython-input-9-289fa74c7767>:80] Wall time: 19.756105 I0712 100351.132 <ipython-input-9-289fa74c7767>:22] rnd: -0.277It 4/15, best obs.: , ei: -0.277It 4/15, best obs.: , ei_hogp_cf: -0.277 I0712 100423.226 <ipython-input-9-289fa74c7767>:80] Wall time: 32.094512 I0712 100423.228 <ipython-input-9-289fa74c7767>:22] rnd: -0.277It 5/15, best obs.: , ei: -0.277It 5/15, best obs.: , ei_hogp_cf: -0.277 I0712 100454.340 <ipython-input-9-289fa74c7767>:80] Wall time: 31.112883 I0712 100454.342 <ipython-input-9-289fa74c7767>:22] rnd: -0.277It 6/15, best obs.: , ei: -0.277It 6/15, best obs.: , ei_hogp_cf: -0.219 I0712 100554.118 <ipython-input-9-289fa74c7767>:80] Wall time: 59.776674 I0712 100554.119 <ipython-input-9-289fa74c7767>:22] rnd: -0.277It 7/15, best obs.: , ei: -0.277It 7/15, best obs.: , ei_hogp_cf: -0.137 I0712 100649.414 <ipython-input-9-289fa74c7767>:80] Wall time: 55.294851 I0712 100649.415 <ipython-input-9-289fa74c7767>:22] rnd: -0.277It 8/15, best obs.: , ei: -0.277It 8/15, best obs.: , ei_hogp_cf: -0.076 /mnt/xarfuse/uid-22017/907c0fbc-seed-nspid4026533651_cgpid35640602-ns-4026533648/gpytorch/utils/cholesky.py:44: NumericalWarning: A not p.d., added jitter of 1.0e-04 to the diagonal I0712 100659.421 <ipython-input-9-289fa74c7767>:80] Wall time: 10.006811 I0712 100659.423 <ipython-input-9-289fa74c7767>:22] rnd: -0.277It 9/15, best obs.: , ei: -0.277It 9/15, best obs.: , ei_hogp_cf: -0.076 I0712 100713.541 <ipython-input-9-289fa74c7767>:80] Wall time: 14.118601 I0712 100713.543 <ipython-input-9-289fa74c7767>:22] rnd: -0.277It 10/15, best obs.: , ei: -0.277It 10/15, best obs.: , ei_hogp_cf: -0.076 I0712 100725.215 <ipython-input-9-289fa74c7767>:80] Wall time: 11.672525 I0712 100725.216 <ipython-input-9-289fa74c7767>:22] rnd: -0.277It 11/15, best obs.: , ei: -0.277It 11/15, best obs.: , ei_hogp_cf: -0.076 I0712 100733.256 <ipython-input-9-289fa74c7767>:80] Wall time: 8.040162 I0712 100733.257 <ipython-input-9-289fa74c7767>:22] rnd: -0.277It 12/15, best obs.: , ei: -0.277It 12/15, best obs.: , ei_hogp_cf: -0.076 I0712 100746.931 <ipython-input-9-289fa74c7767>:80] Wall time: 13.674355 I0712 100746.933 <ipython-input-9-289fa74c7767>:22] rnd: -0.277It 13/15, best obs.: , ei: -0.277It 13/15, best obs.: , ei_hogp_cf: -0.076 I0712 100835.895 <ipython-input-9-289fa74c7767>:80] Wall time: 48.962623 I0712 100835.897 <ipython-input-9-289fa74c7767>:22] rnd: -0.277It 14/15, best obs.: , ei: -0.212It 14/15, best obs.: , ei_hogp_cf: -0.035 I0712 100847.778 <ipython-input-9-289fa74c7767>:80] Wall time: 11.881666 I0712 100847.780 <ipython-input-9-289fa74c7767>:22] rnd: -0.277It 15/15, best obs.: , ei: -0.212It 15/15, best obs.: , ei_hogp_cf: -0.035 I0712 100927.273 <ipython-input-9-289fa74c7767>:80] Wall time: 39.493477 I0712 100927.289 <ipython-input-9-289fa74c7767>:22] rnd: -0.423It 1/15, best obs.: , ei: -0.423It 1/15, best obs.: , ei_hogp_cf: -0.423
Finished with trial: 1 Beginning with trial: 2 ---- Running the environmental problem with 3 4 ----
I0712 100953.469 <ipython-input-9-289fa74c7767>:80] Wall time: 26.180312 I0712 100953.471 <ipython-input-9-289fa74c7767>:22] rnd: -0.423It 2/15, best obs.: , ei: -0.423It 2/15, best obs.: , ei_hogp_cf: -0.108 I0712 101006.910 <ipython-input-9-289fa74c7767>:80] Wall time: 13.439897 I0712 101006.912 <ipython-input-9-289fa74c7767>:22] rnd: -0.423It 3/15, best obs.: , ei: -0.131It 3/15, best obs.: , ei_hogp_cf: -0.108 I0712 101021.213 <ipython-input-9-289fa74c7767>:80] Wall time: 14.300933 I0712 101021.215 <ipython-input-9-289fa74c7767>:22] rnd: -0.423It 4/15, best obs.: , ei: -0.131It 4/15, best obs.: , ei_hogp_cf: -0.108 I0712 101029.243 <ipython-input-9-289fa74c7767>:80] Wall time: 8.029261 I0712 101029.245 <ipython-input-9-289fa74c7767>:22] rnd: -0.423It 5/15, best obs.: , ei: -0.131It 5/15, best obs.: , ei_hogp_cf: -0.108 I0712 101106.178 <ipython-input-9-289fa74c7767>:80] Wall time: 36.933169 I0712 101106.179 <ipython-input-9-289fa74c7767>:22] rnd: -0.278It 6/15, best obs.: , ei: -0.131It 6/15, best obs.: , ei_hogp_cf: -0.043 I0712 101142.619 <ipython-input-9-289fa74c7767>:80] Wall time: 36.440421 I0712 101142.621 <ipython-input-9-289fa74c7767>:22] rnd: -0.278It 7/15, best obs.: , ei: -0.131It 7/15, best obs.: , ei_hogp_cf: -0.021 I0712 101152.412 <ipython-input-9-289fa74c7767>:80] Wall time: 9.791319 I0712 101152.413 <ipython-input-9-289fa74c7767>:22] rnd: -0.278It 8/15, best obs.: , ei: -0.131It 8/15, best obs.: , ei_hogp_cf: -0.021 I0712 101221.768 <ipython-input-9-289fa74c7767>:80] Wall time: 29.355340 I0712 101221.770 <ipython-input-9-289fa74c7767>:22] rnd: -0.278It 9/15, best obs.: , ei: -0.131It 9/15, best obs.: , ei_hogp_cf: -0.021 I0712 101232.866 <ipython-input-9-289fa74c7767>:80] Wall time: 11.096618 I0712 101232.867 <ipython-input-9-289fa74c7767>:22] rnd: -0.278It 10/15, best obs.: , ei: -0.131It 10/15, best obs.: , ei_hogp_cf: -0.021 I0712 101307.960 <ipython-input-9-289fa74c7767>:80] Wall time: 35.093111 I0712 101307.962 <ipython-input-9-289fa74c7767>:22] rnd: -0.278It 11/15, best obs.: , ei: -0.131It 11/15, best obs.: , ei_hogp_cf: -0.021 I0712 101317.455 <ipython-input-9-289fa74c7767>:80] Wall time: 9.493767 I0712 101317.457 <ipython-input-9-289fa74c7767>:22] rnd: -0.278It 12/15, best obs.: , ei: -0.131It 12/15, best obs.: , ei_hogp_cf: -0.021 I0712 101401.017 <ipython-input-9-289fa74c7767>:80] Wall time: 43.560580 I0712 101401.018 <ipython-input-9-289fa74c7767>:22] rnd: -0.209It 13/15, best obs.: , ei: -0.131It 13/15, best obs.: , ei_hogp_cf: -0.021 I0712 101411.866 <ipython-input-9-289fa74c7767>:80] Wall time: 10.848468 I0712 101411.868 <ipython-input-9-289fa74c7767>:22] rnd: -0.209It 14/15, best obs.: , ei: -0.131It 14/15, best obs.: , ei_hogp_cf: -0.021 I0712 101428.308 <ipython-input-9-289fa74c7767>:80] Wall time: 16.440600 I0712 101428.310 <ipython-input-9-289fa74c7767>:22] rnd: -0.209It 15/15, best obs.: , ei: -0.131It 15/15, best obs.: , ei_hogp_cf: -0.021 I0712 101516.034 <ipython-input-9-289fa74c7767>:80] Wall time: 47.724485 I0712 101516.051 <ipython-input-9-289fa74c7767>:22] rnd: -0.419It 1/15, best obs.: , ei: -0.419It 1/15, best obs.: , ei_hogp_cf: -0.419
Finished with trial: 2 Beginning with trial: 3 ---- Running the environmental problem with 3 4 ----
I0712 101538.563 <ipython-input-9-289fa74c7767>:80] Wall time: 22.512460 I0712 101538.564 <ipython-input-9-289fa74c7767>:22] rnd: -0.419It 2/15, best obs.: , ei: -0.419It 2/15, best obs.: , ei_hogp_cf: -0.419 I0712 101557.352 <ipython-input-9-289fa74c7767>:80] Wall time: 18.788451 I0712 101557.354 <ipython-input-9-289fa74c7767>:22] rnd: -0.373It 3/15, best obs.: , ei: -0.419It 3/15, best obs.: , ei_hogp_cf: -0.409 I0712 101624.634 <ipython-input-9-289fa74c7767>:80] Wall time: 27.281004 I0712 101624.636 <ipython-input-9-289fa74c7767>:22] rnd: -0.373It 4/15, best obs.: , ei: -0.419It 4/15, best obs.: , ei_hogp_cf: -0.238 I0712 101648.071 <ipython-input-9-289fa74c7767>:80] Wall time: 23.435322 I0712 101648.072 <ipython-input-9-289fa74c7767>:22] rnd: -0.373It 5/15, best obs.: , ei: -0.419It 5/15, best obs.: , ei_hogp_cf: -0.238 I0712 101717.918 <ipython-input-9-289fa74c7767>:80] Wall time: 29.846168 I0712 101717.920 <ipython-input-9-289fa74c7767>:22] rnd: -0.373It 6/15, best obs.: , ei: -0.419It 6/15, best obs.: , ei_hogp_cf: -0.235 I0712 101737.518 <ipython-input-9-289fa74c7767>:80] Wall time: 19.599094 I0712 101737.520 <ipython-input-9-289fa74c7767>:22] rnd: -0.373It 7/15, best obs.: , ei: -0.419It 7/15, best obs.: , ei_hogp_cf: -0.235 I0712 101818.189 <ipython-input-9-289fa74c7767>:80] Wall time: 40.669840 I0712 101818.191 <ipython-input-9-289fa74c7767>:22] rnd: -0.316It 8/15, best obs.: , ei: -0.419It 8/15, best obs.: , ei_hogp_cf: -0.053 I0712 101831.639 <ipython-input-9-289fa74c7767>:80] Wall time: 13.448530 I0712 101831.640 <ipython-input-9-289fa74c7767>:22] rnd: -0.316It 9/15, best obs.: , ei: -0.419It 9/15, best obs.: , ei_hogp_cf: -0.053 I0712 101908.641 <ipython-input-9-289fa74c7767>:80] Wall time: 37.001188 I0712 101908.643 <ipython-input-9-289fa74c7767>:22] rnd: -0.316It 10/15, best obs.: , ei: -0.419It 10/15, best obs.: , ei_hogp_cf: -0.020 I0712 101935.712 <ipython-input-9-289fa74c7767>:80] Wall time: 27.069663 I0712 101935.713 <ipython-input-9-289fa74c7767>:22] rnd: -0.316It 11/15, best obs.: , ei: -0.419It 11/15, best obs.: , ei_hogp_cf: -0.020 /mnt/xarfuse/uid-22017/907c0fbc-seed-nspid4026533651_cgpid35640602-ns-4026533648/gpytorch/utils/cholesky.py:44: NumericalWarning: A not p.d., added jitter of 1.0e-04 to the diagonal I0712 102016.146 <ipython-input-9-289fa74c7767>:80] Wall time: 40.432912 I0712 102016.147 <ipython-input-9-289fa74c7767>:22] rnd: -0.316It 12/15, best obs.: , ei: -0.419It 12/15, best obs.: , ei_hogp_cf: -0.020 /mnt/xarfuse/uid-22017/907c0fbc-seed-nspid4026533651_cgpid35640602-ns-4026533648/gpytorch/utils/cholesky.py:44: NumericalWarning: A not p.d., added jitter of 1.0e-04 to the diagonal I0712 102027.985 <ipython-input-9-289fa74c7767>:80] Wall time: 11.838584 I0712 102027.987 <ipython-input-9-289fa74c7767>:22] rnd: -0.316It 13/15, best obs.: , ei: -0.419It 13/15, best obs.: , ei_hogp_cf: -0.020 I0712 102111.012 <ipython-input-9-289fa74c7767>:80] Wall time: 43.025202 I0712 102111.013 <ipython-input-9-289fa74c7767>:22] rnd: -0.316It 14/15, best obs.: , ei: -0.419It 14/15, best obs.: , ei_hogp_cf: -0.020 I0712 102121.232 <ipython-input-9-289fa74c7767>:80] Wall time: 10.219100 I0712 102121.233 <ipython-input-9-289fa74c7767>:22] rnd: -0.316It 15/15, best obs.: , ei: -0.419It 15/15, best obs.: , ei_hogp_cf: -0.020 /mnt/xarfuse/uid-22017/907c0fbc-seed-nspid4026533651_cgpid35640602-ns-4026533648/gpytorch/utils/cholesky.py:44: NumericalWarning: A not p.d., added jitter of 1.0e-04 to the diagonal I0712 102213.324 <ipython-input-9-289fa74c7767>:80] Wall time: 52.091378
Finished with trial: 3
methods_dict = {k: torch.stack([trial[k].cpu() for trial in all_objective_vals]).cummax(1)[0] for k in models_used}
mean_and_sem_results = {k: [
-methods_dict[k].mean(0)[n_init:],
2. * methods_dict[k].std(0)[n_init:] / (n_trials**0.5)] for k in models_used}
Finally, we plot the results, showing that the HOGP performs well on this task, and converges to a closer parameter value than a batch GP on the composite metric itself.
plt.figure(figsize = (8,6))
labels_dict = {"rnd": "Random", "ei": "EI", "ei_hogp_cf": "Composite EI"}
for k in models_used:
plt.plot(
torch.arange(n_batches * batch_size),
mean_and_sem_results[k][0],
label = labels_dict[k],
)
plt.fill_between(
torch.arange(n_batches * batch_size),
(mean_and_sem_results[k][0] - mean_and_sem_results[k][1]).clamp(min=1e-3),
mean_and_sem_results[k][0] + mean_and_sem_results[k][1],
alpha = 0.3
)
plt.legend(fontsize = 20)
plt.semilogy()
plt.xlabel("Number of Function Queries")
plt.ylabel("Difference from True Parameter")
Text(0, 0.5, 'Difference from True Parameter')