Source code for botorch.sampling.list_sampler

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
A `SamplerList` for sampling from a `PosteriorList`.
"""

from __future__ import annotations

import torch
from botorch.exceptions.errors import UnsupportedError
from botorch.posteriors.posterior_list import PosteriorList
from botorch.sampling.base import MCSampler
from torch import Tensor
from torch.nn import ModuleList


[docs] class ListSampler(MCSampler): def __init__(self, *samplers: MCSampler) -> None: r"""A list of samplers for sampling from a `PosteriorList`. Args: samplers: A variable number of samplers. This should include a sampler for each posterior. """ super(MCSampler, self).__init__() self.samplers = ModuleList(samplers) self._validate_samplers() def _validate_samplers(self) -> None: r"""Checks that the samplers share the same sample shape.""" sample_shapes = [s.sample_shape for s in self.samplers] if not all(sample_shapes[0] == ss for ss in sample_shapes): raise UnsupportedError( "ListSampler requires all samplers to have the same sample shape." ) @property def sample_shape(self) -> torch.Size: r"""The sample shape of the underlying samplers.""" self._validate_samplers() return self.samplers[0].sample_shape
[docs] def forward(self, posterior: PosteriorList) -> Tensor: r"""Samples from the posteriors and concatenates the samples. Args: posterior: A `PosteriorList` to sample from. Returns: The samples drawn from the posterior. """ samples_list = [ s(posterior=p) for s, p in zip(self.samplers, posterior.posteriors) ] return posterior._reshape_and_cat(tensors=samples_list)
def _update_base_samples( self, posterior: PosteriorList, base_sampler: ListSampler ) -> None: r"""Update the sampler to use the original base samples for X_baseline. This is used in CachedCholeskyAcquisitionFunctions to ensure consistency. Args: posterior: The posterior for which the base samples are constructed. base_sampler: The base sampler to retrieve the base samples from. """ self._instance_check(base_sampler=base_sampler) for s, p, bs in zip(self.samplers, posterior.posteriors, base_sampler.samplers): s._update_base_samples(posterior=p, base_sampler=bs)