Source code for botorch.posteriors.posterior

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Abstract base module for all botorch posteriors.
"""

from __future__ import annotations

from abc import ABC, abstractmethod, abstractproperty
from typing import Optional

import torch
from torch import Tensor


[docs]class Posterior(ABC): r"""Abstract base class for botorch posteriors.""" @property def base_sample_shape(self) -> torch.Size: r"""The shape of a base sample used for constructing posterior samples. This function may be overwritten by subclasses in case `base_sample_shape` and `event_shape` do not agree (e.g. if the posterior is a Multivariate Gaussian that is not full rank). """ return self.event_shape @abstractproperty def device(self) -> torch.device: r"""The torch device of the posterior.""" pass # pragma: no cover @abstractproperty def dtype(self) -> torch.dtype: r"""The torch dtype of the posterior.""" pass # pragma: no cover @abstractproperty def event_shape(self) -> torch.Size: r"""The event shape (i.e. the shape of a single sample).""" pass # pragma: no cover @property def mean(self) -> Tensor: r"""The mean of the posterior as a `(b) x n x m`-dim Tensor.""" raise NotImplementedError( f"Property `mean` not implemented for {self.__class__.__name__}" ) @property def variance(self) -> Tensor: r"""The variance of the posterior as a `(b) x n x m`-dim Tensor.""" raise NotImplementedError( f"Property `variance` not implemented for {self.__class__.__name__}" )
[docs] @abstractmethod def rsample( self, sample_shape: Optional[torch.Size] = None, base_samples: Optional[Tensor] = None, ) -> Tensor: r"""Sample from the posterior (with gradients). Args: sample_shape: A `torch.Size` object specifying the sample shape. To draw `n` samples, set to `torch.Size([n])`. To draw `b` batches of `n` samples each, set to `torch.Size([b, n])`. base_samples: An (optional) Tensor of `N(0, I)` base samples of appropriate dimension, typically obtained from a `Sampler`. This is used for deterministic optimization. Returns: A `sample_shape x event`-dim Tensor of samples from the posterior. """ pass # pragma: no cover
[docs] def sample( self, sample_shape: Optional[torch.Size] = None, base_samples: Optional[Tensor] = None, ) -> Tensor: r"""Sample from the posterior (without gradients). This is a simple wrapper calling `rsample` using `with torch.no_grad()`. Args: sample_shape: A `torch.Size` object specifying the sample shape. To draw `n` samples, set to `torch.Size([n])`. To draw `b` batches of `n` samples each, set to `torch.Size([b, n])`. base_samples: An (optional) Tensor of `N(0, I)` base samples of appropriate dimension, typically obtained from a `Sampler` object. This is used for deterministic optimization. Returns: A `sample_shape x event_shape`-dim Tensor of samples from the posterior. """ with torch.no_grad(): return self.rsample(sample_shape=sample_shape, base_samples=base_samples)