This demo currently considers three approaches to discrete Thompson sampling on m
candidates points:
Exact sampling with Cholesky: Computing a Cholesky decomposition of the corresponding m x m
covariance matrix which reuqires O(m^3)
computational cost and O(m^2)
space. This is the standard approach to sampling from a Gaussian process, but the quadratic memory usage and cubic compliexity limits the number of candidate points.
Contour integral quadrature (CIQ): CIQ [1] is a Krylov subspace method combined with a rational approximation that can be used for computing matrix square roots of covariance matrices, which is the main bottleneck when sampling from a Gaussian process. CIQ relies on computing matrix vector multiplications with the exact kernel matrix which requires O(m^2)
computational complexity and space. The space complexity can be lowered to O(m)
by using KeOps, which is necessary to scale to large values of m
.
Lanczos: Rather than using CIQ, we can solve the linear systems K^(1/2) v = b
using Lanczos and the conjugate gradient (CG) method. This will be faster than CIQ, but will generally produce samples of worse quality. Similarly to CIQ, we need to use KeOps as we reqiuire computing matrix vector multiplications with the exact kernel matrix.
Random Fourier features (RFFs): The RFF kernel was originally proposed in [2] and we use it as implemented in GPyTorch. RFFs are computationally cheap to work with as the computational cost and space are both O(km)
where k
is the number of Fourier features. Note that while Cholesky and CIQ are able to generate exact samples from the GP model, RFFs are an unbiased approximation and the resulting samples often aren't perfectly calibrated.
import os
import time
from contextlib import ExitStack
import torch
from torch.quasirandom import SobolEngine
import gpytorch
import gpytorch.settings as gpts
import pykeops
from botorch.fit import fit_gpytorch_model
from botorch.generation import MaxPosteriorSampling
from botorch.models import SingleTaskGP
from botorch.test_functions import Hartmann
from botorch.utils.transforms import unnormalize
from gpytorch.constraints import Interval
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import MaternKernel, RFFKernel, ScaleKernel
from gpytorch.kernels.keops import MaternKernel as KMaternKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.double
SMOKE_TEST = os.environ.get("SMOKE_TEST")
pykeops.test_torch_bindings() # Make sure the KeOps bindings are working
pyKeOps with torch bindings is working!
hart6 = Hartmann(dim=6, negate=True).to(device=device, dtype=dtype)
dim = hart6.dim
def eval_objective(x):
"""This is a helper function we use to unnormalize and evalaute a point"""
return hart6(unnormalize(x, hart6.bounds))
def get_initial_points(dim, n_pts, seed=None):
sobol = SobolEngine(dimension=dim, scramble=True, seed=seed)
X_init = sobol.draw(n=n_pts).to(dtype=dtype, device=device)
return X_init
def generate_batch(
X,
Y,
batch_size,
n_candidates,
sampler="cholesky", # "cholesky", "ciq", "rff"
use_keops=False,
):
assert sampler in ("cholesky", "ciq", "rff", "lanczos")
assert X.min() >= 0.0 and X.max() <= 1.0 and torch.all(torch.isfinite(Y))
# NOTE: We probably want to pass in the default priors in SingleTaskGP here later
kernel_kwargs = {"nu": 2.5, "ard_num_dims": X.shape[-1]}
if sampler == "rff":
base_kernel = RFFKernel(**kernel_kwargs, num_samples=1024)
else:
base_kernel = (
KMaternKernel(**kernel_kwargs) if use_keops else MaternKernel(**kernel_kwargs)
)
covar_module = ScaleKernel(base_kernel)
# Fit a GP model
train_Y = (Y - Y.mean()) / Y.std()
likelihood = GaussianLikelihood(noise_constraint=Interval(1e-8, 1e-3))
model = SingleTaskGP(X, train_Y, likelihood=likelihood, covar_module=covar_module)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_model(mll)
# Draw samples on a Sobol sequence
sobol = SobolEngine(X.shape[-1], scramble=True)
X_cand = sobol.draw(n_candidates).to(dtype=dtype, device=device)
# Thompson sample
with ExitStack() as es:
if sampler == "cholesky":
es.enter_context(gpts.max_cholesky_size(float("inf")))
elif sampler == "ciq":
es.enter_context(gpts.fast_computations(covar_root_decomposition=True))
es.enter_context(gpts.max_cholesky_size(0))
es.enter_context(gpts.ciq_samples(True))
es.enter_context(gpts.minres_tolerance(2e-3)) # Controls accuracy and runtime
es.enter_context(gpts.num_contour_quadrature(15))
elif sampler == "lanczos":
es.enter_context(gpts.fast_computations(covar_root_decomposition=True))
es.enter_context(gpts.max_cholesky_size(0))
es.enter_context(gpts.ciq_samples(False))
elif sampler == "rff":
es.enter_context(gpts.fast_computations(covar_root_decomposition=True))
thompson_sampling = MaxPosteriorSampling(model=model, replacement=False)
X_next = thompson_sampling(X_cand, num_samples=batch_size)
return X_next
def run_optimization(sampler, n_candidates, n_init, max_evals, batch_size, use_keops=False, seed=None):
X = get_initial_points(dim, n_init, seed)
Y = torch.tensor([eval_objective(x) for x in X], dtype=dtype, device=device).unsqueeze(-1)
print(f"{len(X)}) Best value: {Y.max().item():.2e}")
while len(X) < max_evals:
# Create a batch
start = time.time()
X_next = generate_batch(
X=X,
Y=Y,
batch_size=min(batch_size, max_evals - len(X)),
n_candidates=n_candidates,
sampler=sampler,
use_keops=use_keops,
)
end = time.time()
print(f"Generated batch in {end - start:.1f} seconds")
Y_next = torch.tensor(
[eval_objective(x) for x in X_next], dtype=dtype, device=device
).unsqueeze(-1)
# Append data
X = torch.cat((X, X_next), dim=0)
Y = torch.cat((Y, Y_next), dim=0)
print(f"{len(X)}) Best value: {Y.max().item():.2e}")
return X, Y
batch_size = 5
n_init = 10
max_evals = 60
seed = 0 # To get the same Sobol points
shared_args = {
"n_init": n_init,
"max_evals": max_evals,
"batch_size": batch_size,
"seed": seed,
}
USE_KEOPS = True if not SMOKE_TEST else False
N_CAND = 50000 if not SMOKE_TEST else 10
N_CAND_CHOL = 10000 if not SMOKE_TEST else 10
%load_ext memory_profiler
%memit X_chol, Y_chol = run_optimization("cholesky", N_CAND_CHOL, **shared_args)
10) Best value: 1.12e+00 Generated batch in 0.7 seconds 15) Best value: 1.12e+00 Generated batch in 1.2 seconds 20) Best value: 1.12e+00 Generated batch in 1.1 seconds 25) Best value: 1.12e+00 Generated batch in 0.9 seconds 30) Best value: 2.81e+00 Generated batch in 0.9 seconds 35) Best value: 2.81e+00 Generated batch in 1.2 seconds 40) Best value: 3.03e+00 Generated batch in 1.0 seconds 45) Best value: 3.03e+00 Generated batch in 1.3 seconds 50) Best value: 3.03e+00 Generated batch in 1.2 seconds 55) Best value: 3.03e+00 Generated batch in 1.1 seconds 60) Best value: 3.03e+00 peak memory: 2661.07 MiB, increment: 23.34 MiB
%memit X_rff, Y_rff = run_optimization("rff", N_CAND, **shared_args)
10) Best value: 1.12e+00 Generated batch in 1.1 seconds 15) Best value: 1.12e+00 Generated batch in 1.0 seconds 20) Best value: 1.40e+00 Generated batch in 1.1 seconds 25) Best value: 1.98e+00 Generated batch in 0.9 seconds 30) Best value: 2.07e+00 Generated batch in 0.8 seconds 35) Best value: 3.03e+00 Generated batch in 1.0 seconds 40) Best value: 3.10e+00 Generated batch in 1.0 seconds 45) Best value: 3.10e+00 Generated batch in 1.2 seconds 50) Best value: 3.10e+00 Generated batch in 1.0 seconds 55) Best value: 3.10e+00 Generated batch in 1.0 seconds 60) Best value: 3.11e+00 peak memory: 2679.02 MiB, increment: 19.55 MiB
%memit X_lanczos, Y_lanczos = run_optimization("lanczos", N_CAND, use_keops=USE_KEOPS, **shared_args)
10) Best value: 1.12e+00 Generated batch in 4.4 seconds 15) Best value: 1.12e+00 Generated batch in 4.6 seconds 20) Best value: 1.39e+00 Generated batch in 4.5 seconds 25) Best value: 1.39e+00 Generated batch in 4.5 seconds 30) Best value: 2.30e+00 Generated batch in 4.4 seconds 35) Best value: 2.66e+00 Generated batch in 4.3 seconds 40) Best value: 2.81e+00 Generated batch in 4.5 seconds 45) Best value: 3.09e+00 Generated batch in 4.7 seconds 50) Best value: 3.09e+00 Generated batch in 4.6 seconds 55) Best value: 3.12e+00 Generated batch in 4.6 seconds 60) Best value: 3.12e+00 peak memory: 2691.95 MiB, increment: 14.54 MiB
%memit X_ciq, Y_ciq = run_optimization("ciq", N_CAND, use_keops=USE_KEOPS, **shared_args)
10) Best value: 1.12e+00 Generated batch in 23.7 seconds 15) Best value: 1.12e+00 Generated batch in 30.6 seconds 20) Best value: 1.83e+00 Generated batch in 27.2 seconds 25) Best value: 1.83e+00 Generated batch in 25.4 seconds 30) Best value: 2.36e+00 Generated batch in 22.2 seconds 35) Best value: 3.15e+00 Generated batch in 28.7 seconds 40) Best value: 3.15e+00 Generated batch in 22.4 seconds 45) Best value: 3.15e+00 Generated batch in 27.2 seconds 50) Best value: 3.15e+00 Generated batch in 25.6 seconds 55) Best value: 3.15e+00 Generated batch in 25.6 seconds 60) Best value: 3.15e+00 peak memory: 2695.97 MiB, increment: 4.99 MiB
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure(figsize=(10, 8))
matplotlib.rcParams.update({"font.size": 20})
results = [
(Y_chol.cpu(), "Cholesky-10,000", "b", "", 14, "--"),
(Y_rff.cpu(), "RFF-50,000", "r", ".", 16, "-"),
(Y_lanczos.cpu(), "Lanczos-50,000", "m", "^", 9, "-"),
(Y_ciq.cpu(), "CIQ-50,000", "g", "*", 12, "-"),
]
optimum = hart6.optimal_value
ax = fig.add_subplot(1, 1, 1)
names = []
for res, name, c, m, ms, ls in results:
names.append(name)
fx = res.cummax(dim=0)[0]
t = 1 + np.arange(len(fx))
plt.plot(t[0::2], fx[0::2], c=c, marker=m, linestyle=ls, markersize=ms)
plt.plot([0, max_evals], [hart6.optimal_value, hart6.optimal_value], "k--", lw=3)
plt.xlabel("Function value", fontsize=18)
plt.xlabel("Number of evaluations", fontsize=18)
plt.title("Hartmann6", fontsize=24)
plt.xlim([0, max_evals])
plt.ylim([0.5, 3.5])
plt.grid(True)
plt.tight_layout()
plt.legend(
names + ["Global optimal value"],
loc="lower right",
ncol=1,
fontsize=18,
)
plt.show()