#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Sampler modules to be used with MC-evaluated acquisition functions.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional
import torch
from botorch.exceptions import UnsupportedError
from botorch.posteriors import Posterior
from botorch.utils.sampling import draw_sobol_normal_samples, manual_seed
from torch import Tensor
from torch.nn import Module
from torch.quasirandom import SobolEngine
[docs]class MCSampler(Module, ABC):
r"""Abstract base class for Samplers.
Subclasses must implement the `_construct_base_samples` method.
Attributes:
resample: If `True`, re-draw samples in each `forward` evaluation -
this results in stochastic acquisition functions (and thus should
not be used with deterministic optimization algorithms).
collapse_batch_dims: If True, collapse the t-batch dimensions of the
produced samples to size 1. This is useful for preventing sampling
variance across t-batches.
Example:
This method is usually not called directly, but via the sampler's
`__call__` method:
>>> posterior = model.posterior(test_X)
>>> samples = sampler(posterior)
"""
[docs] def forward(self, posterior: Posterior) -> Tensor:
r"""Draws MC samples from the posterior.
Args:
posterior: The Posterior to sample from.
Returns:
The samples drawn from the posterior.
"""
base_sample_shape = self._get_base_sample_shape(posterior=posterior)
self._construct_base_samples(posterior=posterior, shape=base_sample_shape)
samples = posterior.rsample(
sample_shape=self.sample_shape, base_samples=self.base_samples
)
return samples
def _get_base_sample_shape(self, posterior: Posterior) -> torch.Size:
r"""Get the shape of the base samples.
Args:
posterior: The Posterior to sample from.
Returns:
The shape of the base samples expected by the posterior. If
`collapse_batch_dims=True`, the t-batch dimensions of the base
samples are collapsed to size 1. This is useful to prevent sampling
variance across t-batches.
"""
event_shape = posterior.event_shape
if self.collapse_batch_dims:
event_shape = torch.Size([1 for _ in event_shape[:-2]]) + event_shape[-2:]
return self.sample_shape + event_shape
@property
def sample_shape(self) -> torch.Size:
r"""The shape of a single sample"""
return self._sample_shape
@abstractmethod
def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> None:
r"""Generate base samples (if necessary).
This function will generate a new set of base samples and register the
`base_samples` buffer if one of the following is true:
- `resample=True`
- the MCSampler has no `base_samples` attribute.
- `shape` is different than `self.base_samples.shape` (if
`collapse_batch_dims=True`, then batch dimensions of will be
automatically broadcasted as necessary)
Args:
posterior: The Posterior for which to generate base samples.
shape: The shape of the base samples to construct.
"""
pass # pragma: no cover
[docs]class IIDNormalSampler(MCSampler):
r"""Sampler for MC base samples using iid N(0,1) samples.
Example:
>>> sampler = IIDNormalSampler(1000, seed=1234)
>>> posterior = model.posterior(test_X)
>>> samples = sampler(posterior)
"""
def __init__(
self,
num_samples: int,
resample: bool = False,
seed: Optional[int] = None,
collapse_batch_dims: bool = True,
) -> None:
r"""Sampler for MC base samples using iid `N(0,1)` samples.
Args:
num_samples: The number of samples to use.
resample: If `True`, re-draw samples in each `forward` evaluation -
this results in stochastic acquisition functions (and thus should
not be used with deterministic optimization algorithms).
seed: The seed for the RNG. If omitted, use a random seed.
collapse_batch_dims: If True, collapse the t-batch dimensions to
size 1. This is useful for preventing sampling variance across
t-batches.
"""
super().__init__()
self._sample_shape = torch.Size([num_samples])
self.collapse_batch_dims = collapse_batch_dims
self.resample = resample
self.seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item()
def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> None:
r"""Generate iid `N(0,1)` base samples (if necessary).
This function will generate a new set of base samples and set the
`base_samples` buffer if one of the following is true:
- `resample=True`
- the MCSampler has no `base_samples` attribute.
- `shape` is different than `self.base_samples.shape` (if
`collapse_batch_dims=True`, then batch dimensions of will be
automatically broadcasted as necessary)
Args:
posterior: The Posterior for which to generate base samples.
shape: The shape of the base samples to construct.
"""
if (
self.resample
or not hasattr(self, "base_samples")
or self.base_samples.shape[-2:] != shape[-2:]
or (not self.collapse_batch_dims and shape != self.base_samples.shape)
):
with manual_seed(seed=self.seed):
base_samples = torch.randn(
shape, device=posterior.device, dtype=posterior.dtype
)
self.seed += 1
self.register_buffer("base_samples", base_samples)
elif self.collapse_batch_dims and shape != self.base_samples.shape:
self.base_samples = self.base_samples.view(shape)
if self.base_samples.device != posterior.device:
self.to(device=posterior.device) # pragma: nocover
if self.base_samples.dtype != posterior.dtype:
self.to(dtype=posterior.dtype)
[docs]class SobolQMCNormalSampler(MCSampler):
r"""Sampler for quasi-MC base samples using Sobol sequences.
Example:
>>> sampler = SobolQMCNormalSampler(1000, seed=1234)
>>> posterior = model.posterior(test_X)
>>> samples = sampler(posterior)
"""
def __init__(
self,
num_samples: int,
resample: bool = False,
seed: Optional[int] = None,
collapse_batch_dims: bool = True,
) -> None:
r"""Sampler for quasi-MC base samples using Sobol sequences.
Args:
num_samples: The number of samples to use.
resample: If `True`, re-draw samples in each `forward` evaluation -
this results in stochastic acquisition functions (and thus should
not be used with deterministic optimization algorithms).
seed: The seed for the RNG. If omitted, use a random seed.
collapse_batch_dims: If True, collapse the t-batch dimensions to
size 1. This is useful for preventing sampling variance across
t-batches.
"""
super().__init__()
self._sample_shape = torch.Size([num_samples])
self.collapse_batch_dims = collapse_batch_dims
self.resample = resample
self.seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item()
def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> None:
r"""Generate quasi-random Normal base samples (if necessary).
This function will generate a new set of base samples and set the
`base_samples` buffer if one of the following is true:
- `resample=True`
- the MCSampler has no `base_samples` attribute.
- `shape` is different than `self.base_samples.shape` (if
`collapse_batch_dims=True`, then batch dimensions of will be
automatically broadcasted as necessary)
Args:
posterior: The Posterior for which to generate base samples.
shape: The shape of the base samples to construct.
"""
if (
self.resample
or not hasattr(self, "base_samples")
or self.base_samples.shape[-2:] != shape[-2:]
or (not self.collapse_batch_dims and shape != self.base_samples.shape)
):
output_dim = shape[-2:].numel()
if output_dim > SobolEngine.MAXDIM:
raise UnsupportedError(
"SobolQMCSampler only supports dimensions "
f"`q * o <= {SobolEngine.MAXDIM}`. Requested: {output_dim}"
)
base_samples = draw_sobol_normal_samples(
d=output_dim,
n=shape[:-2].numel(),
device=posterior.device,
dtype=posterior.dtype,
seed=self.seed,
)
self.seed += 1
base_samples = base_samples.view(shape)
self.register_buffer("base_samples", base_samples)
elif self.collapse_batch_dims and shape != posterior.event_shape:
self.base_samples = self.base_samples.view(shape)
if self.base_samples.device != posterior.device:
self.to(device=posterior.device) # pragma: nocover
if self.base_samples.dtype != posterior.dtype:
self.to(dtype=posterior.dtype)