Source code for botorch.optim.closures.core

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Core methods for building closures in torch and interfacing with numpy."""

from __future__ import annotations

from collections.abc import Callable, Sequence

from functools import partial
from typing import Any

import numpy.typing as npt

import torch
from botorch.optim.utils import (
    _handle_numerical_errors,
    get_tensors_as_ndarray_1d,
    set_tensors_from_ndarray_1d,
)
from botorch.optim.utils.numpy_utils import as_ndarray
from botorch.utils.context_managers import zero_grad_ctx
from numpy import float64 as np_float64, full as np_full, zeros as np_zeros
from torch import Tensor


[docs] class ForwardBackwardClosure: r"""Wrapper for fused forward and backward closures.""" def __init__( self, forward: Callable[[], Tensor], parameters: dict[str, Tensor], backward: Callable[[Tensor], None] = Tensor.backward, reducer: Callable[[Tensor], Tensor] | None = torch.sum, callback: Callable[[Tensor, Sequence[Tensor | None]], None] | None = None, context_manager: Callable = None, # pyre-ignore [9] ) -> None: r"""Initializes a ForwardBackwardClosure instance. Args: closure: Callable that returns a tensor. parameters: A dictionary of tensors whose `grad` fields are to be returned. backward: Callable that takes the (reduced) output of `forward` and sets the `grad` attributes of tensors in `parameters`. reducer: Optional callable used to reduce the output of the forward pass. callback: Optional callable that takes the reduced output of `forward` and the gradients of `parameters` as positional arguments. context_manager: A ContextManager used to wrap each forward-backward call. When passed as `None`, `context_manager` defaults to a `zero_grad_ctx` that zeroes the gradients of `parameters` upon entry. """ if context_manager is None: context_manager = partial(zero_grad_ctx, parameters) self.forward = forward self.backward = backward self.parameters = parameters self.reducer = reducer self.callback = callback self.context_manager = context_manager def __call__(self, **kwargs: Any) -> tuple[Tensor, tuple[Tensor | None, ...]]: with self.context_manager(): values = self.forward(**kwargs) value = values if self.reducer is None else self.reducer(values) self.backward(value) grads = tuple(param.grad for param in self.parameters.values()) if self.callback: self.callback(value, grads) return value, grads
[docs] class NdarrayOptimizationClosure: r"""Adds stateful behavior and a numpy.ndarray-typed API to a closure with an expected return type Tuple[Tensor, Union[Tensor, Sequence[Optional[Tensor]]]].""" def __init__( self, closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]], parameters: dict[str, Tensor], as_array: Callable[[Tensor], npt.NDArray] = None, # pyre-ignore [9] get_state: Callable[[], npt.NDArray] = None, # pyre-ignore [9] set_state: Callable[[npt.NDArray], None] = None, # pyre-ignore [9] fill_value: float = 0.0, persistent: bool = True, ) -> None: r"""Initializes a NdarrayOptimizationClosure instance. Args: closure: A ForwardBackwardClosure instance. parameters: A dictionary of tensors representing the closure's state. Expected to correspond with the first `len(parameters)` optional gradient tensors returned by `closure`. as_array: Callable used to convert tensors to ndarrays. get_state: Callable that returns the closure's state as an ndarray. When passed as `None`, defaults to calling `get_tensors_as_ndarray_1d` on `closure.parameters` while passing `as_array` (if given by the user). set_state: Callable that takes a 1-dimensional ndarray and sets the closure's state. When passed as `None`, `set_state` defaults to calling `set_tensors_from_ndarray_1d` with `closure.parameters` and a given ndarray. fill_value: Fill value for parameters whose gradients are None. In most cases, `fill_value` should either be zero or NaN. persistent: Boolean specifying whether an ndarray should be retained as a persistent buffer for gradients. """ if get_state is None: # Note: Numpy supports copying data between ndarrays with different dtypes. # Hence, our default behavior need not coerce the ndarray representations # of tensors in `parameters` to float64 when copying over data. _as_array = as_ndarray if as_array is None else as_array get_state = partial( get_tensors_as_ndarray_1d, tensors=parameters, dtype=np_float64, as_array=_as_array, ) if as_array is None: # per the note, do this after resolving `get_state` as_array = partial(as_ndarray, dtype=np_float64) if set_state is None: set_state = partial(set_tensors_from_ndarray_1d, parameters) self.closure = closure self.parameters = parameters self.as_array = as_ndarray self._get_state = get_state self._set_state = set_state self.fill_value = fill_value self.persistent = persistent self._gradient_ndarray: npt.NDArray | None = None def __call__( self, state: npt.NDArray | None = None, **kwargs: Any ) -> tuple[npt.NDArray, npt.NDArray]: if state is not None: self.state = state try: value_tensor, grad_tensors = self.closure(**kwargs) value = self.as_array(value_tensor) grads = self._get_gradient_ndarray(fill_value=self.fill_value) index = 0 for param, grad in zip(self.parameters.values(), grad_tensors): size = param.numel() if grad is not None: grads[index : index + size] = self.as_array(grad.view(-1)) index += size except RuntimeError as e: value, grads = _handle_numerical_errors(e, x=self.state, dtype=np_float64) return value, grads @property def state(self) -> npt.NDArray: return self._get_state() @state.setter def state(self, state: npt.NDArray) -> None: self._set_state(state) def _get_gradient_ndarray(self, fill_value: float | None = None) -> npt.NDArray: if self.persistent and self._gradient_ndarray is not None: if fill_value is not None: self._gradient_ndarray.fill(fill_value) return self._gradient_ndarray size = sum(param.numel() for param in self.parameters.values()) array = ( np_zeros(size, dtype=np_float64) if fill_value is None or fill_value == 0.0 else np_full(size, fill_value, dtype=np_float64) ) if self.persistent: self._gradient_ndarray = array return array