Source code for botorch.test_utils.mock

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Utilities for speeding up optimization in tests.

"""

from __future__ import annotations

from collections.abc import Generator
from contextlib import contextmanager, ExitStack
from functools import wraps
from typing import Any, Callable
from unittest import mock

from botorch.optim.initializers import (
    gen_batch_initial_conditions,
    gen_one_shot_kg_initial_conditions,
)
from botorch.optim.utils.timeout import minimize_with_timeout
from scipy.optimize import OptimizeResult
from torch import Tensor


[docs] @contextmanager def mock_optimize_context_manager( force: bool = False, ) -> Generator[None, None, None]: """A context manager that uses mocks to speed up optimization for testing. Currently, the primary tactic is to force the underlying scipy methods to stop after just one iteration. force: If True will not raise an AssertionError if no mocks are called. USE RESPONSIBLY. """ def one_iteration_minimize(*args: Any, **kwargs: Any) -> OptimizeResult: if kwargs["options"] is None: kwargs["options"] = {} kwargs["options"]["maxiter"] = 1 return minimize_with_timeout(*args, **kwargs) def minimal_gen_ics(*args: Any, **kwargs: Any) -> Tensor: kwargs["num_restarts"] = 2 kwargs["raw_samples"] = 4 return gen_batch_initial_conditions(*args, **kwargs) def minimal_gen_os_ics(*args: Any, **kwargs: Any) -> Tensor | None: kwargs["num_restarts"] = 2 kwargs["raw_samples"] = 4 return gen_one_shot_kg_initial_conditions(*args, **kwargs) with ExitStack() as es: # Note this `minimize_with_timeout` is defined in optim.utils.timeout; # this mock only has an effect when calling a function used in # `botorch.generation.gen`, such as `gen_candidates_scipy`. mock_generation = es.enter_context( mock.patch( "botorch.generation.gen.minimize_with_timeout", wraps=one_iteration_minimize, ) ) # Similarly, works when using calling a function defined in # `optim.core`, such as `scipy_minimize` and `torch_minimize`. mock_fit = es.enter_context( mock.patch( "botorch.optim.core.minimize_with_timeout", wraps=one_iteration_minimize, ) ) # Works when calling a function in `optim.optimize` such as # `optimize_acqf` mock_gen_ics = es.enter_context( mock.patch( "botorch.optim.optimize.gen_batch_initial_conditions", wraps=minimal_gen_ics, ) ) # Works when calling a function in `optim.optimize` such as # `optimize_acqf` mock_gen_os_ics = es.enter_context( mock.patch( "botorch.optim.optimize.gen_one_shot_kg_initial_conditions", wraps=minimal_gen_os_ics, ) ) # Reduce default number of iterations in `optimize_acqf_mixed_alternating`. for name in [ "MAX_ITER_ALTER", "MAX_ITER_DISCRETE", "MAX_ITER_CONT", ]: es.enter_context(mock.patch(f"botorch.optim.optimize_mixed.{name}", new=1)) yield if (not force) and all( mock_.call_count < 1 for mock_ in [ mock_generation, mock_fit, mock_gen_ics, mock_gen_os_ics, ] ): raise AssertionError( "No mocks were called in the context manager. Please remove unused " "mock_optimize_context_manager()." )
[docs] def mock_optimize(f: Callable) -> Callable: """Wraps `f` in `mock_optimize_context_manager` for use as a decorator.""" @wraps(f) # pyre-fixme[3]: Return type must be annotated. def inner(*args: Any, **kwargs: Any): with mock_optimize_context_manager(): return f(*args, **kwargs) return inner