#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Acquisition functions for joint entropy search for Bayesian optimization (JES).
References:
.. [Tu2022]
B. Tu, A. Gandy, N. Kantas and B.Shafei. Joint Entropy Search for Multi-Objective
Bayesian Optimization. Advances in Neural Information Processing Systems, 35.
2022.
"""
from __future__ import annotations
from abc import abstractmethod
from math import pi
from typing import Optional, Tuple, Union
import torch
from botorch import settings
from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
from botorch.exceptions.errors import UnsupportedError
from botorch.models.model import Model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.utils import fantasize as fantasize_flag
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
from torch import Tensor
from torch.distributions import Normal
[docs]
class LowerBoundMultiObjectiveEntropySearch(AcquisitionFunction, MCSamplerMixin):
r"""Abstract base class for the lower bound multi-objective entropy search
acquisition functions.
"""
def __init__(
self,
model: Model,
pareto_sets: Tensor,
pareto_fronts: Tensor,
hypercell_bounds: Tensor,
X_pending: Optional[Tensor] = None,
estimation_type: str = "LB",
num_samples: int = 64,
) -> None:
r"""Lower bound multi-objective entropy search acquisition function.
Args:
model: A fitted batch model with 'M' number of outputs.
pareto_sets: A `num_pareto_samples x num_pareto_points x d`-dim Tensor
containing the sampled Pareto optimal sets of inputs.
pareto_fronts: A `num_pareto_samples x num_pareto_points x M`-dim Tensor
containing the sampled Pareto optimal sets of outputs.
hypercell_bounds: A `num_pareto_samples x 2 x J x M`-dim Tensor
containing the hyper-rectangle bounds for integration, where `J` is
the number of hyper-rectangles. In the unconstrained case, this gives
the partition of the dominated space. In the constrained case, this
gives the partition of the feasible dominated space union the
infeasible space.
X_pending: A `m x d`-dim Tensor of `m` design points that have been
submitted for function evaluation, but have not yet been evaluated.
estimation_type: A string to determine which entropy estimate is
computed: "0", "LB", "LB2", or "MC".
num_samples: The number of Monte Carlo samples for the Monte Carlo
estimate.
"""
super().__init__(model=model)
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([num_samples]))
MCSamplerMixin.__init__(self, sampler=sampler)
# Batch GP models (e.g. fantasized models) are not currently supported
if isinstance(model, ModelListGP):
train_X = model.models[0].train_inputs[0]
else:
train_X = model.train_inputs[0]
if (model.num_outputs > 1 and train_X.ndim > 3) or (
model.num_outputs == 1 and train_X.ndim > 2
):
raise NotImplementedError(
"Batch GP models (e.g. fantasized models) are not supported."
)
self.initial_model = model
if (pareto_sets is not None and pareto_sets.ndim != 3) or (
pareto_fronts is not None and pareto_fronts.ndim != 3
):
raise UnsupportedError(
"The Pareto set and front should have a shape of "
"`num_pareto_samples x num_pareto_points x input_dim` and "
"`num_pareto_samples x num_pareto_points x num_objectives`, "
"respectively"
)
else:
self.pareto_sets = pareto_sets
self.pareto_fronts = pareto_fronts
if hypercell_bounds.ndim != 4:
raise UnsupportedError(
"The hypercell_bounds should have a shape of "
"`num_pareto_samples x 2 x num_boxes x num_objectives`."
)
else:
self.hypercell_bounds = hypercell_bounds
self.num_pareto_samples = hypercell_bounds.shape[0]
self.estimation_type = estimation_type
estimation_types = ["0", "LB", "LB2", "MC"]
if estimation_type not in estimation_types:
raise NotImplementedError(
"Currently the only supported estimation type are: "
+ ", ".join(f'"{h}"' for h in estimation_types)
+ "."
)
self.set_X_pending(X_pending)
@abstractmethod
def _compute_posterior_statistics(
self, X: Tensor
) -> dict[str, Union[GPyTorchPosterior, Tensor]]:
r"""Compute the posterior statistics.
Args:
X: A `batch_shape x q x d`-dim Tensor of inputs.
Returns:
A dictionary containing the posterior variables used to estimate the
entropy.
- "initial_entropy": A `batch_shape`-dim Tensor containing the entropy of
the Gaussian random variable `p(Y| X, D_n)`.
- "posterior_mean": A `batch_shape x num_pareto_samples x q x 1 x M`-dim
Tensor containing the posterior mean at the input `X`.
- "posterior_variance": A `batch_shape x num_pareto_samples x q x 1 x M`
-dim Tensor containing the posterior variance at the input `X`
excluding the observation noise.
- "observation_noise": A `batch_shape x num_pareto_samples x q x 1 x M`
-dim Tensor containing the observation noise at the input `X`.
- "posterior_with_noise": The posterior distribution at `X` which
includes the observation noise. This is used to compute the marginal
log-probabilities with respect to `p(y| x, D_n)` for `x` in `X`.
"""
pass # pragma: no cover
@abstractmethod
def _compute_monte_carlo_variables(
self, posterior: GPyTorchPosterior
) -> Tuple[Tensor, Tensor]:
r"""Compute the samples and log-probability associated with a posterior
distribution.
Args:
posterior: A posterior distribution.
Returns:
A two-element tuple containing:
- samples: A `num_mc_samples x batch_shape x num_pareto_samples x q x 1
x M`-dim Tensor containing the Monte Carlo samples.
- samples_log_prob: A `num_mc_samples x batch_shape x num_pareto_samples
x q`-dim Tensor containing the log-probabilities of the Monte Carlo
samples.
"""
pass # pragma: no cover
def _compute_lower_bound_information_gain(self, X: Tensor) -> Tensor:
r"""Evaluates the lower bound information gain at the design points `X`.
Args:
X: A `batch_shape x q x d`-dim Tensor of `batch_shape` t-batches with `q`
`d`-dim design points each.
Returns:
A `batch_shape`-dim Tensor of acquisition values at the given design
points `X`.
"""
posterior_statistics = self._compute_posterior_statistics(X)
initial_entropy = posterior_statistics["initial_entropy"]
post_mean = posterior_statistics["posterior_mean"]
post_var = posterior_statistics["posterior_variance"]
obs_noise = posterior_statistics["observation_noise"]
# Estimate the expected conditional entropy.
# `batch_shape x q` dim Tensor of entropy estimates
if self.estimation_type == "0":
conditional_entropy = _compute_entropy_noiseless(
hypercell_bounds=self.hypercell_bounds,
mean=post_mean,
variance=post_var,
observation_noise=obs_noise,
)
elif self.estimation_type == "LB":
conditional_entropy = _compute_entropy_upper_bound(
hypercell_bounds=self.hypercell_bounds,
mean=post_mean,
variance=post_var,
observation_noise=obs_noise,
only_diagonal=False,
)
elif self.estimation_type == "LB2":
conditional_entropy = _compute_entropy_upper_bound(
hypercell_bounds=self.hypercell_bounds,
mean=post_mean,
variance=post_var,
observation_noise=obs_noise,
only_diagonal=True,
)
elif self.estimation_type == "MC":
posterior_with_noise = posterior_statistics["posterior_with_noise"]
samples, samples_log_prob = self._compute_monte_carlo_variables(
posterior_with_noise
)
conditional_entropy = _compute_entropy_monte_carlo(
hypercell_bounds=self.hypercell_bounds,
mean=post_mean,
variance=post_var,
observation_noise=obs_noise,
samples=samples,
samples_log_prob=samples_log_prob,
)
# Sum over the batch.
return initial_entropy - conditional_entropy.sum(dim=-1)
[docs]
@abstractmethod
def forward(self, X: Tensor) -> Tensor:
r"""Compute lower bound multi-objective entropy search at the design points
`X`.
Args:
X: A `batch_shape x q x d`-dim Tensor of `batch_shape` t-batches with `q`
`d`-dim design points each.
Returns:
A `batch_shape`-dim Tensor of acquisition values at the given design
points `X`.
"""
pass # pragma: no cover
[docs]
class qLowerBoundMultiObjectiveJointEntropySearch(
LowerBoundMultiObjectiveEntropySearch
):
r"""The acquisition function for the multi-objective joint entropy search, where
the batches `q > 1` are supported through the lower bound formulation.
This acquisition function computes the mutual information between the observation
at a candidate point `X` and the Pareto optimal input-output pairs.
See [Tu2022]_ for a discussion on the estimation procedure.
NOTES:
(i) The estimated acquisition value could be negative.
(ii) The lower bound batch acquisition function might not be monotone in the
sense that adding more elements to the batch does not necessarily increase the
acquisition value. Specifically, the acquisition value can become smaller when
more inputs are added.
"""
def __init__(
self,
model: Model,
pareto_sets: Tensor,
pareto_fronts: Tensor,
hypercell_bounds: Tensor,
X_pending: Optional[Tensor] = None,
estimation_type: str = "LB",
num_samples: int = 64,
) -> None:
r"""Lower bound multi-objective joint entropy search acquisition function.
Args:
model: A fitted batch model with 'M' number of outputs.
pareto_sets: A `num_pareto_samples x num_pareto_points x d`-dim Tensor
containing the sampled Pareto optimal sets of inputs.
pareto_fronts: A `num_pareto_samples x num_pareto_points x M`-dim Tensor
containing the sampled Pareto optimal sets of outputs.
hypercell_bounds: A `num_pareto_samples x 2 x J x M`-dim Tensor
containing the hyper-rectangle bounds for integration. In the
unconstrained case, this gives the partition of the dominated space.
In the constrained case, this gives the partition of the feasible
dominated space union the infeasible space.
X_pending: A `m x d`-dim Tensor of `m` design points that have been
submitted for function evaluation, but have not yet been evaluated.
estimation_type: A string to determine which entropy estimate is
computed: "0", "LB", "LB2", or "MC".
num_samples: The number of Monte Carlo samples used for the Monte Carlo
estimate.
"""
super().__init__(
model=model,
pareto_sets=pareto_sets,
pareto_fronts=pareto_fronts,
hypercell_bounds=hypercell_bounds,
X_pending=X_pending,
estimation_type=estimation_type,
num_samples=num_samples,
)
# Condition the model on the sampled pareto optimal points.
# TODO: Apparently, we need to make a call to the posterior otherwise
# we run into a gpytorch runtime error:
# "Fantasy observations can only be added after making predictions with a
# model so that all test independent caches exist."
with fantasize_flag():
with settings.propagate_grads(False):
_ = self.initial_model.posterior(
self.pareto_sets, observation_noise=False
)
# Condition with observation noise.
self.conditional_model = self.initial_model.condition_on_observations(
X=self.initial_model.transform_inputs(self.pareto_sets),
Y=self.pareto_fronts,
)
def _compute_posterior_statistics(
self, X: Tensor
) -> dict[str, Union[Tensor, GPyTorchPosterior]]:
r"""Compute the posterior statistics.
Args:
X: A `batch_shape x q x d`-dim Tensor of inputs.
Returns:
A dictionary containing the posterior variables used to estimate the
entropy.
- "initial_entropy": A `batch_shape`-dim Tensor containing the entropy of
the Gaussian random variable `p(Y| X, D_n)`.
- "posterior_mean": A `batch_shape x num_pareto_samples x q x 1 x M`-dim
Tensor containing the posterior mean at the input `X`.
- "posterior_variance": A `batch_shape x num_pareto_samples x q x 1 x M`
-dim Tensor containing the posterior variance at the input `X`
excluding the observation noise.
- "observation_noise": A `batch_shape x num_pareto_samples x q x 1 x M`
-dim Tensor containing the observation noise at the input `X`.
- "posterior_with_noise": The posterior distribution at `X` which
includes the observation noise. This is used to compute the marginal
log-probabilities with respect to `p(y| x, D_n)` for `x` in `X`.
"""
tkwargs = {"dtype": X.dtype, "device": X.device}
CLAMP_LB = torch.finfo(tkwargs["dtype"]).eps
# Compute the prior entropy term depending on `X`.
initial_posterior_plus_noise = self.initial_model.posterior(
X, observation_noise=True
)
# Additional constant term.
add_term = (
0.5
* self.model.num_outputs
* (1 + torch.log(2 * pi * torch.ones(1, **tkwargs)))
)
# The variance initially has shape `batch_shape x (q*M) x (q*M)`
# prior_entropy has shape `batch_shape`.
initial_entropy = add_term + 0.5 * torch.logdet(
initial_posterior_plus_noise.mvn.covariance_matrix
)
posterior_statistics = {"initial_entropy": initial_entropy}
# Compute the posterior entropy term.
conditional_posterior_with_noise = self.conditional_model.posterior(
X.unsqueeze(-2).unsqueeze(-3), observation_noise=True
)
# `batch_shape x num_pareto_samples x q x 1 x M`
post_mean = conditional_posterior_with_noise.mean.swapaxes(-4, -3)
post_var_with_noise = conditional_posterior_with_noise.variance.clamp_min(
CLAMP_LB
).swapaxes(-4, -3)
# TODO: This computes the observation noise via a second evaluation of the
# posterior. This step could be done better.
conditional_posterior = self.conditional_model.posterior(
X.unsqueeze(-2).unsqueeze(-3), observation_noise=False
)
# `batch_shape x num_pareto_samples x q x 1 x M`
post_var = conditional_posterior.variance.clamp_min(CLAMP_LB).swapaxes(-4, -3)
obs_noise = (post_var_with_noise - post_var).clamp_min(CLAMP_LB)
posterior_statistics["posterior_mean"] = post_mean
posterior_statistics["posterior_variance"] = post_var
posterior_statistics["observation_noise"] = obs_noise
posterior_statistics["posterior_with_noise"] = conditional_posterior_with_noise
return posterior_statistics
def _compute_monte_carlo_variables(
self, posterior: GPyTorchPosterior
) -> Tuple[Tensor, Tensor]:
r"""Compute the samples and log-probability associated with the posterior
distribution that conditions on the Pareto optimal points.
Args:
posterior: The conditional posterior distribution at an input `X`, where
we have also conditioned over the `num_pareto_samples` of optimal
points. Note that this posterior includes the observation noise.
Returns:
A two-element tuple containing
- samples: A `num_mc_samples x batch_shape x num_pareto_samples x q x 1
x M`-dim Tensor containing the Monte Carlo samples.
- samples_log_probs: A `num_mc_samples x batch_shape x num_pareto_samples
x q`-dim Tensor containing the log-probabilities of the Monte Carlo
samples.
"""
# `num_mc_samples x batch_shape x q x num_pareto_samples x 1 x M`
samples = self.get_posterior_samples(posterior)
# `num_mc_samples x batch_shape x q x num_pareto_samples`
if self.model.num_outputs == 1:
samples_log_prob = posterior.mvn.log_prob(samples.squeeze(-1))
else:
samples_log_prob = posterior.mvn.log_prob(samples)
# Swap axes to get the correct shape:
# samples:`num_mc_samples x batch_shape x num_pareto_samples x q x 1 x M`
# log prob:`num_mc_samples x batch_shape x num_pareto_samples x q`
return samples.swapaxes(-4, -3), samples_log_prob.swapaxes(-2, -1)
[docs]
@concatenate_pending_points
@t_batch_mode_transform()
def forward(self, X: Tensor) -> Tensor:
r"""Evaluates qLowerBoundMultiObjectiveJointEntropySearch at the design
points `X`.
Args:
X: A `batch_shape x q x d`-dim Tensor of `batch_shape` t-batches with `q`
`d`-dim design points each.
Returns:
A `batch_shape`-dim Tensor of acquisition values at the given design
points `X`.
"""
return self._compute_lower_bound_information_gain(X)
def _compute_entropy_noiseless(
hypercell_bounds: Tensor,
mean: Tensor,
variance: Tensor,
observation_noise: Tensor,
) -> Tensor:
r"""Computes the entropy estimate at the design points `X` assuming noiseless
observations. This is used for the JES-0 and MES-0 estimate.
Args:
hypercell_bounds: A `num_pareto_samples x 2 x J x M` -dim Tensor containing
the box decomposition bounds, where `J = max(num_boxes)`.
mean: A `batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor containing
the posterior mean at X.
variance: A `batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor
containing the posterior variance at X excluding observation noise.
observation_noise: A `batch_shape x num_pareto_samples x q x 1 x M`-dim
Tensor containing the observation noise at X.
Returns:
A `batch_shape x q`-dim Tensor of entropy estimate at the given design points
`X`.
"""
tkwargs = {"dtype": hypercell_bounds.dtype, "device": hypercell_bounds.device}
CLAMP_LB = torch.finfo(tkwargs["dtype"]).eps
variance_plus_noise = variance + observation_noise
# Standardize the box decomposition bounds and compute normal quantities.
# `batch_shape x num_pareto_samples x q x 2 x J x M`
g = (hypercell_bounds.unsqueeze(-4) - mean.unsqueeze(-2)) / torch.sqrt(
variance.unsqueeze(-2)
)
normal = Normal(torch.zeros_like(g), torch.ones_like(g))
gcdf = normal.cdf(g)
gpdf = torch.exp(normal.log_prob(g))
g_times_gpdf = g * gpdf
# Compute the differences between the upper and lower terms.
Wjm = (gcdf[..., 1, :, :] - gcdf[..., 0, :, :]).clamp_min(CLAMP_LB)
Vjm = g_times_gpdf[..., 1, :, :] - g_times_gpdf[..., 0, :, :]
# Compute W.
Wj = torch.exp(torch.sum(torch.log(Wjm), dim=-1, keepdims=True))
W = torch.sum(Wj, dim=-2, keepdims=True).clamp_max(1.0)
# Compute the sum of ratios.
ratios = 0.5 * (Wj * (Vjm / Wjm)) / W
# `batch_shape x num_pareto_samples x q x 1 x 1`
ratio_term = torch.sum(ratios, dim=(-2, -1), keepdims=True)
# Compute the logarithm of the variance.
log_term = 0.5 * torch.log(variance_plus_noise).sum(-1, keepdims=True)
# `batch_shape x num_pareto_samples x q x 1 x 1`
log_term = log_term + torch.log(W)
# Additional constant term.
M_plus_K = mean.shape[-1]
add_term = 0.5 * M_plus_K * (1 + torch.log(torch.ones(1, **tkwargs) * 2 * pi))
# `batch_shape x num_pareto_samples x q`
entropy = add_term + (log_term - ratio_term).squeeze(-1).squeeze(-1)
return entropy.mean(-2)
def _compute_entropy_upper_bound(
hypercell_bounds: Tensor,
mean: Tensor,
variance: Tensor,
observation_noise: Tensor,
only_diagonal: bool = False,
) -> Tensor:
r"""Computes the entropy upper bound at the design points `X`. This is used for
the JES-LB and MES-LB estimate. If `only_diagonal` is True, then this computes
the entropy estimate for the JES-LB2 and MES-LB2.
Args:
hypercell_bounds: A `num_pareto_samples x 2 x J x M` -dim Tensor containing
the box decomposition bounds, where `J` = max(num_boxes).
mean: A `batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor containing
the posterior mean at X.
variance: A `batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor
containing the posterior variance at X excluding observation noise.
observation_noise: A `batch_shape x num_pareto_samples x q x 1 x M`-dim
Tensor containing the observation noise at X.
only_diagonal: If true, we only compute the diagonal elements of the variance.
Returns:
A `batch_shape x q`-dim Tensor of entropy estimate at the given design points
`X`.
"""
tkwargs = {"dtype": hypercell_bounds.dtype, "device": hypercell_bounds.device}
CLAMP_LB = torch.finfo(tkwargs["dtype"]).eps
variance_plus_noise = variance + observation_noise
# Standardize the box decomposition bounds and compute normal quantities.
# `batch_shape x num_pareto_samples x q x 2 x J x M`
g = (hypercell_bounds.unsqueeze(-4) - mean.unsqueeze(-2)) / torch.sqrt(
variance.unsqueeze(-2)
)
normal = Normal(torch.zeros_like(g), torch.ones_like(g))
gcdf = normal.cdf(g)
gpdf = torch.exp(normal.log_prob(g))
g_times_gpdf = g * gpdf
# Compute the differences between the upper and lower terms.
Wjm = (gcdf[..., 1, :, :] - gcdf[..., 0, :, :]).clamp_min(CLAMP_LB)
Vjm = g_times_gpdf[..., 1, :, :] - g_times_gpdf[..., 0, :, :]
Gjm = gpdf[..., 1, :, :] - gpdf[..., 0, :, :]
# Compute W.
Wj = torch.exp(torch.sum(torch.log(Wjm), dim=-1, keepdims=True))
W = torch.sum(Wj, dim=-2, keepdims=True).clamp_max(1.0)
Cjm = Gjm / Wjm
# First moment:
Rjm = Cjm * Wj / W
# `batch_shape x num_pareto_samples x q x 1 x M
mom1 = mean - torch.sqrt(variance) * Rjm.sum(-2, keepdims=True)
# diagonal weighted sum
# `batch_shape x num_pareto_samples x q x 1 x M
diag_weighted_sum = (Wj * variance * Vjm / Wjm / W).sum(-2, keepdims=True)
if only_diagonal:
# `batch_shape x num_pareto_samples x q x 1 x M`
mean_squared = mean.pow(2)
cross_sum = -2 * (mean * torch.sqrt(variance) * Rjm).sum(-2, keepdims=True)
# `batch_shape x num_pareto_samples x q x 1 x M`
mom2 = variance_plus_noise - diag_weighted_sum + cross_sum + mean_squared
var = (mom2 - mom1.pow(2)).clamp_min(CLAMP_LB)
# `batch_shape x num_pareto_samples x q
log_det_term = 0.5 * torch.log(var).sum(dim=-1).squeeze(-1)
else:
# First moment x First moment
# `batch_shape x num_pareto_samples x q x 1 x M x M
cross_mom1 = torch.einsum("...i,...j->...ij", mom1, mom1)
# Second moment:
# `batch_shape x num_pareto_samples x q x 1 x M x M
# firstly compute the general terms
mom2_cross1 = -torch.einsum(
"...i,...j->...ij", mean, torch.sqrt(variance) * Cjm
)
mom2_cross2 = -torch.einsum(
"...i,...j->...ji", mean, torch.sqrt(variance) * Cjm
)
mom2_mean_squared = torch.einsum("...i,...j->...ij", mean, mean)
mom2_weighted_sum = (
(mom2_cross1 + mom2_cross2) * Wj.unsqueeze(-1) / W.unsqueeze(-1)
).sum(-3, keepdims=True)
mom2_weighted_sum = mom2_weighted_sum + mom2_mean_squared
# Compute the additional off-diagonal terms.
mom2_off_diag = torch.einsum(
"...i,...j->...ij", torch.sqrt(variance) * Cjm, torch.sqrt(variance) * Cjm
)
mom2_off_diag_sum = (mom2_off_diag * Wj.unsqueeze(-1) / W.unsqueeze(-1)).sum(
-3, keepdims=True
)
# Compute the diagonal terms and subtract the diagonal computed before.
init_diag = torch.diagonal(mom2_off_diag_sum, dim1=-2, dim2=-1)
diag_weighted_sum = torch.diag_embed(
variance_plus_noise - diag_weighted_sum - init_diag
)
mom2 = mom2_weighted_sum + mom2_off_diag_sum + diag_weighted_sum
# Compute the variance
var = (mom2 - cross_mom1).squeeze(-3)
# Jitter the diagonal.
# The jitter is probably not needed here at all.
jitter_diag = 1e-6 * torch.diag_embed(torch.ones(var.shape[:-1], **tkwargs))
log_det_term = 0.5 * torch.logdet(var + jitter_diag)
# Additional terms.
M_plus_K = mean.shape[-1]
add_term = 0.5 * M_plus_K * (1 + torch.log(torch.ones(1, **tkwargs) * 2 * pi))
# `batch_shape x num_pareto_samples x q
entropy = add_term + log_det_term
return entropy.mean(-2)
def _compute_entropy_monte_carlo(
hypercell_bounds: Tensor,
mean: Tensor,
variance: Tensor,
observation_noise: Tensor,
samples: Tensor,
samples_log_prob: Tensor,
) -> Tensor:
r"""Computes the Monte Carlo entropy at the design points `X`. This is used for
the JES-MC and MES-MC estimate.
Args:
hypercell_bounds: A `num_pareto_samples x 2 x J x M`-dim Tensor containing
the box decomposition bounds, where `J` = max(num_boxes).
mean: A `batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor containing
the posterior mean at X.
variance: A `batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor
containing the posterior variance at X excluding observation noise.
observation_noise: A `batch_shape x num_pareto_samples x q x 1 x M`-dim
Tensor containing the observation noise at X.
samples: A `num_mc_samples x batch_shape x num_pareto_samples x q x 1 x M`-dim
Tensor containing the noisy samples at `X` from the posterior conditioned
on the Pareto optimal points.
samples_log_prob: A `num_mc_samples x batch_shape x num_pareto_samples
x q`-dim Tensor containing the log probability densities of the samples.
Returns:
A `batch_shape x q`-dim Tensor of entropy estimate at the given design points
`X`.
"""
tkwargs = {"dtype": hypercell_bounds.dtype, "device": hypercell_bounds.device}
CLAMP_LB = torch.finfo(tkwargs["dtype"]).eps
variance_plus_noise = variance + observation_noise
####################################################################
# Standardize the box decomposition bounds and compute normal quantities.
# `batch_shape x num_pareto_samples x q x 2 x J x M`
g = (hypercell_bounds.unsqueeze(-4) - mean.unsqueeze(-2)) / torch.sqrt(
variance.unsqueeze(-2)
)
# `batch_shape x num_pareto_samples x q x 1 x M`
rho = torch.sqrt(variance / variance_plus_noise)
# Compute the initial normal quantities.
normal = Normal(torch.zeros_like(g), torch.ones_like(g))
gcdf = normal.cdf(g)
# Compute the differences between the upper and lower terms.
Wjm = (gcdf[..., 1, :, :] - gcdf[..., 0, :, :]).clamp_min(CLAMP_LB)
# Compute W.
Wj = torch.exp(torch.sum(torch.log(Wjm), dim=-1, keepdims=True))
# `batch_shape x num_pareto_samples x q x 1 x 1`
W = torch.sum(Wj, dim=-2, keepdims=True).clamp_max(1.0)
####################################################################
g = g.unsqueeze(0)
rho = rho.unsqueeze(0).unsqueeze(-2)
# `num_mc_samples x batch_shape x num_pareto_samples x q x 1 x 1 x M`
z = ((samples - mean) / torch.sqrt(variance_plus_noise)).unsqueeze(-2)
# `num_mc_samples x batch_shape x num_pareto_samples x q x 2 x J x M`
# Clamping here is important because `1 - rho^2 = 0` at an input where
# observation noise is zero.
g_new = (g - rho * z) / torch.sqrt((1 - rho * rho).clamp_min(CLAMP_LB))
# Compute the initial normal quantities.
normal_new = Normal(torch.zeros_like(g_new), torch.ones_like(g_new))
gcdf_new = normal_new.cdf(g_new)
# Compute the differences between the upper and lower terms.
Wjm_new = (gcdf_new[..., 1, :, :] - gcdf_new[..., 0, :, :]).clamp_min(CLAMP_LB)
# Compute W+.
Wj_new = torch.exp(torch.sum(torch.log(Wjm_new), dim=-1, keepdims=True))
# `num_mc_samples x batch_shape x num_pareto_samples x q x 1 x 1`
W_new = torch.sum(Wj_new, dim=-2, keepdims=True).clamp_max(1.0)
####################################################################
# W_ratio = W+ / W
W_ratio = torch.exp(torch.log(W_new) - torch.log(W).unsqueeze(0))
samples_log_prob = samples_log_prob.unsqueeze(-1).unsqueeze(-1)
# Compute the Monte Carlo average: - E[W_ratio * log(W+ p(y))] + log(W)
log_term = torch.log(W_new) + samples_log_prob
mc_estimate = -(W_ratio * log_term).mean(0)
# `batch_shape x num_pareto_samples x q
entropy = (mc_estimate + torch.log(W)).squeeze(-1).squeeze(-1)
# An alternative Monte Carlo estimate: - E[W_ratio * log(W_ratio p(y))]
# log_term = torch.log(W_ratio) + samples_log_prob
# mc_estimate = - (W_ratio * log_term).mean(0)
# # `batch_shape x num_pareto_samples x q
# entropy = mc_estimate.squeeze(-1).squeeze(-1)
return entropy.mean(-2)