Source code for botorch.generation.sampling

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

Sampling-based generation strategies.

A SamplingStrategy returns samples from the input points (i.e. Tensors in feature
space), rather than the value for a set of tensors, as acquisition functions do.
The q-batch dimension has similar semantics as for acquisition functions in that the
points across the q-batch are considered jointly for sampling (where as for
q-acquisition functions we evaluate the joint value of the q-batch).

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Optional

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.objective import (
from botorch.generation.utils import _flip_sub_unique
from botorch.models.model import Model
from botorch.utils.sampling import batched_multinomial
from botorch.utils.transforms import standardize
from torch import Tensor
from torch.nn import Module

[docs]class SamplingStrategy(Module, ABC): r"""Abstract base class for sampling-based generation strategies."""
[docs] @abstractmethod def forward(self, X: Tensor, num_samples: int = 1, **kwargs: Any) -> Tensor: r"""Sample according to the SamplingStrategy. Args: X: A `batch_shape x N x d`-dim Tensor from which to sample (in the `N` dimension). num_samples: The number of samples to draw. kwargs: Additional implementation-specific kwargs. Returns: A `batch_shape x num_samples x d`-dim Tensor of samples from `X`, where `X[..., i, :]` is the `i`-th sample. """ pass # pragma: no cover
[docs]class MaxPosteriorSampling(SamplingStrategy): r"""Sample from a set of points according to their max posterior value. Example: >>> MPS = MaxPosteriorSampling(model) # model w/ feature dim d=3 >>> X = torch.rand(2, 100, 3) >>> sampled_X = MPS(X, num_samples=5) """ def __init__( self, model: Model, objective: Optional[MCAcquisitionObjective] = None, posterior_transform: Optional[PosteriorTransform] = None, replacement: bool = True, ) -> None: r"""Constructor for the SamplingStrategy base class. Args: model: A fitted model. objective: The MCAcquisitionObjective under which the samples are evaluated. Defaults to `IdentityMCObjective()`. posterior_transform: An optional PosteriorTransform. replacement: If True, sample with replacement. """ super().__init__() self.model = model if objective is None: objective = IdentityMCObjective() elif not isinstance(objective, MCAcquisitionObjective): # TODO: Clean up once ScalarizedObjective is removed. if posterior_transform is not None: raise RuntimeError( "A ScalarizedObjective (DEPRECATED) and a posterior transform " "are not supported at the same time. Use only a posterior " "transform instead." ) else: posterior_transform = ScalarizedPosteriorTransform( weights=objective.weights, offset=objective.offset ) objective = IdentityMCObjective() self.objective = objective self.posterior_transform = posterior_transform self.replacement = replacement
[docs] def forward( self, X: Tensor, num_samples: int = 1, observation_noise: bool = False ) -> Tensor: r"""Sample from the model posterior. Args: X: A `batch_shape x N x d`-dim Tensor from which to sample (in the `N` dimension) according to the maximum posterior value under the objective. num_samples: The number of samples to draw. observation_noise: If True, sample with observation noise. Returns: A `batch_shape x num_samples x d`-dim Tensor of samples from `X`, where `X[..., i, :]` is the `i`-th sample. """ posterior = self.model.posterior( X, observation_noise=observation_noise, posterior_transform=self.posterior_transform, ) # num_samples x batch_shape x N x m samples = posterior.rsample(sample_shape=torch.Size([num_samples])) obj = self.objective(samples, X=X) # num_samples x batch_shape x N if self.replacement: # if we allow replacement then things are simple(r) idcs = torch.argmax(obj, dim=-1) else: # if we need to deduplicate we have to do some tensor acrobatics # first we get the indices associated w/ the num_samples top samples _, idcs_full = torch.topk(obj, num_samples, dim=-1) # generate some indices to smartly index into the lower triangle of # idcs_full (broadcasting across batch dimensions) ridx, cindx = torch.tril_indices(num_samples, num_samples) # pick the unique indices in order - since we look at the lower triangle # of the index matrix and we don't sort, this achieves deduplication sub_idcs = idcs_full[ridx, ..., cindx] if sub_idcs.ndim == 1: idcs = _flip_sub_unique(sub_idcs, num_samples) elif sub_idcs.ndim == 2: # TODO: Find a better way to do this n_b = sub_idcs.size(-1) idcs = torch.stack( [_flip_sub_unique(sub_idcs[:, i], num_samples) for i in range(n_b)], dim=-1, ) else: # TODO: Find a general way to do this efficiently. raise NotImplementedError( "MaxPosteriorSampling without replacement for more than a single " "batch dimension is not yet implemented." ) # idcs is num_samples x batch_shape, to index into X we need to permute for it # to have shape batch_shape x num_samples if idcs.ndim > 1: idcs = idcs.permute(*range(1, idcs.ndim), 0) # in order to use gather, we need to repeat the index tensor d times idcs = idcs.unsqueeze(-1).expand(*idcs.shape, X.size(-1)) # now if the model is batched batch_shape will not necessarily be the # batch_shape of X, so we expand X to the proper shape Xe = X.expand(*obj.shape[1:], X.size(-1)) # finally we can gather along the N dimension return torch.gather(Xe, -2, idcs)
[docs]class BoltzmannSampling(SamplingStrategy): r"""Sample from a set of points according to a tempered acquisition value. Given an acquisition function `acq_func`, this sampling strategies draws samples from a `batch_shape x N x d`-dim tensor `X` according to a multinomial distribution over its indices given by weight(X[..., i, :]) ~ exp(eta * standardize(acq_func(X[..., i, :]))) where `standardize(Y)` standardizes `Y` to zero mean and unit variance. As the temperature parameter `eta -> 0`, this approaches uniform sampling, while as `eta -> infty`, this approaches selecting the maximizer(s) of the acquisition function `acq_func`. Example: >>> UCB = UpperConfidenceBound(model, beta=0.1) >>> BMUCB = BoltzmannSampling(UCB, eta=0.5) >>> X = torch.rand(2, 100, 3) >>> sampled_X = BMUCB(X, num_samples=5) """ def __init__( self, acq_func: AcquisitionFunction, eta: float = 1.0, replacement: bool = True ) -> None: r"""Boltzmann Acquisition Value Sampling. Args: acq_func: The acquisition function; to be evaluated in batch at the individual points of a q-batch (not jointly, as is the case for acquisition functions). Can be analytic or Monte-Carlo. eta: The temperature parameter in the softmax. replacement: If True, sample with replacement. """ super().__init__() self.acq_func = acq_func self.eta = eta self.replacement = replacement
[docs] def forward(self, X: Tensor, num_samples: int = 1) -> Tensor: r"""Sample from a tempered value of the acquisition function value. Args: X: A `batch_shape x N x d`-dim Tensor from which to sample (in the `N` dimension) according to the maximum posterior value under the objective. Note that if a batched model is used in the underlying acquisition function, then its batch shape must be broadcastable to `batch_shape`. num_samples: The number of samples to draw. Returns: A `batch_shape x num_samples x d`-dim Tensor of samples from `X`, where `X[..., i, :]` is the `i`-th sample. """ # TODO: Can we get the model batch shape property from the model? # we move the `N` dimension to the front for evaluating the acquisition function # so that X_eval has shape `N x batch_shape x 1 x d` X_eval = X.permute(-2, *range(X.ndim - 2), -1).unsqueeze(-2) acqval = self.acq_func(X_eval) # N x batch_shape # now move the `N` dimension back (this is the number of categories) acqval = acqval.permute(*range(1, X.ndim - 1), 0) # batch_shape x N weights = torch.exp(self.eta * standardize(acqval)) # batch_shape x N idcs = batched_multinomial( weights=weights, num_samples=num_samples, replacement=self.replacement ) # now do some gathering acrobatics to select the right elements from X return torch.gather(X, -2, idcs.unsqueeze(-1).expand(*idcs.shape, X.size(-1)))