#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""Model fitting routines."""
from __future__ import annotations
import logging
from contextlib import nullcontext
from re import compile, Pattern
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
from warnings import catch_warnings, simplefilter, warn, WarningMessage
from botorch.exceptions.errors import ModelFittingError, UnsupportedError
from botorch.exceptions.warnings import BotorchWarning, OptimizationWarning
from botorch.models.converter import batched_to_model_list, model_list_to_batched
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.optim.fit import fit_gpytorch_scipy
from botorch.optim.utils import (
allclose_mll,
del_attribute_ctx,
parameter_rollback_ctx,
requires_grad_ctx,
sample_all_priors,
state_rollback_ctx,
Tkwargs,
)
from botorch.settings import debug
from botorch.utils.dispatcher import Dispatcher, MDNotImplementedError
from gpytorch.likelihoods import Likelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from linear_operator.utils.errors import NotPSDError
from pyro.infer.mcmc import MCMC, NUTS
from torch import device, mean, Tensor
OptimizerType = Callable[[MarginalLogLikelihood], Tuple[MarginalLogLikelihood, Any]]
DEFAULT_LOGGING_PATTERNS: Dict[int, Pattern] = {
logging.DEBUG: compile( # catch warning corresponding to `maxiter` and `maxfun`
"TOTAL NO. of (ITERATIONS REACHED LIMIT|f AND g EVALUATIONS EXCEEDS LIMIT)"
)
}
[docs]def DEFAULT_WARNING_FILTER(
w: WarningMessage,
logging_patterns: Dict[int, Pattern] = DEFAULT_LOGGING_PATTERNS,
) -> bool:
r"""Default warning resolution policy: retry upon encountering an
OptimizationWarning that does not match any logging pattern.
Args:
w: Candidate for filtering.
logging_patterns: Dictionary mapping logging levels to regular expressions.
Warning messages are compared against these expressions and matches are
awarded first-come-first-serve when iterating through the dictionary.
Returns:
Boolean indicating whether the warning is unresolved.
"""
for level, pattern in logging_patterns.items():
if pattern.search(str(w.message)):
logging.log(level, w.message)
return False
# Rethrow OptimizationWarnings but mark them as resolved
if not issubclass(w.category, OptimizationWarning):
warn(w.message, w.category)
return False
return True
# Dispatcher for `fit_gpytorch_mll`
def _type_bypassing_encoder(arg: Any) -> Type:
# Allow type variables to be passed as pre-encoded arguments
return arg if isinstance(arg, type) else type(arg)
dispatcher = Dispatcher("fit_gpytorch_mll", encoder=_type_bypassing_encoder)
[docs]def fit_gpytorch_mll(
mll: MarginalLogLikelihood,
optimizer: Optional[Callable] = None,
optimizer_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> MarginalLogLikelihood:
r"""Clearing house for fitting models passed as GPyTorch MarginalLogLikelihoods.
Args:
mll: A GPyTorch MarginalLogLikelihood instance.
optimizer: User specified optimization algorithm. When `optimizer is None`,
this keyword argument is omitted when calling the dispatcher.
optimizer_kwargs: A dictionary of keyword arguments passed when
calling `optimizer`.
**kwargs: Keyword arguments passed down through the dispatcher to
fit subroutines. Unexpected keywords are ignored.
Returns:
The `mll` instance. If fitting succeeded, then `mll` will be in evaluation mode,
i.e. `mll.training == False`. Otherwise, `mll` will be in training mode.
"""
if optimizer is not None: # defer to per-method defaults
kwargs["optimizer"] = optimizer
return dispatcher(
mll,
type(mll.likelihood),
type(mll.model),
optimizer_kwargs=optimizer_kwargs,
**kwargs,
)
[docs]def fit_gpytorch_model(
mll: MarginalLogLikelihood,
optimizer: Optional[OptimizerType] = None,
optimizer_kwargs: Optional[dict] = None,
exclude: Optional[Iterable[str]] = None,
max_retries: Optional[int] = None,
**kwargs: Any,
) -> MarginalLogLikelihood:
r"""Convenience method for fitting GPyTorch models using legacy API. For more
details, see `fit_gpytorch_mll`.
Args:
mll: A GPyTorch MarginalLogLikelihood instance.
optimizer: User specified optimization algorithm. When `optimizer is None`,
this keyword argument is omitted when calling the dispatcher from inside
`fit_gpytorch_mll`.
exclude: Legacy argument for specifying parameters `x` that should be held fixed
during optimization. Internally, used to temporarily set `x.requires_grad`
to False.
max_retries: Legacy name for `max_attempts`. When `max_retries is None`,
this keyword argument is omitted when calling `fit_gpytorch_mll`.
"""
warn(
"`fit_gpytorch_model` is marked for deprecation, consider using "
"`fit_gpytorch_mll` instead.",
DeprecationWarning,
)
if max_retries is not None:
kwargs["max_attempts"] = max_retries
optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
for key in ("bounds", "options", "track_iterations", "approx_mll"):
if key not in kwargs:
continue
val = kwargs.pop(key)
if key in optimizer_kwargs and val is not optimizer_kwargs[key]:
raise SyntaxError(f"keyword argument repeated: {key}")
optimizer_kwargs[key] = val
with (
nullcontext()
if exclude is None
else requires_grad_ctx(mll, assignments={name: False for name in exclude})
):
try:
mll = fit_gpytorch_mll(
mll,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
**kwargs,
)
except ModelFittingError as err:
warn(str(err), RuntimeWarning)
return mll
@dispatcher.register(MarginalLogLikelihood, object, object)
def _fit_fallback(
mll: MarginalLogLikelihood,
_: Type[object],
__: Type[object],
*,
optimizer: Optional[Callable] = fit_gpytorch_scipy,
optimizer_kwargs: Optional[dict] = None,
max_attempts: int = 5,
warning_filter: Callable[[WarningMessage], bool] = DEFAULT_WARNING_FILTER,
caught_exception_types: Tuple[Type[BaseException], ...] = (NotPSDError,),
**ignore: Any,
) -> MarginalLogLikelihood:
r"""Generic fallback method for fitting Gaussian processes.
Attempts to fit a model using the provided optimizer, then determines whether or
not to retry by evaluating a given policy on emitted warning messages. The first
attempt is run using the initialized parameter values; subsequent attempts begin
by resampling tunable parameters.
Args:
optimizer: The underlying optimization algorithm to run.
optimizer_kwargs: Keyword arguments passed when calling `optimizer`.
max_attempts: The maximum number of fit attempts allowed. The attempt budget
is NOT shared between calls to this method.
warning_filter: A function used to filter warnings produced when calling
`optimizer`. Any unfiltered warnings will be rethrown and trigger a
model fitting retry.
caught_exception_types: A tuple of exception types whose instances should
be redirected to `logging.DEBUG`.
**ignore: This function ignores unrecognized keyword arguments.
Returns:
The `mll` instance. If fitting succeeded, then `mll` will be in evaluation mode,
i.e. `mll.training == False`. Otherwise, `mll` will be in training mode.
"""
ckpt: Dict[str, Tuple[Tensor, Tkwargs]] = None # lazy CPU-based checkpoint
ckpt_nograd: Dict[str, Tuple[Tensor, Tkwargs]] = None # subset for fixed parameters
optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
mll.train()
for attempt in range(1, 1 + max_attempts):
# Wrap with rollback contextmanager so each loop iteration reloads the original
# state_dict upon exiting (unless `ckpt` is cleared).
with state_rollback_ctx(mll, checkpoint=ckpt, device=device("cpu")) as ckpt:
if ckpt_nograd is None:
ckpt_nograd = { # reuse cached values from primary checkpoint
k: ckpt[k] for k, v in mll.named_parameters() if not v.requires_grad
}
if attempt > 1: # maybe resample parameters that require gradients
with parameter_rollback_ctx(mll, checkpoint=ckpt_nograd):
sample_all_priors(mll.model)
try:
# Fit the model
with catch_warnings(record=True) as warning_list, debug(True):
simplefilter("always", category=OptimizationWarning)
mll, _ = optimizer(mll, **optimizer_kwargs)
# Resolve warning messages and determine whether or not to retry
done = True
for unresolved_warning in filter(warning_filter, warning_list):
warn(unresolved_warning.message, unresolved_warning.category)
done = False
if done:
ckpt.clear() # do not rollback upon exiting
return mll.eval()
# Ensure mll is in the right mode if fitting failed
mll = mll if mll.training else mll.train()
logging.log(
logging.DEBUG,
f"Fit attempt #{attempt} of {max_attempts} triggered retry policy"
f"{'.' if attempt == max_attempts else '; retrying...'}",
)
except caught_exception_types as err:
logging.log(
logging.DEBUG,
f"Fit attempt #{attempt} of {max_attempts} failed with exception: "
f"{err}",
)
raise ModelFittingError("All attempts to fit the model have failed.")
@dispatcher.register(SumMarginalLogLikelihood, Likelihood, ModelListGP)
def _fit_list(
mll: SumMarginalLogLikelihood,
_: Type[Likelihood],
__: Type[ModelListGP],
**kwargs: Any,
) -> SumMarginalLogLikelihood:
r"""Fitting routine for lists of independent Gaussian processes.
Args:
**kwargs: Passed to each of `mll.mlls`.
Returns:
The `mll` instance. If fitting succeeded for all of `mll.mlls`, then `mll` will
be in evaluation mode, i.e. `mll.training == False`. Otherwise, `mll` will be in
training mode.
"""
mll.train()
for sub_mll in mll.mlls:
fit_gpytorch_mll(sub_mll, **kwargs)
return mll.eval() if not any(sub_mll.training for sub_mll in mll.mlls) else mll
@dispatcher.register(MarginalLogLikelihood, Likelihood, BatchedMultiOutputGPyTorchModel)
def _fit_multioutput_independent(
mll: MarginalLogLikelihood,
_: Type[Likelihood],
__: Type[BatchedMultiOutputGPyTorchModel],
*,
sequential: bool = True,
**kwargs: Any,
) -> MarginalLogLikelihood:
r"""Fitting routine for multioutput Gaussian processes.
Args:
sequential: Boolean specifying whether or not to an attempt should be made to
fit the model as a collection of independent GPs. Only relevant for
certain types of GPs with independent outputs, see `batched_to_model_list`.
**kwargs: Passed to the next method unaltered.
Returns:
The `mll` instance. If fitting succeeded, then `mll` will be in evaluation mode,
i.e. `mll.training == False`. Otherwise, `mll` will be in training mode.
"""
if ( # incompatible models
not sequential
or mll.model.num_outputs == 1
or mll.likelihood is not getattr(mll.model, "likelihood", None)
):
raise MDNotImplementedError # defer to generic
# TODO: Unpacking of OutcomeTransforms not yet supported. Targets are often
# pre-transformed in __init__, so try fitting with outcome_transform hidden
mll.train()
with del_attribute_ctx(mll.model, "outcome_transform"):
try:
# Attempt to unpack batched model into a list of independent submodels
unpacked_model = batched_to_model_list(mll.model)
unpacked_mll = SumMarginalLogLikelihood( # avg. over MLLs internally
unpacked_model.likelihood, unpacked_model
)
if not allclose_mll(a=mll, b=unpacked_mll, transform_a=mean):
raise RuntimeError( # validate model unpacking
"Training loss of unpacked model differs from that of the original."
)
# Fit submodels independently
unpacked_mll = fit_gpytorch_mll(unpacked_mll, **kwargs)
# Repackage submodels and copy over state_dict
repacked_model = model_list_to_batched(unpacked_mll.model.train())
repacked_mll = type(mll)(repacked_model.likelihood, repacked_model)
with state_rollback_ctx(mll, device=device("cpu")) as ckpt:
mll.load_state_dict(repacked_mll.state_dict())
if not allclose_mll(a=mll, b=repacked_mll):
raise RuntimeError( # validate model repacking
"Training loss of repacked model differs from that of the "
"original."
)
ckpt.clear() # do not rollback when exiting
return mll.eval() # DONE!
except (AttributeError, RuntimeError, UnsupportedError) as err:
msg = f"Failed to independently fit submodels with exception: {err}"
warn(
f"{msg.rstrip('.')}. Deferring to generic dispatch...",
BotorchWarning,
)
raise MDNotImplementedError
[docs]def fit_fully_bayesian_model_nuts(
model: Union[SaasFullyBayesianSingleTaskGP, SaasFullyBayesianMultiTaskGP],
max_tree_depth: int = 6,
warmup_steps: int = 512,
num_samples: int = 256,
thinning: int = 16,
disable_progbar: bool = False,
jit_compile: bool = False,
) -> None:
r"""Fit a fully Bayesian model using the No-U-Turn-Sampler (NUTS)
Args:
model: SaasFullyBayesianSingleTaskGP to be fitted.
max_tree_depth: Maximum tree depth for NUTS
warmup_steps: The number of burn-in steps for NUTS.
num_samples: The number of MCMC samples. Note that with thinning,
num_samples / thinning samples are retained.
thinning: The amount of thinning. Every nth sample is retained.
disable_progbar: A boolean indicating whether to print the progress
bar and diagnostics during MCMC.
jit_compile: Whether to use jit. Using jit may be ~2X faster (rough estimate),
but it will also increase the memory usage and sometimes result in runtime
errors, e.g., https://github.com/pyro-ppl/pyro/issues/3136.
Example:
>>> gp = SaasFullyBayesianSingleTaskGP(train_X, train_Y)
>>> fit_fully_bayesian_model_nuts(gp)
"""
model.train()
# Do inference with NUTS
nuts = NUTS(
model.pyro_model.sample,
jit_compile=jit_compile,
full_mass=True,
ignore_jit_warnings=True,
max_tree_depth=max_tree_depth,
)
mcmc = MCMC(
nuts,
warmup_steps=warmup_steps,
num_samples=num_samples,
disable_progbar=disable_progbar,
)
mcmc.run()
# Get final MCMC samples from the Pyro model
mcmc_samples = model.pyro_model.postprocess_mcmc_samples(
mcmc_samples=mcmc.get_samples()
)
for k, v in mcmc_samples.items():
mcmc_samples[k] = v[::thinning]
# Load the MCMC samples back into the BoTorch model
model.load_mcmc_samples(mcmc_samples)
model.eval()