Source code for botorch.optim.homotopy

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Callable, Optional, Union

import torch
from torch import Tensor
from torch.nn import Parameter


[docs] class FixedHomotopySchedule: """Homotopy schedule with a fixed list of values.""" def __init__(self, values: list[float]) -> None: r"""Initialize FixedHomotopySchedule. Args: values: A list of values used in homotopy """ self._values = values self.idx = 0 @property def num_steps(self) -> int: return len(self._values) @property def value(self) -> float: return self._values[self.idx] @property def should_stop(self) -> bool: return self.idx == len(self._values)
[docs] def restart(self) -> None: self.idx = 0
[docs] def step(self) -> None: self.idx += 1
[docs] class LinearHomotopySchedule(FixedHomotopySchedule): """Linear homotopy schedule.""" def __init__(self, start: float, end: float, num_steps: int) -> None: r"""Initialize LinearHomotopySchedule. Args: start: start value of homotopy end: end value of homotopy num_steps: number of steps in the homotopy schedule. """ super().__init__( values=torch.linspace(start, end, num_steps, dtype=torch.double).tolist() )
[docs] class LogLinearHomotopySchedule(FixedHomotopySchedule): """Log-linear homotopy schedule.""" def __init__(self, start: float, end: float, num_steps: int): r"""Initialize LogLinearHomotopySchedule. Args: start: start value of homotopy end: end value of homotopy num_steps: number of steps in the homotopy schedule. """ super().__init__( values=torch.logspace( math.log10(start), math.log10(end), num_steps, dtype=torch.double ).tolist() )
[docs] @dataclass class HomotopyParameter: r"""Homotopy parameter. The parameter is expected to either be a torch parameter or a torch tensor which may correspond to a buffer of a module. The parameter has a corresponding schedule. """ parameter: Union[Parameter, Tensor] schedule: FixedHomotopySchedule
[docs] class Homotopy: """Generic homotopy class. This class is designed to be used in `optimize_acqf_homotopy`. Given a set of homotopy parameters and corresponding schedules we step through the homotopies until we have solved the final problem. We additionally support passing in a list of callbacks that will be executed each time `step`, `reset`, and `restart` are called. """ def __init__( self, homotopy_parameters: list[HomotopyParameter], callbacks: Optional[list[Callable]] = None, ) -> None: r"""Initialize the homotopy. Args: homotopy_parameters: List of homotopy parameters callbacks: Optional list of callbacks that are executed each time `restart`, `reset`, or `step` are called. These may be used to, e.g., reinitialize the acquisition function which is needed when using qNEHVI. """ self._homotopy_parameters = homotopy_parameters self._callbacks = callbacks or [] self._original_values = [ hp.parameter.item() for hp in self._homotopy_parameters ] assert all( isinstance(hp.parameter, Parameter) or isinstance(hp.parameter, Tensor) for hp in self._homotopy_parameters ) # Assume the same number of steps for now assert len({h.schedule.num_steps for h in self._homotopy_parameters}) == 1 # Initialize the homotopy parameters self.restart() def _execute_callbacks(self) -> None: """Execute the callbacks.""" for callback in self._callbacks: callback() @property def should_stop(self) -> bool: """Returns true if all schedules have reached the end.""" return all(h.schedule.should_stop for h in self._homotopy_parameters)
[docs] def restart(self) -> None: """Restart the homotopy to use the initial value in the schedule.""" for hp in self._homotopy_parameters: hp.schedule.restart() hp.parameter.data.fill_(hp.schedule.value) self._execute_callbacks()
[docs] def reset(self) -> None: """Reset the homotopy parameter to their original values.""" for hp, val in zip(self._homotopy_parameters, self._original_values): hp.parameter.data.fill_(val) self._execute_callbacks()
[docs] def step(self) -> None: """Take a step according to the schedules.""" for hp in self._homotopy_parameters: hp.schedule.step() if not hp.schedule.should_stop: hp.parameter.data.fill_(hp.schedule.value) self._execute_callbacks()