Source code for botorch.posteriors.transformed

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import Callable, Optional, Tuple

import torch
from botorch.posteriors.posterior import Posterior
from torch import Tensor


[docs]class TransformedPosterior(Posterior): r"""A generic transformation of a posterior (implicitly represented).""" def __init__( self, posterior: Posterior, sample_transform: Callable[[Tensor], Tensor], mean_transform: Optional[Callable[[Tensor, Tensor], Tensor]] = None, variance_transform: Optional[Callable[[Tensor, Tensor], Tensor]] = None, ) -> None: r"""An implicitly represented transformed posterior. Args: posterior: The posterior object to be transformed. sample_transform: A callable applying a sample-level transform to a `sample_shape x batch_shape x q x m`-dim tensor of samples from the original posterior, returning a tensor of samples of the same shape. mean_transform: A callable transforming a 2-tuple of mean and variance (both of shape `batch_shape x m x o`) of the original posterior to the mean of the transformed posterior. variance_transform: A callable transforming a 2-tuple of mean and variance (both of shape `batch_shape x m x o`) of the original posterior to a variance of the transformed posterior. """ self._posterior = posterior self._sample_transform = sample_transform self._mean_transform = mean_transform self._variance_transform = variance_transform @property def base_sample_shape(self) -> torch.Size: r"""The shape of a base sample used for constructing posterior samples.""" return self._posterior.base_sample_shape @property def batch_range(self) -> Tuple[int, int]: r"""The t-batch range. This is used in samplers to identify the t-batch component of the `base_sample_shape`. The base samples are expanded over the t-batches to provide consistency in the acquisition values, i.e., to ensure that a candidate produces same value regardless of its position on the t-batch. """ return self._posterior.batch_range @property def device(self) -> torch.device: r"""The torch device of the posterior.""" return self._posterior.device @property def dtype(self) -> torch.dtype: r"""The torch dtype of the posterior.""" return self._posterior.dtype def _extended_shape( self, sample_shape: torch.Size = torch.Size() # noqa: B008 ) -> torch.Size: r"""Returns the shape of the samples produced by the posterior with the given `sample_shape`. NOTE: This assumes that the `sample_transform` does not change the shape of the samples. """ return self._posterior._extended_shape(sample_shape=sample_shape) @property def mean(self) -> Tensor: r"""The mean of the posterior as a `batch_shape x n x m`-dim Tensor.""" if self._mean_transform is None: raise NotImplementedError("No mean transform provided.") try: variance = self._posterior.variance except (NotImplementedError, AttributeError): variance = None return self._mean_transform(self._posterior.mean, variance) @property def variance(self) -> Tensor: r"""The variance of the posterior as a `batch_shape x n x m`-dim Tensor.""" if self._variance_transform is None: raise NotImplementedError("No variance transform provided.") return self._variance_transform(self._posterior.mean, self._posterior.variance)
[docs] def rsample_from_base_samples( self, sample_shape: torch.Size, base_samples: Tensor, ) -> Tensor: r"""Sample from the posterior (with gradients) using base samples. This is intended to be used with a sampler that produces the corresponding base samples, and enables acquisition optimization via Sample Average Approximation. Args: sample_shape: A `torch.Size` object specifying the sample shape. To draw `n` samples, set to `torch.Size([n])`. To draw `b` batches of `n` samples each, set to `torch.Size([b, n])`. base_samples: The base samples, obtained from the appropriate sampler. This is a tensor of shape `sample_shape x base_sample_shape`. Returns: Samples from the posterior, a tensor of shape `self._extended_shape(sample_shape=sample_shape)`. """ samples = self._posterior.rsample_from_base_samples( sample_shape=sample_shape, base_samples=base_samples ) return self._sample_transform(samples)
[docs] def rsample( self, sample_shape: Optional[torch.Size] = None, ) -> Tensor: r"""Sample from the posterior (with gradients). Args: sample_shape: A `torch.Size` object specifying the sample shape. To draw `n` samples, set to `torch.Size([n])`. To draw `b` batches of `n` samples each, set to `torch.Size([b, n])`. Returns: Samples from the posterior, a tensor of shape `self._extended_shape(sample_shape=sample_shape)`. """ samples = self._posterior.rsample(sample_shape=sample_shape) return self._sample_transform(samples)