#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Utilities for acquisition functions.
"""
from __future__ import annotations
from typing import Callable, List, Optional, Union
import torch
from botorch.acquisition import logei, monte_carlo
from botorch.acquisition.multi_objective import (
logei as moo_logei,
monte_carlo as moo_monte_carlo,
)
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
from botorch.acquisition.utils import compute_best_feasible_objective
from botorch.models.model import Model
from botorch.sampling.get_sampler import get_sampler
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
FastNondominatedPartitioning,
NondominatedPartitioning,
)
from torch import Tensor
[docs]
def get_acquisition_function(
acquisition_function_name: str,
model: Model,
objective: MCAcquisitionObjective,
X_observed: Tensor,
posterior_transform: Optional[PosteriorTransform] = None,
X_pending: Optional[Tensor] = None,
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
eta: Optional[Union[Tensor, float]] = 1e-3,
mc_samples: int = 512,
seed: Optional[int] = None,
*,
# optional parameters that are only needed for certain acquisition functions
tau: float = 1e-3,
prune_baseline: bool = True,
marginalize_dim: Optional[int] = None,
cache_root: bool = True,
beta: Optional[float] = None,
ref_point: Union[None, List[float], Tensor] = None,
Y: Optional[Tensor] = None,
alpha: float = 0.0,
) -> monte_carlo.MCAcquisitionFunction:
r"""Convenience function for initializing botorch acquisition functions.
Args:
acquisition_function_name: Name of the acquisition function.
model: A fitted model.
objective: A MCAcquisitionObjective.
X_observed: A `m1 x d`-dim Tensor of `m1` design points that have
already been observed.
posterior_transform: A PosteriorTransform (optional).
X_pending: A `m2 x d`-dim Tensor of `m2` design points whose evaluation
is pending.
constraints: A list of callables, each mapping a Tensor of dimension
`sample_shape x batch-shape x q x m` to a Tensor of dimension
`sample_shape x batch-shape x q`, where negative values imply
feasibility. Used for all acquisition functions except qSR and qUCB.
eta: The temperature parameter for the sigmoid function used for the
differentiable approximation of the constraints. In case of a float the
same eta is used for every constraint in constraints. In case of a
tensor the length of the tensor must match the number of provided
constraints. The i-th constraint is then estimated with the i-th
eta value. Used for all acquisition functions except qSR and qUCB.
mc_samples: The number of samples to use for (q)MC evaluation of the
acquisition function.
seed: If provided, perform deterministic optimization (i.e. the
function to optimize is fixed and not stochastic).
Returns:
The requested acquisition function.
Example:
>>> model = SingleTaskGP(train_X, train_Y)
>>> obj = LinearMCObjective(weights=torch.tensor([1.0, 2.0]))
>>> acqf = get_acquisition_function("qEI", model, obj, train_X)
"""
# initialize the sampler
sampler = get_sampler(
posterior=model.posterior(X_observed[:1]),
sample_shape=torch.Size([mc_samples]),
seed=seed,
)
if posterior_transform is not None and acquisition_function_name in [
"qEHVI",
"qNEHVI",
"qLogEHVI",
"qLogNEHVI",
]:
raise NotImplementedError(
"PosteriorTransforms are not yet implemented for multi-objective "
"acquisition functions."
)
# instantiate and return the requested acquisition function
if acquisition_function_name in ("qEI", "qLogEI", "qPI"):
# Since these are the non-noisy variants, use the posterior mean at the observed
# inputs directly to compute the best feasible value without sampling.
Y = model.posterior(X_observed, posterior_transform=posterior_transform).mean
obj = objective(samples=Y, X=X_observed)
best_f = compute_best_feasible_objective(
samples=Y,
obj=obj,
constraints=constraints,
model=model,
objective=objective,
posterior_transform=posterior_transform,
X_baseline=X_observed,
)
if acquisition_function_name in ["qEI", "qLogEI"]:
acqf_class = (
monte_carlo.qExpectedImprovement
if acquisition_function_name == "qEI"
else logei.qLogExpectedImprovement
)
return acqf_class(
model=model,
best_f=best_f,
sampler=sampler,
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
constraints=constraints,
eta=eta,
)
elif acquisition_function_name == "qPI":
return monte_carlo.qProbabilityOfImprovement(
model=model,
best_f=best_f,
sampler=sampler,
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
tau=tau,
constraints=constraints,
eta=eta,
)
elif acquisition_function_name in ["qNEI", "qLogNEI"]:
acqf_class = (
monte_carlo.qNoisyExpectedImprovement
if acquisition_function_name == "qNEI"
else logei.qLogNoisyExpectedImprovement
)
return acqf_class(
model=model,
X_baseline=X_observed,
sampler=sampler,
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
prune_baseline=prune_baseline,
marginalize_dim=marginalize_dim,
cache_root=cache_root,
constraints=constraints,
eta=eta,
)
elif acquisition_function_name == "qSR":
return monte_carlo.qSimpleRegret(
model=model,
sampler=sampler,
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
)
elif acquisition_function_name == "qUCB":
if beta is None:
raise ValueError("`beta` must be not be None for qUCB.")
return monte_carlo.qUpperConfidenceBound(
model=model,
beta=beta,
sampler=sampler,
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
)
elif acquisition_function_name in ["qEHVI", "qLogEHVI"]:
if Y is None:
raise ValueError(f"`Y` must not be None for {acquisition_function_name}")
if ref_point is None:
raise ValueError(
f"`ref_point` must not be None for {acquisition_function_name}"
)
# get feasible points
if constraints is not None:
feas = torch.stack([c(Y) <= 0 for c in constraints], dim=-1).all(dim=-1)
Y = Y[feas]
obj = objective(Y)
if alpha > 0:
partitioning = NondominatedPartitioning(
ref_point=torch.as_tensor(ref_point, dtype=Y.dtype, device=Y.device),
Y=obj,
alpha=alpha,
)
else:
partitioning = FastNondominatedPartitioning(
ref_point=torch.as_tensor(ref_point, dtype=Y.dtype, device=Y.device),
Y=obj,
)
acqf_class = (
moo_monte_carlo.qExpectedHypervolumeImprovement
if acquisition_function_name == "qEHVI"
else moo_logei.qLogExpectedHypervolumeImprovement
)
return acqf_class(
model=model,
ref_point=ref_point,
partitioning=partitioning,
sampler=sampler,
objective=objective,
constraints=constraints,
eta=eta,
X_pending=X_pending,
)
elif acquisition_function_name in ["qNEHVI", "qLogNEHVI"]:
if ref_point is None:
raise ValueError(
f"`ref_point` must not be None for {acquisition_function_name}"
)
acqf_class = (
moo_monte_carlo.qNoisyExpectedHypervolumeImprovement
if acquisition_function_name == "qNEHVI"
else moo_logei.qLogNoisyExpectedHypervolumeImprovement
)
return acqf_class(
model=model,
ref_point=ref_point,
X_baseline=X_observed,
sampler=sampler,
objective=objective,
constraints=constraints,
eta=eta,
prune_baseline=prune_baseline,
alpha=alpha,
X_pending=X_pending,
marginalize_dim=marginalize_dim,
cache_root=cache_root,
)
raise NotImplementedError(
f"Unknown acquisition function {acquisition_function_name}"
)