Source code for botorch.optim.utils.timeout

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import time
import warnings
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import numpy as np
from botorch.exceptions.errors import OptimizationTimeoutError
from botorch.exceptions.warnings import OptimizationWarning
from scipy import optimize


[docs]def minimize_with_timeout( fun: Callable[[np.ndarray, *Any], float], x0: np.ndarray, args: Tuple[Any, ...] = (), method: Optional[str] = None, jac: Optional[Union[str, Callable, bool]] = None, hess: Optional[Union[str, Callable, optimize.HessianUpdateStrategy]] = None, hessp: Optional[Callable] = None, bounds: Optional[Union[Sequence[Tuple[float, float]], optimize.Bounds]] = None, constraints=(), # Typing this properly is a s**t job tol: Optional[float] = None, callback: Optional[Callable] = None, options: Optional[Dict[str, Any]] = None, timeout_sec: Optional[float] = None, ) -> optimize.OptimizeResult: r"""Wrapper around scipy.optimize.minimize to support timeout. This method calls scipy.optimize.minimize with all arguments forwarded verbatim. The only difference is that if provided a `timeout_sec` argument, it will automatically stop the optimziation after the timeout is reached. Internally, this is achieved by automatically constructing a wrapper callback method that is injected to the scipy.optimize.minimize call and that keeps track of the runtime and the optimization variables at the current iteration. """ if timeout_sec: start_time = time.monotonic() callback_data = {"num_iterations": 0} # update from withing callback below def timeout_callback(xk: np.ndarray) -> bool: runtime = time.monotonic() - start_time callback_data["num_iterations"] += 1 if runtime > timeout_sec: raise OptimizationTimeoutError(current_x=xk, runtime=runtime) return False if callback is None: wrapped_callback = timeout_callback elif callable(method): raise NotImplementedError( "Custom callable not supported for `method` argument." ) elif method == "trust-constr": # special signature def wrapped_callback( xk: np.ndarray, state: optimize.OptimizeResult ) -> bool: # order here is important to make sure base callback gets executed return callback(xk, state) or timeout_callback(xk=xk) else: def wrapped_callback(xk: np.ndarray) -> None: timeout_callback(xk=xk) callback(xk) else: wrapped_callback = callback try: return optimize.minimize( fun=fun, x0=x0, args=args, method=method, jac=jac, hess=hess, hessp=hessp, bounds=bounds, constraints=constraints, tol=tol, callback=wrapped_callback, options=options, ) except OptimizationTimeoutError as e: msg = f"Optimization timed out after {e.runtime} seconds." warnings.warn(msg, OptimizationWarning) current_fun, *_ = fun(e.current_x, *args) return optimize.OptimizeResult( fun=current_fun, x=e.current_x, nit=callback_data["num_iterations"], success=False, # same as when maxiter is reached status=1, # same as when L-BFGS-B reaches maxiter message=msg, )