Source code for botorch.optim.stopping

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from abc import ABC, abstractmethod

import torch
from torch import Tensor


class StoppingCriterion(ABC):
    r"""Base class for evaluating optimization convergence.

    Stopping criteria are implemented as a objects rather than a function, so that they
    can keep track of past function values between optimization steps.

    :meta private:
    """

    @abstractmethod
    def evaluate(self, fvals: Tensor) -> bool:
        r"""Evaluate the stopping criterion.

        Args:
            fvals: tensor containing function values for the current iteration. If
                `fvals` contains more than one element, then the stopping criterion is
                evaluated element-wise and True is returned if the stopping criterion is
                true for all elements.

        Returns:
            Stopping indicator (if True, stop the optimziation).
        """
        pass  # pragma: no cover

    def __call__(self, fvals: Tensor) -> bool:
        return self.evaluate(fvals)


[docs] class ExpMAStoppingCriterion(StoppingCriterion): r"""Exponential moving average stopping criterion. Computes an exponentially weighted moving average over window length `n_window` and checks whether the relative decrease in this moving average between steps is less than a provided tolerance level. That is, in iteration `i`, it computes v[i,j] := fvals[i - n_window + j] * w[j] for all `j = 0, ..., n_window`, where `w[j] = exp(-eta * (1 - j / n_window))`. Letting `ma[i] := sum_j(v[i,j])`, the criterion evaluates to `True` whenever (ma[i-1] - ma[i]) / abs(ma[i-1]) < rel_tol (if minimize=True) (ma[i] - ma[i-1]) / abs(ma[i-1]) < rel_tol (if minimize=False) """ def __init__( self, maxiter: int = 10000, minimize: bool = True, n_window: int = 10, eta: float = 1.0, rel_tol: float = 1e-5, ) -> None: r"""Exponential moving average stopping criterion. Args: maxiter: Maximum number of iterations. minimize: If True, assume minimization. n_window: The size of the exponential moving average window. eta: The exponential decay factor in the weights. rel_tol: Relative tolerance for termination. """ self.maxiter = maxiter self.minimize = minimize self.n_window = n_window self.rel_tol = rel_tol self.iter = 0 weights = torch.exp(torch.linspace(-eta, 0, self.n_window)) self.weights = weights / weights.sum() self._prev_fvals = None
[docs] def evaluate(self, fvals: Tensor) -> bool: r"""Evaluate the stopping criterion. Args: fvals: tensor containing function values for the current iteration. If `fvals` contains more than one element, then the stopping criterion is evaluated element-wise and True is returned if the stopping criterion is true for all elements. TODO: add support for utilizing gradient information Returns: Stopping indicator (if True, stop the optimziation). """ self.iter += 1 if self.iter == self.maxiter: return True if self._prev_fvals is None: self._prev_fvals = fvals.unsqueeze(0) else: self._prev_fvals = torch.cat( [self._prev_fvals[-self.n_window :], fvals.unsqueeze(0)] ) if self._prev_fvals.size(0) < self.n_window + 1: return False weights = self.weights weights = weights.to(fvals) if self._prev_fvals.ndim > 1: weights = weights.unsqueeze(-1) # TODO: Update the exp moving average efficiently prev_ma = (self._prev_fvals[:-1] * weights).sum(dim=0) ma = (self._prev_fvals[1:] * weights).sum(dim=0) # TODO: Handle approx. zero losses (normalize by min/max loss range) rel_delta = (prev_ma - ma) / prev_ma.abs() if not self.minimize: rel_delta = -rel_delta if torch.max(rel_delta) < self.rel_tol: return True return False