Source code for botorch.cross_validation

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Cross-validation utilities using batch evaluation mode.
"""

from __future__ import annotations

from typing import Any, NamedTuple

import torch
from botorch.exceptions.errors import UnsupportedError
from botorch.fit import fit_gpytorch_mll
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.multitask import MultiTaskGP
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from torch import Tensor


[docs] class CVFolds(NamedTuple): train_X: Tensor test_X: Tensor train_Y: Tensor test_Y: Tensor train_Yvar: Tensor | None = None test_Yvar: Tensor | None = None
[docs] class CVResults(NamedTuple): model: GPyTorchModel posterior: GPyTorchPosterior observed_Y: Tensor observed_Yvar: Tensor | None = None
[docs] def gen_loo_cv_folds( train_X: Tensor, train_Y: Tensor, train_Yvar: Tensor | None = None ) -> CVFolds: r"""Generate LOO CV folds w.r.t. to `n`. Args: train_X: A `n x d` or `batch_shape x n x d` (batch mode) tensor of training features. train_Y: A `n x (m)` or `batch_shape x n x (m)` (batch mode) tensor of training observations. train_Yvar: An `n x (m)` or `batch_shape x n x (m)` (batch mode) tensor of observed measurement noise. Returns: CVFolds NamedTuple with the following fields: - train_X: A `n x (n-1) x d` or `batch_shape x n x (n-1) x d` tensor of training features. - test_X: A `n x 1 x d` or `batch_shape x n x 1 x d` tensor of test features. - train_Y: A `n x (n-1) x m` or `batch_shape x n x (n-1) x m` tensor of training observations. - test_Y: A `n x 1 x m` or `batch_shape x n x 1 x m` tensor of test observations. - train_Yvar: A `n x (n-1) x m` or `batch_shape x n x (n-1) x m` tensor of observed measurement noise. - test_Yvar: A `n x 1 x m` or `batch_shape x n x 1 x m` tensor of observed measurement noise. Example: >>> train_X = torch.rand(10, 1) >>> train_Y = torch.rand_like(train_X) >>> cv_folds = gen_loo_cv_folds(train_X, train_Y) >>> cv_folds.train_X.shape torch.Size([10, 9, 1]) """ masks = torch.eye(train_X.shape[-2], dtype=torch.uint8, device=train_X.device) masks = masks.to(dtype=torch.bool) if train_Y.dim() < train_X.dim(): # add output dimension train_Y = train_Y.unsqueeze(-1) if train_Yvar is not None: train_Yvar = train_Yvar.unsqueeze(-1) train_X_cv = torch.cat( [train_X[..., ~m, :].unsqueeze(dim=-3) for m in masks], dim=-3 ) test_X_cv = torch.cat([train_X[..., m, :].unsqueeze(dim=-3) for m in masks], dim=-3) train_Y_cv = torch.cat( [train_Y[..., ~m, :].unsqueeze(dim=-3) for m in masks], dim=-3 ) test_Y_cv = torch.cat([train_Y[..., m, :].unsqueeze(dim=-3) for m in masks], dim=-3) if train_Yvar is None: train_Yvar_cv = None test_Yvar_cv = None else: train_Yvar_cv = torch.cat( [train_Yvar[..., ~m, :].unsqueeze(dim=-3) for m in masks], dim=-3 ) test_Yvar_cv = torch.cat( [train_Yvar[..., m, :].unsqueeze(dim=-3) for m in masks], dim=-3 ) return CVFolds( train_X=train_X_cv, test_X=test_X_cv, train_Y=train_Y_cv, test_Y=test_Y_cv, train_Yvar=train_Yvar_cv, test_Yvar=test_Yvar_cv, )
[docs] def batch_cross_validation( model_cls: type[GPyTorchModel], mll_cls: type[MarginalLogLikelihood], cv_folds: CVFolds, fit_args: dict[str, Any] | None = None, observation_noise: bool = False, model_init_kwargs: dict[str, Any] | None = None, ) -> CVResults: r"""Perform cross validation by using GPyTorch batch mode. WARNING: This function is currently very memory inefficient; use it only for problems of small size. Args: model_cls: A GPyTorchModel class. This class must initialize the likelihood internally. Note: Multi-task GPs are not currently supported. mll_cls: A MarginalLogLikelihood class. cv_folds: A CVFolds tuple. fit_args: Arguments passed along to fit_gpytorch_mll. model_init_kwargs: Keyword arguments passed to the model constructor. Returns: A CVResults tuple with the following fields - model: GPyTorchModel for batched cross validation - posterior: GPyTorchPosterior where the mean has shape `n x 1 x m` or `batch_shape x n x 1 x m` - observed_Y: A `n x 1 x m` or `batch_shape x n x 1 x m` tensor of observations. - observed_Yvar: A `n x 1 x m` or `batch_shape x n x 1 x m` tensor of observed measurement noise. Example: >>> import torch >>> from botorch.cross_validation import ( ... batch_cross_validation, gen_loo_cv_folds ... ) >>> >>> from botorch.models import SingleTaskGP >>> from botorch.models.transforms.input import Normalize >>> from botorch.models.transforms.outcome import Standardize >>> from gpytorch.mlls import ExactMarginalLogLikelihood >>> train_X = torch.rand(10, 1) >>> train_Y = torch.rand_like(train_X) >>> cv_folds = gen_loo_cv_folds(train_X, train_Y) >>> input_transform = Normalize(d=train_X.shape[-1]) >>> outcome_transform = Standardize( ... m=train_Y.shape[-1], batch_shape=cv_folds.train_Y.shape[:-2] ... ) >>> >>> cv_results = batch_cross_validation( ... model_cls=SingleTaskGP, ... mll_cls=ExactMarginalLogLikelihood, ... cv_folds=cv_folds, ... model_init_kwargs={ ... "input_transform": input_transform, ... "outcome_transform": outcome_transform, ... }, ... ) """ if issubclass(model_cls, MultiTaskGP): raise UnsupportedError( "Multi-task GPs are not currently supported by `batch_cross_validation`." ) model_init_kws = model_init_kwargs if model_init_kwargs is not None else {} if cv_folds.train_Yvar is not None: model_init_kws["train_Yvar"] = cv_folds.train_Yvar model_cv = model_cls( train_X=cv_folds.train_X, train_Y=cv_folds.train_Y, **model_init_kws, ) mll_cv = mll_cls(model_cv.likelihood, model_cv) mll_cv.to(cv_folds.train_X) fit_args = fit_args or {} mll_cv = fit_gpytorch_mll(mll_cv, **fit_args) # Evaluate on the hold-out set in batch mode with torch.no_grad(): posterior = model_cv.posterior( cv_folds.test_X, observation_noise=observation_noise ) return CVResults( model=model_cv, posterior=posterior, observed_Y=cv_folds.test_Y, observed_Yvar=cv_folds.test_Yvar, )