Source code for botorch.posteriors.torch
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Posterior module to be used with PyTorch distributions.
"""
from __future__ import annotations
from typing import Any, Dict, Optional
import torch
from botorch.posteriors.posterior import Posterior
from torch import Tensor
from torch.distributions.distribution import Distribution
[docs]
class TorchPosterior(Posterior):
r"""A posterior based on a PyTorch Distribution.
NOTE: For any attribute that is not explicitly defined on the Posterior level, this
returns the corresponding attribute of the distribution. This allows easy access
to the distribution attributes, without having to expose them on the Posterior.
"""
def __init__(self, distribution: Distribution) -> None:
r"""A posterior based on a PyTorch Distribution.
Args:
distribution: A PyTorch Distribution object.
"""
self.distribution = distribution
# Get the device and dtype from distribution attributes.
for attr in vars(distribution).values():
if isinstance(attr, Tensor):
self._device = attr.device
self._dtype = attr.dtype
break
[docs]
def rsample(
self,
sample_shape: Optional[torch.Size] = None,
) -> Tensor:
r"""Sample from the posterior (with gradients).
This is generally used with a sampler that produces the base samples.
Args:
sample_shape: A `torch.Size` object specifying the sample shape. To
draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
of `n` samples each, set to `torch.Size([b, n])`.
Returns:
Samples from the posterior, a tensor of shape
`self._extended_shape(sample_shape=sample_shape)`.
"""
if sample_shape is None:
sample_shape = torch.Size()
return self.distribution.rsample(sample_shape=sample_shape)
@property
def device(self) -> torch.device:
r"""The torch device of the distribution."""
return self._device
@property
def dtype(self) -> torch.dtype:
r"""The torch dtype of the distribution."""
return self._dtype
def __getattr__(self, name: str) -> Any:
r"""A catch-all for attributes not defined on the posterior level.
Returns the attributes of the distribution instead.
"""
return getattr(self.distribution, name)
def __getstate__(self) -> Dict[str, Any]:
r"""A minimal utility to support pickle protocol.
Pickle uses `__get/setstate__` to serialize / deserialize the objects.
Since we define `__getattr__` above, it takes precedence over these
methods, and we end up in an infinite loop unless we also define
`__getstate__` and `__setstate__`.
"""
return self.__dict__
def __setstate__(self, d: Dict[str, Any]) -> None:
r"""A minimal utility to support pickle protocol."""
self.__dict__ = d
[docs]
def quantile(self, value: Tensor) -> Tensor:
r"""Compute quantiles of the distribution.
For multi-variate distributions, this may return the quantiles of
the marginal distributions.
"""
if value.numel() > 1:
return torch.stack([self.quantile(v) for v in value], dim=0)
return self.icdf(value)
[docs]
def density(self, value: Tensor) -> Tensor:
r"""The probability density (or mass if discrete) of the distribution.
For multi-variate distributions, this may return the density of
the marginal distributions.
"""
if value.numel() > 1:
return torch.stack([self.density(v) for v in value], dim=0)
return self.log_prob(value).exp()
def _extended_shape(
self, sample_shape: torch.Size = torch.Size() # noqa: B008
) -> torch.Size:
r"""Returns the shape of the samples produced by the distribution with
the given `sample_shape`.
"""
return self.distribution._extended_shape(sample_shape=sample_shape)