Skip to main content
Version: Next

Using a custom BoTorch model

Using a custom BoTorch model with Ax

In this tutorial, we illustrate how to use a custom BoTorch model within Ax's botorch_modular API. This allows us to harness the convenience of Ax for running Bayesian Optimization loops while maintaining full flexibility in modeling.

Acquisition functions and their optimizers can be swapped out in much the same fashion. See for example the tutorial for Implementing a custom acquisition function.

If you want to do something non-standard, or would like to have full insight into every aspect of the implementation, please see this tutorial for how to write your own full optimization loop in BoTorch.

# Install dependencies if we are running in colab
import sys
import plotly.io as pio
if 'google.colab' in sys.modules:
pio.renderers.default = "colab"
%pip install botorch ax
else:
# Ax uses Plotly to produce interactive plots. These are great for viewing and analysis,
# though they also lead to large file sizes, which is not ideal for files living in GH.
# Changing the default to `png` strips the interactive components to get around this.
pio.renderers.default = "png"

import os
from contextlib import contextmanager, nullcontext

from ax.utils.testing.mock import mock_botorch_optimize_context_manager

SMOKE_TEST = os.environ.get("SMOKE_TEST")
NUM_EVALS = 10 if SMOKE_TEST else 30

Implementing the custom model

For this tutorial, we implement a very simple GPyTorch ExactGP model that uses an RBF kernel (with ARD) and infers a homoskedastic noise level.

Model definition is straightforward. Here we implement a GPyTorch ExactGP that inherits from GPyTorchModel; together these two superclasses add all the API calls that BoTorch expects in its various modules.

Note: BoTorch allows implementing any custom model that follows the Model API. For more information, please see the Model Documentation.

from typing import Optional

from botorch.models.gpytorch import GPyTorchModel
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP
from torch import Tensor


class SimpleCustomGP(ExactGP, GPyTorchModel):

_num_outputs = 1 # to inform GPyTorchModel API

def __init__(self, train_X, train_Y, train_Yvar: Optional[Tensor] = None):
# NOTE: This ignores train_Yvar and uses inferred noise instead.
# squeeze output dim before passing train_Y to ExactGP
super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
self.mean_module = ConstantMean()
self.covar_module = ScaleKernel(
base_kernel=RBFKernel(ard_num_dims=train_X.shape[-1]),
)
self.to(train_X) # make sure we're on the right device/dtype

def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)

Instantiate a BoTorchModel in Ax

A BoTorchModel in Ax encapsulates both the surrogate -- which Ax calls a Surrogate and BoTorch calls a Model -- and an acquisition function. Here, we will only specify the custom surrogate and let Ax choose the default acquisition function.

Most models should work with the base Surrogate in Ax, except for BoTorch ModelListGP, which works with ListSurrogate. Note that the Model (e.g., the SimpleCustomGP) must implement construct_inputs, as this is used to construct the inputs required for instantiating a Model instance from the experiment data.

from ax.models.torch.botorch_modular.model import BoTorchModel
from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.models.torch.botorch_modular.utils import ModelConfig

ax_model = BoTorchModel(
surrogate=Surrogate(
surrogate_spec=SurrogateSpec(
model_configs=[
ModelConfig(
# The model class to use
botorch_model_class=SimpleCustomGP,
# Optional, MLL class with which to optimize model parameters
# mll_class=ExactMarginalLogLikelihood,
# Optional, dictionary of keyword arguments to model constructor
# model_options={}
# Passing in `None` to disable the default set of input transforms
# constructed in Ax, since the model doesn't support transforms.
input_transform_classes=None,
)
]
)
),
# Optional, acquisition function class to use - see custom acquisition tutorial
# botorch_acqf_class=qExpectedImprovement,
)

Combine with a ModelBridge

Models in Ax require a ModelBridge to interface with Experiments. A ModelBridge takes the inputs supplied by the Experiment and converts them to the inputs expected by the Model. For a BoTorchModel, we use TorchModelBridge. The Modular BoTorch interface creates the BoTorchModel and the TorchModelBridge in a single step, as follows:

from ax.modelbridge.registry import Models
model_bridge = Models.BOTORCH_MODULAR(
experiment=experiment,
data=data,
surrogate=Surrogate(SimpleCustomGP),
# Optional, will use default if unspecified
# botorch_acqf_class=qLogNoisyExpectedImprovement,
)
# To generate a trial
trial = model_bridge.gen(1)

Using the custom model in Ax to optimize the Branin function

We will demonstrate this with both the Service API (simpler, easier to use) and the Developer API (advanced, more customizable).

Optimization with Ax's Service API

A detailed tutorial on the Service API can be found here.

In order to customize the way the candidates are created in the Service API, we need to construct a new GenerationStrategy and pass it into AxClient.

from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models


gs = GenerationStrategy(
steps=[
# Quasi-random initialization step
GenerationStep(
model=Models.SOBOL,
num_trials=5, # How many trials should be produced from this generation step
),
# Bayesian optimization step using the custom acquisition function
GenerationStep(
model=Models.BOTORCH_MODULAR,
num_trials=-1, # No limitation on how many trials should be produced from this step
# For `BOTORCH_MODULAR`, we pass in kwargs to specify what surrogate or acquisition function to use.
model_kwargs={
"surrogate_spec": SurrogateSpec(
model_configs=[ModelConfig(botorch_model_class=SimpleCustomGP, input_transform_classes=None)]
),
},
),
]
)

Setting up the experiment

In order to use the GenerationStrategy we just created, we will pass it into the AxClient.

import torch
from ax.service.ax_client import AxClient
from ax.service.utils.instantiation import ObjectiveProperties
from botorch.test_functions import Branin


# Initialize the client - AxClient offers a convenient API to control the experiment
ax_client = AxClient(generation_strategy=gs)
# Setup the experiment
ax_client.create_experiment(
name="branin_test_experiment",
parameters=[
{
"name": "x1",
"type": "range",
# It is crucial to use floats for the bounds, i.e., 0.0 rather than 0.
# Otherwise, the parameter would be inferred as an integer range.
"bounds": [-5.0, 10.0],
},
{
"name": "x2",
"type": "range",
"bounds": [0.0, 15.0],
},
],
objectives={
"branin": ObjectiveProperties(minimize=True),
},
)
# Setup a function to evaluate the trials
branin = Branin()


def evaluate(parameters):
x = torch.tensor([[parameters.get(f"x{i+1}") for i in range(2)]])
# The GaussianLikelihood used by our model infers an observation noise level,
# so we pass an sem value of NaN to indicate that observation noise is unknown
return {"branin": (branin(x).item(), float("nan"))}
Out:

[INFO 12-02 15:16:19] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the verbose_logging argument to False. Note that float values in the logs are rounded to 6 decimal points.

[INFO 12-02 15:16:19] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x1. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.

[INFO 12-02 15:16:19] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.

[INFO 12-02 15:16:19] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[]).

Running the BO loop

The next cell sets up a decorator solely to speed up the testing of the notebook in SMOKE_TEST mode. You can safely ignore this cell and the use of the decorator throughout the tutorial.

if SMOKE_TEST:
fast_smoke_test = mock_botorch_optimize_context_manager
else:
fast_smoke_test = nullcontext

# Set a seed for reproducible tutorial output
torch.manual_seed(0)
Out:

<torch._C.Generator at 0x12db71a50>

with fast_smoke_test():
for i in range(NUM_EVALS):
parameters, trial_index = ax_client.get_next_trial()
# Local evaluation here can be replaced with deployment to external system.
ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(parameters))
Out:

/opt/anaconda3/envs/botorch/lib/python3.10/site-packages/ax/modelbridge/cross_validation.py:439: UserWarning:

Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.

[INFO 12-02 15:16:19] ax.service.ax_client: Generated new trial 0 with parameters {'x1': 0.62583, 'x2': 14.359564} using model Sobol.

[INFO 12-02 15:16:19] ax.service.ax_client: Completed trial 0 with data: {'branin': (104.365417, nan)}.

/opt/anaconda3/envs/botorch/lib/python3.10/site-packages/ax/modelbridge/cross_validation.py:439: UserWarning:

Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.

[INFO 12-02 15:16:19] ax.service.ax_client: Generated new trial 1 with parameters {'x1': 3.166217, 'x2': 3.867106} using model Sobol.

[INFO 12-02 15:16:19] ax.service.ax_client: Completed trial 1 with data: {'branin': (2.996862, nan)}.

/opt/anaconda3/envs/botorch/lib/python3.10/site-packages/ax/modelbridge/cross_validation.py:439: UserWarning:

Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.

[INFO 12-02 15:16:19] ax.service.ax_client: Generated new trial 2 with parameters {'x1': 9.560105, 'x2': 10.718323} using model Sobol.

[INFO 12-02 15:16:19] ax.service.ax_client: Completed trial 2 with data: {'branin': (66.530624, nan)}.

/opt/anaconda3/envs/botorch/lib/python3.10/site-packages/ax/modelbridge/cross_validation.py:439: UserWarning:

Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.

[INFO 12-02 15:16:19] ax.service.ax_client: Generated new trial 3 with parameters {'x1': -3.878664, 'x2': 0.117947} using model Sobol.

[INFO 12-02 15:16:19] ax.service.ax_client: Completed trial 3 with data: {'branin': (198.850861, nan)}.

/opt/anaconda3/envs/botorch/lib/python3.10/site-packages/ax/modelbridge/cross_validation.py:439: UserWarning:

Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.

[INFO 12-02 15:16:19] ax.service.ax_client: Generated new trial 4 with parameters {'x1': -2.362858, 'x2': 8.855021} using model Sobol.

[INFO 12-02 15:16:19] ax.service.ax_client: Completed trial 4 with data: {'branin': (5.811776, nan)}.

[INFO 12-02 15:16:20] ax.service.ax_client: Generated new trial 5 with parameters {'x1': 2.562432, 'x2': 4.925782} using model BoTorch.

[INFO 12-02 15:16:20] ax.service.ax_client: Completed trial 5 with data: {'branin': (6.611189, nan)}.

[INFO 12-02 15:16:20] ax.service.ax_client: Generated new trial 6 with parameters {'x1': 5.504102, 'x2': 4.955757} using model BoTorch.

[INFO 12-02 15:16:20] ax.service.ax_client: Completed trial 6 with data: {'branin': (31.288765, nan)}.

[INFO 12-02 15:16:21] ax.service.ax_client: Generated new trial 7 with parameters {'x1': -2.306774, 'x2': 4.433857} using model BoTorch.

[INFO 12-02 15:16:21] ax.service.ax_client: Completed trial 7 with data: {'branin': (38.658497, nan)}.

[INFO 12-02 15:16:21] ax.service.ax_client: Generated new trial 8 with parameters {'x1': -1.597063, 'x2': 7.317824} using model BoTorch.

[INFO 12-02 15:16:21] ax.service.ax_client: Completed trial 8 with data: {'branin': (12.161115, nan)}.

[INFO 12-02 15:16:21] ax.service.ax_client: Generated new trial 9 with parameters {'x1': -5.0, 'x2': 9.07374} using model BoTorch.

[INFO 12-02 15:16:21] ax.service.ax_client: Completed trial 9 with data: {'branin': (78.554581, nan)}.

[INFO 12-02 15:16:22] ax.service.ax_client: Generated new trial 10 with parameters {'x1': 0.768745, 'x2': 6.898975} using model BoTorch.

[INFO 12-02 15:16:22] ax.service.ax_client: Completed trial 10 with data: {'branin': (21.088478, nan)}.

[INFO 12-02 15:16:22] ax.service.ax_client: Generated new trial 11 with parameters {'x1': 1.619218, 'x2': 0.603319} using model BoTorch.

[INFO 12-02 15:16:22] ax.service.ax_client: Completed trial 11 with data: {'branin': (19.510229, nan)}.

[INFO 12-02 15:16:22] ax.service.ax_client: Generated new trial 12 with parameters {'x1': 10.0, 'x2': 0.0} using model BoTorch.

[INFO 12-02 15:16:22] ax.service.ax_client: Completed trial 12 with data: {'branin': (10.960894, nan)}.

[INFO 12-02 15:16:22] ax.service.ax_client: Generated new trial 13 with parameters {'x1': 7.386909, 'x2': 0.0} using model BoTorch.

[INFO 12-02 15:16:22] ax.service.ax_client: Completed trial 13 with data: {'branin': (15.994156, nan)}.

[INFO 12-02 15:16:23] ax.service.ax_client: Generated new trial 14 with parameters {'x1': 4.187552, 'x2': 0.0} using model BoTorch.

[INFO 12-02 15:16:23] ax.service.ax_client: Completed trial 14 with data: {'branin': (7.750666, nan)}.

[INFO 12-02 15:16:23] ax.service.ax_client: Generated new trial 15 with parameters {'x1': -3.93692, 'x2': 15.0} using model BoTorch.

[INFO 12-02 15:16:23] ax.service.ax_client: Completed trial 15 with data: {'branin': (3.813743, nan)}.

[INFO 12-02 15:16:23] ax.service.ax_client: Generated new trial 16 with parameters {'x1': -3.32116, 'x2': 12.384737} using model BoTorch.

[INFO 12-02 15:16:23] ax.service.ax_client: Completed trial 16 with data: {'branin': (0.658538, nan)}.

[INFO 12-02 15:16:24] ax.service.ax_client: Generated new trial 17 with parameters {'x1': 10.0, 'x2': 3.669781} using model BoTorch.

[INFO 12-02 15:16:24] ax.service.ax_client: Completed trial 17 with data: {'branin': (2.387794, nan)}.

[INFO 12-02 15:16:24] ax.service.ax_client: Generated new trial 18 with parameters {'x1': 9.341568, 'x2': 2.543665} using model BoTorch.

[INFO 12-02 15:16:24] ax.service.ax_client: Completed trial 18 with data: {'branin': (0.450142, nan)}.

[INFO 12-02 15:16:25] ax.service.ax_client: Generated new trial 19 with parameters {'x1': 3.076194, 'x2': 2.421518} using model BoTorch.

[INFO 12-02 15:16:25] ax.service.ax_client: Completed trial 19 with data: {'branin': (0.427432, nan)}.

[INFO 12-02 15:16:25] ax.service.ax_client: Generated new trial 20 with parameters {'x1': 9.536534, 'x2': 2.494184} using model BoTorch.

[INFO 12-02 15:16:25] ax.service.ax_client: Completed trial 20 with data: {'branin': (0.463669, nan)}.

[INFO 12-02 15:16:26] ax.service.ax_client: Generated new trial 21 with parameters {'x1': -3.365433, 'x2': 15.0} using model BoTorch.

[INFO 12-02 15:16:26] ax.service.ax_client: Completed trial 21 with data: {'branin': (5.392394, nan)}.

[INFO 12-02 15:16:26] ax.service.ax_client: Generated new trial 22 with parameters {'x1': 9.496421, 'x2': 2.820387} using model BoTorch.

[INFO 12-02 15:16:26] ax.service.ax_client: Completed trial 22 with data: {'branin': (0.50334, nan)}.

[INFO 12-02 15:16:27] ax.service.ax_client: Generated new trial 23 with parameters {'x1': 3.217649, 'x2': 2.439988} using model BoTorch.

[INFO 12-02 15:16:27] ax.service.ax_client: Completed trial 23 with data: {'branin': (0.475622, nan)}.

[INFO 12-02 15:16:27] ax.service.ax_client: Generated new trial 24 with parameters {'x1': 9.487907, 'x2': 2.245164} using model BoTorch.

[INFO 12-02 15:16:27] ax.service.ax_client: Completed trial 24 with data: {'branin': (0.497444, nan)}.

[INFO 12-02 15:16:28] ax.service.ax_client: Generated new trial 25 with parameters {'x1': 9.635134, 'x2': 2.550312} using model BoTorch.

[INFO 12-02 15:16:28] ax.service.ax_client: Completed trial 25 with data: {'branin': (0.62118, nan)}.

[INFO 12-02 15:16:29] ax.service.ax_client: Generated new trial 26 with parameters {'x1': -3.229368, 'x2': 12.309249} using model BoTorch.

[INFO 12-02 15:16:29] ax.service.ax_client: Completed trial 26 with data: {'branin': (0.466428, nan)}.

[INFO 12-02 15:16:30] ax.service.ax_client: Generated new trial 27 with parameters {'x1': 2.970411, 'x2': 2.46022} using model BoTorch.

[INFO 12-02 15:16:30] ax.service.ax_client: Completed trial 27 with data: {'branin': (0.540529, nan)}.

[INFO 12-02 15:16:31] ax.service.ax_client: Generated new trial 28 with parameters {'x1': 9.462445, 'x2': 2.493521} using model BoTorch.

[INFO 12-02 15:16:31] ax.service.ax_client: Completed trial 28 with data: {'branin': (0.404879, nan)}.

[INFO 12-02 15:16:32] ax.service.ax_client: Generated new trial 29 with parameters {'x1': 9.568048, 'x2': 2.731275} using model BoTorch.

[INFO 12-02 15:16:32] ax.service.ax_client: Completed trial 29 with data: {'branin': (0.513895, nan)}.

Out:

[INFO 11-07 08:26:05] ax.service.ax_client: Generated new trial 3 with parameters {'x1': -3.878664, 'x2': 0.117947} using model Sobol.

Out:

[INFO 11-07 08:26:05] ax.service.ax_client: Completed trial 3 with data: {'branin': (198.850861, nan)}.

Out:

[INFO 11-07 08:26:05] ax.service.ax_client: Generated new trial 4 with parameters {'x1': -2.362858, 'x2': 8.855021} using model Sobol.

Out:

[INFO 11-07 08:26:05] ax.service.ax_client: Completed trial 4 with data: {'branin': (5.811776, nan)}.

Out:

[INFO 11-07 08:26:07] ax.service.ax_client: Generated new trial 5 with parameters {'x1': 2.562432, 'x2': 4.925782} using model BoTorch.

Out:

[INFO 11-07 08:26:07] ax.service.ax_client: Completed trial 5 with data: {'branin': (6.611189, nan)}.

Out:

[INFO 11-07 08:26:07] ax.service.ax_client: Generated new trial 6 with parameters {'x1': 5.50005, 'x2': 4.949873} using model BoTorch.

Out:

[INFO 11-07 08:26:07] ax.service.ax_client: Completed trial 6 with data: {'branin': (31.211433, nan)}.

Out:

[INFO 11-07 08:26:08] ax.service.ax_client: Generated new trial 7 with parameters {'x1': -2.300231, 'x2': 4.436402} using model BoTorch.

Out:

[INFO 11-07 08:26:08] ax.service.ax_client: Completed trial 7 with data: {'branin': (38.505764, nan)}.

Out:

[INFO 11-07 08:26:08] ax.service.ax_client: Generated new trial 8 with parameters {'x1': -1.583362, 'x2': 7.318469} using model BoTorch.

Out:

[INFO 11-07 08:26:08] ax.service.ax_client: Completed trial 8 with data: {'branin': (12.206194, nan)}.

Out:

[INFO 11-07 08:26:09] ax.service.ax_client: Generated new trial 9 with parameters {'x1': -5.0, 'x2': 9.066302} using model BoTorch.

Out:

[INFO 11-07 08:26:09] ax.service.ax_client: Completed trial 9 with data: {'branin': (78.675331, nan)}.

Out:

[INFO 11-07 08:26:09] ax.service.ax_client: Generated new trial 10 with parameters {'x1': 0.787884, 'x2': 6.879815} using model BoTorch.

Out:

[INFO 11-07 08:26:09] ax.service.ax_client: Completed trial 10 with data: {'branin': (20.990005, nan)}.

Out:

[INFO 11-07 08:26:10] ax.service.ax_client: Generated new trial 11 with parameters {'x1': 1.60023, 'x2': 0.584966} using model BoTorch.

Out:

[INFO 11-07 08:26:10] ax.service.ax_client: Completed trial 11 with data: {'branin': (19.951, nan)}.

Out:

[INFO 11-07 08:26:10] ax.service.ax_client: Generated new trial 12 with parameters {'x1': 10.0, 'x2': 0.0} using model BoTorch.

Out:

[INFO 11-07 08:26:10] ax.service.ax_client: Completed trial 12 with data: {'branin': (10.960894, nan)}.

Out:

[INFO 11-07 08:26:11] ax.service.ax_client: Generated new trial 13 with parameters {'x1': 7.38266, 'x2': 0.0} using model BoTorch.

Out:

[INFO 11-07 08:26:11] ax.service.ax_client: Completed trial 13 with data: {'branin': (16.027073, nan)}.

Out:

[INFO 11-07 08:26:11] ax.service.ax_client: Generated new trial 14 with parameters {'x1': 4.173322, 'x2': 0.0} using model BoTorch.

Out:

[INFO 11-07 08:26:11] ax.service.ax_client: Completed trial 14 with data: {'branin': (7.656268, nan)}.

Out:

[INFO 11-07 08:26:12] ax.service.ax_client: Generated new trial 15 with parameters {'x1': -3.935855, 'x2': 15.0} using model BoTorch.

Out:

[INFO 11-07 08:26:12] ax.service.ax_client: Completed trial 15 with data: {'branin': (3.810518, nan)}.

Out:

[INFO 11-07 08:26:12] ax.service.ax_client: Generated new trial 16 with parameters {'x1': -3.321259, 'x2': 12.38287} using model BoTorch.

Out:

[INFO 11-07 08:26:12] ax.service.ax_client: Completed trial 16 with data: {'branin': (0.660087, nan)}.

Out:

[INFO 11-07 08:26:13] ax.service.ax_client: Generated new trial 17 with parameters {'x1': 10.0, 'x2': 3.666754} using model BoTorch.

Out:

[INFO 11-07 08:26:13] ax.service.ax_client: Completed trial 17 with data: {'branin': (2.383767, nan)}.

Out:

[INFO 11-07 08:26:14] ax.service.ax_client: Generated new trial 18 with parameters {'x1': 9.34166, 'x2': 2.5446} using model BoTorch.

Out:

[INFO 11-07 08:26:14] ax.service.ax_client: Completed trial 18 with data: {'branin': (0.450308, nan)}.

Out:

[INFO 11-07 08:26:14] ax.service.ax_client: Generated new trial 19 with parameters {'x1': 3.076019, 'x2': 2.418569} using model BoTorch.

Out:

[INFO 11-07 08:26:14] ax.service.ax_client: Completed trial 19 with data: {'branin': (0.426966, nan)}.

Out:

[INFO 11-07 08:26:15] ax.service.ax_client: Generated new trial 20 with parameters {'x1': 9.537424, 'x2': 2.493842} using model BoTorch.

Out:

[INFO 11-07 08:26:15] ax.service.ax_client: Completed trial 20 with data: {'branin': (0.4648, nan)}.

Out:

[INFO 11-07 08:26:16] ax.service.ax_client: Generated new trial 21 with parameters {'x1': -3.360749, 'x2': 15.0} using model BoTorch.

Out:

[INFO 11-07 08:26:16] ax.service.ax_client: Completed trial 21 with data: {'branin': (5.432912, nan)}.

Out:

[INFO 11-07 08:26:17] ax.service.ax_client: Generated new trial 22 with parameters {'x1': 9.516079, 'x2': 2.791557} using model BoTorch.

Out:

[INFO 11-07 08:26:17] ax.service.ax_client: Completed trial 22 with data: {'branin': (0.494746, nan)}.

Out:

[INFO 11-07 08:26:19] ax.service.ax_client: Generated new trial 23 with parameters {'x1': 3.202976, 'x2': 2.439512} using model BoTorch.

Out:

[INFO 11-07 08:26:19] ax.service.ax_client: Completed trial 23 with data: {'branin': (0.460872, nan)}.

Out:

[INFO 11-07 08:26:20] ax.service.ax_client: Generated new trial 24 with parameters {'x1': 9.625609, 'x2': 2.470825} using model BoTorch.

Out:

[INFO 11-07 08:26:20] ax.service.ax_client: Completed trial 24 with data: {'branin': (0.622846, nan)}.

Out:

[INFO 11-07 08:26:21] ax.service.ax_client: Generated new trial 25 with parameters {'x1': -3.235781, 'x2': 12.32664} using model BoTorch.

Out:

[INFO 11-07 08:26:21] ax.service.ax_client: Completed trial 25 with data: {'branin': (0.471375, nan)}.

Out:

[INFO 11-07 08:26:22] ax.service.ax_client: Generated new trial 26 with parameters {'x1': 9.466124, 'x2': 2.301119} using model BoTorch.

Out:

[INFO 11-07 08:26:22] ax.service.ax_client: Completed trial 26 with data: {'branin': (0.449765, nan)}.

Out:

[W 241107 08:26:24 optimize:576] Optimization failed in gen_candidates_scipy with the following warning(s):

[OptimizationWarning('Optimization failed within scipy.optimize.minimize with status 2 and message ABNORMAL_TERMINATION_IN_LNSRCH.'), OptimizationWarning('Optimization failed within scipy.optimize.minimize with status 2 and message ABNORMAL_TERMINATION_IN_LNSRCH.')]

Trying again with a new set of initial conditions.

Out:

[INFO 11-07 08:26:25] ax.service.ax_client: Generated new trial 27 with parameters {'x1': 2.97826, 'x2': 2.43746} using model BoTorch.

Out:

[INFO 11-07 08:26:25] ax.service.ax_client: Completed trial 27 with data: {'branin': (0.526684, nan)}.

Out:

[INFO 11-07 08:26:27] ax.service.ax_client: Generated new trial 28 with parameters {'x1': -3.286554, 'x2': 12.040548} using model BoTorch.

Out:

[INFO 11-07 08:26:27] ax.service.ax_client: Completed trial 28 with data: {'branin': (0.84146, nan)}.

Out:

[W 241107 08:26:28 optimize:576] Optimization failed in gen_candidates_scipy with the following warning(s):

[OptimizationWarning('Optimization failed within scipy.optimize.minimize with status 2 and message ABNORMAL_TERMINATION_IN_LNSRCH.')]

Trying again with a new set of initial conditions.

Out:

[INFO 11-07 08:26:29] ax.service.ax_client: Generated new trial 29 with parameters {'x1': 9.459437, 'x2': 2.554713} using model BoTorch.

Out:

[INFO 11-07 08:26:29] ax.service.ax_client: Completed trial 29 with data: {'branin': (0.406186, nan)}.

Viewing the evaluated trials

ax_client.get_trials_data_frame()
trial_indexarm_nametrial_statusgeneration_methodbraninx1x2
000_0COMPLETEDSobol104.3650.6258314.3596
111_0COMPLETEDSobol2.996863.166223.86711
222_0COMPLETEDSobol66.53069.5601110.7183
333_0COMPLETEDSobol198.851-3.878660.117947
444_0COMPLETEDSobol5.81178-2.362868.85502
555_0COMPLETEDBoTorch6.611192.562434.92578
666_0COMPLETEDBoTorch31.28885.50414.95576
777_0COMPLETEDBoTorch38.6585-2.306774.43386
888_0COMPLETEDBoTorch12.1611-1.597067.31782
999_0COMPLETEDBoTorch78.5546-59.07374
101010_0COMPLETEDBoTorch21.08850.7687456.89898
111111_0COMPLETEDBoTorch19.51021.619220.603319
121212_0COMPLETEDBoTorch10.9609100
131313_0COMPLETEDBoTorch15.99427.386910
141414_0COMPLETEDBoTorch7.750674.187550
151515_0COMPLETEDBoTorch3.81374-3.9369215
161616_0COMPLETEDBoTorch0.658538-3.3211612.3847
171717_0COMPLETEDBoTorch2.38779103.66978
181818_0COMPLETEDBoTorch0.4501429.341572.54366
191919_0COMPLETEDBoTorch0.4274323.076192.42152
202020_0COMPLETEDBoTorch0.4636699.536532.49418
212121_0COMPLETEDBoTorch5.39239-3.3654315
222222_0COMPLETEDBoTorch0.503349.496422.82039
232323_0COMPLETEDBoTorch0.4756223.217652.43999
242424_0COMPLETEDBoTorch0.4974449.487912.24516
252525_0COMPLETEDBoTorch0.621189.635132.55031
262626_0COMPLETEDBoTorch0.466428-3.2293712.3092
272727_0COMPLETEDBoTorch0.5405292.970412.46022
282828_0COMPLETEDBoTorch0.4048799.462452.49352
292929_0COMPLETEDBoTorch0.5138959.568052.73128
parameters, values = ax_client.get_best_parameters()
print(f"Best parameters: {parameters}")
print(f"Corresponding mean: {values[0]}, covariance: {values[1]}")
Out:

Best parameters: {'x1': 9.536533793530687, 'x2': 2.494184288965757}

Corresponding mean: {'branin': 0.41573114009047885}, covariance: {'branin': {'branin': 0.025883709201637482}}

Plotting the response surface and optimization progress

from ax.utils.notebook.plotting import render

render(ax_client.get_contour_plot())
Out:

[INFO 12-02 15:16:32] ax.service.ax_client: Retrieving contour plot with parameter 'x1' on X-axis and 'x2' on Y-axis, for metric 'branin'. Remaining parameters are affixed to the middle of their range.

best_parameters, values = ax_client.get_best_parameters()
best_parameters, values[0]
Out:

({'x1': 9.536533793530687, 'x2': 2.494184288965757},

{'branin': 0.41573114009047885})

render(ax_client.get_optimization_trace(objective_optimum=0.397887))

Optimization with the Developer API

A detailed tutorial on the Service API can be found here.

Set up the Experiment in Ax

We need 3 inputs for an Ax Experiment:

  • A search space to optimize over;
  • An optimization config specifiying the objective / metrics to optimize, and optional outcome constraints;
  • A runner that handles the deployment of trials. For a synthetic optimization problem, such as here, this only returns simple metadata about the trial.
import pandas as pd
import torch
from ax import (
Data,
Experiment,
Metric,
Objective,
OptimizationConfig,
ParameterType,
RangeParameter,
Runner,
SearchSpace,
)
from ax.utils.common.result import Ok
from botorch.test_functions import Branin


branin_func = Branin()

# For our purposes, the metric is a wrapper that structures the function output.
class BraninMetric(Metric):
def fetch_trial_data(self, trial):
records = []
for arm_name, arm in trial.arms_by_name.items():
params = arm.parameters
tensor_params = torch.tensor([params["x1"], params["x2"]])
records.append(
{
"arm_name": arm_name,
"metric_name": self.name,
"trial_index": trial.index,
"mean": branin_func(tensor_params),
"sem": float(
"nan"
), # SEM (observation noise) - NaN indicates unknown
}
)
return Ok(value=Data(df=pd.DataFrame.from_records(records)))


# Search space defines the parameters, their types, and acceptable values.
search_space = SearchSpace(
parameters=[
RangeParameter(
name="x1", parameter_type=ParameterType.FLOAT, lower=-5, upper=10
),
RangeParameter(
name="x2", parameter_type=ParameterType.FLOAT, lower=0, upper=15
),
]
)

optimization_config = OptimizationConfig(
objective=Objective(
metric=BraninMetric(name="branin_metric", lower_is_better=True),
minimize=True, # This is optional since we specified `lower_is_better=True`
)
)


class MyRunner(Runner):
def run(self, trial):
trial_metadata = {"name": str(trial.index)}
return trial_metadata


exp = Experiment(
name="branin_experiment",
search_space=search_space,
optimization_config=optimization_config,
runner=MyRunner(),
)

Run the BO loop

First, we use the Sobol generator to create 5 (quasi-) random initial point in the search space. Ax controls objective evaluations via Trials.

  • We generate a Trial using a generator run, e.g., Sobol below. A Trial specifies relevant metadata as well as the parameters to be evaluated. At this point, the Trial is at the CANDIDATE stage.
  • We run the Trial using Trial.run(). In our example, this serves to mark the Trial as RUNNING. In an advanced application, this can be used to dispatch the Trial for evaluation on a remote server.
  • Once the Trial is done running, we mark it as COMPLETED. This tells the Experiment that it can fetch the Trial data.

A Trial supports evaluation of a single parameterization. For parallel evaluations, see BatchTrial.

from ax.modelbridge.registry import Models


sobol = Models.SOBOL(exp.search_space)

for i in range(5):
trial = exp.new_trial(generator_run=sobol.gen(1))
trial.run()
trial.mark_completed()

Once the initial (quasi-) random stage is completed, we can use our SimpleCustomGP with the default acquisition function chosen by Ax to run the BO loop.

with fast_smoke_test():
for i in range(NUM_EVALS - 5):
model_bridge = Models.BOTORCH_MODULAR(
experiment=exp,
data=exp.fetch_data(),
surrogate_spec=SurrogateSpec(
model_configs=[ModelConfig(SimpleCustomGP, input_transform_classes=None)]
),
)
trial = exp.new_trial(generator_run=model_bridge.gen(1))
trial.run()
trial.mark_completed()
Out:

/opt/anaconda3/envs/botorch/lib/python3.10/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning:

A not p.d., added jitter of 1.0e-08 to the diagonal

/Users/saitcakmak/botorch/botorch/optim/optimize.py:576: RuntimeWarning:

Optimization failed in gen_candidates_scipy with the following warning(s):

[OptimizationWarning('Optimization failed within scipy.optimize.minimize with status 2 and message ABNORMAL_TERMINATION_IN_LNSRCH.')]

Trying again with a new set of initial conditions.

View the trials attached to the Experiment.

exp.trials
Out:

{0: Trial(experiment_name='branin_experiment', index=0, status=TrialStatus.COMPLETED, arm=Arm(name='0_0', parameters={'x1': -3.001730814576149, 'x2': 10.747691988945007})),

1: Trial(experiment_name='branin_experiment', index=1, status=TrialStatus.COMPLETED, arm=Arm(name='1_0', parameters={'x1': 5.38613950368017, 'x2': 3.3136763563379645})),

2: Trial(experiment_name='branin_experiment', index=2, status=TrialStatus.COMPLETED, arm=Arm(name='2_0', parameters={'x1': 7.6478424202650785, 'x2': 13.54916384909302})),

3: Trial(experiment_name='branin_experiment', index=3, status=TrialStatus.COMPLETED, arm=Arm(name='3_0', parameters={'x1': -0.9862332977354527, 'x2': 6.144560454413295})),

4: Trial(experiment_name='branin_experiment', index=4, status=TrialStatus.COMPLETED, arm=Arm(name='4_0', parameters={'x1': 0.9829527512192726, 'x2': 11.45068815909326})),

5: Trial(experiment_name='branin_experiment', index=5, status=TrialStatus.COMPLETED, arm=Arm(name='5_0', parameters={'x1': -5.0, 'x2': 10.681135484818025})),

6: Trial(experiment_name='branin_experiment', index=6, status=TrialStatus.COMPLETED, arm=Arm(name='6_0', parameters={'x1': 0.9771782267390936, 'x2': 0.0})),

7: Trial(experiment_name='branin_experiment', index=7, status=TrialStatus.COMPLETED, arm=Arm(name='7_0', parameters={'x1': 10.0, 'x2': 0.0})),

8: Trial(experiment_name='branin_experiment', index=8, status=TrialStatus.COMPLETED, arm=Arm(name='8_0', parameters={'x1': -5.0, 'x2': 3.5936134273014724})),

9: Trial(experiment_name='branin_experiment', index=9, status=TrialStatus.COMPLETED, arm=Arm(name='9_0', parameters={'x1': 7.451191634850801, 'x2': 0.0})),

10: Trial(experiment_name='branin_experiment', index=10, status=TrialStatus.COMPLETED, arm=Arm(name='10_0', parameters={'x1': -2.989790970817857, 'x2': 14.507362023858969})),

11: Trial(experiment_name='branin_experiment', index=11, status=TrialStatus.COMPLETED, arm=Arm(name='11_0', parameters={'x1': -2.206908749062847, 'x2': 11.361635411408425})),

12: Trial(experiment_name='branin_experiment', index=12, status=TrialStatus.COMPLETED, arm=Arm(name='12_0', parameters={'x1': 10.0, 'x2': 3.854754557203238})),

13: Trial(experiment_name='branin_experiment', index=13, status=TrialStatus.COMPLETED, arm=Arm(name='13_0', parameters={'x1': 8.879169356973378, 'x2': 2.600188907289197})),

14: Trial(experiment_name='branin_experiment', index=14, status=TrialStatus.COMPLETED, arm=Arm(name='14_0', parameters={'x1': 10.0, 'x2': 2.3822346114805195})),

15: Trial(experiment_name='branin_experiment', index=15, status=TrialStatus.COMPLETED, arm=Arm(name='15_0', parameters={'x1': -4.3121547230008925, 'x2': 15.0})),

16: Trial(experiment_name='branin_experiment', index=16, status=TrialStatus.COMPLETED, arm=Arm(name='16_0', parameters={'x1': -3.329959894641452, 'x2': 12.717573331024465})),

17: Trial(experiment_name='branin_experiment', index=17, status=TrialStatus.COMPLETED, arm=Arm(name='17_0', parameters={'x1': 2.7242595961882197, 'x2': 2.9602941886720284})),

18: Trial(experiment_name='branin_experiment', index=18, status=TrialStatus.COMPLETED, arm=Arm(name='18_0', parameters={'x1': 3.649650277984131, 'x2': 0.0})),

19: Trial(experiment_name='branin_experiment', index=19, status=TrialStatus.COMPLETED, arm=Arm(name='19_0', parameters={'x1': 9.46343663901485, 'x2': 2.5467470674060593})),

20: Trial(experiment_name='branin_experiment', index=20, status=TrialStatus.COMPLETED, arm=Arm(name='20_0', parameters={'x1': 3.0104763473247154, 'x2': 1.9945479954698144})),

21: Trial(experiment_name='branin_experiment', index=21, status=TrialStatus.COMPLETED, arm=Arm(name='21_0', parameters={'x1': -3.203719761808281, 'x2': 12.37041540884107})),

22: Trial(experiment_name='branin_experiment', index=22, status=TrialStatus.COMPLETED, arm=Arm(name='22_0', parameters={'x1': 9.476541680994133, 'x2': 2.9065640930047554})),

23: Trial(experiment_name='branin_experiment', index=23, status=TrialStatus.COMPLETED, arm=Arm(name='23_0', parameters={'x1': 2.734598996972391, 'x2': 2.1218033122913322})),

24: Trial(experiment_name='branin_experiment', index=24, status=TrialStatus.COMPLETED, arm=Arm(name='24_0', parameters={'x1': 3.2195726883400404, 'x2': 2.136557689490957})),

25: Trial(experiment_name='branin_experiment', index=25, status=TrialStatus.COMPLETED, arm=Arm(name='25_0', parameters={'x1': 9.42953989277149, 'x2': 2.2711877578623016})),

26: Trial(experiment_name='branin_experiment', index=26, status=TrialStatus.COMPLETED, arm=Arm(name='26_0', parameters={'x1': 3.1637517854500388, 'x2': 2.5217242210404827})),

27: Trial(experiment_name='branin_experiment', index=27, status=TrialStatus.COMPLETED, arm=Arm(name='27_0', parameters={'x1': -3.1115663937782836, 'x2': 12.125534390814288})),

28: Trial(experiment_name='branin_experiment', index=28, status=TrialStatus.COMPLETED, arm=Arm(name='28_0', parameters={'x1': 9.517886960473454, 'x2': 2.660053049906566})),

29: Trial(experiment_name='branin_experiment', index=29, status=TrialStatus.COMPLETED, arm=Arm(name='29_0', parameters={'x1': 3.3039667073151158, 'x2': 1.9687756455552108}))}

View the evaluation data about these trials.

exp.fetch_data().df
arm_namemetric_namemeansemtrial_index
00_0branin_metric1.9166nan0
11_0branin_metric20.5632nan1
22_0branin_metric159.956nan2
33_0branin_metric17.7035nan3
44_0branin_metric62.8011nan4
55_0branin_metric55.0547nan5
66_0branin_metric36.2389nan6
77_0branin_metric10.9609nan7
88_0branin_metric197.514nan8
99_0branin_metric15.489nan9
1010_0branin_metric7.23819nan10
1111_0branin_metric5.78416nan11
1212_0branin_metric2.6687nan12
1313_0branin_metric2.09117nan13
1414_0branin_metric2.32844nan14
1515_0branin_metric6.329nan15
1616_0branin_metric0.567952nan16
1717_0branin_metric1.3358nan17
1818_0branin_metric5.26697nan18
1919_0branin_metric0.406578nan19
2020_0branin_metric0.628474nan20
2121_0branin_metric0.419371nan21
2222_0branin_metric0.560946nan22
2323_0branin_metric1.42429nan23
2424_0branin_metric0.433216nan24
2525_0branin_metric0.441191nan25
2626_0branin_metric0.46991nan26
2727_0branin_metric0.408209nan27
2828_0branin_metric0.450586nan28
2929_0branin_metric0.55768nan29

Plot results

We can use convenient Ax utilities for plotting the results.

import numpy as np
from ax.plot.trace import optimization_trace_single_method


# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple
# optimization runs, so we wrap out best objectives array in another array.
objective_means = np.array([[trial.objective_mean for trial in exp.trials.values()]])
best_objective_plot = optimization_trace_single_method(
y=np.minimum.accumulate(objective_means, axis=1),
optimum=0.397887, # Known minimum objective for Branin function.
)
render(best_objective_plot)