Using a custom BoTorch model
Using a custom BoTorch model with Ax
In this tutorial, we illustrate how to use a custom BoTorch model within Ax's
botorch_modular
API. This allows us to harness the convenience of Ax for running
Bayesian Optimization loops while maintaining full flexibility in modeling.
Acquisition functions and their optimizers can be swapped out in much the same fashion.
See for example the tutorial for
Implementing a custom acquisition function.
If you want to do something non-standard, or would like to have full insight into every
aspect of the implementation, please see
this tutorial for how to write your own full
optimization loop in BoTorch.
import sys
import plotly.io as pio
if 'google.colab' in sys.modules:
pio.renderers.default = "colab"
%pip install botorch ax
else:
pio.renderers.default = "png"
import os
from contextlib import contextmanager, nullcontext
from ax.utils.testing.mock import mock_botorch_optimize_context_manager
SMOKE_TEST = os.environ.get("SMOKE_TEST")
NUM_EVALS = 10 if SMOKE_TEST else 30
Implementing the custom model
For this tutorial, we implement a very simple GPyTorch ExactGP
model that uses an RBF
kernel (with ARD) and infers a homoskedastic noise level.
Model definition is straightforward. Here we implement a GPyTorch ExactGP
that
inherits from GPyTorchModel
; together these two superclasses add all the API calls
that BoTorch expects in its various modules.
Note: BoTorch allows implementing any custom model that follows the Model
API. For
more information, please see the Model Documentation.
from typing import Optional
from botorch.models.gpytorch import GPyTorchModel
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP
from torch import Tensor
class SimpleCustomGP(ExactGP, GPyTorchModel):
_num_outputs = 1
def __init__(self, train_X, train_Y, train_Yvar: Optional[Tensor] = None):
super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
self.mean_module = ConstantMean()
self.covar_module = ScaleKernel(
base_kernel=RBFKernel(ard_num_dims=train_X.shape[-1]),
)
self.to(train_X)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
Instantiate a BoTorchModel
in Ax
A BoTorchModel
in Ax encapsulates both the surrogate -- which Ax
calls a Surrogate
and BoTorch calls a Model
-- and an acquisition function. Here, we will only specify
the custom surrogate and let Ax choose the default acquisition function.
Most models should work with the base Surrogate
in Ax, except for BoTorch
ModelListGP
, which works with ListSurrogate
. Note that the Model
(e.g., the
SimpleCustomGP
) must implement construct_inputs
, as this is used to construct the
inputs required for instantiating a Model
instance from the experiment data.
from ax.models.torch.botorch_modular.model import BoTorchModel
from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.models.torch.botorch_modular.utils import ModelConfig
ax_model = BoTorchModel(
surrogate=Surrogate(
surrogate_spec=SurrogateSpec(
model_configs=[
ModelConfig(
botorch_model_class=SimpleCustomGP,
input_transform_classes=None,
)
]
)
),
)
Combine with a ModelBridge
Model
s in Ax require a ModelBridge
to interface with Experiment
s. A ModelBridge
takes the inputs supplied by the Experiment
and converts them to the inputs expected
by the Model
. For a BoTorchModel
, we use TorchModelBridge
. The Modular BoTorch
interface creates the BoTorchModel
and the TorchModelBridge
in a single step, as
follows:
from ax.modelbridge.registry import Models
model_bridge = Models.BOTORCH_MODULAR(
experiment=experiment,
data=data,
surrogate=Surrogate(SimpleCustomGP),
# Optional, will use default if unspecified
# botorch_acqf_class=qLogNoisyExpectedImprovement,
)
# To generate a trial
trial = model_bridge.gen(1)
Using the custom model in Ax to optimize the Branin function
We will demonstrate this with both the Service API (simpler, easier to use) and the
Developer API (advanced, more customizable).
Optimization with Ax's Service API
A detailed tutorial on the Service API can be found
here.
In order to customize the way the candidates are created in the Service API, we need to
construct a new GenerationStrategy
and pass it into AxClient
.
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
gs = GenerationStrategy(
steps=[
GenerationStep(
model=Models.SOBOL,
num_trials=5,
),
GenerationStep(
model=Models.BOTORCH_MODULAR,
num_trials=-1,
model_kwargs={
"surrogate_spec": SurrogateSpec(
model_configs=[ModelConfig(botorch_model_class=SimpleCustomGP, input_transform_classes=None)]
),
},
),
]
)
Setting up the experiment
In order to use the GenerationStrategy
we just created, we will pass it into the
AxClient
.
import torch
from ax.service.ax_client import AxClient
from ax.service.utils.instantiation import ObjectiveProperties
from botorch.test_functions import Branin
ax_client = AxClient(generation_strategy=gs)
ax_client.create_experiment(
name="branin_test_experiment",
parameters=[
{
"name": "x1",
"type": "range",
"bounds": [-5.0, 10.0],
},
{
"name": "x2",
"type": "range",
"bounds": [0.0, 15.0],
},
],
objectives={
"branin": ObjectiveProperties(minimize=True),
},
)
branin = Branin()
def evaluate(parameters):
x = torch.tensor([[parameters.get(f"x{i+1}") for i in range(2)]])
return {"branin": (branin(x).item(), float("nan"))}
Out: [INFO 12-02 15:16:19] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the verbose_logging argument to False. Note that float values in the logs are rounded to 6 decimal points.
[INFO 12-02 15:16:19] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x1. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 12-02 15:16:19] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 12-02 15:16:19] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[]).
Running the BO loop
The next cell sets up a decorator solely to speed up the testing of the notebook in
SMOKE_TEST
mode. You can safely ignore this cell and the use of the decorator
throughout the tutorial.
if SMOKE_TEST:
fast_smoke_test = mock_botorch_optimize_context_manager
else:
fast_smoke_test = nullcontext
torch.manual_seed(0)
Out: <torch._C.Generator at 0x12db71a50>
with fast_smoke_test():
for i in range(NUM_EVALS):
parameters, trial_index = ax_client.get_next_trial()
ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(parameters))
Out: /opt/anaconda3/envs/botorch/lib/python3.10/site-packages/ax/modelbridge/cross_validation.py:439: UserWarning:
Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.
[INFO 12-02 15:16:19] ax.service.ax_client: Generated new trial 0 with parameters {'x1': 0.62583, 'x2': 14.359564} using model Sobol.
[INFO 12-02 15:16:19] ax.service.ax_client: Completed trial 0 with data: {'branin': (104.365417, nan)}.
/opt/anaconda3/envs/botorch/lib/python3.10/site-packages/ax/modelbridge/cross_validation.py:439: UserWarning:
Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.
[INFO 12-02 15:16:19] ax.service.ax_client: Generated new trial 1 with parameters {'x1': 3.166217, 'x2': 3.867106} using model Sobol.
[INFO 12-02 15:16:19] ax.service.ax_client: Completed trial 1 with data: {'branin': (2.996862, nan)}.
/opt/anaconda3/envs/botorch/lib/python3.10/site-packages/ax/modelbridge/cross_validation.py:439: UserWarning:
Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.
[INFO 12-02 15:16:19] ax.service.ax_client: Generated new trial 2 with parameters {'x1': 9.560105, 'x2': 10.718323} using model Sobol.
[INFO 12-02 15:16:19] ax.service.ax_client: Completed trial 2 with data: {'branin': (66.530624, nan)}.
/opt/anaconda3/envs/botorch/lib/python3.10/site-packages/ax/modelbridge/cross_validation.py:439: UserWarning:
Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.
[INFO 12-02 15:16:19] ax.service.ax_client: Generated new trial 3 with parameters {'x1': -3.878664, 'x2': 0.117947} using model Sobol.
[INFO 12-02 15:16:19] ax.service.ax_client: Completed trial 3 with data: {'branin': (198.850861, nan)}.
/opt/anaconda3/envs/botorch/lib/python3.10/site-packages/ax/modelbridge/cross_validation.py:439: UserWarning:
Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.
[INFO 12-02 15:16:19] ax.service.ax_client: Generated new trial 4 with parameters {'x1': -2.362858, 'x2': 8.855021} using model Sobol.
[INFO 12-02 15:16:19] ax.service.ax_client: Completed trial 4 with data: {'branin': (5.811776, nan)}.
[INFO 12-02 15:16:20] ax.service.ax_client: Generated new trial 5 with parameters {'x1': 2.562432, 'x2': 4.925782} using model BoTorch.
[INFO 12-02 15:16:20] ax.service.ax_client: Completed trial 5 with data: {'branin': (6.611189, nan)}.
[INFO 12-02 15:16:20] ax.service.ax_client: Generated new trial 6 with parameters {'x1': 5.504102, 'x2': 4.955757} using model BoTorch.
[INFO 12-02 15:16:20] ax.service.ax_client: Completed trial 6 with data: {'branin': (31.288765, nan)}.
[INFO 12-02 15:16:21] ax.service.ax_client: Generated new trial 7 with parameters {'x1': -2.306774, 'x2': 4.433857} using model BoTorch.
[INFO 12-02 15:16:21] ax.service.ax_client: Completed trial 7 with data: {'branin': (38.658497, nan)}.
[INFO 12-02 15:16:21] ax.service.ax_client: Generated new trial 8 with parameters {'x1': -1.597063, 'x2': 7.317824} using model BoTorch.
[INFO 12-02 15:16:21] ax.service.ax_client: Completed trial 8 with data: {'branin': (12.161115, nan)}.
[INFO 12-02 15:16:21] ax.service.ax_client: Generated new trial 9 with parameters {'x1': -5.0, 'x2': 9.07374} using model BoTorch.
[INFO 12-02 15:16:21] ax.service.ax_client: Completed trial 9 with data: {'branin': (78.554581, nan)}.
[INFO 12-02 15:16:22] ax.service.ax_client: Generated new trial 10 with parameters {'x1': 0.768745, 'x2': 6.898975} using model BoTorch.
[INFO 12-02 15:16:22] ax.service.ax_client: Completed trial 10 with data: {'branin': (21.088478, nan)}.
[INFO 12-02 15:16:22] ax.service.ax_client: Generated new trial 11 with parameters {'x1': 1.619218, 'x2': 0.603319} using model BoTorch.
[INFO 12-02 15:16:22] ax.service.ax_client: Completed trial 11 with data: {'branin': (19.510229, nan)}.
[INFO 12-02 15:16:22] ax.service.ax_client: Generated new trial 12 with parameters {'x1': 10.0, 'x2': 0.0} using model BoTorch.
[INFO 12-02 15:16:22] ax.service.ax_client: Completed trial 12 with data: {'branin': (10.960894, nan)}.
[INFO 12-02 15:16:22] ax.service.ax_client: Generated new trial 13 with parameters {'x1': 7.386909, 'x2': 0.0} using model BoTorch.
[INFO 12-02 15:16:22] ax.service.ax_client: Completed trial 13 with data: {'branin': (15.994156, nan)}.
[INFO 12-02 15:16:23] ax.service.ax_client: Generated new trial 14 with parameters {'x1': 4.187552, 'x2': 0.0} using model BoTorch.
[INFO 12-02 15:16:23] ax.service.ax_client: Completed trial 14 with data: {'branin': (7.750666, nan)}.
[INFO 12-02 15:16:23] ax.service.ax_client: Generated new trial 15 with parameters {'x1': -3.93692, 'x2': 15.0} using model BoTorch.
[INFO 12-02 15:16:23] ax.service.ax_client: Completed trial 15 with data: {'branin': (3.813743, nan)}.
[INFO 12-02 15:16:23] ax.service.ax_client: Generated new trial 16 with parameters {'x1': -3.32116, 'x2': 12.384737} using model BoTorch.
[INFO 12-02 15:16:23] ax.service.ax_client: Completed trial 16 with data: {'branin': (0.658538, nan)}.
[INFO 12-02 15:16:24] ax.service.ax_client: Generated new trial 17 with parameters {'x1': 10.0, 'x2': 3.669781} using model BoTorch.
[INFO 12-02 15:16:24] ax.service.ax_client: Completed trial 17 with data: {'branin': (2.387794, nan)}.
[INFO 12-02 15:16:24] ax.service.ax_client: Generated new trial 18 with parameters {'x1': 9.341568, 'x2': 2.543665} using model BoTorch.
[INFO 12-02 15:16:24] ax.service.ax_client: Completed trial 18 with data: {'branin': (0.450142, nan)}.
[INFO 12-02 15:16:25] ax.service.ax_client: Generated new trial 19 with parameters {'x1': 3.076194, 'x2': 2.421518} using model BoTorch.
[INFO 12-02 15:16:25] ax.service.ax_client: Completed trial 19 with data: {'branin': (0.427432, nan)}.
[INFO 12-02 15:16:25] ax.service.ax_client: Generated new trial 20 with parameters {'x1': 9.536534, 'x2': 2.494184} using model BoTorch.
[INFO 12-02 15:16:25] ax.service.ax_client: Completed trial 20 with data: {'branin': (0.463669, nan)}.
[INFO 12-02 15:16:26] ax.service.ax_client: Generated new trial 21 with parameters {'x1': -3.365433, 'x2': 15.0} using model BoTorch.
[INFO 12-02 15:16:26] ax.service.ax_client: Completed trial 21 with data: {'branin': (5.392394, nan)}.
[INFO 12-02 15:16:26] ax.service.ax_client: Generated new trial 22 with parameters {'x1': 9.496421, 'x2': 2.820387} using model BoTorch.
[INFO 12-02 15:16:26] ax.service.ax_client: Completed trial 22 with data: {'branin': (0.50334, nan)}.
[INFO 12-02 15:16:27] ax.service.ax_client: Generated new trial 23 with parameters {'x1': 3.217649, 'x2': 2.439988} using model BoTorch.
[INFO 12-02 15:16:27] ax.service.ax_client: Completed trial 23 with data: {'branin': (0.475622, nan)}.
[INFO 12-02 15:16:27] ax.service.ax_client: Generated new trial 24 with parameters {'x1': 9.487907, 'x2': 2.245164} using model BoTorch.
[INFO 12-02 15:16:27] ax.service.ax_client: Completed trial 24 with data: {'branin': (0.497444, nan)}.
[INFO 12-02 15:16:28] ax.service.ax_client: Generated new trial 25 with parameters {'x1': 9.635134, 'x2': 2.550312} using model BoTorch.
[INFO 12-02 15:16:28] ax.service.ax_client: Completed trial 25 with data: {'branin': (0.62118, nan)}.
[INFO 12-02 15:16:29] ax.service.ax_client: Generated new trial 26 with parameters {'x1': -3.229368, 'x2': 12.309249} using model BoTorch.
[INFO 12-02 15:16:29] ax.service.ax_client: Completed trial 26 with data: {'branin': (0.466428, nan)}.
[INFO 12-02 15:16:30] ax.service.ax_client: Generated new trial 27 with parameters {'x1': 2.970411, 'x2': 2.46022} using model BoTorch.
[INFO 12-02 15:16:30] ax.service.ax_client: Completed trial 27 with data: {'branin': (0.540529, nan)}.
[INFO 12-02 15:16:31] ax.service.ax_client: Generated new trial 28 with parameters {'x1': 9.462445, 'x2': 2.493521} using model BoTorch.
[INFO 12-02 15:16:31] ax.service.ax_client: Completed trial 28 with data: {'branin': (0.404879, nan)}.
[INFO 12-02 15:16:32] ax.service.ax_client: Generated new trial 29 with parameters {'x1': 9.568048, 'x2': 2.731275} using model BoTorch.
[INFO 12-02 15:16:32] ax.service.ax_client: Completed trial 29 with data: {'branin': (0.513895, nan)}.
Out: [INFO 11-07 08:26:05] ax.service.ax_client: Generated new trial 3 with parameters {'x1': -3.878664, 'x2': 0.117947} using model Sobol.
Out: [INFO 11-07 08:26:05] ax.service.ax_client: Completed trial 3 with data: {'branin': (198.850861, nan)}.
Out: [INFO 11-07 08:26:05] ax.service.ax_client: Generated new trial 4 with parameters {'x1': -2.362858, 'x2': 8.855021} using model Sobol.
Out: [INFO 11-07 08:26:05] ax.service.ax_client: Completed trial 4 with data: {'branin': (5.811776, nan)}.
Out: [INFO 11-07 08:26:07] ax.service.ax_client: Generated new trial 5 with parameters {'x1': 2.562432, 'x2': 4.925782} using model BoTorch.
Out: [INFO 11-07 08:26:07] ax.service.ax_client: Completed trial 5 with data: {'branin': (6.611189, nan)}.
Out: [INFO 11-07 08:26:07] ax.service.ax_client: Generated new trial 6 with parameters {'x1': 5.50005, 'x2': 4.949873} using model BoTorch.
Out: [INFO 11-07 08:26:07] ax.service.ax_client: Completed trial 6 with data: {'branin': (31.211433, nan)}.
Out: [INFO 11-07 08:26:08] ax.service.ax_client: Generated new trial 7 with parameters {'x1': -2.300231, 'x2': 4.436402} using model BoTorch.
Out: [INFO 11-07 08:26:08] ax.service.ax_client: Completed trial 7 with data: {'branin': (38.505764, nan)}.
Out: [INFO 11-07 08:26:08] ax.service.ax_client: Generated new trial 8 with parameters {'x1': -1.583362, 'x2': 7.318469} using model BoTorch.
Out: [INFO 11-07 08:26:08] ax.service.ax_client: Completed trial 8 with data: {'branin': (12.206194, nan)}.
Out: [INFO 11-07 08:26:09] ax.service.ax_client: Generated new trial 9 with parameters {'x1': -5.0, 'x2': 9.066302} using model BoTorch.
Out: [INFO 11-07 08:26:09] ax.service.ax_client: Completed trial 9 with data: {'branin': (78.675331, nan)}.
Out: [INFO 11-07 08:26:09] ax.service.ax_client: Generated new trial 10 with parameters {'x1': 0.787884, 'x2': 6.879815} using model BoTorch.
Out: [INFO 11-07 08:26:09] ax.service.ax_client: Completed trial 10 with data: {'branin': (20.990005, nan)}.
Out: [INFO 11-07 08:26:10] ax.service.ax_client: Generated new trial 11 with parameters {'x1': 1.60023, 'x2': 0.584966} using model BoTorch.
Out: [INFO 11-07 08:26:10] ax.service.ax_client: Completed trial 11 with data: {'branin': (19.951, nan)}.
Out: [INFO 11-07 08:26:10] ax.service.ax_client: Generated new trial 12 with parameters {'x1': 10.0, 'x2': 0.0} using model BoTorch.
Out: [INFO 11-07 08:26:10] ax.service.ax_client: Completed trial 12 with data: {'branin': (10.960894, nan)}.
Out: [INFO 11-07 08:26:11] ax.service.ax_client: Generated new trial 13 with parameters {'x1': 7.38266, 'x2': 0.0} using model BoTorch.
Out: [INFO 11-07 08:26:11] ax.service.ax_client: Completed trial 13 with data: {'branin': (16.027073, nan)}.
Out: [INFO 11-07 08:26:11] ax.service.ax_client: Generated new trial 14 with parameters {'x1': 4.173322, 'x2': 0.0} using model BoTorch.
Out: [INFO 11-07 08:26:11] ax.service.ax_client: Completed trial 14 with data: {'branin': (7.656268, nan)}.
Out: [INFO 11-07 08:26:12] ax.service.ax_client: Generated new trial 15 with parameters {'x1': -3.935855, 'x2': 15.0} using model BoTorch.
Out: [INFO 11-07 08:26:12] ax.service.ax_client: Completed trial 15 with data: {'branin': (3.810518, nan)}.
Out: [INFO 11-07 08:26:12] ax.service.ax_client: Generated new trial 16 with parameters {'x1': -3.321259, 'x2': 12.38287} using model BoTorch.
Out: [INFO 11-07 08:26:12] ax.service.ax_client: Completed trial 16 with data: {'branin': (0.660087, nan)}.
Out: [INFO 11-07 08:26:13] ax.service.ax_client: Generated new trial 17 with parameters {'x1': 10.0, 'x2': 3.666754} using model BoTorch.
Out: [INFO 11-07 08:26:13] ax.service.ax_client: Completed trial 17 with data: {'branin': (2.383767, nan)}.
Out: [INFO 11-07 08:26:14] ax.service.ax_client: Generated new trial 18 with parameters {'x1': 9.34166, 'x2': 2.5446} using model BoTorch.
Out: [INFO 11-07 08:26:14] ax.service.ax_client: Completed trial 18 with data: {'branin': (0.450308, nan)}.
Out: [INFO 11-07 08:26:14] ax.service.ax_client: Generated new trial 19 with parameters {'x1': 3.076019, 'x2': 2.418569} using model BoTorch.
Out: [INFO 11-07 08:26:14] ax.service.ax_client: Completed trial 19 with data: {'branin': (0.426966, nan)}.
Out: [INFO 11-07 08:26:15] ax.service.ax_client: Generated new trial 20 with parameters {'x1': 9.537424, 'x2': 2.493842} using model BoTorch.
Out: [INFO 11-07 08:26:15] ax.service.ax_client: Completed trial 20 with data: {'branin': (0.4648, nan)}.
Out: [INFO 11-07 08:26:16] ax.service.ax_client: Generated new trial 21 with parameters {'x1': -3.360749, 'x2': 15.0} using model BoTorch.
Out: [INFO 11-07 08:26:16] ax.service.ax_client: Completed trial 21 with data: {'branin': (5.432912, nan)}.
Out: [INFO 11-07 08:26:17] ax.service.ax_client: Generated new trial 22 with parameters {'x1': 9.516079, 'x2': 2.791557} using model BoTorch.
Out: [INFO 11-07 08:26:17] ax.service.ax_client: Completed trial 22 with data: {'branin': (0.494746, nan)}.
Out: [INFO 11-07 08:26:19] ax.service.ax_client: Generated new trial 23 with parameters {'x1': 3.202976, 'x2': 2.439512} using model BoTorch.
Out: [INFO 11-07 08:26:19] ax.service.ax_client: Completed trial 23 with data: {'branin': (0.460872, nan)}.
Out: [INFO 11-07 08:26:20] ax.service.ax_client: Generated new trial 24 with parameters {'x1': 9.625609, 'x2': 2.470825} using model BoTorch.
Out: [INFO 11-07 08:26:20] ax.service.ax_client: Completed trial 24 with data: {'branin': (0.622846, nan)}.
Out: [INFO 11-07 08:26:21] ax.service.ax_client: Generated new trial 25 with parameters {'x1': -3.235781, 'x2': 12.32664} using model BoTorch.
Out: [INFO 11-07 08:26:21] ax.service.ax_client: Completed trial 25 with data: {'branin': (0.471375, nan)}.
Out: [INFO 11-07 08:26:22] ax.service.ax_client: Generated new trial 26 with parameters {'x1': 9.466124, 'x2': 2.301119} using model BoTorch.
Out: [INFO 11-07 08:26:22] ax.service.ax_client: Completed trial 26 with data: {'branin': (0.449765, nan)}.
Out: [W 241107 08:26:24 optimize:576] Optimization failed in gen_candidates_scipy with the following warning(s):
[OptimizationWarning('Optimization failed within scipy.optimize.minimize with status 2 and message ABNORMAL_TERMINATION_IN_LNSRCH.'), OptimizationWarning('Optimization failed within scipy.optimize.minimize with status 2 and message ABNORMAL_TERMINATION_IN_LNSRCH.')]
Trying again with a new set of initial conditions.
Out: [INFO 11-07 08:26:25] ax.service.ax_client: Generated new trial 27 with parameters {'x1': 2.97826, 'x2': 2.43746} using model BoTorch.
Out: [INFO 11-07 08:26:25] ax.service.ax_client: Completed trial 27 with data: {'branin': (0.526684, nan)}.
Out: [INFO 11-07 08:26:27] ax.service.ax_client: Generated new trial 28 with parameters {'x1': -3.286554, 'x2': 12.040548} using model BoTorch.
Out: [INFO 11-07 08:26:27] ax.service.ax_client: Completed trial 28 with data: {'branin': (0.84146, nan)}.
Out: [W 241107 08:26:28 optimize:576] Optimization failed in gen_candidates_scipy with the following warning(s):
[OptimizationWarning('Optimization failed within scipy.optimize.minimize with status 2 and message ABNORMAL_TERMINATION_IN_LNSRCH.')]
Trying again with a new set of initial conditions.
Out: [INFO 11-07 08:26:29] ax.service.ax_client: Generated new trial 29 with parameters {'x1': 9.459437, 'x2': 2.554713} using model BoTorch.
Out: [INFO 11-07 08:26:29] ax.service.ax_client: Completed trial 29 with data: {'branin': (0.406186, nan)}.
Viewing the evaluated trials
ax_client.get_trials_data_frame()
| trial_index | arm_name | trial_status | generation_method | branin | x1 | x2 |
---|
0 | 0 | 0_0 | COMPLETED | Sobol | 104.365 | 0.62583 | 14.3596 |
1 | 1 | 1_0 | COMPLETED | Sobol | 2.99686 | 3.16622 | 3.86711 |
2 | 2 | 2_0 | COMPLETED | Sobol | 66.5306 | 9.56011 | 10.7183 |
3 | 3 | 3_0 | COMPLETED | Sobol | 198.851 | -3.87866 | 0.117947 |
4 | 4 | 4_0 | COMPLETED | Sobol | 5.81178 | -2.36286 | 8.85502 |
5 | 5 | 5_0 | COMPLETED | BoTorch | 6.61119 | 2.56243 | 4.92578 |
6 | 6 | 6_0 | COMPLETED | BoTorch | 31.2888 | 5.5041 | 4.95576 |
7 | 7 | 7_0 | COMPLETED | BoTorch | 38.6585 | -2.30677 | 4.43386 |
8 | 8 | 8_0 | COMPLETED | BoTorch | 12.1611 | -1.59706 | 7.31782 |
9 | 9 | 9_0 | COMPLETED | BoTorch | 78.5546 | -5 | 9.07374 |
10 | 10 | 10_0 | COMPLETED | BoTorch | 21.0885 | 0.768745 | 6.89898 |
11 | 11 | 11_0 | COMPLETED | BoTorch | 19.5102 | 1.61922 | 0.603319 |
12 | 12 | 12_0 | COMPLETED | BoTorch | 10.9609 | 10 | 0 |
13 | 13 | 13_0 | COMPLETED | BoTorch | 15.9942 | 7.38691 | 0 |
14 | 14 | 14_0 | COMPLETED | BoTorch | 7.75067 | 4.18755 | 0 |
15 | 15 | 15_0 | COMPLETED | BoTorch | 3.81374 | -3.93692 | 15 |
16 | 16 | 16_0 | COMPLETED | BoTorch | 0.658538 | -3.32116 | 12.3847 |
17 | 17 | 17_0 | COMPLETED | BoTorch | 2.38779 | 10 | 3.66978 |
18 | 18 | 18_0 | COMPLETED | BoTorch | 0.450142 | 9.34157 | 2.54366 |
19 | 19 | 19_0 | COMPLETED | BoTorch | 0.427432 | 3.07619 | 2.42152 |
20 | 20 | 20_0 | COMPLETED | BoTorch | 0.463669 | 9.53653 | 2.49418 |
21 | 21 | 21_0 | COMPLETED | BoTorch | 5.39239 | -3.36543 | 15 |
22 | 22 | 22_0 | COMPLETED | BoTorch | 0.50334 | 9.49642 | 2.82039 |
23 | 23 | 23_0 | COMPLETED | BoTorch | 0.475622 | 3.21765 | 2.43999 |
24 | 24 | 24_0 | COMPLETED | BoTorch | 0.497444 | 9.48791 | 2.24516 |
25 | 25 | 25_0 | COMPLETED | BoTorch | 0.62118 | 9.63513 | 2.55031 |
26 | 26 | 26_0 | COMPLETED | BoTorch | 0.466428 | -3.22937 | 12.3092 |
27 | 27 | 27_0 | COMPLETED | BoTorch | 0.540529 | 2.97041 | 2.46022 |
28 | 28 | 28_0 | COMPLETED | BoTorch | 0.404879 | 9.46245 | 2.49352 |
29 | 29 | 29_0 | COMPLETED | BoTorch | 0.513895 | 9.56805 | 2.73128 |
parameters, values = ax_client.get_best_parameters()
print(f"Best parameters: {parameters}")
print(f"Corresponding mean: {values[0]}, covariance: {values[1]}")
Out: Best parameters: {'x1': 9.536533793530687, 'x2': 2.494184288965757}
Corresponding mean: {'branin': 0.41573114009047885}, covariance: {'branin': {'branin': 0.025883709201637482}}
Plotting the response surface and optimization progress
from ax.utils.notebook.plotting import render
render(ax_client.get_contour_plot())
Out: [INFO 12-02 15:16:32] ax.service.ax_client: Retrieving contour plot with parameter 'x1' on X-axis and 'x2' on Y-axis, for metric 'branin'. Remaining parameters are affixed to the middle of their range.
best_parameters, values = ax_client.get_best_parameters()
best_parameters, values[0]
Out: ({'x1': 9.536533793530687, 'x2': 2.494184288965757},
{'branin': 0.41573114009047885})
render(ax_client.get_optimization_trace(objective_optimum=0.397887))
Optimization with the Developer API
A detailed tutorial on the Service API can be found
here.
Set up the Experiment in Ax
We need 3 inputs for an Ax Experiment
:
- A search space to optimize over;
- An optimization config specifiying the objective / metrics to optimize, and optional
outcome constraints;
- A runner that handles the deployment of trials. For a synthetic optimization problem,
such as here, this only returns simple metadata about the trial.
import pandas as pd
import torch
from ax import (
Data,
Experiment,
Metric,
Objective,
OptimizationConfig,
ParameterType,
RangeParameter,
Runner,
SearchSpace,
)
from ax.utils.common.result import Ok
from botorch.test_functions import Branin
branin_func = Branin()
class BraninMetric(Metric):
def fetch_trial_data(self, trial):
records = []
for arm_name, arm in trial.arms_by_name.items():
params = arm.parameters
tensor_params = torch.tensor([params["x1"], params["x2"]])
records.append(
{
"arm_name": arm_name,
"metric_name": self.name,
"trial_index": trial.index,
"mean": branin_func(tensor_params),
"sem": float(
"nan"
),
}
)
return Ok(value=Data(df=pd.DataFrame.from_records(records)))
search_space = SearchSpace(
parameters=[
RangeParameter(
name="x1", parameter_type=ParameterType.FLOAT, lower=-5, upper=10
),
RangeParameter(
name="x2", parameter_type=ParameterType.FLOAT, lower=0, upper=15
),
]
)
optimization_config = OptimizationConfig(
objective=Objective(
metric=BraninMetric(name="branin_metric", lower_is_better=True),
minimize=True,
)
)
class MyRunner(Runner):
def run(self, trial):
trial_metadata = {"name": str(trial.index)}
return trial_metadata
exp = Experiment(
name="branin_experiment",
search_space=search_space,
optimization_config=optimization_config,
runner=MyRunner(),
)
Run the BO loop
First, we use the Sobol generator to create 5 (quasi-) random initial point in the
search space. Ax controls objective evaluations via Trial
s.
- We generate a
Trial
using a generator run, e.g., Sobol
below. A Trial
specifies
relevant metadata as well as the parameters to be evaluated. At this point, the
Trial
is at the CANDIDATE
stage.
- We run the
Trial
using Trial.run()
. In our example, this serves to mark the
Trial
as RUNNING
. In an advanced application, this can be used to dispatch the
Trial
for evaluation on a remote server.
- Once the
Trial
is done running, we mark it as COMPLETED
. This tells the
Experiment
that it can fetch the Trial
data.
A Trial
supports evaluation of a single parameterization. For parallel evaluations,
see BatchTrial
.
from ax.modelbridge.registry import Models
sobol = Models.SOBOL(exp.search_space)
for i in range(5):
trial = exp.new_trial(generator_run=sobol.gen(1))
trial.run()
trial.mark_completed()
Once the initial (quasi-) random stage is completed, we can use our SimpleCustomGP
with the default acquisition function chosen by Ax
to run the BO loop.
with fast_smoke_test():
for i in range(NUM_EVALS - 5):
model_bridge = Models.BOTORCH_MODULAR(
experiment=exp,
data=exp.fetch_data(),
surrogate_spec=SurrogateSpec(
model_configs=[ModelConfig(SimpleCustomGP, input_transform_classes=None)]
),
)
trial = exp.new_trial(generator_run=model_bridge.gen(1))
trial.run()
trial.mark_completed()
Out: /opt/anaconda3/envs/botorch/lib/python3.10/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/Users/saitcakmak/botorch/botorch/optim/optimize.py:576: RuntimeWarning:
Optimization failed in gen_candidates_scipy with the following warning(s):
[OptimizationWarning('Optimization failed within scipy.optimize.minimize with status 2 and message ABNORMAL_TERMINATION_IN_LNSRCH.')]
Trying again with a new set of initial conditions.
View the trials attached to the Experiment
.
Out: {0: Trial(experiment_name='branin_experiment', index=0, status=TrialStatus.COMPLETED, arm=Arm(name='0_0', parameters={'x1': -3.001730814576149, 'x2': 10.747691988945007})),
1: Trial(experiment_name='branin_experiment', index=1, status=TrialStatus.COMPLETED, arm=Arm(name='1_0', parameters={'x1': 5.38613950368017, 'x2': 3.3136763563379645})),
2: Trial(experiment_name='branin_experiment', index=2, status=TrialStatus.COMPLETED, arm=Arm(name='2_0', parameters={'x1': 7.6478424202650785, 'x2': 13.54916384909302})),
3: Trial(experiment_name='branin_experiment', index=3, status=TrialStatus.COMPLETED, arm=Arm(name='3_0', parameters={'x1': -0.9862332977354527, 'x2': 6.144560454413295})),
4: Trial(experiment_name='branin_experiment', index=4, status=TrialStatus.COMPLETED, arm=Arm(name='4_0', parameters={'x1': 0.9829527512192726, 'x2': 11.45068815909326})),
5: Trial(experiment_name='branin_experiment', index=5, status=TrialStatus.COMPLETED, arm=Arm(name='5_0', parameters={'x1': -5.0, 'x2': 10.681135484818025})),
6: Trial(experiment_name='branin_experiment', index=6, status=TrialStatus.COMPLETED, arm=Arm(name='6_0', parameters={'x1': 0.9771782267390936, 'x2': 0.0})),
7: Trial(experiment_name='branin_experiment', index=7, status=TrialStatus.COMPLETED, arm=Arm(name='7_0', parameters={'x1': 10.0, 'x2': 0.0})),
8: Trial(experiment_name='branin_experiment', index=8, status=TrialStatus.COMPLETED, arm=Arm(name='8_0', parameters={'x1': -5.0, 'x2': 3.5936134273014724})),
9: Trial(experiment_name='branin_experiment', index=9, status=TrialStatus.COMPLETED, arm=Arm(name='9_0', parameters={'x1': 7.451191634850801, 'x2': 0.0})),
10: Trial(experiment_name='branin_experiment', index=10, status=TrialStatus.COMPLETED, arm=Arm(name='10_0', parameters={'x1': -2.989790970817857, 'x2': 14.507362023858969})),
11: Trial(experiment_name='branin_experiment', index=11, status=TrialStatus.COMPLETED, arm=Arm(name='11_0', parameters={'x1': -2.206908749062847, 'x2': 11.361635411408425})),
12: Trial(experiment_name='branin_experiment', index=12, status=TrialStatus.COMPLETED, arm=Arm(name='12_0', parameters={'x1': 10.0, 'x2': 3.854754557203238})),
13: Trial(experiment_name='branin_experiment', index=13, status=TrialStatus.COMPLETED, arm=Arm(name='13_0', parameters={'x1': 8.879169356973378, 'x2': 2.600188907289197})),
14: Trial(experiment_name='branin_experiment', index=14, status=TrialStatus.COMPLETED, arm=Arm(name='14_0', parameters={'x1': 10.0, 'x2': 2.3822346114805195})),
15: Trial(experiment_name='branin_experiment', index=15, status=TrialStatus.COMPLETED, arm=Arm(name='15_0', parameters={'x1': -4.3121547230008925, 'x2': 15.0})),
16: Trial(experiment_name='branin_experiment', index=16, status=TrialStatus.COMPLETED, arm=Arm(name='16_0', parameters={'x1': -3.329959894641452, 'x2': 12.717573331024465})),
17: Trial(experiment_name='branin_experiment', index=17, status=TrialStatus.COMPLETED, arm=Arm(name='17_0', parameters={'x1': 2.7242595961882197, 'x2': 2.9602941886720284})),
18: Trial(experiment_name='branin_experiment', index=18, status=TrialStatus.COMPLETED, arm=Arm(name='18_0', parameters={'x1': 3.649650277984131, 'x2': 0.0})),
19: Trial(experiment_name='branin_experiment', index=19, status=TrialStatus.COMPLETED, arm=Arm(name='19_0', parameters={'x1': 9.46343663901485, 'x2': 2.5467470674060593})),
20: Trial(experiment_name='branin_experiment', index=20, status=TrialStatus.COMPLETED, arm=Arm(name='20_0', parameters={'x1': 3.0104763473247154, 'x2': 1.9945479954698144})),
21: Trial(experiment_name='branin_experiment', index=21, status=TrialStatus.COMPLETED, arm=Arm(name='21_0', parameters={'x1': -3.203719761808281, 'x2': 12.37041540884107})),
22: Trial(experiment_name='branin_experiment', index=22, status=TrialStatus.COMPLETED, arm=Arm(name='22_0', parameters={'x1': 9.476541680994133, 'x2': 2.9065640930047554})),
23: Trial(experiment_name='branin_experiment', index=23, status=TrialStatus.COMPLETED, arm=Arm(name='23_0', parameters={'x1': 2.734598996972391, 'x2': 2.1218033122913322})),
24: Trial(experiment_name='branin_experiment', index=24, status=TrialStatus.COMPLETED, arm=Arm(name='24_0', parameters={'x1': 3.2195726883400404, 'x2': 2.136557689490957})),
25: Trial(experiment_name='branin_experiment', index=25, status=TrialStatus.COMPLETED, arm=Arm(name='25_0', parameters={'x1': 9.42953989277149, 'x2': 2.2711877578623016})),
26: Trial(experiment_name='branin_experiment', index=26, status=TrialStatus.COMPLETED, arm=Arm(name='26_0', parameters={'x1': 3.1637517854500388, 'x2': 2.5217242210404827})),
27: Trial(experiment_name='branin_experiment', index=27, status=TrialStatus.COMPLETED, arm=Arm(name='27_0', parameters={'x1': -3.1115663937782836, 'x2': 12.125534390814288})),
28: Trial(experiment_name='branin_experiment', index=28, status=TrialStatus.COMPLETED, arm=Arm(name='28_0', parameters={'x1': 9.517886960473454, 'x2': 2.660053049906566})),
29: Trial(experiment_name='branin_experiment', index=29, status=TrialStatus.COMPLETED, arm=Arm(name='29_0', parameters={'x1': 3.3039667073151158, 'x2': 1.9687756455552108}))}
View the evaluation data about these trials.
| arm_name | metric_name | mean | sem | trial_index |
---|
0 | 0_0 | branin_metric | 1.9166 | nan | 0 |
1 | 1_0 | branin_metric | 20.5632 | nan | 1 |
2 | 2_0 | branin_metric | 159.956 | nan | 2 |
3 | 3_0 | branin_metric | 17.7035 | nan | 3 |
4 | 4_0 | branin_metric | 62.8011 | nan | 4 |
5 | 5_0 | branin_metric | 55.0547 | nan | 5 |
6 | 6_0 | branin_metric | 36.2389 | nan | 6 |
7 | 7_0 | branin_metric | 10.9609 | nan | 7 |
8 | 8_0 | branin_metric | 197.514 | nan | 8 |
9 | 9_0 | branin_metric | 15.489 | nan | 9 |
10 | 10_0 | branin_metric | 7.23819 | nan | 10 |
11 | 11_0 | branin_metric | 5.78416 | nan | 11 |
12 | 12_0 | branin_metric | 2.6687 | nan | 12 |
13 | 13_0 | branin_metric | 2.09117 | nan | 13 |
14 | 14_0 | branin_metric | 2.32844 | nan | 14 |
15 | 15_0 | branin_metric | 6.329 | nan | 15 |
16 | 16_0 | branin_metric | 0.567952 | nan | 16 |
17 | 17_0 | branin_metric | 1.3358 | nan | 17 |
18 | 18_0 | branin_metric | 5.26697 | nan | 18 |
19 | 19_0 | branin_metric | 0.406578 | nan | 19 |
20 | 20_0 | branin_metric | 0.628474 | nan | 20 |
21 | 21_0 | branin_metric | 0.419371 | nan | 21 |
22 | 22_0 | branin_metric | 0.560946 | nan | 22 |
23 | 23_0 | branin_metric | 1.42429 | nan | 23 |
24 | 24_0 | branin_metric | 0.433216 | nan | 24 |
25 | 25_0 | branin_metric | 0.441191 | nan | 25 |
26 | 26_0 | branin_metric | 0.46991 | nan | 26 |
27 | 27_0 | branin_metric | 0.408209 | nan | 27 |
28 | 28_0 | branin_metric | 0.450586 | nan | 28 |
29 | 29_0 | branin_metric | 0.55768 | nan | 29 |
Plot results
We can use convenient Ax utilities for plotting the results.
import numpy as np
from ax.plot.trace import optimization_trace_single_method
objective_means = np.array([[trial.objective_mean for trial in exp.trials.values()]])
best_objective_plot = optimization_trace_single_method(
y=np.minimum.accumulate(objective_means, axis=1),
optimum=0.397887,
)
render(best_objective_plot)