Source code for botorch.posteriors.fully_bayesian

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


from functools import lru_cache
from typing import Callable, List, Optional

import torch
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import PosteriorList
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from torch import Tensor


MCMC_DIM = -3  # Location of the MCMC batch dimension
TOL = 1e-6  # Bisection tolerance


[docs]def batched_bisect( f: Callable, target: float, bounds: Tensor, tol: float = TOL, max_steps: int = 32 ): r"""Batched bisection with a fixed number of steps. Args: f: Target function that takes a `(b1 x ... x bk)`-dim tensor and returns a `(b1 x ... x bk)`-dim tensor. target: Scalar target value of type float. bounds: Lower and upper bounds, of size `2 x b1 x ... x bk`. tol: We termniate if all elements satisfy are within `tol` of the `target`. max_steps: Maximum number of bisection steps. Returns: Tensor X of size `b1 x ... x bk` such that `f(X) = target`. """ # Make sure target is actually contained in the interval f1, f2 = f(bounds[0]), f(bounds[1]) if not ((f1 <= target) & (target <= f2)).all(): raise ValueError( "The target is not contained in the interval specified by the bounds" ) bounds = bounds.clone() # Will be modified in-place center = bounds.mean(dim=0) f_center = f(center) for _ in range(max_steps): go_left = f_center > target bounds[1, go_left] = center[go_left] bounds[0, ~go_left] = center[~go_left] center = bounds.mean(dim=0) f_center = f(center) # Check convergence if (f_center - target).abs().max() <= tol: return center return center
[docs]class FullyBayesianPosterior(GPyTorchPosterior): r"""A posterior for a fully Bayesian model. The MCMC batch dimension that corresponds to the models in the mixture is located at `MCMC_DIM` (defined at the top of this file). Note that while each MCMC sample corresponds to a Gaussian posterior, the fully Bayesian posterior is rather a mixture of Gaussian distributions. We provide convenience properties/methods for computing the mean, variance, median, and quantiles of this mixture. """ def __init__(self, mvn: MultivariateNormal) -> None: r"""A posterior for a fully Bayesian model. Args: mvn: A GPyTorch MultivariateNormal (single-output case) """ super().__init__(mvn=mvn) self._mean = mvn.mean if self._is_mt else mvn.mean.unsqueeze(-1) self._variance = mvn.variance if self._is_mt else mvn.variance.unsqueeze(-1) @property @lru_cache(maxsize=None) def mixture_mean(self) -> Tensor: r"""The posterior mean for the mixture of models.""" return self._mean.mean(dim=MCMC_DIM) @property @lru_cache(maxsize=None) def mixture_variance(self) -> Tensor: r"""The posterior variance for the mixture of models.""" num_mcmc_samples = self.mean.shape[MCMC_DIM] t1 = self._variance.sum(dim=MCMC_DIM) / num_mcmc_samples t2 = self._mean.pow(2).sum(dim=MCMC_DIM) / num_mcmc_samples t3 = -(self._mean.sum(dim=MCMC_DIM) / num_mcmc_samples).pow(2) return t1 + t2 + t3 @property @lru_cache(maxsize=None) def mixture_median(self) -> Tensor: r"""The posterior median for the mixture of models.""" return self.mixture_quantile(q=0.5)
[docs] @lru_cache(maxsize=None) def mixture_quantile(self, q: float) -> Tensor: r"""The posterior quantiles for the mixture of models.""" if not isinstance(q, float): raise ValueError("q is expected to be a float.") if q <= 0 or q >= 1: raise ValueError("q is expected to be in the range (0, 1).") q_tensor = torch.tensor(q).to(self.mean) dist = torch.distributions.Normal(loc=self.mean, scale=self.variance.sqrt()) if self.mean.shape[MCMC_DIM] == 1: # Analytical solution return dist.icdf(q_tensor).squeeze(MCMC_DIM) low = dist.icdf(q_tensor).min(dim=MCMC_DIM).values - TOL high = dist.icdf(q_tensor).max(dim=MCMC_DIM).values + TOL bounds = torch.cat((low.unsqueeze(0), high.unsqueeze(0)), dim=0) return batched_bisect( f=lambda x: dist.cdf(x.unsqueeze(MCMC_DIM)).mean(dim=MCMC_DIM), target=q, bounds=bounds, )
[docs]class FullyBayesianPosteriorList(PosteriorList): r"""A Posterior represented by a list of independent Posteriors. This posterior should only be used when at least one posterior is a `FullyBayesianPosterior`. Posteriors that aren't of type `FullyBayesianPosterior` are automatically reshaped to match the size of the fully Bayesian posteriors to allow mixing, e.g., deterministic and fully Bayesian models. Args: *posteriors: A variable number of single-outcome posteriors. Example: >>> p_1 = model_1.posterior(test_X) >>> p_2 = model_2.posterior(test_X) >>> p_12 = FullyBayesianPosteriorList(p_1, p_2) """ def _get_mcmc_batch_dimension(self) -> int: """Return the number of MCMC samples in the corresponding batch dimension.""" mcmc_samples = [ p.mean.shape[MCMC_DIM] for p in self.posteriors if isinstance(p, FullyBayesianPosterior) ] if len(set(mcmc_samples)) > 1: raise NotImplementedError( "All MCMC batch dimensions must have the same size, got shapes: " f"{mcmc_samples}." ) return mcmc_samples[0] @staticmethod def _reshape_tensor(X: Tensor, mcmc_samples: int) -> Tensor: """Reshape a tensor without an MCMC batch dimension to match the shape.""" X = X.unsqueeze(MCMC_DIM) return X.expand(*X.shape[:MCMC_DIM], mcmc_samples, *X.shape[MCMC_DIM + 1 :]) def _reshape_and_cat(self, Xs: List[Tensor]): r"""Reshape and cat a list of tensors.""" mcmc_samples = self._get_mcmc_batch_dimension() return torch.cat( [ x if isinstance(p, FullyBayesianPosterior) else self._reshape_tensor(x, mcmc_samples=mcmc_samples) for x, p in zip(Xs, self.posteriors) ], dim=-1, ) @property def event_shape(self) -> torch.Size: r"""The event shape (i.e. the shape of a single sample).""" fully_bayesian_posteriors = [ p for p in self.posteriors if isinstance(p, FullyBayesianPosterior) ] event_shape = fully_bayesian_posteriors[0].event_shape if not all(event_shape == p.event_shape for p in fully_bayesian_posteriors): # Make sure all fully Bayesian posteriors have the same event shape raise NotImplementedError( f"`{self.__class__.__name__}.event_shape` is only supported if all " "constituent posteriors have the same `event_shape`." ) event_shapes = [event_shape for _ in self.posteriors] batch_shapes = [es[:-1] for es in event_shapes] return batch_shapes[0] + torch.Size([es[-1] for es in event_shapes]) @property def mean(self) -> Tensor: r"""The mean of the posterior as a `(b) x n x m`-dim Tensor.""" return self._reshape_and_cat(Xs=[p.mean for p in self.posteriors]) @property def variance(self) -> Tensor: r"""The variance of the posterior as a `(b) x n x m`-dim Tensor.""" return self._reshape_and_cat(Xs=[p.variance for p in self.posteriors])
[docs] def rsample( self, sample_shape: Optional[torch.Size] = None, base_samples: Optional[Tensor] = None, ) -> Tensor: r"""Sample from the posterior (with gradients). Args: sample_shape: A `torch.Size` object specifying the sample shape. To draw `n` samples, set to `torch.Size([n])`. To draw `b` batches of `n` samples each, set to `torch.Size([b, n])`. base_samples: An (optional) Tensor of `N(0, I)` base samples of appropriate dimension, typically obtained from a `Sampler`. This is used for deterministic optimization. Returns: A `sample_shape x event`-dim Tensor of samples from the posterior. """ samples = super()._rsample(sample_shape=sample_shape, base_samples=base_samples) return self._reshape_and_cat(Xs=samples)