Source code for botorch.acquisition.fixed_feature

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
A wrapper around AquisitionFunctions to fix certain features for optimization.
This is useful e.g. for performing contextual optimization.
"""

from __future__ import annotations

from numbers import Number
from typing import List, Sequence, Union

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from torch import Tensor
from torch.nn import Module


[docs]class FixedFeatureAcquisitionFunction(AcquisitionFunction): """A wrapper around AquisitionFunctions to fix a subset of features. Example: >>> model = SingleTaskGP(train_X, train_Y) # d = 5 >>> qEI = qExpectedImprovement(model, best_f=0.0) >>> columns = [2, 4] >>> values = X[..., columns] >>> qEI_FF = FixedFeatureAcquisitionFunction(qEI, 5, columns, values) >>> qei = qEI_FF(test_X) # d' = 3 """ def __init__( self, acq_function: AcquisitionFunction, d: int, columns: List[int], values: Union[Tensor, Sequence[Union[Tensor, float]]], ) -> None: r"""Derived Acquisition Function by fixing a subset of input features. Args: acq_function: The base acquisition function, operating on input tensors `X_full` of feature dimension `d`. d: The feature dimension expected by `acq_function`. columns: `d_f < d` indices of columns in `X_full` that are to be fixed to the provided values. values: The values to which to fix the columns in `columns`. Either a full `batch_shape x q x d_f` tensor of values (if values are different for each of the `q` input points), or an array-like of values that is broadcastable to the input across `t`-batch and `q`-batch dimensions, e.g. a list of length `d_f` if values are the same across all `t` and `q`-batch dimensions, or a combination of `Tensor`s and numbers which can be broadcasted to form a tensor with trailing dimension size of `d_f`. """ Module.__init__(self) self.acq_func = acq_function dtype = torch.float device = torch.device("cpu") self.d = d if isinstance(values, Tensor): new_values = values.detach().clone() else: new_values = [] for value in values: if isinstance(value, Number): new_values.append(torch.tensor([float(value)])) else: # if any value uses double, use double for all values # likewise if any value uses cuda, use cuda for all values dtype = value.dtype if value.dtype == torch.double else dtype device = value.device if value.device.type == "cuda" else device new_values.append(value.detach().clone()) # move all values to same device for i, val in enumerate(new_values): new_values[i] = val.to(dtype=dtype, device=device) # There are 3 cases for when `values` is a `Sequence`. # 1) `values` == list of floats as earlier. # 2) `values` == combination of floats and `Tensor`s. # 3) `values` == a list of `Tensor`s. # For 1), the below step creates a vector of length `len(values)` # For 2), the below step creates a `Tensor` of shape `batch_shape x q x d_f` # with the broadcasting functionality. # For 3), this is simply a concatenation, yielding a `Tensor` with the # same shape as in 2). # The key difference arises when `_construct_X_full` is invoked. # In 1), the expansion (`self.values.expand`) will expand the `Tensor` to # size `batch_shape x q x d_f`. # In 2) and 3), this expansion is a no-op because they are already of the # required size. However, 2) and 3) _cannot_ support varying `batch_shape`, # which means that all calls to `FixedFeatureAcquisitionFunction` have # to have the same size throughout when `values` contains a `Tensor`. # This is consistent with the scenario when a singular `Tensor` is passed # as the `values` argument. new_values = torch.cat(torch.broadcast_tensors(*new_values), dim=-1) self.register_buffer("values", new_values) # build selector for _construct_X_full self._selector = [] idx_X, idx_f = 0, d - new_values.shape[-1] for i in range(self.d): if i in columns: self._selector.append(idx_f) idx_f += 1 else: self._selector.append(idx_X) idx_X += 1
[docs] def forward(self, X: Tensor): r"""Evaluate base acquisition function under the fixed features. Args: X: Input tensor of feature dimension `d' < d` such that `d' + d_f = d`. Returns: Base acquisition function evaluated on tensor `X_full` constructed by adding `values` in the appropriate places (see `_construct_X_full`). """ X_full = self._construct_X_full(X) return self.acq_func(X_full)
def _construct_X_full(self, X: Tensor) -> Tensor: r"""Constructs the full input for the base acquisition function. Args: X: Input tensor with shape `batch_shape x q x d'` such that `d' + d_f = d`. Returns: Tensor `X_full` of shape `batch_shape x q x d`, where `X_full[..., i] = values[..., i]` if `i in columns`, and `X_full[..., i] = X[..., j]`, with `j = i - sum_{l<=i} 1_{l in fixed_colunns}`. """ d_prime, d_f = X.shape[-1], self.values.shape[-1] if d_prime + d_f != self.d: raise ValueError( f"Feature dimension d' ({d_prime}) of input must be " f"d - d_f ({self.d - d_f})." ) # concatenate values to the end values = self.values.to(X).expand(*X.shape[:-1], d_f) X_perm = torch.cat([X, values], dim=-1) # now select the appropriate column order return X_perm[..., self._selector]