#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from collections.abc import Sequence
import torch
from botorch.utils.probability.lin_ess import LinearEllipticalSliceSampler
from botorch.utils.probability.mvnxpb import MVNXPB
from botorch.utils.probability.utils import get_constants_like
from torch import Tensor
from torch.distributions.multivariate_normal import MultivariateNormal
[docs]
class TruncatedMultivariateNormal(MultivariateNormal):
def __init__(
self,
loc: Tensor,
covariance_matrix: Tensor | None = None,
precision_matrix: Tensor | None = None,
scale_tril: Tensor | None = None,
bounds: Tensor = None,
solver: MVNXPB | None = None,
sampler: LinearEllipticalSliceSampler | None = None,
validate_args: bool | None = None,
):
r"""Initializes an instance of a TruncatedMultivariateNormal distribution.
Let `x ~ N(0, K)` be an `n`-dimensional Gaussian random vector. This class
represents the distribution of the truncated Multivariate normal random vector
`x | a <= x <= b`.
Args:
loc: A mean vector for the distribution, `batch_shape x event_shape`.
covariance_matrix: Covariance matrix distribution parameter.
precision_matrix: Inverse covariance matrix distribution parameter.
scale_tril: Lower triangular, square-root covariance matrix distribution
parameter.
bounds: A `batch_shape x event_shape x 2` tensor of strictly increasing
bounds for `x` so that `bounds[..., 0] < bounds[..., 1]` everywhere.
solver: A pre-solved MVNXPB instance used to approximate the log partition.
sampler: A LinearEllipticalSliceSampler instance used for sample generation.
validate_args: Optional argument to super().__init__.
"""
if bounds is None:
raise SyntaxError("Missing required argument `bounds`.")
elif bounds.shape[-1] != 2:
raise ValueError(
f"Expected bounds.shape[-1] to be 2 but bounds shape is {bounds.shape}"
)
elif torch.gt(*bounds.unbind(dim=-1)).any():
raise ValueError("`bounds` must be strictly increasing along dim=-1.")
super().__init__(
loc=loc,
covariance_matrix=covariance_matrix,
precision_matrix=precision_matrix,
scale_tril=scale_tril,
validate_args=validate_args,
)
self.bounds = bounds
self._solver = solver
self._sampler = sampler
[docs]
def log_prob(self, value: Tensor) -> Tensor:
r"""Approximates the true log probability."""
neg_inf = get_constants_like(-float("inf"), value)
inbounds = torch.logical_and(
(self.bounds[..., 0] < value).all(-1),
(self.bounds[..., 1] > value).all(-1),
)
if inbounds.any():
return torch.where(
inbounds,
super().log_prob(value) - self.log_partition,
neg_inf,
)
return torch.full(value.shape[: -len(self.event_shape)], neg_inf)
[docs]
def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor: # noqa: B008
r"""Draw samples from the Truncated Multivariate Normal.
Args:
sample_shape: The shape of the samples.
Returns:
The (sample_shape x batch_shape x event_shape) tensor of samples.
"""
num_samples = sample_shape.numel() if sample_shape else 1
return self.loc + self.sampler.draw(n=num_samples).view(*sample_shape, -1)
@property
def log_partition(self) -> Tensor:
return self.solver.log_prob
@property
def solver(self) -> MVNXPB:
if self._solver is None:
self._solver = MVNXPB(
covariance_matrix=self.covariance_matrix,
bounds=self.bounds - self.loc.unsqueeze(-1),
)
self._solver.solve()
return self._solver
@property
def sampler(self) -> LinearEllipticalSliceSampler:
if self._sampler is None:
eye = torch.eye(
self.scale_tril.shape[-1],
dtype=self.scale_tril.dtype,
device=self.scale_tril.device,
)
A = torch.concat([-eye, eye])
b = torch.concat(
[
self.loc - self.bounds[..., 0],
self.bounds[..., 1] - self.loc,
],
dim=-1,
).unsqueeze(-1)
self._sampler = LinearEllipticalSliceSampler(
inequality_constraints=(A, b),
covariance_root=self.scale_tril,
)
return self._sampler
[docs]
def expand(
self, batch_shape: Sequence[int], _instance: TruncatedMultivariateNormal = None
) -> TruncatedMultivariateNormal:
new = self._get_checked_instance(TruncatedMultivariateNormal, _instance)
super().expand(batch_shape=batch_shape, _instance=new)
new.bounds = self.bounds.expand(*new.batch_shape, *self.event_shape, 2)
new._sampler = None # does not implement `expand`
new._solver = (
None if self._solver is None else self._solver.expand(*batch_shape)
)
return new
def __repr__(self) -> str:
return super().__repr__()[:-1] + f", bounds: {self.bounds.shape})"