#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from collections.abc import Callable
from types import NoneType
from typing import Any
import torch
from botorch.models.approximate_gp import ApproximateGPyTorchModel
from botorch.models.transforms.input import InputTransform
from botorch.sampling.pathwise.features import KernelEvaluationMap
from botorch.sampling.pathwise.paths import GeneralizedLinearPath, SamplePath
from botorch.sampling.pathwise.utils import (
get_input_transform,
get_train_inputs,
get_train_targets,
TInputTransform,
)
from botorch.utils.dispatcher import Dispatcher
from botorch.utils.types import DEFAULT
from gpytorch.kernels.kernel import Kernel
from gpytorch.likelihoods import _GaussianLikelihoodBase, Likelihood
from gpytorch.models import ApproximateGP, ExactGP, GP
from gpytorch.variational import VariationalStrategy
from linear_operator.operators import (
LinearOperator,
SumLinearOperator,
ZeroLinearOperator,
)
from torch import Tensor
TPathwiseUpdate = Callable[[GP, Tensor], SamplePath]
GaussianUpdate = Dispatcher("gaussian_update")
[docs]
def gaussian_update(
model: GP,
sample_values: Tensor,
likelihood: Likelihood | None = DEFAULT,
**kwargs: Any,
) -> GeneralizedLinearPath:
r"""Computes a Gaussian pathwise update in exact arithmetic:
.. code-block:: text
(f | y)(·) = f(·) + Cov(f(·), y) Cov(y, y)^{-1} (y - f(X) - ε),
\_______________________________________/
V
"Gaussian pathwise update"
where `=` denotes equality in distribution, :math:`f \sim GP(0, k)`,
:math:`y \sim N(f(X), \Sigma)`, and :math:`\epsilon \sim N(0, \Sigma)`.
For more information, see [wilson2020sampling]_ and [wilson2021pathwise]_.
Args:
model: A Gaussian process prior together with a likelihood.
sample_values: Assumed values for :math:`f(X)`.
likelihood: An optional likelihood used to help define the desired
update. Defaults to `model.likelihood` if it exists else None.
"""
if likelihood is DEFAULT:
likelihood = getattr(model, "likelihood", None)
return GaussianUpdate(model, likelihood, sample_values=sample_values, **kwargs)
def _gaussian_update_exact(
kernel: Kernel,
points: Tensor,
target_values: Tensor,
sample_values: Tensor,
noise_covariance: Tensor | LinearOperator | None = None,
scale_tril: Tensor | LinearOperator | None = None,
input_transform: TInputTransform | None = None,
) -> GeneralizedLinearPath:
# Prepare Cholesky factor of `Cov(y, y)` and noise sample values as needed
if isinstance(noise_covariance, (NoneType, ZeroLinearOperator)):
scale_tril = kernel(points).cholesky() if scale_tril is None else scale_tril
else:
noise_values = torch.randn_like(sample_values).unsqueeze(-1)
noise_values = noise_covariance.cholesky() @ noise_values
sample_values = sample_values + noise_values.squeeze(-1)
scale_tril = (
SumLinearOperator(kernel(points), noise_covariance).cholesky()
if scale_tril is None
else scale_tril
)
# Solve for `Cov(y, y)^{-1}(Y - f(X) - ε)`
errors = target_values - sample_values
weight = torch.cholesky_solve(errors.unsqueeze(-1), scale_tril.to_dense())
# Define update feature map and paths
feature_map = KernelEvaluationMap(
kernel=kernel,
points=points,
input_transform=input_transform,
)
return GeneralizedLinearPath(feature_map=feature_map, weight=weight.squeeze(-1))
@GaussianUpdate.register(ExactGP, _GaussianLikelihoodBase)
def _gaussian_update_ExactGP(
model: ExactGP,
likelihood: _GaussianLikelihoodBase,
*,
sample_values: Tensor,
target_values: Tensor | None = None,
points: Tensor | None = None,
noise_covariance: Tensor | LinearOperator | None = None,
scale_tril: Tensor | LinearOperator | None = None,
) -> GeneralizedLinearPath:
if points is None:
(points,) = get_train_inputs(model, transformed=True)
if target_values is None:
target_values = get_train_targets(model, transformed=True)
if noise_covariance is None:
noise_covariance = likelihood.noise_covar(shape=points.shape[:-1])
return _gaussian_update_exact(
kernel=model.covar_module,
points=points,
target_values=target_values,
sample_values=sample_values,
noise_covariance=noise_covariance,
scale_tril=scale_tril,
input_transform=get_input_transform(model),
)
@GaussianUpdate.register(ApproximateGPyTorchModel, (Likelihood, NoneType))
def _gaussian_update_ApproximateGPyTorchModel(
model: ApproximateGPyTorchModel,
likelihood: Likelihood | None,
**kwargs: Any,
) -> GeneralizedLinearPath:
return GaussianUpdate(
model.model, likelihood, input_transform=get_input_transform(model), **kwargs
)
@GaussianUpdate.register(ApproximateGP, (Likelihood, NoneType))
def _gaussian_update_ApproximateGP(
model: ApproximateGP, likelihood: Likelihood | None, **kwargs: Any
) -> GeneralizedLinearPath:
return GaussianUpdate(model, model.variational_strategy, **kwargs)
@GaussianUpdate.register(ApproximateGP, VariationalStrategy)
def _gaussian_update_ApproximateGP_VariationalStrategy(
model: ApproximateGP,
_: VariationalStrategy,
*,
sample_values: Tensor,
target_values: Tensor | None = None,
noise_covariance: Tensor | LinearOperator | None = None,
input_transform: InputTransform | None = None,
**ignore: Any,
) -> GeneralizedLinearPath:
# TODO: Account for jitter added by `psd_safe_cholesky`
if not isinstance(noise_covariance, (NoneType, ZeroLinearOperator)):
raise NotImplementedError(
f"`noise_covariance` argument not yet supported for {type(model)}."
)
# Inducing points `Z` are assumed to live in transformed space
batch_shape = model.covar_module.batch_shape
v = model.variational_strategy
Z = v.inducing_points
L = v._cholesky_factor(v(Z, prior=True).lazy_covariance_matrix).to(
dtype=sample_values.dtype
)
# Generate whitened inducing variables `u`, then location-scale transform
if target_values is None:
u = v.variational_distribution.rsample(
sample_values.shape[: sample_values.ndim - len(batch_shape) - 1],
)
target_values = model.mean_module(Z) + (u @ L.transpose(-1, -2))
return _gaussian_update_exact(
kernel=model.covar_module,
points=Z,
target_values=target_values,
sample_values=sample_values,
scale_tril=L,
input_transform=input_transform,
)