Source code for botorch.sampling.get_sampler

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import torch
from botorch.logging import logger
from botorch.posteriors.ensemble import EnsemblePosterior
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
from botorch.posteriors.posterior_list import PosteriorList
from botorch.posteriors.torch import TorchPosterior
from botorch.posteriors.transformed import TransformedPosterior
from botorch.sampling.base import MCSampler
from botorch.sampling.index_sampler import IndexSampler
from botorch.sampling.list_sampler import ListSampler
from botorch.sampling.normal import (
    IIDNormalSampler,
    NormalMCSampler,
    SobolQMCNormalSampler,
)
from botorch.utils.dispatcher import Dispatcher
from gpytorch.distributions import MultivariateNormal
from torch.distributions import Distribution
from torch.quasirandom import SobolEngine


def _posterior_to_distribution_encoder(
    posterior: Posterior,
) -> type[Distribution] | type[Posterior]:
    r"""An encoder returning the type of the distribution for `TorchPosterior`
    and the type of the posterior for the rest.
    """
    if isinstance(posterior, TorchPosterior):
        return type(posterior.distribution)
    return type(posterior)


GetSampler = Dispatcher("get_sampler", encoder=_posterior_to_distribution_encoder)


[docs] def get_sampler( posterior: TorchPosterior, sample_shape: torch.Size, *, seed: int | None = None, ) -> MCSampler: r"""Get the sampler for the given posterior. The sampler can be used as `sampler(posterior)` to produce samples suitable for use in acquisition function optimization via SAA. Args: posterior: A `Posterior` to get the sampler for. sample_shape: The sample shape of the samples produced by the given sampler. The full shape of the resulting samples is given by `posterior._extended_shape(sample_shape)`. seed: Seed used to initialize sampler. Returns: The `MCSampler` object for the given posterior. """ return GetSampler(posterior, sample_shape=sample_shape, seed=seed)
@GetSampler.register(MultivariateNormal) def _get_sampler_mvn( posterior: GPyTorchPosterior, sample_shape: torch.Size, *, seed: int | None = None, ) -> NormalMCSampler: r"""The Sobol normal sampler for the `MultivariateNormal` posterior. If the output dim is too large, falls back to `IIDNormalSampler`. """ sampler = SobolQMCNormalSampler(sample_shape=sample_shape, seed=seed) collapsed_shape = sampler._get_collapsed_shape(posterior=posterior) base_collapsed_shape = collapsed_shape[len(sample_shape) :] if base_collapsed_shape.numel() > SobolEngine.MAXDIM: logger.warning( f"Output dim {base_collapsed_shape.numel()} is too large for the " "Sobol engine. Using IIDNormalSampler instead." ) sampler = IIDNormalSampler(sample_shape=sample_shape, seed=seed) return sampler @GetSampler.register(TransformedPosterior) def _get_sampler_derived( posterior: TransformedPosterior, sample_shape: torch.Size, *, seed: int | None = None, ) -> MCSampler: r"""Get the sampler for the underlying posterior.""" return get_sampler( posterior=posterior._posterior, sample_shape=sample_shape, seed=seed, ) @GetSampler.register(PosteriorList) def _get_sampler_list( posterior: PosteriorList, sample_shape: torch.Size, *, seed: int | None = None ) -> MCSampler: r"""Get the `ListSampler` with the appropriate list of samplers.""" samplers = [ get_sampler(posterior=p, sample_shape=sample_shape, seed=seed) for p in posterior.posteriors ] return ListSampler(*samplers) @GetSampler.register(EnsemblePosterior) def _get_sampler_ensemble( posterior: EnsemblePosterior, sample_shape: torch.Size, seed: int | None = None, ) -> MCSampler: r"""Get the `IndexSampler` for the `EnsemblePosterior`.""" return IndexSampler(sample_shape=sample_shape, seed=seed) @GetSampler.register(object) def _not_found_error( posterior: Posterior, sample_shape: torch.Size, seed: int | None = None, ) -> None: raise NotImplementedError( f"A registered `MCSampler` for posterior {posterior} is not found. You can " "implement and register one using `@GetSampler.register`." )