{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using a custom botorch model with Ax\n",
"\n",
"In this tutorial, we illustrate how to use a custom BoTorch model within Ax's `SimpleExperiment` API. This allows us to harness the convenience of Ax for running Bayesian Optimization loops, while at the same time maintaining full flexibility in terms of the modeling.\n",
"\n",
"Acquisition functions and strategies for optimizing acquisitions can be swapped out in much the same fashion. See for example the tutorial for [Implementing a custom acquisition function](./custom_acquisition).\n",
"\n",
"If you want to do something non-standard, or would like to have full insight into every aspect of the implementation, please see [this tutorial](./closed_loop_botorch_only) for how to write your own full optimization loop in BoTorch."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Implementing the custom model\n",
"\n",
"For this tutorial, we implement a very simple gpytorch Exact GP Model that uses an RBF kernel (with ARD) and infers a (homoskedastic) noise level.\n",
"\n",
"Model definition is straightforward - here we implement a gpytorch `ExactGP` that also inherits from `GPyTorchModel` -- this adds all the api calls that botorch expects in its various modules. \n",
"\n",
"*Note:* botorch also allows implementing other custom models as long as they follow the minimal `Model` API. For more information, please see the [Model Documentation](../docs/models)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from botorch.models.gpytorch import GPyTorchModel\n",
"from gpytorch.distributions import MultivariateNormal\n",
"from gpytorch.means import ConstantMean\n",
"from gpytorch.models import ExactGP\n",
"from gpytorch.kernels import RBFKernel, ScaleKernel\n",
"from gpytorch.likelihoods import GaussianLikelihood\n",
"from gpytorch.mlls import ExactMarginalLogLikelihood\n",
"from gpytorch.priors import GammaPrior\n",
"\n",
"\n",
"class SimpleCustomGP(ExactGP, GPyTorchModel):\n",
"\n",
" _num_outputs = 1 # to inform GPyTorchModel API\n",
" \n",
" def __init__(self, train_X, train_Y):\n",
" # squeeze output dim before passing train_Y to ExactGP\n",
" super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())\n",
" self.mean_module = ConstantMean()\n",
" self.covar_module = ScaleKernel(\n",
" base_kernel=RBFKernel(ard_num_dims=train_X.shape[-1]),\n",
" )\n",
" self.to(train_X) # make sure we're on the right device/dtype\n",
" \n",
" def forward(self, x):\n",
" mean_x = self.mean_module(x)\n",
" covar_x = self.covar_module(x)\n",
" return MultivariateNormal(mean_x, covar_x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Define a factory function to be used with Ax's BotorchModel\n",
"\n",
"Ax's `BotorchModel` internally breaks down the different components of Bayesian Optimization (model generation & fitting, defining acquisition functions, and optimizing them) into a functional api. \n",
"\n",
"Depending on which of these components we want to modify, we can pass in an associated custom factory function to the `BotorchModel` constructor. In order to use a custom model, we have to implement a model factory function that, given data according to Ax's api specification, instantiates and fits a BoTorch Model object.\n",
"\n",
"The call signature of this factory function is the following: \n",
"\n",
"```python\n",
"def get_and_fit_gpytorch_model(\n",
" Xs: List[Tensor],\n",
" Ys: List[Tensor],\n",
" Yvars: List[Tensor],\n",
" state_dict: Optional[Dict[str, Tensor]] = None,\n",
" **kwargs: Any,\n",
") -> Model:\n",
"```\n",
"\n",
"where\n",
"- the `i`-th element of `Xs` are the training features for the i-th outcome as an `n_i x d` tensor (in our simple example, we only have one outcome)\n",
"- similarly, the `i`-th element of `Ys` and `Yvars` are the observations and associated observation variances for the `i`-th outcome as `n_i x 1` tensors\n",
"- `state_dict` is an optional PyTorch module state dict that can be used to initialize the model's parameters to pre-specified values\n",
"\n",
"The function must return a botorch `Model` object. What happens inside the function is up to you.\n",
"\n",
"Using botorch's `fit_gpytorch_model` utility function, model-fitting is straightforward for this simple model (you may have to use your own custom model fitting loop when working with more complex models - see the tutorial for [Fitting a model with torch.optim](fit_model_with_torch_optimizer)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from botorch.fit import fit_gpytorch_model\n",
"\n",
"def _get_and_fit_simple_custom_gp(Xs, Ys, **kwargs):\n",
" model = SimpleCustomGP(Xs[0], Ys[0])\n",
" mll = ExactMarginalLogLikelihood(model.likelihood, model)\n",
" fit_gpytorch_model(mll)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Set up the optimization problem in Ax"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ax's `SimpleExperiment` API requires an evaluation function that is able to compute all the metrics required in the experiment. This function needs to accept a set of parameter values as a dictionary. It should produce a dictionary of metric names to tuples of mean and standard error for those metrics.\n",
"\n",
"For this tutorial, we use the Branin function, a simple synthetic benchmark function in two dimensions. In an actual application, this could be arbitrarily complicated - e.g. this function could run some costly simulation, conduct some A/B tests, or kick off some ML model training job with the given parameters). "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import numpy as np\n",
"\n",
"def branin(parameterization, *args):\n",
" x1, x2 = parameterization[\"x1\"], parameterization[\"x2\"]\n",
" y = (x2 - 5.1 / (4 * np.pi ** 2) * x1 ** 2 + 5 * x1 / np.pi - 6) ** 2\n",
" y += 10 * (1 - 1 / (8 * np.pi)) * np.cos(x1) + 10\n",
" # let's add some synthetic observation noise\n",
" y += random.normalvariate(0, 0.1)\n",
" return {\"branin\": (y, 0.0)}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We need to define a search space for our experiment that defines the parameters and the set of feasible values."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from ax import ParameterType, RangeParameter, SearchSpace\n",
"\n",
"search_space = SearchSpace(\n",
" parameters=[\n",
" RangeParameter(\n",
" name=\"x1\", parameter_type=ParameterType.FLOAT, lower=-5, upper=10\n",
" ),\n",
" RangeParameter(\n",
" name=\"x2\", parameter_type=ParameterType.FLOAT, lower=0, upper=15\n",
" ),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Third, we make a `SimpleExperiment` — note that the `objective_name` needs to be one of the metric names returned by the evaluation function."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from ax import SimpleExperiment\n",
"\n",
"exp = SimpleExperiment(\n",
" name=\"test_branin\",\n",
" search_space=search_space,\n",
" evaluation_function=branin,\n",
" objective_name=\"branin\",\n",
" minimize=True,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We use the Sobol generator to create 5 (quasi-) random initial point in the search space. Calling `batch_trial` will cause Ax to evaluate the underlying `branin` function at the generated points, and automatically keep track of the results."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"BatchTrial(experiment_name='test_branin', index=0, status=TrialStatus.CANDIDATE)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from ax.modelbridge import get_sobol\n",
"\n",
"sobol = get_sobol(exp.search_space)\n",
"exp.new_batch_trial(generator_run=sobol.gen(5))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To run our custom botorch model inside the Ax optimization loop, we can use the `get_botorch` factory function from `ax.modelbridge.factory`. Any keyword arguments given to this function are passed through to the `BotorchModel` constructor. To use our custom model, we just need to pass our newly minted `_get_and_fit_simple_custom_gp` function to `get_botorch` using the `model_constructor` argument.\n",
"\n",
"**Note:** `get_botorch` by default automatically applies a number of parameter transformations (e.g. to normalize input data or standardize output data). This is typically what you want for standard use cases with continuous parameters. If your model expects raw parameters, make sure to pass in `transforms=[]` to avoid any transformations to take place. See the [Ax documentation](https://ax.dev/docs/models.html#transforms) for additional information on how transformations in Ax work."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Run the optimization loop\n",
"\n",
"We're ready to run the Bayesian Optimization loop."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running optimization batch 1/5...\n",
"Running optimization batch 2/5...\n",
"Running optimization batch 3/5...\n",
"Running optimization batch 4/5...\n",
"Running optimization batch 5/5...\n",
"Done!\n"
]
}
],
"source": [
"from ax.modelbridge.factory import get_botorch\n",
"\n",
"for i in range(5):\n",
" print(f\"Running optimization batch {i+1}/5...\")\n",
" model = get_botorch(\n",
" experiment=exp,\n",
" data=exp.eval(),\n",
" search_space=exp.search_space,\n",
" model_constructor=_get_and_fit_simple_custom_gp,\n",
" )\n",
" batch = exp.new_trial(generator_run=model.gen(1))\n",
" \n",
"print(\"Done!\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}