Source code for botorch.sampling.list_sampler
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
A `SamplerList` for sampling from a `PosteriorList`.
"""
from __future__ import annotations
import torch
from botorch.exceptions.errors import UnsupportedError
from botorch.posteriors.posterior_list import PosteriorList
from botorch.sampling.base import MCSampler
from torch import Tensor
from torch.nn import ModuleList
[docs]
class ListSampler(MCSampler):
def __init__(self, *samplers: MCSampler) -> None:
r"""A list of samplers for sampling from a `PosteriorList`.
Args:
samplers: A variable number of samplers. This should include
a sampler for each posterior.
"""
super(MCSampler, self).__init__()
self.samplers = ModuleList(samplers)
self._validate_samplers()
def _validate_samplers(self) -> None:
r"""Checks that the samplers share the same sample shape."""
sample_shapes = [s.sample_shape for s in self.samplers]
if not all(sample_shapes[0] == ss for ss in sample_shapes):
raise UnsupportedError(
"ListSampler requires all samplers to have the same sample shape."
)
@property
def sample_shape(self) -> torch.Size:
r"""The sample shape of the underlying samplers."""
self._validate_samplers()
return self.samplers[0].sample_shape
[docs]
def forward(self, posterior: PosteriorList) -> Tensor:
r"""Samples from the posteriors and concatenates the samples.
Args:
posterior: A `PosteriorList` to sample from.
Returns:
The samples drawn from the posterior.
"""
samples_list = [
s(posterior=p) for s, p in zip(self.samplers, posterior.posteriors)
]
return posterior._reshape_and_cat(tensors=samples_list)
def _update_base_samples(
self, posterior: PosteriorList, base_sampler: ListSampler
) -> None:
r"""Update the sampler to use the original base samples for X_baseline.
This is used in CachedCholeskyAcquisitionFunctions to ensure consistency.
Args:
posterior: The posterior for which the base samples are constructed.
base_sampler: The base sampler to retrieve the base samples from.
"""
self._instance_check(base_sampler=base_sampler)
for s, p, bs in zip(self.samplers, posterior.posteriors, base_sampler.samplers):
s._update_base_samples(posterior=p, base_sampler=bs)