Source code for botorch.posteriors.deterministic

#! /usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Deterministic (degenerate) posteriors. Used in conjunction with deterministic
models.
"""

from __future__ import annotations

from typing import Optional

import torch
from torch import Tensor

from .posterior import Posterior


[docs]class DeterministicPosterior(Posterior): r"""Deterministic posterior.""" def __init__(self, values: Tensor) -> None: self.values = values @property def device(self) -> torch.device: r"""The torch device of the posterior.""" return self.values.device @property def dtype(self) -> torch.dtype: r"""The torch dtype of the posterior.""" return self.values.dtype @property def event_shape(self) -> torch.Size: r"""The event shape (i.e. the shape of a single sample).""" return self.values.shape @property def mean(self) -> Tensor: r"""The mean of the posterior as a `(b) x n x m`-dim Tensor.""" return self.values @property def variance(self) -> Tensor: r"""The variance of the posterior as a `(b) x n x m`-dim Tensor. As this is a deterministic posterior, this is a tensor of zeros. """ return torch.zeros_like(self.values)
[docs] def rsample( self, sample_shape: Optional[torch.Size] = None, base_samples: Optional[Tensor] = None, ) -> Tensor: r"""Sample from the posterior (with gradients). For the deterministic posterior, this just returns the values expanded to the requested shape. Args: sample_shape: A `torch.Size` object specifying the sample shape. To draw `n` samples, set to `torch.Size([n])`. To draw `b` batches of `n` samples each, set to `torch.Size([b, n])`. base_samples: An (optional) Tensor of `N(0, I)` base samples of appropriate dimension, typically obtained from a `Sampler`. Ignored in construction of the samples (used only for shape validation). Returns: A `sample_shape x event`-dim Tensor of samples from the posterior. """ if sample_shape is None: sample_shape = torch.Size([1]) if base_samples is not None: if base_samples.shape[: len(sample_shape)] != sample_shape: raise RuntimeError("sample_shape disagrees with shape of base_samples.") return self.values.expand(sample_shape + self.values.shape)