#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
.. [wilson2020sampling]
J. Wilson, V. Borovitskiy, A. Terenin, P. Mostowsky, and M. Deisenroth. Efficiently
sampling functions from Gaussian process posteriors. International Conference on
Machine Learning (2020).
.. [wilson2021pathwise]
J. Wilson, V. Borovitskiy, A. Terenin, P. Mostowsky, and M. Deisenroth. Pathwise
Conditioning of Gaussian Processes. Journal of Machine Learning Research (2021).
"""
from __future__ import annotations
import torch
from botorch.exceptions.errors import UnsupportedError
from botorch.models.approximate_gp import ApproximateGPyTorchModel
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.model import ModelList
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.sampling.pathwise.paths import PathDict, PathList, SamplePath
from botorch.sampling.pathwise.prior_samplers import (
draw_kernel_feature_paths,
TPathwisePriorSampler,
)
from botorch.sampling.pathwise.update_strategies import gaussian_update, TPathwiseUpdate
from botorch.sampling.pathwise.utils import (
get_output_transform,
get_train_inputs,
get_train_targets,
TInputTransform,
TOutputTransform,
)
from botorch.utils.context_managers import delattr_ctx
from botorch.utils.dispatcher import Dispatcher
from botorch.utils.transforms import is_ensemble
from gpytorch.models import ApproximateGP, ExactGP, GP
from torch import Size, Tensor
DrawMatheronPaths = Dispatcher("draw_matheron_paths")
[docs]
class MatheronPath(PathDict):
r"""Represents function draws from a GP posterior via Matheron's rule:
.. code-block:: text
"Prior path"
v
(f | y)(·) = f(·) + Cov(f(·), y) Cov(y, y)^{-1} (y - f(X) - ε),
\_______________________________________/
v
"Update path"
where `=` denotes equality in distribution, :math:`f \sim GP(0, k)`,
:math:`y \sim N(f(X), \Sigma)`, and :math:`\epsilon \sim N(0, \Sigma)`.
For more information, see [wilson2020sampling]_ and [wilson2021pathwise]_.
"""
def __init__(
self,
prior_paths: SamplePath,
update_paths: SamplePath,
input_transform: TInputTransform | None = None,
output_transform: TOutputTransform | None = None,
) -> None:
r"""Initializes a MatheronPath instance.
Args:
prior_paths: Sample paths used to represent the prior.
update_paths: Sample paths used to represent the data.
input_transform: An optional input transform for the module.
output_transform: An optional output transform for the module.
"""
super().__init__(
join=sum,
paths={"prior_paths": prior_paths, "update_paths": update_paths},
input_transform=input_transform,
output_transform=output_transform,
)
[docs]
def get_matheron_path_model(
model: GP, sample_shape: Size | None = None
) -> GenericDeterministicModel:
r"""Generates a deterministic model using a single Matheron path drawn
from the model's posterior.
The deterministic model evalutes the output of `draw_matheron_paths`,
and reshapes it to mimic the output behavior of the model's posterior.
Args:
model: The model whose posterior is to be sampled.
sample_shape: The shape of the sample paths to be drawn, if an ensemble
of sample paths is desired. If this is specified, the resulting
deterministic model will behave as if the `sample_shape` is prepended
to the `batch_shape` of the model. The inputs used to evaluate the model
must be adjusted to match.
Returns:
A deterministic model that evaluates the Matheron path.
"""
sample_shape = Size() if sample_shape is None else sample_shape
path = draw_matheron_paths(model, sample_shape=sample_shape)
num_outputs = model.num_outputs
if isinstance(model, ModelList) and len(model.models) != num_outputs:
raise UnsupportedError("A model-list of multi-output models is not supported.")
def f(X: Tensor) -> Tensor:
r"""Reshapes the path evaluations to bring the output dimension to the end.
Args:
X: The input tensor of shape `batch_shape x q x d`.
If the model is batched, `batch_shape` must be broadcastable to
the model batch shape.
Returns:
The output tensor of shape `batch_shape x q x m`.
"""
if num_outputs == 1:
# For single-output, we lack the output dimension. Add one.
res = path(X).unsqueeze(-1)
elif isinstance(model, ModelList):
# For model list, path evaluates to a list of tensors. Stack them.
res = torch.stack(path(X), dim=-1)
else:
# For multi-output, path expects inputs broadcastable to
# `model._aug_batch_shape x q x d` and returns outputs of shape
# `model._aug_batch_shape x q`. Augmented batch shape includes the
# `m` dimension, so we will unsqueeze that and transpose after.
res = path(X.unsqueeze(-3)).transpose(-1, -2)
return res
path_model = GenericDeterministicModel(f=f, num_outputs=num_outputs)
path_model._is_ensemble = is_ensemble(model) or len(sample_shape) > 0
return path_model
[docs]
def draw_matheron_paths(
model: GP,
sample_shape: Size,
prior_sampler: TPathwisePriorSampler = draw_kernel_feature_paths,
update_strategy: TPathwiseUpdate = gaussian_update,
) -> MatheronPath:
r"""Generates function draws from (an approximate) Gaussian process posterior.
When evaluted, sample paths produced by this method return Tensors with dimensions
`sample_dims x batch_dims x [joint_dim]`, where `joint_dim` denotes the penultimate
dimension of the input tensor. For multioutput models, outputs are returned as the
final batch dimension.
Args:
model: Gaussian process whose posterior is to be sampled.
sample_shape: Sizes of sample dimensions.
prior_sample: A callable that takes a model and a sample shape and returns
a set of sample paths representing the prior.
update_strategy: A callable that takes a model and a tensor of prior process
values and returns a set of sample paths representing the data.
"""
return DrawMatheronPaths(
model,
sample_shape=sample_shape,
prior_sampler=prior_sampler,
update_strategy=update_strategy,
)
@DrawMatheronPaths.register(ModelListGP)
def _draw_matheron_paths_ModelListGP(
model: ModelListGP,
sample_shape: Size,
*,
prior_sampler: TPathwisePriorSampler = draw_kernel_feature_paths,
update_strategy: TPathwiseUpdate = gaussian_update,
):
return PathList(
[
draw_matheron_paths(
model=m,
sample_shape=sample_shape,
prior_sampler=prior_sampler,
update_strategy=update_strategy,
)
for m in model.models
]
)
@DrawMatheronPaths.register(ExactGP)
def _draw_matheron_paths_ExactGP(
model: ExactGP,
*,
sample_shape: Size,
prior_sampler: TPathwisePriorSampler,
update_strategy: TPathwiseUpdate,
) -> MatheronPath:
(train_X,) = get_train_inputs(model, transformed=True)
train_Y = get_train_targets(model, transformed=True)
with delattr_ctx(model, "outcome_transform"):
# Generate draws from the prior
prior_paths = prior_sampler(model=model, sample_shape=sample_shape)
sample_values = prior_paths.forward(train_X)
# Compute pathwise updates
update_paths = update_strategy(
model=model,
sample_values=sample_values,
target_values=train_Y,
)
return MatheronPath(
prior_paths=prior_paths,
update_paths=update_paths,
output_transform=get_output_transform(model),
)
@DrawMatheronPaths.register((ApproximateGP, ApproximateGPyTorchModel))
def _draw_matheron_paths_ApproximateGP(
model: ApproximateGP | ApproximateGPyTorchModel,
*,
sample_shape: Size,
prior_sampler: TPathwisePriorSampler,
update_strategy: TPathwiseUpdate,
) -> MatheronPath:
# Note: Inducing points are assumed to be pre-transformed
Z = (
model.model.variational_strategy.inducing_points
if isinstance(model, ApproximateGPyTorchModel)
else model.variational_strategy.inducing_points
)
with delattr_ctx(model, "outcome_transform"):
# Generate draws from the prior
prior_paths = prior_sampler(model=model, sample_shape=sample_shape)
sample_values = prior_paths.forward(Z) # `forward` bypasses transforms
# Compute pathwise updates
update_paths = update_strategy(model=model, sample_values=sample_values)
return MatheronPath(
prior_paths=prior_paths,
update_paths=update_paths,
output_transform=get_output_transform(model),
)