Source code for botorch.sampling.samplers

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Sampler modules to be used with MC-evaluated acquisition functions.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Optional

import torch
from torch import Tensor
from torch.nn import Module
from torch.quasirandom import SobolEngine

from ..exceptions import UnsupportedError
from ..posteriors import Posterior
from ..utils.sampling import draw_sobol_normal_samples, manual_seed


[docs]class MCSampler(Module, ABC): r"""Abstract base class for Samplers. Subclasses must implement the `_construct_base_samples` method. Attributes: resample: If `True`, re-draw samples in each `forward` evaluation - this results in stochastic acquisition functions (and thus should not be used with deterministic optimization algorithms). collapse_batch_dims: If True, collapse the t-batch dimensions of the produced samples to size 1. This is useful for preventing sampling variance across t-batches. Example: This method is usually not called directly, but via the sampler's `__call__` method: >>> posterior = model.posterior(test_X) >>> samples = sampler(posterior) """
[docs] def forward(self, posterior: Posterior) -> Tensor: r"""Draws MC samples from the posterior. Args: posterior: The Posterior to sample from. Returns: The samples drawn from the posterior. """ base_sample_shape = self._get_base_sample_shape(posterior=posterior) self._construct_base_samples(posterior=posterior, shape=base_sample_shape) samples = posterior.rsample( sample_shape=self.sample_shape, base_samples=self.base_samples ) return samples
def _get_base_sample_shape(self, posterior: Posterior) -> torch.Size: r"""Get the shape of the base samples. Args: posterior: The Posterior to sample from. Returns: The shape of the base samples expected by the posterior. If `collapse_batch_dims=True`, the t-batch dimensions of the base samples are collapsed to size 1. This is useful to prevent sampling variance across t-batches. """ event_shape = posterior.event_shape if self.collapse_batch_dims: event_shape = torch.Size([1 for _ in event_shape[:-2]]) + event_shape[-2:] return self.sample_shape + event_shape @property def sample_shape(self) -> torch.Size: r"""The shape of a single sample""" return self._sample_shape @abstractmethod def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> None: r"""Generate base samples (if necessary). This function will generate a new set of base samples and register the `base_samples` buffer if one of the following is true: - `resample=True` - the MCSampler has no `base_samples` attribute. - `shape` is different than `self.base_samples.shape` (if `collapse_batch_dims=True`, then batch dimensions of will be automatically broadcasted as necessary) Args: posterior: The Posterior for which to generate base samples. shape: The shape of the base samples to construct. """ pass # pragma: no cover
[docs]class IIDNormalSampler(MCSampler): r"""Sampler for MC base samples using iid N(0,1) samples. Example: >>> sampler = IIDNormalSampler(1000, seed=1234) >>> posterior = model.posterior(test_X) >>> samples = sampler(posterior) """ def __init__( self, num_samples: int, resample: bool = False, seed: Optional[int] = None, collapse_batch_dims: bool = True, ) -> None: r"""Sampler for MC base samples using iid `N(0,1)` samples. Args: num_samples: The number of samples to use. resample: If `True`, re-draw samples in each `forward` evaluation - this results in stochastic acquisition functions (and thus should not be used with deterministic optimization algorithms). seed: The seed for the RNG. If omitted, use a random seed. collapse_batch_dims: If True, collapse the t-batch dimensions to size 1. This is useful for preventing sampling variance across t-batches. """ super().__init__() self._sample_shape = torch.Size([num_samples]) self.collapse_batch_dims = collapse_batch_dims self.resample = resample self.seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item() def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> None: r"""Generate iid `N(0,1)` base samples (if necessary). This function will generate a new set of base samples and set the `base_samples` buffer if one of the following is true: - `resample=True` - the MCSampler has no `base_samples` attribute. - `shape` is different than `self.base_samples.shape` (if `collapse_batch_dims=True`, then batch dimensions of will be automatically broadcasted as necessary) Args: posterior: The Posterior for which to generate base samples. shape: The shape of the base samples to construct. """ if ( self.resample or not hasattr(self, "base_samples") or self.base_samples.shape[-2:] != shape[-2:] or (not self.collapse_batch_dims and shape != self.base_samples.shape) ): with manual_seed(seed=self.seed): base_samples = torch.randn( shape, device=posterior.device, dtype=posterior.dtype ) self.seed += 1 self.register_buffer("base_samples", base_samples) elif self.collapse_batch_dims and shape != self.base_samples.shape: self.base_samples = self.base_samples.view(shape) if self.base_samples.device != posterior.device: self.to(device=posterior.device) # pragma: nocover if self.base_samples.dtype != posterior.dtype: self.to(dtype=posterior.dtype)
[docs]class SobolQMCNormalSampler(MCSampler): r"""Sampler for quasi-MC base samples using Sobol sequences. Example: >>> sampler = SobolQMCNormalSampler(1000, seed=1234) >>> posterior = model.posterior(test_X) >>> samples = sampler(posterior) """ def __init__( self, num_samples: int, resample: bool = False, seed: Optional[int] = None, collapse_batch_dims: bool = True, ) -> None: r"""Sampler for quasi-MC base samples using Sobol sequences. Args: num_samples: The number of samples to use. resample: If `True`, re-draw samples in each `forward` evaluation - this results in stochastic acquisition functions (and thus should not be used with deterministic optimization algorithms). seed: The seed for the RNG. If omitted, use a random seed. collapse_batch_dims: If True, collapse the t-batch dimensions to size 1. This is useful for preventing sampling variance across t-batches. """ super().__init__() self._sample_shape = torch.Size([num_samples]) self.collapse_batch_dims = collapse_batch_dims self.resample = resample self.seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item() def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> None: r"""Generate quasi-random Normal base samples (if necessary). This function will generate a new set of base samples and set the `base_samples` buffer if one of the following is true: - `resample=True` - the MCSampler has no `base_samples` attribute. - `shape` is different than `self.base_samples.shape` (if `collapse_batch_dims=True`, then batch dimensions of will be automatically broadcasted as necessary) Args: posterior: The Posterior for which to generate base samples. shape: The shape of the base samples to construct. """ if ( self.resample or not hasattr(self, "base_samples") or self.base_samples.shape[-2:] != shape[-2:] or (not self.collapse_batch_dims and shape != self.base_samples.shape) ): output_dim = shape[-2:].numel() if output_dim > SobolEngine.MAXDIM: raise UnsupportedError( "SobolQMCSampler only supports dimensions " f"`q * o <= {SobolEngine.MAXDIM}`. Requested: {output_dim}" ) base_samples = draw_sobol_normal_samples( d=output_dim, n=shape[:-2].numel(), device=posterior.device, dtype=posterior.dtype, seed=self.seed, ) self.seed += 1 base_samples = base_samples.view(shape) self.register_buffer("base_samples", base_samples) elif self.collapse_batch_dims and shape != posterior.event_shape: self.base_samples = self.base_samples.view(shape) if self.base_samples.device != posterior.device: self.to(device=posterior.device) # pragma: nocover if self.base_samples.dtype != posterior.dtype: self.to(dtype=posterior.dtype)