Source code for botorch.posteriors.latent_kronecker

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import torch
from botorch.models.gpytorch import GPyTorchModel
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.distributions import MultivariateNormal
from linear_operator.operators import IdentityLinearOperator, ZeroLinearOperator
from torch import Tensor


r"""
References

.. [wilson2020sampling]
    J. Wilson, V. Borovitskiy, A. Terenin, P. Mostowsky, and M. Deisenroth. Efficiently
    sampling functions from Gaussian process posteriors. International Conference on
    Machine Learning (2020).

.. [wilson2021pathwise]
    J. Wilson, V. Borovitskiy, A. Terenin, P. Mostowsky, and M. Deisenroth. Pathwise
    Conditioning of Gaussian Processes. Journal of Machine Learning Research (2021).
"""


[docs] class LatentKroneckerGPPosterior(GPyTorchPosterior): r""" Dummy posterior class for a LatentKroneckerGP model. Internally calls model._rsample_from_base_samples to draw posterior samples via pathwise conditioning aka Matheron's rule [wilson2020sampling, wilson2021pathwise]. This is necessary because BoTorch instantiates the posterior object before creating base samples, whereas pathwise conditioning requires the base samples first to calculate the posterior samples. To cache expensive computations, which only have to be performed once for the same base samples, the results are stored in the model instead of the posterior object, because a new posterior object is created in each acquisition function call. """ def __init__( self, model: GPyTorchModel, X: Tensor, ) -> None: r"""A dummy posterior for LatentKroneckerGP models. Args: model: The LatentKroneckerGP model to which this posterior belongs to. X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension of the feature space and `q` is the number of points considered jointly, on which the posterior shall be evaluated. """ self._dtype = X.dtype self._device = X.device self.batch_shape = model.batch_shape self.output_batch_shape = torch.broadcast_shapes( model.batch_shape, X.shape[:-2] ) self.q = X.shape[-2] output_dim = self.q * model.T.shape[-1] mean = ZeroLinearOperator( *self.output_batch_shape, output_dim, dtype=X.dtype, device=X.device ) covar = IdentityLinearOperator( output_dim, batch_shape=self.output_batch_shape, dtype=X.dtype, device=X.device, ) dummy_mvn = MultivariateNormal(mean=mean, covariance_matrix=covar) super().__init__(distribution=dummy_mvn) self.model = model self.X = X self._is_mt = True @property def base_sample_shape(self): r"""The shape of a base sample used for constructing posterior samples. Overwrites the standard `base_sample_shape` call to inform samplers that `n_train_full + n_train + n_test` samples are needed rather than n samples. """ n_train_full = self.model.train_inputs[0].shape[-2] * self.model.T.shape[-1] n_train = self.model.train_targets.shape[-1] n_test = self.q * self.model.T.shape[-1] return self.batch_shape + torch.Size([n_train_full + n_train + n_test]) @property def batch_range(self) -> tuple[int, int]: r"""The t-batch range. This is used in samplers to identify the t-batch component of the `base_sample_shape`. The base samples are expanded over the t-batches to provide consistency in the acquisition values, i.e., to ensure that a candidate produces same value regardless of its position on the t-batch. """ return (0, -1) def _extended_shape( self, sample_shape: torch.Size = torch.Size(), # noqa: B008 ) -> torch.Size: r"""Returns the shape of the samples produced by the distribution with the given `sample_shape`. """ time_shape = torch.Size([self.model.T.shape[-1]]) q_shape = torch.Size([self.q]) return sample_shape + self.output_batch_shape + q_shape + time_shape
[docs] def rsample_from_base_samples( self, sample_shape: torch.Size, base_samples: Tensor, ) -> Tensor: r"""Sample from the posterior (with gradients) using base samples. This is intended to be used with a sampler that produces the corresponding base samples, and enables acquisition optimization via Sample Average Approximation. Since this posterior is a dummy object, call the model to perform sampling. Args: sample_shape: A `torch.Size` object specifying the sample shape. To draw `n` samples, set to `torch.Size([n])`. To draw `b` batches of `n` samples each, set to `torch.Size([b, n])`. base_samples: A Tensor of `N(0, I)` base samples of shape `sample_shape x base_sample_shape`, typically obtained from a `Sampler`. This is used for deterministic optimization. Returns: Samples from the posterior, a tensor of shape `self._extended_shape(sample_shape=sample_shape)`. """ if base_samples.shape[: len(sample_shape)] != sample_shape: raise RuntimeError( "`sample_shape` disagrees with shape of `base_samples`. " f"Got {sample_shape=} and {base_samples.shape=}." ) return self.model._rsample_from_base_samples(self.X, base_samples)
[docs] def rsample( self, sample_shape: torch.Size | None = None, ) -> Tensor: r"""Sample from the posterior (with gradients). Args: sample_shape: A `torch.Size` object specifying the sample shape. To draw `n` samples, set to `torch.Size([n])`. To draw `b` batches of `n` samples each, set to `torch.Size([b, n])`. Returns: Samples from the posterior, a tensor of shape `self._extended_shape(sample_shape=sample_shape)`. """ if sample_shape is None: sample_shape = torch.Size([1]) base_samples = torch.randn( sample_shape + self.base_sample_shape, dtype=self.X.dtype, device=self.X.device, ) return self.rsample_from_base_samples(sample_shape, base_samples)