#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from copy import deepcopy
from math import pi
from typing import List, Optional
import torch
from botorch.models.converter import batched_to_model_list
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.model import Model, ModelList
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import MultiTaskGP
from botorch.utils.sampling import manual_seed
from gpytorch.kernels import Kernel, MaternKernel, RBFKernel, ScaleKernel
from linear_operator.utils.cholesky import psd_safe_cholesky
from torch import Tensor
from torch.distributions import MultivariateNormal
from torch.nn import Module
[docs]class GPDraw(Module):
r"""Convenience wrapper for sampling a function from a GP prior.
This wrapper implicitly defines the GP sample as a self-updating function by keeping
track of the evaluated points and respective base samples used during the
evaluation.
This does not yet support multi-output models.
"""
def __init__(self, model: Model, seed: Optional[int] = None) -> None:
r"""Construct a GP function sampler.
Args:
model: The Model defining the GP prior.
"""
super().__init__()
self._model = deepcopy(model)
seed = torch.tensor(
seed if seed is not None else torch.randint(0, 1000000, (1,)).item()
)
self.register_buffer("_seed", seed)
@property
def Xs(self) -> Tensor:
"""A `(batch_shape) x n_eval x d`-dim tensor of locations at which the GP was
evaluated (or `None` if the sample has never been evaluated).
"""
try:
return self._Xs
except AttributeError:
return None
@property
def Ys(self) -> Tensor:
"""A `(batch_shape) x n_eval x d`-dim tensor of associated function values (or
`None` if the sample has never been evaluated).
"""
try:
return self._Ys
except AttributeError:
return None
[docs] def forward(self, X: Tensor) -> Tensor:
r"""Evaluate the GP sample function at a set of points X.
Args:
X: A `batch_shape x n x d`-dim tensor of points
Returns:
The value of the GP sample at the `n` points.
"""
if self.Xs is None:
X_eval = X # first time, no previous evaluation points
else:
X_eval = torch.cat([self.Xs, X], dim=-2)
posterior = self._model.posterior(X=X_eval)
base_sample_shape = posterior.base_sample_shape
# re-use old samples
bs_shape = base_sample_shape[:-2] + X.shape[-2:-1] + base_sample_shape[-1:]
with manual_seed(seed=int(self._seed)):
new_base_samples = torch.randn(bs_shape, device=X.device, dtype=X.dtype)
seed = self._seed + 1
if self.Xs is None:
base_samples = new_base_samples
else:
base_samples = torch.cat([self._base_samples, new_base_samples], dim=-2)
# TODO: Deduplicate repeated evaluations / deal with numerical degeneracies
# that could lead to non-determinsitic evaluations. We could use SVD- or
# eigendecomposition-based sampling, but we probably don't want to use this
# by default for performance reasonse.
Ys = posterior.rsample(torch.Size(), base_samples=base_samples)
self.register_buffer("_Xs", X_eval)
self.register_buffer("_Ys", Ys)
self.register_buffer("_seed", seed)
self.register_buffer("_base_samples", base_samples)
return self.Ys[..., -(X.size(-2)) :, :]
[docs]class RandomFourierFeatures(Module):
"""A class that represents Random Fourier Features."""
def __init__(
self,
kernel: Kernel,
input_dim: int,
num_rff_features: int,
sample_shape: Optional[torch.Size] = None,
) -> None:
r"""Initialize RandomFourierFeatures.
Args:
kernel: The GP kernel.
input_dim: The input dimension to the GP kernel.
num_rff_features: The number of Fourier features.
sample_shape: The shape of a single sample. For a single-element
`torch.Size` object, this is simply the number of RFF draws.
"""
if not isinstance(kernel, ScaleKernel):
base_kernel = kernel
outputscale = torch.tensor(
1.0,
dtype=base_kernel.lengthscale.dtype,
device=base_kernel.lengthscale.device,
)
else:
base_kernel = kernel.base_kernel
outputscale = kernel.outputscale.detach().clone()
if not isinstance(base_kernel, (MaternKernel, RBFKernel)):
raise NotImplementedError("Only Matern and RBF kernels are supported.")
super().__init__()
self.kernel_batch_shape = base_kernel.batch_shape
self.register_buffer("outputscale", outputscale)
self.register_buffer("lengthscale", base_kernel.lengthscale.detach().clone())
self.sample_shape = torch.Size() if sample_shape is None else sample_shape
self.register_buffer(
"weights",
self._get_weights(
base_kernel=base_kernel,
input_dim=input_dim,
num_rff_features=num_rff_features,
sample_shape=self.sample_shape,
),
)
# initialize uniformly in [0, 2 * pi]
self.register_buffer(
"bias",
2
* pi
* torch.rand(
*self.sample_shape,
*self.kernel_batch_shape,
num_rff_features,
dtype=base_kernel.lengthscale.dtype,
device=base_kernel.lengthscale.device,
),
)
def _get_weights(
self,
base_kernel: Kernel,
input_dim: int,
num_rff_features: int,
sample_shape: Optional[torch.Size] = None,
) -> Tensor:
r"""Sample weights for RFF.
Args:
kernel: The GP base kernel.
input_dim: The input dimension to the GP kernel.
num_rff_features: The number of Fourier features.
sample_shape: The sample shape of weights.
Returns:
A tensor of weights with shape
`(*sample_shape, *kernel_batch_shape, input_dim, num_rff_features)`.
"""
sample_shape = torch.Size() if sample_shape is None else sample_shape
weights = torch.randn(
*sample_shape,
*self.kernel_batch_shape,
input_dim,
num_rff_features,
dtype=base_kernel.lengthscale.dtype,
device=base_kernel.lengthscale.device,
)
if isinstance(base_kernel, MaternKernel):
gamma_dist = torch.distributions.Gamma(base_kernel.nu, base_kernel.nu)
gamma_samples = gamma_dist.sample(torch.Size([1, num_rff_features])).to(
weights
)
weights = torch.rsqrt(gamma_samples) * weights
return weights
[docs] def forward(self, X: Tensor) -> Tensor:
"""Get Fourier basis features for the provided inputs.
Note that the right-most subset of the batch shape of `X` should
be `(sample_shape) x (kernel_batch_shape)` if using either the
`sample_shape` argument or a batched kernel. In other words,
`X` should be of shape `(added_batch_shape) x (sample_shape) x
(kernel_batch_shape) x n x input_dim`, where parantheses denote
that the given batch shape can be empty. `X` can always be
a tensor of shape `n x input_dim`, in which case broadcasting
will take care of the batch shape. This will raise a `ValueError`
if the batch shapes are not compatible.
Args:
X: Input tensor of shape `(batch_shape) x n x input_dim`.
Returns:
A Tensor of shape `(batch_shape) x n x rff`. If `X` does not have
a `batch_shape`, the output `batch_shape` will be
`(sample_shape) x (kernel_batch_shape)`.
"""
self._check_forward_X_shape_compatibility(X)
# X is of shape (additional_batch_shape) x (sample_shape)
# x (kernel_batch_shape) x n x d.
# Weights is of shape (sample_shape) x (kernel_batch_shape) x d x num_rff.
X_scaled = torch.div(X, self.lengthscale)
batchmatmul = X_scaled @ self.weights
bias = self.bias
# Bias is of shape (sample_shape) x (kernel_batch_shape) x num_rff.
# Batchmatmul is of shape (additional_batch_shape) x (sample_shape)
# x (kernel_batch_shape) x n x num_rff.
outputs = torch.cos(batchmatmul + bias.unsqueeze(-2))
# Make sure we divide at the correct (i.e., kernel's) batch dimension.
if len(self.kernel_batch_shape) > 0:
outputscale = self.outputscale.view(*self.kernel_batch_shape, 1, 1)
else:
outputscale = self.outputscale
return torch.sqrt(2.0 * outputscale / self.weights.shape[-1]) * outputs
def _check_forward_X_shape_compatibility(self, X: Tensor) -> None:
r"""Check that the `batch_shape` of X, if any, is compatible with the
`sample_shape` & `kernel_batch_shape`.
"""
full_batch_shape_X = X.shape[:-2]
len_full_batch_shape_X = len(full_batch_shape_X)
if len_full_batch_shape_X == 0:
# Non-batched X.
return
expected_batch_shape = self.sample_shape + self.kernel_batch_shape
# Check if they're broadcastable.
for b_idx in range(min(len(expected_batch_shape), len_full_batch_shape_X)):
neg_idx = -b_idx - 1
if (
full_batch_shape_X[neg_idx] != expected_batch_shape[neg_idx]
and full_batch_shape_X[neg_idx] != 1
):
raise ValueError(
"the batch shape of X is expected to follow the pattern: "
f"`... x {tuple(expected_batch_shape)}`"
)
[docs]def get_deterministic_model_multi_samples(
weights: List[Tensor],
bases: List[RandomFourierFeatures],
) -> GenericDeterministicModel:
"""
Get a batched deterministic model that batch evaluates `n_samples` function
samples. This supports multi-output models as well.
Args:
weights: A list of weights with `num_outputs` elements. Each weight is of
shape `(batch_shape_input) x n_samples x num_rff_features`, where
`(batch_shape_input)` is the batch shape of the inputs used to obtain the
posterior weights.
bases: A list of `RandomFourierFeatures` with `num_outputs` elements. Each
basis has a sample shape of `n_samples`.
n_samples: The number of function samples.
Returns:
A batched `GenericDeterministicModel`s that batch evaluates `n_samples`
function samples.
"""
eval_callables = [
get_eval_gp_sample_callable(w=w, basis=basis)
for w, basis in zip(weights, bases)
]
def evaluate_gps_X(X):
return torch.cat([_f(X) for _f in eval_callables], dim=-1)
return GenericDeterministicModel(
f=evaluate_gps_X,
num_outputs=len(weights),
)
[docs]def get_eval_gp_sample_callable(w: Tensor, basis: RandomFourierFeatures) -> Tensor:
def _f(X):
return basis(X) @ w.unsqueeze(-1)
return _f
[docs]def get_deterministic_model(
weights: List[Tensor], bases: List[RandomFourierFeatures]
) -> GenericDeterministicModel:
"""Get a deterministic model using the provided weights and bases for each output.
Args:
weights: A list of weights with `m` elements.
bases: A list of `RandomFourierFeatures` with `m` elements.
Returns:
A deterministic model.
"""
callables = [
get_eval_gp_sample_callable(w=w, basis=basis)
for w, basis in zip(weights, bases)
]
def evaluate_gp_sample(X):
return torch.cat([c(X) for c in callables], dim=-1)
return GenericDeterministicModel(f=evaluate_gp_sample, num_outputs=len(weights))
[docs]def get_deterministic_model_list(
weights: List[Tensor],
bases: List[RandomFourierFeatures],
) -> ModelList:
"""Get a deterministic model list using the provided weights and bases
for each output.
Args:
weights: A list of weights with `m` elements.
bases: A list of `RandomFourierFeatures` with `m` elements.
Returns:
A deterministic model.
"""
samples = []
for w, basis in zip(weights, bases):
sample = GenericDeterministicModel(
f=get_eval_gp_sample_callable(w=w, basis=basis),
num_outputs=1,
)
samples.append(sample)
return ModelList(*samples)
[docs]def get_weights_posterior(X: Tensor, y: Tensor, sigma_sq: Tensor) -> MultivariateNormal:
r"""Sample bayesian linear regression weights.
Args:
X: A tensor of inputs with shape `(*batch_shape, n num_rff_features)`.
y: A tensor of outcomes with shape `(*batch_shape, n)`.
sigma_sq: The likelihood noise variance. This should be a tensor with
shape `kernel_batch_shape, 1, 1` if using a batched kernel.
Otherwise, it should be a scalar tensor.
Returns:
The posterior distribution over the weights.
"""
with torch.no_grad():
X_trans = X.transpose(-2, -1)
A = X_trans @ X + sigma_sq * torch.eye(
X.shape[-1], dtype=X.dtype, device=X.device
)
# mean is given by: m = S @ x.T @ y, where S = A_inv
# compute inverse of A using solves
# covariance is A_inv * sigma
L_A = psd_safe_cholesky(A)
# solve L_A @ u = I
Iw = torch.eye(L_A.shape[-1], dtype=X.dtype, device=X.device)
u = torch.triangular_solve(Iw, L_A, upper=False).solution
# solve L_A^T @ S = u
A_inv = torch.triangular_solve(u, L_A.transpose(-2, -1)).solution
m = (A_inv @ X_trans @ y.unsqueeze(-1)).squeeze(-1)
L = psd_safe_cholesky(A_inv * sigma_sq)
return MultivariateNormal(loc=m, scale_tril=L)
[docs]def get_gp_samples(
model: Model, num_outputs: int, n_samples: int, num_rff_features: int = 512
) -> GenericDeterministicModel:
r"""Sample functions from GP posterior using RFFs. The returned
`GenericDeterministicModel` effectively wraps `num_outputs` models,
each of which has a batch shape of `n_samples`. Refer
`get_deterministic_model_multi_samples` for more details.
NOTE: If using input / outcome transforms, the gp samples must be accessed via
the `gp_sample.posterior(X)` call. Otherwise, `gp_sample(X)` will produce bogus
values that do not agree with the underlying `model`. It is also highly recommended
to use outcome transforms to standardize the input data, since the gp samples do
not work well when training outcomes are not zero-mean.
Args:
model: The model.
num_outputs: The number of outputs.
n_samples: The number of functions to be sampled IID.
num_rff_features: The number of random Fourier features.
Returns:
A `GenericDeterministicModel` that evaluates `n_samples` sampled functions.
If `n_samples > 1`, this will be a batched model.
"""
# Get transforms from the model.
intf = getattr(model, "input_transform", None)
octf = getattr(model, "outcome_transform", None)
# Remove the outcome transform - leads to buggy draws.
if octf is not None:
del model.outcome_transform
if intf is not None:
del model.input_transform
if num_outputs > 1:
if not isinstance(model, ModelListGP):
models = batched_to_model_list(model).models
else:
models = model.models
else:
models = [model]
if isinstance(models[0], MultiTaskGP):
raise NotImplementedError
weights = []
bases = []
octfs = []
intfs = []
for m in range(num_outputs):
train_X = models[m].train_inputs[0]
train_targets = models[m].train_targets
_model = models[m]
_intf = getattr(_model, "input_transform", None)
_octf = getattr(_model, "outcome_transform", None)
# Remove the outcome transform - leads to buggy draws.
if _octf is not None:
del _model.outcome_transform
octfs.append(_octf)
intfs.append(_intf)
# Get random Fourier features.
# sample_shape controls the number of iid functions.
basis = RandomFourierFeatures(
kernel=_model.covar_module,
input_dim=train_X.shape[-1],
num_rff_features=num_rff_features,
sample_shape=torch.Size([n_samples] if n_samples > 1 else []),
)
bases.append(basis)
phi_X = basis(train_X)
# Sample weights from bayesian linear model.
# weights.sample().shape == (n_samples, batch_shape_input, num_rff_features)
sigma_sq = _model.likelihood.noise
if len(basis.kernel_batch_shape) > 0:
sigma_sq = sigma_sq.unsqueeze(-2)
mvn = get_weights_posterior(
X=phi_X,
y=train_targets,
sigma_sq=sigma_sq,
)
weights.append(mvn.sample())
# TODO: Ideally support RFFs for multi-outputs instead of having to
# generate a basis for each output serially.
if any(_octf is not None for _octf in octfs) or any(
_intf is not None for _intf in intfs
):
base_gp_samples = get_deterministic_model_list(
weights=weights,
bases=bases,
)
for m in range(len(weights)):
_octf = octfs[m]
_intf = intfs[m]
if _octf is not None:
base_gp_samples.models[m].outcome_transform = _octf
models[m].outcome_transform = _octf
if _intf is not None:
base_gp_samples.models[m].input_transform = _intf
return base_gp_samples
elif n_samples > 1:
base_gp_samples = get_deterministic_model_multi_samples(
weights=weights,
bases=bases,
)
else:
base_gp_samples = get_deterministic_model(
weights=weights,
bases=bases,
)
# Load the transforms on the models.
if intf is not None:
base_gp_samples.input_transform = intf
model.input_transform = intf
if octf is not None:
base_gp_samples.outcome_transform = octf
model.outcome_transform = octf
return base_gp_samples