In this tutorial, we illustrate how to use a custom BoTorch model within Ax's botorch_modular
API. This allows us to harness the convenience of Ax for running Bayesian Optimization loops while maintaining full flexibility in modeling.
Acquisition functions and their optimizers can be swapped out in much the same fashion. See for example the tutorial for Implementing a custom acquisition function.
If you want to do something non-standard, or would like to have full insight into every aspect of the implementation, please see this tutorial for how to write your own full optimization loop in BoTorch.
import os
from contextlib import contextmanager, nullcontext
from ax.utils.testing.mock import fast_botorch_optimize_context_manager
import plotly.io as pio
# Ax uses Plotly to produce interactive plots. These are great for viewing and analysis,
# though they also lead to large file sizes, which is not ideal for files living in GH.
# Changing the default to `png` strips the interactive components to get around this.
pio.renderers.default = "png"
SMOKE_TEST = os.environ.get("SMOKE_TEST")
NUM_EVALS = 10 if SMOKE_TEST else 30
For this tutorial, we implement a very simple GPyTorch ExactGP
model that uses an RBF kernel (with ARD) and infers a homoskedastic noise level.
Model definition is straightforward. Here we implement a GPyTorch ExactGP
that inherits from GPyTorchModel
; together these two superclasses add all the API calls that BoTorch expects in its various modules.
Note: BoTorch allows implementing any custom model that follows the Model
API. For more information, please see the Model Documentation.
from typing import Optional
from botorch.models.gpytorch import GPyTorchModel
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP
from torch import Tensor
class SimpleCustomGP(ExactGP, GPyTorchModel):
_num_outputs = 1 # to inform GPyTorchModel API
def __init__(self, train_X, train_Y, train_Yvar: Optional[Tensor] = None):
# NOTE: This ignores train_Yvar and uses inferred noise instead.
# squeeze output dim before passing train_Y to ExactGP
super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
self.mean_module = ConstantMean()
self.covar_module = ScaleKernel(
base_kernel=RBFKernel(ard_num_dims=train_X.shape[-1]),
)
self.to(train_X) # make sure we're on the right device/dtype
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
BoTorchModel
in Ax¶A BoTorchModel
in Ax encapsulates both the surrogate -- which Ax
calls a Surrogate
and BoTorch calls a Model
-- and an acquisition function. Here, we will only specify the custom surrogate and let Ax choose the default acquisition function.
Most models should work with the base Surrogate
in Ax, except for BoTorch ModelListGP
, which works with ListSurrogate
.
Note that the Model
(e.g., the SimpleCustomGP
) must implement construct_inputs
, as this is used to construct the inputs required for instantiating a Model
instance from the experiment data.
from ax.models.torch.botorch_modular.model import BoTorchModel
from ax.models.torch.botorch_modular.surrogate import Surrogate
ax_model = BoTorchModel(
surrogate=Surrogate(
# The model class to use
botorch_model_class=SimpleCustomGP,
# Optional, MLL class with which to optimize model parameters
# mll_class=ExactMarginalLogLikelihood,
# Optional, dictionary of keyword arguments to model constructor
# model_options={}
),
# Optional, acquisition function class to use - see custom acquisition tutorial
# botorch_acqf_class=qExpectedImprovement,
)
ModelBridge
¶Model
s in Ax require a ModelBridge
to interface with Experiment
s. A ModelBridge
takes the inputs supplied by the Experiment
and converts them to the inputs expected by the Model
. For a BoTorchModel
, we use TorchModelBridge
. The Modular BoTorch interface creates the BoTorchModel
and the TorchModelBridge
in a single step, as follows:
from ax.modelbridge.registry import Models
model_bridge = Models.BOTORCH_MODULAR(
experiment=experiment,
data=data,
surrogate=Surrogate(SimpleCustomGP),
# Optional, will use default if unspecified
# botorch_acqf_class=qLogNoisyExpectedImprovement,
)
# To generate a trial
trial = model_bridge.gen(1)
We will demonstrate this with both the Service API (simpler, easier to use) and the Developer API (advanced, more customizable).
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
gs = GenerationStrategy(
steps=[
# Quasi-random initialization step
GenerationStep(
model=Models.SOBOL,
num_trials=5, # How many trials should be produced from this generation step
),
# Bayesian optimization step using the custom acquisition function
GenerationStep(
model=Models.BOTORCH_MODULAR,
num_trials=-1, # No limitation on how many trials should be produced from this step
# For `BOTORCH_MODULAR`, we pass in kwargs to specify what surrogate or acquisition function to use.
model_kwargs={
"surrogate": Surrogate(SimpleCustomGP),
},
),
]
)
In order to use the GenerationStrategy
we just created, we will pass it into the AxClient
.
import torch
from ax.service.ax_client import AxClient
from ax.service.utils.instantiation import ObjectiveProperties
from botorch.test_functions import Branin
# Initialize the client - AxClient offers a convenient API to control the experiment
ax_client = AxClient(generation_strategy=gs)
# Setup the experiment
ax_client.create_experiment(
name="branin_test_experiment",
parameters=[
{
"name": "x1",
"type": "range",
# It is crucial to use floats for the bounds, i.e., 0.0 rather than 0.
# Otherwise, the parameter would be inferred as an integer range.
"bounds": [-5.0, 10.0],
},
{
"name": "x2",
"type": "range",
"bounds": [0.0, 15.0],
},
],
objectives={
"branin": ObjectiveProperties(minimize=True),
},
)
# Setup a function to evaluate the trials
branin = Branin()
def evaluate(parameters):
x = torch.tensor([[parameters.get(f"x{i+1}") for i in range(2)]])
# The GaussianLikelihood used by our model infers an observation noise level,
# so we pass an sem value of NaN to indicate that observation noise is unknown
return {"branin": (branin(x).item(), float("nan"))}
The next cell sets up a decorator solely to speed up the testing of the notebook in SMOKE_TEST
mode. You can safely ignore this cell and the use of the decorator throughout the tutorial.
if SMOKE_TEST:
fast_smoke_test = fast_botorch_optimize_context_manager
else:
fast_smoke_test = nullcontext
# Set a seed for reproducible tutorial output
torch.manual_seed(0)
<torch._C.Generator at 0x7f6bc3d613f0>
with fast_smoke_test():
for i in range(NUM_EVALS):
parameters, trial_index = ax_client.get_next_trial()
# Local evaluation here can be replaced with deployment to external system.
ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(parameters))
ax_client.get_trials_data_frame()
trial_index | arm_name | trial_status | generation_method | branin | x1 | x2 | |
---|---|---|---|---|---|---|---|
0 | 0 | 0_0 | COMPLETED | Sobol | 104.365417 | 0.625830 | 14.359564 |
1 | 1 | 1_0 | COMPLETED | Sobol | 122.930588 | 1.458453 | 14.529696 |
2 | 2 | 2_0 | COMPLETED | Sobol | 7.390774 | 9.140571 | 4.816245 |
3 | 3 | 3_0 | COMPLETED | Sobol | 61.271847 | 8.215463 | 9.037606 |
4 | 4 | 4_0 | COMPLETED | Sobol | 145.015671 | 9.934317 | 14.913237 |
5 | 5 | 5_0 | COMPLETED | BoTorch | 16.665333 | 7.924335 | 4.209484 |
6 | 6 | 6_0 | COMPLETED | BoTorch | 36.188915 | -2.860520 | 5.658734 |
7 | 7 | 7_0 | COMPLETED | BoTorch | 28.414114 | -1.552397 | 13.052576 |
8 | 8 | 8_0 | COMPLETED | BoTorch | 106.399498 | 8.005993 | 11.430362 |
9 | 9 | 9_0 | COMPLETED | BoTorch | 27.105524 | 0.326089 | 2.664659 |
10 | 10 | 10_0 | COMPLETED | BoTorch | 22.409479 | -4.674472 | 12.688425 |
11 | 11 | 11_0 | COMPLETED | BoTorch | 271.716095 | -5.000000 | 1.094121 |
12 | 12 | 12_0 | COMPLETED | BoTorch | 155.017776 | -4.371402 | 3.250981 |
13 | 13 | 13_0 | COMPLETED | BoTorch | 17.016251 | -0.802942 | 7.950013 |
14 | 14 | 14_0 | COMPLETED | BoTorch | 3.365896 | -2.905859 | 10.071743 |
15 | 15 | 15_0 | COMPLETED | BoTorch | 10.960894 | 10.000000 | 0.000000 |
16 | 16 | 16_0 | COMPLETED | BoTorch | 2.059195 | 10.000000 | 3.343624 |
17 | 17 | 17_0 | COMPLETED | BoTorch | 20.728100 | 6.154240 | 0.000000 |
18 | 18 | 18_0 | COMPLETED | BoTorch | 0.601567 | 9.225543 | 2.194878 |
19 | 19 | 19_0 | COMPLETED | BoTorch | 0.541902 | 9.596565 | 2.675480 |
20 | 20 | 20_0 | COMPLETED | BoTorch | 3.603829 | -3.830353 | 15.000000 |
21 | 21 | 21_0 | COMPLETED | BoTorch | 0.518112 | -3.279040 | 12.435520 |
22 | 22 | 22_0 | COMPLETED | BoTorch | 0.400917 | 9.449460 | 2.485637 |
23 | 23 | 23_0 | COMPLETED | BoTorch | 2.305150 | 3.475702 | 3.202017 |
24 | 24 | 24_0 | COMPLETED | BoTorch | 0.961923 | 3.395999 | 1.580014 |
25 | 25 | 25_0 | COMPLETED | BoTorch | 0.439405 | -3.167333 | 12.141147 |
26 | 26 | 26_0 | COMPLETED | BoTorch | 1.402626 | 3.598832 | 2.080916 |
27 | 27 | 27_0 | COMPLETED | BoTorch | 0.453665 | 9.517319 | 2.675371 |
28 | 28 | 28_0 | COMPLETED | BoTorch | 0.413040 | 3.197775 | 2.229931 |
29 | 29 | 29_0 | COMPLETED | BoTorch | 0.571966 | 9.587401 | 2.833282 |
parameters, values = ax_client.get_best_parameters()
print(f"Best parameters: {parameters}")
print(f"Corresponding mean: {values[0]}, covariance: {values[1]}")
Best parameters: {'x1': 9.517319461327038, 'x2': 2.675371257532727} Corresponding mean: {'branin': 0.3263206610090492}, covariance: {'branin': {'branin': 0.06906659192184665}}
from ax.utils.notebook.plotting import render
render(ax_client.get_contour_plot())
[INFO 11-22 08:58:03] ax.service.ax_client: Retrieving contour plot with parameter 'x1' on X-axis and 'x2' on Y-axis, for metric 'branin'. Remaining parameters are affixed to the middle of their range.
best_parameters, values = ax_client.get_best_parameters()
best_parameters, values[0]
({'x1': 9.517319461327038, 'x2': 2.675371257532727}, {'branin': 0.3263206610090492})
render(ax_client.get_optimization_trace(objective_optimum=0.397887))
A detailed tutorial on the Service API can be found here.
We need 3 inputs for an Ax Experiment
:
import pandas as pd
import torch
from ax import (
Data,
Experiment,
Metric,
Objective,
OptimizationConfig,
ParameterType,
RangeParameter,
Runner,
SearchSpace,
)
from ax.utils.common.result import Ok
from botorch.test_functions import Branin
branin_func = Branin()
# For our purposes, the metric is a wrapper that structures the function output.
class BraninMetric(Metric):
def fetch_trial_data(self, trial):
records = []
for arm_name, arm in trial.arms_by_name.items():
params = arm.parameters
tensor_params = torch.tensor([params["x1"], params["x2"]])
records.append(
{
"arm_name": arm_name,
"metric_name": self.name,
"trial_index": trial.index,
"mean": branin_func(tensor_params),
"sem": float(
"nan"
), # SEM (observation noise) - NaN indicates unknown
}
)
return Ok(value=Data(df=pd.DataFrame.from_records(records)))
# Search space defines the parameters, their types, and acceptable values.
search_space = SearchSpace(
parameters=[
RangeParameter(
name="x1", parameter_type=ParameterType.FLOAT, lower=-5, upper=10
),
RangeParameter(
name="x2", parameter_type=ParameterType.FLOAT, lower=0, upper=15
),
]
)
optimization_config = OptimizationConfig(
objective=Objective(
metric=BraninMetric(name="branin_metric", lower_is_better=True),
minimize=True, # This is optional since we specified `lower_is_better=True`
)
)
class MyRunner(Runner):
def run(self, trial):
trial_metadata = {"name": str(trial.index)}
return trial_metadata
exp = Experiment(
name="branin_experiment",
search_space=search_space,
optimization_config=optimization_config,
runner=MyRunner(),
)
First, we use the Sobol generator to create 5 (quasi-) random initial point in the search space. Ax controls objective evaluations via Trial
s.
Trial
using a generator run, e.g., Sobol
below. A Trial
specifies relevant metadata as well as the parameters to be evaluated. At this point, the Trial
is at the CANDIDATE
stage.Trial
using Trial.run()
. In our example, this serves to mark the Trial
as RUNNING
. In an advanced application, this can be used to dispatch the Trial
for evaluation on a remote server.Trial
is done running, we mark it as COMPLETED
. This tells the Experiment
that it can fetch the Trial
data.A Trial
supports evaluation of a single parameterization. For parallel evaluations, see BatchTrial
.
from ax.modelbridge.registry import Models
sobol = Models.SOBOL(exp.search_space)
for i in range(5):
trial = exp.new_trial(generator_run=sobol.gen(1))
trial.run()
trial.mark_completed()
Once the initial (quasi-) random stage is completed, we can use our SimpleCustomGP
with the default acquisition function chosen by Ax
to run the BO loop.
with fast_smoke_test():
for i in range(NUM_EVALS - 5):
model_bridge = Models.BOTORCH_MODULAR(
experiment=exp,
data=exp.fetch_data(),
surrogate=Surrogate(SimpleCustomGP),
)
trial = exp.new_trial(generator_run=model_bridge.gen(1))
trial.run()
trial.mark_completed()
View the trials attached to the Experiment
.
exp.trials
{0: Trial(experiment_name='branin_experiment', index=0, status=TrialStatus.COMPLETED, arm=Arm(name='0_0', parameters={'x1': -0.4329736530780792, 'x2': 12.617264986038208})), 1: Trial(experiment_name='branin_experiment', index=1, status=TrialStatus.COMPLETED, arm=Arm(name='1_0', parameters={'x1': 2.9818817181512713, 'x2': 2.856269543990493})), 2: Trial(experiment_name='branin_experiment', index=2, status=TrialStatus.COMPLETED, arm=Arm(name='2_0', parameters={'x1': 6.599703226238489, 'x2': 8.393928492441773})), 3: Trial(experiment_name='branin_experiment', index=3, status=TrialStatus.COMPLETED, arm=Arm(name='3_0', parameters={'x1': -4.985555941238999, 'x2': 6.132936486974359})), 4: Trial(experiment_name='branin_experiment', index=4, status=TrialStatus.COMPLETED, arm=Arm(name='4_0', parameters={'x1': -1.767810033634305, 'x2': 10.938136987388134})), 5: Trial(experiment_name='branin_experiment', index=5, status=TrialStatus.COMPLETED, arm=Arm(name='5_0', parameters={'x1': 10.0, 'x2': 0.0})), 6: Trial(experiment_name='branin_experiment', index=6, status=TrialStatus.COMPLETED, arm=Arm(name='6_0', parameters={'x1': 4.826118509694027, 'x2': 15.0})), 7: Trial(experiment_name='branin_experiment', index=7, status=TrialStatus.COMPLETED, arm=Arm(name='7_0', parameters={'x1': 7.881232756494741, 'x2': 1.667689434902604})), 8: Trial(experiment_name='branin_experiment', index=8, status=TrialStatus.COMPLETED, arm=Arm(name='8_0', parameters={'x1': -5.0, 'x2': 2.2314214877327094})), 9: Trial(experiment_name='branin_experiment', index=9, status=TrialStatus.COMPLETED, arm=Arm(name='9_0', parameters={'x1': 1.524281688383832, 'x2': 6.8288515949282935})), 10: Trial(experiment_name='branin_experiment', index=10, status=TrialStatus.COMPLETED, arm=Arm(name='10_0', parameters={'x1': -5.0, 'x2': 14.195328959491395})), 11: Trial(experiment_name='branin_experiment', index=11, status=TrialStatus.COMPLETED, arm=Arm(name='11_0', parameters={'x1': 4.954093857651461, 'x2': 0.0})), 12: Trial(experiment_name='branin_experiment', index=12, status=TrialStatus.COMPLETED, arm=Arm(name='12_0', parameters={'x1': 10.0, 'x2': 4.292254247708062})), 13: Trial(experiment_name='branin_experiment', index=13, status=TrialStatus.COMPLETED, arm=Arm(name='13_0', parameters={'x1': 1.784995401645844, 'x2': 2.48105215133849})), 14: Trial(experiment_name='branin_experiment', index=14, status=TrialStatus.COMPLETED, arm=Arm(name='14_0', parameters={'x1': 10.0, 'x2': 2.636078735605977})), 15: Trial(experiment_name='branin_experiment', index=15, status=TrialStatus.COMPLETED, arm=Arm(name='15_0', parameters={'x1': 3.5206582595405216, 'x2': 2.734373821999437})), 16: Trial(experiment_name='branin_experiment', index=16, status=TrialStatus.COMPLETED, arm=Arm(name='16_0', parameters={'x1': 10.0, 'x2': 15.0})), 17: Trial(experiment_name='branin_experiment', index=17, status=TrialStatus.COMPLETED, arm=Arm(name='17_0', parameters={'x1': -3.432540455981013, 'x2': 11.981022952446304})), 18: Trial(experiment_name='branin_experiment', index=18, status=TrialStatus.COMPLETED, arm=Arm(name='18_0', parameters={'x1': -3.889312984711564, 'x2': 11.70726330370977})), 19: Trial(experiment_name='branin_experiment', index=19, status=TrialStatus.COMPLETED, arm=Arm(name='19_0', parameters={'x1': -3.3829822918223034, 'x2': 13.579969027820502})), 20: Trial(experiment_name='branin_experiment', index=20, status=TrialStatus.COMPLETED, arm=Arm(name='20_0', parameters={'x1': -3.298671582798891, 'x2': 12.874831586512578})), 21: Trial(experiment_name='branin_experiment', index=21, status=TrialStatus.COMPLETED, arm=Arm(name='21_0', parameters={'x1': 9.50823416510418, 'x2': 2.8940046336419702})), 22: Trial(experiment_name='branin_experiment', index=22, status=TrialStatus.COMPLETED, arm=Arm(name='22_0', parameters={'x1': 3.074374349475095, 'x2': 2.3894152080730544})), 23: Trial(experiment_name='branin_experiment', index=23, status=TrialStatus.COMPLETED, arm=Arm(name='23_0', parameters={'x1': 9.588438538916012, 'x2': 2.6513939215449205})), 24: Trial(experiment_name='branin_experiment', index=24, status=TrialStatus.COMPLETED, arm=Arm(name='24_0', parameters={'x1': -3.229034722434058, 'x2': 12.655684947578434})), 25: Trial(experiment_name='branin_experiment', index=25, status=TrialStatus.COMPLETED, arm=Arm(name='25_0', parameters={'x1': 3.105126694005836, 'x2': 2.6126209051274367})), 26: Trial(experiment_name='branin_experiment', index=26, status=TrialStatus.COMPLETED, arm=Arm(name='26_0', parameters={'x1': -3.3716682849456543, 'x2': 13.126453746040617})), 27: Trial(experiment_name='branin_experiment', index=27, status=TrialStatus.COMPLETED, arm=Arm(name='27_0', parameters={'x1': 3.0652809834567183, 'x2': 1.994860653228331})), 28: Trial(experiment_name='branin_experiment', index=28, status=TrialStatus.COMPLETED, arm=Arm(name='28_0', parameters={'x1': 9.400941016992647, 'x2': 3.04534723557459})), 29: Trial(experiment_name='branin_experiment', index=29, status=TrialStatus.COMPLETED, arm=Arm(name='29_0', parameters={'x1': -3.8481443576407575, 'x2': 15.0}))}
View the evaluation data about these trials.
exp.fetch_data().df
arm_name | metric_name | mean | sem | trial_index | |
---|---|---|---|---|---|
0 | 0_0 | branin_metric | 53.572651 | NaN | 0 |
1 | 1_0 | branin_metric | 0.725682 | NaN | 1 |
2 | 2_0 | branin_metric | 71.991264 | NaN | 2 |
3 | 3_0 | branin_metric | 133.872314 | NaN | 3 |
4 | 4_0 | branin_metric | 11.081824 | NaN | 4 |
5 | 5_0 | branin_metric | 10.960894 | NaN | 5 |
6 | 6_0 | branin_metric | 198.016434 | NaN | 6 |
7 | 7_0 | branin_metric | 9.773301 | NaN | 7 |
8 | 8_0 | branin_metric | 236.403839 | NaN | 8 |
9 | 9_0 | branin_metric | 19.176552 | NaN | 9 |
10 | 10_0 | branin_metric | 21.676010 | NaN | 10 |
11 | 11_0 | branin_metric | 13.951876 | NaN | 11 |
12 | 12_0 | branin_metric | 3.605428 | NaN | 12 |
13 | 13_0 | branin_metric | 9.146261 | NaN | 13 |
14 | 14_0 | branin_metric | 2.077741 | NaN | 14 |
15 | 15_0 | branin_metric | 1.621861 | NaN | 15 |
16 | 16_0 | branin_metric | 145.872208 | NaN | 16 |
17 | 17_0 | branin_metric | 1.809717 | NaN | 17 |
18 | 18_0 | branin_metric | 8.897885 | NaN | 18 |
19 | 19_0 | branin_metric | 1.190839 | NaN | 19 |
20 | 20_0 | branin_metric | 0.564129 | NaN | 20 |
21 | 21_0 | branin_metric | 0.552207 | NaN | 21 |
22 | 22_0 | branin_metric | 0.423343 | NaN | 22 |
23 | 23_0 | branin_metric | 0.527413 | NaN | 23 |
24 | 24_0 | branin_metric | 0.463321 | NaN | 24 |
25 | 25_0 | branin_metric | 0.499759 | NaN | 25 |
26 | 26_0 | branin_metric | 0.735993 | NaN | 26 |
27 | 27_0 | branin_metric | 0.541708 | NaN | 27 |
28 | 28_0 | branin_metric | 0.749163 | NaN | 28 |
29 | 29_0 | branin_metric | 3.622983 | NaN | 29 |
We can use convenient Ax utilities for plotting the results.
import numpy as np
from ax.plot.trace import optimization_trace_single_method
# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple
# optimization runs, so we wrap out best objectives array in another array.
objective_means = np.array([[trial.objective_mean for trial in exp.trials.values()]])
best_objective_plot = optimization_trace_single_method(
y=np.minimum.accumulate(objective_means, axis=1),
optimum=0.397887, # Known minimum objective for Branin function.
)
render(best_objective_plot)