Source code for botorch.utils.probability.truncated_multivariate_normal

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from collections.abc import Sequence

import torch
from botorch.utils.probability.lin_ess import LinearEllipticalSliceSampler
from botorch.utils.probability.mvnxpb import MVNXPB
from botorch.utils.probability.utils import get_constants_like
from torch import Tensor
from torch.distributions.multivariate_normal import MultivariateNormal


[docs] class TruncatedMultivariateNormal(MultivariateNormal): def __init__( self, loc: Tensor, covariance_matrix: Tensor | None = None, precision_matrix: Tensor | None = None, scale_tril: Tensor | None = None, bounds: Tensor = None, solver: MVNXPB | None = None, sampler: LinearEllipticalSliceSampler | None = None, validate_args: bool | None = None, ): r"""Initializes an instance of a TruncatedMultivariateNormal distribution. Let `x ~ N(0, K)` be an `n`-dimensional Gaussian random vector. This class represents the distribution of the truncated Multivariate normal random vector `x | a <= x <= b`. Args: loc: A mean vector for the distribution, `batch_shape x event_shape`. covariance_matrix: Covariance matrix distribution parameter. precision_matrix: Inverse covariance matrix distribution parameter. scale_tril: Lower triangular, square-root covariance matrix distribution parameter. bounds: A `batch_shape x event_shape x 2` tensor of strictly increasing bounds for `x` so that `bounds[..., 0] < bounds[..., 1]` everywhere. solver: A pre-solved MVNXPB instance used to approximate the log partition. sampler: A LinearEllipticalSliceSampler instance used for sample generation. validate_args: Optional argument to super().__init__. """ if bounds is None: raise SyntaxError("Missing required argument `bounds`.") elif bounds.shape[-1] != 2: raise ValueError( f"Expected bounds.shape[-1] to be 2 but bounds shape is {bounds.shape}" ) elif torch.gt(*bounds.unbind(dim=-1)).any(): raise ValueError("`bounds` must be strictly increasing along dim=-1.") super().__init__( loc=loc, covariance_matrix=covariance_matrix, precision_matrix=precision_matrix, scale_tril=scale_tril, validate_args=validate_args, ) self.bounds = bounds self._solver = solver self._sampler = sampler
[docs] def log_prob(self, value: Tensor) -> Tensor: r"""Approximates the true log probability.""" neg_inf = get_constants_like(-float("inf"), value) inbounds = torch.logical_and( (self.bounds[..., 0] < value).all(-1), (self.bounds[..., 1] > value).all(-1), ) if inbounds.any(): return torch.where( inbounds, super().log_prob(value) - self.log_partition, neg_inf, ) return torch.full(value.shape[: -len(self.event_shape)], neg_inf)
[docs] def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor: # noqa: B008 r"""Draw samples from the Truncated Multivariate Normal. Args: sample_shape: The shape of the samples. Returns: The (sample_shape x batch_shape x event_shape) tensor of samples. """ num_samples = sample_shape.numel() if sample_shape else 1 return self.loc + self.sampler.draw(n=num_samples).view(*sample_shape, -1)
@property def log_partition(self) -> Tensor: return self.solver.log_prob @property def solver(self) -> MVNXPB: if self._solver is None: self._solver = MVNXPB( covariance_matrix=self.covariance_matrix, bounds=self.bounds - self.loc.unsqueeze(-1), ) self._solver.solve() return self._solver @property def sampler(self) -> LinearEllipticalSliceSampler: if self._sampler is None: eye = torch.eye( self.scale_tril.shape[-1], dtype=self.scale_tril.dtype, device=self.scale_tril.device, ) A = torch.concat([-eye, eye]) b = torch.concat( [ self.loc - self.bounds[..., 0], self.bounds[..., 1] - self.loc, ], dim=-1, ).unsqueeze(-1) self._sampler = LinearEllipticalSliceSampler( inequality_constraints=(A, b), covariance_root=self.scale_tril, ) return self._sampler
[docs] def expand( self, batch_shape: Sequence[int], _instance: TruncatedMultivariateNormal = None ) -> TruncatedMultivariateNormal: new = self._get_checked_instance(TruncatedMultivariateNormal, _instance) super().expand(batch_shape=batch_shape, _instance=new) new.bounds = self.bounds.expand(*new.batch_shape, *self.event_shape, 2) new._sampler = None # does not implement `expand` new._solver = ( None if self._solver is None else self._solver.expand(*batch_shape) ) return new
def __repr__(self) -> str: return super().__repr__()[:-1] + f", bounds: {self.bounds.shape})"