Source code for botorch.acquisition.multi_objective.joint_entropy_search

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Acquisition functions for joint entropy search for Bayesian optimization (JES).

References:

.. [Tu2022]
    B. Tu, A. Gandy, N. Kantas and B.Shafei. Joint Entropy Search for Multi-Objective
    Bayesian Optimization. Advances in Neural Information Processing Systems, 35.
    2022.

"""

from __future__ import annotations

from abc import abstractmethod
from math import pi

import torch
from botorch import settings
from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
from botorch.exceptions.errors import UnsupportedError
from botorch.models.model import Model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.utils import fantasize as fantasize_flag
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
from torch import Tensor
from torch.distributions import Normal


[docs] class LowerBoundMultiObjectiveEntropySearch(AcquisitionFunction, MCSamplerMixin): r"""Abstract base class for the lower bound multi-objective entropy search acquisition functions. """ def __init__( self, model: Model, pareto_sets: Tensor, pareto_fronts: Tensor, hypercell_bounds: Tensor, X_pending: Tensor | None = None, estimation_type: str = "LB", num_samples: int = 64, ) -> None: r"""Lower bound multi-objective entropy search acquisition function. Args: model: A fitted batch model with 'M' number of outputs. pareto_sets: A `num_pareto_samples x num_pareto_points x d`-dim Tensor containing the sampled Pareto optimal sets of inputs. pareto_fronts: A `num_pareto_samples x num_pareto_points x M`-dim Tensor containing the sampled Pareto optimal sets of outputs. hypercell_bounds: A `num_pareto_samples x 2 x J x M`-dim Tensor containing the hyper-rectangle bounds for integration, where `J` is the number of hyper-rectangles. In the unconstrained case, this gives the partition of the dominated space. In the constrained case, this gives the partition of the feasible dominated space union the infeasible space. X_pending: A `m x d`-dim Tensor of `m` design points that have been submitted for function evaluation, but have not yet been evaluated. estimation_type: A string to determine which entropy estimate is computed: "0", "LB", "LB2", or "MC". num_samples: The number of Monte Carlo samples for the Monte Carlo estimate. """ super().__init__(model=model) sampler = SobolQMCNormalSampler(sample_shape=torch.Size([num_samples])) MCSamplerMixin.__init__(self, sampler=sampler) # Batch GP models (e.g. fantasized models) are not currently supported if isinstance(model, ModelListGP): train_X = model.models[0].train_inputs[0] else: train_X = model.train_inputs[0] if (model.num_outputs > 1 and train_X.ndim > 3) or ( model.num_outputs == 1 and train_X.ndim > 2 ): raise NotImplementedError( "Batch GP models (e.g. fantasized models) are not supported." ) self.initial_model = model if (pareto_sets is not None and pareto_sets.ndim != 3) or ( pareto_fronts is not None and pareto_fronts.ndim != 3 ): raise UnsupportedError( "The Pareto set and front should have a shape of " "`num_pareto_samples x num_pareto_points x input_dim` and " "`num_pareto_samples x num_pareto_points x num_objectives`, " "respectively" ) else: self.pareto_sets = pareto_sets self.pareto_fronts = pareto_fronts if hypercell_bounds.ndim != 4: raise UnsupportedError( "The hypercell_bounds should have a shape of " "`num_pareto_samples x 2 x num_boxes x num_objectives`." ) else: self.hypercell_bounds = hypercell_bounds self.num_pareto_samples = hypercell_bounds.shape[0] self.estimation_type = estimation_type estimation_types = ["0", "LB", "LB2", "MC"] if estimation_type not in estimation_types: raise NotImplementedError( "Currently the only supported estimation type are: " + ", ".join(f'"{h}"' for h in estimation_types) + "." ) self.set_X_pending(X_pending) @abstractmethod def _compute_posterior_statistics( self, X: Tensor ) -> dict[str, GPyTorchPosterior | Tensor]: r"""Compute the posterior statistics. Args: X: A `batch_shape x q x d`-dim Tensor of inputs. Returns: A dictionary containing the posterior variables used to estimate the entropy. - "initial_entropy": A `batch_shape`-dim Tensor containing the entropy of the Gaussian random variable `p(Y| X, D_n)`. - "posterior_mean": A `batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor containing the posterior mean at the input `X`. - "posterior_variance": A `batch_shape x num_pareto_samples x q x 1 x M` -dim Tensor containing the posterior variance at the input `X` excluding the observation noise. - "observation_noise": A `batch_shape x num_pareto_samples x q x 1 x M` -dim Tensor containing the observation noise at the input `X`. - "posterior_with_noise": The posterior distribution at `X` which includes the observation noise. This is used to compute the marginal log-probabilities with respect to `p(y| x, D_n)` for `x` in `X`. """ pass # pragma: no cover @abstractmethod def _compute_monte_carlo_variables( self, posterior: GPyTorchPosterior ) -> tuple[Tensor, Tensor]: r"""Compute the samples and log-probability associated with a posterior distribution. Args: posterior: A posterior distribution. Returns: A two-element tuple containing: - samples: A `num_mc_samples x batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor containing the Monte Carlo samples. - samples_log_prob: A `num_mc_samples x batch_shape x num_pareto_samples x q`-dim Tensor containing the log-probabilities of the Monte Carlo samples. """ pass # pragma: no cover def _compute_lower_bound_information_gain(self, X: Tensor) -> Tensor: r"""Evaluates the lower bound information gain at the design points `X`. Args: X: A `batch_shape x q x d`-dim Tensor of `batch_shape` t-batches with `q` `d`-dim design points each. Returns: A `batch_shape`-dim Tensor of acquisition values at the given design points `X`. """ posterior_statistics = self._compute_posterior_statistics(X) initial_entropy = posterior_statistics["initial_entropy"] post_mean = posterior_statistics["posterior_mean"] post_var = posterior_statistics["posterior_variance"] obs_noise = posterior_statistics["observation_noise"] # Estimate the expected conditional entropy. # `batch_shape x q` dim Tensor of entropy estimates if self.estimation_type == "0": conditional_entropy = _compute_entropy_noiseless( hypercell_bounds=self.hypercell_bounds, mean=post_mean, variance=post_var, observation_noise=obs_noise, ) elif self.estimation_type == "LB": conditional_entropy = _compute_entropy_upper_bound( hypercell_bounds=self.hypercell_bounds, mean=post_mean, variance=post_var, observation_noise=obs_noise, only_diagonal=False, ) elif self.estimation_type == "LB2": conditional_entropy = _compute_entropy_upper_bound( hypercell_bounds=self.hypercell_bounds, mean=post_mean, variance=post_var, observation_noise=obs_noise, only_diagonal=True, ) elif self.estimation_type == "MC": posterior_with_noise = posterior_statistics["posterior_with_noise"] samples, samples_log_prob = self._compute_monte_carlo_variables( posterior_with_noise ) conditional_entropy = _compute_entropy_monte_carlo( hypercell_bounds=self.hypercell_bounds, mean=post_mean, variance=post_var, observation_noise=obs_noise, samples=samples, samples_log_prob=samples_log_prob, ) # Sum over the batch. return initial_entropy - conditional_entropy.sum(dim=-1)
[docs] @abstractmethod def forward(self, X: Tensor) -> Tensor: r"""Compute lower bound multi-objective entropy search at the design points `X`. Args: X: A `batch_shape x q x d`-dim Tensor of `batch_shape` t-batches with `q` `d`-dim design points each. Returns: A `batch_shape`-dim Tensor of acquisition values at the given design points `X`. """ pass # pragma: no cover
[docs] class qLowerBoundMultiObjectiveJointEntropySearch( LowerBoundMultiObjectiveEntropySearch ): r"""The acquisition function for the multi-objective joint entropy search, where the batches `q > 1` are supported through the lower bound formulation. This acquisition function computes the mutual information between the observation at a candidate point `X` and the Pareto optimal input-output pairs. See [Tu2022]_ for a discussion on the estimation procedure. NOTES: (i) The estimated acquisition value could be negative. (ii) The lower bound batch acquisition function might not be monotone in the sense that adding more elements to the batch does not necessarily increase the acquisition value. Specifically, the acquisition value can become smaller when more inputs are added. """ def __init__( self, model: Model, pareto_sets: Tensor, pareto_fronts: Tensor, hypercell_bounds: Tensor, X_pending: Tensor | None = None, estimation_type: str = "LB", num_samples: int = 64, ) -> None: r"""Lower bound multi-objective joint entropy search acquisition function. Args: model: A fitted batch model with 'M' number of outputs. pareto_sets: A `num_pareto_samples x num_pareto_points x d`-dim Tensor containing the sampled Pareto optimal sets of inputs. pareto_fronts: A `num_pareto_samples x num_pareto_points x M`-dim Tensor containing the sampled Pareto optimal sets of outputs. hypercell_bounds: A `num_pareto_samples x 2 x J x M`-dim Tensor containing the hyper-rectangle bounds for integration. In the unconstrained case, this gives the partition of the dominated space. In the constrained case, this gives the partition of the feasible dominated space union the infeasible space. X_pending: A `m x d`-dim Tensor of `m` design points that have been submitted for function evaluation, but have not yet been evaluated. estimation_type: A string to determine which entropy estimate is computed: "0", "LB", "LB2", or "MC". num_samples: The number of Monte Carlo samples used for the Monte Carlo estimate. """ super().__init__( model=model, pareto_sets=pareto_sets, pareto_fronts=pareto_fronts, hypercell_bounds=hypercell_bounds, X_pending=X_pending, estimation_type=estimation_type, num_samples=num_samples, ) # Condition the model on the sampled pareto optimal points. # TODO: Apparently, we need to make a call to the posterior otherwise # we run into a gpytorch runtime error: # "Fantasy observations can only be added after making predictions with a # model so that all test independent caches exist." with fantasize_flag(): with settings.propagate_grads(False): _ = self.initial_model.posterior( self.pareto_sets, observation_noise=False ) # Condition with observation noise. self.conditional_model = self.initial_model.condition_on_observations( X=self.initial_model.transform_inputs(self.pareto_sets), Y=self.pareto_fronts, ) def _compute_posterior_statistics( self, X: Tensor ) -> dict[str, Tensor | GPyTorchPosterior]: r"""Compute the posterior statistics. Args: X: A `batch_shape x q x d`-dim Tensor of inputs. Returns: A dictionary containing the posterior variables used to estimate the entropy. - "initial_entropy": A `batch_shape`-dim Tensor containing the entropy of the Gaussian random variable `p(Y| X, D_n)`. - "posterior_mean": A `batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor containing the posterior mean at the input `X`. - "posterior_variance": A `batch_shape x num_pareto_samples x q x 1 x M` -dim Tensor containing the posterior variance at the input `X` excluding the observation noise. - "observation_noise": A `batch_shape x num_pareto_samples x q x 1 x M` -dim Tensor containing the observation noise at the input `X`. - "posterior_with_noise": The posterior distribution at `X` which includes the observation noise. This is used to compute the marginal log-probabilities with respect to `p(y| x, D_n)` for `x` in `X`. """ tkwargs = {"dtype": X.dtype, "device": X.device} CLAMP_LB = torch.finfo(tkwargs["dtype"]).eps # Compute the prior entropy term depending on `X`. initial_posterior_plus_noise = self.initial_model.posterior( X, observation_noise=True ) # Additional constant term. add_term = ( 0.5 * self.model.num_outputs * (1 + torch.log(2 * pi * torch.ones(1, **tkwargs))) ) # The variance initially has shape `batch_shape x (q*M) x (q*M)` # prior_entropy has shape `batch_shape`. initial_entropy = add_term + 0.5 * torch.logdet( initial_posterior_plus_noise.mvn.covariance_matrix ) posterior_statistics = {"initial_entropy": initial_entropy} # Compute the posterior entropy term. conditional_posterior_with_noise = self.conditional_model.posterior( X.unsqueeze(-2).unsqueeze(-3), observation_noise=True ) # `batch_shape x num_pareto_samples x q x 1 x M` post_mean = conditional_posterior_with_noise.mean.swapaxes(-4, -3) post_var_with_noise = conditional_posterior_with_noise.variance.clamp_min( CLAMP_LB ).swapaxes(-4, -3) # TODO: This computes the observation noise via a second evaluation of the # posterior. This step could be done better. conditional_posterior = self.conditional_model.posterior( X.unsqueeze(-2).unsqueeze(-3), observation_noise=False ) # `batch_shape x num_pareto_samples x q x 1 x M` post_var = conditional_posterior.variance.clamp_min(CLAMP_LB).swapaxes(-4, -3) obs_noise = (post_var_with_noise - post_var).clamp_min(CLAMP_LB) posterior_statistics["posterior_mean"] = post_mean posterior_statistics["posterior_variance"] = post_var posterior_statistics["observation_noise"] = obs_noise posterior_statistics["posterior_with_noise"] = conditional_posterior_with_noise return posterior_statistics def _compute_monte_carlo_variables( self, posterior: GPyTorchPosterior ) -> tuple[Tensor, Tensor]: r"""Compute the samples and log-probability associated with the posterior distribution that conditions on the Pareto optimal points. Args: posterior: The conditional posterior distribution at an input `X`, where we have also conditioned over the `num_pareto_samples` of optimal points. Note that this posterior includes the observation noise. Returns: A two-element tuple containing - samples: A `num_mc_samples x batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor containing the Monte Carlo samples. - samples_log_probs: A `num_mc_samples x batch_shape x num_pareto_samples x q`-dim Tensor containing the log-probabilities of the Monte Carlo samples. """ # `num_mc_samples x batch_shape x q x num_pareto_samples x 1 x M` samples = self.get_posterior_samples(posterior) # `num_mc_samples x batch_shape x q x num_pareto_samples` if self.model.num_outputs == 1: samples_log_prob = posterior.mvn.log_prob(samples.squeeze(-1)) else: samples_log_prob = posterior.mvn.log_prob(samples) # Swap axes to get the correct shape: # samples:`num_mc_samples x batch_shape x num_pareto_samples x q x 1 x M` # log prob:`num_mc_samples x batch_shape x num_pareto_samples x q` return samples.swapaxes(-4, -3), samples_log_prob.swapaxes(-2, -1)
[docs] @concatenate_pending_points @t_batch_mode_transform() def forward(self, X: Tensor) -> Tensor: r"""Evaluates qLowerBoundMultiObjectiveJointEntropySearch at the design points `X`. Args: X: A `batch_shape x q x d`-dim Tensor of `batch_shape` t-batches with `q` `d`-dim design points each. Returns: A `batch_shape`-dim Tensor of acquisition values at the given design points `X`. """ return self._compute_lower_bound_information_gain(X)
def _compute_entropy_noiseless( hypercell_bounds: Tensor, mean: Tensor, variance: Tensor, observation_noise: Tensor, ) -> Tensor: r"""Computes the entropy estimate at the design points `X` assuming noiseless observations. This is used for the JES-0 and MES-0 estimate. Args: hypercell_bounds: A `num_pareto_samples x 2 x J x M` -dim Tensor containing the box decomposition bounds, where `J = max(num_boxes)`. mean: A `batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor containing the posterior mean at X. variance: A `batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor containing the posterior variance at X excluding observation noise. observation_noise: A `batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor containing the observation noise at X. Returns: A `batch_shape x q`-dim Tensor of entropy estimate at the given design points `X`. """ tkwargs = {"dtype": hypercell_bounds.dtype, "device": hypercell_bounds.device} CLAMP_LB = torch.finfo(tkwargs["dtype"]).eps variance_plus_noise = variance + observation_noise # Standardize the box decomposition bounds and compute normal quantities. # `batch_shape x num_pareto_samples x q x 2 x J x M` g = (hypercell_bounds.unsqueeze(-4) - mean.unsqueeze(-2)) / torch.sqrt( variance.unsqueeze(-2) ) normal = Normal(torch.zeros_like(g), torch.ones_like(g)) gcdf = normal.cdf(g) gpdf = torch.exp(normal.log_prob(g)) g_times_gpdf = g * gpdf # Compute the differences between the upper and lower terms. Wjm = (gcdf[..., 1, :, :] - gcdf[..., 0, :, :]).clamp_min(CLAMP_LB) Vjm = g_times_gpdf[..., 1, :, :] - g_times_gpdf[..., 0, :, :] # Compute W. Wj = torch.exp(torch.sum(torch.log(Wjm), dim=-1, keepdims=True)) W = torch.sum(Wj, dim=-2, keepdims=True).clamp_max(1.0) # Compute the sum of ratios. ratios = 0.5 * (Wj * (Vjm / Wjm)) / W # `batch_shape x num_pareto_samples x q x 1 x 1` ratio_term = torch.sum(ratios, dim=(-2, -1), keepdims=True) # Compute the logarithm of the variance. log_term = 0.5 * torch.log(variance_plus_noise).sum(-1, keepdims=True) # `batch_shape x num_pareto_samples x q x 1 x 1` log_term = log_term + torch.log(W) # Additional constant term. M_plus_K = mean.shape[-1] add_term = 0.5 * M_plus_K * (1 + torch.log(torch.ones(1, **tkwargs) * 2 * pi)) # `batch_shape x num_pareto_samples x q` entropy = add_term + (log_term - ratio_term).squeeze(-1).squeeze(-1) return entropy.mean(-2) def _compute_entropy_upper_bound( hypercell_bounds: Tensor, mean: Tensor, variance: Tensor, observation_noise: Tensor, only_diagonal: bool = False, ) -> Tensor: r"""Computes the entropy upper bound at the design points `X`. This is used for the JES-LB and MES-LB estimate. If `only_diagonal` is True, then this computes the entropy estimate for the JES-LB2 and MES-LB2. Args: hypercell_bounds: A `num_pareto_samples x 2 x J x M` -dim Tensor containing the box decomposition bounds, where `J` = max(num_boxes). mean: A `batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor containing the posterior mean at X. variance: A `batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor containing the posterior variance at X excluding observation noise. observation_noise: A `batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor containing the observation noise at X. only_diagonal: If true, we only compute the diagonal elements of the variance. Returns: A `batch_shape x q`-dim Tensor of entropy estimate at the given design points `X`. """ tkwargs = {"dtype": hypercell_bounds.dtype, "device": hypercell_bounds.device} CLAMP_LB = torch.finfo(tkwargs["dtype"]).eps variance_plus_noise = variance + observation_noise # Standardize the box decomposition bounds and compute normal quantities. # `batch_shape x num_pareto_samples x q x 2 x J x M` g = (hypercell_bounds.unsqueeze(-4) - mean.unsqueeze(-2)) / torch.sqrt( variance.unsqueeze(-2) ) normal = Normal(torch.zeros_like(g), torch.ones_like(g)) gcdf = normal.cdf(g) gpdf = torch.exp(normal.log_prob(g)) g_times_gpdf = g * gpdf # Compute the differences between the upper and lower terms. Wjm = (gcdf[..., 1, :, :] - gcdf[..., 0, :, :]).clamp_min(CLAMP_LB) Vjm = g_times_gpdf[..., 1, :, :] - g_times_gpdf[..., 0, :, :] Gjm = gpdf[..., 1, :, :] - gpdf[..., 0, :, :] # Compute W. Wj = torch.exp(torch.sum(torch.log(Wjm), dim=-1, keepdims=True)) W = torch.sum(Wj, dim=-2, keepdims=True).clamp_max(1.0) Cjm = Gjm / Wjm # First moment: Rjm = Cjm * Wj / W # `batch_shape x num_pareto_samples x q x 1 x M mom1 = mean - torch.sqrt(variance) * Rjm.sum(-2, keepdims=True) # diagonal weighted sum # `batch_shape x num_pareto_samples x q x 1 x M diag_weighted_sum = (Wj * variance * Vjm / Wjm / W).sum(-2, keepdims=True) if only_diagonal: # `batch_shape x num_pareto_samples x q x 1 x M` mean_squared = mean.pow(2) cross_sum = -2 * (mean * torch.sqrt(variance) * Rjm).sum(-2, keepdims=True) # `batch_shape x num_pareto_samples x q x 1 x M` mom2 = variance_plus_noise - diag_weighted_sum + cross_sum + mean_squared var = (mom2 - mom1.pow(2)).clamp_min(CLAMP_LB) # `batch_shape x num_pareto_samples x q log_det_term = 0.5 * torch.log(var).sum(dim=-1).squeeze(-1) else: # First moment x First moment # `batch_shape x num_pareto_samples x q x 1 x M x M cross_mom1 = torch.einsum("...i,...j->...ij", mom1, mom1) # Second moment: # `batch_shape x num_pareto_samples x q x 1 x M x M # firstly compute the general terms mom2_cross1 = -torch.einsum( "...i,...j->...ij", mean, torch.sqrt(variance) * Cjm ) mom2_cross2 = -torch.einsum( "...i,...j->...ji", mean, torch.sqrt(variance) * Cjm ) mom2_mean_squared = torch.einsum("...i,...j->...ij", mean, mean) mom2_weighted_sum = ( (mom2_cross1 + mom2_cross2) * Wj.unsqueeze(-1) / W.unsqueeze(-1) ).sum(-3, keepdims=True) mom2_weighted_sum = mom2_weighted_sum + mom2_mean_squared # Compute the additional off-diagonal terms. mom2_off_diag = torch.einsum( "...i,...j->...ij", torch.sqrt(variance) * Cjm, torch.sqrt(variance) * Cjm ) mom2_off_diag_sum = (mom2_off_diag * Wj.unsqueeze(-1) / W.unsqueeze(-1)).sum( -3, keepdims=True ) # Compute the diagonal terms and subtract the diagonal computed before. init_diag = torch.diagonal(mom2_off_diag_sum, dim1=-2, dim2=-1) diag_weighted_sum = torch.diag_embed( variance_plus_noise - diag_weighted_sum - init_diag ) mom2 = mom2_weighted_sum + mom2_off_diag_sum + diag_weighted_sum # Compute the variance var = (mom2 - cross_mom1).squeeze(-3) # Jitter the diagonal. # The jitter is probably not needed here at all. jitter_diag = 1e-6 * torch.diag_embed(torch.ones(var.shape[:-1], **tkwargs)) log_det_term = 0.5 * torch.logdet(var + jitter_diag) # Additional terms. M_plus_K = mean.shape[-1] add_term = 0.5 * M_plus_K * (1 + torch.log(torch.ones(1, **tkwargs) * 2 * pi)) # `batch_shape x num_pareto_samples x q entropy = add_term + log_det_term return entropy.mean(-2) def _compute_entropy_monte_carlo( hypercell_bounds: Tensor, mean: Tensor, variance: Tensor, observation_noise: Tensor, samples: Tensor, samples_log_prob: Tensor, ) -> Tensor: r"""Computes the Monte Carlo entropy at the design points `X`. This is used for the JES-MC and MES-MC estimate. Args: hypercell_bounds: A `num_pareto_samples x 2 x J x M`-dim Tensor containing the box decomposition bounds, where `J` = max(num_boxes). mean: A `batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor containing the posterior mean at X. variance: A `batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor containing the posterior variance at X excluding observation noise. observation_noise: A `batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor containing the observation noise at X. samples: A `num_mc_samples x batch_shape x num_pareto_samples x q x 1 x M`-dim Tensor containing the noisy samples at `X` from the posterior conditioned on the Pareto optimal points. samples_log_prob: A `num_mc_samples x batch_shape x num_pareto_samples x q`-dim Tensor containing the log probability densities of the samples. Returns: A `batch_shape x q`-dim Tensor of entropy estimate at the given design points `X`. """ tkwargs = {"dtype": hypercell_bounds.dtype, "device": hypercell_bounds.device} CLAMP_LB = torch.finfo(tkwargs["dtype"]).eps variance_plus_noise = variance + observation_noise #################################################################### # Standardize the box decomposition bounds and compute normal quantities. # `batch_shape x num_pareto_samples x q x 2 x J x M` g = (hypercell_bounds.unsqueeze(-4) - mean.unsqueeze(-2)) / torch.sqrt( variance.unsqueeze(-2) ) # `batch_shape x num_pareto_samples x q x 1 x M` rho = torch.sqrt(variance / variance_plus_noise) # Compute the initial normal quantities. normal = Normal(torch.zeros_like(g), torch.ones_like(g)) gcdf = normal.cdf(g) # Compute the differences between the upper and lower terms. Wjm = (gcdf[..., 1, :, :] - gcdf[..., 0, :, :]).clamp_min(CLAMP_LB) # Compute W. Wj = torch.exp(torch.sum(torch.log(Wjm), dim=-1, keepdims=True)) # `batch_shape x num_pareto_samples x q x 1 x 1` W = torch.sum(Wj, dim=-2, keepdims=True).clamp_max(1.0) #################################################################### g = g.unsqueeze(0) rho = rho.unsqueeze(0).unsqueeze(-2) # `num_mc_samples x batch_shape x num_pareto_samples x q x 1 x 1 x M` z = ((samples - mean) / torch.sqrt(variance_plus_noise)).unsqueeze(-2) # `num_mc_samples x batch_shape x num_pareto_samples x q x 2 x J x M` # Clamping here is important because `1 - rho^2 = 0` at an input where # observation noise is zero. g_new = (g - rho * z) / torch.sqrt((1 - rho * rho).clamp_min(CLAMP_LB)) # Compute the initial normal quantities. normal_new = Normal(torch.zeros_like(g_new), torch.ones_like(g_new)) gcdf_new = normal_new.cdf(g_new) # Compute the differences between the upper and lower terms. Wjm_new = (gcdf_new[..., 1, :, :] - gcdf_new[..., 0, :, :]).clamp_min(CLAMP_LB) # Compute W+. Wj_new = torch.exp(torch.sum(torch.log(Wjm_new), dim=-1, keepdims=True)) # `num_mc_samples x batch_shape x num_pareto_samples x q x 1 x 1` W_new = torch.sum(Wj_new, dim=-2, keepdims=True).clamp_max(1.0) #################################################################### # W_ratio = W+ / W W_ratio = torch.exp(torch.log(W_new) - torch.log(W).unsqueeze(0)) samples_log_prob = samples_log_prob.unsqueeze(-1).unsqueeze(-1) # Compute the Monte Carlo average: - E[W_ratio * log(W+ p(y))] + log(W) log_term = torch.log(W_new) + samples_log_prob mc_estimate = -(W_ratio * log_term).mean(0) # `batch_shape x num_pareto_samples x q entropy = (mc_estimate + torch.log(W)).squeeze(-1).squeeze(-1) # An alternative Monte Carlo estimate: - E[W_ratio * log(W_ratio p(y))] # log_term = torch.log(W_ratio) + samples_log_prob # mc_estimate = - (W_ratio * log_term).mean(0) # # `batch_shape x num_pareto_samples x q # entropy = mc_estimate.squeeze(-1).squeeze(-1) return entropy.mean(-2)