#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
.. [rahimi2007random]
A. Rahimi and B. Recht. Random features for large-scale kernel machines.
Advances in Neural Information Processing Systems 20 (2007).
.. [sutherland2015error]
D. J. Sutherland and J. Schneider. On the error of random Fourier features.
arXiv preprint arXiv:1506.02785 (2015).
"""
from __future__ import annotations
from typing import Any, Callable
import torch
from botorch.exceptions.errors import UnsupportedError
from botorch.sampling.pathwise.features.maps import KernelFeatureMap
from botorch.sampling.pathwise.utils import (
ChainedTransform,
FeatureSelector,
InverseLengthscaleTransform,
OutputscaleTransform,
SineCosineTransform,
)
from botorch.utils.dispatcher import Dispatcher
from botorch.utils.sampling import draw_sobol_normal_samples
from gpytorch import kernels
from gpytorch.kernels.kernel import Kernel
from torch import Size, Tensor
from torch.distributions import Gamma
TKernelFeatureMapGenerator = Callable[[Kernel, int, int], KernelFeatureMap]
GenKernelFeatures = Dispatcher("gen_kernel_features")
[docs]def gen_kernel_features(
kernel: kernels.Kernel,
num_inputs: int,
num_outputs: int,
**kwargs: Any,
) -> KernelFeatureMap:
r"""Generates a feature map :math:`\phi: \mathcal{X} \to \mathbb{R}^{n}` such that
:math:`k(x, x') ≈ \phi(x)^{T} \phi(x')`. For stationary kernels :math:`k`, defaults
to the method of random Fourier features. For more details, see [rahimi2007random]_
and [sutherland2015error]_.
Args:
kernel: The kernel :math:`k` to be represented via a finite-dim basis.
num_inputs: The number of input features.
num_outputs: The number of kernel features.
"""
return GenKernelFeatures(
kernel,
num_inputs=num_inputs,
num_outputs=num_outputs,
**kwargs,
)
def _gen_fourier_features(
kernel: kernels.Kernel,
weight_generator: Callable[[Size], Tensor],
num_inputs: int,
num_outputs: int,
) -> KernelFeatureMap:
r"""Generate a feature map :math:`\phi: \mathcal{X} \to \mathbb{R}^{2l}` that
approximates a stationary kernel so that :math:`k(x, x') ≈ \phi(x)^\top \phi(x')`.
Following [sutherland2015error]_, we represent complex exponentials by pairs of
basis functions :math:`\phi_{i}(x) = \sin(x^\top w_{i})` and
:math:`\phi_{i + l} = \cos(x^\top w_{i}).
Args:
kernel: A stationary kernel :math:`k(x, x') = k(x - x')`.
weight_generator: A callable used to generate weight vectors :math:`w`.
num_inputs: The number of input features.
num_outputs: The number of Fourier features.
"""
if num_outputs % 2:
raise UnsupportedError(
f"Expected an even number of output features, but received {num_outputs=}."
)
input_transform = InverseLengthscaleTransform(kernel)
if kernel.active_dims is not None:
num_inputs = len(kernel.active_dims)
input_transform = ChainedTransform(
input_transform, FeatureSelector(indices=kernel.active_dims)
)
weight = weight_generator(
Size([kernel.batch_shape.numel() * num_outputs // 2, num_inputs])
).reshape(*kernel.batch_shape, num_outputs // 2, num_inputs)
output_transform = SineCosineTransform(
torch.tensor((2 / num_outputs) ** 0.5, device=kernel.device, dtype=kernel.dtype)
)
return KernelFeatureMap(
kernel=kernel,
weight=weight,
input_transform=input_transform,
output_transform=output_transform,
)
@GenKernelFeatures.register(kernels.RBFKernel)
def _gen_kernel_features_rbf(
kernel: kernels.RBFKernel,
*,
num_inputs: int,
num_outputs: int,
) -> KernelFeatureMap:
def _weight_generator(shape: Size) -> Tensor:
try:
n, d = shape
except ValueError:
raise UnsupportedError(
f"Expected `shape` to be 2-dimensional, but {len(shape)=}."
)
return draw_sobol_normal_samples(
n=n,
d=d,
device=kernel.lengthscale.device,
dtype=kernel.lengthscale.dtype,
)
return _gen_fourier_features(
kernel=kernel,
weight_generator=_weight_generator,
num_inputs=num_inputs,
num_outputs=num_outputs,
)
@GenKernelFeatures.register(kernels.MaternKernel)
def _gen_kernel_features_matern(
kernel: kernels.MaternKernel,
*,
num_inputs: int,
num_outputs: int,
) -> KernelFeatureMap:
def _weight_generator(shape: Size) -> Tensor:
try:
n, d = shape
except ValueError:
raise UnsupportedError(
f"Expected `shape` to be 2-dimensional, but {len(shape)=}."
)
dtype = kernel.lengthscale.dtype
device = kernel.lengthscale.device
nu = torch.tensor(kernel.nu, device=device, dtype=dtype)
normals = draw_sobol_normal_samples(n=n, d=d, device=device, dtype=dtype)
return Gamma(nu, nu).rsample((n, 1)).rsqrt() * normals
return _gen_fourier_features(
kernel=kernel,
weight_generator=_weight_generator,
num_inputs=num_inputs,
num_outputs=num_outputs,
)
@GenKernelFeatures.register(kernels.ScaleKernel)
def _gen_kernel_features_scale(
kernel: kernels.ScaleKernel,
*,
num_inputs: int,
num_outputs: int,
) -> KernelFeatureMap:
active_dims = kernel.active_dims
feature_map = gen_kernel_features(
kernel.base_kernel,
num_inputs=num_inputs if active_dims is None else len(active_dims),
num_outputs=num_outputs,
)
if active_dims is not None and active_dims is not kernel.base_kernel.active_dims:
feature_map.input_transform = ChainedTransform(
feature_map.input_transform, FeatureSelector(indices=active_dims)
)
feature_map.output_transform = ChainedTransform(
OutputscaleTransform(kernel), feature_map.output_transform
)
return feature_map