Source code for botorch.acquisition.bayesian_active_learning
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Acquisition functions for Bayesian active learning. This includes:
BALD [Houlsby2011bald]_ and its batch version [kirsch2019batchbald]_.
References
.. [kirsch2019batchbald]
Andreas Kirsch, Joost van Amersfoort, Yarin Gal.
BatchBALD: Efficient and Diverse Batch Acquisition for Deep Bayesian
Active Learning.
In Proceedings of the Annual Conference on Neural Information
Processing Systems (NeurIPS), 2019.
"""
from __future__ import annotations
import warnings
from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
from botorch.acquisition.objective import PosteriorTransform
from botorch.models import ModelListGP
from botorch.models.fully_bayesian import MCMC_DIM, SaasFullyBayesianSingleTaskGP
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
from botorch.utils.transforms import (
concatenate_pending_points,
is_fully_bayesian,
t_batch_mode_transform,
)
from gpytorch.distributions.multitask_multivariate_normal import (
MultitaskMultivariateNormal,
)
from torch import Tensor
FULLY_BAYESIAN_ERROR_MSG = (
"Fully Bayesian acquisition functions require a SaasFullyBayesianSingleTaskGP "
"or of ModelList of SaasFullyBayesianSingleTaskGPs to run."
)
NEGATIVE_INFOGAIN_WARNING = (
"Information gain is negative. This is likely due to a poor Monte Carlo "
"estimation of the entropies, extremely high or extremely low correlation "
"in the data." # because both of those cases result in no information gain
)
[docs]
def check_negative_info_gain(info_gain: Tensor) -> None:
r"""Check if the (expected) information gain is negative, raise a warning if so."""
if info_gain.lt(0).any():
warnings.warn(NEGATIVE_INFOGAIN_WARNING, RuntimeWarning, stacklevel=2)
[docs]
class FullyBayesianAcquisitionFunction(AcquisitionFunction):
def __init__(self, model: Model):
"""Base class for acquisition functions which require a Fully Bayesian
model treatment.
Args:
model: A fully bayesian single-outcome model.
"""
if is_fully_bayesian(model):
super().__init__(model)
else:
raise RuntimeError(FULLY_BAYESIAN_ERROR_MSG)
[docs]
class qBayesianActiveLearningByDisagreement(
FullyBayesianAcquisitionFunction, MCSamplerMixin
):
def __init__(
self,
model: ModelListGP | SaasFullyBayesianSingleTaskGP,
sampler: MCSampler | None = None,
posterior_transform: PosteriorTransform | None = None,
X_pending: Tensor | None = None,
) -> None:
"""
Batch implementation [kirsch2019batchbald]_ of BALD [Houlsby2011bald]_,
which maximizes the mutual information between the next observation and the
hyperparameters of the model. Computed by Monte Carlo integration.
Args:
model: A fully bayesian model (SaasFullyBayesianSingleTaskGP).
sampler: The sampler used for drawing samples to approximate the entropy
of the Gaussian Mixture posterior.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points
"""
super().__init__(model=model)
MCSamplerMixin.__init__(self, sampler=sampler)
self.set_X_pending(X_pending)
self.posterior_transform = posterior_transform
[docs]
@concatenate_pending_points
@t_batch_mode_transform()
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate qBayesianActiveLearningByDisagreement on the candidate set `X`.
A monte carlo-estimated information gain is computed over a Gaussian Mixture
marginal posterior, and the Gaussian conditional posterior to obtain the
qBayesianActiveLearningByDisagreement on the candidate set `X`.
Args:
X: `batch_shape x q x D`-dim Tensor of input points.
Returns:
A `batch_shape x num_models`-dim Tensor of BALD values.
"""
posterior = self.model.posterior(
X, observation_noise=True, posterior_transform=self.posterior_transform
)
if isinstance(posterior.mvn, MultitaskMultivariateNormal):
# The default MultitaskMultivariateNormal conversion for
# GuassianMixturePosteriors does not interleave (and models task and data)
# covariances in the unintended order. This is a inter-task block-diagonal,
# and not inter-data block-diagonal, which is the default for GMMPosteriors
posterior.mvn._interleaved = True
# draw samples from the mixture posterior.
# samples: num_samples x batch_shape x num_models x q x num_outputs
samples = self.get_posterior_samples(posterior=posterior)
# Estimate the entropy of 'num_samples' samples from 'num_models' models by
# evaluating the log_prob on each sample on the mixture posterior
# (which constitutes of M models). thus, order N*M^2 computations
# Make room and move the model dim to the front, squeeze the num_outputs dim.
# prev_samples: num_models x num_samples x batch_shape x 1 x q
prev_samples = samples.unsqueeze(0).transpose(0, MCMC_DIM).squeeze(-1)
# avg the probs over models in the mixture - dim (-2) will be broadcasted
# with the num_models of the posterior --> querying all samples on all models
# posterior.mvn takes q-dimensional input by default, which removes the q-dim
# component_sample_probs: num_models x num_samples x batch_shape x num_models
component_sample_probs = posterior.mvn.log_prob(prev_samples).exp()
# average over mixture components
mixture_sample_probs = component_sample_probs.mean(dim=-1, keepdim=True)
# this is the average over the model and sample dim
prev_entropy = -mixture_sample_probs.log().mean(dim=[0, 1])
# the posterior entropy is an average entropy over gaussians, so no mixture
post_entropy = -posterior.mvn.log_prob(samples.squeeze(-1)).mean(0)
# The BALD acq is defined as an expectation over a fully bayesian model,
# so thus, the mean is computed here and not outside of the forward pass
bald = (prev_entropy - post_entropy).mean(-1, keepdim=True)
check_negative_info_gain(bald)
return bald