Source code for botorch.optim.fit

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""Tools for model fitting."""

from __future__ import annotations

from collections.abc import Sequence

from functools import partial
from typing import Any, Callable, Optional, Union
from warnings import warn

from botorch.exceptions.warnings import OptimizationWarning
from botorch.optim.closures import get_loss_closure_with_grads
from botorch.optim.core import (
    OptimizationResult,
    OptimizationStatus,
    scipy_minimize,
    torch_minimize,
)
from botorch.optim.stopping import ExpMAStoppingCriterion
from botorch.optim.utils import get_parameters_and_bounds, TorchAttr
from botorch.utils.types import DEFAULT
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from numpy import ndarray
from torch import Tensor
from torch.nn import Module
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer

TBoundsDict = dict[str, tuple[Optional[float], Optional[float]]]
TScipyObjective = Callable[
    [ndarray, MarginalLogLikelihood, dict[str, TorchAttr]], tuple[float, ndarray]
]
TModToArray = Callable[
    [Module, Optional[TBoundsDict], Optional[set[str]]],
    tuple[ndarray, dict[str, TorchAttr], Optional[ndarray]],
]
TArrayToMod = Callable[[Module, ndarray, dict[str, TorchAttr]], Module]


[docs] def fit_gpytorch_mll_scipy( mll: MarginalLogLikelihood, parameters: Optional[dict[str, Tensor]] = None, bounds: Optional[dict[str, tuple[Optional[float], Optional[float]]]] = None, closure: Optional[Callable[[], tuple[Tensor, Sequence[Optional[Tensor]]]]] = None, closure_kwargs: Optional[dict[str, Any]] = None, method: str = "L-BFGS-B", options: Optional[dict[str, Any]] = None, callback: Optional[Callable[[dict[str, Tensor], OptimizationResult], None]] = None, timeout_sec: Optional[float] = None, ) -> OptimizationResult: r"""Generic scipy.optimized-based fitting routine for GPyTorch MLLs. The model and likelihood in mll must already be in train mode. Args: mll: MarginalLogLikelihood to be maximized. parameters: Optional dictionary of parameters to be optimized. Defaults to all parameters of `mll` that require gradients. bounds: A dictionary of user-specified bounds for `parameters`. Used to update default parameter bounds obtained from `mll`. closure: Callable that returns a tensor and an iterable of gradient tensors. Responsible for setting the `grad` attributes of `parameters`. If no closure is provided, one will be obtained by calling `get_loss_closure_with_grads`. closure_kwargs: Keyword arguments passed to `closure`. method: Solver type, passed along to scipy.minimize. options: Dictionary of solver options, passed along to scipy.minimize. callback: Optional callback taking `parameters` and an OptimizationResult as its sole arguments. timeout_sec: Timeout in seconds after which to terminate the fitting loop (note that timing out can result in bad fits!). Returns: The final OptimizationResult. """ # Resolve `parameters` and update default bounds _parameters, _bounds = get_parameters_and_bounds(mll) bounds = _bounds if bounds is None else {**_bounds, **bounds} if parameters is None: parameters = {n: p for n, p in _parameters.items() if p.requires_grad} if closure is None: closure = get_loss_closure_with_grads(mll, parameters=parameters) if closure_kwargs is not None: closure = partial(closure, **closure_kwargs) result = scipy_minimize( closure=closure, parameters=parameters, bounds=bounds, method=method, options=options, callback=callback, timeout_sec=timeout_sec, ) if result.status != OptimizationStatus.SUCCESS: warn( f"`scipy_minimize` terminated with status {result.status}, displaying" f" original message from `scipy.optimize.minimize`: {result.message}", OptimizationWarning, ) return result
[docs] def fit_gpytorch_mll_torch( mll: MarginalLogLikelihood, parameters: Optional[dict[str, Tensor]] = None, bounds: Optional[dict[str, tuple[Optional[float], Optional[float]]]] = None, closure: Optional[Callable[[], tuple[Tensor, Sequence[Optional[Tensor]]]]] = None, closure_kwargs: Optional[dict[str, Any]] = None, step_limit: Optional[int] = None, stopping_criterion: Optional[Callable[[Tensor], bool]] = DEFAULT, # pyre-ignore [9] optimizer: Union[Optimizer, Callable[..., Optimizer]] = Adam, scheduler: Optional[Union[_LRScheduler, Callable[..., _LRScheduler]]] = None, callback: Optional[Callable[[dict[str, Tensor], OptimizationResult], None]] = None, timeout_sec: Optional[float] = None, ) -> OptimizationResult: r"""Generic torch.optim-based fitting routine for GPyTorch MLLs. Args: mll: MarginalLogLikelihood to be maximized. parameters: Optional dictionary of parameters to be optimized. Defaults to all parameters of `mll` that require gradients. bounds: A dictionary of user-specified bounds for `parameters`. Used to update default parameter bounds obtained from `mll`. closure: Callable that returns a tensor and an iterable of gradient tensors. Responsible for setting the `grad` attributes of `parameters`. If no closure is provided, one will be obtained by calling `get_loss_closure_with_grads`. closure_kwargs: Keyword arguments passed to `closure`. step_limit: Optional upper bound on the number of optimization steps. stopping_criterion: A StoppingCriterion for the optimization loop. optimizer: A `torch.optim.Optimizer` instance or a factory that takes a list of parameters and returns an `Optimizer` instance. scheduler: A `torch.optim.lr_scheduler._LRScheduler` instance or a factory that takes an `Optimizer` instance and returns an `_LRSchedule`. callback: Optional callback taking `parameters` and an OptimizationResult as its sole arguments. timeout_sec: Timeout in seconds after which to terminate the fitting loop (note that timing out can result in bad fits!). Returns: The final OptimizationResult. """ if stopping_criterion == DEFAULT: stopping_criterion = ExpMAStoppingCriterion() # Resolve `parameters` and update default bounds param_dict, bounds_dict = get_parameters_and_bounds(mll) if parameters is None: parameters = {n: p for n, p in param_dict.items() if p.requires_grad} if closure is None: closure = get_loss_closure_with_grads(mll, parameters) if closure_kwargs is not None: closure = partial(closure, **closure_kwargs) return torch_minimize( closure=closure, parameters=parameters, bounds=bounds_dict if bounds is None else {**bounds_dict, **bounds}, optimizer=optimizer, scheduler=scheduler, step_limit=step_limit, stopping_criterion=stopping_criterion, callback=callback, timeout_sec=timeout_sec, )