Source code for botorch.sampling.stochastic_samplers

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Samplers to enable use cases that are not base sample driven, such as
stochastic optimization of acquisition functions.
"""

from __future__ import annotations

import torch
from botorch.posteriors import Posterior
from botorch.sampling.base import MCSampler
from torch import Tensor


[docs] class ForkedRNGSampler(MCSampler): r"""A sampler using `torch.fork_rng` to enable replicable sampling from a posterior that does not support base samples. NOTE: This approach is not a one-to-one replacement for base sample driven sampling. The main missing piece in this approach is that its outputs are not replicable across the batch dimensions. As a result, when an acquisition function is batch evaluated with repeated candidates, each candidate will produce a different acquisition value, which is not compatible with Sample Average Approximation. """
[docs] def forward(self, posterior: Posterior) -> Tensor: r"""Draws MC samples from the posterior in a `fork_rng` context. Args: posterior: The posterior to sample from. Returns: The samples drawn from the posterior. """ with torch.random.fork_rng(): torch.manual_seed(self.seed) return posterior.rsample(sample_shape=self.sample_shape)
[docs] class StochasticSampler(MCSampler): r"""A sampler that simply calls `posterior.rsample` to generate the samples. This should only be used for stochastic optimization of the acquisition functions, e.g., via `gen_candidates_torch`. This should not be used with `optimize_acqf`, which uses deterministic optimizers under the hood. NOTE: This ignores the `seed` option. """
[docs] def forward(self, posterior: Posterior) -> Tensor: r"""Draws MC samples from the posterior. Args: posterior: The posterior to sample from. Returns: The samples drawn from the posterior. """ return posterior.rsample(sample_shape=self.sample_shape)