#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from typing import Any, Callable, Optional
from botorch.models.approximate_gp import ApproximateGPyTorchModel
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.sampling.pathwise.features import gen_kernel_features
from botorch.sampling.pathwise.features.generators import TKernelFeatureMapGenerator
from botorch.sampling.pathwise.paths import GeneralizedLinearPath, PathList, SamplePath
from botorch.sampling.pathwise.utils import (
get_input_transform,
get_output_transform,
get_train_inputs,
TInputTransform,
TOutputTransform,
)
from botorch.utils.dispatcher import Dispatcher
from botorch.utils.sampling import draw_sobol_normal_samples
from gpytorch.kernels import Kernel
from gpytorch.models import ApproximateGP, ExactGP, GP
from gpytorch.variational import _VariationalStrategy
from torch import Size, Tensor
from torch.nn import Module
TPathwisePriorSampler = Callable[[GP, Size], SamplePath]
DrawKernelFeaturePaths = Dispatcher("draw_kernel_feature_paths")
[docs]
def draw_kernel_feature_paths(
model: GP, sample_shape: Size, **kwargs: Any
) -> GeneralizedLinearPath:
r"""Draws functions from a Bayesian-linear-model-based approximation to a GP prior.
When evaluted, sample paths produced by this method return Tensors with dimensions
`sample_dims x batch_dims x [joint_dim]`, where `joint_dim` denotes the penultimate
dimension of the input tensor. For multioutput models, outputs are returned as the
final batch dimension.
Args:
model: The prior over functions.
sample_shape: The shape of the sample paths to be drawn.
"""
return DrawKernelFeaturePaths(model, sample_shape=sample_shape, **kwargs)
def _draw_kernel_feature_paths_fallback(
num_inputs: int,
mean_module: Optional[Module],
covar_module: Kernel,
sample_shape: Size,
num_features: int = 1024,
map_generator: TKernelFeatureMapGenerator = gen_kernel_features,
input_transform: Optional[TInputTransform] = None,
output_transform: Optional[TOutputTransform] = None,
weight_generator: Optional[Callable[[Size], Tensor]] = None,
) -> GeneralizedLinearPath:
# Generate a kernel feature map
feature_map = map_generator(
kernel=covar_module,
num_inputs=num_inputs,
num_outputs=num_features,
)
# Sample random weights with which to combine kernel features
if weight_generator is None:
weight = draw_sobol_normal_samples(
n=sample_shape.numel() * covar_module.batch_shape.numel(),
d=feature_map.num_outputs,
device=covar_module.device,
dtype=covar_module.dtype,
).reshape(sample_shape + covar_module.batch_shape + (feature_map.num_outputs,))
else:
weight = weight_generator(
sample_shape + covar_module.batch_shape + (feature_map.num_outputs,)
).to(device=covar_module.device, dtype=covar_module.dtype)
# Return the sample paths
return GeneralizedLinearPath(
feature_map=feature_map,
weight=weight,
bias_module=mean_module,
input_transform=input_transform,
output_transform=output_transform,
)
@DrawKernelFeaturePaths.register(ExactGP)
def _draw_kernel_feature_paths_ExactGP(
model: ExactGP, **kwargs: Any
) -> GeneralizedLinearPath:
(train_X,) = get_train_inputs(model, transformed=False)
return _draw_kernel_feature_paths_fallback(
num_inputs=train_X.shape[-1],
mean_module=model.mean_module,
covar_module=model.covar_module,
input_transform=get_input_transform(model),
output_transform=get_output_transform(model),
**kwargs,
)
@DrawKernelFeaturePaths.register(ModelListGP)
def _draw_kernel_feature_paths_list(
model: ModelListGP,
join: Optional[Callable[[list[Tensor]], Tensor]] = None,
**kwargs: Any,
) -> PathList:
paths = [draw_kernel_feature_paths(m, **kwargs) for m in model.models]
return PathList(paths=paths, join=join)
@DrawKernelFeaturePaths.register(ApproximateGPyTorchModel)
def _draw_kernel_feature_paths_ApproximateGPyTorchModel(
model: ApproximateGPyTorchModel, **kwargs: Any
) -> GeneralizedLinearPath:
(train_X,) = get_train_inputs(model, transformed=False)
return DrawKernelFeaturePaths(
model.model,
num_inputs=train_X.shape[-1],
input_transform=get_input_transform(model),
output_transform=get_output_transform(model),
**kwargs,
)
@DrawKernelFeaturePaths.register(ApproximateGP)
def _draw_kernel_feature_paths_ApproximateGP(
model: ApproximateGP, **kwargs: Any
) -> GeneralizedLinearPath:
return DrawKernelFeaturePaths(model, model.variational_strategy, **kwargs)
@DrawKernelFeaturePaths.register(ApproximateGP, _VariationalStrategy)
def _draw_kernel_feature_paths_ApproximateGP_fallback(
model: ApproximateGP,
_: _VariationalStrategy,
*,
num_inputs: int,
**kwargs: Any,
) -> GeneralizedLinearPath:
return _draw_kernel_feature_paths_fallback(
num_inputs=num_inputs,
mean_module=model.mean_module,
covar_module=model.covar_module,
**kwargs,
)