Source code for botorch.fit
#!/usr/bin/env python3
r"""
Utilities for model fitting.
"""
from typing import Any, Callable
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from .optim.fit import fit_gpytorch_scipy
[docs]def fit_gpytorch_model(
mll: MarginalLogLikelihood, optimizer: Callable = fit_gpytorch_scipy, **kwargs: Any
) -> MarginalLogLikelihood:
r"""Fit hyperparameters of a gpytorch model.
Optimizer functions are in botorch.optim.fit.
Args:
mll: MarginalLogLikelihood to be maximized.
optimizer: The optimizer function.
kwargs: Arguments passed along to the optimizer function.
Returns:
MarginalLogLikelihood with optimized parameters.
Example:
>>> gp = SingleTaskGP(train_X, train_Y)
>>> mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
>>> fit_gpytorch_model(mll)
"""
mll.train()
mll, _ = optimizer(mll, track_iterations=False, **kwargs)
mll.eval()
return mll