Source code for botorch.utils.context_managers

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Utilities for optimization.
"""

from __future__ import annotations

from collections.abc import Generator, Iterable

from contextlib import contextmanager
from typing import Any, Callable, NamedTuple, Optional, Union

from torch import device as Device, dtype as Dtype, Tensor
from torch.nn import Module


[docs] class TensorCheckpoint(NamedTuple): values: Tensor device: Optional[Device] = None dtype: Optional[Dtype] = None
[docs] @contextmanager def delattr_ctx( instance: object, *attrs: str, enforce_hasattr: bool = False ) -> Generator[None, None, None]: r"""Contextmanager for temporarily deleting attributes.""" try: cache = {} for key in attrs: if hasattr(instance, key): cache[key] = getattr(instance, key) delattr(instance, key) elif enforce_hasattr: raise ValueError( f"Attribute {key} missing from {type(instance)} instance." ) yield finally: for key, cached_val in cache.items(): setattr(instance, key, cached_val)
[docs] @contextmanager def parameter_rollback_ctx( parameters: dict[str, Tensor], checkpoint: Optional[dict[str, TensorCheckpoint]] = None, **tkwargs: Any, ) -> Generator[dict[str, TensorCheckpoint], None, None]: r"""Contextmanager that exits by rolling back a module's state_dict. Args: module: Module instance. name_filter: Optional Boolean function used to filter items by name. checkpoint: Optional cache of values and tensor metadata specifying the rollback state for the module (or some subset thereof). **tkwargs: Keyword arguments passed to `torch.Tensor.to` when copying data from each tensor in `module.state_dict()` to the internally created checkpoint. Only adhered to when the `checkpoint` argument is None. Yields: A dictionary of TensorCheckpoints for the module's state_dict. Any in-places changes to the checkpoint will be observed at rollback time. If the checkpoint is cleared, no rollback will occur. """ # Create copies of the orginal values if checkpoint is None: checkpoint = { name: TensorCheckpoint( values=param.detach().to(**tkwargs).clone(), device=param.device, dtype=param.dtype, ) for name, param in parameters.items() } try: # yield the checkpoint dictionary to the user yield checkpoint finally: # restore original values of tracked parameters if checkpoint: for name, param in parameters.items(): if name in checkpoint: values, device, dtype = checkpoint[name] param.data.copy_(values.to(device=device, dtype=dtype))
[docs] @contextmanager def module_rollback_ctx( module: Module, name_filter: Optional[Callable[[str], bool]] = None, checkpoint: Optional[dict[str, TensorCheckpoint]] = None, **tkwargs: Any, ) -> Generator[dict[str, TensorCheckpoint], None, None]: r"""Contextmanager that exits by rolling back a module's state_dict. Args: module: Module instance. name_filter: Optional Boolean function used to filter items by name. checkpoint: Optional cache of values and tensor metadata specifying the rollback state for the module (or some subset thereof). **tkwargs: Keyword arguments passed to `torch.Tensor.to` when copying data from each tensor in `module.state_dict()` to the internally created checkpoint. Only adhered to when the `checkpoint` argument is None. Yields: A dictionary of TensorCheckpoints for the module's state_dict. Any in-places changes to the checkpoint will be observed at rollback time. If the checkpoint is cleared, no rollback will occur. """ # Create copies of the orginal values if checkpoint is None: checkpoint = { name: TensorCheckpoint( values=values.detach().to(**tkwargs).clone(), device=values.device, dtype=values.dtype, ) for name, values in module.state_dict().items() if name_filter is None or name_filter(name) } try: # yield the checkpoint dictionary to the user yield checkpoint finally: # restore original values of tracked parameters if checkpoint: state_dict = module.state_dict() for key, (values, device, dtype) in checkpoint.items(): tnsr = state_dict.get(key) if tnsr is None: state_dict[key] = values.to(device=device, dtype=dtype) else: tnsr[...] = values.to(device=device, dtype=dtype) module.load_state_dict(state_dict)
[docs] @contextmanager def zero_grad_ctx( parameters: Union[dict[str, Tensor], Iterable[Tensor]], zero_on_enter: bool = True, zero_on_exit: bool = False, ) -> Generator[None, None, None]: def zero_() -> None: for param in ( parameters.values() if isinstance(parameters, dict) else parameters ): if param.grad is not None: param.grad.zero_() if zero_on_enter: zero_() try: yield finally: if zero_on_exit: zero_()