#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
r"""
Utilities for acquisition functions.
"""
from typing import Callable, Optional
from torch import Tensor
from . import analytic, monte_carlo
from ..models.model import Model
from ..sampling.samplers import IIDNormalSampler, SobolQMCNormalSampler
from ..utils.transforms import squeeze_last_dim
from .acquisition import AcquisitionFunction
from .monte_carlo import MCAcquisitionFunction
from .objective import MCAcquisitionObjective
[docs]def get_acquisition_function(
acquisition_function_name: str,
model: Model,
objective: MCAcquisitionObjective,
X_observed: Tensor,
X_pending: Optional[Tensor] = None,
mc_samples: int = 500,
qmc: bool = True,
seed: Optional[int] = None,
**kwargs,
) -> MCAcquisitionFunction:
r"""Convenience function for initializing botorch acquisition functions.
Args:
acquisition_function_name: Name of the acquisition function.
model: A fitted model.
objective: A MCAcquisitionObjective.
X_observed: A `m1 x d`-dim Tensor of `m1` design points that have
already been observed.
X_pending: A `m2 x d`-dim Tensor of `m2` design points whose evaluation
is pending.
mc_samples: The number of samples to use for (q)MC evaluation of the
acquisition function.
qmc: If True, use quasi-Monte-Carlo sampling (instead of iid).
seed: If provided, perform deterministic optimization (i.e. the
function to optimize is fixed and not stochastic).
Returns:
The requested acquisition function.
Example:
>>> model = SingleTaskGP(train_X, train_Y)
>>> obj = LinearMCObjective(weights=torch.tensor([1.0, 2.0]))
>>> acqf = get_acquisition_function("qEI", model, obj, train_X)
"""
# initialize the sampler
if qmc:
sampler = SobolQMCNormalSampler(num_samples=mc_samples, seed=seed)
else:
sampler = IIDNormalSampler(num_samples=mc_samples, seed=seed)
# instantiate and return the requested acquisition function
if acquisition_function_name == "qEI":
best_f = objective(model.posterior(X_observed).mean).max().item()
return monte_carlo.qExpectedImprovement(
model=model,
best_f=best_f,
sampler=sampler,
objective=objective,
X_pending=X_pending,
)
elif acquisition_function_name == "qPI":
best_f = objective(model.posterior(X_observed).mean).max().item()
return monte_carlo.qProbabilityOfImprovement(
model=model,
best_f=best_f,
sampler=sampler,
objective=objective,
X_pending=X_pending,
tau=kwargs.get("tau", 1e-3),
)
elif acquisition_function_name == "qNEI":
return monte_carlo.qNoisyExpectedImprovement(
model=model,
X_baseline=X_observed,
sampler=sampler,
objective=objective,
X_pending=X_pending,
)
elif acquisition_function_name == "qSR":
return monte_carlo.qSimpleRegret(
model=model, sampler=sampler, objective=objective, X_pending=X_pending
)
elif acquisition_function_name == "qUCB":
if "beta" not in kwargs:
raise ValueError("`beta` must be specified in kwargs for qUCB.")
return monte_carlo.qUpperConfidenceBound(
model=model,
beta=kwargs["beta"],
sampler=sampler,
objective=objective,
X_pending=X_pending,
)
raise NotImplementedError(
f"Unknown acquisition function {acquisition_function_name}"
)
[docs]def get_infeasible_cost(
X: Tensor, model: Model, objective: Callable[[Tensor], Tensor] = squeeze_last_dim
) -> float:
r"""Get infeasible cost for a model and objective.
Computes an infeasible cost `M` such that `-M < min_x f(x)` almost always,
so that feasible points are preferred.
Args:
X: A `m x d` Tensor of `m` design points to use in evaluating the
minimum. These points should cover the design space well. The more
points the better the estimate, at the expense of added computation.
model: A fitted botorch model.
objective: The objective with which to evaluate the model output.
Returns:
The infeasible cost `M` value.
Example:
>>> model = SingleTaskGP(train_X, train_Y)
>>> objective = lambda Y: Y[..., -1] ** 2
>>> M = get_infeasible_cost(train_X, model, obj)
"""
posterior = model.posterior(X)
lb = objective(posterior.mean - 6 * posterior.variance.clamp_min(0).sqrt()).min()
M = -lb.clamp_max(0.0)
return M.item()
[docs]def is_nonnegative(acq_function: AcquisitionFunction) -> bool:
r"""Determine whether a given acquisition function is non-negative.
Args:
acq_function: The `AcquisitionFunction` instance.
Returns:
True if `acq_function` is non-negative, False if not, or if the behavior
is unknown (for custom acquisition functions).
Example:
>>> qEI = qExpectedImprovement(model, best_f=0.1)
>>> is_nonnegative(qEI) # returns True
"""
return isinstance(
acq_function,
(
analytic.ExpectedImprovement,
analytic.ConstrainedExpectedImprovement,
analytic.ProbabilityOfImprovement,
analytic.NoisyExpectedImprovement,
monte_carlo.qExpectedImprovement,
monte_carlo.qNoisyExpectedImprovement,
monte_carlo.qProbabilityOfImprovement,
),
)