Source code for botorch.posteriors.posterior

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Abstract base module for all botorch posteriors.
"""

from __future__ import annotations

from abc import ABC, abstractmethod, abstractproperty
from typing import Optional, Tuple

import torch
from torch import Tensor


[docs] class Posterior(ABC): """Abstract base class for botorch posteriors."""
[docs] def rsample_from_base_samples( self, sample_shape: torch.Size, base_samples: Tensor, ) -> Tensor: r"""Sample from the posterior (with gradients) using base samples. This is intended to be used with a sampler that produces the corresponding base samples, and enables acquisition optimization via Sample Average Approximation. Args: sample_shape: A `torch.Size` object specifying the sample shape. To draw `n` samples, set to `torch.Size([n])`. To draw `b` batches of `n` samples each, set to `torch.Size([b, n])`. base_samples: The base samples, obtained from the appropriate sampler. This is a tensor of shape `sample_shape x base_sample_shape`. Returns: Samples from the posterior, a tensor of shape `self._extended_shape(sample_shape=sample_shape)`. """ raise NotImplementedError( f"{self.__class__.__name__} does not implement `rsample_from_base_samples`." ) # pragma: no cover
[docs] @abstractmethod def rsample( self, sample_shape: Optional[torch.Size] = None, ) -> Tensor: r"""Sample from the posterior (with gradients). Args: sample_shape: A `torch.Size` object specifying the sample shape. To draw `n` samples, set to `torch.Size([n])`. To draw `b` batches of `n` samples each, set to `torch.Size([b, n])`. Returns: Samples from the posterior, a tensor of shape `self._extended_shape(sample_shape=sample_shape)`. """ pass # pragma: no cover
[docs] def sample(self, sample_shape: Optional[torch.Size] = None) -> Tensor: r"""Sample from the posterior without gradients. Args: sample_shape: A `torch.Size` object specifying the sample shape. To draw `n` samples, set to `torch.Size([n])`. To draw `b` batches of `n` samples each, set to `torch.Size([b, n])`. Returns: Samples from the posterior, a tensor of shape `self._extended_shape(sample_shape=sample_shape)`. """ with torch.no_grad(): return self.rsample(sample_shape=sample_shape)
@abstractproperty def device(self) -> torch.device: r"""The torch device of the distribution.""" pass # pragma: no cover @abstractproperty def dtype(self) -> torch.dtype: r"""The torch dtype of the distribution.""" pass # pragma: no cover
[docs] def quantile(self, value: Tensor) -> Tensor: r"""Compute quantiles of the distribution. For multi-variate distributions, this may return the quantiles of the marginal distributions. """ raise NotImplementedError( f"{self.__class__.__name__} does not implement a `quantile` method." ) # pragma: no cover
[docs] def density(self, value: Tensor) -> Tensor: r"""The probability density (or mass) of the distribution. For multi-variate distributions, this may return the density of the marginal distributions. """ raise NotImplementedError( f"{self.__class__.__name__} does not implement a `density` method." ) # pragma: no cover
def _extended_shape( self, sample_shape: torch.Size = torch.Size() # noqa: B008 ) -> torch.Size: r"""Returns the shape of the samples produced by the posterior with the given `sample_shape`. """ raise NotImplementedError( f"{self.__class__.__name__} does not implement `_extended_shape`." ) @property def base_sample_shape(self) -> torch.Size: r"""The base shape of the base samples expected in `rsample`. Informs the sampler to produce base samples of shape `sample_shape x base_sample_shape`. """ raise NotImplementedError( f"{self.__class__.__name__} does not implement `base_sample_shape`." ) @property def batch_range(self) -> Tuple[int, int]: r"""The t-batch range. This is used in samplers to identify the t-batch component of the `base_sample_shape`. The base samples are expanded over the t-batches to provide consistency in the acquisition values, i.e., to ensure that a candidate produces same value regardless of its position on the t-batch. """ raise NotImplementedError( f"{self.__class__.__name__} does not implement `batch_range`." )