BoTorch provides a convenient botorch.fit.fit_gpytorch_model
function with sensible defaults that work on most basic models, including those that BoTorch ships with. Internally, this function uses L-BFGS-B to fit the parameters. However, in more advanced uses cases you may need or want to implement your own model fitting logic.
This tutorial allows you to customize model fitting to your needs using the familiar PyTorch-style model fitting loop.
This tutorial is adapted from GPyTorch's Simple GP Regression Tutorial and has very few changes because the out-of-the box models that BoTorch provides are GPyTorch models; in fact, they are proper subclasses that add the botorch.models.Model
API functions.
import math
import torch
# use a GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float
In this tutorial we will model a simple sinusoidal function with i.i.d. Gaussian noise:
$$y = \sin(2\pi x) + \epsilon, ~\epsilon \sim \mathcal N(0, 0.15)$$# use regular spaced points on the interval [0, 1] (training data needs an explicit dimension)
train_X = torch.linspace(0, 1, 15, dtype=dtype, device=device).unsqueeze(1)
# sample observed values and add some synthetic noise
train_Y = torch.sin(train_X * (2 * math.pi)) + 0.15 * torch.randn_like(train_X)
# input observations must be single-dimensional for single-output model
train_Y = train_Y.view(-1)
We will model the function using a SingleTaskGP
, which by default uses a GaussianLikelihood
and infers the unknown noise level.
The SingleTaskGP
is typically fit using L-BFGS-B with explicit bounds on the noise parameter. This means that internally the model does not use a transform to ensure the noise level remains positive. Since the torch
optimizers don't handle explicit constraints very well, we need to manually register a constraint on the noise level that enforces a lower bound using a softplus transformation (note that the constraint defined in the constructor of SingleTask
has transform=None
, which means that no transform is enforced). See the GPyTorch constraints module for additional information.
from botorch.models import SingleTaskGP
from gpytorch.constraints import GreaterThan
model = SingleTaskGP(train_X=train_X, train_Y=train_Y)
model.likelihood.noise_covar.register_constraint("raw_noise", GreaterThan(1e-5))
We will optimizing the the kernel hyperparameters and the likelihood's noise parameter jointly by minimizing the negative gpytorch.mlls.ExactMarginalLogLikelihood
(our loss function).
from gpytorch.mlls import ExactMarginalLogLikelihood
mll = ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model)
# set mll and all submodules to the specified dtype and device
mll = mll.to(train_X)
We will use stochastic gradient descent (torch.optim.SGD
) to optimize the kernel hyperparameters and the noise level. In this example, we will use a simple fixed learning rate of 0.1, but in practice the learning may need to adjusted. You can use any of
Notes:
GaussianLikelihood
module is a of child (submodule) of the SingleTaskGP
moduel, model.parameters()
will also include the noise level of the GaussianLikelihood
. from torch.optim import SGD
optimizer = SGD([{'params': model.parameters()}], lr=0.1)
Now we are ready to write our optimization loop. We will perform 100 epochs of stochastic gradient descent using our entire training set.
NUM_EPOCHS = 150
model.train()
for epoch in range(NUM_EPOCHS):
# clear gradients
optimizer.zero_grad()
# forward pass through the model to obtain the output MultivariateNormal
output = model(train_X)
# Compute negative marginal log likelihood
loss = - mll(output, train_Y)
# back prop gradients
loss.backward()
# print every 10 iterations
if (epoch + 1) % 10 == 0:
print(
f"Epoch {epoch+1:>3}/{NUM_EPOCHS} - Loss: {loss.item():>4.3f} "
f"lengthscale: {model.covar_module.base_kernel.lengthscale.item():>4.3f} "
f"noise: {model.likelihood.noise.item():>4.3f}"
)
optimizer.step()
Epoch 10/150 - Loss: 1.938 lengthscale: 0.646 noise: 1.997 Epoch 20/150 - Loss: 1.898 lengthscale: 0.602 noise: 1.850 Epoch 30/150 - Loss: 1.858 lengthscale: 0.565 noise: 1.700 Epoch 40/150 - Loss: 1.817 lengthscale: 0.533 noise: 1.549 Epoch 50/150 - Loss: 1.773 lengthscale: 0.506 noise: 1.397 Epoch 60/150 - Loss: 1.726 lengthscale: 0.483 noise: 1.246 Epoch 70/150 - Loss: 1.676 lengthscale: 0.462 noise: 1.097 Epoch 80/150 - Loss: 1.623 lengthscale: 0.444 noise: 0.953 Epoch 90/150 - Loss: 1.565 lengthscale: 0.428 noise: 0.816 Epoch 100/150 - Loss: 1.504 lengthscale: 0.413 noise: 0.688 Epoch 110/150 - Loss: 1.439 lengthscale: 0.400 noise: 0.571 Epoch 120/150 - Loss: 1.372 lengthscale: 0.387 noise: 0.467 Epoch 130/150 - Loss: 1.302 lengthscale: 0.375 noise: 0.378 Epoch 140/150 - Loss: 1.231 lengthscale: 0.364 noise: 0.302 Epoch 150/150 - Loss: 1.161 lengthscale: 0.355 noise: 0.240
We plot the posterior mean and the 2 standard deviations from the mean.
# set model (and likelihood)
model.eval();
from matplotlib import pyplot as plt
%matplotlib inline
# Initialize plot
f, ax = plt.subplots(1, 1, figsize=(6, 4))
# test model on 101 regular spaced points on the interval [0, 1]
test_X = torch.linspace(0, 1, 101, dtype=dtype, device=device)
# no need for gradients
with torch.no_grad():
# compute posterior
posterior = model.posterior(test_X)
# Get upper and lower confidence bounds (2 standard deviations from the mean)
lower, upper = posterior.mvn.confidence_region()
# Plot training points as black stars
ax.plot(train_X.cpu().numpy(), train_Y.cpu().numpy(), 'k*')
# Plot posterior means as blue line
ax.plot(test_X.cpu().numpy(), posterior.mean.cpu().numpy(), 'b')
# Shade between the lower and upper confidence bounds
ax.fill_between(test_X.cpu().numpy(), lower.cpu().numpy(), upper.cpu().numpy(), alpha=0.5)
ax.legend(['Observed Data', 'Mean', 'Confidence'])
plt.tight_layout()
It is simple to package up a custom optimizer loop like the one above and use it within Ax. As described in the Using BoTorch with Ax tutorial, this requires defining a custom model_constructor
callable that then can ben passed to the get_botorch
factory function.
def _get_and_fit_model(Xs, Ys, **kwargs):
train_X, train_Y = Xs[0], Ys[0]
model = SingleTaskGP(train_X=train_X, train_Y=train_Y)
mll = ExactMarginalLogLikelihood(model.likelihood, model).to(train_X)
model.train()
optimizer = SGD([{'params': model.parameters()}], lr=kwargs.get("lr"))
for epoch in range(kwargs.get("epochs")):
optimizer.zero_grad()
output = model(train_X)
loss = -mll(output, train_Y)
loss.backward()
optimizer.step()
return model