Source code for botorch.fit

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""Model fitting routines."""

from __future__ import annotations

import logging
from contextlib import nullcontext
from re import compile, Pattern
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
from warnings import catch_warnings, simplefilter, warn, WarningMessage

from botorch.exceptions.errors import ModelFittingError, UnsupportedError
from botorch.exceptions.warnings import BotorchWarning, OptimizationWarning
from botorch.models.converter import batched_to_model_list, model_list_to_batched
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP

from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.optim.fit import fit_gpytorch_scipy
from botorch.optim.utils import (
    allclose_mll,
    del_attribute_ctx,
    parameter_rollback_ctx,
    requires_grad_ctx,
    sample_all_priors,
    state_rollback_ctx,
    Tkwargs,
)
from botorch.settings import debug
from botorch.utils.dispatcher import Dispatcher, MDNotImplementedError
from gpytorch.likelihoods import Likelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from linear_operator.utils.errors import NotPSDError
from pyro.infer.mcmc import MCMC, NUTS
from torch import device, mean, Tensor

OptimizerType = Callable[[MarginalLogLikelihood], Tuple[MarginalLogLikelihood, Any]]
DEFAULT_LOGGING_PATTERNS: Dict[int, Pattern] = {
    logging.DEBUG: compile(  # catch warning corresponding to `maxiter` and `maxfun`
        "TOTAL NO. of (ITERATIONS REACHED LIMIT|f AND g EVALUATIONS EXCEEDS LIMIT)"
    )
}


[docs]def DEFAULT_WARNING_FILTER( w: WarningMessage, logging_patterns: Dict[int, Pattern] = DEFAULT_LOGGING_PATTERNS, ) -> bool: r"""Default warning resolution policy: retry upon encountering an OptimizationWarning that does not match any logging pattern. Args: w: Candidate for filtering. logging_patterns: Dictionary mapping logging levels to regular expressions. Warning messages are compared against these expressions and matches are awarded first-come-first-serve when iterating through the dictionary. Returns: Boolean indicating whether the warning is unresolved. """ for level, pattern in logging_patterns.items(): if pattern.search(str(w.message)): logging.log(level, w.message) return False # Rethrow OptimizationWarnings but mark them as resolved if not issubclass(w.category, OptimizationWarning): warn(w.message, w.category) return False return True
# Dispatcher for `fit_gpytorch_mll` def _type_bypassing_encoder(arg: Any) -> Type: # Allow type variables to be passed as pre-encoded arguments return arg if isinstance(arg, type) else type(arg) dispatcher = Dispatcher("fit_gpytorch_mll", encoder=_type_bypassing_encoder)
[docs]def fit_gpytorch_mll( mll: MarginalLogLikelihood, optimizer: Optional[Callable] = None, optimizer_kwargs: Optional[dict] = None, **kwargs: Any, ) -> MarginalLogLikelihood: r"""Clearing house for fitting models passed as GPyTorch MarginalLogLikelihoods. Args: mll: A GPyTorch MarginalLogLikelihood instance. optimizer: User specified optimization algorithm. When `optimizer is None`, this keyword argument is omitted when calling the dispatcher. optimizer_kwargs: A dictionary of keyword arguments passed when calling `optimizer`. **kwargs: Keyword arguments passed down through the dispatcher to fit subroutines. Unexpected keywords are ignored. Returns: The `mll` instance. If fitting succeeded, then `mll` will be in evaluation mode, i.e. `mll.training == False`. Otherwise, `mll` will be in training mode. """ if optimizer is not None: # defer to per-method defaults kwargs["optimizer"] = optimizer return dispatcher( mll, type(mll.likelihood), type(mll.model), optimizer_kwargs=optimizer_kwargs, **kwargs, )
[docs]def fit_gpytorch_model( mll: MarginalLogLikelihood, optimizer: Optional[OptimizerType] = None, optimizer_kwargs: Optional[dict] = None, exclude: Optional[Iterable[str]] = None, max_retries: Optional[int] = None, **kwargs: Any, ) -> MarginalLogLikelihood: r"""Convenience method for fitting GPyTorch models using legacy API. For more details, see `fit_gpytorch_mll`. Args: mll: A GPyTorch MarginalLogLikelihood instance. optimizer: User specified optimization algorithm. When `optimizer is None`, this keyword argument is omitted when calling the dispatcher from inside `fit_gpytorch_mll`. exclude: Legacy argument for specifying parameters `x` that should be held fixed during optimization. Internally, used to temporarily set `x.requires_grad` to False. max_retries: Legacy name for `max_attempts`. When `max_retries is None`, this keyword argument is omitted when calling `fit_gpytorch_mll`. """ warn( "`fit_gpytorch_model` is marked for deprecation, consider using " "`fit_gpytorch_mll` instead.", DeprecationWarning, ) if max_retries is not None: kwargs["max_attempts"] = max_retries optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs for key in ("bounds", "options", "track_iterations", "approx_mll"): if key not in kwargs: continue val = kwargs.pop(key) if key in optimizer_kwargs and val is not optimizer_kwargs[key]: raise SyntaxError(f"keyword argument repeated: {key}") optimizer_kwargs[key] = val with ( nullcontext() if exclude is None else requires_grad_ctx(mll, assignments={name: False for name in exclude}) ): try: mll = fit_gpytorch_mll( mll, optimizer=optimizer, optimizer_kwargs=optimizer_kwargs, **kwargs, ) except ModelFittingError as err: warn(str(err), RuntimeWarning) return mll
@dispatcher.register(MarginalLogLikelihood, object, object) def _fit_fallback( mll: MarginalLogLikelihood, _: Type[object], __: Type[object], *, optimizer: Optional[Callable] = fit_gpytorch_scipy, optimizer_kwargs: Optional[dict] = None, max_attempts: int = 5, warning_filter: Callable[[WarningMessage], bool] = DEFAULT_WARNING_FILTER, caught_exception_types: Tuple[Type[BaseException], ...] = (NotPSDError,), **ignore: Any, ) -> MarginalLogLikelihood: r"""Generic fallback method for fitting Gaussian processes. Attempts to fit a model using the provided optimizer, then determines whether or not to retry by evaluating a given policy on emitted warning messages. The first attempt is run using the initialized parameter values; subsequent attempts begin by resampling tunable parameters. Args: optimizer: The underlying optimization algorithm to run. optimizer_kwargs: Keyword arguments passed when calling `optimizer`. max_attempts: The maximum number of fit attempts allowed. The attempt budget is NOT shared between calls to this method. warning_filter: A function used to filter warnings produced when calling `optimizer`. Any unfiltered warnings will be rethrown and trigger a model fitting retry. caught_exception_types: A tuple of exception types whose instances should be redirected to `logging.DEBUG`. **ignore: This function ignores unrecognized keyword arguments. Returns: The `mll` instance. If fitting succeeded, then `mll` will be in evaluation mode, i.e. `mll.training == False`. Otherwise, `mll` will be in training mode. """ ckpt: Dict[str, Tuple[Tensor, Tkwargs]] = None # lazy CPU-based checkpoint ckpt_nograd: Dict[str, Tuple[Tensor, Tkwargs]] = None # subset for fixed parameters optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs mll.train() for attempt in range(1, 1 + max_attempts): # Wrap with rollback contextmanager so each loop iteration reloads the original # state_dict upon exiting (unless `ckpt` is cleared). with state_rollback_ctx(mll, checkpoint=ckpt, device=device("cpu")) as ckpt: if ckpt_nograd is None: ckpt_nograd = { # reuse cached values from primary checkpoint k: ckpt[k] for k, v in mll.named_parameters() if not v.requires_grad } if attempt > 1: # maybe resample parameters that require gradients with parameter_rollback_ctx(mll, checkpoint=ckpt_nograd): sample_all_priors(mll.model) try: # Fit the model with catch_warnings(record=True) as warning_list, debug(True): simplefilter("always", category=OptimizationWarning) mll, _ = optimizer(mll, **optimizer_kwargs) # Resolve warning messages and determine whether or not to retry done = True for unresolved_warning in filter(warning_filter, warning_list): warn(unresolved_warning.message, unresolved_warning.category) done = False if done: ckpt.clear() # do not rollback upon exiting return mll.eval() # Ensure mll is in the right mode if fitting failed mll = mll if mll.training else mll.train() logging.log( logging.DEBUG, f"Fit attempt #{attempt} of {max_attempts} triggered retry policy" f"{'.' if attempt == max_attempts else '; retrying...'}", ) except caught_exception_types as err: logging.log( logging.DEBUG, f"Fit attempt #{attempt} of {max_attempts} failed with exception: " f"{err}", ) raise ModelFittingError("All attempts to fit the model have failed.") @dispatcher.register(SumMarginalLogLikelihood, Likelihood, ModelListGP) def _fit_list( mll: SumMarginalLogLikelihood, _: Type[Likelihood], __: Type[ModelListGP], **kwargs: Any, ) -> SumMarginalLogLikelihood: r"""Fitting routine for lists of independent Gaussian processes. Args: **kwargs: Passed to each of `mll.mlls`. Returns: The `mll` instance. If fitting succeeded for all of `mll.mlls`, then `mll` will be in evaluation mode, i.e. `mll.training == False`. Otherwise, `mll` will be in training mode. """ mll.train() for sub_mll in mll.mlls: fit_gpytorch_mll(sub_mll, **kwargs) return mll.eval() if not any(sub_mll.training for sub_mll in mll.mlls) else mll @dispatcher.register(MarginalLogLikelihood, Likelihood, BatchedMultiOutputGPyTorchModel) def _fit_multioutput_independent( mll: MarginalLogLikelihood, _: Type[Likelihood], __: Type[BatchedMultiOutputGPyTorchModel], *, sequential: bool = True, **kwargs: Any, ) -> MarginalLogLikelihood: r"""Fitting routine for multioutput Gaussian processes. Args: sequential: Boolean specifying whether or not to an attempt should be made to fit the model as a collection of independent GPs. Only relevant for certain types of GPs with independent outputs, see `batched_to_model_list`. **kwargs: Passed to the next method unaltered. Returns: The `mll` instance. If fitting succeeded, then `mll` will be in evaluation mode, i.e. `mll.training == False`. Otherwise, `mll` will be in training mode. """ if ( # incompatible models not sequential or mll.model.num_outputs == 1 or mll.likelihood is not getattr(mll.model, "likelihood", None) ): raise MDNotImplementedError # defer to generic # TODO: Unpacking of OutcomeTransforms not yet supported. Targets are often # pre-transformed in __init__, so try fitting with outcome_transform hidden mll.train() with del_attribute_ctx(mll.model, "outcome_transform"): try: # Attempt to unpack batched model into a list of independent submodels unpacked_model = batched_to_model_list(mll.model) unpacked_mll = SumMarginalLogLikelihood( # avg. over MLLs internally unpacked_model.likelihood, unpacked_model ) if not allclose_mll(a=mll, b=unpacked_mll, transform_a=mean): raise RuntimeError( # validate model unpacking "Training loss of unpacked model differs from that of the original." ) # Fit submodels independently unpacked_mll = fit_gpytorch_mll(unpacked_mll, **kwargs) # Repackage submodels and copy over state_dict repacked_model = model_list_to_batched(unpacked_mll.model.train()) repacked_mll = type(mll)(repacked_model.likelihood, repacked_model) with state_rollback_ctx(mll, device=device("cpu")) as ckpt: mll.load_state_dict(repacked_mll.state_dict()) if not allclose_mll(a=mll, b=repacked_mll): raise RuntimeError( # validate model repacking "Training loss of repacked model differs from that of the " "original." ) ckpt.clear() # do not rollback when exiting return mll.eval() # DONE! except (AttributeError, RuntimeError, UnsupportedError) as err: msg = f"Failed to independently fit submodels with exception: {err}" warn( f"{msg.rstrip('.')}. Deferring to generic dispatch...", BotorchWarning, ) raise MDNotImplementedError
[docs]def fit_fully_bayesian_model_nuts( model: Union[SaasFullyBayesianSingleTaskGP, SaasFullyBayesianMultiTaskGP], max_tree_depth: int = 6, warmup_steps: int = 512, num_samples: int = 256, thinning: int = 16, disable_progbar: bool = False, jit_compile: bool = False, ) -> None: r"""Fit a fully Bayesian model using the No-U-Turn-Sampler (NUTS) Args: model: SaasFullyBayesianSingleTaskGP to be fitted. max_tree_depth: Maximum tree depth for NUTS warmup_steps: The number of burn-in steps for NUTS. num_samples: The number of MCMC samples. Note that with thinning, num_samples / thinning samples are retained. thinning: The amount of thinning. Every nth sample is retained. disable_progbar: A boolean indicating whether to print the progress bar and diagnostics during MCMC. jit_compile: Whether to use jit. Using jit may be ~2X faster (rough estimate), but it will also increase the memory usage and sometimes result in runtime errors, e.g., https://github.com/pyro-ppl/pyro/issues/3136. Example: >>> gp = SaasFullyBayesianSingleTaskGP(train_X, train_Y) >>> fit_fully_bayesian_model_nuts(gp) """ model.train() # Do inference with NUTS nuts = NUTS( model.pyro_model.sample, jit_compile=jit_compile, full_mass=True, ignore_jit_warnings=True, max_tree_depth=max_tree_depth, ) mcmc = MCMC( nuts, warmup_steps=warmup_steps, num_samples=num_samples, disable_progbar=disable_progbar, ) mcmc.run() # Get final MCMC samples from the Pyro model mcmc_samples = model.pyro_model.postprocess_mcmc_samples( mcmc_samples=mcmc.get_samples() ) for k, v in mcmc_samples.items(): mcmc_samples[k] = v[::thinning] # Load the MCMC samples back into the BoTorch model model.load_mcmc_samples(mcmc_samples) model.eval()