Source code for botorch.optim.homotopy

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import math
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Union

import torch
from torch import Tensor
from torch.nn import Parameter

[docs] class HomotopySchedule(ABC): @property @abstractmethod def num_steps(self) -> int: """Number of steps in the schedule.""" @property @abstractmethod def value(self) -> Any: """Current value in the schedule.""" @property @abstractmethod def should_stop(self) -> bool: """Return true if we have incremented past the end of the schedule."""
[docs] @abstractmethod def restart(self) -> None: """Restart the schedule to start from the beginning."""
[docs] @abstractmethod def step(self) -> None: """Move to solving the next problem."""
[docs] class FixedHomotopySchedule(HomotopySchedule): """Homotopy schedule with a fixed list of values.""" def __init__(self, values: List[Any]) -> None: r"""Initialize FixedHomotopySchedule. Args: values: A list of values used in homotopy """ self._values = values self.idx = 0 @property def num_steps(self) -> int: return len(self._values) @property def value(self) -> Any: return self._values[self.idx] @property def should_stop(self) -> bool: return self.idx == len(self._values)
[docs] def restart(self) -> None: self.idx = 0
[docs] def step(self) -> None: self.idx += 1
[docs] class LinearHomotopySchedule(FixedHomotopySchedule): """Linear homotopy schedule.""" def __init__(self, start: float, end: float, num_steps: int) -> None: r"""Initialize LinearHomotopySchedule. Args: start: start value of homotopy end: end value of homotopy num_steps: number of steps in the homotopy schedule. """ super().__init__( values=torch.linspace(start, end, num_steps, dtype=torch.double).tolist() )
[docs] class LogLinearHomotopySchedule(FixedHomotopySchedule): """Log-linear homotopy schedule.""" def __init__(self, start: float, end: float, num_steps: int): r"""Initialize LogLinearHomotopySchedule. Args: start: start value of homotopy end: end value of homotopy num_steps: number of steps in the homotopy schedule. """ super().__init__( values=torch.logspace( math.log10(start), math.log10(end), num_steps, dtype=torch.double ).tolist() )
[docs] @dataclass class HomotopyParameter: r"""Homotopy parameter. The parameter is expected to either be a torch parameter or a torch tensor which may correspond to a buffer of a module. The parameter has a corresponding schedule. """ parameter: Union[Parameter, Tensor] schedule: HomotopySchedule
[docs] class Homotopy: """Generic homotopy class. This class is designed to be used in `optimize_acqf_homotopy`. Given a set of homotopy parameters and corresponding schedules we step through the homotopies until we have solved the final problem. We additionally support passing in a list of callbacks that will be executed each time `step`, `reset`, and `restart` are called. """ def __init__( self, homotopy_parameters: List[HomotopyParameter], callbacks: Optional[List[Callable]] = None, ) -> None: r"""Initialize the homotopy. Args: homotopy_parameters: List of homotopy parameters callbacks: Optional list of callbacks that are executed each time `restart`, `reset`, or `step` are called. These may be used to, e.g., reinitialize the acquisition function which is needed when using qNEHVI. """ self._homotopy_parameters = homotopy_parameters self._callbacks = callbacks or [] self._original_values = [ hp.parameter.item() for hp in self._homotopy_parameters ] assert all( isinstance(hp.parameter, Parameter) or isinstance(hp.parameter, Tensor) for hp in self._homotopy_parameters ) # Assume the same number of steps for now assert len({h.schedule.num_steps for h in self._homotopy_parameters}) == 1 # Initialize the homotopy parameters self.restart() def _execute_callbacks(self) -> None: """Execute the callbacks.""" for callback in self._callbacks: callback() @property def should_stop(self) -> bool: """Returns true if all schedules have reached the end.""" return all(h.schedule.should_stop for h in self._homotopy_parameters)
[docs] def restart(self) -> None: """Restart the homotopy to use the initial value in the schedule.""" for hp in self._homotopy_parameters: hp.schedule.restart() self._execute_callbacks()
[docs] def reset(self) -> None: """Reset the homotopy parameter to their original values.""" for hp, val in zip(self._homotopy_parameters, self._original_values): self._execute_callbacks()
[docs] def step(self) -> None: """Take a step according to the schedules.""" for hp in self._homotopy_parameters: hp.schedule.step() if not hp.schedule.should_stop: self._execute_callbacks()