Source code for botorch.sampling.samplers

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Sampler modules to be used with MC-evaluated acquisition functions.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Optional, Tuple

import torch
from botorch.exceptions import UnsupportedError
from botorch.posteriors import Posterior
from botorch.utils.sampling import draw_sobol_normal_samples, manual_seed
from torch import Tensor
from torch.nn import Module
from torch.quasirandom import SobolEngine


[docs]class MCSampler(Module, ABC): r"""Abstract base class for Samplers. Subclasses must implement the `_construct_base_samples` method. Attributes: resample: If `True`, re-draw samples in each `forward` evaluation - this results in stochastic acquisition functions (and thus should not be used with deterministic optimization algorithms). collapse_batch_dims: If True, collapse the t-batch dimensions of the produced samples to size 1. This is useful for preventing sampling variance across t-batches. Example: This method is usually not called directly, but via the sampler's `__call__` method: >>> posterior = model.posterior(test_X) >>> samples = sampler(posterior) """ def __init__(self, batch_range: Tuple[int, int] = (0, -2)) -> None: r"""Abstract base class for Samplers. Args: batch_range: The range of t-batch dimensions used by collapse_batch_dims. The t-batch dims are batch_range[0]:batch_range[1]. By default, this is (0, -2), for the case where the non-batch dimensions are -2 (q) and -1 (d) and all dims in the front are t-batch dims. """ super().__init__() self.batch_range = batch_range self.register_buffer("base_samples", None) @property def batch_range(self) -> Tuple[int, int]: r"""The t-batch range.""" return tuple(self._batch_range.tolist()) @batch_range.setter def batch_range(self, batch_range: Tuple[int, int]): r"""Set the t-batch range and clear base samples. Args: batch_range: The range of t-batch dimensions used by collapse_batch_dims. The t-batch dims are batch_range[0]:batch_range[1]. By default, this is (0, -2), for the case where the non-batch dimensions are -2 (q) and -1 (d) and all dims in the front are t-batch dims. """ # set t-batch range if different; trigger resample & set base_samples to None if not hasattr(self, "_batch_range") or self.batch_range != batch_range: self.register_buffer( "_batch_range", torch.tensor(batch_range, dtype=torch.long) ) self.register_buffer("base_samples", None)
[docs] def forward(self, posterior: Posterior) -> Tensor: r"""Draws MC samples from the posterior. Args: posterior: The Posterior to sample from. Returns: The samples drawn from the posterior. """ base_sample_shape = self._get_base_sample_shape(posterior=posterior) self._construct_base_samples(posterior=posterior, shape=base_sample_shape) samples = posterior.rsample( sample_shape=self.sample_shape, base_samples=self.base_samples ) return samples
def _get_base_sample_shape(self, posterior: Posterior) -> torch.Size: r"""Get the shape of the base samples. Args: posterior: The Posterior to sample from. Returns: The shape of the base samples expected by the posterior. If `collapse_batch_dims=True`, the t-batch dimensions of the base samples are collapsed to size 1. This is useful to prevent sampling variance across t-batches. """ event_shape = posterior.event_shape if self.collapse_batch_dims: batch_start, batch_end = self.batch_range event_shape = ( event_shape[0:batch_start] + torch.Size([1 for _ in event_shape[batch_start:batch_end]]) + event_shape[batch_end:] ) return self.sample_shape + event_shape @property def sample_shape(self) -> torch.Size: r"""The shape of a single sample""" return self._sample_shape @abstractmethod def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> None: r"""Generate base samples (if necessary). This function will generate a new set of base samples and register the `base_samples` buffer if one of the following is true: - `resample=True` - the MCSampler has no `base_samples` attribute. - `shape` is different than `self.base_samples.shape` (if `collapse_batch_dims=True`, then batch dimensions of will be automatically broadcasted as necessary) Args: posterior: The Posterior for which to generate base samples. shape: The shape of the base samples to construct. """ pass # pragma: no cover
[docs]class IIDNormalSampler(MCSampler): r"""Sampler for MC base samples using iid N(0,1) samples. Example: >>> sampler = IIDNormalSampler(1000, seed=1234) >>> posterior = model.posterior(test_X) >>> samples = sampler(posterior) """ def __init__( self, num_samples: int, resample: bool = False, seed: Optional[int] = None, collapse_batch_dims: bool = True, batch_range: Tuple[int, int] = (0, -2), ) -> None: r"""Sampler for MC base samples using iid `N(0,1)` samples. Args: num_samples: The number of samples to use. resample: If `True`, re-draw samples in each `forward` evaluation - this results in stochastic acquisition functions (and thus should not be used with deterministic optimization algorithms). seed: The seed for the RNG. If omitted, use a random seed. collapse_batch_dims: If True, collapse the t-batch dimensions to size 1. This is useful for preventing sampling variance across t-batches. batch_range: The range of t-batch dimensions used by collapse_batch_dims. The t-batch dims are batch_range[0]:batch_range[1]. By default, this is (0, -2), for the case where the non-batch dimensions are -2 (q) and -1 (d) and all dims in the front are t-batch dims. """ super().__init__(batch_range=batch_range) self._sample_shape = torch.Size([num_samples]) self.collapse_batch_dims = collapse_batch_dims self.resample = resample self.seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item() def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> None: r"""Generate iid `N(0,1)` base samples (if necessary). This function will generate a new set of base samples and set the `base_samples` buffer if one of the following is true: - `resample=True` - the MCSampler has no `base_samples` attribute. - `shape` is different than `self.base_samples.shape` (if `collapse_batch_dims=True`, then batch dimensions of will be automatically broadcasted as necessary) Args: posterior: The Posterior for which to generate base samples. shape: The shape of the base samples to construct. """ if ( self.resample or _check_shape_changed(self.base_samples, self.batch_range, shape) or (not self.collapse_batch_dims and shape != self.base_samples.shape) ): with manual_seed(seed=self.seed): base_samples = torch.randn( shape, device=posterior.device, dtype=posterior.dtype ) self.seed += 1 self.register_buffer("base_samples", base_samples) elif self.collapse_batch_dims and shape != self.base_samples.shape: self.base_samples = self.base_samples.view(shape) if self.base_samples.device != posterior.device: self.to(device=posterior.device) # pragma: nocover if self.base_samples.dtype != posterior.dtype: self.to(dtype=posterior.dtype)
[docs]class SobolQMCNormalSampler(MCSampler): r"""Sampler for quasi-MC base samples using Sobol sequences. Example: >>> sampler = SobolQMCNormalSampler(1000, seed=1234) >>> posterior = model.posterior(test_X) >>> samples = sampler(posterior) """ def __init__( self, num_samples: int, resample: bool = False, seed: Optional[int] = None, collapse_batch_dims: bool = True, batch_range: Tuple[int, int] = (0, -2), ) -> None: r"""Sampler for quasi-MC base samples using Sobol sequences. Args: num_samples: The number of samples to use. resample: If `True`, re-draw samples in each `forward` evaluation - this results in stochastic acquisition functions (and thus should not be used with deterministic optimization algorithms). seed: The seed for the RNG. If omitted, use a random seed. collapse_batch_dims: If True, collapse the t-batch dimensions to size 1. This is useful for preventing sampling variance across t-batches. batch_range: The range of t-batch dimensions used by collapse_batch_dims. The t-batch dims are batch_range[0]:batch_range[1]. By default, this is (0, -2), for the case where the non-batch dimensions are -2 (q) and -1 (d) and all dims in the front are t-batch dims. """ super().__init__(batch_range=batch_range) self._sample_shape = torch.Size([num_samples]) self.collapse_batch_dims = collapse_batch_dims self.resample = resample self.seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item() def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> None: r"""Generate quasi-random Normal base samples (if necessary). This function will generate a new set of base samples and set the `base_samples` buffer if one of the following is true: - `resample=True` - the MCSampler has no `base_samples` attribute. - `shape` is different than `self.base_samples.shape` (if `collapse_batch_dims=True`, then batch dimensions of will be automatically broadcasted as necessary) Args: posterior: The Posterior for which to generate base samples. shape: The shape of the base samples to construct. """ if ( self.resample or _check_shape_changed(self.base_samples, self.batch_range, shape) or (not self.collapse_batch_dims and shape != self.base_samples.shape) ): batch_start, batch_end = self.batch_range output_dim = (shape[:batch_start] + shape[batch_end:]).numel() if output_dim > SobolEngine.MAXDIM: raise UnsupportedError( "SobolQMCSampler only supports dimensions " f"`q * o <= {SobolEngine.MAXDIM}`. Requested: {output_dim}" ) base_samples = draw_sobol_normal_samples( d=output_dim, n=shape[batch_start:batch_end].numel(), device=posterior.device, dtype=posterior.dtype, seed=self.seed, ) self.seed += 1 base_samples = base_samples.view(shape) self.register_buffer("base_samples", base_samples) elif self.collapse_batch_dims and shape != posterior.event_shape: self.base_samples = self.base_samples.view(shape) if self.base_samples.device != posterior.device: self.to(device=posterior.device) # pragma: nocover if self.base_samples.dtype != posterior.dtype: self.to(dtype=posterior.dtype)
def _check_shape_changed( base_samples: Optional[Tensor], batch_range: Tuple[int, int], shape: torch.Size ) -> bool: r"""Check if the base samples shape matches a given shape in non batch dims. Args: base_samples: The Posterior for which to generate base samples. batch_range: The range t-batch dimensions to ignore for shape check. shape: The shape to compare. Returns: A bool indicating whether the shape changed. """ if base_samples is None: return True batch_start = batch_range[0] batch_end = batch_range[1] return ( base_samples.shape[batch_end:] != shape[batch_end:] or base_samples.shape[:batch_start] != shape[:batch_start] )