#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""Model fitting routines."""
from __future__ import annotations
import logging
from functools import partial
from itertools import filterfalse
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from warnings import catch_warnings, simplefilter, warn_explicit, WarningMessage
from botorch.exceptions.errors import ModelFittingError, UnsupportedError
from botorch.exceptions.warnings import OptimizationWarning
from botorch.models.approximate_gp import ApproximateGPyTorchModel
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.optim.closures import get_loss_closure_with_grads
from botorch.optim.core import _LBFGSB_MAXITER_MAXFUN_REGEX
from botorch.optim.fit import fit_gpytorch_mll_scipy, fit_gpytorch_mll_torch
from botorch.optim.utils import (
_warning_handler_template,
get_parameters,
sample_all_priors,
)
from botorch.settings import debug
from botorch.utils.context_managers import (
module_rollback_ctx,
parameter_rollback_ctx,
TensorCheckpoint,
)
from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder
from gpytorch.likelihoods import Likelihood
from gpytorch.mlls._approximate_mll import _ApproximateMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from linear_operator.utils.errors import NotPSDError
from pyro.infer.mcmc import MCMC, NUTS
from torch import device, Tensor
from torch.nn import Parameter
from torch.utils.data import DataLoader
def _debug_warn(w: WarningMessage) -> bool:
if _LBFGSB_MAXITER_MAXFUN_REGEX.search(str(w.message)):
return True
# TODO: Better handle cases where warning handling logic
# affects both debug and rethrow functions.
return False
def _rethrow_warn(w: WarningMessage) -> bool:
if not issubclass(w.category, OptimizationWarning):
return True
if "Optimization timed out after" in str(w.message):
return True
return False
DEFAULT_WARNING_HANDLER = partial(
_warning_handler_template,
debug=_debug_warn,
rethrow=_rethrow_warn,
)
FitGPyTorchMLL = Dispatcher("fit_gpytorch_mll", encoder=type_bypassing_encoder)
[docs]
def fit_gpytorch_mll(
mll: MarginalLogLikelihood,
closure: Optional[Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]]] = None,
optimizer: Optional[Callable] = None,
closure_kwargs: Optional[Dict[str, Any]] = None,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> MarginalLogLikelihood:
r"""Clearing house for fitting models passed as GPyTorch MarginalLogLikelihoods.
Args:
mll: A GPyTorch MarginalLogLikelihood instance.
closure: Forward-backward closure for obtaining objective values and gradients.
Responsible for setting parameters' `grad` attributes. If no closure is
provided, one will be obtained by calling `get_loss_closure_with_grads`.
optimizer: User specified optimization algorithm. When `optimizer is None`,
this keyword argument is omitted when calling the dispatcher.
closure_kwargs: Keyword arguments passed when calling `closure`.
optimizer_kwargs: A dictionary of keyword arguments passed when
calling `optimizer`.
**kwargs: Keyword arguments passed down through the dispatcher to
fit subroutines. Unexpected keywords are ignored.
Returns:
The `mll` instance. If fitting succeeded, then `mll` will be in evaluation mode,
i.e. `mll.training == False`. Otherwise, `mll` will be in training mode.
"""
if optimizer is not None: # defer to per-method defaults
kwargs["optimizer"] = optimizer
return FitGPyTorchMLL(
mll,
type(mll.likelihood),
type(mll.model),
closure=closure,
closure_kwargs=closure_kwargs,
optimizer_kwargs=optimizer_kwargs,
**kwargs,
)
@FitGPyTorchMLL.register(MarginalLogLikelihood, object, object)
def _fit_fallback(
mll: MarginalLogLikelihood,
_: Type[object],
__: Type[object],
*,
closure: Optional[Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]]] = None,
optimizer: Optional[Callable] = fit_gpytorch_mll_scipy,
closure_kwargs: Optional[Dict[str, Any]] = None,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
max_attempts: int = 5,
warning_handler: Callable[[WarningMessage], bool] = DEFAULT_WARNING_HANDLER,
caught_exception_types: Tuple[Type[BaseException], ...] = (NotPSDError,),
**ignore: Any,
) -> MarginalLogLikelihood:
r"""Generic fallback method for fitting Gaussian processes.
Attempts to fit a model using the provided optimizer, then determines whether or
not to retry by evaluating a given policy on emitted warning messages. The first
attempt is run using the initialized parameter values; subsequent attempts begin
by resampling tunable parameters.
Args:
closure: Forward-backward closure for obtaining objective values and gradients.
Responsible for setting parameters' `grad` attributes. If no closure is
provided, one will be obtained by calling `get_loss_closure_with_grads`.
optimizer: The underlying optimization algorithm to run.
closure_kwargs: Keyword arguments passed to `closure`.
optimizer_kwargs: Keyword arguments passed to `optimizer`.
max_attempts: The maximum number of fit attempts allowed. The attempt budget
is NOT shared between calls to this method.
warning_handler: A function used to filter warnings produced when calling
`optimizer`. Any unfiltered warnings (those for which `warning_handler`
returns `False`) will be rethrown and trigger a model fitting retry.
caught_exception_types: A tuple of exception types whose instances should
be redirected to `logging.DEBUG`.
**ignore: This function ignores unrecognized keyword arguments.
Returns:
The `mll` instance. If fitting succeeded, then `mll` will be in evaluation mode,
i.e. `mll.training == False`. Otherwise, `mll` will be in training mode.
"""
# Setup
optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
params_nograd: Dict[str, Parameter] = None # pyre-ignore [9]
ckpt_nograd: Dict[str, TensorCheckpoint] = None # pyre-ignore [9]
ckpt: Dict[str, TensorCheckpoint] = None # pyre-ignore [9]
# Build closure
mll.train()
if closure is None:
closure = get_loss_closure_with_grads(
mll, parameters=get_parameters(mll, requires_grad=True)
)
if closure_kwargs is not None:
closure = partial(closure, **closure_kwargs)
# Attempt to fit the model
for attempt in range(1, 1 + max_attempts):
# Wrap with rollback contextmanager so that each loop iteration reloads the
# original state_dict upon exiting (unless we clear `ckpt`).
with module_rollback_ctx(mll, checkpoint=ckpt, device=device("cpu")) as ckpt:
if attempt > 1: # resample free parameters
if params_nograd is None:
params_nograd = get_parameters(mll, requires_grad=False)
if ckpt_nograd is None: # reuse primary checkpoint
ckpt_nograd = {name: ckpt[name] for name in params_nograd}
with parameter_rollback_ctx(params_nograd, checkpoint=ckpt_nograd):
sample_all_priors(mll.model)
try:
# Fit the model
with catch_warnings(record=True) as warning_list, debug(True):
simplefilter("always", category=OptimizationWarning)
optimizer(mll, closure=closure, **optimizer_kwargs)
# Resolved warnings and determine whether or not to retry
done = True
for w in filterfalse(warning_handler, warning_list):
warn_explicit(str(w.message), w.category, w.filename, w.lineno)
done = False
if done:
ckpt.clear() # do not rollback upon exiting
return mll.eval()
# Ensure mll is in the right mode if fitting failed
mll = mll if mll.training else mll.train()
logging.log(
logging.DEBUG,
f"Fit attempt #{attempt} of {max_attempts} triggered retry policy"
f"{'.' if attempt == max_attempts else '; retrying...'}",
)
except caught_exception_types as err:
logging.log(
logging.DEBUG,
f"Fit attempt #{attempt} of {max_attempts} failed with exception: "
f"{err}",
)
msg = "All attempts to fit the model have failed."
if debug.off():
msg = msg + " For more information, try enabling botorch.settings.debug mode."
raise ModelFittingError(msg)
@FitGPyTorchMLL.register(SumMarginalLogLikelihood, object, ModelListGP)
def _fit_list(
mll: SumMarginalLogLikelihood,
_: Type[Likelihood],
__: Type[ModelListGP],
**kwargs: Any,
) -> SumMarginalLogLikelihood:
r"""Fitting routine for lists of independent Gaussian processes.
Args:
**kwargs: Passed to each of `mll.mlls`.
Returns:
The `mll` instance. If fitting succeeded for all of `mll.mlls`, then `mll` will
be in evaluation mode, i.e. `mll.training == False`. Otherwise, `mll` will be in
training mode.
"""
mll.train()
for sub_mll in mll.mlls:
fit_gpytorch_mll(sub_mll, **kwargs)
return mll.eval() if not any(sub_mll.training for sub_mll in mll.mlls) else mll
@FitGPyTorchMLL.register(_ApproximateMarginalLogLikelihood, object, object)
def _fit_fallback_approximate(
mll: _ApproximateMarginalLogLikelihood,
_: Type[Likelihood],
__: Type[ApproximateGPyTorchModel],
*,
closure: Optional[Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]]] = None,
data_loader: Optional[DataLoader] = None,
optimizer: Optional[Callable] = None,
full_batch_limit: int = 1024,
**kwargs: Any,
) -> _ApproximateMarginalLogLikelihood:
r"""Fallback method for fitting approximate Gaussian processes.
Args:
closure: Forward-backward closure for obtaining objective values and gradients.
Responsible for setting parameters' `grad` attributes. If no closure is
provided, one will be obtained by calling `get_loss_closure_with_grads`.
optimizer: The underlying optimization algorithm to run. Default to
`fit_gpytorch_mll_scipy` when `closure=None` and the model's internal
training set has no more than `full_batch_cutoff` observations; otherwise,
defaults to `fit_gpytorch_mll_torch`.
data_loader: An optional DataLoader to pass to `get_loss_closure_with_grads`.
May only be provided when `closure=None`.
full_batch_limit: Threshold for determining the default choice of `optimizer`
when `closure=None`.
**kwargs: Keyword arguments passed to `_fit_fallback`.
"""
if data_loader is not None:
if closure is not None:
raise UnsupportedError(
"Only one of `data_loader` or `closure` may be passed."
)
closure = get_loss_closure_with_grads(
mll=mll,
data_loader=data_loader,
parameters=get_parameters(mll, requires_grad=True),
)
if optimizer is None:
optimizer = (
fit_gpytorch_mll_scipy
if closure is None and len(mll.model.train_targets) <= full_batch_limit
else fit_gpytorch_mll_torch
)
return _fit_fallback(mll, _, __, closure=closure, optimizer=optimizer, **kwargs)
[docs]
def fit_fully_bayesian_model_nuts(
model: Union[SaasFullyBayesianSingleTaskGP, SaasFullyBayesianMultiTaskGP],
max_tree_depth: int = 6,
warmup_steps: int = 512,
num_samples: int = 256,
thinning: int = 16,
disable_progbar: bool = False,
jit_compile: bool = False,
) -> None:
r"""Fit a fully Bayesian model using the No-U-Turn-Sampler (NUTS)
Args:
model: SaasFullyBayesianSingleTaskGP to be fitted.
max_tree_depth: Maximum tree depth for NUTS
warmup_steps: The number of burn-in steps for NUTS.
num_samples: The number of MCMC samples. Note that with thinning,
num_samples / thinning samples are retained.
thinning: The amount of thinning. Every nth sample is retained.
disable_progbar: A boolean indicating whether to print the progress
bar and diagnostics during MCMC.
jit_compile: Whether to use jit. Using jit may be ~2X faster (rough estimate),
but it will also increase the memory usage and sometimes result in runtime
errors, e.g., https://github.com/pyro-ppl/pyro/issues/3136.
Example:
>>> gp = SaasFullyBayesianSingleTaskGP(train_X, train_Y)
>>> fit_fully_bayesian_model_nuts(gp)
"""
model.train()
# Do inference with NUTS
nuts = NUTS(
model.pyro_model.sample,
jit_compile=jit_compile,
full_mass=True,
ignore_jit_warnings=True,
max_tree_depth=max_tree_depth,
)
mcmc = MCMC(
nuts,
warmup_steps=warmup_steps,
num_samples=num_samples,
disable_progbar=disable_progbar,
)
mcmc.run()
# Get final MCMC samples from the Pyro model
mcmc_samples = model.pyro_model.postprocess_mcmc_samples(
mcmc_samples=mcmc.get_samples()
)
for k, v in mcmc_samples.items():
mcmc_samples[k] = v[::thinning]
# Load the MCMC samples back into the BoTorch model
model.load_mcmc_samples(mcmc_samples)
model.eval()