Source code for botorch.sampling.pathwise.features.maps

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import Optional, Union

import torch
from botorch.sampling.pathwise.utils import (
    TInputTransform,
    TOutputTransform,
    TransformedModuleMixin,
)
from gpytorch.kernels import Kernel
from linear_operator.operators import LinearOperator
from torch import Size, Tensor
from torch.nn import Module


[docs]class FeatureMap(TransformedModuleMixin, Module): num_outputs: int batch_shape: Size input_transform: Optional[TInputTransform] output_transform: Optional[TOutputTransform]
[docs]class KernelEvaluationMap(FeatureMap): r"""A feature map defined by centering a kernel at a set of points.""" def __init__( self, kernel: Kernel, points: Tensor, input_transform: Optional[TInputTransform] = None, output_transform: Optional[TOutputTransform] = None, ) -> None: r"""Initializes a KernelEvaluationMap instance: .. code-block:: text feature_map(x) = output_transform(kernel(input_transform(x), points)). Args: kernel: The kernel :math:`k` used to define the feature map. points: A tensor passed as the kernel's second argument. input_transform: An optional input transform for the module. output_transform: An optional output transform for the module. """ try: torch.broadcast_shapes(points.shape[:-2], kernel.batch_shape) except RuntimeError: raise RuntimeError( f"Shape mismatch: {points.shape=}, but {kernel.batch_shape=}." ) super().__init__() self.kernel = kernel self.points = points self.input_transform = input_transform self.output_transform = output_transform
[docs] def forward(self, x: Tensor) -> Union[Tensor, LinearOperator]: return self.kernel(x, self.points)
@property def num_outputs(self) -> int: if self.output_transform is None: return self.points.shape[-1] canary = torch.empty( 1, self.points.shape[-1], device=self.points.device, dtype=self.points.dtype ) return self.output_transform(canary).shape[-1] @property def batch_shape(self) -> Size: return self.kernel.batch_shape
[docs]class KernelFeatureMap(FeatureMap): r"""Representation of a kernel :math:`k: \mathcal{X}^2 \to \mathbb{R}` as an n-dimensional feature map :math:`\phi: \mathcal{X} \to \mathbb{R}^n` satisfying: :math:`k(x, x') ≈ \phi(x)^\top \phi(x')`. """ def __init__( self, kernel: Kernel, weight: Tensor, bias: Optional[Tensor] = None, input_transform: Optional[TInputTransform] = None, output_transform: Optional[TOutputTransform] = None, ) -> None: r"""Initializes a KernelFeatureMap instance: .. code-block:: text feature_map(x) = output_transform(input_transform(x)^{T} weight + bias). Args: kernel: The kernel :math:`k` used to define the feature map. weight: A tensor of weights used to linearly combine the module's inputs. bias: A tensor of biases to be added to the linearly combined inputs. input_transform: An optional input transform for the module. output_transform: An optional output transform for the module. """ super().__init__() self.kernel = kernel self.weight = weight self.bias = bias self.input_transform = input_transform self.output_transform = output_transform
[docs] def forward(self, x: Tensor) -> Tensor: out = x @ self.weight.transpose(-2, -1) return out if self.bias is None else out + self.bias
@property def num_outputs(self) -> int: if self.output_transform is None: return self.weight.shape[-2] canary = torch.empty( self.weight.shape[-2], device=self.weight.device, dtype=self.weight.dtype ) return self.output_transform(canary).shape[-1] @property def batch_shape(self) -> Size: return self.kernel.batch_shape