Multi-Information Source BO with Augmented GP
Multi-Information Source BO with Augmented Gaussian Processes
- Contributors: andreaponti5
- Last updated: Jan 29, 2024
- BoTorch version: 0.9.5(dev)
In this tutorial, we show how to perform Multiple Information Source Bayesian Optimization in BoTorch based on the Augmented Gaussian Process (AGP) and the Augmented UCB (AUCB) acquisition function proposed in [1]. The key idea of the AGP is to fit a GP model for each information source and augment the observations on the high fidelity source with those from cheaper sources which can be considered as reliable. The GP model fitted on this augmented set of observations is the AGP. The AUCB is a modification of the standard UCB -- computed on the AGP -- suitably proposed to also deal with the source-specific query cost.
We emprically show that the AGP-based Multiple Information Source Basyesian Optimization usually performs better than other multi-fidelity approaches [2].
[1] Candelieri, A., & Archetti, F. (2021). Sparsifying to optimize over multiple information sources: an augmented Gaussian process based algorithm. Structural and Multidisciplinary Optimization, 64, 239-255. [2] The arxiv will be available soon.
!pip install matplotlib
import os
import matplotlib.pyplot as plt
import torch
from gpytorch import ExactMarginalLogLikelihood
import botorch
from botorch import fit_gpytorch_mll
from botorch.acquisition import InverseCostWeightedUtility, qMultiFidelityMaxValueEntropy
from botorch_community.acquisition.augmented_multisource import AugmentedUpperConfidenceBound
from botorch.models import AffineFidelityCostModel, SingleTaskMultiFidelityGP
from botorch_community.models.gp_regression_multisource import SingleTaskAugmentedGP, get_random_x_for_agp
from botorch.models.transforms import Standardize
from botorch.optim import optimize_acqf, optimize_acqf_mixed
from botorch.test_functions.multi_fidelity import AugmentedBranin
tkwargs = {
"dtype": torch.double,
"device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
}
SMOKE_TEST = os.environ.get("SMOKE_TEST", False)
N_ITER = 10 if SMOKE_TEST else 50
SEED = 3
Problem setup
We consider the augmented Branin multi-fidelity synthetic test problem. It is important to clarify that augmented is not about the AGP: here, it has a different meaning. It means that the Branin test function has been modified by introducing an additional dimension representing the fidelity parameter.
The test function takes the form where and . The target fidelity is 1.0, which means that our goal is to solve by making use of cheaper evaluations for . In this example, we'll assume that the cost function takes the form , illustrating a situation where the fixed cost is .
Since a multiple information source context is considered, three different sources are considered, with , respectively.
problem = AugmentedBranin(negate=True).to(**tkwargs)
fidelities = torch.tensor([0.5, 0.75, 1.0], **tkwargs)
n_sources = fidelities.shape[0]
bounds = torch.tensor([[-5, 0, 0], [10, 15, n_sources - 1]], **tkwargs)
target_fidelities = {n_sources - 1: 1.0}
cost_model = AffineFidelityCostModel(fidelity_weights=target_fidelities, fixed_cost=5.0)
cost_aware_utility = InverseCostWeightedUtility(cost_model=cost_model)
Model initialization
We use a SingleTaskAugmentedGP
to implement our AGP.
At each Bayesian Optimization iteration, the set of observations from the ground-truth (i.e., the highest fidelity and more expensive source) is temporarily augmented by including observations from the other cheap sources, only if they can be considered reliable. Specifically, an observation from a cheap source is considered reliable if it satisfies the following inequality:
where and are, respectively, the posterior mean and standard deviation of the GP model fitted on the high fidelity observations only, and is a technical parameter making more conservative () or inclusive ( the augmentation process. As reported in [1], a suitable value for this parameter is .
After the set of observations is augmented, the AGP is fitted through
SingleTaskAugmentedGP
.
def generate_initial_data(n):
train_x = get_random_x_for_agp(n, bounds, 1)
xs = train_x[..., :-1]
fids = fidelities[train_x[..., -1].int()].reshape(-1, 1)
train_obj = problem(torch.cat((xs, fids), dim=1)).unsqueeze(-1)
return train_x, train_obj
def initialize_model(train_x, train_obj, m):
model = SingleTaskAugmentedGP(
train_x, train_obj, m=m, outcome_transform=Standardize(m=1),
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
return mll, model
Define a helper function that performs the essential BO step
This helper function optimizes the acquisition function and returns the candidate point along with the observed function values.
The UCB acquisition function has been modified to deal with both the discrepancy between information sources and the source-specific query cost.
Formally, the AUCB acquisition function, at a generic iteration , is defined as:
where is the best (i.e., highest) value in the augmented set of observations, the numerator is -- therefore -- the optimistic improvement with respect to , is the query cost for the source , and is a discrepancy measure between the predictions provided by the AGP and the GP on the source , respectively, given the input (i.e., 1 is added just to avoid division by zero).
For more information, please refer to [1],
def optimize_aucb(acqf):
candidate, value = optimize_acqf(
acq_function=acqf,
bounds=bounds,
q=1,
num_restarts=5,
raw_samples=128,
)
# observe new values
new_x = candidate.detach()
new_x[:, -1] = torch.round(new_x[:, -1], decimals=0)
return new_x
Perform a few steps of multi-fidelity BO
First, let's generate some initial random data and fit a surrogate model.
torch.manual_seed(SEED)
train_x, train_obj = generate_initial_data(n=5)
We can now use the helper functions above to run a few iterations of BO.
cumulative_cost = 0.0
with botorch.settings.validate_input_scaling(False):
for it in range(N_ITER):
mll, model = initialize_model(train_x, train_obj, m=1)
fit_gpytorch_mll(mll)
acqf = AugmentedUpperConfidenceBound(
model,
beta=3,
maximize=True,
best_f=train_obj[torch.where(train_x[:, -1] == 0)].min(),
cost={i: fid + 5.0 for i, fid in enumerate(fidelities)},
)
new_x = optimize_aucb(acqf)
if model.n_true_points < model.max_n_cheap_points:
new_x[:, -1] = fidelities.shape[0] - 1
train_x = torch.cat([train_x, new_x])
new_x[:, -1] = fidelities[new_x[:, -1].int()]
new_obj = problem(new_x).unsqueeze(-1)
train_obj = torch.cat([train_obj, new_obj])
print(
f"Iter {it};"
f"\t Fid = {new_x[0].tolist()[-1]:.2f};"
f"\t Obj = {new_obj[0][0].tolist():.4f};"
)
Iter 0; Fid = 1.00; Obj = -200.3252;
Iter 1; Fid = 1.00; Obj = -38.1094;
Iter 2; Fid = 1.00; Obj = -38.1090;
Iter 3; Fid = 1.00; Obj = -38.1093;
Iter 4; Fid = 1.00; Obj = -38.1093;
Iter 5; Fid = 0.75; Obj = -10.3835;
Iter 6; Fid = 1.00; Obj = -38.1093;
Iter 7; Fid = 1.00; Obj = -36.9691;
Iter 8; Fid = 1.00; Obj = -38.0601;
Iter 9; Fid = 1.00; Obj = -36.4893;
Iter 10; Fid = 1.00; Obj = -34.5676;
Iter 11; Fid = 1.00; Obj = -27.2647;
Iter 12; Fid = 1.00; Obj = -23.7780;
C:UserspontiDesktopworkspaceotorchotorchoptimoptimize.py:367: RuntimeWarning: Optimization failed in gen_candidates_scipy with the following warning(s):
[OptimizationWarning('Optimization failed within scipy.optimize.minimize with status 2 and message ABNORMAL_TERMINATION_IN_LNSRCH.')]
Trying again with a new set of initial conditions.
warnings.warn(first_warn_msg, RuntimeWarning)
Iter 13; Fid = 1.00; Obj = -50.3585;
Iter 14; Fid = 1.00; Obj = -35.5112;
Iter 15; Fid = 1.00; Obj = -17.9104;
Iter 16; Fid = 1.00; Obj = -12.9752;
Iter 17; Fid = 1.00; Obj = -13.3390;
Iter 18; Fid = 1.00; Obj = -8.9100;
Iter 19; Fid = 1.00; Obj = -3.2656;
Iter 20; Fid = 1.00; Obj = -2.7450;
Iter 21; Fid = 1.00; Obj = -5.7930;
Iter 22; Fid = 1.00; Obj = -0.6055;
Iter 23; Fid = 1.00; Obj = -0.5995;
Iter 24; Fid = 1.00; Obj = -2.8311;
Iter 25; Fid = 1.00; Obj = -1.0225;
Iter 26; Fid = 1.00; Obj = -2.8477;
Iter 27; Fid = 1.00; Obj = -1.0333;
Iter 28; Fid = 1.00; Obj = -0.4881;
Iter 29; Fid = 1.00; Obj = -3.4325;
Iter 30; Fid = 1.00; Obj = -0.8336;
Iter 31; Fid = 1.00; Obj = -0.4630;
Iter 32; Fid = 1.00; Obj = -2.3562;
Iter 33; Fid = 1.00; Obj = -4.9150;
Iter 34; Fid = 1.00; Obj = -0.4038;
Iter 35; Fid = 1.00; Obj = -1.9965;
Iter 36; Fid = 1.00; Obj = -0.5458;
Iter 37; Fid = 1.00; Obj = -5.4650;
Iter 38; Fid = 1.00; Obj = -0.8234;
Iter 39; Fid = 1.00; Obj = -0.6905;
Iter 40; Fid = 1.00; Obj = -0.8766;
Iter 41; Fid = 1.00; Obj = -0.9116;
Iter 42; Fid = 1.00; Obj = -3.7283;
Iter 43; Fid = 1.00; Obj = -1.8566;
Iter 44; Fid = 1.00; Obj = -1.5902;
Iter 45; Fid = 1.00; Obj = -1.2975;
Iter 46; Fid = 1.00; Obj = -1.7442;
Iter 47; Fid = 1.00; Obj = -6.0570;
Iter 48; Fid = 1.00; Obj = -2.5479;
Iter 49; Fid = 1.00; Obj = -1.4998;
Comparison to MES
def initialize_mes_model(train_x, train_obj, data_fidelity):
model = SingleTaskMultiFidelityGP(
train_x,
train_obj,
outcome_transform=Standardize(m=1),
data_fidelity=data_fidelity,
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
return mll, model
def optimize_mes_and_get_observation(mes_acq, fixed_features_list):
candidates, acq_value = optimize_acqf_mixed(
acq_function=mes_acq,
bounds=problem.bounds,
q=1,
num_restarts=5,
raw_samples=128,
fixed_features_list=fixed_features_list,
)
# observe new values
cost = cost_model(candidates).sum()
new_x = candidates.detach()
new_obj = problem(new_x).unsqueeze(-1)
return new_x, new_obj, cost
train_x_mes = torch.clone(train_x[:10])
train_x_mes[:, -1] = fidelities[train_x_mes[:, -1].int()]
train_obj_mes = torch.clone(train_obj[:10])
candidate_set = torch.rand(
1000, problem.bounds.size(1), device=problem.bounds.device, dtype=problem.bounds.dtype
)
candidate_set = problem.bounds[0] + (problem.bounds[1] - problem.bounds[0]) * candidate_set
cumulative_cost = 0.0
with botorch.settings.validate_input_scaling(False):
for it in range(N_ITER):
mll, model = initialize_mes_model(train_x_mes, train_obj_mes, data_fidelity=2)
fit_gpytorch_mll(mll)
acqf = qMultiFidelityMaxValueEntropy(
model, candidate_set, cost_aware_utility=cost_aware_utility
)
new_x, new_obj, cost = optimize_mes_and_get_observation(acqf,
fixed_features_list=[{2: fid} for fid in fidelities])
train_x_mes = torch.cat([train_x_mes, new_x])
train_obj_mes = torch.cat([train_obj_mes, new_obj])
cumulative_cost += cost
print(
f"Iter {it};"
f"\t Fid = {new_x[0].tolist()[-1]:.2f};"
f"\t Obj = {new_obj[0][0].tolist():.4f};"
)
Iter 0; Fid = 0.50; Obj = -85.5566;
Iter 1; Fid = 0.50; Obj = -7.7401;
Iter 2; Fid = 0.50; Obj = -85.9927;
Iter 3; Fid = 0.50; Obj = -10.4996;
Iter 4; Fid = 0.50; Obj = -90.1834;
Iter 5; Fid = 0.50; Obj = -15.9023;
Iter 6; Fid = 0.50; Obj = -121.0723;
Iter 7; Fid = 0.50; Obj = -96.2876;
Iter 8; Fid = 0.50; Obj = -8.3718;
Iter 9; Fid = 0.50; Obj = -11.9013;
Iter 10; Fid = 0.50; Obj = -15.0995;
Iter 11; Fid = 0.50; Obj = -7.6864;
Iter 12; Fid = 0.50; Obj = -6.4228;
Iter 13; Fid = 0.50; Obj = -51.4467;
Iter 14; Fid = 0.50; Obj = -22.0845;
Iter 15; Fid = 0.50; Obj = -5.9313;
Iter 16; Fid = 0.50; Obj = -130.6765;
Iter 17; Fid = 0.50; Obj = -181.9663;
Iter 18; Fid = 0.50; Obj = -88.9639;
Iter 19; Fid = 0.50; Obj = -21.3274;
Iter 20; Fid = 0.50; Obj = -8.2452;
Iter 21; Fid = 0.50; Obj = -14.1170;
Iter 22; Fid = 0.50; Obj = -8.3733;
Iter 23; Fid = 0.50; Obj = -33.3266;
Iter 24; Fid = 0.50; Obj = -22.8821;
Iter 25; Fid = 0.50; Obj = -1.2653;
Iter 26; Fid = 0.50; Obj = -6.2855;
Iter 27; Fid = 0.50; Obj = -129.2954;
Iter 28; Fid = 0.50; Obj = -27.6052;
Iter 29; Fid = 0.50; Obj = -41.9348;
Iter 30; Fid = 0.50; Obj = -22.8847;
Iter 31; Fid = 0.50; Obj = -2.2557;
Iter 32; Fid = 0.50; Obj = -13.0811;
Iter 33; Fid = 0.50; Obj = -9.5395;
Iter 34; Fid = 0.50; Obj = -13.6012;
Iter 35; Fid = 0.50; Obj = -85.0653;
Iter 36; Fid = 0.50; Obj = -79.1870;
Iter 37; Fid = 0.50; Obj = -7.1699;
Iter 38; Fid = 0.50; Obj = -8.6771;
Iter 39; Fid = 1.00; Obj = -1.5034;
Iter 40; Fid = 0.50; Obj = -122.4489;
Iter 41; Fid = 0.50; Obj = -24.0085;
Iter 42; Fid = 0.50; Obj = -2.1577;
Iter 43; Fid = 1.00; Obj = -4.0370;
Iter 44; Fid = 0.50; Obj = -13.9348;
Iter 45; Fid = 1.00; Obj = -5.6706;
Iter 46; Fid = 1.00; Obj = -3.8358;
Iter 47; Fid = 1.00; Obj = -1.0238;
Iter 48; Fid = 0.50; Obj = -289.3283;
Iter 49; Fid = 1.00; Obj = -3.0379;
Plot results
mapping_fid = dict(zip(range(fidelities.shape[0]), fidelities.tolist()))
cost_AGP = torch.cumsum(torch.tensor([mapping_fid[int(source)] for source in train_x[:, -1].tolist()]), dim=0)
cost_MES = torch.cumsum(train_x_mes[:, -1], dim=0)
train_obj[torch.where(train_x[:, -1] != fidelities.shape[0] - 1)] = train_obj.min()
best_seen_AGP = torch.cummax(train_obj, dim=0)[0]
train_obj_mes[torch.where(train_x_mes[:, -1] != 1)[0]] = train_obj_mes.min()
best_seen_MES = torch.cummax(train_obj_mes, dim=0)[0]
fig, ax = plt.subplots(1, 1, figsize=(4, 3), dpi=200)
ax.plot(
cost_AGP.cpu()[9:],
best_seen_AGP.cpu()[9:],
label="AGP"
)
ax.plot(
cost_MES.cpu()[9:],
best_seen_MES.cpu()[9:],
label="MES"
)
ax.set_title("Branin", fontsize="12")
ax.set_xlabel("Total cost", fontsize="10")
ax.set_ylabel("Best seen", fontsize="10")
ax.tick_params(labelsize=10)
ax.legend(loc="lower right", fontsize="7", frameon=True, ncol=1)
plt.tight_layout()