# Source code for botorch.optim.homotopy

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import math
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Union

import torch
from torch import Tensor
from torch.nn import Parameter

[docs]
class HomotopySchedule(ABC):
@property
@abstractmethod
def num_steps(self) -> int:
"""Number of steps in the schedule."""

@property
@abstractmethod
def value(self) -> Any:
"""Current value in the schedule."""

@property
@abstractmethod
def should_stop(self) -> bool:
"""Return true if we have incremented past the end of the schedule."""

[docs]
@abstractmethod
def restart(self) -> None:
"""Restart the schedule to start from the beginning."""

[docs]
@abstractmethod
def step(self) -> None:
"""Move to solving the next problem."""

[docs]
class FixedHomotopySchedule(HomotopySchedule):
"""Homotopy schedule with a fixed list of values."""

def __init__(self, values: List[Any]) -> None:
r"""Initialize FixedHomotopySchedule.

Args:
values: A list of values used in homotopy
"""
self._values = values
self.idx = 0

@property
def num_steps(self) -> int:
return len(self._values)

@property
def value(self) -> Any:
return self._values[self.idx]

@property
def should_stop(self) -> bool:
return self.idx == len(self._values)

[docs]
def restart(self) -> None:
self.idx = 0

[docs]
def step(self) -> None:
self.idx += 1

[docs]
class LinearHomotopySchedule(FixedHomotopySchedule):
"""Linear homotopy schedule."""

def __init__(self, start: float, end: float, num_steps: int) -> None:
r"""Initialize LinearHomotopySchedule.

Args:
start: start value of homotopy
end: end value of homotopy
num_steps: number of steps in the homotopy schedule.
"""
super().__init__(
values=torch.linspace(start, end, num_steps, dtype=torch.double).tolist()
)

[docs]
class LogLinearHomotopySchedule(FixedHomotopySchedule):
"""Log-linear homotopy schedule."""

def __init__(self, start: float, end: float, num_steps: int):
r"""Initialize LogLinearHomotopySchedule.

Args:
start: start value of homotopy
end: end value of homotopy
num_steps: number of steps in the homotopy schedule.
"""
super().__init__(
values=torch.logspace(
math.log10(start), math.log10(end), num_steps, dtype=torch.double
).tolist()
)

[docs]
@dataclass
class HomotopyParameter:
r"""Homotopy parameter.

The parameter is expected to either be a torch parameter or a torch tensor which may
correspond to a buffer of a module. The parameter has a corresponding schedule.
"""

parameter: Union[Parameter, Tensor]
schedule: HomotopySchedule

[docs]
class Homotopy:
"""Generic homotopy class.

This class is designed to be used in optimize_acqf_homotopy. Given a set of
homotopy parameters and corresponding schedules we step through the homotopies
until we have solved the final problem. We additionally support passing in a list
of callbacks that will be executed each time step, reset, and restart are
called.
"""

def __init__(
self,
homotopy_parameters: List[HomotopyParameter],
callbacks: Optional[List[Callable]] = None,
) -> None:
r"""Initialize the homotopy.

Args:
homotopy_parameters: List of homotopy parameters
callbacks: Optional list of callbacks that are executed each time
restart, reset, or step are called. These may be used to, e.g.,
reinitialize the acquisition function which is needed when using qNEHVI.
"""
self._homotopy_parameters = homotopy_parameters
self._callbacks = callbacks or []
self._original_values = [
hp.parameter.item() for hp in self._homotopy_parameters
]
assert all(
isinstance(hp.parameter, Parameter) or isinstance(hp.parameter, Tensor)
for hp in self._homotopy_parameters
)
# Assume the same number of steps for now
assert len({h.schedule.num_steps for h in self._homotopy_parameters}) == 1
# Initialize the homotopy parameters
self.restart()

def _execute_callbacks(self) -> None:
"""Execute the callbacks."""
for callback in self._callbacks:
callback()

@property
def should_stop(self) -> bool:
"""Returns true if all schedules have reached the end."""
return all(h.schedule.should_stop for h in self._homotopy_parameters)

[docs]
def restart(self) -> None:
"""Restart the homotopy to use the initial value in the schedule."""
for hp in self._homotopy_parameters:
hp.schedule.restart()
hp.parameter.data.fill_(hp.schedule.value)
self._execute_callbacks()

[docs]
def reset(self) -> None:
"""Reset the homotopy parameter to their original values."""
for hp, val in zip(self._homotopy_parameters, self._original_values):
hp.parameter.data.fill_(val)
self._execute_callbacks()

[docs]
def step(self) -> None:
"""Take a step according to the schedules."""
for hp in self._homotopy_parameters:
hp.schedule.step()
if not hp.schedule.should_stop:
hp.parameter.data.fill_(hp.schedule.value)
self._execute_callbacks()