Source code for botorch.posteriors.deterministic
#! /usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Deterministic (degenerate) posteriors. Used in conjunction with deterministic
models.
"""
from typing import Optional
import torch
from torch import Tensor
from .posterior import Posterior
[docs]class DeterministicPosterior(Posterior):
r"""Deterministic posterior."""
def __init__(self, values: Tensor) -> None:
self.values = values
@property
def device(self) -> torch.device:
r"""The torch device of the posterior."""
return self.values.device
@property
def dtype(self) -> torch.dtype:
r"""The torch dtype of the posterior."""
return self.values.dtype
@property
def event_shape(self) -> torch.Size:
r"""The event shape (i.e. the shape of a single sample)."""
return self.values.shape
@property
def mean(self) -> Tensor:
r"""The mean of the posterior as a `(b) x n x m`-dim Tensor."""
return self.values
@property
def variance(self) -> Tensor:
r"""The variance of the posterior as a `(b) x n x m`-dim Tensor.
As this is a deterministic posterior, this is a tensor of zeros.
"""
return torch.zeros_like(self.values)
[docs] def rsample(
self,
sample_shape: Optional[torch.Size] = None,
base_samples: Optional[Tensor] = None,
) -> Tensor:
r"""Sample from the posterior (with gradients).
For the deterministic posterior, this just returns the values expanded
to the requested shape.
Args:
sample_shape: A `torch.Size` object specifying the sample shape. To
draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
of `n` samples each, set to `torch.Size([b, n])`.
base_samples: An (optional) Tensor of `N(0, I)` base samples of
appropriate dimension, typically obtained from a `Sampler`.
Ignored in construction of the samples (used only for shape
validation).
Returns:
A `sample_shape x event`-dim Tensor of samples from the posterior.
"""
if sample_shape is None:
sample_shape = torch.Size([1])
if base_samples is not None:
if base_samples.shape[: len(sample_shape)] != sample_shape:
raise RuntimeError("sample_shape disagrees with shape of base_samples.")
return self.values.expand(sample_shape + self.values.shape)