#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
A wrapper around AquisitionFunctions to fix certain features for optimization.
This is useful e.g. for performing contextual optimization.
"""
from __future__ import annotations
from numbers import Number
from typing import List, Sequence, Union
import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from torch import Tensor
from torch.nn import Module
[docs]class FixedFeatureAcquisitionFunction(AcquisitionFunction):
"""A wrapper around AquisitionFunctions to fix a subset of features.
Example:
>>> model = SingleTaskGP(train_X, train_Y) # d = 5
>>> qEI = qExpectedImprovement(model, best_f=0.0)
>>> columns = [2, 4]
>>> values = X[..., columns]
>>> qEI_FF = FixedFeatureAcquisitionFunction(qEI, 5, columns, values)
>>> qei = qEI_FF(test_X) # d' = 3
"""
def __init__(
self,
acq_function: AcquisitionFunction,
d: int,
columns: List[int],
values: Union[Tensor, Sequence[Union[Tensor, float]]],
) -> None:
r"""Derived Acquisition Function by fixing a subset of input features.
Args:
acq_function: The base acquisition function, operating on input
tensors `X_full` of feature dimension `d`.
d: The feature dimension expected by `acq_function`.
columns: `d_f < d` indices of columns in `X_full` that are to be
fixed to the provided values.
values: The values to which to fix the columns in `columns`. Either
a full `batch_shape x q x d_f` tensor of values (if values are
different for each of the `q` input points), or an array-like of
values that is broadcastable to the input across `t`-batch and
`q`-batch dimensions, e.g. a list of length `d_f` if values
are the same across all `t` and `q`-batch dimensions, or a
combination of `Tensor`s and numbers which can be broadcasted
to form a tensor with trailing dimension size of `d_f`.
"""
Module.__init__(self)
self.acq_func = acq_function
self.d = d
if isinstance(values, Tensor):
new_values = values.detach().clone()
else:
new_values = []
for value in values:
if isinstance(value, Number):
new_values.append(torch.tensor([float(value)]))
else:
new_values.append(value.detach().clone())
# There are 3 cases for when `values` is a `Sequence`.
# 1) `values` == list of floats as earlier.
# 2) `values` == combination of floats and `Tensor`s.
# 3) `values` == a list of `Tensor`s.
# For 1), the below step creates a vector of length `len(values)`
# For 2), the below step creates a `Tensor` of shape `batch_shape x q x d_f`
# with the broadcasting functionality.
# For 3), this is simply a concatenation, yielding a `Tensor` with the
# same shape as in 2).
# The key difference arises when `_construct_X_full` is invoked.
# In 1), the expansion (`self.values.expand`) will expand the `Tensor` to
# size `batch_shape x q x d_f`.
# In 2) and 3), this expansion is a no-op because they are already of the
# required size. However, 2) and 3) _cannot_ support varying `batch_shape`,
# which means that all calls to `FixedFeatureAcquisitionFunction` have
# to have the same size throughout when `values` contains a `Tensor`.
# This is consistent with the scenario when a singular `Tensor` is passed
# as the `values` argument.
new_values = torch.cat(torch.broadcast_tensors(*new_values), dim=-1)
self.register_buffer("values", new_values)
# build selector for _construct_X_full
self._selector = []
idx_X, idx_f = 0, d - new_values.shape[-1]
for i in range(self.d):
if i in columns:
self._selector.append(idx_f)
idx_f += 1
else:
self._selector.append(idx_X)
idx_X += 1
[docs] def forward(self, X: Tensor):
r"""Evaluate base acquisition function under the fixed features.
Args:
X: Input tensor of feature dimension `d' < d` such that `d' + d_f = d`.
Returns:
Base acquisition function evaluated on tensor `X_full` constructed
by adding `values` in the appropriate places (see
`_construct_X_full`).
"""
X_full = self._construct_X_full(X)
return self.acq_func(X_full)
def _construct_X_full(self, X: Tensor) -> Tensor:
r"""Constructs the full input for the base acquisition function.
Args:
X: Input tensor with shape `batch_shape x q x d'` such that
`d' + d_f = d`.
Returns:
Tensor `X_full` of shape `batch_shape x q x d`, where
`X_full[..., i] = values[..., i]` if `i in columns`,
and `X_full[..., i] = X[..., j]`, with
`j = i - sum_{l<=i} 1_{l in fixed_colunns}`.
"""
d_prime, d_f = X.shape[-1], self.values.shape[-1]
if d_prime + d_f != self.d:
raise ValueError(
f"Feature dimension d' ({d_prime}) of input must be "
f"d - d_f ({self.d - d_f})."
)
# concatenate values to the end
values = self.values.to(X).expand(*X.shape[:-1], d_f)
X_perm = torch.cat([X, values], dim=-1)
# now select the appropriate column order
return X_perm[..., self._selector]