Skip to main content
Version: Next

Variational Bayesian Last Layers for Bayesian Optimization

  • Contributors: brunzema
  • Last updated: Feb 13, 2025
  • BoTorch version: 0.9.6(dev), commit hash: dccda59d8ef51d8074de82fdb5614bad2db0ee96 UPDATE!

In this notebook, we will demonstrate how to use variational Bayesian last layers (VBLLs) for Bayesian optimization [1, 2].

[1] P. Brunzema, M. Jordahn, J. Willes, S. Trimpe, J. Snoek, J. Harrison. Bayesian Optimization via Continual Variational Last Layer Training. International Conference on Learning Representations (ICLR), 2025.

[2] J. Harrison, J. Willes, J. Snoek. Variational Bayesian Last Layers. International Conference on Learning Representations (ICLR), 2024.

Introduction to the VBLL Model

Bayesian optimization (BO) relies on surrogate models that provide uncertainty-aware predictions. Usually, Gaussian processes are the goto choice due to their analytical tractability, but they can limiting for input spaces that are non-Euclidean where careful choise of the kernel is curcial. Here, Bayesian neural networks are promissing as they automatically learn the correlations.

Variational Bayesian Last Layer (VBLL) [2] models provide a practical and scalable way to approximate Bayesian inference in neural networks. Instead of placing a prior over all network weights, VBLL only model the uncertainty on the parameters on the last layer while keeping the feature extractor deterministic. This setup allows the model to retain the expressive power of deep learning while maintaining well-calibrated uncertainty estimates.

In VBLL, we model the output as a generalized linear model with learned features ϕθ\phi_{\theta} as:

y=wϕθ(x)+ε y = \mathbf{w}^\top \phi_{\theta} (\mathbf{x}) + \varepsilon

where εN(0,σ2)\varepsilon \sim \mathcal{N}(0, \sigma^2). VBLL models use a variational posterior distribution on the weights as wq(w)\mathbf{w} \sim q(\mathbf{w}) with q(w)=N(wˉ,S)q(\mathbf{w}) = \mathcal{N}(\bar{\mathbf{w}}, S) where wˉ\bar{\mathbf{w}} is the mean and SS the full covariance of a multivariate normal distribution (for more information, see [1,2]). Through conjugacy, we yield the posterior predictive

p(yx,DT,θ)=N(wˉTϕθ(x),ϕθ(x)STϕθ(x)+σ2). p(y \mid \mathbf{x}, \mathcal{D}_T, \theta) = \mathcal{N} \big( \bar{\mathbf{w}}_T^\top \phi_{\theta}(\mathbf{x}), \, \phi_{\theta}(\mathbf{x})^\top S_T \phi_{\theta}(\mathbf{x}) + \sigma^{2} \big).

This predictive can then be used in downsteam tasks such as Bayesian optimization (BO).

The VBLLs on Toy Data

Before using the VBLL model in the context of Bayesian Optimization, let's see how the interface of the model looks. The current implementation allows for passing a full backbone to the mode. This backbone could be any neural network architecture that outputs a feature vector of dimension M (input to the VBLL head), such as a pretrained MLP, GNN, or some other problem-specific architecture. However, note that the current formulation of acquisition functions only allows for an Rd\mathbb{R}^d or discrete input space. If no backbone is provided, a standard MLP is created. For training, [1] discusses continual learning approaches which will be added to the implementation, but in the following, we will not go into detail. See the docstring of model.fit for the different options for configuring the optimization and the training.

Below, we use the model in a 1D regression example.

import matplotlib.pyplot as plt
import torch
from botorch_community.models.vblls import VBLLModel

torch.set_default_dtype(torch.float64)
torch.manual_seed(42)


def objective(x, noise=True):
out = torch.sin(x * 5)
if noise:
out += 0.05 * torch.randn_like(x)
return out


X = torch.tensor([[0.0582], [0.0629], [0.1236], [0.0526], [0.5262], [0.9552]])
Y = objective(X)

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 4))
ax.set_title("Some observations of a 1D function")
ax.scatter(X, Y, c="k", label="Observations")
ax.legend()
plt.show()

model = VBLLModel(
in_features=1,
hidden_features=64,
num_layers=3,
out_features=1,
)

# lets print the model, we can see the MLP backbone and the VBLL regression head
print(model)

# fit the model on the data -- it is also possible to specify the optimizer
# throught `optimizer_settings` (see docstring)
model.fit(X, Y)
Output:
VBLLNetwork(
(activation): ELU(alpha=1.0)
(backbone): Sequential(
(0): Linear(in_features=1, out_features=64, bias=True)
(1): ELU(alpha=1.0)
(2): Sequential(
(0): Linear(in_features=64, out_features=64, bias=True)
(1): ELU(alpha=1.0)
)
(3): Sequential(
(0): Linear(in_features=64, out_features=64, bias=True)
(1): ELU(alpha=1.0)
)
(4): Sequential(
(0): Linear(in_features=64, out_features=64, bias=True)
(1): ELU(alpha=1.0)
)
)
(head): Regression()
)

Now we can use the model in the same way as a Gaussian process. Note, below is the exact same code as one would use with a GPyTorch model!

test_X = torch.linspace(0, 1, 100)
with torch.no_grad():
posterior = model.posterior(test_X.view(-1, 1))
mean = posterior.mean.squeeze()
std = posterior.variance.sqrt().squeeze()

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 4))
ax.set_title("VBLL Predictive on Toy Data")
ax.scatter(X, Y, c="k", label="Observations")
ax.plot(test_X, mean, label="Posterior predictive", color="tab:blue")
ax.fill_between(test_X, mean - 2 * std, mean + 2 * std, alpha=0.2, color="tab:blue")
ax.set_ylim(-1.5, 1.5)
ax.legend()
plt.show()

Thompson sampling with VBLLs

VBLLs yield a Gaussian predictive distribution, and thus most acquisition functions that are straightforward to compute for Gaussian Processes (GPs) are also straightforward for VBLLs. Moreover, parametric VBLLs are especially well-suited for Thompson sampling.

For a Thompson sample, we simply sample from the variational posterior of w\mathbf{w} at iteration tt and then construct a sample from the predictive f^\hat{f} as a generalized linear model:

First:w^q(w)Second:f^(x)w^ϕ_θ(x). \begin{aligned} \text{First:} &\quad \hat{\mathbf{w}} \sim q(\mathbf{w}) \\ \text{Second:} &\quad \hat{f} (\mathbf{x}) \coloneqq \hat{\mathbf{w}}^\top \phi\_{\theta} (\mathbf{x}). \end{aligned}

We can directly leverage this in BO to choose the next query location as xt+1=argmaxxXf^(x)\mathbf{x}_{t+1} = \arg\max_{\mathbf{x} \in \mathcal{X}} \hat{f} (\mathbf{x}). Since the sample is a linear model, we can numerically optimize the sample path as we can simply calculate gradients of the sample.

def plot_model(
model,
x,
y,
new_data=None,
show_objective=False,
show_samples=False,
title="VBLL Predictive on Toy Data with Thompson Samples",
):
x_test = torch.linspace(0, 1, 100)
with torch.no_grad():
posterior = model.posterior(x_test.view(-1, 1))
mean = posterior.mean.squeeze()
std = posterior.variance.sqrt().squeeze()

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 4))
ax.set_title(title)
ax.scatter(x, y, c="k", label="Observations")

if new_data is not None:
x_new, y_new = new_data
ax.scatter(x_new, y_new, c="tab:red", label="New observations")

if show_objective:
ax.plot(x_test, objective(x_test, noise=False), label="Objective", color="k")

ax.plot(x_test, mean, label="Posterior predictive", color="tab:blue")

# Posterior samples
if show_samples:
torch.manual_seed(0)
with torch.no_grad():
for i in range(5):
# a sample is essentially a standard MLP (nn.Module)
ts_sample = model.sample()

ts_mean = ts_sample(x_test.view(-1, 1)).squeeze()
ax.plot(
x_test,
ts_mean,
color="tab:grey",
alpha=0.5,
label="Thompson samples" if i == 0 else None,
)

# show max of the samples
max_idx = ts_mean.argmax()
ax.scatter(
x_test[max_idx],
ts_mean[max_idx],
color="tab:grey",
marker="x",
s=100,
label="Max of Thompson sample" if i == 0 else None,
)
ax.set_ylim(-1.5, 1.5)
ax.fill_between(x_test, mean - 2 * std, mean + 2 * std, alpha=0.2, color="tab:blue")
ax.legend()
plt.show()


plot_model(model, X, Y, show_samples=True)

Bayesian optimization with VBLLs

Next, lets us the VBLLs for Bayesian optimization to optimize the toy function above using Thompson sampling. Note that setting up the Thompson sampling is similar to the standard MaxPosteriorSampling in BoTorch. We also include logEI as an acquisition function to demonstrate the staight-forward use of the model also for other acquisition functions. To visualize the optimization, set the plot_bo_step flag to True.

from botorch_community.acquisition.bll_thompson_sampling import BLLMaxPosteriorSampling
from botorch.acquisition.analytic import LogExpectedImprovement
from botorch.optim import optimize_acqf

batch_size = 1
max_iter = 10
plot_bo_step = True

acq_functions = ["TS", "logEI"] # "logEI" or "TS"

# lets define the optimizer settings (the values below are the default values)
optimizer_settings = {
"num_epochs": 10_000, # number of epochs
"freeze_backbone": False, # whether to freeze the backbone during training
"patience": 100, # patience for early stopping
"batch_size": 32, # mini-batch size
"optimizer": torch.optim.AdamW, # optimizer
"wd": 1e-4, # weight decay
"lr": 1e-3, # learning rate
"clip_val": 1.0, # gradient clipping value
}

results = {}
for acq in acq_functions:
torch.manual_seed(42)

X_bo = torch.rand(2, 1)
Y_bo = objective(X_bo)

for iteration in range(1, max_iter + 1):
print(
f"Acquisition function {acq.upper()}\t Iteration {iteration})"
f"\t Best Value: {Y_bo.max().item()}"
)

# initialize the model
model = VBLLModel(
in_features=1,
hidden_features=64,
num_layers=3,
out_features=1,
)
model.fit(X_bo, Y_bo, optimization_settings=optimizer_settings)

if acq == "TS":
# Thompson sampling as acquisition function
thompson_sampling = BLLMaxPosteriorSampling(
model=model,
num_restarts=5,
)
X_next = thompson_sampling(num_samples=batch_size)

elif acq == "logEI":
bounds = torch.tensor([[0.0] * 1, [1.0] * 1])
# (log) Expected Improvement as acquisition function
acq_func = LogExpectedImprovement(model, best_f=Y_bo.max())
X_next, _ = optimize_acqf(
acq_function=acq_func,
bounds=bounds,
q=1,
num_restarts=5,
raw_samples=20,
)

# evaluate the objective
Y_next = objective(X_next)

# plot the model, the thompson samples and the objective
if plot_bo_step:
plot_model(
model,
X_bo,
Y_bo,
new_data=(X_next, Y_next),
show_objective=True,
title=f"Iteration {iteration} | Acquisition function: {acq}",
)

# update the data
X_bo = torch.cat([X_bo, X_next])
Y_bo = torch.cat([Y_bo, Y_next])

results[acq] = Y_bo
Output:
Acquisition function TS	 Iteration 1)	 Best Value: 0.3209060289482761

Output:
Acquisition function TS	 Iteration 2)	 Best Value: 0.3209060289482761

Output:
Acquisition function TS	 Iteration 3)	 Best Value: 0.3209060289482761

Output:
Acquisition function TS	 Iteration 4)	 Best Value: 1.0017157572316162

Output:
Acquisition function TS	 Iteration 5)	 Best Value: 1.0017157572316162

Output:
Acquisition function TS	 Iteration 6)	 Best Value: 1.04615523378849

Output:
Acquisition function TS	 Iteration 7)	 Best Value: 1.04615523378849

Output:
Acquisition function TS	 Iteration 8)	 Best Value: 1.04615523378849

Output:
Acquisition function TS	 Iteration 9)	 Best Value: 1.04615523378849

Output:
Acquisition function TS	 Iteration 10)	 Best Value: 1.04615523378849

Output:
Acquisition function LOGEI	 Iteration 1)	 Best Value: 0.3209060289482761

Output:
Acquisition function LOGEI	 Iteration 2)	 Best Value: 0.3209060289482761

Output:
Acquisition function LOGEI	 Iteration 3)	 Best Value: 0.3209060289482761

Output:
Acquisition function LOGEI	 Iteration 4)	 Best Value: 0.517998203362518

Output:
Acquisition function LOGEI	 Iteration 5)	 Best Value: 1.0645000326314427

Output:
Acquisition function LOGEI	 Iteration 6)	 Best Value: 1.0645000326314427

Output:
Acquisition function LOGEI	 Iteration 7)	 Best Value: 1.0645000326314427

Output:
Acquisition function LOGEI	 Iteration 8)	 Best Value: 1.0645000326314427

Output:
Acquisition function LOGEI	 Iteration 9)	 Best Value: 1.0645000326314427

Output:
Acquisition function LOGEI	 Iteration 10)	 Best Value: 1.0645000326314427

Lets plot the performance over the number of iterations. Note that this my no means shows that one should be prefered over the other as this is only one seed.

import numpy as np


def plot_performance(results, labels):
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 4))

observations = [results[key] for key in results.keys()]
for obs, label in zip(observations, labels):
best_value = np.maximum.accumulate(obs.numpy())
ax.plot(best_value, marker="", lw=3, label=label)

ax.hlines(
y=1.0,
xmin=0,
xmax=len(best_value) - 1,
color="k",
linestyle="--",
label=r"Maximum of $\mathbb{E}[f(x)]$",
)

ax.set_xlabel("Iteration")
ax.set_ylabel("Best value")
ax.legend()
plt.show()


plot_performance(results, ["VBLL (Thompson Sampling)", "VBLL (Expected Improvement)"])