#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Abstract class for acquisition functions leveraging a cached Cholesky
decomposition of the posterior covariance over f(X_baseline).
"""
from __future__ import annotations
import warnings
import torch
from botorch.acquisition.acquisition import MCSamplerMixin
from botorch.exceptions.warnings import BotorchWarning
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.higher_order_gp import HigherOrderGP
from botorch.models.model import Model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import KroneckerMultiTaskGP, MultiTaskGP
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
from botorch.sampling.base import MCSampler
from botorch.utils.low_rank import extract_batch_covar, sample_cached_cholesky
from gpytorch.distributions.multitask_multivariate_normal import (
MultitaskMultivariateNormal,
)
from linear_operator.utils.errors import NanError, NotPSDError
from torch import Tensor
[docs]
def supports_cache_root(model: Model) -> bool:
r"""Checks if a model supports the cache_root functionality.
The two criteria are that the model is not multi-task and the model
produces a GPyTorchPosterior.
"""
if isinstance(model, ModelListGP):
return all(supports_cache_root(m) for m in model.models)
# Multi task models and non-GPyTorch models are not supported.
if isinstance(
model, (MultiTaskGP, KroneckerMultiTaskGP, HigherOrderGP)
) or not isinstance(model, GPyTorchModel):
return False
# Models that return a TransformedPosterior are not supported.
if hasattr(model, "outcome_transform") and (not model.outcome_transform._is_linear):
return False
return True
def _get_cache_root_not_supported_message(model_cls: type) -> str:
msg = (
"`cache_root` is only supported for GPyTorchModels that "
"are not MultiTask models and don't produce a "
f"TransformedPosterior. Got a model of type {model_cls}. Setting "
"`cache_root = False`."
)
return msg
[docs]
class CachedCholeskyMCSamplerMixin(MCSamplerMixin):
r"""Abstract Mixin class for acquisition functions using a cached Cholesky.
Specifically, this is for acquisition functions that require sampling from
the posterior P(f(X_baseline, X) | D). The Cholesky of the posterior
covariance over f(X_baseline) is cached.
"""
def __init__(
self,
model: Model,
cache_root: bool = False,
sampler: MCSampler | None = None,
) -> None:
r"""Set class attributes and perform compatibility checks.
Args:
model: A model.
cache_root: A boolean indicating whether to cache the Cholesky.
This might be overridden in the model is not compatible.
sampler: An optional MCSampler object.
"""
MCSamplerMixin.__init__(self, sampler=sampler)
if cache_root and not supports_cache_root(model):
warnings.warn(
_get_cache_root_not_supported_message(type(model)),
RuntimeWarning,
stacklevel=3,
)
cache_root = False
self._cache_root = cache_root
def _compute_root_decomposition(
self,
posterior: Posterior,
) -> Tensor:
r"""Cache Cholesky of the posterior covariance over f(X_baseline).
Because `LinearOperator.root_decomposition` is decorated with LinearOperator's
@cached decorator, this function is doing a lot implicitly:
1) Check if a root decomposition has already been cached to `lazy_covar`.
Note that it will not have been if `posterior.mvn` is a
`MultitaskMultivariateNormal`, since we construct `lazy_covar` in that
case.
2) If the root decomposition has not been found in the cache, compute it.
3) Write it to the cache of `lazy_covar`. Note that this will become
inaccessible if `posterior.mvn` is a `MultitaskMultivariateNormal`,
since in that case `lazy_covar`'s scope is only this function.
Args:
posterior: The posterior over f(X_baseline).
"""
if isinstance(posterior.distribution, MultitaskMultivariateNormal):
lazy_covar = extract_batch_covar(posterior.distribution)
else:
lazy_covar = posterior.distribution.lazy_covariance_matrix
lazy_covar_root = lazy_covar.root_decomposition()
return lazy_covar_root.root.to_dense()
def _get_f_X_samples(self, posterior: GPyTorchPosterior, q_in: int) -> Tensor:
r"""Get posterior samples at the `q_in` new points from the joint posterior.
Args:
posterior: The joint posterior is over (X_baseline, X).
q_in: The number of new points in the posterior. See `_set_sampler` for
more information.
Returns:
A `sample_shape x batch_shape x q x m`-dim tensor of posterior
samples at the new points.
"""
# Technically we should make sure that we add a consistent nugget to the
# cached covariance (and box decompositions) and the new block.
# But recomputing box decompositions every time the jitter changes would
# be quite slow.
if self._cache_root and hasattr(self, "_baseline_L"):
try:
return sample_cached_cholesky(
posterior=posterior,
baseline_L=self._baseline_L,
q=q_in,
base_samples=self.sampler.base_samples,
sample_shape=self.sampler.sample_shape,
)
except (NanError, NotPSDError):
warnings.warn(
"Low-rank cholesky updates failed due NaNs or due to an "
"ill-conditioned covariance matrix. "
"Falling back to standard sampling.",
BotorchWarning,
stacklevel=3,
)
# TODO: improve efficiency for multi-task models
samples = self.get_posterior_samples(posterior)
if isinstance(self.model, HigherOrderGP):
# Select the correct q-batch dimension for HOGP.
q_dim = -self.model._num_dimensions
q_idcs = (
torch.arange(-q_in, 0, device=samples.device) + samples.shape[q_dim]
)
return samples.index_select(q_dim, q_idcs)
else:
return samples[..., -q_in:, :]
def _set_sampler(
self,
q_in: int,
posterior: Posterior,
) -> None:
r"""Update the sampler to use the original base samples for X_baseline.
Args:
q_in: The effective input batch size. This is typically equal to the
q-batch size of `X`. However, if using a one-to-many input transform,
e.g., `InputPerturbation` with `n_w` perturbations, the posterior will
have `n_w` points on the q-batch for each point on the q-batch of `X`.
In which case, `q_in = q * n_w` is used.
posterior: The posterior.
"""
if self.q_in != q_in and self.base_sampler is not None:
self.sampler._update_base_samples(
posterior=posterior, base_sampler=self.base_sampler
)
self.q_in = q_in