Source code for botorch.acquisition.proximal
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
A wrapper around AcquisitionFunctions to add proximal weighting of the
acquisition function.
"""
from __future__ import annotations
from typing import Optional
import torch
from botorch.acquisition import AcquisitionFunction
from botorch.exceptions.errors import UnsupportedError
from botorch.models import ModelListGP
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
from botorch.models.model import Model
from botorch.models.transforms.input import InputTransform
from botorch.utils import t_batch_mode_transform
from torch import Tensor
from torch.nn import Module
[docs]class ProximalAcquisitionFunction(AcquisitionFunction):
"""A wrapper around AcquisitionFunctions to add proximal weighting of the
acquisition function. Acquisition function is weighted via a squared exponential
centered at the last training point, with varying lengthscales corresponding to
`proximal_weights`. Can only be used with acquisition functions based on single
batch models.
Small values of `proximal_weights` corresponds to strong biasing towards recently
observed points, which smoothes optimization with a small potential decrese in
convergence rate.
Example:
>>> model = SingleTaskGP(train_X, train_Y)
>>> EI = ExpectedImprovement(model, best_f=0.0)
>>> proximal_weights = torch.ones(d)
>>> EI_proximal = ProximalAcquisitionFunction(EI, proximal_weights)
>>> eip = EI_proximal(test_X)
"""
def __init__(
self,
acq_function: AcquisitionFunction,
proximal_weights: Tensor,
transformed_weighting: bool = True,
) -> None:
r"""Derived Acquisition Function weighted by proximity to recently
observed point.
Args:
acq_function: The base acquisition function, operating on input tensors
of feature dimension `d`.
proximal_weights: A `d` dim tensor used to bias locality
along each axis.
transformed_weighting: If True, the proximal weights are applied in
the transformed input space given by
`acq_function.model.input_transform` (if available), otherwise
proximal weights are applied in real input space.
"""
Module.__init__(self)
self.acq_func = acq_function
model = self.acq_func.model
if hasattr(acq_function, "X_pending"):
if acq_function.X_pending is not None:
raise UnsupportedError(
"Proximal acquisition function requires `X_pending` to be None."
)
self.X_pending = acq_function.X_pending
self.register_buffer("proximal_weights", proximal_weights)
self.register_buffer(
"transformed_weighting", torch.tensor(transformed_weighting)
)
_validate_model(model, proximal_weights)
[docs] @t_batch_mode_transform(expected_q=1, assert_output_shape=False)
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate base acquisition function with proximal weighting.
Args:
X: Input tensor of feature dimension `d` .
Returns:
Base acquisition function evaluated on tensor `X` multiplied by proximal
weighting.
"""
model = self.acq_func.model
train_inputs = model.train_inputs[0]
# if the model is ModelListGP then get the first model
if isinstance(model, ModelListGP):
train_inputs = train_inputs[0]
model = model.models[0]
# if the model has more than one output get the first copy of training inputs
if isinstance(model, BatchedMultiOutputGPyTorchModel) and model.num_outputs > 1:
train_inputs = train_inputs[0]
input_transform = _get_input_transform(model)
last_X = train_inputs[-1].reshape(1, 1, -1)
# if transformed_weighting, transform X to calculate diff
# (proximal weighting in transformed space)
# otherwise,un-transform the last observed point to real space
# (proximal weighting in real space)
if input_transform is not None:
if self.transformed_weighting:
# transformed space weighting
diff = input_transform.transform(X) - last_X
else:
# real space weighting
diff = X - input_transform.untransform(last_X)
else:
# no transformation
diff = X - last_X
M = torch.linalg.norm(diff / self.proximal_weights, dim=-1) ** 2
proximal_acq_weight = torch.exp(-0.5 * M)
return self.acq_func(X) * proximal_acq_weight.flatten()
def _validate_model(model: Model, proximal_weights: Tensor) -> None:
r"""Validate model
Perform vaidation checks on model used in base acquisition function to make sure
it is compatible with proximal weighting.
Args:
model: Model associated with base acquisition function to be validated.
proximal_weights: A `d` dim tensor used to bias locality
along each axis.
"""
# check model for train_inputs and single batch
if not hasattr(model, "train_inputs"):
raise UnsupportedError("Acquisition function model must have `train_inputs`.")
# get train inputs for each type of possible model
if isinstance(model, ModelListGP):
# ModelListGP models
# check to make sure that the training inputs and input transformers for each
# model match and are reversible
train_inputs = model.train_inputs[0][0]
input_transform = _get_input_transform(model.models[0])
for i in range(len(model.train_inputs)):
if not torch.equal(train_inputs, model.train_inputs[i][0]):
raise UnsupportedError(
"Proximal acquisition function does not support unequal "
"training inputs"
)
if not input_transform == _get_input_transform(model.models[i]):
raise UnsupportedError(
"Proximal acquisition function does not support non-identical "
"input transforms"
)
else:
# any non-ModelListGP model
train_inputs = model.train_inputs[0]
# check to make sure that the model is single t-batch (q-batches are allowed)
if model.batch_shape != torch.Size([]) and train_inputs.shape[1] != 1:
raise UnsupportedError(
"Proximal acquisition function requires a single batch model"
)
# check to make sure that weights match the training data shape
if (
len(proximal_weights.shape) != 1
or proximal_weights.shape[0] != train_inputs.shape[-1]
):
raise ValueError(
"`proximal_weights` must be a one dimensional tensor with "
"same feature dimension as model."
)
def _get_input_transform(model: Model) -> Optional[InputTransform]:
"""get input transform if defined"""
try:
return model.input_transform
except AttributeError:
return None