#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""Utilities for building model-based closures."""
from __future__ import annotations
from collections.abc import Sequence
from itertools import chain, repeat
from types import NoneType
from typing import Any, Callable, Optional
from botorch.optim.closures.core import ForwardBackwardClosure
from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder
from gpytorch.mlls import (
ExactMarginalLogLikelihood,
MarginalLogLikelihood,
SumMarginalLogLikelihood,
)
from torch import Tensor
from torch.utils.data import DataLoader
GetLossClosure = Dispatcher("get_loss_closure", encoder=type_bypassing_encoder)
GetLossClosureWithGrads = Dispatcher(
"get_loss_closure_with_grads", encoder=type_bypassing_encoder
)
[docs]
def get_loss_closure(
mll: MarginalLogLikelihood,
data_loader: Optional[DataLoader] = None,
**kwargs: Any,
) -> Callable[[], Tensor]:
r"""Public API for GetLossClosure dispatcher.
This method, and the dispatcher that powers it, acts as a clearing house
for factory functions that define how `mll` is evaluated.
Users may specify custom evaluation routines by registering a factory function
with GetLossClosure. These factories should be registered using the type signature
`Type[MarginalLogLikeLihood], Type[Likelihood], Type[Model], Type[DataLoader]`.
The final argument, Type[DataLoader], is optional. Evaluation routines that obtain
training data from, e.g., `mll.model` should register this argument as `type(None)`.
Args:
mll: A MarginalLogLikelihood instance whose negative defines the loss.
data_loader: An optional DataLoader instance for cases where training
data is passed in rather than obtained from `mll.model`.
Returns:
A closure that takes zero positional arguments and returns the negated
value of `mll`.
"""
return GetLossClosure(
mll, type(mll.likelihood), type(mll.model), data_loader, **kwargs
)
[docs]
def get_loss_closure_with_grads(
mll: MarginalLogLikelihood,
parameters: dict[str, Tensor],
data_loader: Optional[DataLoader] = None,
backward: Callable[[Tensor], None] = Tensor.backward,
reducer: Optional[Callable[[Tensor], Tensor]] = Tensor.sum,
context_manager: Optional[Callable] = None,
**kwargs: Any,
) -> Callable[[], tuple[Tensor, tuple[Tensor, ...]]]:
r"""Public API for GetLossClosureWithGrads dispatcher.
In most cases, this method simply adds a backward pass to a loss closure obtained by
calling `get_loss_closure`. For further details, see `get_loss_closure`.
Args:
mll: A MarginalLogLikelihood instance whose negative defines the loss.
parameters: A dictionary of tensors whose `grad` fields are to be returned.
reducer: Optional callable used to reduce the output of the forward pass.
data_loader: An optional DataLoader instance for cases where training
data is passed in rather than obtained from `mll.model`.
context_manager: An optional ContextManager used to wrap each forward-backward
pass. Defaults to a `zero_grad_ctx` that zeroes the gradients of
`parameters` upon entry. None may be passed as an alias for `nullcontext`.
Returns:
A closure that takes zero positional arguments and returns the reduced and
negated value of `mll` along with the gradients of `parameters`.
"""
return GetLossClosureWithGrads(
mll,
type(mll.likelihood),
type(mll.model),
data_loader,
parameters=parameters,
reducer=reducer,
backward=backward,
context_manager=context_manager,
**kwargs,
)
@GetLossClosureWithGrads.register(object, object, object, object)
def _get_loss_closure_with_grads_fallback(
mll: MarginalLogLikelihood,
_likelihood_type: object,
_model_type: object,
data_loader: Optional[DataLoader],
parameters: dict[str, Tensor],
reducer: Callable[[Tensor], Tensor] = Tensor.sum,
backward: Callable[[Tensor], None] = Tensor.backward,
context_manager: Callable = None, # pyre-ignore [9]
**kwargs: Any,
) -> ForwardBackwardClosure:
r"""Wraps a `loss_closure` with a ForwardBackwardClosure."""
loss_closure = get_loss_closure(mll, data_loader=data_loader, **kwargs)
return ForwardBackwardClosure(
forward=loss_closure,
backward=backward,
parameters=parameters,
reducer=reducer,
context_manager=context_manager,
)
@GetLossClosure.register(MarginalLogLikelihood, object, object, DataLoader)
def _get_loss_closure_fallback_external(
mll: MarginalLogLikelihood,
_likelihood_type: object,
_model_type: object,
data_loader: DataLoader,
**ignore: Any,
) -> Callable[[], Tensor]:
r"""Fallback loss closure with externally provided data."""
batch_generator = chain.from_iterable(iter(data_loader) for _ in repeat(None))
def closure(**kwargs: Any) -> Tensor:
batch = next(batch_generator)
if not isinstance(batch, Sequence):
raise TypeError(
"Expected `data_loader` to generate a batch of tensors, "
f"but found {type(batch)}."
)
num_inputs = len(mll.model.train_inputs)
model_output = mll.model(*batch[:num_inputs])
log_likelihood = mll(model_output, *batch[num_inputs:], **kwargs)
return -log_likelihood
return closure
@GetLossClosure.register(MarginalLogLikelihood, object, object, NoneType)
def _get_loss_closure_fallback_internal(
mll: MarginalLogLikelihood, _: object, __: object, ___: None, **ignore: Any
) -> Callable[[], Tensor]:
r"""Fallback loss closure with internally managed data."""
def closure(**kwargs: Any) -> Tensor:
model_output = mll.model(*mll.model.train_inputs)
log_likelihood = mll(model_output, mll.model.train_targets, **kwargs)
return -log_likelihood
return closure
@GetLossClosure.register(ExactMarginalLogLikelihood, object, object, NoneType)
def _get_loss_closure_exact_internal(
mll: ExactMarginalLogLikelihood, _: object, __: object, ___: None, **ignore: Any
) -> Callable[[], Tensor]:
r"""ExactMarginalLogLikelihood loss closure with internally managed data."""
def closure(**kwargs: Any) -> Tensor:
model = mll.model
# The inputs will get transformed in forward here.
model_output = model(*model.train_inputs)
log_likelihood = mll(
model_output,
model.train_targets,
# During model training, the model inputs get transformed in the forward
# pass. The train_inputs property is not transformed yet, so we need to
# transform it before passing it to the likelihood for consistency.
*(model.transform_inputs(X=t_in) for t_in in model.train_inputs),
**kwargs,
)
return -log_likelihood
return closure
@GetLossClosure.register(SumMarginalLogLikelihood, object, object, NoneType)
def _get_loss_closure_sum_internal(
mll: SumMarginalLogLikelihood, _: object, __: object, ___: None, **ignore: Any
) -> Callable[[], Tensor]:
r"""SumMarginalLogLikelihood loss closure with internally managed data."""
def closure(**kwargs: Any) -> Tensor:
model = mll.model
# The inputs will get transformed in forward here.
model_output = model(*model.train_inputs)
log_likelihood = mll(
model_output,
model.train_targets,
# During model training, the model inputs get transformed in the forward
# pass. The train_inputs property is not transformed yet, so we need to
# transform it before passing it to the likelihood for consistency.
*(
(model.transform_inputs(X=t_in) for t_in in sub_t_in)
for sub_t_in in model.train_inputs
),
**kwargs,
)
return -log_likelihood
return closure