#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import torch
from botorch.exceptions.errors import BotorchError
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.distributions.multitask_multivariate_normal import (
MultitaskMultivariateNormal,
)
from gpytorch.lazy import BlockDiagLazyTensor
from gpytorch.lazy.lazy_tensor import LazyTensor
from gpytorch.utils.cholesky import psd_safe_cholesky
from gpytorch.utils.errors import NanError
from torch import Tensor
def _reshape_base_samples(
base_samples: Tensor, sample_shape: torch.Size, posterior: GPyTorchPosterior
) -> Tensor:
r"""Manipulate shape of base_samples to match `MultivariateNormal.rsample`.
This ensure that base_samples are used in the same way as in
gpytorch.distributions.MultivariateNormal. For CBD, it is important to ensure
that the same base samples are used for the in-sample points here and in the
cached box decompositions.
Args:
base_samples: The base samples.
sample_shape: The sample shape.
posterior: The joint posterior is over (X_baseline, X).
Returns:
Reshaped and expanded base samples.
"""
loc = posterior.mvn.loc
peshape = posterior.event_shape
base_samples = base_samples.view(
sample_shape + torch.Size([1 for _ in range(loc.ndim - 1)]) + peshape[-2:]
).expand(sample_shape + loc.shape[:-1] + peshape[-2:])
base_samples = base_samples.reshape(
-1, *loc.shape[:-1], posterior.mvn.lazy_covariance_matrix.shape[-1]
)
base_samples = base_samples.permute(*range(1, loc.dim() + 1), 0)
return base_samples.reshape(
*peshape[:-2],
peshape[-1],
peshape[-2],
*sample_shape,
)
[docs]def sample_cached_cholesky(
posterior: GPyTorchPosterior,
baseline_L: Tensor,
q: int,
base_samples: Tensor,
sample_shape: torch.Size,
max_tries: int = 6,
) -> Tensor:
r"""Get posterior samples at the `q` new points from the joint multi-output posterior.
Args:
posterior: The joint posterior is over (X_baseline, X).
baseline_L: The baseline lower triangular cholesky factor.
q: The number of new points in X.
base_samples: The base samples.
sample_shape: The sample shape.
max_tries: The number of tries for computing the Cholesky
decomposition with increasing jitter.
Returns:
A `sample_shape x batch_shape x q x m`-dim tensor of posterior
samples at the new points.
"""
# compute bottom left covariance block
if isinstance(posterior.mvn, MultitaskMultivariateNormal):
lazy_covar = extract_batch_covar(mt_mvn=posterior.mvn)
else:
lazy_covar = posterior.mvn.lazy_covariance_matrix
# Get the `q` new rows of the batched covariance matrix
bottom_rows = lazy_covar[..., -q:, :].evaluate()
# The covariance in block form is:
# [K(X_baseline, X_baseline), K(X_baseline, X)]
# [K(X, X_baseline), K(X, X)]
# bl := K(X, X_baseline)
# br := K(X, X)
# Get bottom right block of new covariance
bl, br = torch.split(bottom_rows, bottom_rows.shape[-1] - q, dim=-1)
# Solve Ax = b
# where A = K(X_baseline, X_baseline) and b = K(X, X_baseline)^T
# and bl_chol := x^T
# bl_chol is the new `(batch_shape) x q x n`-dim bottom left block
# of the cholesky decomposition
bl_chol = torch.triangular_solve(
bl.transpose(-2, -1), baseline_L, upper=False
).solution.transpose(-2, -1)
# Compute the new bottom right block of the Cholesky
# decomposition via:
# Cholesky(K(X, X) - bl_chol @ bl_chol^T)
br_to_chol = br - bl_chol @ bl_chol.transpose(-2, -1)
# TODO: technically we should make sure that we add a
# consistent nugget to the cached covariance and the new block
br_chol = psd_safe_cholesky(br_to_chol, max_tries=max_tries)
# Create a `(batch_shape) x q x (n+q)`-dim tensor containing the
# `q` new bottom rows of the Cholesky decomposition
new_Lq = torch.cat([bl_chol, br_chol], dim=-1)
mean = posterior.mvn.mean
base_samples = _reshape_base_samples(
base_samples=base_samples, sample_shape=sample_shape, posterior=posterior
)
if not isinstance(posterior.mvn, MultitaskMultivariateNormal):
# add output dim
mean = mean.unsqueeze(-1)
# add batch dim corresponding to output dim
new_Lq = new_Lq.unsqueeze(-3)
new_mean = mean[..., -q:, :]
res = (
new_Lq.matmul(base_samples)
.permute(-1, *range(mean.dim() - 2), -2, -3)
.contiguous()
.add(new_mean)
)
contains_nans = torch.isnan(res).any()
contains_infs = torch.isinf(res).any()
if contains_nans or contains_infs:
suffix_args = []
if contains_nans:
suffix_args.append("nans")
if contains_infs:
suffix_args.append("infs")
suffix = " and ".join(suffix_args)
raise NanError(f"Samples contain {suffix}.")
return res