Source code for botorch.sampling.base
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
The base class for sampler modules to be used with MC-evaluated acquisition functions.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, Tuple
import torch
from botorch.exceptions.errors import InputDataError
from botorch.posteriors import Posterior
from torch import Tensor
from torch.nn import Module
KWARGS_DEPRECATED_MSG = (
"The {} argument of `MCSampler`s has been deprecated and will raise an "
"error in a future version."
)
KWARG_ERR_MSG = (
"`MCSampler`s no longer support the `{}` argument. "
"Consider using `{}` for similar functionality."
)
[docs]
class MCSampler(Module, ABC):
r"""Abstract base class for Samplers.
Subclasses must implement the `forward` method.
Example:
This method is usually not called directly, but via the sampler's
`__call__` method:
>>> posterior = model.posterior(test_X)
>>> samples = sampler(posterior)
"""
def __init__(
self,
sample_shape: torch.Size,
seed: Optional[int] = None,
) -> None:
r"""Abstract base class for samplers.
Args:
sample_shape: The `sample_shape` of the samples to generate. The full shape
of the samples is given by `posterior._extended_shape(sample_shape)`.
seed: An optional seed to use for sampling.
"""
super().__init__()
if not isinstance(sample_shape, torch.Size):
raise InputDataError(
"Expected `sample_shape` to be a `torch.Size` object, "
f"got {sample_shape}."
)
self.sample_shape = sample_shape
self.seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item()
self.register_buffer("base_samples", None)
[docs]
@abstractmethod
def forward(self, posterior: Posterior) -> Tensor:
r"""Draws MC samples from the posterior.
Args:
posterior: The posterior to sample from.
Returns:
The samples drawn from the posterior.
"""
pass # pragma no cover
def _get_batch_range(self, posterior: Posterior) -> Tuple[int, int]:
r"""Get the t-batch range of the posterior with an optional override.
In rare cases, e.g., in `qMultiStepLookahead`, we may want to override the
`batch_range` of the posterior. If this behavior is desired, one can set
`batch_range_override` attribute on the samplers.
Args:
posterior: The posterior to sample from.
Returns:
The t-batch range to use for collapsing the base samples.
"""
if hasattr(self, "batch_range_override"):
return self.batch_range_override
return posterior.batch_range
def _get_collapsed_shape(self, posterior: Posterior) -> torch.Size:
r"""Get the shape of the base samples with the t-batches collapsed.
Args:
posterior: The posterior to sample from.
Returns:
The collapsed shape of the base samples expected by the posterior. The
t-batch dimensions of the base samples are collapsed to size 1. This is
useful to prevent sampling variance across t-batches.
"""
base_sample_shape = posterior.base_sample_shape
batch_start, batch_end = self._get_batch_range(posterior)
base_sample_shape = (
base_sample_shape[:batch_start]
+ torch.Size([1 for _ in base_sample_shape[batch_start:batch_end]])
+ base_sample_shape[batch_end:]
)
return self.sample_shape + base_sample_shape
def _get_extended_base_sample_shape(self, posterior: Posterior) -> torch.Size:
r"""Get the shape of the base samples expected by the posterior.
Args:
posterior: The posterior to sample from.
Returns:
The extended shape of the base samples expected by the posterior.
"""
return self.sample_shape + posterior.base_sample_shape
def _update_base_samples(
self, posterior: Posterior, base_sampler: MCSampler
) -> None:
r"""Update the sampler to use the original base samples for X_baseline.
This is used in CachedCholeskyAcquisitionFunctions to ensure consistency.
Args:
posterior: The posterior for which the base samples are constructed.
base_sampler: The base sampler to retrieve the base samples from.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement `_update_base_samples`."
)
def _instance_check(self, base_sampler):
r"""Check that `base_sampler` is an instance of `self.__class__`."""
if not isinstance(base_sampler, self.__class__):
raise RuntimeError(
"Expected `base_sampler` to be an instance of "
f"{self.__class__.__name__}. Got {base_sampler}."
)