BO with Binary Black-box Constraint
- Contributors: Fanjin Wang
- Last updated: Jan 26, 2025
- BoTorch version: 0.12.1.dev91+g0df552162.d20250124
In this notebook, we show how to implement BO under unknown constraints learned from a classification model trained together with a GP.
To add some context, the recommendations presented by BO to assist laboratory experiments may lead to undesired experiment results, such as failures or infeasible protocols [1]. And in these cases, the underlying feasible region can only be obtained through experimentation. In such cases, these unknown constraints can be treated as an unknown variable modeled by a surrogate classification model. Note that this setting, in which we only obtain binary information about whether or not a proposed candidate is feasible or not, is different from the setting in which we observe numerical values of an outcome that is subject to some constraint.
The present code is also inspired by the implementation in [2].
Set dtype and device
import torch
tkwargs = {
"dtype": torch.double,
"device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
}
Problem setup
We begin with setting up a classical synthetic problem Townsend
presented in [3]. The
objective to be maximized is:
Subjected to the constraint :
where
And the binary feasibility is defined by:
Here, we follow a natural representation where indicates a feasible condition. We will train a classification model to predict the feasibility of the point. Note that in BoTorch's implementation, negative values indicate feasibility, thus we need to do conversion later when feeding feasibility into the pipeline.
Note that we essentially 'throw away' information contained in the value of by applying a binary mask to generate - this is for illustration purposes as part of this tutorial, in many real-world applications the latent value is not directly observable and only binary information (experiment success or failure) is available.
[3]: Townsend, A. (2014). Constrained optimization in Chebfun. https://www.chebfun.org/examples/opt/ConstrainedOptimization.html
class Townsend():
def __init__(self):
self.dim = 2
self.lb = torch.tensor([-2.25, 2.25], **tkwargs)
self.ub = torch.tensor([-2.5, 1.75], **tkwargs)
self.bounds = torch.stack([self.lb, self.ub]).T
self._optimal_value = 1.660
self.name = "Townsend"
def __call__(self, x):
return self.objective(x)
def is_feasible(self, x):
x1, x2 = x[..., 0], x[..., 1]
t = torch.atan2(x1, x2)
c = ((2 * torch.cos(t) - 0.5 * torch.cos(2 * t) - 0.25 * torch.cos(3 * t) - 0.125 * torch.cos(4 * t)) ** 2 + (2 * torch.sin(t)) ** 2 - x1 ** 2 - x2 ** 2)
y_con = (c > 0).float() #binarize the feasibility
return y_con
def objective(self,x):
x1, x2 = x[..., 0], x[..., 1]
return torch.cos((x1 - 0.1) * x2) ** 2 + x1 * torch.sin(3 * x1 + x2)
townsend = Townsend()
We plot the landscape of the Townsend function as a reference. The infeasible region is masked out.
# Plot the townsend function and constraint
import matplotlib.pyplot as plt
import numpy as np
def plot_townsend(ax):
x = np.linspace(-2.5, 2.5, 100)
y = np.linspace(-2.5, 2.5, 100)
X, Y = np.meshgrid(x, y)
obj = townsend(torch.tensor(np.stack([X, Y], axis=-1), **tkwargs)).cpu().numpy()
con = townsend.is_feasible(torch.tensor(np.stack([X, Y], axis=-1), **tkwargs)).cpu().numpy()
#mask out the constraint region<0
obj[con==0] = np.nan
c = ax.contourf(X, Y, obj, levels=20,cmap='Blues')
ax.set_xlabel("X1")
ax.set_ylabel("X2")
ax.set_title("Townsend Problem")
plt.colorbar(c, ax=ax, orientation='vertical')
return ax
# Plot the townsend function and constraint
fig, ax = plt.subplots(1, 1, figsize=(5, 4))
plot_townsend(ax)
plt.tight_layout()
plt.show()
Generate Training Data
from botorch.utils.sampling import draw_sobol_samples
def generate_initial_data(n):
# generate training data within the problem bounds
train_x = draw_sobol_samples(bounds=townsend.bounds, n=n, q=1).squeeze(1)
train_obj = townsend(train_x).unsqueeze(-1)
train_con = townsend.is_feasible(train_x)
return train_x, train_obj, train_con
Define Classification Model
We use approximate GP implemented by GPyTorch as the surrogate for unknown constraint. The latent function is modelled by a GP, and the likelihood is modelled by a Bernoulli distribution. Followed by training, we extract the probability from the Bernoulli distribution as the feasibility prediction.
import gpytorch
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy
from gpytorch.kernels.scale_kernel import ScaleKernel
from botorch.models.gpytorch import GPyTorchModel
class GP_vi(ApproximateGP, GPyTorchModel):
def __init__(self, train_x, train_y):
self.train_inputs = (train_x,)
self.train_targets = train_y
variational_distribution = CholeskyVariationalDistribution(train_x.size(0))
variational_strategy = VariationalStrategy(
self, train_x, variational_distribution
)
super(GP_vi, self).__init__(variational_strategy)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = ScaleKernel(gpytorch.kernels.RBFKernel())
self.likelihood = gpytorch.likelihoods.BernoulliLikelihood()
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
Model initialization
We initialize the model with a SingleTaskGP
model and the custom GP_vi
model for the
feasibility modeling.
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.models import SingleTaskGP, ModelListGP
from botorch.fit import fit_gpytorch_mll
from botorch.models.transforms.outcome import Standardize
from botorch.utils.transforms import unnormalize, normalize
from functools import partial
def initialize_model(train_x, train_obj, train_con):
'''Initialize the model for the problem.'''
train_x = normalize(train_x, bounds=townsend.bounds)
model_obj = SingleTaskGP(
train_X=train_x,
train_Y=train_obj,
outcome_transform=Standardize(m=1),
)
mll_obj = ExactMarginalLogLikelihood(model_obj.likelihood, model_obj)
fit_gpytorch_mll(mll_obj)
model_con = GP_vi(train_x, train_con)
mll_con = gpytorch.mlls.VariationalELBO(
model_con.likelihood, model_con, num_data=train_con.size(0)
)
#make sure the GPyTorch model is in double precision
model_con.double()
mll_con.double()
fit_gpytorch_mll(mll_con)
model = ModelListGP(model_obj, model_con)
return model
We further set up the acquisition function qLogExpectedImprovement
and the method to
optimize and get observation from Townsend function. The constraint is passed to the
acquisition function as a constraint
argument. See
here
for more details. The helper functions pass_obj
and pass_con
are used to pass the
objective and constraint values to the acquisition function. The fat
arguent of the
acquisition function is set to None
for the constraint to indicate that no
transformation should be applied to the constraint as it already spits out values in the
interval [0,1].
from botorch.acquisition import qLogExpectedImprovement
from botorch.acquisition.objective import GenericMCObjective
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.optim import optimize_acqf
BATCH_SIZE = 1
NUM_RESTARTS = 10
RAW_SAMPLES = 512
def optimize_acqf_and_get_observation(model, train_obj, train_con):
"""
Optimizes the acquisition function, and returns a new candidate and observation.
"""
# best_f is the best feasible objective value observed so far
best_f = np.ma.masked_array(train_obj, mask=~train_con.bool()).max().item()
# standardize the training data
standard_bounds = torch.stack([torch.zeros(townsend.dim), torch.ones(townsend.dim)])
acqf = qLogExpectedImprovement(
model=model,
best_f=best_f,
sampler=SobolQMCNormalSampler(sample_shape=torch.Size([1024])), objective=GenericMCObjective(pass_obj),
constraints=[partial(pass_con, model_con=model.models[1])],
fat=[None]
)
# run the optimization function
candidates, _ = optimize_acqf(
acq_function=acqf,
bounds=standard_bounds,
q=BATCH_SIZE,
num_restarts=NUM_RESTARTS,
raw_samples=RAW_SAMPLES,
)
# observe new values
new_x = unnormalize(candidates.detach(), townsend.bounds)
new_obj = townsend(new_x)
new_con = townsend.is_feasible(new_x)
return new_x, new_obj, new_con, acqf
def pass_obj(Z,X=None):
'''
directly pass the objective to the acquisition function
'''
return Z[...,0]
def pass_con(Z, model_con, X=None):
'''
pass the constraint to the acquisition function
'''
y_con = Z[...,1] #get the constraint
prob = model_con.likelihood(y_con).probs #obtain the probability of y_con(when constraint satisfied)
return prob+1e-8 # we add some small value to avoid log(0) as qLogEI is used
We also define some helper functions to visualize the BO process through plotting out the acquisition function value, and the underlying constraint probability and expected improvement value before subjecting to constraints.
def plot_helper(model, train_x,new_x, acqf,axes):
with torch.no_grad():
x = np.linspace(-2.5, 2.5, 100)
y = np.linspace(-2.5, 2.5, 100)
X, Y = np.meshgrid(x, y)
Z = torch.tensor(np.stack([X, Y], axis=-1)).to(**tkwargs)
Z = normalize(Z, bounds=townsend.bounds)
Z = Z.reshape(-1, 2).unsqueeze(1)
# get the acquisition function value
acq_values = acqf(Z).cpu().numpy()
# get the constraint probability
model_con = model.models[1]
prob = model_con.likelihood(model_con(Z)).probs.cpu().numpy()
# get the expected improvement value
ei_values = model.models[0](Z).mean.cpu().numpy()
# plot the townsend function
plot_townsend(axes[0])
c_acqf = axes[1].contourf(X, Y, acq_values.reshape(100,100), levels=20,cmap='Blues')
#plot the constraint probability and set the colorbar to 0-1
c_prob = axes[2].contourf(X, Y, prob.reshape(100,100), levels=20,cmap='RdYlGn',vmin=0,vmax=1)
c_ei = axes[3].contourf(X, Y, ei_values.reshape(100,100), levels=20,cmap='Oranges')
#plot the current observations
for ax in axes:
ax.scatter(train_x[:, 0].cpu(), train_x[:, 1].cpu(), color='grey', label='Observations',alpha=0.5)
ax.scatter(new_x[:, 0].cpu(), new_x[:, 1].cpu(), marker='*' ,color='red', label='New Point')
ax.set_xlabel("X1")
ax.set_ylabel("X2")
axes[1].set_title("Acquisition Function")
axes[2].set_title("Constraint Probability")
axes[3].set_title("EI value")
#add colorbar
plt.colorbar(c_acqf, ax=axes[1])
plt.colorbar(c_prob, ax=axes[2])
plt.colorbar(c_ei, ax=axes[3])
plt.tight_layout()
Perform BO loop
We initialize the BO loop with 10 random points from SOBOL sequence. We then perform 50 iterations of BO. The acquisition function value, the underlying constraint probability, and the expected improvement value are plotted at every 10 iterations.
INIT_DATA_SIZE = 2*(townsend.dim + 1)
N_BATCH = 25
VERBOSE = True
# generate initial training data
train_x, train_obj, train_con = generate_initial_data(n=INIT_DATA_SIZE)
# add random baseline
train_x_rand, train_obj_rand, train_con_rand = (train_x, train_obj, train_con)
# initialize the model
model = initialize_model(train_x, train_obj, train_con)
# store a list of regrets at each step
optimal = torch.tensor(townsend._optimal_value, **tkwargs)
regrets_model = []
regrets_rand = []
for _ in range(N_BATCH+1):
# append the regrets
best_f_model = np.ma.masked_array(train_obj, mask=~train_con.bool()).max().item()
regrets_model.append(optimal-best_f_model)
best_f_rand = np.ma.masked_array(train_obj_rand, mask=~train_con_rand.bool()).max().item()
regrets_rand.append(optimal-best_f_rand)
# optimize the acquisition function and get new observation
new_x, new_obj, new_con,acqf = optimize_acqf_and_get_observation(model, train_obj, train_con)
# print the current model and random regret and iteration, plot the acquisition function every 5 iterations
if VERBOSE:
print(f"Iteration {_}: \n log Regret CEI = {torch.log(regrets_model[-1]):.2f} \n log Regret Sobol = {torch.log(regrets_rand[-1]):.2f}")
if _ % 5 == 0:
fig, axes = plt.subplots(1, 4, figsize=(15, 3))
plot_helper(model, train_x, new_x, acqf,axes)
plt.show()
#clear the axes
plt.pause(0.1)
plt.close(fig)
# include the new observation in the training data
train_x = torch.cat([train_x, new_x])
train_obj = torch.cat([train_obj, new_obj.unsqueeze(-1)])
train_con = torch.cat([train_con, new_con])
# update the model
model = initialize_model(train_x, train_obj, train_con)
# execute the random baseline
new_x_rand, new_obj_rand, new_con_rand = generate_initial_data(n=BATCH_SIZE)
train_x_rand = torch.cat([train_x_rand, new_x_rand])
train_obj_rand = torch.cat([train_obj_rand, new_obj_rand])
train_con_rand = torch.cat([train_con_rand, new_con_rand])
Iteration 0:
log Regret CEI = -1.13
log Regret Sobol = -1.13
Iteration 1:
log Regret CEI = -1.13
log Regret Sobol = -1.13
Iteration 2:
log Regret CEI = -1.13
log Regret Sobol = -1.55
/Users/j30607/sandbox/botorch/botorch/fit.py:215: OptimizationWarning: scipy_minimize terminated with status 3, displaying original message from scipy.optimize.minimize: ABNORMAL_TERMINATION_IN_LNSRCH
result = optimizer(mll, closure=closure, **optimizer_kwargs)
Iteration 3:
log Regret CEI = -1.13
log Regret Sobol = -1.55
Iteration 4:
log Regret CEI = -2.31
log Regret Sobol = -1.55
Iteration 5:
log Regret CEI = -2.31
log Regret Sobol = -1.55
Iteration 6:
log Regret CEI = -2.31
log Regret Sobol = -1.55
Iteration 7:
log Regret CEI = -2.31
log Regret Sobol = -1.55
Iteration 8:
log Regret CEI = -2.31
log Regret Sobol = -1.55
/Users/j30607/sandbox/botorch/botorch/fit.py:215: OptimizationWarning: scipy_minimize terminated with status 3, displaying original message from scipy.optimize.minimize: ABNORMAL_TERMINATION_IN_LNSRCH
result = optimizer(mll, closure=closure, **optimizer_kwargs)
Iteration 9:
log Regret CEI = -2.31
log Regret Sobol = -1.55
Iteration 10:
log Regret CEI = -2.34
log Regret Sobol = -1.55
Iteration 11:
log Regret CEI = -2.34
log Regret Sobol = -1.55
Iteration 12:
log Regret CEI = -2.34
log Regret Sobol = -1.55
/Users/j30607/sandbox/botorch/botorch/fit.py:215: OptimizationWarning: scipy_minimize terminated with status 3, displaying original message from scipy.optimize.minimize: ABNORMAL_TERMINATION_IN_LNSRCH
result = optimizer(mll, closure=closure, **optimizer_kwargs)
Iteration 13:
log Regret CEI = -2.34
log Regret Sobol = -1.55
/Users/j30607/sandbox/botorch/botorch/fit.py:215: OptimizationWarning: scipy_minimize terminated with status 3, displaying original message from scipy.optimize.minimize: ABNORMAL_TERMINATION_IN_LNSRCH
result = optimizer(mll, closure=closure, **optimizer_kwargs)
Iteration 14:
log Regret CEI = -2.34
log Regret Sobol = -1.55
Iteration 15:
log Regret CEI = -2.34
log Regret Sobol = -1.55
Iteration 16:
log Regret CEI = -2.34
log Regret Sobol = -1.55
Iteration 17:
log Regret CEI = -2.34
log Regret Sobol = -1.55
Iteration 18:
log Regret CEI = -2.34
log Regret Sobol = -1.55
Iteration 19:
log Regret CEI = -2.34
log Regret Sobol = -1.55
Iteration 20:
log Regret CEI = -2.34
log Regret Sobol = -1.55
Iteration 21:
log Regret CEI = -3.78
log Regret Sobol = -1.55
Iteration 22:
log Regret CEI = -3.78
log Regret Sobol = -1.55
/Users/j30607/sandbox/botorch/botorch/fit.py:215: OptimizationWarning: scipy_minimize terminated with status 3, displaying original message from scipy.optimize.minimize: ABNORMAL_TERMINATION_IN_LNSRCH
result = optimizer(mll, closure=closure, **optimizer_kwargs)
Iteration 23:
log Regret CEI = -3.78
log Regret Sobol = -1.55
Iteration 24:
log Regret CEI = -3.78
log Regret Sobol = -1.55
Iteration 25:
log Regret CEI = -3.78
log Regret Sobol = -1.55
#plot the regret
plt.plot(regrets_model, label='CEI')
plt.plot(regrets_rand, label='Sobol')
plt.yscale("log")
plt.xlabel("Number of Iterations")
plt.ylabel("Regret")
plt.title("Regret of CEI and Sobol")
plt.legend()
plt.show()