Source code for botorch.posteriors.posterior
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Abstract base module for all botorch posteriors.
"""
from __future__ import annotations
from abc import ABC, abstractmethod, abstractproperty
from typing import List, Optional
import torch
from torch import Tensor
class Posterior(ABC):
r"""
Abstract base class for botorch posteriors.
:meta private:
"""
@property
def base_sample_shape(self) -> torch.Size:
r"""The shape of a base sample used for constructing posterior samples.
This function may be overwritten by subclasses in case `base_sample_shape`
and `event_shape` do not agree (e.g. if the posterior is a Multivariate
Gaussian that is not full rank).
"""
return self.event_shape
@abstractproperty
def device(self) -> torch.device:
r"""The torch device of the posterior."""
pass # pragma: no cover
@abstractproperty
def dtype(self) -> torch.dtype:
r"""The torch dtype of the posterior."""
pass # pragma: no cover
@abstractproperty
def event_shape(self) -> torch.Size:
r"""The event shape (i.e. the shape of a single sample)."""
pass # pragma: no cover
@property
def mean(self) -> Tensor:
r"""The mean of the posterior as a `(b) x n x m`-dim Tensor."""
raise NotImplementedError(
f"Property `mean` not implemented for {self.__class__.__name__}"
)
@property
def variance(self) -> Tensor:
r"""The variance of the posterior as a `(b) x n x m`-dim Tensor."""
raise NotImplementedError(
f"Property `variance` not implemented for {self.__class__.__name__}"
)
@abstractmethod
def rsample(
self,
sample_shape: Optional[torch.Size] = None,
base_samples: Optional[Tensor] = None,
) -> Tensor:
r"""Sample from the posterior (with gradients).
Args:
sample_shape: A `torch.Size` object specifying the sample shape. To
draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
of `n` samples each, set to `torch.Size([b, n])`.
base_samples: An (optional) Tensor of `N(0, I)` base samples of
appropriate dimension, typically obtained from a `Sampler`.
This is used for deterministic optimization.
Returns:
A `sample_shape x event`-dim Tensor of samples from the posterior.
"""
pass # pragma: no cover
def sample(
self,
sample_shape: Optional[torch.Size] = None,
base_samples: Optional[Tensor] = None,
) -> Tensor:
r"""Sample from the posterior (without gradients).
This is a simple wrapper calling `rsample` using `with torch.no_grad()`.
Args:
sample_shape: A `torch.Size` object specifying the sample shape. To
draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
of `n` samples each, set to `torch.Size([b, n])`.
base_samples: An (optional) Tensor of `N(0, I)` base samples of
appropriate dimension, typically obtained from a `Sampler` object.
This is used for deterministic optimization.
Returns:
A `sample_shape x event_shape`-dim Tensor of samples from the posterior.
"""
with torch.no_grad():
return self.rsample(sample_shape=sample_shape, base_samples=base_samples)
[docs]class PosteriorList(Posterior):
r"""A Posterior represented by a list of independent Posteriors."""
def __init__(self, *posteriors: Posterior) -> None:
r"""A Posterior represented by a list of independent Posteriors.
Args:
*posteriors: A variable number of single-outcome posteriors.
Example:
>>> p_1 = model_1.posterior(test_X)
>>> p_2 = model_2.posterior(test_X)
>>> p_12 = PosteriorList(p_1, p_2)
Note: This is typically produced automatically in `ModelList`; it should
generally not be necessary for the end user to invoke it manually.
"""
self.posteriors = list(posteriors)
@property
def base_sample_shape(self) -> torch.Size:
r"""The shape of a base sample used for constructing posterior samples."""
base_sample_shapes = [
p.base_sample_shape
for p in self.posteriors
if p.base_sample_shape # ignore empty sample shapes
]
batch_shapes = [bss[:-1] for bss in base_sample_shapes]
if len(set(batch_shapes)) > 1:
raise NotImplementedError(
"`PosteriorList` only supported if the constituent posteriors "
f"all have the same `batch_shape`. Batch shapes: {batch_shapes}."
)
elif len(set(batch_shapes)) == 0:
# batch shapes are all zero if and only if the models
# are determinisitic
return torch.Size([])
return batch_shapes[0] + torch.Size(
[sum(bss[-1] for bss in base_sample_shapes)]
)
@property
def device(self) -> torch.device:
r"""The torch device of the posterior."""
devices = {p.device for p in self.posteriors}
if len(devices) > 1:
raise NotImplementedError( # pragma: no cover
"Multi-device posteriors are currently not supported. "
"The devices of the constituent posteriors are: {devices}."
)
return next(iter(devices))
@property
def dtype(self) -> torch.dtype:
r"""The torch dtype of the posterior."""
dtypes = {p.dtype for p in self.posteriors}
if len(dtypes) > 1:
raise NotImplementedError(
"Multi-dtype posteriors are currently not supported. "
"The dtypes of the constituent posteriors are: {dtypes}."
)
return next(iter(dtypes))
@property
def event_shape(self) -> torch.Size:
r"""The event shape (i.e. the shape of a single sample)."""
event_shapes = [p.event_shape for p in self.posteriors]
batch_shapes = [es[:-1] for es in event_shapes]
if len(set(batch_shapes)) > 1:
raise NotImplementedError(
"`PosteriorList` only supported if the constituent posteriors "
f"all have the same `batch_shape`. Batch shapes: {batch_shapes}."
)
# last dimension is the output dimension (concatenation dimension)
return batch_shapes[0] + torch.Size([sum(es[-1] for es in event_shapes)])
@property
def mean(self) -> Tensor:
r"""The mean of the posterior as a `(b) x n x m`-dim Tensor."""
return torch.cat([p.mean for p in self.posteriors], dim=-1)
@property
def variance(self) -> Tensor:
r"""The variance of the posterior as a `(b) x n x m`-dim Tensor."""
return torch.cat([p.variance for p in self.posteriors], dim=-1)
def _rsample(
self,
sample_shape: Optional[torch.Size] = None,
base_samples: Optional[Tensor] = None,
) -> List[Tensor]:
# handle the case where all posteriors have empty
# base_sample_shape
split_bss = False
base_sample_splits = [None] * len(self.posteriors)
if base_samples is not None:
split_sizes = []
for p in self.posteriors:
if p.base_sample_shape:
split_sizes.append(p.base_sample_shape[-1])
split_bss = True
else:
split_sizes.append(0)
if split_bss:
base_sample_splits = torch.split(base_samples, split_sizes, dim=-1)
base_sample_splits = [
bss if ss > 0 else None
for ss, bss in zip(split_sizes, base_sample_splits)
]
return [
p.rsample(sample_shape=sample_shape, base_samples=bss)
for p, bss in zip(self.posteriors, base_sample_splits)
]
[docs] def rsample(
self,
sample_shape: Optional[torch.Size] = None,
base_samples: Optional[Tensor] = None,
) -> Tensor:
r"""Sample from the posterior (with gradients).
Args:
sample_shape: A `torch.Size` object specifying the sample shape. To
draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
of `n` samples each, set to `torch.Size([b, n])`.
base_samples: An (optional) Tensor of `N(0, I)` base samples of
appropriate dimension, typically obtained from a `Sampler`.
This is used for deterministic optimization.
Returns:
A `sample_shape x event`-dim Tensor of samples from the posterior.
"""
samples = self._rsample(sample_shape=sample_shape, base_samples=base_samples)
return torch.cat(samples, dim=-1)