#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Acquisition function for predictive entropy search for multi-objective Bayesian
optimization (PES). The code does not support constraint handling.
NOTE: The PES acquisition might not be differentiable. As a result, we recommend
optimizing the acquisition function using finite differences.
References:
.. [Garrido-Merchan2019]
E. Garrido-Merchan and D. Hernandez-Lobato. Predictive Entropy Search for
Multi-objective Bayesian Optimization with Constraints. Neurocomputing. 2019.
The computation follows the procedure described in the supplementary material:
https://www.sciencedirect.com/science/article/abs/pii/S0925231219308525
"""
from __future__ import annotations
from typing import Any, Optional, Tuple
import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.exceptions import InputDataError
from botorch.exceptions.errors import UnsupportedError
from botorch.models.model import Model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.utils import check_no_nans
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
from torch import Tensor
from torch.distributions import Normal
[docs]
class qMultiObjectivePredictiveEntropySearch(AcquisitionFunction):
r"""The acquisition function for Predictive Entropy Search. The code supports
both single and multiple objectives as well as batching.
This acquisition function approximates the mutual information between the
observation at a candidate point `X` and the Pareto optimal input using the
moment-matching procedure known as expectation propagation (EP).
See the Appendix of [Garrido-Merchan2019]_ for the description of the EP
procedure.
IMPORTANT NOTES:
(i) The PES acquisition function estimated using EP is sometimes not
differentiable, and therefore we advise using a finite-difference estimate of
the gradient as opposed to the gradients identified using automatic
differentiation, which occasionally outputs `nan` values.
The source of this differentiability is in the `_update_damping` function, which
finds the damping factor `a` that is used to update the EP parameters
`a * param_new + (1 - a) * param_old`. The damping factor has to ensure
that the updated covariance matrices, `a * cov_f_new + (1 - a) cov_f_old`, is
positive semi-definiteness. We follow the original paper, which identifies
`a` via a successive halving scheme i.e. we check `a=1` then `a=0.5` etc. This
procedure means `a` is a function of the test input `X`. This function is not
differentiable in `X`.
(ii) EP could potentially fail for a number of reasons:
(a) When the sampled Pareto optimal points `x_p` is poor compared to the
training or testing data `x_n`.
(b) When the training or testing data `x_n` is close the Pareto optimal
points `x_p`.
(c) When the convergence threshold is set too small.
Problem (a) occurs because we have to compute the variable:
`alpha = (mean(x_n) - mean(x_p)) / std(x_n - x_p)`, which becomes very
large when `x_n` is better than `x_p` with high-probability. This leads to a
log(0) error when we compute `log(1 - cdf(alpha))`. We have preemptively
clamped some values depending on `1`alpha` in order to mitigate this.
Problem (b) occurs because we have to compute matrix inverses for the
two-dimensional marginals (x_n, x_p). To address this we manually add jitter
to the diagonal of the covariance matrix i.e. `ep_jitter` when training and
`test_jitter` when testing. The default choice is not always appropriate
because the same jitter is used for the inversion of the covariance
and precision matrix, which are on different scales.
TODO: come up with strategy to adaptively update the jitter.
Problem (c) occurs because a smaller threshold usually means that more EP
iterations are required. Running too many EP iterations could lead to
invertibility problems such as in problem (b). Setting a larger threshold
or reducing the number of EP iterations could alleviate this.
(iii) The estimated acquisition value could be negative.
"""
def __init__(
self,
model: Model,
pareto_sets: Tensor,
maximize: bool = True,
X_pending: Optional[Tensor] = None,
max_ep_iterations: int = 250,
ep_jitter: float = 1e-4,
test_jitter: float = 1e-4,
threshold: float = 1e-2,
**kwargs: Any,
) -> None:
r"""Multi-objective predictive entropy search acquisition function.
Args:
model: A fitted batched model with `M` number of outputs.
pareto_sets: A `num_pareto_samples x P x d`-dim tensor containing the
Pareto optimal set of inputs, where `P` is the number of pareto
optimal points. The points in each sample have to be discrete
otherwise expectation propagation will fail.
maximize: If true, we consider a maximization problem.
X_pending: A `m x d`-dim Tensor of `m` design points that have been
submitted for function evaluation, but have not yet been evaluated.
max_ep_iterations: The maximum number of expectation propagation
iterations. (The minimum number of iterations is set at 3.)
ep_jitter: The amount of jitter added for the matrix inversion that
occurs during the expectation propagation update during the training
phase.
test_jitter: The amount of jitter added for the matrix inversion that
occurs during the expectation propagation update in the testing
phase.
threshold: The convergence threshold for expectation propagation. This
assesses the relative change in the mean and covariance. We default
to one percent change i.e. `threshold = 1e-2`.
"""
super().__init__(model=model)
self.model = model
self.maximize = maximize
self.set_X_pending(X_pending)
if model.num_outputs > 1 or isinstance(model, ModelListGP):
train_X = self.model.train_inputs[0][0]
else:
train_X = self.model.train_inputs[0]
# Batch GP models (e.g. fantasized models) are not currently supported
if train_X.ndim > 2:
raise NotImplementedError(
"Batch GP models (e.g. fantasized models) are not supported."
)
if pareto_sets.ndim != 3 or pareto_sets.shape[-1] != train_X.shape[-1]:
raise UnsupportedError(
"The Pareto set should have a shape of "
"`num_pareto_samples x num_pareto_points x input_dim`."
)
else:
self.pareto_sets = pareto_sets
# add the pareto set to the existing training data
self.num_pareto_samples = pareto_sets.shape[0]
self.augmented_X = torch.cat(
[train_X.repeat(self.num_pareto_samples, 1, 1), self.pareto_sets], dim=-2
)
self.max_ep_iterations = max_ep_iterations
self.ep_jitter = ep_jitter
self.test_jitter = test_jitter
self.threshold = threshold
self._expectation_propagation()
def _expectation_propagation(self) -> None:
r"""Perform expectation propagation to obtain the covariance factors that
depend on the Pareto sets.
The updates are performed in the natural parameter space. For a multivariate
normal distribution with mean mu and covariance Sigma, we call Sigma^{-1}
the natural covariance and Sigma^{-1} mu the natural mean.
"""
###########################################################################
# INITIALIZATION
###########################################################################
M = self.model.num_outputs
if self.model.num_outputs > 1 or isinstance(self.model, ModelListGP):
train_X = self.model.train_inputs[0][0]
else:
train_X = self.model.train_inputs[0]
tkwargs = {"dtype": train_X.dtype, "device": train_X.device}
N = len(train_X)
num_pareto_samples = self.num_pareto_samples
P = self.pareto_sets.shape[-2]
# initialize the predictive natural mean and variances
(
pred_nat_mean,
pred_nat_cov,
pred_mean,
pred_cov,
) = _initialize_predictive_matrices(
X=self.augmented_X,
model=self.model,
observation_noise=False,
jitter=self.ep_jitter,
natural=True,
)
pred_f_mean = pred_mean[..., 0:M, :]
pred_f_nat_mean = pred_nat_mean[..., 0:M, :]
pred_f_cov = pred_cov[..., 0:M, :, :]
pred_f_nat_cov = pred_nat_cov[..., 0:M, :, :]
# initialize the marginals
# `num_pareto_samples x M x (N + P)`
mean_f = pred_f_mean.clone()
nat_mean_f = pred_f_nat_mean.clone()
# `num_pareto_samples x M x (N + P) x (N + P)`
cov_f = pred_f_cov.clone()
nat_cov_f = pred_f_nat_cov.clone()
# initialize omega the function which encodes the fact that the pareto points
# are optimal in the feasible space i.e. any point in the feasible space
# should not dominate the Pareto efficient points.
# `num_pareto_samples x M x (N + P) x P x 2`
omega_f_nat_mean = torch.zeros((num_pareto_samples, M, N + P, P, 2), **tkwargs)
# `num_pareto_samples x M x (N + P) x P x 2 x 2`
omega_f_nat_cov = torch.zeros(
(num_pareto_samples, M, N + P, P, 2, 2), **tkwargs
)
###########################################################################
# EXPECTATION PROPAGATION
###########################################################################
damping = torch.ones(num_pareto_samples, M, **tkwargs)
iteration = 0
while (torch.sum(damping) > 0) and (iteration < self.max_ep_iterations):
# Compute the new natural mean and covariance
####################################################################
# OBJECTIVE FUNCTION: OMEGA UPDATE
####################################################################
omega_f_nat_mean_new, omega_f_nat_cov_new = _safe_update_omega(
mean_f=mean_f,
cov_f=cov_f,
omega_f_nat_mean=omega_f_nat_mean,
omega_f_nat_cov=omega_f_nat_cov,
N=N,
P=P,
M=M,
maximize=self.maximize,
jitter=self.ep_jitter,
)
####################################################################
# OBJECTIVE FUNCTION: MARGINAL UPDATE
####################################################################
nat_mean_f_new, nat_cov_f_new = _update_marginals(
pred_f_nat_mean=pred_f_nat_mean,
pred_f_nat_cov=pred_f_nat_cov,
omega_f_nat_mean=omega_f_nat_mean_new,
omega_f_nat_cov=omega_f_nat_cov_new,
N=N,
P=P,
)
########################################################################
# OBJECTIVE FUNCTION: DAMPING UPDATE
########################################################################
# update damping of objectives
damping, cholesky_nat_cov_f = _update_damping(
nat_cov=nat_cov_f,
nat_cov_new=nat_cov_f_new,
damping_factor=damping,
jitter=self.ep_jitter,
)
check_no_nans(cholesky_nat_cov_f)
########################################################################
# OBJECTIVE FUNCTION: DAMPED UPDATE
########################################################################
# Damp update of omega
omega_f_nat_mean = _damped_update(
old_factor=omega_f_nat_mean,
new_factor=omega_f_nat_mean_new,
damping_factor=damping,
)
omega_f_nat_cov = _damped_update(
old_factor=omega_f_nat_cov,
new_factor=omega_f_nat_cov_new,
damping_factor=damping,
)
# update the mean and covariance
nat_mean_f = _damped_update(
old_factor=nat_mean_f, new_factor=nat_mean_f_new, damping_factor=damping
)
nat_cov_f = _damped_update(
old_factor=nat_cov_f, new_factor=nat_cov_f_new, damping_factor=damping
)
# compute cholesky inverse
cov_f_new = torch.cholesky_inverse(cholesky_nat_cov_f)
mean_f_new = torch.einsum("...ij,...j->...i", cov_f_new, nat_mean_f)
check_no_nans(cov_f_new)
########################################################################
# OBJECTIVE FUNCTION: CONVERGENCE UPDATE
########################################################################
# Set the damping to zero when the change in the mean and
# covariance is less than the threshold
damping, delta_mean_f, delta_cov_f = _update_damping_when_converged(
mean_old=mean_f,
mean_new=mean_f_new,
cov_old=cov_f,
cov_new=cov_f_new,
damping_factor=damping,
threshold=self.threshold,
iteration=iteration,
)
cov_f = cov_f_new
mean_f = mean_f_new
iteration = iteration + 1
############################################################################
# SAVE OMEGA AND PHI FACTORS
############################################################################
check_no_nans(omega_f_nat_mean)
check_no_nans(omega_f_nat_cov)
# save phi and omega for the forward
self._omega_f_nat_mean = omega_f_nat_mean
self._omega_f_nat_cov = omega_f_nat_cov
def _compute_information_gain(self, X: Tensor) -> Tensor:
r"""Evaluate qMultiObjectivePredictiveEntropySearch on the candidate set `X`.
Args:
X: A `batch_shape x q x d`-dim Tensor of t-batches with `q` `d`-dim
design points each.
Returns:
A `batch_shape'`-dim Tensor of Predictive Entropy Search values at the
given design points `X`.
"""
tkwargs = {"dtype": X.dtype, "device": X.device}
batch_shape = X.shape[0:-2]
q = X.shape[-2]
M = self.model.num_outputs
if M > 1 or isinstance(self.model, ModelListGP):
N = len(self.model.train_inputs[0][0])
else:
N = len(self.model.train_inputs[0])
P = self.pareto_sets.shape[-2]
num_pareto_samples = self.num_pareto_samples
###########################################################################
# AUGMENT X WITH THE SAMPLED PARETO SET
###########################################################################
new_shape = batch_shape + torch.Size([num_pareto_samples]) + X.shape[-2:]
expanded_X = X.unsqueeze(-3).expand(new_shape)
expanded_ps = self.pareto_sets.expand(X.shape[0:-2] + self.pareto_sets.shape)
# `batch_shape x num_pareto_samples x (q + P) x d`
aug_X = torch.cat([expanded_X, expanded_ps], dim=-2)
###########################################################################
# COMPUTE THE POSTERIORS AND OBSERVATION NOISE
###########################################################################
# compute predictive distribution without observation noise
(
pred_nat_mean,
pred_nat_cov,
pred_mean,
pred_cov,
) = _initialize_predictive_matrices(
X=aug_X,
model=self.model,
observation_noise=True,
jitter=self.test_jitter,
natural=True,
)
pred_f_mean = pred_mean[..., 0:M, :]
pred_f_nat_mean = pred_nat_mean[..., 0:M, :]
pred_f_cov = pred_cov[..., 0:M, :, :]
pred_f_nat_cov = pred_nat_cov[..., 0:M, :, :]
(_, _, _, pred_cov_noise) = _initialize_predictive_matrices(
X=aug_X,
model=self.model,
observation_noise=True,
jitter=self.test_jitter,
natural=False,
)
pred_f_cov_noise = pred_cov_noise[..., 0:M, :, :]
observation_noise = pred_f_cov_noise - pred_f_cov
###########################################################################
# INITIALIZE THE EP FACTORS
###########################################################################
# `batch_shape x num_pareto_samples x M x (q + P) x P x 2`
omega_f_nat_mean = torch.zeros(
batch_shape + torch.Size([num_pareto_samples, M, q + P, P, 2]), **tkwargs
)
# `batch_shape x num_pareto_samples x M x (q + P) x P x 2 x 2`
omega_f_nat_cov = torch.zeros(
batch_shape + torch.Size([num_pareto_samples, M, q + P, P, 2, 2]), **tkwargs
)
###########################################################################
# RUN EP ONCE
###########################################################################
# run update omega once
omega_f_nat_mean, omega_f_nat_cov = _safe_update_omega(
mean_f=pred_f_mean,
cov_f=pred_f_cov,
omega_f_nat_mean=omega_f_nat_mean,
omega_f_nat_cov=omega_f_nat_cov,
N=q,
P=P,
M=M,
maximize=self.maximize,
jitter=self.test_jitter,
)
###########################################################################
# ADD THE CACHE FACTORS BACK
###########################################################################
omega_f_nat_mean, omega_f_nat_cov = _augment_factors_with_cached_factors(
q=q,
N=N,
omega_f_nat_mean=omega_f_nat_mean,
cached_omega_f_nat_mean=self._omega_f_nat_mean,
omega_f_nat_cov=omega_f_nat_cov,
cached_omega_f_nat_cov=self._omega_f_nat_cov,
)
###########################################################################
# COMPUTE THE MARGINAL
###########################################################################
nat_mean_f, nat_cov_f = _update_marginals(
pred_f_nat_mean=pred_f_nat_mean,
pred_f_nat_cov=pred_f_nat_cov,
omega_f_nat_mean=omega_f_nat_mean,
omega_f_nat_cov=omega_f_nat_cov,
N=q,
P=P,
)
###########################################################################
# COMPUTE THE DAMPED UPDATE
###########################################################################
# # update damping of objectives
damping = torch.ones(
batch_shape + torch.Size([num_pareto_samples, M]), **tkwargs
)
damping, cholesky_nat_cov_f_new = _update_damping(
nat_cov=pred_f_nat_cov,
nat_cov_new=nat_cov_f,
damping_factor=damping,
jitter=self.test_jitter,
)
# invert matrix
cov_f_new = torch.cholesky_inverse(cholesky_nat_cov_f_new)
check_no_nans(cov_f_new)
###########################################################################
# COMPUTE THE LOG DETERMINANTS
###########################################################################
# compute the initial log determinant term
log_det_pred_f_cov_noise = _compute_log_determinant(cov=pred_f_cov_noise, q=q)
# compute the post log determinant term
log_det_cov_f = _compute_log_determinant(cov=cov_f_new + observation_noise, q=q)
###########################################################################
# COMPUTE THE ACQUISITION FUNCTION
###########################################################################
q_pes_f = log_det_pred_f_cov_noise - log_det_cov_f
check_no_nans(q_pes_f)
return 0.5 * q_pes_f
[docs]
@concatenate_pending_points
@t_batch_mode_transform()
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate qMultiObjectivePredictiveEntropySearch on the candidate set `X`.
Args:
X: A `batch_shape x q x d`-dim Tensor of t-batches with `q` `d`-dim
design points each.
Returns:
A `batch_shape'`-dim Tensor of acquisition values at the given design
points `X`.
"""
return self._compute_information_gain(X)
[docs]
def log_cdf_robust(x: Tensor) -> Tensor:
r"""Computes the logarithm of the normal cumulative density robustly. This uses
the approximation log(1-z) ~ -z when z is small:
if x > 5:
log(cdf(x)) = log(1-cdf(-x)) approx -cdf(-x)
else:
log(cdf(x)).
Args:
x: a `x_shape`-dim Tensor.
Returns
A `x_shape`-dim Tensor.
"""
CLAMP_LB = torch.finfo(x.dtype).eps
NEG_INF = torch.finfo(x.dtype).min
normal = Normal(torch.zeros_like(x), torch.ones_like(x))
cdf_x = normal.cdf(x)
neg_cdf_neg_x = -normal.cdf(-x)
log_cdf_x = torch.where(x < 5, torch.log(cdf_x), neg_cdf_neg_x)
return log_cdf_x.clamp(NEG_INF, -CLAMP_LB)
def _initialize_predictive_matrices(
X: Tensor,
model: Model,
observation_noise: bool = True,
jitter: float = 1e-4,
natural: bool = True,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
r"""Initializes the natural predictive mean and covariance matrix. For a
multivariate normal distribution with mean mu and covariance Sigma, the natural
mean is Sigma^{-1} mu and the natural covariance is Sigma^{-1}.
Args:
X: A `batch_shape x R x d`-dim Tensor.
model: The fitted model.
observation_noise: If true, the posterior is computed with observation noise.
jitter: The jitter added to the covariance matrix.
natural: If true, we compute the natural statistics as well.
Return:
A four-element tuple containing
- pred_nat_mean: A `batch_shape x num_outputs x R `-dim Tensor containing the
predictive natural mean vectors.
- pred_nat_cov: A `batch_shape x num_outputs x R x R`-dim Tensor containing
the predictive natural covariance matrices.
- pred_mean: A `batch_shape x num_outputs x R`-dim Tensor containing the
predictive mean vectors.
- pred_cov: A `batch_shape x num_outputs x R x R`-dim Tensor containing the
predictive covariance matrices.
"""
tkwargs = {"dtype": X.dtype, "device": X.device}
# compute the predictive mean and covariances at X
posterior = model.posterior(X, observation_noise=observation_noise)
# `batch_shape x (R * num_outputs) x (R * num_outputs)`
init_pred_cov = posterior.mvn.covariance_matrix
num_outputs = model.num_outputs
R = int(init_pred_cov.shape[-1] / num_outputs)
pred_cov = [
init_pred_cov[..., (m * R) : ((m + 1) * R), (m * R) : ((m + 1) * R)].unsqueeze(
-1
)
for m in range(num_outputs)
]
# `batch_shape x R x R x num_outputs` (before swap axes)
# `batch_shape x num_outputs x R * R`
pred_cov = torch.cat(pred_cov, axis=-1).swapaxes(-2, -1).swapaxes(-3, -2)
identity = torch.diag_embed(torch.ones(pred_cov.shape[:-1], **tkwargs))
pred_cov = pred_cov + jitter * identity
# `batch_shape x num_outputs x R`
pred_mean = posterior.mean.swapaxes(-2, -1)
#############################################################
if natural:
# natural parameters
# `batch_shape x num_outputs x R x R`
cholesky_pred_cov, _ = torch.linalg.cholesky_ex(pred_cov)
pred_nat_cov = torch.cholesky_inverse(cholesky_pred_cov)
# `batch_shape x num_outputs x R`
pred_nat_mean = torch.einsum("...ij,...j->...i", pred_nat_cov, pred_mean)
return pred_nat_mean, pred_nat_cov, pred_mean, pred_cov
else:
return None, None, pred_mean, pred_cov
def _get_omega_f_contribution(
mean: Tensor, cov: Tensor, N: int, P: int, M: int
) -> Tuple[Tensor, Tensor]:
r"""Extract the mean vector and covariance matrix corresponding to the `2 x 2`
multivariate normal blocks in the objective model between the points in `X` and
the Pareto optimal set.
[There is likely a more efficient way to do this.]
Args:
mean: A `batch_shape x M x (N + P)`-dim Tensor containing the natural
mean matrix for the objectives.
cov: A `batch_shape x M x (N + P) x (N + P)`-dim Tensor containing
the natural mean matrix for the objectives.
N: The number of design points.
P: The number of Pareto optimal points.
M: The number of objectives.
Return:
A two-element tuple containing
- mean_fX_fS: A `batch_shape x M x (N + P) x P x 2`-dim Tensor containing the
means of the inputs and Pareto optimal points.
- cov_fX_fS: A `batch_shape x M x (N + P) x P x 2 x 2`-dim Tensor containing
the covariances between the inputs and Pareto optimal points.
"""
tkwargs = {"dtype": mean.dtype, "device": mean.device}
batch_shape = mean.shape[:-2]
# `batch_shape x M x (N + P) x P x 2 x 2`
cov_fX_fS = torch.zeros(batch_shape + torch.Size([M, N + P, P, 2, 2]), **tkwargs)
# `batch_shape x M x (N + P) x P x 2`
mean_fX_fS = torch.zeros(batch_shape + torch.Size([M, N + P, P, 2]), **tkwargs)
# `batch_shape x M x (N + P) x P`
mean_fX_fS[..., 0] = mean.unsqueeze(-1).expand(mean.shape + torch.Size([P]))
# `batch_shape x M x (N + P) x P`
mean_fX_fS[..., 1] = (
mean[..., N:].unsqueeze(-2).expand(mean.shape + torch.Size([P]))
)
# `batch_shape x M x (N + P) x P`
cov_fX_fS[..., 0, 0] = (
cov[..., range(N + P), range(N + P)]
.unsqueeze(-1)
.expand(batch_shape + torch.Size([M, N + P, P]))
)
# `batch_shape x M x (N + P) x P`
cov_fX_fS[..., 1, 1] = (
cov[..., range(N, N + P), range(N, N + P)]
.unsqueeze(-2)
.expand(batch_shape + torch.Size([M, N + P, P]))
)
for p in range(P):
# `batch_shape x M x (N + P)`
cov_p = cov[..., range(N + P), N + p]
cov_fX_fS[..., p, 0, 1] = cov_p
cov_fX_fS[..., p, 1, 0] = cov_p
return mean_fX_fS, cov_fX_fS
def _replace_pareto_diagonal(A: Tensor) -> Tensor:
"""Replace the pareto diagonal with identity matricx.
The Pareto diagonal of the omega factor shouldn't be updated because does not
contribute anything: `omega(x_p, x_p) = 1` for any pareto optimal input `x_p`.
Args:
A: a `batch_shape x M x (N + P) x P x 2 x 2`-dim Tensor.
Returns:
A `batch_shape x M x (N + P) x P x 2 x 2`-dim Tensor, where the Pareto
diagonal is padded with identity matrices.
"""
tkwargs = {"dtype": A.dtype, "device": A.device}
batch_shape = A.shape[:-5]
P = A.shape[-3]
N = A.shape[-4] - P
M = A.shape[-5]
identity = torch.diag_embed(torch.ones(batch_shape + torch.Size([M, 2]), **tkwargs))
for p in range(P):
A[..., N + p, p, :, :] = identity
return A
def _update_omega(
mean_f: Tensor,
cov_f: Tensor,
omega_f_nat_mean: Tensor,
omega_f_nat_cov: Tensor,
N: int,
P: int,
M: int,
maximize: bool = True,
jitter: float = 1e-6,
) -> Tuple[Tensor, Tensor]:
r"""Computes the new omega factors by matching the moments.
Args:
mean_f: A `batch_shape x M x (N + P)`-dim Tensor containing the mean vector
for the objectives.
cov_f: A `batch_shape x M x (N + P) x (N + P)`-dim Tensor containing the
covariance matrix for the objectives.
omega_f_nat_mean: A `batch_shape x M x (N + P) x P x 2`-dim Tensor containing
the omega natural mean factors for the objective matrix.
omega_f_nat_cov: A `batch_shape x M x (N + P) x P x 2 x 2`-dim Tensor
containing the omega natural covariance factors for the objective matrix.
N: The number of design points.
M: The number of Pareto optimal points.
M: The number of objectives.
maximize: If true, we consider the Pareto maximum domination relation.
jitter: The jitter for the matrix inverse.
Return:
A two-element tuple containing
- omega_f_nat_mean_new: A `batch_shape x M x (N + P) x P x 2` containing the
new omega natural mean factors for the objective matrix.
- omega_f_nat_cov_new: A `batch_shape x M x (N + P) x P x 2 x 2` containing
the new omega natural covariance factors for the objective matrix.
"""
tkwargs = {"dtype": mean_f.dtype, "device": mean_f.device}
CLAMP_LB = torch.finfo(tkwargs["dtype"]).eps
NEG_INF = torch.finfo(tkwargs["dtype"]).min
weight = 1.0 if maximize else -1.0
###############################################################################
# EXTRACT THE NECESSARY COMPONENTS
###############################################################################
# `batch_shape x M x (N + P) x P x 2`-dim mean
# `batch_shape x M x (N + P) x P x 2 x 2`-dim covariance
mean_fX_fS, cov_fX_fS = _get_omega_f_contribution(mean_f, cov_f, N, P, M)
identity = torch.diag_embed(torch.ones(cov_fX_fS.shape[:-1], **tkwargs))
# remove the Pareto diagonal
cov_fX_fS = _replace_pareto_diagonal(cov_fX_fS + jitter * identity)
nat_cov_fX_fS = torch.inverse(cov_fX_fS)
nat_mean_fX_fS = torch.einsum("...ij,...j->...i", nat_cov_fX_fS, mean_fX_fS)
###############################################################################
# COMPUTE THE CAVITIES
###############################################################################
# cavity distribution
# natural parameters
cav_nat_mean_f = nat_mean_fX_fS - omega_f_nat_mean
cav_nat_cov_f = nat_cov_fX_fS - omega_f_nat_cov
# transform to standard parameters
# remove the Pareto diagonal
cav_nat_cov_f = _replace_pareto_diagonal(cav_nat_cov_f)
identity = torch.diag_embed(torch.ones(cav_nat_cov_f.shape[:-1], **tkwargs))
cav_cov_f = torch.inverse(cav_nat_cov_f + jitter * identity)
cav_mean_f = torch.einsum("...ij,...j->...i", cav_cov_f, cav_nat_mean_f)
###############################################################################
# COMPUTE THE NORMALIZATION CONSTANT
###############################################################################
# `batch_shape x M x (N + P) x P`
# Equation 29
cav_var_fX_minus_fS = (
cav_cov_f[..., 0, 0] + cav_cov_f[..., 1, 1] - 2 * cav_cov_f[..., 0, 1]
).clamp_min(CLAMP_LB)
cav_std_fX_minus_fS = torch.sqrt(cav_var_fX_minus_fS).clamp_min(CLAMP_LB)
# `batch_shape x M x (N + P) x P`
cav_mean_fX_minus_fS = weight * (cav_mean_f[..., 0] - cav_mean_f[..., 1])
# Equation 30
cav_alpha = cav_mean_fX_minus_fS / cav_std_fX_minus_fS
# compute alpha pdf and cdf
normal_alpha = Normal(torch.zeros_like(cav_alpha), torch.ones_like(cav_alpha))
# `batch_shape x M x (N + P) x P`
cav_alpha_log_cdf = log_cdf_robust(cav_alpha)
# `batch_shape x M x (N + P) x P`
cav_alpha_log_pdf = normal_alpha.log_prob(cav_alpha).clamp_min(NEG_INF)
# `batch_shape x (N + P) x P`
cav_sum_alpha_log_cdf = torch.sum(cav_alpha_log_cdf, dim=-3).clamp_min(NEG_INF)
# compute normalization constant Z
# Equation 35
cav_log_zeta = torch.log1p(-torch.exp(cav_sum_alpha_log_cdf)).clamp_min(NEG_INF)
# Need to clamp log values to prevent `exp(-inf) = nan`
cav_logZ = cav_log_zeta
# Equation 40 [first bit]
# `batch_shape x (N + P) x P`
cav_log_rho = -cav_logZ + cav_sum_alpha_log_cdf
# Equation 40 [second bit]
# `batch_shape x M x (N + P) x P`
cav_log_rho = cav_log_rho.unsqueeze(-3) - cav_alpha_log_cdf + cav_alpha_log_pdf
cav_rho = -torch.exp(cav_log_rho).clamp(NEG_INF, -NEG_INF)
###############################################################################
# COMPUTE THE PARTIAL DERIVATIVES
###############################################################################
# `batch_shape x M x (N + P) x P x 2`
# Final vector: `[1, -1]`
ones_mean = torch.ones(cav_mean_f.shape, **tkwargs)
ones_mean[..., 1] = -ones_mean[..., 1]
# `batch_shape x M x (N + P) x P x 2 x 2`
# Final matrix: `[[1, -1], [-1, 1]]`
ones_cov = torch.ones(cav_cov_f.shape, **tkwargs)
ones_cov[..., 0, 1] = -ones_cov[..., 0, 1]
ones_cov[..., 1, 0] = -ones_cov[..., 1, 0]
# first partial derivation of the log Z with respect to the mean
# assuming maximization (this is also where the sign will change)
# Equation 41
cav_dlogZ_dm = cav_rho / cav_std_fX_minus_fS
cav_dlogZ_dm = weight * cav_dlogZ_dm.unsqueeze(-1) * ones_mean
# second partial derivation of the log Z with respect to the mean
# Equation 42
cav_d2logZ_dm2 = -cav_rho * (cav_rho + cav_alpha) / cav_var_fX_minus_fS
cav_d2logZ_dm2 = cav_d2logZ_dm2.unsqueeze(-1).unsqueeze(-1) * ones_cov
###############################################################################
# COMPUTE THE NEW MEAN AND COVARIANCE
###############################################################################
# compute the new mean and covariance
cav_updated_mean_f = cav_mean_f + torch.einsum(
"...ij,...j->...i", cav_cov_f, cav_dlogZ_dm
)
cav_updated_cov_f = cav_cov_f + torch.einsum(
"...ij,...jk,...kl->...il", cav_cov_f, cav_d2logZ_dm2, cav_cov_f
)
# transform to natural parameters
# remove the Pareto diagonal
cav_updated_cov_f = _replace_pareto_diagonal(cav_updated_cov_f)
identity = torch.diag_embed(torch.ones(cav_updated_cov_f.shape[:-1], **tkwargs))
cav_updated_nat_cov_f = torch.inverse(cav_updated_cov_f + jitter * identity)
cav_updated_nat_mean_f = torch.einsum(
"...ij,...j->...i", cav_updated_nat_cov_f, cav_updated_mean_f
)
# match the moments to compute the gain
omega_f_nat_mean_new = cav_updated_nat_mean_f - cav_nat_mean_f
omega_f_nat_cov_new = cav_updated_nat_cov_f - cav_nat_cov_f
# it is also possible to calculate the update directly as in the original paper:
# identity = torch.diag_embed(torch.ones(cav_d2logZ_dm2.shape[:-1], **tkwargs))
# denominator = torch.inverse(cav_cov_f @ cav_d2logZ_dm2 + identity)
# omega_f_nat_cov_new = - cav_d2logZ_dm2 @ denominator
# omega_f_nat_mean_new = torch.einsum(
# '...ij,...j->...i', denominator,
# cav_dlogZ_dm - torch.einsum('...ij,...j->...i', cav_d2logZ_dm2, cav_mean_f)
# )
return omega_f_nat_mean_new, omega_f_nat_cov_new
def _safe_update_omega(
mean_f: Tensor,
cov_f: Tensor,
omega_f_nat_mean: Tensor,
omega_f_nat_cov: Tensor,
N: int,
P: int,
M: int,
maximize: bool = True,
jitter: float = 1e-6,
) -> Tuple[Tensor, Tensor]:
r"""Try to update the new omega factors by matching the moments. If the update
is not possible then this returns the initial omega factors.
Args:
mean_f: A `batch_shape x M x (N + P)`-dim Tensor containing the mean vector
for the objectives.
cov_f: A `batch_shape x M x (N + P) x (N + P)`-dim Tensor containing the
covariance matrix for the objectives.
omega_f_nat_mean: A `batch_shape x M x (N + P) x P x 2`-dim Tensor containing
the omega natural mean factors for the objective matrix.
omega_f_nat_cov: A `batch_shape x M x (N + P) x P x 2 x 2`-dim Tensor
containing the omega natural covariance factors for the objective
matrix.
N: The number of design points.
M: The number of Pareto optimal points.
M: The number of objectives.
maximize: If true, we consider the Pareto maximum domination relation.
jitter: The jitter for the matrix inverse.
Return:
A two-element tuple containing
- omega_f_nat_mean_new: A `batch_shape x M x (N + P) x P x 2` containing the
new omega natural mean factors for the objective matrix.
- omega_f_nat_cov_new: A `batch_shape x M x (N + P) x P x 2 x 2` containing
the new omega natural covariance factors for the objective matrix.
"""
try:
omega_f_nat_mean_new, omega_f_nat_cov_new = _update_omega(
mean_f=mean_f,
cov_f=cov_f,
omega_f_nat_mean=omega_f_nat_mean,
omega_f_nat_cov=omega_f_nat_cov,
N=N,
P=P,
M=M,
maximize=maximize,
jitter=jitter,
)
check_no_nans(omega_f_nat_mean_new)
check_no_nans(omega_f_nat_cov_new)
return omega_f_nat_mean_new, omega_f_nat_cov_new
except RuntimeError or InputDataError:
return omega_f_nat_mean, omega_f_nat_cov
def _update_marginals(
pred_f_nat_mean: Tensor,
pred_f_nat_cov: Tensor,
omega_f_nat_mean: Tensor,
omega_f_nat_cov: Tensor,
N: int,
P: int,
) -> Tuple[Tensor, Tensor]:
r"""Computes the new marginal by summing up all the natural factors.
Args:
pred_f_nat_mean: A `batch_shape x M x (N + P)`-dim Tensor containing the
natural predictive mean matrix for the objectives.
pred_f_nat_cov: A `batch_shape x M x (N + P) x (N + P)`-dim Tensor containing
the natural predictive covariance matrix for the objectives.
omega_f_nat_mean: A `batch_shape x M x (N + P) x P x 2`-dim Tensor containing
the omega natural mean factors for the objective matrix.
omega_f_nat_cov: A `batch_shape x M x (N + P) x P x 2 x 2`-dim Tensor
containing the omega natural covariance factors for the objective matrix.
N: The number of design points.
P: The number of Pareto optimal points.
Returns:
A two-element tuple containing
- nat_mean_f: A `batch_shape x M x (N + P)`-dim Tensor containing the updated
natural mean matrix for the objectives.
- nat_cov_f: A `batch_shape x M x (N + P) x (N + P)`-dim Tensor containing
the updated natural predictive covariance matrix for the objectives.
"""
# `batch_shape x M x (N + P)`
nat_mean_f = pred_f_nat_mean.clone()
# `batch_shape x M x (N + P) x (N + P)
nat_cov_f = pred_f_nat_cov.clone()
################################################################################
# UPDATE THE OBJECTIVES
################################################################################
# remove Pareto diagonal
# zero out the diagonal
omega_f_nat_mean[..., range(N, N + P), range(P), :] = 0
omega_f_nat_cov[..., range(N, N + P), range(P), :, :] = 0
# `batch_shape x M x (N + P)`
# sum over the pareto dim
nat_mean_f = nat_mean_f + omega_f_nat_mean[..., 0].sum(dim=-1)
# `batch_shape x M x P`
# sum over the data dim
nat_mean_f[..., N:] = nat_mean_f[..., N:] + omega_f_nat_mean[..., 1].sum(dim=-2)
# `batch_shape x M x (N + P)`
nat_cov_f[..., range(N + P), range(N + P)] = nat_cov_f[
..., range(N + P), range(N + P)
] + omega_f_nat_cov[..., 0, 0].sum(dim=-1)
# `batch_shape x M x P`
nat_cov_f[..., range(N, N + P), range(N, N + P)] = nat_cov_f[
..., range(N, N + P), range(N, N + P)
] + omega_f_nat_cov[..., 1, 1].sum(dim=-2)
for p in range(P):
# `batch_shape x M x (N + P)`
nat_cov_f[..., range(N + P), N + p] = (
nat_cov_f[..., range(N + P), N + p] + omega_f_nat_cov[..., p, 0, 1]
)
# `batch_shape x M x (N + P)`
nat_cov_f[..., N + p, range(N + P)] = (
nat_cov_f[..., N + p, range(N + P)] + omega_f_nat_cov[..., p, 1, 0]
)
return nat_mean_f, nat_cov_f
def _damped_update(
old_factor: Tensor,
new_factor: Tensor,
damping_factor: Tensor,
) -> Tensor:
r"""Computes the damped updated for natural factor.
Args:
old_factor: A `batch_shape x param_shape`-dim Tensor containing the old
natural factor.
new_factor: A `batch_shape x param_shape`-dim Tensor containing the new
natural factor.
damping_factor: A `batch_shape`-dim Tensor containing the damping factor.
Returns:
A `batch_shape x param_shape`-dim Tensor containing the updated natural
factor.
"""
bs = damping_factor.shape
fs = old_factor.shape
df = damping_factor
for _ in range(len(fs[len(bs) :])):
df = df.unsqueeze(-1)
return df * new_factor + (1 - df) * old_factor
def _update_damping(
nat_cov: Tensor,
nat_cov_new: Tensor,
damping_factor: Tensor,
jitter: Tensor,
) -> Tuple[Tensor, Tensor]:
r"""Updates the damping factor whilst ensuring the covariance matrix is positive
definite by trying a Cholesky decomposition.
Args:
nat_cov: A `batch_shape x R x R`-dim Tensor containing the old natural
covariance matrix.
nat_cov_new: A `batch_shape x R x R`-dim Tensor containing the new natural
covariance matrix.
damping_factor: A`batch_shape`-dim Tensor containing the damping factor.
jitter: The amount of jitter added before matrix inversion.
Returns:
A two-element tuple containing
- A `batch_shape x param_shape`-dim Tensor containing the updated damping
factor.
- A `batch_shape x R x R`-dim Tensor containing the Cholesky factor.
"""
tkwargs = {"dtype": nat_cov.dtype, "device": nat_cov.device}
df = damping_factor
jitter = jitter * torch.diag_embed(torch.ones(nat_cov.shape[:-1], **tkwargs))
_, info = torch.linalg.cholesky_ex(nat_cov + jitter)
if torch.sum(info) > 1:
raise ValueError(
"The previous covariance is not positive semi-definite. "
"This usually happens if the predictive covariance is "
"ill-conditioned and the added jitter is insufficient."
)
damped_nat_cov = _damped_update(
old_factor=nat_cov, new_factor=nat_cov_new, damping_factor=df
)
cholesky_factor, info = torch.linalg.cholesky_ex(damped_nat_cov)
contains_nans = torch.any(torch.isnan(cholesky_factor)).item()
run = 0
while torch.sum(info) > 1 or contains_nans:
# propose an alternate damping factor which is half the original
df_alt = 0.5 * df
# hard threshold at 1e-3
df_alt = torch.where(
df_alt > 1e-3, df_alt, torch.zeros(df_alt.shape, **tkwargs)
)
# only change the damping factor where psd failure occurs
df_new = torch.where(info == 0, df, df_alt)
# new damped covariance
damped_nat_cov = _damped_update(nat_cov, nat_cov_new, df_new)
# try cholesky decomposition
cholesky_factor, info = torch.linalg.cholesky_ex(damped_nat_cov + jitter)
contains_nans = torch.any(torch.isnan(cholesky_factor)).item()
df = df_new
run = run + 1
return df, cholesky_factor
def _update_damping_when_converged(
mean_old: Tensor,
mean_new: Tensor,
cov_old: Tensor,
cov_new: Tensor,
damping_factor: Tensor,
iteration: Tensor,
threshold: float = 1e-3,
) -> Tensor:
r"""Set the damping factor to 0 once converged. Convergence is determined by the
relative change in the entries of the mean and covariance matrix.
Args:
mean_old: A `batch_shape x R`-dim Tensor containing the old natural mean
matrix for the objective.
mean_new: A `batch_shape x R`-dim Tensor containing the new natural mean
matrix for the objective.
cov_old: A `batch_shape x R x R`-dim Tensor containing the old natural
covariance matrix for the objective.
cov_new: A `batch_shape x R x R`-dim Tensor containing the new natural
covariance matrix for the objective.
iteration: The current iteration number
damping_factor: A `batch_shape`-dim Tensor containing the damping factor.
Returns:
A `batch_shape x param_shape`-dim Tensor containing the updated damping
factor.
"""
df = damping_factor.clone()
delta_mean = mean_new - mean_old
delta_cov = cov_new - cov_old
am = torch.amax(abs(mean_old), dim=-1)
ac = torch.amax(abs(cov_old), dim=(-2, -1))
if iteration > 2:
mask_mean = torch.amax(abs(delta_mean), dim=-1) < threshold * am
mask_cov = torch.amax(abs(delta_cov), dim=(-2, -1)) < threshold * ac
mask = torch.logical_and(mask_mean, mask_cov)
df[mask] = 0
return df, delta_mean, delta_cov
def _augment_factors_with_cached_factors(
q: int,
N: int,
omega_f_nat_mean: Tensor,
cached_omega_f_nat_mean: Tensor,
omega_f_nat_cov: Tensor,
cached_omega_f_nat_cov: Tensor,
) -> Tuple[Tensor, Tensor]:
r"""Incorporate the cached Pareto updated factors in the forward call and
augment them with the previously computed factors.
Args:
q: The batch size.
N: The number of training points.
omega_f_nat_mean: A `batch_shape x num_pareto_samples x M x (q + P) x P x 2`
-dim Tensor containing the omega natural mean for the objective at `X`.
cached_omega_f_nat_mean: A `num_pareto_samples x M x (N + P) x P x 2`-dim
Tensor containing the omega natural mean for the objective at `X`.
omega_f_nat_cov: A `batch_shape x num_pareto_samples x M x (q + P) x P x 2
x 2` -dim Tensor containing the omega natural covariance for the
objective at `X`.
cached_omega_f_nat_cov: A `num_pareto_samples x M x (N + P) x P x 2 x 2`-dim
Tensor containing the omega covariance mean for the objective at `X`.
Returns:
A two-element tuple containing
- omega_f_nat_mean_new: A `batch_shape x num_pareto_samples x M x (q + P)
x P x 2`-dim Tensor containing the omega natural mean for the objective
at `X`.
- omega_f_nat_cov_new: A `batch_shape x num_pareto_samples x M x (q + P) x
P x 2 x 2`-dim Tensor containing the omega natural covariance for the
objective at `X`.
"""
##############################################################################
# omega_f_nat_mean
##############################################################################
# retrieve the natural mean contribution of the Pareto block omega(x_p, x_p) for
# the objective
exp_cached_omega_f_nat_mean = cached_omega_f_nat_mean[..., N:, :, :].expand(
omega_f_nat_mean[..., q:, :, :].shape
)
omega_f_nat_mean[..., q:, :, :] = exp_cached_omega_f_nat_mean
##############################################################################
# omega_f_nat_cov
##############################################################################
# retrieve the natural covariance contribution of the Pareto block
# omega(x_p, x_p) for the objective
exp_omega_f_nat_cov = cached_omega_f_nat_cov[..., N:, :, :, :].expand(
omega_f_nat_cov[..., q:, :, :, :].shape
)
omega_f_nat_cov[..., q:, :, :, :] = exp_omega_f_nat_cov
return omega_f_nat_mean, omega_f_nat_cov
def _compute_log_determinant(cov: Tensor, q: int) -> Tensor:
r"""Computes the sum of the log determinants of a block diagonal covariance
matrices averaged over the Pareto samples.
Args:
cov: A `batch_shape x num_pareto_samples x num_outputs x (q + P) x (q + P)`
-dim Tensor containing the covariance matrices.
q: The batch size.
Return:
log_det_cov: A `batch_shape`-dim Tensor containing the sum of the
determinants for each output.
"""
log_det_cov = torch.logdet(cov[..., 0:q, 0:q])
check_no_nans(log_det_cov)
return log_det_cov.sum(dim=-1).mean(dim=-1)