#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
References
.. [lin2024scaling]
J. A. Lin, S. Ament, M. Balandat, E. Bakshy. Scaling Gaussian Processes
for Learning Curve Prediction via Latent Kronecker Structure. NeurIPS 2024
Bayesian Decision-making and Uncertainty Workshop.
.. [lin2023sampling]
J. A. Lin, J. Antorán, s. Padhy, D. Janz, J. M. Hernández-Lobato, A. Terenin.
Sampling from Gaussian Process Posterior using Stochastic Gradient Descent.
Advances in Neural Information Processing Systems 2023.
"""
import contextlib
import warnings
from typing import Any
import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.model import FantasizeMixin, Model
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.latent_kronecker import LatentKroneckerGPPosterior
from botorch.utils.types import _DefaultType, DEFAULT
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import MaternKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.means import Mean, ZeroMean
from gpytorch.models.exact_gp import ExactGP
from gpytorch.module import Module
from linear_operator import settings
from linear_operator.operators import (
ConstantDiagLinearOperator,
KroneckerProductLinearOperator,
MaskedLinearOperator,
)
from linear_operator.utils.warnings import PerformanceWarning
from torch import Tensor
[docs]
class MinMaxStandardize(Standardize):
r"""Standardize outcomes (zero mean, unit variance),
centered about the minimum (or maximum) instead of the mean.
Otherwise equivalent to 'Standardize'.
"""
def __init__(
self,
m: int = 1,
use_min: bool = False,
outputs: list[int] | None = None,
batch_shape: torch.Size = torch.Size(), # noqa: B008
min_stdv: float = 1e-8,
) -> None:
r"""Standardize outcomes (zero mean, unit variance).
Args:
m: The output dimension.
use_min: Whether to use the minimum or maximum (instead of the mean).
outputs: Which of the outputs to standardize. If omitted, all
outputs will be standardized.
batch_shape: The batch_shape of the training targets.
min_stddv: The minimum standard deviation for which to perform
standardization (if lower, only de-mean the data).
"""
super().__init__(
m=m, outputs=outputs, batch_shape=batch_shape, min_stdv=min_stdv
)
self._use_min = use_min
[docs]
def forward(
self, Y: Tensor, Yvar: Tensor | None = None
) -> tuple[Tensor, Tensor | None]:
r"""Standardize outcomes.
If the module is in train mode, this updates the module state (i.e. the
mean/std normalizing constants). If the module is in eval mode, simply
applies the normalization using the module state.
Args:
Y: A `batch_shape x n x m`-dim tensor of training targets.
Yvar: A `batch_shape x n x m`-dim tensor of observation noises
associated with the training targets (if applicable).
Returns:
A two-tuple with the transformed outcomes:
- The transformed outcome observations.
- The transformed observation noise (if applicable).
"""
if self.training:
if Y.shape[:-2] != self._batch_shape:
raise RuntimeError(
f"Expected Y.shape[:-2] to be {self._batch_shape}, matching "
"the `batch_shape` argument to `Standardize`, but got "
f"Y.shape[:-2]={Y.shape[:-2]}."
)
if Y.size(-1) != self._m:
raise RuntimeError(
f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected "
f"{self._m}."
)
if Y.shape[-2] < 1:
raise ValueError(f"Can't standardize with no observations. {Y.shape=}.")
elif Y.shape[-2] == 1:
stdvs = torch.ones(
(*Y.shape[:-2], 1, Y.shape[-1]), dtype=Y.dtype, device=Y.device
)
else:
stdvs = Y.std(dim=-2, keepdim=True)
stdvs = stdvs.where(stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0))
means = (
Y.min(dim=-2, keepdim=True).values
if self._use_min
else Y.max(dim=-2, keepdim=True).values
)
if self._outputs is not None:
unused = [i for i in range(self._m) if i not in self._outputs]
means[..., unused] = 0.0
stdvs[..., unused] = 1.0
self.means = means
self.stdvs = stdvs
self._stdvs_sq = stdvs.pow(2)
self._is_trained = torch.tensor(True)
Y_tf = (Y - self.means) / self.stdvs
Yvar_tf = Yvar / self._stdvs_sq if Yvar is not None else None
return Y_tf, Yvar_tf
[docs]
class LatentKroneckerGP(GPyTorchModel, ExactGP, FantasizeMixin):
r"""
A multi-task GP model which uses Kronecker structure despite missing entries.
Leverages pathwise conditioning and iterative linear system solvers to
efficiently draw samples from the GP posterior. See [lin2024scaling]_
for details.
For more information about pathwise conditioning, see [wilson2021pathwise]_
and [Maddox2021bohdo]_. Details about iterative linear system solvers for GPs
with pathwise conditioning can be found in [lin2023sampling]_.
NOTE: This model requires iterative methods for efficient posterior inference.
To enable iterative methods, the `use_iterative_methods` helper function can be
used as a context manager.
Example:
>>> model = LatentKroneckerGP(train_X, train_Y)
>>> mll = ExactMarginalLogLikelihood(model.likelihood, model)
>>> with model.use_iterative_methods():
>>> fit_gpytorch_mll(mll)
>>> samples = model.posterior(test_X).rsample()
"""
def __init__(
self,
train_X: Tensor,
train_Y: Tensor,
train_Y_valid: Tensor | None = None,
T: Tensor | None = None,
likelihood: Likelihood | None = None,
mean_module_X: Mean | None = None,
mean_module_T: Mean | None = None,
covar_module_X: Module | None = None,
covar_module_T: Module | None = None,
input_transform: InputTransform | None = None,
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
) -> None:
r"""
Args:
train_X: A `batch_shape x n x d` tensor of training features.
train_Y: A `batch_shape x n x t` tensor of training observations.
train_Y_valid: A `n x t` boolean tensor of valid values.
True indicates that the corresponding value is valid.
False indicates that the corresponding value is missing.
Does not allow explicit `batch_shape` because
the mask must be shared across batch dimensions.
T: A `batch_shape x t` tensor of training time steps.
If omitted, use [1, ..., t].
likelihood: A likelihood. If omitted, use a standard
`GaussianLikelihood` with inferred noise level.
mean_module_X: The mean function to be used for X.
If omitted, use a `ConstantMean`.
mean_module_T: The mean function to be used for T.
If omitted, use a `ConstantMean`.
covar_module_X: The module computing the covariance matrix of X.
If omitted, use a `MaternKernel`.
covar_module_T: The module computing the covariance matrix of T.
If omitted, use a `MaternKernel`.
input_transform: An input transform that is applied to X.
outcome_transform: An outcome transform that is applied to Y.
"""
with torch.no_grad():
# transform inputs here to check resulting shapes
# actual transforms will be applied in forward() and posterior()
transformed_X = self.transform_inputs(
X=train_X, input_transform=input_transform
)
self._validate_tensor_args(X=transformed_X, Y=train_Y)
batch_shape, ard_num_dims = transformed_X.shape[:-2], transformed_X.shape[-1]
self.T = self._init_T(T, batch_shape, train_Y)
self._num_outputs = self.T.shape[-1]
if likelihood is None:
likelihood = GaussianLikelihood(batch_shape=batch_shape)
if train_Y_valid is not None:
if train_Y_valid.shape != train_Y.shape[-2:]:
raise BotorchTensorDimensionError(
"Explicit batch_shape not allowed for train_Y_valid, "
"because the mask must be shared across batch dimensions. "
f"Expected train_Y_valid with shape: {train_Y.shape[-2:]} "
f"(got {train_Y_valid.shape})."
)
assert train_Y_valid.dtype == torch.bool
self.mask = train_Y_valid.reshape(-1)
else:
mask_len = train_Y.shape[-2] * train_Y.shape[-1]
self.mask = torch.ones(mask_len, dtype=torch.bool, device=train_Y.device)
train_Y = train_Y.reshape(*batch_shape, -1)[..., self.mask]
if outcome_transform == DEFAULT:
outcome_transform = MinMaxStandardize(batch_shape=batch_shape)
if outcome_transform is not None:
# transform outputs once and keep the results
train_Y = outcome_transform(train_Y.unsqueeze(-1))[0].squeeze(-1)
ExactGP.__init__(
self,
train_inputs=train_X,
train_targets=train_Y,
likelihood=likelihood,
)
if mean_module_X is None:
mean_module_X = ZeroMean(batch_shape=batch_shape)
self.mean_module_X: Module = mean_module_X
if mean_module_T is None:
mean_module_T = ZeroMean(batch_shape=batch_shape)
self.mean_module_T: Module = mean_module_T
if covar_module_X is None:
covar_module_X = MaternKernel(
ard_num_dims=ard_num_dims, batch_shape=batch_shape
)
if covar_module_T is None:
covar_module_T = ScaleKernel(
base_kernel=MaternKernel(ard_num_dims=1, batch_shape=batch_shape),
)
self.covar_module_X: Module = covar_module_X
self.covar_module_T: Module = covar_module_T
if input_transform is not None:
self.input_transform = input_transform
if outcome_transform is not None:
self.outcome_transform = outcome_transform
self._cached_base_samples = None
self._cached_L_train_train_X = None
self._cached_L_T = None
self._cached_H_inv_v = None
self.to(train_X)
def _init_T(
self, T: Tensor | None, batch_shape: torch.Size, train_Y: Tensor
) -> Tensor:
if T is not None:
expected_shape = torch.Size([*batch_shape, train_Y.shape[-1]])
if T.shape != expected_shape:
raise BotorchTensorDimensionError(
f"Expected T with shape: {expected_shape} (got {T.shape})."
)
return T
else:
T = torch.linspace(
0, 1, train_Y.shape[-1], dtype=train_Y.dtype, device=train_Y.device
)
T = T.expand(*batch_shape, -1)
return T
[docs]
def use_iterative_methods(
self,
tol: float = 0.01,
max_iter: int = 10000,
covar_root_decomposition: bool = False,
log_prob: bool = True,
solves: bool = True,
):
with contextlib.ExitStack() as stack:
stack.enter_context(
settings.fast_computations(
covar_root_decomposition=covar_root_decomposition,
log_prob=log_prob,
solves=solves,
)
)
stack.enter_context(settings.cg_tolerance(tol))
stack.enter_context(settings.max_cg_iterations(max_iter))
return stack.pop_all()
def _get_mean(self, X: Tensor, mask: Tensor | None = None) -> Tensor:
mean_X = self.mean_module_X(X).unsqueeze(-1)
mean_T = self.mean_module_T(self.T.unsqueeze(-1)).unsqueeze(-1)
mean = KroneckerProductLinearOperator(mean_X, mean_T).squeeze(-1)
return mean[..., mask] if mask is not None else mean
[docs]
def forward(self, X: Tensor) -> MultivariateNormal:
if self.training:
X = self.transform_inputs(X)
mask = self.mask
else:
total_len = X.shape[-2] * self._num_outputs
mask = torch.ones(total_len, dtype=torch.bool, device=X.device)
mask[: self.mask.shape[-1]] = self.mask
mean = self._get_mean(X, mask)
covar_X = self.covar_module_X(X)
covar_T = self.covar_module_T(self.T.unsqueeze(-1))
covar = KroneckerProductLinearOperator(covar_X, covar_T)
covar = MaskedLinearOperator(covar, row_mask=mask, col_mask=mask)
return MultivariateNormal(mean, covar)
[docs]
def posterior(
self,
X: Tensor,
observation_noise: bool | Tensor = False,
posterior_transform: PosteriorTransform | None = None,
**kwargs: Any,
) -> GPyTorchPosterior:
if posterior_transform is not None:
raise NotImplementedError(
"Posterior transforms currently not supported for "
f"{self.__class__.__name__}"
)
if not isinstance(self.likelihood, GaussianLikelihood):
raise NotImplementedError(
"Only GaussianLikelihood currently supported for "
f"{self.__class__.__name__}"
)
if observation_noise is not False:
raise NotImplementedError(
"Observation noise currently not supported for "
f"{self.__class__.__name__}"
)
return LatentKroneckerGPPosterior(self, X)
def _rsample_from_base_samples(
self,
X: Tensor,
base_samples: Tensor,
observation_noise: bool | Tensor = False,
) -> Tensor:
r"""Sample from the posterior distribution at the provided points `X`
using Matheron's rule, requiring `n + 2 n_train` base samples.
Args:
X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension
of the feature space and `q` is the number of points considered
jointly
base_samples: A Tensor of `N(0, I)` base samples of shape
`sample_shape x base_sample_shape`, typically obtained from
a `Sampler`. This is used for deterministic optimization.
Returns:
Samples from the posterior, a tensor of shape
`self._extended_shape(sample_shape=sample_shape)`.
"""
# toggle eval mode to switch the behavior of input / outcome transforms
# this also implicitly applies the input transform to the train_inputs
self.eval()
X_train = self.train_inputs[0]
X_test = self.transform_inputs(X)
n_train_full = X_train.shape[-2] * self._num_outputs
n_train = self.train_targets.shape[-1]
n_test = X_test.shape[-2] * self._num_outputs
sample_shape = base_samples.shape[: -len(self.batch_shape) - 1]
w_train, eps_base, w_test = torch.split(
base_samples, [n_train_full, n_train, n_test], dim=-1
)
eps = torch.sqrt(self.likelihood.noise) * eps_base
K_T = self.covar_module_T(self.T.unsqueeze(-1))
if self._cached_base_samples is not None and torch.equal(
base_samples, self._cached_base_samples
):
L_train_train_X = self._cached_L_train_train_X
L_T = self._cached_L_T
H_inv_v = self._cached_H_inv_v
else:
# Evaluate prior mean at training data
m_train = self._get_mean(X_train, self.mask)
# Calculate prior sample
K_train_train_X = self.covar_module_X(X_train)
L_train_train_X = K_train_train_X.cholesky(upper=False)
L_T = K_T.cholesky(upper=False)
L_train_train = KroneckerProductLinearOperator(L_train_train_X, L_T)
f_prior_train = L_train_train @ w_train.unsqueeze(-1)
f_prior_train = m_train + f_prior_train.squeeze(-1)[..., self.mask]
K_train_train = KroneckerProductLinearOperator(K_train_train_X, K_T)
K_train_train = MaskedLinearOperator(
K_train_train, row_mask=self.mask, col_mask=self.mask
)
noise_covar = ConstantDiagLinearOperator(
self.likelihood.noise
* torch.ones(*self.batch_shape, 1, dtype=X.dtype, device=X.device),
diag_shape=n_train,
)
H = K_train_train + noise_covar
v = self.train_targets - (f_prior_train + eps)
# Expand once here to avoid repeated expansion
# by MaskedLinearOperator later
H_inv_v = torch.zeros(
*sample_shape,
*self.batch_shape,
n_train_full,
dtype=X.dtype,
device=X.device,
)
if settings._fast_solves.off():
warn_msg = (
"Iterative methods are disabled. Performing linear solve using "
"full joint covariance matrix, which might be slow and require "
"a lot of memory. Iterative methods can be enabled using "
"'with model.use_iterative_methods():'."
)
warnings.warn(
warn_msg,
PerformanceWarning,
stacklevel=2,
)
H_inv_v[..., self.mask] = H.solve(v.unsqueeze(-1)).squeeze(-1)
self._cached_base_samples = base_samples
self._cached_L_train_train_X = L_train_train_X
self._cached_L_T = L_T
self._cached_H_inv_v = H_inv_v
# Evaluate prior mean at test data
m_test = self._get_mean(X_test)
K_train_test_X = self.covar_module_X(X_train, X_test).evaluate_kernel()
K_test_test_X = self.covar_module_X(X_test).evaluate_kernel()
L_train_test_X = L_train_train_X.solve_triangular(
K_train_test_X.tensor, upper=False
)
L_test_test_X = (
K_test_test_X - L_train_test_X.transpose(-2, -1) @ L_train_test_X
).cholesky(upper=False)
L_test_train = KroneckerProductLinearOperator(
L_train_test_X.transpose(-2, -1), L_T
)
L_test_test = KroneckerProductLinearOperator(L_test_test_X, L_T)
# match dimensions for broadcasting
broadcast_shape = L_test_train.shape[:-2]
extra_batch_dims = len(broadcast_shape) - len(self.batch_shape)
for _ in range(extra_batch_dims):
w_train = w_train.unsqueeze(len(sample_shape))
w_test = w_test.unsqueeze(len(sample_shape))
H_inv_v = H_inv_v.unsqueeze(len(sample_shape))
f_prior_test = L_test_train @ w_train.unsqueeze(-1)
f_prior_test = f_prior_test + L_test_test @ w_test.unsqueeze(-1)
f_prior_test = m_test + f_prior_test.squeeze(-1)
K_train_test = KroneckerProductLinearOperator(K_train_test_X, K_T)
# no MaskedLinearOperator here because H_inv_v is already expanded
samples = K_train_test.transpose(-2, -1) @ H_inv_v.unsqueeze(-1)
samples = samples + f_prior_test.unsqueeze(-1)
# reshape samples to separate X and T dimensions
# samples.shape = (*sample_shape, *broadcast_shape, n_test_x * n_t, 1)
samples = samples.reshape(
*samples.shape[:-2], X_test.shape[-2], self._num_outputs
)
# samples.shape = (*sample_shape, *broadcast_shape, n_test_x, n_t)
if hasattr(self, "outcome_transform") and self.outcome_transform is not None:
samples, _ = self.outcome_transform.untransform(samples)
return samples
[docs]
def condition_on_observations(
self, X: Tensor, Y: Tensor, noise: Tensor | None = None, **kwargs: Any
) -> Model:
raise NotImplementedError(
f"Conditioning currently not supported for {self.__class__.__name__}"
)