#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Deterministic Models: Simple wrappers that allow the usage of deterministic
mappings via the BoTorch Model and Posterior APIs.
Deterministic models are useful for expressing known input-output relationships
within the BoTorch Model API. This is useful e.g. for multi-objective
optimization with known objective functions (e.g. the number of parameters of a
Neural Network in the context of Neural Architecture Search is usually a known
function of the architecture configuration), or to encode cost functions for
cost-aware acquisition utilities. Cost-aware optimization is desirable when
evaluations have a cost that is heterogeneous, either in the inputs `X` or in a
particular fidelity parameter that directly encodes the fidelity of the
observation. `GenericDeterministicModel` supports arbitrary deterministic
functions, while `AffineFidelityCostModel` is a particular cost model for
multi-fidelity optimization. Other use cases of deterministic models include
representing approximate GP sample paths, e.g. Matheron paths obtained
with `get_matheron_path_model`, which allows them to be substituted in acquisition
functions or in other places where a `Model` is expected.
"""
from __future__ import annotations
from abc import abstractmethod
from collections.abc import Callable
import torch
from botorch.models.ensemble import EnsembleModel
from botorch.models.model import Model
from torch import Tensor
[docs]
class DeterministicModel(EnsembleModel):
"""Abstract base class for deterministic models."""
[docs]
@abstractmethod
def forward(self, X: Tensor) -> Tensor:
r"""Compute the (deterministic) model output at X.
Args:
X: A `batch_shape x n x d`-dim input tensor `X`.
Returns:
A `batch_shape x n x m`-dimensional output tensor (the outcome
dimension `m` must be explicit if `m=1`).
"""
pass # pragma: no cover
def _forward(self, X: Tensor) -> Tensor:
r"""Compatibilizes the `DeterministicModel` with `EnsemblePosterior`"""
return self.forward(X=X).unsqueeze(-3)
[docs]
class GenericDeterministicModel(DeterministicModel):
r"""A generic deterministic model constructed from a callable.
Example:
>>> f = lambda x: x.sum(dim=-1, keep_dims=True)
>>> model = GenericDeterministicModel(f)
"""
def __init__(self, f: Callable[[Tensor], Tensor], num_outputs: int = 1) -> None:
r"""
Args:
f: A callable mapping a `batch_shape x n x d`-dim input tensor `X`
to a `batch_shape x n x m`-dimensional output tensor (the
outcome dimension `m` must be explicit, even if `m=1`).
num_outputs: The number of outputs `m`.
"""
super().__init__()
self._f = f
self._num_outputs = num_outputs
[docs]
def subset_output(self, idcs: list[int]) -> GenericDeterministicModel:
r"""Subset the model along the output dimension.
Args:
idcs: The output indices to subset the model to.
Returns:
The current model, subset to the specified output indices.
"""
def f_subset(X: Tensor) -> Tensor:
return self._f(X)[..., idcs]
return self.__class__(f=f_subset, num_outputs=len(idcs))
[docs]
def forward(self, X: Tensor) -> Tensor:
r"""Compute the (deterministic) model output at X.
Args:
X: A `batch_shape x n x d`-dim input tensor `X`.
Returns:
A `batch_shape x n x m`-dimensional output tensor.
"""
return self._f(X)
[docs]
class AffineDeterministicModel(DeterministicModel):
r"""An affine deterministic model."""
def __init__(self, a: Tensor, b: Tensor | float = 0.01) -> None:
r"""Affine deterministic model from weights and offset terms.
A simple model of the form
y[..., m] = b[m] + sum_{i=1}^d a[i, m] * X[..., i]
Args:
a: A `d x m`-dim tensor of linear weights, where `m` is the number
of outputs (must be explicit if `m=1`)
b: The affine (offset) term. Either a float (for single-output
models or if the offset is shared), or a `m`-dim tensor (with
different offset values for for the `m` different outputs).
"""
if not a.ndim == 2:
raise ValueError("a must be two-dimensional")
if not torch.is_tensor(b):
b = torch.tensor([b])
if not b.ndim == 1:
raise ValueError("b nust be one-dimensional")
super().__init__()
self.register_buffer("a", a)
self.register_buffer("b", b.expand(a.size(-1)))
self._num_outputs = a.size(-1)
[docs]
def subset_output(self, idcs: list[int]) -> AffineDeterministicModel:
r"""Subset the model along the output dimension.
Args:
idcs: The output indices to subset the model to.
Returns:
The current model, subset to the specified output indices.
"""
a_sub = self.a.detach()[..., idcs].clone()
b_sub = self.b.detach()[..., idcs].clone()
return self.__class__(a=a_sub, b=b_sub)
[docs]
def forward(self, X: Tensor) -> Tensor:
return self.b + torch.einsum("...d,dm", X, self.a)
[docs]
class PosteriorMeanModel(DeterministicModel):
"""A deterministic model that always returns the posterior mean."""
def __init__(self, model: Model) -> None:
r"""
Args:
model: The base model.
"""
super().__init__()
self.model = model
[docs]
def forward(self, X: Tensor) -> Tensor:
return self.model.posterior(X).mean
[docs]
class FixedSingleSampleModel(DeterministicModel):
r"""
A deterministic model defined by a single sample `w`.
Given a base model `f` and a fixed sample `w`, the model always outputs
y = f_mean(x) + f_stddev(x) * w
We assume the outcomes are uncorrelated here.
"""
def __init__(
self,
model: Model,
w: Tensor | None = None,
dim: int | None = None,
jitter: float | None = 1e-8,
dtype: torch.dtype | None = None,
device: torch.dtype | None = None,
) -> None:
r"""
Args:
model: The base model.
w: A 1-d tensor with length model.num_outputs.
If None, draw it from a standard normal distribution.
dim: dimensionality of w.
If None and w is not provided, draw w samples of size model.num_outputs.
jitter: jitter value to be added for numerical stability, 1e-8 by default.
dtype: dtype for w if specified
device: device for w if specified
"""
super().__init__()
self.model = model
self._num_outputs = model.num_outputs
self.jitter = jitter
if w is None:
self.w = (
torch.randn(model.num_outputs, dtype=dtype, device=device)
if dim is None
else torch.randn(dim, dtype=dtype, device=device)
)
else:
self.w = w
[docs]
def forward(self, X: Tensor) -> Tensor:
post = self.model.posterior(X)
return post.mean + torch.sqrt(post.variance + self.jitter) * self.w.to(X)