Source code for botorch.optim.closures.model_closures

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""Utilities for building model-based closures."""

from __future__ import annotations

from itertools import chain, repeat
from typing import Any, Callable, Dict, Optional, Sequence, Tuple

from botorch.optim.closures.core import ForwardBackwardClosure
from botorch.optim.utils import TNone
from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder
from gpytorch.mlls import (
    ExactMarginalLogLikelihood,
    MarginalLogLikelihood,
    SumMarginalLogLikelihood,
)
from torch import Tensor
from torch.utils.data import DataLoader

GetLossClosure = Dispatcher("get_loss_closure", encoder=type_bypassing_encoder)
GetLossClosureWithGrads = Dispatcher(
    "get_loss_closure_with_grads", encoder=type_bypassing_encoder
)


[docs]def get_loss_closure( mll: MarginalLogLikelihood, data_loader: Optional[DataLoader] = None, **kwargs: Any, ) -> Callable[[], Tensor]: r"""Public API for GetLossClosure dispatcher. This method, and the dispatcher that powers it, acts as a clearing house for factory functions that define how `mll` is evaluated. Users may specify custom evaluation routines by registering a factory function with GetLossClosure. These factories should be registered using the type signature `Type[MarginalLogLikeLihood], Type[Likelihood], Type[Model], Type[DataLoader]`. The final argument, Type[DataLoader], is optional. Evaluation routines that obtain training data from, e.g., `mll.model` should register this argument as `type(None)`. Args: mll: A MarginalLogLikelihood instance whose negative defines the loss. data_loader: An optional DataLoader instance for cases where training data is passed in rather than obtained from `mll.model`. Returns: A closure that takes zero positional arguments and returns the negated value of `mll`. """ return GetLossClosure( mll, type(mll.likelihood), type(mll.model), data_loader, **kwargs )
[docs]def get_loss_closure_with_grads( mll: MarginalLogLikelihood, parameters: Dict[str, Tensor], data_loader: Optional[DataLoader] = None, backward: Callable[[Tensor], None] = Tensor.backward, reducer: Optional[Callable[[Tensor], Tensor]] = Tensor.sum, context_manager: Optional[Callable] = None, **kwargs: Any, ) -> Callable[[], Tuple[Tensor, Tuple[Tensor, ...]]]: r"""Public API for GetLossClosureWithGrads dispatcher. In most cases, this method simply adds a backward pass to a loss closure obtained by calling `get_loss_closure`. For further details, see `get_loss_closure`. Args: mll: A MarginalLogLikelihood instance whose negative defines the loss. parameters: A dictionary of tensors whose `grad` fields are to be returned. reducer: Optional callable used to reduce the output of the forward pass. data_loader: An optional DataLoader instance for cases where training data is passed in rather than obtained from `mll.model`. context_manager: An optional ContextManager used to wrap each forward-backward pass. Defaults to a `zero_grad_ctx` that zeroes the gradients of `parameters` upon entry. None may be passed as an alias for `nullcontext`. Returns: A closure that takes zero positional arguments and returns the reduced and negated value of `mll` along with the gradients of `parameters`. """ return GetLossClosureWithGrads( mll, type(mll.likelihood), type(mll.model), data_loader, parameters=parameters, reducer=reducer, backward=backward, context_manager=context_manager, **kwargs, )
@GetLossClosureWithGrads.register(object, object, object, object) def _get_loss_closure_with_grads_fallback( mll: MarginalLogLikelihood, _: object, __: object, data_loader: Optional[DataLoader], parameters: Dict[str, Tensor], reducer: Callable[[Tensor], Tensor] = Tensor.sum, backward: Callable[[Tensor], None] = Tensor.backward, context_manager: Callable = None, # pyre-ignore [9] **kwargs: Any, ) -> ForwardBackwardClosure: r"""Wraps a `loss_closure` with a ForwardBackwardClosure.""" loss_closure = get_loss_closure(mll, data_loader=data_loader, **kwargs) return ForwardBackwardClosure( forward=loss_closure, backward=backward, parameters=parameters, reducer=reducer, context_manager=context_manager, ) @GetLossClosure.register(MarginalLogLikelihood, object, object, DataLoader) def _get_loss_closure_fallback_external( mll: MarginalLogLikelihood, _: object, __: object, data_loader: DataLoader, **ignore: Any, ) -> Callable[[], Tensor]: r"""Fallback loss closure with externally provided data.""" batch_generator = chain.from_iterable(iter(data_loader) for _ in repeat(None)) def closure(**kwargs: Any) -> Tensor: batch = next(batch_generator) if not isinstance(batch, Sequence): raise TypeError( "Expected `data_loader` to generate a batch of tensors, " f"but found {type(batch)}." ) num_inputs = len(mll.model.train_inputs) model_output = mll.model(*batch[:num_inputs]) log_likelihood = mll(model_output, *batch[num_inputs:], **kwargs) return -log_likelihood return closure @GetLossClosure.register(MarginalLogLikelihood, object, object, TNone) def _get_loss_closure_fallback_internal( mll: MarginalLogLikelihood, _: object, __: object, ___: TNone, **ignore: Any ) -> Callable[[], Tensor]: r"""Fallback loss closure with internally managed data.""" def closure(**kwargs: Any) -> Tensor: model_output = mll.model(*mll.model.train_inputs) log_likelihood = mll(model_output, mll.model.train_targets, **kwargs) return -log_likelihood return closure @GetLossClosure.register(ExactMarginalLogLikelihood, object, object, TNone) def _get_loss_closure_exact_internal( mll: ExactMarginalLogLikelihood, _: object, __: object, ___: TNone, **ignore: Any ) -> Callable[[], Tensor]: r"""ExactMarginalLogLikelihood loss closure with internally managed data.""" def closure(**kwargs: Any) -> Tensor: model_output = mll.model(*mll.model.train_inputs) log_likelihood = mll( model_output, mll.model.train_targets, *mll.model.train_inputs, **kwargs ) return -log_likelihood return closure @GetLossClosure.register(SumMarginalLogLikelihood, object, object, TNone) def _get_loss_closure_sum_internal( mll: SumMarginalLogLikelihood, _: object, __: object, ___: TNone, **ignore: Any ) -> Callable[[], Tensor]: r"""SumMarginalLogLikelihood loss closure with internally managed data.""" def closure(**kwargs: Any) -> Tensor: model_output = mll.model(*mll.model.train_inputs) log_likelihood = mll( model_output, mll.model.train_targets, *map(list, mll.model.train_inputs), **kwargs, ) return -log_likelihood return closure