Source code for botorch.optim.homotopy
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Callable, Optional, Union
import torch
from torch import Tensor
from torch.nn import Parameter
[docs]
class FixedHomotopySchedule:
"""Homotopy schedule with a fixed list of values."""
def __init__(self, values: list[float]) -> None:
r"""Initialize FixedHomotopySchedule.
Args:
values: A list of values used in homotopy
"""
self._values = values
self.idx = 0
@property
def num_steps(self) -> int:
return len(self._values)
@property
def value(self) -> float:
return self._values[self.idx]
@property
def should_stop(self) -> bool:
return self.idx == len(self._values)
[docs]
def restart(self) -> None:
self.idx = 0
[docs]
def step(self) -> None:
self.idx += 1
[docs]
class LinearHomotopySchedule(FixedHomotopySchedule):
"""Linear homotopy schedule."""
def __init__(self, start: float, end: float, num_steps: int) -> None:
r"""Initialize LinearHomotopySchedule.
Args:
start: start value of homotopy
end: end value of homotopy
num_steps: number of steps in the homotopy schedule.
"""
super().__init__(
values=torch.linspace(start, end, num_steps, dtype=torch.double).tolist()
)
[docs]
class LogLinearHomotopySchedule(FixedHomotopySchedule):
"""Log-linear homotopy schedule."""
def __init__(self, start: float, end: float, num_steps: int):
r"""Initialize LogLinearHomotopySchedule.
Args:
start: start value of homotopy
end: end value of homotopy
num_steps: number of steps in the homotopy schedule.
"""
super().__init__(
values=torch.logspace(
math.log10(start), math.log10(end), num_steps, dtype=torch.double
).tolist()
)
[docs]
@dataclass
class HomotopyParameter:
r"""Homotopy parameter.
The parameter is expected to either be a torch parameter or a torch tensor which may
correspond to a buffer of a module. The parameter has a corresponding schedule.
"""
parameter: Union[Parameter, Tensor]
schedule: FixedHomotopySchedule
[docs]
class Homotopy:
"""Generic homotopy class.
This class is designed to be used in `optimize_acqf_homotopy`. Given a set of
homotopy parameters and corresponding schedules we step through the homotopies
until we have solved the final problem. We additionally support passing in a list
of callbacks that will be executed each time `step`, `reset`, and `restart` are
called.
"""
def __init__(
self,
homotopy_parameters: list[HomotopyParameter],
callbacks: Optional[list[Callable]] = None,
) -> None:
r"""Initialize the homotopy.
Args:
homotopy_parameters: List of homotopy parameters
callbacks: Optional list of callbacks that are executed each time
`restart`, `reset`, or `step` are called. These may be used to, e.g.,
reinitialize the acquisition function which is needed when using qNEHVI.
"""
self._homotopy_parameters = homotopy_parameters
self._callbacks = callbacks or []
self._original_values = [
hp.parameter.item() for hp in self._homotopy_parameters
]
assert all(
isinstance(hp.parameter, Parameter) or isinstance(hp.parameter, Tensor)
for hp in self._homotopy_parameters
)
# Assume the same number of steps for now
assert len({h.schedule.num_steps for h in self._homotopy_parameters}) == 1
# Initialize the homotopy parameters
self.restart()
def _execute_callbacks(self) -> None:
"""Execute the callbacks."""
for callback in self._callbacks:
callback()
@property
def should_stop(self) -> bool:
"""Returns true if all schedules have reached the end."""
return all(h.schedule.should_stop for h in self._homotopy_parameters)
[docs]
def restart(self) -> None:
"""Restart the homotopy to use the initial value in the schedule."""
for hp in self._homotopy_parameters:
hp.schedule.restart()
hp.parameter.data.fill_(hp.schedule.value)
self._execute_callbacks()
[docs]
def reset(self) -> None:
"""Reset the homotopy parameter to their original values."""
for hp, val in zip(self._homotopy_parameters, self._original_values):
hp.parameter.data.fill_(val)
self._execute_callbacks()
[docs]
def step(self) -> None:
"""Take a step according to the schedules."""
for hp in self._homotopy_parameters:
hp.schedule.step()
if not hp.schedule.should_stop:
hp.parameter.data.fill_(hp.schedule.value)
self._execute_callbacks()