# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""Multi-task Gaussian Process Regression models with fully Bayesian inference.
"""
from typing import Any, Dict, List, NoReturn, Optional, Tuple
import pyro
import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.fully_bayesian import (
matern52_kernel,
MIN_INFERRED_NOISE_LEVEL,
PyroModel,
reshape_and_detach,
SaasPyroModel,
)
from botorch.models.multitask import FixedNoiseMultiTaskGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.posteriors.fully_bayesian import FullyBayesianPosterior, MCMC_DIM
from botorch.utils.datasets import SupervisedDataset
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import MaternKernel
from gpytorch.kernels.kernel import Kernel
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.means.mean import Mean
from torch import Tensor
from torch.nn.parameter import Parameter
[docs]class MultitaskSaasPyroModel(SaasPyroModel):
r"""
Implementation of the multi-task sparse axis-aligned subspace priors (SAAS) model.
The multi-task model uses an ICM kernel. The data kernel is same as the single task
SAAS model in order to handle high-dimensional parameter spaces. The task kernel
is a Matern-5/2 kernel using learned task embeddings as the input.
"""
[docs] def sample(self) -> None:
r"""Sample from the SAAS model.
This samples the mean, noise variance, outputscale, and lengthscales according
to the SAAS prior.
"""
tkwargs = {"dtype": self.train_X.dtype, "device": self.train_X.device}
base_idxr = torch.arange(self.ard_num_dims, **{"device": tkwargs["device"]})
base_idxr[self.task_feature :] += 1 # exclude task feature
task_indices = self.train_X[..., self.task_feature].to(
device=tkwargs["device"], dtype=torch.long
)
outputscale = self.sample_outputscale(concentration=2.0, rate=0.15, **tkwargs)
mean = self.sample_mean(**tkwargs)
noise = self.sample_noise(**tkwargs)
lengthscale = self.sample_lengthscale(dim=self.ard_num_dims, **tkwargs)
K = matern52_kernel(X=self.train_X[..., base_idxr], lengthscale=lengthscale)
# compute task covar matrix
task_latent_features = self.sample_latent_features(**tkwargs)[task_indices]
task_lengthscale = self.sample_task_lengthscale(**tkwargs)
task_covar = matern52_kernel(
X=task_latent_features, lengthscale=task_lengthscale
)
K = K.mul(task_covar)
K = outputscale * K + noise * torch.eye(self.train_X.shape[0], **tkwargs)
pyro.sample(
"Y",
pyro.distributions.MultivariateNormal(
loc=mean.view(-1).expand(self.train_X.shape[0]),
covariance_matrix=K,
),
obs=self.train_Y.squeeze(-1),
)
[docs] def sample_latent_features(self, **tkwargs: Any):
return pyro.sample(
"latent_features",
pyro.distributions.Normal(
torch.tensor(0.0, **tkwargs),
torch.tensor(1.0, **tkwargs),
).expand(torch.Size([self.num_tasks, self.task_rank])),
)
[docs] def sample_task_lengthscale(
self, concentration: float = 6.0, rate: float = 3.0, **tkwargs: Any
):
return pyro.sample(
"task_lengthscale",
pyro.distributions.Gamma(
torch.tensor(concentration, **tkwargs),
torch.tensor(rate, **tkwargs),
).expand(torch.Size([self.task_rank])),
)
[docs] def load_mcmc_samples(
self, mcmc_samples: Dict[str, Tensor]
) -> Tuple[Mean, Kernel, Likelihood, Kernel, Parameter]:
r"""Load the MCMC samples into the mean_module, covar_module, and likelihood."""
tkwargs = {"device": self.train_X.device, "dtype": self.train_X.dtype}
num_mcmc_samples = len(mcmc_samples["mean"])
batch_shape = torch.Size([num_mcmc_samples])
mean_module, covar_module, likelihood = super().load_mcmc_samples(
mcmc_samples=mcmc_samples
)
task_covar_module = MaternKernel(
nu=2.5,
ard_num_dims=self.task_rank,
batch_shape=batch_shape,
).to(**tkwargs)
task_covar_module.lengthscale = reshape_and_detach(
target=task_covar_module.lengthscale,
new_value=mcmc_samples["task_lengthscale"],
)
latent_features = Parameter(
torch.rand(
batch_shape + torch.Size([self.num_tasks, self.task_rank]),
requires_grad=True,
**tkwargs,
)
)
latent_features = reshape_and_detach(
target=latent_features,
new_value=mcmc_samples["latent_features"],
)
return mean_module, covar_module, likelihood, task_covar_module, latent_features
[docs]class SaasFullyBayesianMultiTaskGP(FixedNoiseMultiTaskGP):
r"""A fully Bayesian multi-task GP model with the SAAS prior.
This model assumes that the inputs have been normalized to [0, 1]^d and that the
output has been stratified standardized to have zero mean and unit variance for
each task.The SAAS model [Eriksson2021saasbo]_ with a Matern-5/2 is used as data
kernel by default.
You are expected to use `fit_fully_bayesian_model_nuts` to fit this model as it
isn't compatible with `fit_gpytorch_model`.
Example:
>>> X1, X2 = torch.rand(10, 2), torch.rand(20, 2)
>>> i1, i2 = torch.zeros(10, 1), torch.ones(20, 1)
>>> train_X = torch.cat([
>>> torch.cat([X1, i1], -1), torch.cat([X2, i2], -1),
>>> ])
>>> train_Y = torch.cat(f1(X1), f2(X2)).unsqueeze(-1)
>>> train_Yvar = 0.01 * torch.ones_like(train_Y)
>>> mtsaas_gp = SaasFullyBayesianFixedNoiseMultiTaskGP(
>>> train_X, train_Y, train_Yvar, task_feature=-1,
>>> )
>>> fit_fully_bayesian_model_nuts(mtsaas_gp)
>>> posterior = mtsaas_gp.posterior(test_X)
"""
def __init__(
self,
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Optional[Tensor],
task_feature: int,
output_tasks: Optional[List[int]] = None,
rank: Optional[int] = None,
outcome_transform: Optional[OutcomeTransform] = None,
input_transform: Optional[InputTransform] = None,
pyro_model: Optional[PyroModel] = None,
) -> None:
r"""Initialize the fully Bayesian multi-task GP model.
Args:
train_X: Training inputs (n x (d + 1))
train_Y: Training targets (n x 1)
train_Yvar: Observed noise variance (n x 1). If None, we infer the noise.
Note that the inferred noise is common across all tasks.
task_feature: The index of the task feature (`-d <= task_feature <= d`).
rank: The num of learned task embeddings to be used in the task kernel.
If omitted, set it to be 1.
"""
if not (
train_X.ndim == train_Y.ndim == 2
and len(train_X) == len(train_Y)
and train_Y.shape[-1] == 1
):
raise ValueError(
"Expected train_X to have shape n x d and train_Y to have shape n x 1"
)
if train_Yvar is not None and train_Y.shape != train_Yvar.shape:
raise ValueError(
"Expected train_Yvar to be None or have the same shape as train_Y"
)
with torch.no_grad():
transformed_X = self.transform_inputs(
X=train_X, input_transform=input_transform
)
if outcome_transform is not None:
train_Y, train_Yvar = outcome_transform(train_Y, train_Yvar)
if train_Yvar is not None: # Clamp after transforming
train_Yvar = train_Yvar.clamp(MIN_INFERRED_NOISE_LEVEL)
super().__init__(
train_X=train_X,
train_Y=train_Y,
# Using train_Y as a dummy input here when train_Yvar is None.
train_Yvar=train_Yvar if train_Yvar is not None else train_Y,
task_feature=task_feature,
output_tasks=output_tasks,
)
self.to(train_X)
self.mean_module = None
self.covar_module = None
self.likelihood = None
self.task_covar_module = None
if pyro_model is None:
pyro_model = MultitaskSaasPyroModel()
pyro_model.set_inputs(
train_X=transformed_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
task_feature=task_feature,
task_rank=rank,
)
self.pyro_model = pyro_model
if outcome_transform is not None:
self.outcome_transform = outcome_transform
if input_transform is not None:
self.input_transform = input_transform
[docs] def train(self, mode: bool = True) -> None:
r"""Puts the model in `train` mode."""
super().train(mode=mode)
if mode:
self.mean_module = None
self.covar_module = None
self.likelihood = None
self.task_covar_module = None
@property
def median_lengthscale(self) -> Tensor:
r"""Median lengthscales across the MCMC samples."""
self._check_if_fitted()
lengthscale = self.covar_module.base_kernel.lengthscale.clone()
return lengthscale.median(0).values.squeeze(0)
@property
def num_mcmc_samples(self) -> int:
r"""Number of MCMC samples in the model."""
self._check_if_fitted()
return len(self.covar_module.outputscale)
@property
def batch_shape(self) -> torch.Size:
r"""Batch shape of the model, equal to the number of MCMC samples.
Note that `SaasFullyBayesianMultiTaskGP` does not support batching
over input data at this point.
"""
self._check_if_fitted()
return torch.Size([self.num_mcmc_samples])
[docs] def fantasize(self, *args, **kwargs) -> NoReturn:
raise NotImplementedError("Fantasize is not implemented!")
def _check_if_fitted(self):
r"""Raise an exception if the model hasn't been fitted."""
if self.covar_module is None:
raise RuntimeError(
"Model has not been fitted. You need to call "
"`fit_fully_bayesian_model_nuts` to fit the model."
)
[docs] def load_mcmc_samples(self, mcmc_samples: Dict[str, Tensor]) -> None:
r"""Load the MCMC hyperparameter samples into the model.
This method will be called by `fit_fully_bayesian_model_nuts` when the model
has been fitted in order to create a batched MultiTaskGP model.
"""
(
self.mean_module,
self.covar_module,
self.likelihood,
self.task_covar_module,
self.latent_features,
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
[docs] def posterior(
self,
X: Tensor,
output_indices: Optional[List[int]] = None,
observation_noise: bool = False,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Any,
) -> FullyBayesianPosterior:
r"""Computes the posterior over model outputs at the provided points.
Returns:
A `FullyBayesianPosterior` object. Includes observation noise if specified.
"""
self._check_if_fitted()
posterior = super().posterior(
X=X,
output_indices=output_indices,
observation_noise=observation_noise,
posterior_transform=posterior_transform,
**kwargs,
)
posterior = FullyBayesianPosterior(distribution=posterior.distribution)
return posterior
[docs] def forward(self, X: Tensor) -> MultivariateNormal:
self._check_if_fitted()
X = X.unsqueeze(MCMC_DIM)
x_basic, task_idcs = self._split_inputs(X)
mean_x = self.mean_module(x_basic)
covar_x = self.covar_module(x_basic)
tsub_idcs = task_idcs.squeeze(-3).squeeze(-1)
latent_features = self.latent_features[:, tsub_idcs, :]
if X.ndim > 3:
# batch eval mode
# for X (batch_shape x num_samples x q x d), task_idcs[:,i,:,] are the same
# reshape X to (batch_shape x num_samples x q x d)
latent_features = latent_features.permute(
[-i for i in range(X.ndim - 1, 2, -1)]
+ [0]
+ [-i for i in range(2, 0, -1)]
)
# Combine the two in an ICM fashion
covar_i = self.task_covar_module(latent_features)
covar = covar_x.mul(covar_i)
return MultivariateNormal(mean_x, covar)