In this tutorial, we show how to implement Trust Region Bayesian Optimization (TuRBO) [1] in a closed loop in BoTorch.
This implementation uses one trust region (TuRBO-1) and supports either parallel expected improvement (qEI) or Thompson sampling (TS). We optimize the $20D$ Ackley function on the domain $[-5, 10]^{20}$ and show that TuRBO-1 outperforms qEI as well as Sobol.
Since botorch assumes a maximization problem, we will attempt to maximize $-f(x)$ to achieve $\max_x -f(x)=0$.
import os
import math
from dataclasses import dataclass
import torch
from botorch.acquisition import qExpectedImprovement
from botorch.fit import fit_gpytorch_mll
from botorch.generation import MaxPosteriorSampling
from botorch.models import SingleTaskGP
from botorch.optim import optimize_acqf
from botorch.test_functions import Ackley
from botorch.utils.transforms import unnormalize
from torch.quasirandom import SobolEngine
import gpytorch
from gpytorch.constraints import Interval
from gpytorch.kernels import MaternKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.priors import HorseshoePrior
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.double
SMOKE_TEST = os.environ.get("SMOKE_TEST")
I0128 075923.913 _utils_internal.py:179] NCCL_DEBUG env var is set to None I0128 075923.914 _utils_internal.py:197] NCCL_DEBUG is forced to WARN from None
The goal is to minimize the popular Ackley function:
$f(x_1,\ldots,x_d) = -20\exp\left(-0.2 \sqrt{\frac{1}{d} \sum_{j=1}^d x_j^2} \right) -\exp \left( \frac{1}{d} \sum_{j=1}^d \cos(2 \pi x_j) \right) + 20 + e$
over the domain $[-5, 10]^{20}$. The global optimal value of $0$ is attained at $x_1 = \ldots = x_d = 0$.
As mentioned above, since botorch assumes a maximization problem, we instead maximize $-f(x)$.
fun = Ackley(dim=20, negate=True).to(dtype=dtype, device=device)
fun.bounds[0, :].fill_(-5)
fun.bounds[1, :].fill_(10)
dim = fun.dim
lb, ub = fun.bounds
batch_size = 4
n_init = 2 * dim
max_cholesky_size = float("inf") # Always use Cholesky
def eval_objective(x):
"""This is a helper function we use to unnormalize and evalaute a point"""
return fun(unnormalize(x, fun.bounds))
TuRBO needs to maintain a state, which includes the length of the trust region, success and failure counters, success and failure tolerance, etc.
In this tutorial we store the state in a dataclass and update the state of TuRBO after each batch evaluation.
Note: These settings assume that the domain has been scaled to $[0, 1]^d$ and that the same batch size is used for each iteration.
@dataclass
class TurboState:
dim: int
batch_size: int
length: float = 0.8
length_min: float = 0.5**7
length_max: float = 1.6
failure_counter: int = 0
failure_tolerance: int = float("nan") # Note: Post-initialized
success_counter: int = 0
success_tolerance: int = 10 # Note: The original paper uses 3
best_value: float = -float("inf")
restart_triggered: bool = False
def __post_init__(self):
self.failure_tolerance = math.ceil(
max([4.0 / self.batch_size, float(self.dim) / self.batch_size])
)
def update_state(state, Y_next):
if max(Y_next) > state.best_value + 1e-3 * math.fabs(state.best_value):
state.success_counter += 1
state.failure_counter = 0
else:
state.success_counter = 0
state.failure_counter += 1
if state.success_counter == state.success_tolerance: # Expand trust region
state.length = min(2.0 * state.length, state.length_max)
state.success_counter = 0
elif state.failure_counter == state.failure_tolerance: # Shrink trust region
state.length /= 2.0
state.failure_counter = 0
state.best_value = max(state.best_value, max(Y_next).item())
if state.length < state.length_min:
state.restart_triggered = True
return state
state = TurboState(dim=dim, batch_size=batch_size)
print(state)
TurboState(dim=20, batch_size=4, length=0.8, length_min=0.0078125, length_max=1.6, failure_counter=0, failure_tolerance=5, success_counter=0, success_tolerance=10, best_value=-inf, restart_triggered=False)
This generates an initial set of Sobol points that we use to start of the BO loop.
def get_initial_points(dim, n_pts, seed=0):
sobol = SobolEngine(dimension=dim, scramble=True, seed=seed)
X_init = sobol.draw(n=n_pts).to(dtype=dtype, device=device)
return X_init
Given the current state
and a probabilistic (GP) model
built from observations X
and Y
, we generate a new batch of points.
This method works on the domain $[0, 1]^d$, so make sure to not pass in observations from the true domain. unnormalize
is called before the true function is evaluated which will first map the points back to the original domain.
We support either TS and qEI which can be specified via the acqf
argument.
def generate_batch(
state,
model, # GP model
X, # Evaluated points on the domain [0, 1]^d
Y, # Function values
batch_size,
n_candidates=None, # Number of candidates for Thompson sampling
num_restarts=10,
raw_samples=512,
acqf="ts", # "ei" or "ts"
):
assert acqf in ("ts", "ei")
assert X.min() >= 0.0 and X.max() <= 1.0 and torch.all(torch.isfinite(Y))
if n_candidates is None:
n_candidates = min(5000, max(2000, 200 * X.shape[-1]))
# Scale the TR to be proportional to the lengthscales
x_center = X[Y.argmax(), :].clone()
weights = model.covar_module.base_kernel.lengthscale.squeeze().detach()
weights = weights / weights.mean()
weights = weights / torch.prod(weights.pow(1.0 / len(weights)))
tr_lb = torch.clamp(x_center - weights * state.length / 2.0, 0.0, 1.0)
tr_ub = torch.clamp(x_center + weights * state.length / 2.0, 0.0, 1.0)
if acqf == "ts":
dim = X.shape[-1]
sobol = SobolEngine(dim, scramble=True)
pert = sobol.draw(n_candidates).to(dtype=dtype, device=device)
pert = tr_lb + (tr_ub - tr_lb) * pert
# Create a perturbation mask
prob_perturb = min(20.0 / dim, 1.0)
mask = torch.rand(n_candidates, dim, dtype=dtype, device=device) <= prob_perturb
ind = torch.where(mask.sum(dim=1) == 0)[0]
mask[ind, torch.randint(0, dim - 1, size=(len(ind),), device=device)] = 1
# Create candidate points from the perturbations and the mask
X_cand = x_center.expand(n_candidates, dim).clone()
X_cand[mask] = pert[mask]
# Sample on the candidate points
thompson_sampling = MaxPosteriorSampling(model=model, replacement=False)
with torch.no_grad(): # We don't need gradients when using TS
X_next = thompson_sampling(X_cand, num_samples=batch_size)
elif acqf == "ei":
ei = qExpectedImprovement(model, train_Y.max())
X_next, acq_value = optimize_acqf(
ei,
bounds=torch.stack([tr_lb, tr_ub]),
q=batch_size,
num_restarts=num_restarts,
raw_samples=raw_samples,
)
return X_next
This simple loop runs one instance of TuRBO-1 with Thompson sampling until convergence.
TuRBO-1 is a local optimizer that can be used for a fixed evaluation budget in a multi-start fashion. Once TuRBO converges, state["restart_triggered"]
will be set to true and the run should be aborted. If you want to run more evaluations with TuRBO, you simply generate a new set of initial points and then keep generating batches until convergence or when the evaluation budget has been exceeded. It's important to note that evaluations from previous instances are discarded when TuRBO restarts.
NOTE: We use a SingleTaskGP
with a noise constraint to keep the noise from getting too large as the problem is noise-free.
X_turbo = get_initial_points(dim, n_init)
Y_turbo = torch.tensor(
[eval_objective(x) for x in X_turbo], dtype=dtype, device=device
).unsqueeze(-1)
state = TurboState(dim, batch_size=batch_size)
NUM_RESTARTS = 10 if not SMOKE_TEST else 2
RAW_SAMPLES = 512 if not SMOKE_TEST else 4
N_CANDIDATES = min(5000, max(2000, 200 * dim)) if not SMOKE_TEST else 4
torch.manual_seed(0)
while not state.restart_triggered: # Run until TuRBO converges
# Fit a GP model
train_Y = (Y_turbo - Y_turbo.mean()) / Y_turbo.std()
likelihood = GaussianLikelihood(noise_constraint=Interval(1e-8, 1e-3))
covar_module = ScaleKernel( # Use the same lengthscale prior as in the TuRBO paper
MaternKernel(
nu=2.5, ard_num_dims=dim, lengthscale_constraint=Interval(0.005, 4.0)
)
)
model = SingleTaskGP(
X_turbo, train_Y, covar_module=covar_module, likelihood=likelihood
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
# Do the fitting and acquisition function optimization inside the Cholesky context
with gpytorch.settings.max_cholesky_size(max_cholesky_size):
# Fit the model
fit_gpytorch_mll(mll)
# Create a batch
X_next = generate_batch(
state=state,
model=model,
X=X_turbo,
Y=train_Y,
batch_size=batch_size,
n_candidates=N_CANDIDATES,
num_restarts=NUM_RESTARTS,
raw_samples=RAW_SAMPLES,
acqf="ts",
)
Y_next = torch.tensor(
[eval_objective(x) for x in X_next], dtype=dtype, device=device
).unsqueeze(-1)
# Update state
state = update_state(state=state, Y_next=Y_next)
# Append data
X_turbo = torch.cat((X_turbo, X_next), dim=0)
Y_turbo = torch.cat((Y_turbo, Y_next), dim=0)
# Print current status
print(
f"{len(X_turbo)}) Best value: {state.best_value:.2e}, TR length: {state.length:.2e}"
)
44) Best value: -1.17e+01, TR length: 8.00e-01 48) Best value: -1.17e+01, TR length: 8.00e-01 52) Best value: -1.11e+01, TR length: 8.00e-01 56) Best value: -1.04e+01, TR length: 8.00e-01 60) Best value: -1.04e+01, TR length: 8.00e-01 64) Best value: -9.41e+00, TR length: 8.00e-01 68) Best value: -9.41e+00, TR length: 8.00e-01 72) Best value: -9.41e+00, TR length: 8.00e-01 76) Best value: -9.41e+00, TR length: 8.00e-01 80) Best value: -9.41e+00, TR length: 8.00e-01 84) Best value: -8.82e+00, TR length: 8.00e-01 88) Best value: -8.82e+00, TR length: 8.00e-01 92) Best value: -8.82e+00, TR length: 8.00e-01 96) Best value: -8.36e+00, TR length: 8.00e-01 100) Best value: -7.93e+00, TR length: 8.00e-01 104) Best value: -7.93e+00, TR length: 8.00e-01 108) Best value: -7.93e+00, TR length: 8.00e-01 112) Best value: -7.93e+00, TR length: 8.00e-01 116) Best value: -7.93e+00, TR length: 8.00e-01 120) Best value: -7.93e+00, TR length: 4.00e-01 124) Best value: -6.72e+00, TR length: 4.00e-01 128) Best value: -6.34e+00, TR length: 4.00e-01 132) Best value: -6.01e+00, TR length: 4.00e-01 136) Best value: -5.51e+00, TR length: 4.00e-01 140) Best value: -5.51e+00, TR length: 4.00e-01 144) Best value: -5.51e+00, TR length: 4.00e-01 148) Best value: -5.45e+00, TR length: 4.00e-01 152) Best value: -5.27e+00, TR length: 4.00e-01 156) Best value: -5.27e+00, TR length: 4.00e-01 160) Best value: -5.27e+00, TR length: 4.00e-01 164) Best value: -5.27e+00, TR length: 4.00e-01 168) Best value: -5.06e+00, TR length: 4.00e-01 172) Best value: -5.06e+00, TR length: 4.00e-01 176) Best value: -5.06e+00, TR length: 4.00e-01 180) Best value: -5.06e+00, TR length: 4.00e-01 184) Best value: -5.06e+00, TR length: 4.00e-01 188) Best value: -5.06e+00, TR length: 2.00e-01 192) Best value: -4.32e+00, TR length: 2.00e-01 196) Best value: -3.86e+00, TR length: 2.00e-01 200) Best value: -3.69e+00, TR length: 2.00e-01 204) Best value: -3.69e+00, TR length: 2.00e-01 208) Best value: -3.69e+00, TR length: 2.00e-01 212) Best value: -3.69e+00, TR length: 2.00e-01 216) Best value: -3.41e+00, TR length: 2.00e-01 220) Best value: -3.41e+00, TR length: 2.00e-01 224) Best value: -3.41e+00, TR length: 2.00e-01 228) Best value: -3.41e+00, TR length: 2.00e-01 232) Best value: -3.41e+00, TR length: 2.00e-01 236) Best value: -3.41e+00, TR length: 1.00e-01 240) Best value: -2.95e+00, TR length: 1.00e-01 244) Best value: -2.65e+00, TR length: 1.00e-01 248) Best value: -2.65e+00, TR length: 1.00e-01 252) Best value: -2.65e+00, TR length: 1.00e-01 256) Best value: -2.54e+00, TR length: 1.00e-01 260) Best value: -2.54e+00, TR length: 1.00e-01 264) Best value: -2.54e+00, TR length: 1.00e-01 268) Best value: -2.38e+00, TR length: 1.00e-01 272) Best value: -2.38e+00, TR length: 1.00e-01 276) Best value: -2.38e+00, TR length: 1.00e-01 280) Best value: -2.38e+00, TR length: 1.00e-01 284) Best value: -2.38e+00, TR length: 1.00e-01 288) Best value: -2.38e+00, TR length: 5.00e-02 292) Best value: -1.91e+00, TR length: 5.00e-02 296) Best value: -1.91e+00, TR length: 5.00e-02 300) Best value: -1.91e+00, TR length: 5.00e-02 304) Best value: -1.31e+00, TR length: 5.00e-02 308) Best value: -1.31e+00, TR length: 5.00e-02 312) Best value: -1.31e+00, TR length: 5.00e-02 316) Best value: -1.31e+00, TR length: 5.00e-02 320) Best value: -1.31e+00, TR length: 5.00e-02 324) Best value: -1.31e+00, TR length: 2.50e-02 328) Best value: -1.31e+00, TR length: 2.50e-02 332) Best value: -1.31e+00, TR length: 2.50e-02 336) Best value: -1.31e+00, TR length: 2.50e-02 340) Best value: -1.31e+00, TR length: 2.50e-02 344) Best value: -1.31e+00, TR length: 1.25e-02 348) Best value: -1.14e+00, TR length: 1.25e-02 352) Best value: -1.14e+00, TR length: 1.25e-02 356) Best value: -1.14e+00, TR length: 1.25e-02 360) Best value: -1.14e+00, TR length: 1.25e-02 364) Best value: -1.12e+00, TR length: 1.25e-02 368) Best value: -1.12e+00, TR length: 1.25e-02 372) Best value: -1.12e+00, TR length: 1.25e-02 376) Best value: -1.12e+00, TR length: 1.25e-02 380) Best value: -1.12e+00, TR length: 1.25e-02 384) Best value: -1.12e+00, TR length: 6.25e-03
As a baseline, we compare TuRBO to qEI
torch.manual_seed(0)
X_ei = get_initial_points(dim, n_init)
Y_ei = torch.tensor(
[eval_objective(x) for x in X_ei], dtype=dtype, device=device
).unsqueeze(-1)
while len(Y_ei) < len(Y_turbo):
train_Y = (Y_ei - Y_ei.mean()) / Y_ei.std()
likelihood = GaussianLikelihood(noise_constraint=Interval(1e-8, 1e-3))
model = SingleTaskGP(X_ei, train_Y, likelihood=likelihood)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll)
# Create a batch
ei = qExpectedImprovement(model, train_Y.max())
candidate, acq_value = optimize_acqf(
ei,
bounds=torch.stack(
[
torch.zeros(dim, dtype=dtype, device=device),
torch.ones(dim, dtype=dtype, device=device),
]
),
q=batch_size,
num_restarts=NUM_RESTARTS,
raw_samples=RAW_SAMPLES,
)
Y_next = torch.tensor(
[eval_objective(x) for x in candidate], dtype=dtype, device=device
).unsqueeze(-1)
# Append data
X_ei = torch.cat((X_ei, candidate), axis=0)
Y_ei = torch.cat((Y_ei, Y_next), axis=0)
# Print current status
print(f"{len(X_ei)}) Best value: {Y_ei.max().item():.2e}")
44) Best value: -1.15e+01 48) Best value: -1.05e+01 52) Best value: -1.02e+01 56) Best value: -9.78e+00 60) Best value: -9.30e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
64) Best value: -9.30e+00 68) Best value: -8.58e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
72) Best value: -8.58e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
76) Best value: -8.58e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
80) Best value: -8.58e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
84) Best value: -8.58e+00 88) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
92) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
96) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
100) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
104) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
108) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
112) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
116) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
120) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
124) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
128) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
132) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
136) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
140) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
144) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
148) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
152) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
156) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
160) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
164) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
168) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
172) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
176) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
180) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
184) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
188) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
192) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
196) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
200) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
204) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
208) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
212) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
216) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
220) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
224) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
228) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
232) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
236) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
240) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
244) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
248) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
252) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
256) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
260) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
264) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
268) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
272) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
276) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
280) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
284) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
288) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
292) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
296) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
300) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
304) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
308) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
312) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
316) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
320) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
324) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
328) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
332) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
336) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
340) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
344) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
348) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
352) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
356) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
360) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
364) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
368) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
372) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
376) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
380) Best value: -8.19e+00 384) Best value: -8.19e+00
/mnt/xarfuse/uid-25696/50ffb8fb-seed-nspid4026533582_cgpid14602017-ns-4026533579/botorch/optim/initializers.py:208: BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly.
X_Sobol = (
SobolEngine(dim, scramble=True, seed=0)
.draw(len(X_turbo))
.to(dtype=dtype, device=device)
)
Y_Sobol = torch.tensor(
[eval_objective(x) for x in X_Sobol], dtype=dtype, device=device
).unsqueeze(-1)
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc
%matplotlib inline
names = ["TuRBO-1", "EI", "Sobol"]
runs = [Y_turbo, Y_ei, Y_Sobol]
fig, ax = plt.subplots(figsize=(8, 6))
for name, run in zip(names, runs):
fx = np.maximum.accumulate(run.cpu())
plt.plot(fx, marker="", lw=3)
plt.plot([0, len(Y_turbo)], [fun.optimal_value, fun.optimal_value], "k--", lw=3)
plt.xlabel("Function value", fontsize=18)
plt.xlabel("Number of evaluations", fontsize=18)
plt.title("20D Ackley", fontsize=24)
plt.xlim([0, len(Y_turbo)])
plt.ylim([-15, 1])
plt.grid(True)
plt.tight_layout()
plt.legend(
names + ["Global optimal value"],
loc="lower center",
bbox_to_anchor=(0, -0.08, 1, 1),
bbox_transform=plt.gcf().transFigure,
ncol=4,
fontsize=16,
)
plt.show()
I0128 080236.236 font_manager.py:1349] generated new fontManager