Source code for botorch.acquisition.proximal
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
A wrapper around AcquisitionFunctions to add proximal weighting of the
acquisition function.
"""
from __future__ import annotations
import torch
from botorch.acquisition import AcquisitionFunction
from botorch.exceptions.errors import UnsupportedError
from botorch.utils import t_batch_mode_transform
from torch import Tensor
from torch.nn import Module
[docs]class ProximalAcquisitionFunction(AcquisitionFunction):
"""A wrapper around AcquisitionFunctions to add proximal weighting of the
acquisition function. Acquisition function is weighted via a squared exponential
centered at the last training point, with varying lengthscales corresponding to
`proximal_weights`. Can only be used with acquisition functions based on single
batch models.
Small values of `proximal_weights` corresponds to strong biasing towards recently
observed points, which smoothes optimization with a small potential decrese in
convergence rate.
Example:
>>> model = SingleTaskGP(train_X, train_Y)
>>> EI = ExpectedImprovement(model, best_f=0.0)
>>> proximal_weights = torch.ones(d)
>>> EI_proximal = ProximalAcquisitionFunction(EI, proximal_weights)
>>> eip = EI_proximal(test_X)
"""
def __init__(
self,
acq_function: AcquisitionFunction,
proximal_weights: Tensor,
) -> None:
r"""Derived Acquisition Function weighted by proximity to recently
observed point.
Args:
acq_function: The base acquisition function, operating on input tensors
of feature dimension `d`.
proximal_weights: A `d` dim tensor used to bias locality
along each axis.
"""
Module.__init__(self)
self.acq_func = acq_function
if hasattr(acq_function, "X_pending"):
if acq_function.X_pending is not None:
raise UnsupportedError(
"Proximal acquisition function requires `X_pending` to be None."
)
self.X_pending = acq_function.X_pending
self.register_buffer("proximal_weights", proximal_weights)
# check model for train_inputs and single batch
if not hasattr(self.acq_func.model, "train_inputs"):
raise UnsupportedError(
"Acquisition function model must have " "`train_inputs`."
)
if (
self.acq_func.model.batch_shape != torch.Size([])
and self.acq_func.model.train_inputs[0].shape[1] != 1
):
raise UnsupportedError(
"Proximal acquisition function requires a single batch model"
)
# check to make sure that weights match the training data shape
if (
len(self.proximal_weights.shape) != 1
or self.proximal_weights.shape[0]
!= self.acq_func.model.train_inputs[0][-1].shape[-1]
):
raise ValueError(
"`proximal_weights` must be a one dimensional tensor with "
"same feature dimension as model."
)
[docs] @t_batch_mode_transform(expected_q=1, assert_output_shape=False)
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate base acquisition function with proximal weighting.
Args:
X: Input tensor of feature dimension `d` .
Returns:
Base acquisition function evaluated on tensor `X` multiplied by proximal
weighting.
"""
last_X = self.acq_func.model.train_inputs[0][-1].reshape(1, 1, -1)
diff = X - last_X
M = torch.linalg.norm(diff / self.proximal_weights, dim=-1) ** 2
proximal_acq_weight = torch.exp(-0.5 * M)
return self.acq_func(X) * proximal_acq_weight.flatten()