Source code for botorch.sampling.base

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
The base class for sampler modules to be used with MC-evaluated acquisition functions.
"""

from __future__ import annotations

from abc import ABC, abstractmethod

import torch
from botorch.exceptions.errors import InputDataError
from botorch.posteriors import Posterior
from torch import Tensor
from torch.nn import Module


KWARGS_DEPRECATED_MSG = (
    "The {} argument of `MCSampler`s has been deprecated and will raise an "
    "error in a future version."
)
KWARG_ERR_MSG = (
    "`MCSampler`s no longer support the `{}` argument. "
    "Consider using `{}` for similar functionality."
)


[docs] class MCSampler(Module, ABC): r"""Abstract base class for Samplers. Subclasses must implement the `forward` method. Example: This method is usually not called directly, but via the sampler's `__call__` method: >>> posterior = model.posterior(test_X) >>> samples = sampler(posterior) """ def __init__( self, sample_shape: torch.Size, seed: int | None = None, ) -> None: r"""Abstract base class for samplers. Args: sample_shape: The `sample_shape` of the samples to generate. The full shape of the samples is given by `posterior._extended_shape(sample_shape)`. seed: An optional seed to use for sampling. """ super().__init__() if not isinstance(sample_shape, torch.Size): raise InputDataError( "Expected `sample_shape` to be a `torch.Size` object, " f"got {sample_shape}." ) self.sample_shape = sample_shape self.seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item() self.register_buffer("base_samples", None)
[docs] @abstractmethod def forward(self, posterior: Posterior) -> Tensor: r"""Draws MC samples from the posterior. Args: posterior: The posterior to sample from. Returns: The samples drawn from the posterior. """ pass # pragma no cover
def _get_batch_range(self, posterior: Posterior) -> tuple[int, int]: r"""Get the t-batch range of the posterior with an optional override. In rare cases, e.g., in `qMultiStepLookahead`, we may want to override the `batch_range` of the posterior. If this behavior is desired, one can set `batch_range_override` attribute on the samplers. Args: posterior: The posterior to sample from. Returns: The t-batch range to use for collapsing the base samples. """ if hasattr(self, "batch_range_override"): return self.batch_range_override return posterior.batch_range def _get_collapsed_shape(self, posterior: Posterior) -> torch.Size: r"""Get the shape of the base samples with the t-batches collapsed. Args: posterior: The posterior to sample from. Returns: The collapsed shape of the base samples expected by the posterior. The t-batch dimensions of the base samples are collapsed to size 1. This is useful to prevent sampling variance across t-batches. """ base_sample_shape = posterior.base_sample_shape batch_start, batch_end = self._get_batch_range(posterior) base_sample_shape = ( base_sample_shape[:batch_start] + torch.Size([1 for _ in base_sample_shape[batch_start:batch_end]]) + base_sample_shape[batch_end:] ) return self.sample_shape + base_sample_shape def _get_extended_base_sample_shape(self, posterior: Posterior) -> torch.Size: r"""Get the shape of the base samples expected by the posterior. Args: posterior: The posterior to sample from. Returns: The extended shape of the base samples expected by the posterior. """ return self.sample_shape + posterior.base_sample_shape def _update_base_samples( self, posterior: Posterior, base_sampler: MCSampler ) -> None: r"""Update the sampler to use the original base samples for X_baseline. This is used in CachedCholeskyAcquisitionFunctions to ensure consistency. Args: posterior: The posterior for which the base samples are constructed. base_sampler: The base sampler to retrieve the base samples from. """ raise NotImplementedError( f"{self.__class__.__name__} does not implement `_update_base_samples`." ) def _instance_check(self, base_sampler): r"""Check that `base_sampler` is an instance of `self.__class__`.""" if not isinstance(base_sampler, self.__class__): raise RuntimeError( "Expected `base_sampler` to be an instance of " f"{self.__class__.__name__}. Got {base_sampler}." )