botorch.cross_validation

Cross-validation utilities using batch evaluation mode.

class botorch.cross_validation.CVFolds(train_X, test_X, train_Y, test_Y, train_Yvar, test_Yvar)[source]

Bases: NamedTuple

Create new instance of CVFolds(train_X, test_X, train_Y, test_Y, train_Yvar, test_Yvar)

Parameters:
  • train_X (Tensor)

  • test_X (Tensor)

  • train_Y (Tensor)

  • test_Y (Tensor)

  • train_Yvar (Tensor | None)

  • test_Yvar (Tensor | None)

train_X: Tensor

Alias for field number 0

test_X: Tensor

Alias for field number 1

train_Y: Tensor

Alias for field number 2

test_Y: Tensor

Alias for field number 3

train_Yvar: Tensor | None

Alias for field number 4

test_Yvar: Tensor | None

Alias for field number 5

class botorch.cross_validation.CVResults(model, posterior, observed_Y, observed_Yvar)[source]

Bases: NamedTuple

Create new instance of CVResults(model, posterior, observed_Y, observed_Yvar)

Parameters:
  • model (GPyTorchModel)

  • posterior (GPyTorchPosterior)

  • observed_Y (Tensor)

  • observed_Yvar (Tensor | None)

model: GPyTorchModel

Alias for field number 0

posterior: GPyTorchPosterior

Alias for field number 1

observed_Y: Tensor

Alias for field number 2

observed_Yvar: Tensor | None

Alias for field number 3

botorch.cross_validation.gen_loo_cv_folds(train_X, train_Y, train_Yvar=None)[source]

Generate LOO CV folds w.r.t. to n.

Parameters:
  • train_X (Tensor) – A n x d or batch_shape x n x d (batch mode) tensor of training features.

  • train_Y (Tensor) – A n x (m) or batch_shape x n x (m) (batch mode) tensor of training observations.

  • train_Yvar (Tensor | None) – An n x (m) or batch_shape x n x (m) (batch mode) tensor of observed measurement noise.

Returns:

  • train_X: A n x (n-1) x d or batch_shape x n x (n-1) x d tensor of training features.

  • test_X: A n x 1 x d or batch_shape x n x 1 x d tensor of test features.

  • train_Y: A n x (n-1) x m or batch_shape x n x (n-1) x m tensor of training observations.

  • test_Y: A n x 1 x m or batch_shape x n x 1 x m tensor of test observations.

  • train_Yvar: A n x (n-1) x m or batch_shape x n x (n-1) x m tensor of observed measurement noise.

  • test_Yvar: A n x 1 x m or batch_shape x n x 1 x m tensor of observed measurement noise.

Return type:

CVFolds NamedTuple with the following fields

Example

>>> train_X = torch.rand(10, 1)
>>> train_Y = torch.rand_like(train_X)
>>> cv_folds = gen_loo_cv_folds(train_X, train_Y)
>>> cv_folds.train_X.shape
torch.Size([10, 9, 1])
botorch.cross_validation.batch_cross_validation(model_cls, mll_cls, cv_folds, fit_args=None, observation_noise=False, model_init_kwargs=None)[source]

Perform cross validation by using gpytorch batch mode.

Parameters:
  • model_cls (Type[GPyTorchModel]) – A GPyTorchModel class. This class must initialize the likelihood internally. Note: Multi-task GPs are not currently supported.

  • mll_cls (Type[MarginalLogLikelihood]) – A MarginalLogLikelihood class.

  • cv_folds (CVFolds) – A CVFolds tuple.

  • fit_args (Dict[str, Any] | None) – Arguments passed along to fit_gpytorch_mll.

  • model_init_kwargs (Dict[str, Any] | None) – Keyword arguments passed to the model constructor.

  • observation_noise (bool)

Returns:

A CVResults tuple with the following fields

  • model: GPyTorchModel for batched cross validation

  • posterior: GPyTorchPosterior where the mean has shape n x 1 x m or batch_shape x n x 1 x m

  • observed_Y: A n x 1 x m or batch_shape x n x 1 x m tensor of observations.

  • observed_Yvar: A n x 1 x m or batch_shape x n x 1 x m tensor of observed measurement noise.

Return type:

CVResults

Example

>>> import torch
>>> from botorch.cross_validation import (
...     batch_cross_validation, gen_loo_cv_folds
>>>
>>> from botorch.models import SingleTaskGP
>>> from botorch.models.transforms.input import Normalize
>>> from botorch.models.transforms.outcome import Standardize
>>> from gpytorch.mlls import ExactMarginalLogLikelihood
>>> train_X = torch.rand(10, 1)
>>> train_Y = torch.rand_like(train_X)
>>> cv_folds = gen_loo_cv_folds(train_X, train_Y)
>>> input_transform = Normalize(d=train_X.shape[-1])
>>> output_transform = Standardize(
...     m=train_Y.shape[-1], batch_shape=cv_folds.train_Y.shape[:-2]
... )
>>>
>>> cv_results = batch_cross_validation(
...    model_cls=SingleTaskGP,
...    mll_cls=ExactMarginalLogLikelihood,
...    cv_folds=cv_folds,
...    model_init_kwargs={
...        "input_transform": input_transform,
...        "output_transform": output_transform,
...    },
... )
WARNING: This function is currently very memory inefficient, use it only

for problems of small size.