Source code for botorch.utils.probability.mvnxpb

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Bivariate conditioning algorithm for approximating Gaussian probabilities,
see [Genz2016numerical]_ and [Trinh2015bivariate]_.

.. [Trinh2015bivariate]
    G. Trinh and A. Genz. Bivariate conditioning approximations for
    multivariate normal probabilities. Statistics and Computing, 2015.

.. [Genz2016numerical]
    A. Genz and G. Tring. Numerical Computation of Multivariate Normal Probabilities
    using Bivariate Conditioning. Monte Carlo and Quasi-Monte Carlo Methods, 2016.

.. [Gibson1994monte]
    GJ. Gibson, CA Galsbey, and DA Elston. Monte Carlo evaluation of multivariate normal
    integrals and sensitivity to variate ordering. Advances in Numerical Methods and
    Applications. 1994.
"""

from __future__ import annotations

from typing import Any, TypedDict
from warnings import warn

import torch
from botorch.utils.probability.bvn import bvn, bvnmom
from botorch.utils.probability.linalg import (
    augment_cholesky,
    block_matrix_concat,
    PivotedCholesky,
)
from botorch.utils.probability.utils import (
    case_dispatcher,
    get_constants_like,
    ndtr as Phi,
    phi,
    STANDARDIZED_RANGE,
    swap_along_dim_,
)
from botorch.utils.safe_math import log as safe_log, mul as safe_mul
from linear_operator.utils.cholesky import psd_safe_cholesky
from linear_operator.utils.errors import NotPSDError
from torch import LongTensor, Tensor
from torch.nn.functional import pad


[docs] class mvnxpbState(TypedDict): step: int perm: LongTensor bounds: Tensor piv_chol: PivotedCholesky plug_ins: Tensor log_prob: Tensor log_prob_extra: Tensor | None
[docs] class MVNXPB: r"""An algorithm for approximating Gaussian probabilities `P(X \in bounds)`, where `X ~ N(0, covariance_matrix)`. """ def __init__(self, covariance_matrix: Tensor, bounds: Tensor) -> None: r"""Initializes an MVNXPB instance. Args: covariance_matrix: Covariance matrices of shape `batch_shape x [n, n]`. bounds: Tensor of lower and upper bounds, `batch_shape x [n, 2]`. These bounds are standardized internally and clipped to STANDARDIZED_RANGE. """ *batch_shape, _, n = covariance_matrix.shape device = covariance_matrix.device dtype = covariance_matrix.dtype perm = torch.arange(0, n, device=device).expand(*batch_shape, n).contiguous() # Standardize covariance matrices and bounds var = covariance_matrix.diagonal(dim1=-2, dim2=-1).unsqueeze(-1) std = var.sqrt() istd = var.rsqrt() matrix = istd * covariance_matrix * istd.transpose(-1, -2) # Clip first to avoid differentiating through `istd * inf` bounds = istd * bounds.clip(*(std * lim for lim in STANDARDIZED_RANGE)) # Initialize partial pivoted Cholesky piv_chol = PivotedCholesky( step=0, perm=perm.clone(), diag=std.squeeze(-1).clone(), tril=matrix.tril(), ) self.step = 0 self.perm = perm self.bounds = bounds self.piv_chol = piv_chol self.plug_ins = torch.full( batch_shape + [n], float("nan"), device=device, dtype=dtype ) self.log_prob = torch.zeros(batch_shape, device=device, dtype=dtype) self.log_prob_extra: Tensor | None = None
[docs] @classmethod def build( cls, step: int, perm: Tensor, bounds: Tensor, piv_chol: PivotedCholesky, plug_ins: Tensor, log_prob: Tensor, log_prob_extra: Tensor | None = None, ) -> MVNXPB: r"""Creates an MVNXPB instance from raw arguments. Unlike MVNXPB.__init__, this methods does not preprocess or copy terms. Args: step: Integer used to track the solver's progress. bounds: Tensor of lower and upper bounds, `batch_shape x [n, 2]`. piv_chol: A PivotedCholesky instance for the system. plug_ins: Tensor of plug-in estimators used to update lower and upper bounds on random variables that have yet to be integrated out. log_prob: Tensor of log probabilities. log_prob_extra: Tensor of conditional log probabilities for the next random variable. Used when integrating over an odd number of random variables. """ new = cls.__new__(cls) new.step = step new.perm = perm new.bounds = bounds new.piv_chol = piv_chol new.plug_ins = plug_ins new.log_prob = log_prob new.log_prob_extra = log_prob_extra return new
[docs] def solve(self, num_steps: int | None = None, eps: float = 1e-10) -> Tensor: r"""Runs the MVNXPB solver instance for a fixed number of steps. Calculates a bivariate conditional approximation to P(X \in bounds), where X ~ N(0, Σ). For details, see [Genz2016numerical] or [Trinh2015bivariate]_. """ if self.step > self.piv_chol.step: raise ValueError("Invalid state: solver ran ahead of matrix decomposition.") # Unpack some terms start = self.step bounds = self.bounds piv_chol = self.piv_chol L = piv_chol.tril y = self.plug_ins # Subtract marginal log probability of final term from previous result if # it did not fit in a block. ndim = y.shape[-1] if ndim > start and start % 2: self.log_prob = self.log_prob - self.log_prob_extra self.log_prob_extra = None # Iteratively compute bivariate conditional approximation zero = get_constants_like(0, L) # needed when calling `torch.where` below num_steps = num_steps or ndim - start for i in range(start, start + num_steps): should_update_chol = self.step == piv_chol.step # Determine next pivot element if should_update_chol: pivot = self.select_pivot() else: # pivot using order specified by precomputed pivoted Cholesky step mask = self.perm[..., i:] == piv_chol.perm[..., i : i + 1] pivot = i + torch.nonzero(mask, as_tuple=True)[-1] if pivot is not None and torch.any(pivot > i): self.pivot_(pivot=pivot) # Compute whitened bounds conditional on preceding plug-ins Lii = L[..., i, i].clone() if should_update_chol: Lii = Lii.clip(min=0).sqrt() # conditional stddev inv_Lii = Lii.reciprocal() bounds_i = bounds[..., i, :].clone() if i != 0: bounds_i = bounds_i - torch.sum( L[..., i, :i].clone() * y[..., :i].clone(), dim=-1, keepdim=True ) lb, ub = (inv_Lii.unsqueeze(-1) * bounds_i).unbind(dim=-1) # Initialize `i`-th plug-in value as univariate conditional expectation Phi_i = Phi(ub) - Phi(lb) small = Phi_i <= i * eps y[..., i] = case_dispatcher( # used to select next pivot out=(phi(lb) - phi(ub)) / Phi_i, cases=( # fallback cases for enhanced numerical stability (lambda: small & (lb < -9), lambda m: ub[m]), (lambda: small & (lb > 9), lambda m: lb[m]), (lambda: small, lambda m: 0.5 * (lb[m] + ub[m])), ), ) # Maybe finalize the current block if i and i % 2: h = i - 1 blk = slice(h, i + 1) Lhh = L[..., h, h].clone() Lih = L[..., i, h].clone() std_i = (Lii.square() + Lih.square()).sqrt() istds = 1 / torch.stack([Lhh, std_i], -1) blk_bounds = bounds[..., blk, :].clone() if i > 1: blk_bounds = blk_bounds - ( L[..., blk, : i - 1].clone() @ y[..., : i - 1, None].clone() ) blk_lower, blk_upper = ( pair.unbind(-1) # pair of bounds for `yh` and `yi` for pair in safe_mul(istds.unsqueeze(-1), blk_bounds).unbind(-1) ) blk_corr = Lhh * Lih * istds.prod(-1) blk_prob = bvn(blk_corr, *blk_lower, *blk_upper) zh, zi = bvnmom(blk_corr, *blk_lower, *blk_upper, p=blk_prob) # Replace 1D expectations with 2D ones `L[blk, blk]^{-1} y[..., blk]` mask = blk_prob > zero y[..., h] = torch.where(mask, zh, zero) y[..., i] = torch.where(mask, inv_Lii * (std_i * zi - Lih * zh), zero) # Update running approximation to log probability self.log_prob = self.log_prob + safe_log(blk_prob) self.step += 1 if should_update_chol: piv_chol.update_(eps=eps) # Factor in univariate probability if final term fell outside of a block. if self.step % 2: self.log_prob_extra = safe_log(Phi_i) self.log_prob = self.log_prob + self.log_prob_extra return self.log_prob
[docs] def select_pivot(self) -> LongTensor | None: r"""GGE variable prioritization strategy from [Gibson1994monte]_. Returns the index of the random variable least likely to satisfy its bounds when conditioning on the previously integrated random variables `X[:t - 1]` attaining the values of plug-in estimators `y[:t - 1]`. Equivalently, ``` argmin_{i = t, ..., n} P(X[i] \in bounds[i] | X[:t-1] = y[:t -1]), ``` where `t` denotes the current step.""" i = self.piv_chol.step L = self.piv_chol.tril bounds = self.bounds if i: bounds = bounds[..., i:, :] - L[..., i:, :i] @ self.plug_ins[..., :i, None] inv_stddev = torch.diagonal(L, dim1=-2, dim2=-1)[..., i:].clip(min=0).rsqrt() probs_1d = Phi(inv_stddev.unsqueeze(-1) * bounds).diff(dim=-1).squeeze(-1) return i + torch.argmin(probs_1d, dim=-1)
[docs] def pivot_(self, pivot: LongTensor) -> None: r"""Swap random variables at `pivot` and `step` positions.""" step = self.step if self.piv_chol.step == step: self.piv_chol.pivot_(pivot) elif self.step > self.piv_chol.step: raise ValueError for tnsr in (self.perm, self.bounds): swap_along_dim_(tnsr, i=self.step, j=pivot, dim=pivot.ndim)
def __getitem__(self, key: Any) -> MVNXPB: return self.build( step=self.step, perm=self.perm[key], bounds=self.bounds[key], piv_chol=self.piv_chol[key], plug_ins=self.plug_ins[key], log_prob=self.log_prob[key], log_prob_extra=( None if self.log_prob_extra is None else self.log_prob_extra[key] ), )
[docs] def concat(self, other: MVNXPB, dim: int) -> MVNXPB: if not isinstance(other, MVNXPB): raise TypeError( f"Expected `other` to be {type(self)} typed but was {type(other)}." ) batch_ndim = self.log_prob.ndim if dim > batch_ndim or dim < -batch_ndim: raise ValueError(f"`dim={dim}` is not a valid batch dimension.") state_dict = self.asdict() for key, _other in other.asdict().items(): _self = state_dict.get(key) if _self is None and _other is None: continue if type(_self) is not type(_other): raise TypeError( f"Concatenation failed: `self.{key}` has type {type(_self)}, " f"but `other.{key}` is of type {type(_self)}." ) if isinstance(_self, PivotedCholesky): state_dict[key] = _self.concat(_other, dim=dim) elif isinstance(_self, Tensor): state_dict[key] = torch.concat((_self, _other), dim=dim) elif _self != _other: raise ValueError( f"Concatenation failed: `self.{key}` does not equal `other.{key}`." ) return self.build(**state_dict)
[docs] def expand(self, *sizes: int) -> MVNXPB: state_dict = self.asdict() state_dict["piv_chol"] = state_dict["piv_chol"].expand(*sizes) for name, ndim in { "bounds": 2, "perm": 1, "plug_ins": 1, "log_prob": 0, "log_prob_extra": 0, }.items(): src = state_dict[name] if isinstance(src, Tensor): state_dict[name] = src.expand( sizes + src.shape[-ndim:] if ndim else sizes ) return self.build(**state_dict)
[docs] def augment( self, covariance_matrix: Tensor, bounds: Tensor, cross_covariance_matrix: Tensor, disable_pivoting: bool = False, jitter: float | None = None, max_tries: int | None = None, ) -> MVNXPB: r"""Augment an `n`-dimensional MVNXPB instance to include `m` additional random variables. """ n = self.perm.shape[-1] m = covariance_matrix.shape[-1] if n != self.piv_chol.step: raise NotImplementedError( "Augmentation of incomplete solutions not implemented yet." ) var = covariance_matrix.diagonal(dim1=-2, dim2=-1).unsqueeze(-1) std = var.sqrt() istd = var.rsqrt() Kmn = istd * cross_covariance_matrix if self.piv_chol.diag is None: diag = pad(std.squeeze(-1), (cross_covariance_matrix.shape[-1], 0), value=1) else: Kmn = Kmn * (1 / self.piv_chol.diag).unsqueeze(-2) diag = torch.concat([self.piv_chol.diag, std.squeeze(-1)], -1) # Augment partial pivoted Cholesky factor Kmm = istd * covariance_matrix * istd.transpose(-1, -2) Lnn = self.piv_chol.tril try: L = augment_cholesky(Laa=Lnn, Kba=Kmn, Kbb=Kmm, jitter=jitter) except NotPSDError: warn("Joint covariance matrix not positive definite, attempting recovery.") Knn = Lnn @ Lnn.transpose(-1, -2) Knm = Kmn.transpose(-1, -2) K = block_matrix_concat(blocks=((Knn, Knm), (Kmn, Kmm))) L = psd_safe_cholesky(K, jitter=jitter, max_tries=max_tries) if not disable_pivoting: Lmm = L[..., n:, n:].clone() L[..., n:, n:] = (Lmm @ Lmm.transpose(-2, -1)).tril() _bounds = istd * bounds.clip(*(std * lim for lim in STANDARDIZED_RANGE)) _perm = torch.arange(n, n + m, dtype=self.perm.dtype, device=self.perm.device) _perm = _perm.expand(covariance_matrix.shape[:-2] + (m,)) piv_chol = PivotedCholesky( step=n + m if disable_pivoting else n, tril=L.contiguous(), perm=torch.cat([self.piv_chol.perm, _perm], dim=-1).contiguous(), diag=diag, ) return self.build( step=self.step, perm=torch.cat([self.perm, _perm], dim=-1), bounds=torch.cat([self.bounds, _bounds], dim=-2), piv_chol=piv_chol, plug_ins=pad(self.plug_ins, (0, m), value=float("nan")), log_prob=self.log_prob, log_prob_extra=self.log_prob_extra, )
[docs] def detach(self) -> MVNXPB: state_dict = self.asdict() for key, obj in state_dict.items(): if isinstance(obj, (PivotedCholesky, Tensor)): state_dict[key] = obj.detach() return self.build(**state_dict)
[docs] def clone(self) -> MVNXPB: state_dict = self.asdict() for key, obj in state_dict.items(): if isinstance(obj, (PivotedCholesky, Tensor)): state_dict[key] = obj.clone() return self.build(**state_dict)
[docs] def asdict(self) -> mvnxpbState: return mvnxpbState( step=self.step, perm=self.perm, bounds=self.bounds, piv_chol=self.piv_chol, plug_ins=self.plug_ins, log_prob=self.log_prob, log_prob_extra=self.log_prob_extra, )