#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""Linear Elliptical Slice Sampler.
References
.. [Gessner2020]
A. Gessner, O. Kanjilal, and P. Hennig. Integrals over gaussians under
linear domain constraints. AISTATS 2020.
.. [Wu2024]
K. Wu, and J. Gardner. A Fast, Robust Elliptical Slice Sampling Implementation for
Linearly Truncated Multivariate Normal Distributions. arXiv:2407.10449. 2024.
This implementation is based (with multiple changes / optimiations) on
the following implementations based on the algorithm in [Gessner2020]_:
- https://github.com/alpiges/LinConGauss
- https://github.com/wjmaddox/pytorch_ess
In addition, the active intervals (from which the angle is sampled) are computed using
the improved algorithm described in [Wu2024]_:
https://github.com/kayween/linear-ess
The implementation here differentiates itself from the original implementations with:
1) Support for fixed feature equality constraints.
2) Support for non-standard Normal distributions.
3) Numerical stability improvements, especially relevant for high-dimensional cases.
4) Support multiple Markov chains running in parallel.
"""
from __future__ import annotations
import math
from typing import List, Optional, Tuple, Union
import torch
from botorch.utils.sampling import PolytopeSampler
from linear_operator.operators import DiagLinearOperator, LinearOperator
from torch import Tensor
_twopi = 2.0 * math.pi
[docs]
class LinearEllipticalSliceSampler(PolytopeSampler):
r"""Linear Elliptical Slice Sampler.
Ideas:
- Optimize computations if possible, potentially with torch.compile.
- Extend fixed features constraint to general linear equality constraints.
"""
def __init__(
self,
inequality_constraints: Optional[Tuple[Tensor, Tensor]] = None,
bounds: Optional[Tensor] = None,
interior_point: Optional[Tensor] = None,
fixed_indices: Optional[Union[List[int], Tensor]] = None,
mean: Optional[Tensor] = None,
covariance_matrix: Optional[Union[Tensor, LinearOperator]] = None,
covariance_root: Optional[Union[Tensor, LinearOperator]] = None,
check_feasibility: bool = False,
burnin: int = 0,
thinning: int = 0,
num_chains: int = 1,
) -> None:
r"""Initialize LinearEllipticalSliceSampler.
Args:
inequality_constraints: Tensors `(A, b)` describing inequality constraints
`A @ x <= b`, where `A` is an `n_ineq_con x d`-dim Tensor and `b` is
an `n_ineq_con x 1`-dim Tensor, with `n_ineq_con` the number of
inequalities and `d` the dimension of the sample space. If omitted,
must provide `bounds` instead.
bounds: A `2 x d`-dim tensor of box bounds. If omitted, must provide
`inequality_constraints` instead.
interior_point: A `d x 1`-dim Tensor presenting a point in the (relative)
interior of the polytope. If omitted, an interior point is determined
automatically by solving a Linear Program. Note: It is crucial that
the point lie in the interior of the feasible set (rather than on the
boundary), otherwise the sampler will produce invalid samples.
fixed_indices: Integer list or `d`-dim Tensor representing the indices of
dimensions that are constrained to be fixed to the values specified in
the `interior_point`, which is required to be passed in conjunction with
`fixed_indices`.
mean: The `d x 1`-dim mean of the MVN distribution (if omitted, use zero).
covariance_matrix: The `d x d`-dim covariance matrix of the MVN
distribution (if omitted, use the identity).
covariance_root: A `d x d`-dim root of the covariance matrix such that
covariance_root @ covariance_root.T = covariance_matrix. NOTE: This
matrix is assumed to be lower triangular. covariance_root can only be
passed in conjunction with fixed_indices if covariance_root is a
DiagLinearOperator. Otherwise the factorization would need to be re-
computed, as we need to solve in `standardize`.
check_feasibility: If True, raise an error if the sampling results in an
infeasible sample. This creates some overhead and so is switched off
by default.
burnin: Number of samples to generate upon initialization to warm up the
sampler.
thinning: Number of samples to skip before returning a sample in `draw`.
num_chains: Number of Markov chains to run in parallel.
This sampler samples from a multivariante Normal `N(mean, covariance_matrix)`
subject to linear domain constraints `A x <= b` (intersected with box bounds,
if provided).
"""
if interior_point is not None and interior_point.ndim == 1:
interior_point = interior_point.unsqueeze(-1)
if mean is not None and mean.ndim == 1:
mean = mean.unsqueeze(-1)
super().__init__(
inequality_constraints=inequality_constraints,
# TODO: Support equality constraints?
interior_point=interior_point,
bounds=bounds,
)
tkwargs = {"device": self.x0.device, "dtype": self.x0.dtype}
if covariance_matrix is not None and covariance_root is not None:
raise ValueError(
"Provide either covariance_matrix or covariance_root, not both."
)
# can't unpack inequality constraints directly if bounds are passed
A, b = self.A, self.b
self._Az, self._bz = A, b
self._is_fixed, self._not_fixed = None, None
if fixed_indices is not None:
mean, covariance_matrix, covariance_root = (
self._fixed_features_initialization(
A=A,
b=b,
interior_point=interior_point,
fixed_indices=fixed_indices,
mean=mean,
covariance_matrix=covariance_matrix,
covariance_root=covariance_root,
)
)
self._mean = mean
# Have to delay factorization until after fixed features initialization.
if covariance_matrix is not None: # implies root is None
covariance_root, info = torch.linalg.cholesky_ex(covariance_matrix)
not_psd = torch.any(info)
if not_psd:
raise ValueError(
"Covariance matrix is not positive definite. "
"Currently only non-degenerate distributions are supported."
)
self._covariance_root = covariance_root
# Rewrite the constraints as a system that constrains a standard Normal.
self._standardization_initialization()
# state of the sampler ("current point")
self._x = self.x0.clone()
self._z = self._transform(self._x)
# Expand the shape to (d, num_chains) for running parallel Markov chains.
if num_chains > 1:
self._z = self._z.expand(-1, num_chains).clone()
# We will need the following repeatedly, let's allocate them once
self.zeros = torch.zeros((num_chains, 1), **tkwargs)
self.ones = torch.ones((num_chains, 1), **tkwargs)
self.indices_batch = torch.arange(
num_chains, dtype=torch.int64, device=tkwargs["device"]
)
self.check_feasibility = check_feasibility
self._lifetime_samples = 0
if burnin > 0:
self.thinning = 0
self.draw(burnin)
self.thinning = thinning
def _fixed_features_initialization(
self,
A: Tensor,
b: Tensor,
interior_point: Optional[Tensor],
fixed_indices: Union[List[int], Tensor],
mean: Optional[Tensor],
covariance_matrix: Optional[Tensor],
covariance_root: Optional[Tensor],
) -> Tuple[Optional[Tensor], Optional[Tensor]]:
"""Modifies the constraint system (A, b) due to fixed indices and assigns
the modified constraints system to `self._Az`, `self._bz`. NOTE: Needs to be
called prior to `self._standardization_initialization` in the constructor.
covariance_root and fixed_indices can both not be None only if covariance_root
is a DiagLinearOperator. Otherwise, the covariance matrix would need to be
refactorized.
Returns:
Tuple of `mean` and `covariance_matrix` tensors of the non-fixed dimensions.
"""
if interior_point is None:
raise ValueError(
"If `fixed_indices` are provided, an interior point must also be "
"provided in order to infer feasible values of the fixed features."
)
root_is_diag = isinstance(covariance_root, DiagLinearOperator)
if covariance_root is not None and not root_is_diag:
root_is_diag = (covariance_root.diag().diag() == covariance_root).all()
if root_is_diag: # convert the diagonal root to a DiagLinearOperator
covariance_root = DiagLinearOperator(covariance_root.diagonal())
else: # otherwise, fail
raise ValueError(
"Provide either covariance_root or fixed_indices, not both."
)
d = interior_point.shape[0]
is_fixed, not_fixed = get_index_tensors(fixed_indices=fixed_indices, d=d)
self._is_fixed = is_fixed
self._not_fixed = not_fixed
# Transforming constraint system to incorporate fixed features:
# A @ x - b = (A[:, fixed] @ x[fixed] + A[:, not fixed] @ x[not fixed]) - b
# = A[:, not fixed] @ x[not fixed] - (b - A[:, fixed] @ x[fixed])
# = Az @ z - bz
self._Az = A[:, not_fixed]
self._bz = b - A[:, is_fixed] @ interior_point[is_fixed]
if mean is not None:
mean = mean[not_fixed]
if covariance_matrix is not None: # subselect active dimensions
covariance_matrix = covariance_matrix[
not_fixed.unsqueeze(-1), not_fixed.unsqueeze(0)
]
if root_is_diag: # in the special case of diagonal root, can subselect
covariance_root = DiagLinearOperator(covariance_root.diagonal()[not_fixed])
return mean, covariance_matrix, covariance_root
def _standardization_initialization(self) -> None:
"""For non-standard mean and covariance, we're going to rewrite the problem as
sampling from a standard normal distribution subject to modified constraints.
A @ x - b = A @ (covar_root @ z + mean) - b
= (A @ covar_root) @ z - (b - A @ mean)
= _Az @ z - _bz
NOTE: We need to standardize bz before Az in the following, because it relies
on the untransformed Az. We can't simply use A instead because Az might have
been subject to the fixed features transformation.
"""
if self._mean is not None:
self._bz = self._bz - self._Az @ self._mean
if self._covariance_root is not None:
self._Az = self._Az @ self._covariance_root
@property
def lifetime_samples(self) -> int:
"""The total number of samples generated by the sampler during its lifetime."""
return self._lifetime_samples
[docs]
def draw(self, n: int = 1) -> Tensor:
r"""Draw samples.
Args:
n: The number of samples.
Returns:
A `(n * num_chains) x d`-dim tensor of `n * num_chains` samples.
"""
samples = []
for _ in range(n):
for _ in range(self.thinning):
self.step()
samples.append(self.step())
return torch.cat(samples, dim=-1).transpose(-1, -2)
[docs]
def step(self) -> Tensor:
r"""Take a step, return the new sample, update the internal state.
Returns:
A `d x num_chains`-dim tensor, where each column is a sample from a Markov
chain.
"""
nu = torch.randn_like(self._z)
theta = self._draw_angle(nu=nu)
self._z = z = self._get_cart_coords(nu=nu, theta=theta)
self._x = x = self._untransform(z)
self._lifetime_samples += 1
if self.check_feasibility and (not self._is_feasible(self._x).all()):
Axmb = self.A @ self._x - self.b
violated_indices = Axmb > 0
raise RuntimeError(
"Sampling resulted in infeasible point. \n\t- Number "
f"of violated constraints: {violated_indices.sum()}."
f"\n\t- Magnitude of violations: {Axmb[violated_indices]}"
"\n\t- If the error persists, please report this bug on GitHub."
)
return x
def _draw_angle(self, nu: Tensor) -> Tensor:
r"""Draw the rotation angle.
Args:
nu: A `d x num_chains`-dim tensor (the "new" direction, drawn from N(0, I)).
Returns:
A `num_chains`-dim Tensor containing the rotation angle (radians).
"""
left, right = self._find_active_intersection_angles(nu)
left, right = self._trim_intervals(left, right)
# If left[i, j] <= right[i, j], then [left[i, j], right[i, j]] is an active
# interval. On the other hand, if left[i, j] > right[i, j], then they are both
# dummy variables and should be discarded. Thus, we clamp their difference so
# that they do not contribute to the cumulative length.
csum = right.sub(left).clamp(min=0.0).cumsum(dim=-1)
u = csum[:, -1] * torch.rand(
right.size(-2), dtype=right.dtype, device=right.device
)
# The returned index i satisfies csum[i - 1] < u <= csum[i]
idx = torch.searchsorted(csum, u.unsqueeze(-1)).squeeze(-1)
# Do a zero padding so that padded_csum[i] = csum[i - 1]
padded_csum = torch.cat([self.zeros, csum], dim=-1)
return u - padded_csum[self.indices_batch, idx] + left[self.indices_batch, idx]
def _get_cart_coords(self, nu: Tensor, theta: Tensor) -> Tensor:
r"""Determine location on the ellipse in Cartesian coordinates.
Args:
nu: A `d x num_chains`-dim tensor (the "new" direction, drawn from N(0, I)).
theta: A `num_chains`-dim tensor of angles.
Returns:
A `d x num_chains`-dim tensor of samples from the domain in Cartesian
coordinates.
"""
return self._z * torch.cos(theta) + nu * torch.sin(theta)
def _trim_intervals(self, left: Tensor, right: Tensor) -> Tuple[Tensor, Tensor]:
"""Trim the intervals by a small positive constant. This encourages the Markov
chain to stay in the interior of the domain.
"""
gap = torch.clamp(right - left, min=0.0)
eps = gap.mul(0.25).clamp(max=1e-6 if gap.dtype == torch.float32 else 1e-12)
return left + eps, right - eps
def _find_active_intersection_angles(self, nu: Tensor) -> Tuple[Tensor, Tensor]:
"""Construct the active intersection angles.
Args:
nu: A `d x num_chains`-dim tensor (the "new" direction, drawn from N(0, I)).
Returns:
A tuple (left, right) of two tensors of size `num_chains x m` representing
the active intersection angles. For the i-th Markov chain and the j-th
constraint, a pair of angles left[i, j] and right[i, j] is active if and
only if left[i, j] <= right[i, j]. If left[i, j] > right[i, j], they are
inactive and should be ignored.
"""
alpha, beta = self._find_intersection_angles(nu)
# It's easier to put `num_chains` as the first dimension,
# because `torch.searchsorted` only supports searching in the last dimension
alpha, beta = alpha.T, beta.T
srted, indices = torch.sort(alpha, descending=False)
cummax = beta[self.indices_batch.unsqueeze(-1), indices].cummax(dim=-1).values
srted = torch.cat([srted, self.ones * 2 * math.pi], dim=-1)
cummax = torch.cat([self.zeros, cummax], dim=-1)
return cummax, srted
def _find_intersection_angles(self, nu: Tensor) -> Tuple[Tensor, Tensor]:
"""Compute all 2 * m intersections of the ellipse and the domain, where
`m = n_ineq_con` is the number of inequality constraints defining the domain.
If the i-th linear inequality constraint has no intersection with the ellipse,
we will create two dummy intersection angles alpha_i = beta_i = 0.
Args:
nu: A `d x num_chains`-dim tensor (the "new" direction, drawn from N(0, I)).
Returns:
A tuple of two tensors with the same size `m x num_chains`. The first tensor
represents the smaller intersection angles. The second tensor represents the
larger intersection angles.
"""
p = self._Az @ self._z
q = self._Az @ nu
radius = torch.sqrt(p**2 + q**2)
ratio = self._bz / radius
has_solution = ratio < 1.0
arccos = torch.arccos(ratio)
arccos[~has_solution] = 0.0
arctan = torch.arctan2(q, p)
theta1 = arctan + arccos
theta2 = arctan - arccos
# translate every angle to [0, 2 * pi]
theta1 = theta1 + theta1.lt(0.0) * _twopi
theta2 = theta2 + theta2.lt(0.0) * _twopi
alpha = torch.minimum(theta1, theta2)
beta = torch.maximum(theta1, theta2)
return alpha, beta
def _is_feasible(self, points: Tensor, transformed: bool = False) -> Tensor:
r"""Returns a Boolean tensor indicating whether the `points` are feasible,
i.e. they satisfy `A @ points <= b`, where `(A, b)` are the tensors passed
as the `inequality_constraints` to the constructor of the sampler.
Args:
points: A `d x M`-dim tensor of points.
transformed: Wether points are assumed to be transformed by a change of
basis, which means feasibility should be computed based on the
transformed constraint system (_Az, _bz), instead of (A, b).
Returns:
An `M`-dim binary tensor where `True` indicates that the associated
point is feasible.
"""
A, b = (self._Az, self._bz) if transformed else (self.A, self.b)
return (A @ points <= b).all(dim=0)
def _transform(self, x: Tensor) -> Tensor:
"""Transforms the input so that it is equivalent to a standard Normal variable
constrained with the modified system constraints (self._Az, self._bz).
Args:
x: The input tensor to be transformed, usually `d x 1`-dimensional.
Returns:
A `d x 1`-dimensional tensor of transformed values subject to the modified
system of constraints.
"""
if self._not_fixed is not None:
x = x[self._not_fixed]
return self._standardize(x)
def _untransform(self, z: Tensor) -> Tensor:
"""The inverse transform of the `_transform`, i.e. maps `z` back to the original
space where it is subject to the original constraint system (self.A, self.b).
Args:
z: The transformed tensor to be un-transformed, usually `d x 1`-dimensional.
Returns:
A `d x 1`-dimensional tensor of un-transformed values subject to the
original system of constraints.
"""
if self._is_fixed is None:
return self._unstandardize(z)
else:
x = self._x.clone() # _x already contains the fixed values
x[self._not_fixed] = self._unstandardize(z)
return x
def _standardize(self, x: Tensor) -> Tensor:
"""_transform helper standardizing the input `x`, which is assumed to be a
`d x 1`-dim Tensor, or a `len(self._not_fixed) x 1`-dim if there are fixed
indices.
"""
z = x
if self._mean is not None:
z = z - self._mean
root = self._covariance_root
if root is not None:
z = torch.linalg.solve_triangular(root, z, upper=False)
return z
def _unstandardize(self, z: Tensor) -> Tensor:
"""_untransform helper un-standardizing the input `z`, which is assumed to be a
`d x 1`-dim Tensor, or a `len(self._not_fixed) x 1`-dim if there are fixed
indices.
"""
x = z
if self._covariance_root is not None:
x = self._covariance_root @ x
if self._mean is not None:
x = x + self._mean
return x
[docs]
def get_index_tensors(
fixed_indices: Union[List[int], Tensor], d: int
) -> Tuple[Tensor, Tensor]:
"""Converts `fixed_indices` to a `d`-dim integral Tensor that is True at indices
that are contained in `fixed_indices` and False otherwise.
Args:
fixed_indices: A list or Tensoro of integer indices to fix.
d: The dimensionality of the Tensors to be indexed.
Returns:
A Tuple of integral Tensors partitioning [1, d] into indices that are fixed
(first tensor) and non-fixed (second tensor).
"""
is_fixed = torch.as_tensor(fixed_indices)
dtype, device = is_fixed.dtype, is_fixed.device
dims = torch.arange(d, dtype=dtype, device=device)
not_fixed = torch.tensor([i for i in dims if i not in is_fixed], dtype=dtype)
return is_fixed, not_fixed