#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Analytic Acquisition Functions for Multi-objective Bayesian optimization.
References
.. [Yang2019]
Yang, K., Emmerich, M., Deutz, A. et al. Efficient computation of expected
hypervolume improvement using box decomposition algorithms. J Glob Optim 75,
3–34 (2019)
"""
from __future__ import annotations
from abc import abstractmethod
from itertools import product
from typing import List, Optional
import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions.errors import UnsupportedError
from botorch.models.model import Model
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
NondominatedPartitioning,
)
from botorch.utils.transforms import t_batch_mode_transform
from torch import Tensor
from torch.distributions import Normal
[docs]
class MultiObjectiveAnalyticAcquisitionFunction(AcquisitionFunction):
r"""Abstract base class for Multi-Objective batch acquisition functions."""
def __init__(
self,
model: Model,
posterior_transform: Optional[PosteriorTransform] = None,
) -> None:
r"""Constructor for the MultiObjectiveAnalyticAcquisitionFunction base class.
Args:
model: A fitted model.
posterior_transform: A PosteriorTransform (optional).
"""
super().__init__(model=model)
if posterior_transform is None or isinstance(
posterior_transform, PosteriorTransform
):
self.posterior_transform = posterior_transform
else:
raise UnsupportedError(
"Only a posterior_transform of type PosteriorTransform is "
"supported for Multi-Objective analytic acquisition functions."
)
[docs]
@abstractmethod
def forward(self, X: Tensor) -> Tensor:
r"""Takes in a `batch_shape x 1 x d` X Tensor of t-batches with `1` `d`-dim
design point each, and returns a Tensor with shape `batch_shape'`, where
`batch_shape'` is the broadcasted batch shape of model and input `X`.
"""
pass # pragma: no cover
[docs]
def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
raise UnsupportedError(
"Analytic acquisition functions do not account for X_pending yet."
)
[docs]
class ExpectedHypervolumeImprovement(MultiObjectiveAnalyticAcquisitionFunction):
def __init__(
self,
model: Model,
ref_point: List[float],
partitioning: NondominatedPartitioning,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs,
) -> None:
r"""Expected Hypervolume Improvement supporting m>=2 outcomes.
This implements the computes EHVI using the algorithm from [Yang2019]_, but
additionally computes gradients via auto-differentiation as proposed by
[Daulton2020qehvi]_.
Note: this is currently inefficient in two ways due to the binary partitioning
algorithm that we use for the box decomposition:
- We have more boxes in our decomposition
- If we used a box decomposition that used `inf` as the upper bound for
the last dimension *in all hypercells*, then we could reduce the number
of terms we need to compute from 2^m to 2^(m-1). [Yang2019]_ do this
by using DKLV17 and LKF17 for the box decomposition.
TODO: Use DKLV17 and LKF17 for the box decomposition as in [Yang2019]_ for
greater efficiency.
TODO: Add support for outcome constraints.
Example:
>>> model = SingleTaskGP(train_X, train_Y)
>>> ref_point = [0.0, 0.0]
>>> EHVI = ExpectedHypervolumeImprovement(model, ref_point, partitioning)
>>> ehvi = EHVI(test_X)
Args:
model: A fitted model.
ref_point: A list with `m` elements representing the reference point (in the
outcome space) w.r.t. to which compute the hypervolume. This is a
reference point for the objective values (i.e. after applying
`objective` to the samples).
partitioning: A `NondominatedPartitioning` module that provides the non-
dominated front and a partitioning of the non-dominated space in hyper-
rectangles.
posterior_transform: A `PosteriorTransform`.
"""
# TODO: we could refactor this __init__ logic into a
# HypervolumeAcquisitionFunction Mixin
if len(ref_point) != partitioning.num_outcomes:
raise ValueError(
"The length of the reference point must match the number of outcomes. "
f"Got ref_point with {len(ref_point)} elements, but expected "
f"{partitioning.num_outcomes}."
)
ref_point = torch.tensor(
ref_point,
dtype=partitioning.pareto_Y.dtype,
device=partitioning.pareto_Y.device,
)
better_than_ref = (partitioning.pareto_Y > ref_point).all(dim=1)
if not better_than_ref.any() and partitioning.pareto_Y.shape[0] > 0:
raise ValueError(
"At least one pareto point must be better than the reference point."
)
super().__init__(model=model, posterior_transform=posterior_transform, **kwargs)
self.register_buffer("ref_point", ref_point)
self.partitioning = partitioning
cell_bounds = self.partitioning.get_hypercell_bounds()
self.register_buffer("cell_lower_bounds", cell_bounds[0])
self.register_buffer("cell_upper_bounds", cell_bounds[1])
# create indexing tensor of shape `2^m x m`
self._cross_product_indices = torch.tensor(
list(product(*[[0, 1] for _ in range(ref_point.shape[0])])),
dtype=torch.long,
device=ref_point.device,
)
self.normal = Normal(0, 1)
[docs]
def psi(self, lower: Tensor, upper: Tensor, mu: Tensor, sigma: Tensor) -> Tensor:
r"""Compute Psi function.
For each cell i and outcome k:
Psi(lower_{i,k}, upper_{i,k}, mu_k, sigma_k) = (
sigma_k * PDF((upper_{i,k} - mu_k) / sigma_k) + (
mu_k - lower_{i,k}
) * (1 - CDF(upper_{i,k} - mu_k) / sigma_k)
)
See Equation 19 in [Yang2019]_ for more details.
Args:
lower: A `num_cells x m`-dim tensor of lower cell bounds
upper: A `num_cells x m`-dim tensor of upper cell bounds
mu: A `batch_shape x 1 x m`-dim tensor of means
sigma: A `batch_shape x 1 x m`-dim tensor of standard deviations (clamped).
Returns:
A `batch_shape x num_cells x m`-dim tensor of values.
"""
u = (upper - mu) / sigma
return sigma * self.normal.log_prob(u).exp() + (mu - lower) * (
1 - self.normal.cdf(u)
)
[docs]
def nu(self, lower: Tensor, upper: Tensor, mu: Tensor, sigma: Tensor) -> Tensor:
r"""Compute Nu function.
For each cell i and outcome k:
nu(lower_{i,k}, upper_{i,k}, mu_k, sigma_k) = (
upper_{i,k} - lower_{i,k}
) * (1 - CDF((upper_{i,k} - mu_k) / sigma_k))
See Equation 25 in [Yang2019]_ for more details.
Args:
lower: A `num_cells x m`-dim tensor of lower cell bounds
upper: A `num_cells x m`-dim tensor of upper cell bounds
mu: A `batch_shape x 1 x m`-dim tensor of means
sigma: A `batch_shape x 1 x m`-dim tensor of standard deviations (clamped).
Returns:
A `batch_shape x num_cells x m`-dim tensor of values.
"""
return (upper - lower) * (1 - self.normal.cdf((upper - mu) / sigma))
[docs]
@t_batch_mode_transform()
def forward(self, X: Tensor) -> Tensor:
posterior = self.model.posterior(
X, posterior_transform=self.posterior_transform
)
mu = posterior.mean
sigma = posterior.variance.clamp_min(1e-9).sqrt()
# clamp here, since upper_bounds will contain `inf`s, which
# are not differentiable
cell_upper_bounds = self.cell_upper_bounds.clamp_max(
1e10 if X.dtype == torch.double else 1e8
)
# Compute psi(lower_i, upper_i, mu_i, sigma_i) for i=0, ... m-2
psi_lu = self.psi(
lower=self.cell_lower_bounds, upper=cell_upper_bounds, mu=mu, sigma=sigma
)
# Compute psi(lower_m, lower_m, mu_m, sigma_m)
psi_ll = self.psi(
lower=self.cell_lower_bounds,
upper=self.cell_lower_bounds,
mu=mu,
sigma=sigma,
)
# Compute nu(lower_m, upper_m, mu_m, sigma_m)
nu = self.nu(
lower=self.cell_lower_bounds, upper=cell_upper_bounds, mu=mu, sigma=sigma
)
# compute the difference psi_ll - psi_lu
psi_diff = psi_ll - psi_lu
# this is batch_shape x num_cells x 2 x (m-1)
stacked_factors = torch.stack([psi_diff, nu], dim=-2)
# Take the cross product of psi_diff and nu across all outcomes
# e.g. for m = 2
# for each batch and cell, compute
# [psi_diff_0, psi_diff_1]
# [nu_0, psi_diff_1]
# [psi_diff_0, nu_1]
# [nu_0, nu_1]
# this tensor has shape: `batch_shape x num_cells x 2^m x m`
all_factors_up_to_last = stacked_factors.gather(
dim=-2,
index=self._cross_product_indices.expand(
stacked_factors.shape[:-2] + self._cross_product_indices.shape
),
)
# compute product for all 2^m terms,
# sum across all terms and hypercells
return all_factors_up_to_last.prod(dim=-1).sum(dim=-1).sum(dim=-1)