Source code for botorch.acquisition.cached_cholesky

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Abstract class for acquisition functions leveraging a cached Cholesky
decomposition of the posterior covariance over f(X_baseline).
"""

from __future__ import annotations

import warnings

import torch
from botorch.acquisition.acquisition import MCSamplerMixin
from botorch.exceptions.warnings import BotorchWarning
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.higher_order_gp import HigherOrderGP
from botorch.models.model import Model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import KroneckerMultiTaskGP, MultiTaskGP
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
from botorch.sampling.base import MCSampler
from botorch.utils.low_rank import extract_batch_covar, sample_cached_cholesky
from gpytorch.distributions.multitask_multivariate_normal import (
    MultitaskMultivariateNormal,
)
from linear_operator.utils.errors import NanError, NotPSDError
from torch import Tensor


[docs] def supports_cache_root(model: Model) -> bool: r"""Checks if a model supports the cache_root functionality. The two criteria are that the model is not multi-task and the model produces a GPyTorchPosterior. """ if isinstance(model, ModelListGP): return all(supports_cache_root(m) for m in model.models) # Multi task models and non-GPyTorch models are not supported. if isinstance( model, (MultiTaskGP, KroneckerMultiTaskGP, HigherOrderGP) ) or not isinstance(model, GPyTorchModel): return False # Models that return a TransformedPosterior are not supported. if hasattr(model, "outcome_transform") and (not model.outcome_transform._is_linear): return False return True
def _get_cache_root_not_supported_message(model_cls: type) -> str: msg = ( "`cache_root` is only supported for GPyTorchModels that " "are not MultiTask models and don't produce a " f"TransformedPosterior. Got a model of type {model_cls}. Setting " "`cache_root = False`." ) return msg
[docs] class CachedCholeskyMCSamplerMixin(MCSamplerMixin): r"""Abstract Mixin class for acquisition functions using a cached Cholesky. Specifically, this is for acquisition functions that require sampling from the posterior P(f(X_baseline, X) | D). The Cholesky of the posterior covariance over f(X_baseline) is cached. """ def __init__( self, model: Model, cache_root: bool = False, sampler: MCSampler | None = None, ) -> None: r"""Set class attributes and perform compatibility checks. Args: model: A model. cache_root: A boolean indicating whether to cache the Cholesky. This might be overridden in the model is not compatible. sampler: An optional MCSampler object. """ MCSamplerMixin.__init__(self, sampler=sampler) if cache_root and not supports_cache_root(model): warnings.warn( _get_cache_root_not_supported_message(type(model)), RuntimeWarning, stacklevel=3, ) cache_root = False self._cache_root = cache_root def _compute_root_decomposition( self, posterior: Posterior, ) -> Tensor: r"""Cache Cholesky of the posterior covariance over f(X_baseline). Because `LinearOperator.root_decomposition` is decorated with LinearOperator's @cached decorator, this function is doing a lot implicitly: 1) Check if a root decomposition has already been cached to `lazy_covar`. Note that it will not have been if `posterior.mvn` is a `MultitaskMultivariateNormal`, since we construct `lazy_covar` in that case. 2) If the root decomposition has not been found in the cache, compute it. 3) Write it to the cache of `lazy_covar`. Note that this will become inaccessible if `posterior.mvn` is a `MultitaskMultivariateNormal`, since in that case `lazy_covar`'s scope is only this function. Args: posterior: The posterior over f(X_baseline). """ if isinstance(posterior.distribution, MultitaskMultivariateNormal): lazy_covar = extract_batch_covar(posterior.distribution) else: lazy_covar = posterior.distribution.lazy_covariance_matrix lazy_covar_root = lazy_covar.root_decomposition() return lazy_covar_root.root.to_dense() def _get_f_X_samples(self, posterior: GPyTorchPosterior, q_in: int) -> Tensor: r"""Get posterior samples at the `q_in` new points from the joint posterior. Args: posterior: The joint posterior is over (X_baseline, X). q_in: The number of new points in the posterior. See `_set_sampler` for more information. Returns: A `sample_shape x batch_shape x q x m`-dim tensor of posterior samples at the new points. """ # Technically we should make sure that we add a consistent nugget to the # cached covariance (and box decompositions) and the new block. # But recomputing box decompositions every time the jitter changes would # be quite slow. if self._cache_root and hasattr(self, "_baseline_L"): try: return sample_cached_cholesky( posterior=posterior, baseline_L=self._baseline_L, q=q_in, base_samples=self.sampler.base_samples, sample_shape=self.sampler.sample_shape, ) except (NanError, NotPSDError): warnings.warn( "Low-rank cholesky updates failed due NaNs or due to an " "ill-conditioned covariance matrix. " "Falling back to standard sampling.", BotorchWarning, stacklevel=3, ) # TODO: improve efficiency for multi-task models samples = self.get_posterior_samples(posterior) if isinstance(self.model, HigherOrderGP): # Select the correct q-batch dimension for HOGP. q_dim = -self.model._num_dimensions q_idcs = ( torch.arange(-q_in, 0, device=samples.device) + samples.shape[q_dim] ) return samples.index_select(q_dim, q_idcs) else: return samples[..., -q_in:, :] def _set_sampler( self, q_in: int, posterior: Posterior, ) -> None: r"""Update the sampler to use the original base samples for X_baseline. Args: q_in: The effective input batch size. This is typically equal to the q-batch size of `X`. However, if using a one-to-many input transform, e.g., `InputPerturbation` with `n_w` perturbations, the posterior will have `n_w` points on the q-batch for each point on the q-batch of `X`. In which case, `q_in = q * n_w` is used. posterior: The posterior. """ if self.q_in != q_in and self.base_sampler is not None: self.sampler._update_base_samples( posterior=posterior, base_sampler=self.base_sampler ) self.q_in = q_in