# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import List, Optional, Tuple
import numpy
import torch
from botorch.exceptions.errors import UnsupportedError
from gpytorch.constraints import Interval, Positive
from gpytorch.kernels import Kernel
from torch import nn, Tensor
_positivity_constraint = Positive()
[docs]
class OrthogonalAdditiveKernel(Kernel):
r"""Orthogonal Additive Kernels (OAKs) were introduced in [Lu2022additive]_, though
only for the case of Gaussian base kernels with a Gaussian input data distribution.
The implementation here generalizes OAKs to arbitrary base kernels by using a
Gauss-Legendre quadrature approximation to the required one-dimensional integrals
involving the base kernels.
.. [Lu2022additive]
X. Lu, A. Boukouvalas, and J. Hensman. Additive Gaussian processes revisited.
Proceedings of the 39th International Conference on Machine Learning. Jul 2022.
"""
def __init__(
self,
base_kernel: Kernel,
dim: int,
quad_deg: int = 32,
second_order: bool = False,
batch_shape: Optional[torch.Size] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
coeff_constraint: Interval = _positivity_constraint,
):
"""
Args:
base_kernel: The kernel which to orthogonalize and evaluate in `forward`.
dim: Input dimensionality of the kernel.
quad_deg: Number of integration nodes for orthogonalization.
second_order: Toggles second order interactions. If true, both the time and
space complexity of evaluating the kernel are quadratic in `dim`.
batch_shape: Optional batch shape for the kernel and its parameters.
dtype: Initialization dtype for required Tensors.
device: Initialization device for required Tensors.
coeff_constraint: Constraint on the coefficients of the additive kernel.
"""
super().__init__(batch_shape=batch_shape)
self.base_kernel = base_kernel
# integration nodes, weights for [0, 1]
tkwargs = {"dtype": dtype, "device": device}
z, w = leggauss(deg=quad_deg, a=0, b=1, **tkwargs)
self.z = z.unsqueeze(-1).expand(quad_deg, dim) # deg x dim
self.w = w.unsqueeze(-1)
self.register_parameter(
name="raw_offset",
parameter=nn.Parameter(torch.zeros(self.batch_shape, **tkwargs)),
)
log_d = math.log(dim)
self.register_parameter(
name="raw_coeffs_1",
parameter=nn.Parameter(
torch.zeros(*self.batch_shape, dim, **tkwargs) - log_d
),
)
self.register_parameter(
name="raw_coeffs_2",
parameter=nn.Parameter(
torch.zeros(*self.batch_shape, int(dim * (dim - 1) / 2), **tkwargs)
- 2 * log_d
)
if second_order
else None,
)
if second_order:
self._rev_triu_indices = torch.tensor(
_reverse_triu_indices(dim),
device=device,
dtype=int,
)
# zero tensor for construction of upper-triangular coefficient matrix
self._quad_zero = torch.zeros(
tuple(1 for _ in range(len(batch_shape) + 1)), **tkwargs
).expand(*batch_shape, 1)
self.coeff_constraint = coeff_constraint
self.dim = dim
def k(self, x1, x2) -> Tensor:
"""Evaluates the kernel matrix base_kernel(x1, x2) on each input dimension
independently.
Args:
x1: `batch_shape x n1 x d`-dim Tensor in [0, 1]^dim.
x2: `batch_shape x n2 x d`-dim Tensor in [0, 1]^dim.
Returns:
A `batch_shape x d x n1 x n2`-dim Tensor of kernel matrices.
"""
return self.base_kernel(x1, x2, last_dim_is_batch=True).to_dense()
@property
def offset(self) -> Tensor:
"""Returns the `batch_shape`-dim Tensor of zeroth-order coefficients."""
return self.coeff_constraint.transform(self.raw_offset)
@property
def coeffs_1(self) -> Tensor:
"""Returns the `batch_shape x d`-dim Tensor of first-order coefficients."""
return self.coeff_constraint.transform(self.raw_coeffs_1)
@property
def coeffs_2(self) -> Optional[Tensor]:
"""Returns the upper-triangular tensor of second-order coefficients.
NOTE: We only keep track of the upper triangular part of raw second order
coefficients since the effect of the lower triangular part is identical and
exclude the diagonal, since it is associated with first-order effects only.
While we could further exploit this structure in the forward pass, the
associated indexing and temporary allocations make it significantly less
efficient than the einsum-based implementation below.
Returns:
`batch_shape x d x d`-dim Tensor of second-order coefficients.
"""
if self.raw_coeffs_2 is not None:
C2 = self.coeff_constraint.transform(self.raw_coeffs_2)
C2 = torch.cat((C2, self._quad_zero), dim=-1) # batch_shape x (d(d-1)/2+1)
C2 = C2.index_select(-1, self._rev_triu_indices)
return C2.reshape(*self.batch_shape, self.dim, self.dim)
else:
return None
def forward(
self,
x1: Tensor,
x2: Tensor,
diag: bool = False,
last_dim_is_batch: bool = False,
) -> Tensor:
"""Computes the kernel matrix k(x1, x2).
Args:
x1: `batch_shape x n1 x d`-dim Tensor in [0, 1]^dim.
x2: `batch_shape x n2 x d`-dim Tensor in [0, 1]^dim.
diag: If True, only returns the diagonal of the kernel matrix.
last_dim_is_batch: Not supported by this kernel.
Returns:
A `batch_shape x n1 x n2`-dim Tensor of kernel matrices.
"""
if last_dim_is_batch:
raise UnsupportedError(
"OrthogonalAdditiveKernel does not support `last_dim_is_batch`."
)
K_ortho = self._orthogonal_base_kernels(x1, x2) # batch_shape x d x n1 x n2
# contracting over d, leading to `batch_shape x n x n`-dim tensor, i.e.:
# K1 = torch.sum(self.coeffs_1[..., None, None] * K_ortho, dim=-3)
K1 = torch.einsum(self.coeffs_1, [..., 0], K_ortho, [..., 0, 1, 2], [..., 1, 2])
# adding the non-batch dimensions to offset
K = K1 + self.offset[..., None, None]
if self.coeffs_2 is not None:
# Computing the tensor of second order interactions K2.
# NOTE: K2 here is equivalent to:
# K2 = K_ortho.unsqueeze(-4) * K_ortho.unsqueeze(-3) # d x d x n x n
# K2 = (self.coeffs_2[..., None, None] * K2).sum(dim=(-4, -3))
# but avoids forming the `batch_shape x d x d x n x n`-dim tensor in memory.
# Reducing over the dimensions with the O(d^2) quadratic terms:
K2 = torch.einsum(
K_ortho,
[..., 0, 2, 3],
K_ortho,
[..., 1, 2, 3],
self.coeffs_2,
[..., 0, 1],
[..., 2, 3], # i.e. contracting over the first two non-batch dims
)
K = K + K2
return K if not diag else K.diag() # poor man's diag (TODO)
def _orthogonal_base_kernels(self, x1: Tensor, x2: Tensor) -> Tensor:
"""Evaluates the set of `d` orthogonalized base kernels on (x1, x2).
Note that even if the base kernel is positive, the orthogonalized versions
can - and usually do - take negative values.
Args:
x1: `batch_shape x n1 x d`-dim inputs to the kernel.
x2: `batch_shape x n2 x d`-dim inputs to the kernel.
Returns:
A `batch_shape x d x n1 x n2`-dim Tensor.
"""
_check_hypercube(x1, "x1")
if x1 is not x2:
_check_hypercube(x2, "x2")
Kx1x2 = self.k(x1, x2) # d x n x n
# Overwriting allocated quadrature tensors with fitting dtype and device
# self.z, self.w = self.z.to(x1), self.w.to(x1)
# include normalization constant in weights
w = self.w / self.normalizer().sqrt()
Skx1 = self.k(x1, self.z) @ w # batch_shape x d x n
Skx2 = Skx1 if (x1 is x2) else self.k(x2, self.z) @ w # d x n
# this is a tensor of kernel matrices of orthogonal 1d kernels
K_ortho = (Kx1x2 - Skx1 @ Skx2.transpose(-2, -1)).to_dense() # d x n x n
return K_ortho
def normalizer(self, eps: float = 1e-6) -> Tensor:
"""Integrates the `d` orthogonalized base kernels over `[0, 1] x [0, 1]`.
NOTE: If the module is in train mode, this needs to re-compute the normalizer
each time because the underlying parameters might have changed.
Args:
eps: Minimum value constraint on the normalizers. Avoids division by zero.
Returns:
A `d`-dim tensor of normalization constants.
"""
if self.train() or getattr(self, "_normalizer", None) is None:
self._normalizer = (self.w.T @ self.k(self.z, self.z) @ self.w).clamp(eps)
return self._normalizer
def leggauss(
deg: int,
a: float = -1.0,
b: float = 1.0,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]:
"""Computes Gauss-Legendre quadrature nodes and weights. Wraps
`numpy.polynomial.legendre.leggauss` and returns Torch Tensors.
Args:
deg: Number of sample points and weights. Integrates poynomials of degree
`2 * deg + 1` exactly.
a, b: Lower and upper bound of integration domain.
dtype: Desired floating point type of the return Tensors.
device: Desired device type of the return Tensors.
Returns:
A tuple of Gauss-Legendre quadrature nodes and weights of length deg.
"""
dtype = dtype if dtype is not None else torch.get_default_dtype()
x, w = numpy.polynomial.legendre.leggauss(deg=deg)
x = torch.as_tensor(x, dtype=dtype, device=device)
w = torch.as_tensor(w, dtype=dtype, device=device)
if not (a == -1 and b == 1): # need to normalize for different domain
x = (b - a) * (x + 1) / 2 + a
w = w * ((b - a) / 2)
return x, w
def _check_hypercube(x: Tensor, name: str) -> None:
"""Raises a `ValueError` if an element `x` is not in [0, 1].
Args:
x: Tensor to be checked.
name: Name of the Tensor for the error message.
"""
if (x < 0).any() or (x > 1).any():
raise ValueError(name + " is not in hypercube [0, 1]^d.")
def _reverse_triu_indices(d: int) -> List[int]:
"""Computes a list of indices which, upon indexing a `d * (d - 1) / 2 + 1`-dim
Tensor whose last element is zero, will lead to a vectorized representation of
an upper-triangular matrix, whose diagonal is set to zero and whose super-diagonal
elements are set to the `d * (d - 1) / 2` values in the original tensor.
NOTE: This is a helper function for Orthogonal Additive Kernels, and allows the
implementation to only register `d * (d - 1) / 2` parameters to model the second
order interactions, instead of the full d^2 redundant terms.
Args:
d: Dimensionality that gives rise to the `d * (d - 1) / 2` quadratic terms.
Returns:
A list of integer indices in `[0, d * (d - 1) / 2]`. See above for details.
"""
indices = []
j = 0
d2 = int(d * (d - 1) / 2)
for i in range(d):
indices.extend(d2 for _ in range(i + 1)) # indexing zero (sub-diagonal)
indices.extend(range(j, j + d - i - 1)) # indexing coeffs (super-diagonal)
j += d - i - 1
return indices