#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import time
import warnings
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
import numpy as np
from botorch.exceptions.errors import OptimizationTimeoutError
from scipy import optimize
[docs]
def minimize_with_timeout(
fun: Callable[[np.ndarray, *Any], float],
x0: np.ndarray,
args: Tuple[Any, ...] = (),
method: Optional[str] = None,
jac: Optional[Union[str, Callable, bool]] = None,
hess: Optional[Union[str, Callable, optimize.HessianUpdateStrategy]] = None,
hessp: Optional[Callable] = None,
bounds: Optional[Union[Sequence[Tuple[float, float]], optimize.Bounds]] = None,
constraints=(), # Typing this properly is a s**t job
tol: Optional[float] = None,
callback: Optional[Callable] = None,
options: Optional[Dict[str, Any]] = None,
timeout_sec: Optional[float] = None,
) -> optimize.OptimizeResult:
r"""Wrapper around scipy.optimize.minimize to support timeout.
This method calls scipy.optimize.minimize with all arguments forwarded
verbatim. The only difference is that if provided a `timeout_sec` argument,
it will automatically stop the optimziation after the timeout is reached.
Internally, this is achieved by automatically constructing a wrapper callback
method that is injected to the scipy.optimize.minimize call and that keeps
track of the runtime and the optimization variables at the current iteration.
"""
if timeout_sec:
start_time = time.monotonic()
callback_data = {"num_iterations": 0} # update from withing callback below
def timeout_callback(xk: np.ndarray) -> bool:
runtime = time.monotonic() - start_time
callback_data["num_iterations"] += 1
if runtime > timeout_sec:
raise OptimizationTimeoutError(current_x=xk, runtime=runtime)
return False
if callback is None:
wrapped_callback = timeout_callback
elif callable(method):
raise NotImplementedError(
"Custom callable not supported for `method` argument."
)
elif method == "trust-constr": # special signature
def wrapped_callback(
xk: np.ndarray, state: optimize.OptimizeResult
) -> bool:
# order here is important to make sure base callback gets executed
return callback(xk, state) or timeout_callback(xk=xk)
else:
def wrapped_callback(xk: np.ndarray) -> None:
timeout_callback(xk=xk)
callback(xk)
else:
wrapped_callback = callback
try:
warnings.filterwarnings("error", message="Method .* cannot handle")
return optimize.minimize(
fun=fun,
x0=x0,
args=args,
method=method,
jac=jac,
hess=hess,
hessp=hessp,
bounds=bounds,
constraints=constraints,
tol=tol,
callback=wrapped_callback,
options=options,
)
except OptimizationTimeoutError as e:
msg = f"Optimization timed out after {e.runtime} seconds."
current_fun, *_ = fun(e.current_x, *args)
return optimize.OptimizeResult(
fun=current_fun,
x=e.current_x,
nit=callback_data["num_iterations"],
success=False, # same as when maxiter is reached
status=1, # same as when L-BFGS-B reaches maxiter
message=msg,
)