#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""Utilities for building model-based closures."""
from __future__ import annotations
from itertools import chain, repeat
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
from botorch.optim.closures.core import ForwardBackwardClosure
from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder
from botorch.utils.types import NoneType
from gpytorch.mlls import (
ExactMarginalLogLikelihood,
MarginalLogLikelihood,
SumMarginalLogLikelihood,
)
from torch import Tensor
from torch.utils.data import DataLoader
GetLossClosure = Dispatcher("get_loss_closure", encoder=type_bypassing_encoder)
GetLossClosureWithGrads = Dispatcher(
"get_loss_closure_with_grads", encoder=type_bypassing_encoder
)
[docs]def get_loss_closure(
mll: MarginalLogLikelihood,
data_loader: Optional[DataLoader] = None,
**kwargs: Any,
) -> Callable[[], Tensor]:
r"""Public API for GetLossClosure dispatcher.
This method, and the dispatcher that powers it, acts as a clearing house
for factory functions that define how `mll` is evaluated.
Users may specify custom evaluation routines by registering a factory function
with GetLossClosure. These factories should be registered using the type signature
`Type[MarginalLogLikeLihood], Type[Likelihood], Type[Model], Type[DataLoader]`.
The final argument, Type[DataLoader], is optional. Evaluation routines that obtain
training data from, e.g., `mll.model` should register this argument as `type(None)`.
Args:
mll: A MarginalLogLikelihood instance whose negative defines the loss.
data_loader: An optional DataLoader instance for cases where training
data is passed in rather than obtained from `mll.model`.
Returns:
A closure that takes zero positional arguments and returns the negated
value of `mll`.
"""
return GetLossClosure(
mll, type(mll.likelihood), type(mll.model), data_loader, **kwargs
)
[docs]def get_loss_closure_with_grads(
mll: MarginalLogLikelihood,
parameters: Dict[str, Tensor],
data_loader: Optional[DataLoader] = None,
backward: Callable[[Tensor], None] = Tensor.backward,
reducer: Optional[Callable[[Tensor], Tensor]] = Tensor.sum,
context_manager: Optional[Callable] = None,
**kwargs: Any,
) -> Callable[[], Tuple[Tensor, Tuple[Tensor, ...]]]:
r"""Public API for GetLossClosureWithGrads dispatcher.
In most cases, this method simply adds a backward pass to a loss closure obtained by
calling `get_loss_closure`. For further details, see `get_loss_closure`.
Args:
mll: A MarginalLogLikelihood instance whose negative defines the loss.
parameters: A dictionary of tensors whose `grad` fields are to be returned.
reducer: Optional callable used to reduce the output of the forward pass.
data_loader: An optional DataLoader instance for cases where training
data is passed in rather than obtained from `mll.model`.
context_manager: An optional ContextManager used to wrap each forward-backward
pass. Defaults to a `zero_grad_ctx` that zeroes the gradients of
`parameters` upon entry. None may be passed as an alias for `nullcontext`.
Returns:
A closure that takes zero positional arguments and returns the reduced and
negated value of `mll` along with the gradients of `parameters`.
"""
return GetLossClosureWithGrads(
mll,
type(mll.likelihood),
type(mll.model),
data_loader,
parameters=parameters,
reducer=reducer,
backward=backward,
context_manager=context_manager,
**kwargs,
)
@GetLossClosureWithGrads.register(object, object, object, object)
def _get_loss_closure_with_grads_fallback(
mll: MarginalLogLikelihood,
_: object,
__: object,
data_loader: Optional[DataLoader],
parameters: Dict[str, Tensor],
reducer: Callable[[Tensor], Tensor] = Tensor.sum,
backward: Callable[[Tensor], None] = Tensor.backward,
context_manager: Callable = None, # pyre-ignore [9]
**kwargs: Any,
) -> ForwardBackwardClosure:
r"""Wraps a `loss_closure` with a ForwardBackwardClosure."""
loss_closure = get_loss_closure(mll, data_loader=data_loader, **kwargs)
return ForwardBackwardClosure(
forward=loss_closure,
backward=backward,
parameters=parameters,
reducer=reducer,
context_manager=context_manager,
)
@GetLossClosure.register(MarginalLogLikelihood, object, object, DataLoader)
def _get_loss_closure_fallback_external(
mll: MarginalLogLikelihood,
_: object,
__: object,
data_loader: DataLoader,
**ignore: Any,
) -> Callable[[], Tensor]:
r"""Fallback loss closure with externally provided data."""
batch_generator = chain.from_iterable(iter(data_loader) for _ in repeat(None))
def closure(**kwargs: Any) -> Tensor:
batch = next(batch_generator)
if not isinstance(batch, Sequence):
raise TypeError(
"Expected `data_loader` to generate a batch of tensors, "
f"but found {type(batch)}."
)
num_inputs = len(mll.model.train_inputs)
model_output = mll.model(*batch[:num_inputs])
log_likelihood = mll(model_output, *batch[num_inputs:], **kwargs)
return -log_likelihood
return closure
@GetLossClosure.register(MarginalLogLikelihood, object, object, NoneType)
def _get_loss_closure_fallback_internal(
mll: MarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any
) -> Callable[[], Tensor]:
r"""Fallback loss closure with internally managed data."""
def closure(**kwargs: Any) -> Tensor:
model_output = mll.model(*mll.model.train_inputs)
log_likelihood = mll(model_output, mll.model.train_targets, **kwargs)
return -log_likelihood
return closure
@GetLossClosure.register(ExactMarginalLogLikelihood, object, object, NoneType)
def _get_loss_closure_exact_internal(
mll: ExactMarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any
) -> Callable[[], Tensor]:
r"""ExactMarginalLogLikelihood loss closure with internally managed data."""
def closure(**kwargs: Any) -> Tensor:
model_output = mll.model(*mll.model.train_inputs)
log_likelihood = mll(
model_output, mll.model.train_targets, *mll.model.train_inputs, **kwargs
)
return -log_likelihood
return closure
@GetLossClosure.register(SumMarginalLogLikelihood, object, object, NoneType)
def _get_loss_closure_sum_internal(
mll: SumMarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any
) -> Callable[[], Tensor]:
r"""SumMarginalLogLikelihood loss closure with internally managed data."""
def closure(**kwargs: Any) -> Tensor:
model_output = mll.model(*mll.model.train_inputs)
log_likelihood = mll(
model_output,
mll.model.train_targets,
*map(list, mll.model.train_inputs),
**kwargs,
)
return -log_likelihood
return closure