Source code for botorch.sampling.get_sampler
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from botorch.logging import logger
from botorch.posteriors.ensemble import EnsemblePosterior
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
from botorch.posteriors.posterior_list import PosteriorList
from botorch.posteriors.torch import TorchPosterior
from botorch.posteriors.transformed import TransformedPosterior
from botorch.sampling.base import MCSampler
from botorch.sampling.index_sampler import IndexSampler
from botorch.sampling.list_sampler import ListSampler
from botorch.sampling.normal import (
IIDNormalSampler,
NormalMCSampler,
SobolQMCNormalSampler,
)
from botorch.utils.dispatcher import Dispatcher
from gpytorch.distributions import MultivariateNormal
from torch.distributions import Distribution
from torch.quasirandom import SobolEngine
def _posterior_to_distribution_encoder(
posterior: Posterior,
) -> type[Distribution] | type[Posterior]:
r"""An encoder returning the type of the distribution for `TorchPosterior`
and the type of the posterior for the rest.
"""
if isinstance(posterior, TorchPosterior):
return type(posterior.distribution)
return type(posterior)
GetSampler = Dispatcher("get_sampler", encoder=_posterior_to_distribution_encoder)
[docs]
def get_sampler(
posterior: TorchPosterior,
sample_shape: torch.Size,
*,
seed: int | None = None,
) -> MCSampler:
r"""Get the sampler for the given posterior.
The sampler can be used as `sampler(posterior)` to produce samples
suitable for use in acquisition function optimization via SAA.
Args:
posterior: A `Posterior` to get the sampler for.
sample_shape: The sample shape of the samples produced by the
given sampler. The full shape of the resulting samples is
given by `posterior._extended_shape(sample_shape)`.
seed: Seed used to initialize sampler.
Returns:
The `MCSampler` object for the given posterior.
"""
return GetSampler(posterior, sample_shape=sample_shape, seed=seed)
@GetSampler.register(MultivariateNormal)
def _get_sampler_mvn(
posterior: GPyTorchPosterior,
sample_shape: torch.Size,
*,
seed: int | None = None,
) -> NormalMCSampler:
r"""The Sobol normal sampler for the `MultivariateNormal` posterior.
If the output dim is too large, falls back to `IIDNormalSampler`.
"""
sampler = SobolQMCNormalSampler(sample_shape=sample_shape, seed=seed)
collapsed_shape = sampler._get_collapsed_shape(posterior=posterior)
base_collapsed_shape = collapsed_shape[len(sample_shape) :]
if base_collapsed_shape.numel() > SobolEngine.MAXDIM:
logger.warning(
f"Output dim {base_collapsed_shape.numel()} is too large for the "
"Sobol engine. Using IIDNormalSampler instead."
)
sampler = IIDNormalSampler(sample_shape=sample_shape, seed=seed)
return sampler
@GetSampler.register(TransformedPosterior)
def _get_sampler_derived(
posterior: TransformedPosterior,
sample_shape: torch.Size,
*,
seed: int | None = None,
) -> MCSampler:
r"""Get the sampler for the underlying posterior."""
return get_sampler(
posterior=posterior._posterior,
sample_shape=sample_shape,
seed=seed,
)
@GetSampler.register(PosteriorList)
def _get_sampler_list(
posterior: PosteriorList, sample_shape: torch.Size, *, seed: int | None = None
) -> MCSampler:
r"""Get the `ListSampler` with the appropriate list of samplers."""
samplers = [
get_sampler(posterior=p, sample_shape=sample_shape, seed=seed)
for p in posterior.posteriors
]
return ListSampler(*samplers)
@GetSampler.register(EnsemblePosterior)
def _get_sampler_ensemble(
posterior: EnsemblePosterior,
sample_shape: torch.Size,
seed: int | None = None,
) -> MCSampler:
r"""Get the `IndexSampler` for the `EnsemblePosterior`."""
return IndexSampler(sample_shape=sample_shape, seed=seed)
@GetSampler.register(object)
def _not_found_error(
posterior: Posterior,
sample_shape: torch.Size,
seed: int | None = None,
) -> None:
raise NotImplementedError(
f"A registered `MCSampler` for posterior {posterior} is not found. You can "
"implement and register one using `@GetSampler.register`."
)