Source code for botorch.optim.utils.timeout

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import time
import warnings
from collections.abc import Callable, Sequence
from typing import Any

import numpy.typing as npt
from botorch.exceptions.errors import OptimizationTimeoutError
from scipy import optimize


[docs] def minimize_with_timeout( fun: Callable[[npt.NDArray, ...], float], x0: npt.NDArray, args: tuple[Any, ...] = (), method: str | None = None, jac: str | Callable | bool | None = None, hess: str | Callable | optimize.HessianUpdateStrategy | None = None, hessp: Callable | None = None, bounds: Sequence[tuple[float, float]] | optimize.Bounds | None = None, constraints=(), # Typing this properly is a s**t job tol: float | None = None, callback: Callable | None = None, options: dict[str, Any] | None = None, timeout_sec: float | None = None, ) -> optimize.OptimizeResult: r"""Wrapper around scipy.optimize.minimize to support timeout. This method calls scipy.optimize.minimize with all arguments forwarded verbatim. The only difference is that if provided a `timeout_sec` argument, it will automatically stop the optimziation after the timeout is reached. Internally, this is achieved by automatically constructing a wrapper callback method that is injected to the scipy.optimize.minimize call and that keeps track of the runtime and the optimization variables at the current iteration. """ if timeout_sec is not None: start_time = time.monotonic() callback_data = {"num_iterations": 0} # update from withing callback below def timeout_callback(xk: npt.NDArray) -> bool: runtime = time.monotonic() - start_time callback_data["num_iterations"] += 1 if runtime > timeout_sec: raise OptimizationTimeoutError(current_x=xk, runtime=runtime) return False if callback is None: wrapped_callback = timeout_callback elif callable(method): raise NotImplementedError( "Custom callable not supported for `method` argument." ) elif method == "trust-constr": # special signature def wrapped_callback( xk: npt.NDArray, state: optimize.OptimizeResult ) -> bool: # order here is important to make sure base callback gets executed return callback(xk, state) or timeout_callback(xk=xk) else: def wrapped_callback(xk: npt.NDArray) -> None: timeout_callback(xk=xk) callback(xk) else: wrapped_callback = callback try: warnings.filterwarnings("error", message="Method .* cannot handle") return optimize.minimize( fun=fun, x0=x0, args=args, method=method, jac=jac, hess=hess, hessp=hessp, bounds=bounds, constraints=constraints, tol=tol, callback=wrapped_callback, options=options, ) except OptimizationTimeoutError as e: msg = f"Optimization timed out after {e.runtime} seconds." current_fun, *_ = fun(e.current_x, *args) return optimize.OptimizeResult( fun=current_fun, x=e.current_x, nit=callback_data["num_iterations"], success=False, # same as when maxiter is reached status=1, # same as when L-BFGS-B reaches maxiter message=msg, )