Source code for botorch.utils.gp_sampling
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from copy import deepcopy
from typing import Optional
import torch
from botorch.models.model import Model
from botorch.utils.sampling import manual_seed
from torch import Tensor
from torch.nn import Module
[docs]class GPDraw(Module):
r"""Convenience wrapper for sampling a function from a GP prior.
This wrapper implicitly defines the GP sample as a self-updating function by keeping
track of the evaluated points and respective base samples used during the
evaluation.
This does not yet support multi-output models.
"""
def __init__(self, model: Model, seed: Optional[int] = None) -> None:
r"""Construct a GP function sampler.
Args:
model: The Model defining the GP prior.
"""
super().__init__()
self._model = deepcopy(model)
seed = torch.tensor(
seed if seed is not None else torch.randint(0, 1000000, (1,)).item()
)
self.register_buffer("_seed", seed)
@property
def Xs(self) -> Tensor:
"""A `(batch_shape) x n_eval x d`-dim tensor of locations at which the GP was
evaluated (or `None` if the sample has never been evaluated).
"""
try:
return self._Xs
except AttributeError:
return None
@property
def Ys(self) -> Tensor:
"""A `(batch_shape) x n_eval x d`-dim tensor of associated function values (or
`None` if the sample has never been evaluated).
"""
try:
return self._Ys
except AttributeError:
return None
[docs] def forward(self, X: Tensor) -> Tensor:
r"""Evaluate the GP sample function at a set of points X.
Args:
X: A `batch_shape x n x d`-dim tensor of points
Returns:
The value of the GP sample at the `n` points.
"""
if self.Xs is None:
X_eval = X # first time, no previous evaluation points
else:
X_eval = torch.cat([self.Xs, X], dim=-2)
posterior = self._model.posterior(X=X_eval)
base_sample_shape = posterior.base_sample_shape
# re-use old samples
bs_shape = base_sample_shape[:-2] + X.shape[-2:-1] + base_sample_shape[-1:]
with manual_seed(seed=int(self._seed)):
new_base_samples = torch.randn(bs_shape, device=X.device, dtype=X.dtype)
seed = self._seed + 1
if self.Xs is None:
base_samples = new_base_samples
else:
base_samples = torch.cat([self._base_samples, new_base_samples], dim=-2)
# TODO: Deduplicate repeated evaluations / deal with numerical degeneracies
# that could lead to non-determinsitic evaluations. We could use SVD- or
# eigendecomposition-based sampling, but we probably don't want to use this
# by default for performance reasonse.
Ys = posterior.rsample(torch.Size(), base_samples=base_samples)
self.register_buffer("_Xs", X_eval)
self.register_buffer("_Ys", Ys)
self.register_buffer("_seed", seed)
self.register_buffer("_base_samples", base_samples)
return self.Ys[..., -(X.size(-2)) :, :]