# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import numpy
import torch
from botorch.exceptions.errors import UnsupportedError
from gpytorch.constraints import Interval, Positive
from gpytorch.kernels import Kernel
from gpytorch.module import Module
from gpytorch.priors import Prior
from torch import nn, Tensor
_positivity_constraint = Positive()
SECOND_ORDER_PRIOR_ERROR_MSG = (
"Second order interactions are disabled, but there is a prior on the second order "
"coefficients. Please remove the second order prior or enable second order terms."
)
[docs]
class OrthogonalAdditiveKernel(Kernel):
r"""Orthogonal Additive Kernels (OAKs) were introduced in [Lu2022additive]_, though
only for the case of Gaussian base kernels with a Gaussian input data distribution.
The implementation here generalizes OAKs to arbitrary base kernels by using a
Gauss-Legendre quadrature approximation to the required one-dimensional integrals
involving the base kernels.
.. [Lu2022additive]
X. Lu, A. Boukouvalas, and J. Hensman. Additive Gaussian processes revisited.
Proceedings of the 39th International Conference on Machine Learning. Jul 2022.
"""
def __init__(
self,
base_kernel: Kernel,
dim: int,
quad_deg: int = 32,
second_order: bool = False,
batch_shape: torch.Size | None = None,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
coeff_constraint: Interval = _positivity_constraint,
offset_prior: Prior | None = None,
coeffs_1_prior: Prior | None = None,
coeffs_2_prior: Prior | None = None,
):
"""
Args:
base_kernel: The kernel which to orthogonalize and evaluate in `forward`.
dim: Input dimensionality of the kernel.
quad_deg: Number of integration nodes for orthogonalization.
second_order: Toggles second order interactions. If true, both the time and
space complexity of evaluating the kernel are quadratic in `dim`.
batch_shape: Optional batch shape for the kernel and its parameters.
dtype: Initialization dtype for required Tensors.
device: Initialization device for required Tensors.
coeff_constraint: Constraint on the coefficients of the additive kernel.
offset_prior: Prior on the offset coefficient. Should be prior with non-
negative support.
coeffs_1_prior: Prior on the parameter main effects. Should be prior with
non-negative support.
coeffs_2_prior: coeffs_1_prior: Prior on the parameter interactions. Should
be prior with non-negative support.
"""
super().__init__(batch_shape=batch_shape)
self.base_kernel = base_kernel
if not second_order and coeffs_2_prior is not None:
raise AttributeError(SECOND_ORDER_PRIOR_ERROR_MSG)
# integration nodes, weights for [0, 1]
tkwargs = {"dtype": dtype, "device": device}
z, w = leggauss(deg=quad_deg, a=0, b=1, **tkwargs)
self.z = z.unsqueeze(-1).expand(quad_deg, dim) # deg x dim
self.w = w.unsqueeze(-1)
self.register_parameter(
name="raw_offset",
parameter=nn.Parameter(torch.zeros(self.batch_shape, **tkwargs)),
)
log_d = math.log(dim)
self.register_parameter(
name="raw_coeffs_1",
parameter=nn.Parameter(
torch.zeros(*self.batch_shape, dim, **tkwargs) - log_d
),
)
self.register_parameter(
name="raw_coeffs_2",
parameter=(
nn.Parameter(
torch.zeros(*self.batch_shape, int(dim * (dim - 1) / 2), **tkwargs)
- 2 * log_d
)
if second_order
else None
),
)
if offset_prior is not None:
self.register_prior(
name="offset_prior",
prior=offset_prior,
param_or_closure=_offset_param,
setting_closure=_offset_closure,
)
if coeffs_1_prior is not None:
self.register_prior(
name="coeffs_1_prior",
prior=coeffs_1_prior,
param_or_closure=_coeffs_1_param,
setting_closure=_coeffs_1_closure,
)
if coeffs_2_prior is not None:
self.register_prior(
name="coeffs_2_prior",
prior=coeffs_2_prior,
param_or_closure=_coeffs_2_param,
setting_closure=_coeffs_2_closure,
)
# for second order interactions, we only
if second_order:
self._rev_triu_indices = torch.tensor(
_reverse_triu_indices(dim),
device=device,
dtype=int,
)
# zero tensor for construction of upper-triangular coefficient matrix
self._quad_zero = torch.zeros(
tuple(1 for _ in range(len(self.batch_shape) + 1)), **tkwargs
).expand(*self.batch_shape, 1)
self.coeff_constraint = coeff_constraint
self.dim = dim
def k(self, x1: Tensor, x2: Tensor) -> Tensor:
"""Evaluates the kernel matrix base_kernel(x1, x2) on each input dimension
independently.
Args:
x1: `batch_shape x n1 x d`-dim Tensor in [0, 1]^dim.
x2: `batch_shape x n2 x d`-dim Tensor in [0, 1]^dim.
Returns:
A `batch_shape x d x n1 x n2`-dim Tensor of kernel matrices.
"""
return self.base_kernel(x1, x2, last_dim_is_batch=True).to_dense()
@property
def offset(self) -> Tensor:
"""Returns the `batch_shape`-dim Tensor of zeroth-order coefficients."""
return self.coeff_constraint.transform(self.raw_offset)
@property
def coeffs_1(self) -> Tensor:
"""Returns the `batch_shape x d`-dim Tensor of first-order coefficients."""
return self.coeff_constraint.transform(self.raw_coeffs_1)
@property
def coeffs_2(self) -> Tensor | None:
"""Returns the upper-triangular tensor of second-order coefficients.
NOTE: We only keep track of the upper triangular part of raw second order
coefficients since the effect of the lower triangular part is identical and
exclude the diagonal, since it is associated with first-order effects only.
While we could further exploit this structure in the forward pass, the
associated indexing and temporary allocations make it significantly less
efficient than the einsum-based implementation below.
Returns:
`batch_shape x d x d`-dim Tensor of second-order coefficients.
"""
if self.raw_coeffs_2 is not None:
C2 = self.coeff_constraint.transform(self.raw_coeffs_2)
C2 = torch.cat((C2, self._quad_zero), dim=-1) # batch_shape x (d(d-1)/2+1)
C2 = C2.index_select(-1, self._rev_triu_indices)
return C2.reshape(*self.batch_shape, self.dim, self.dim)
else:
return None
def _set_coeffs_1(self, value: Tensor) -> None:
value = torch.as_tensor(value).to(self.raw_coeffs_1)
value = value.expand(*self.batch_shape, self.dim)
self.initialize(raw_coeffs_1=self.coeff_constraint.inverse_transform(value))
def _set_coeffs_2(self, value: Tensor) -> None:
value = torch.as_tensor(value).to(self.raw_coeffs_1)
value = value.expand(*self.batch_shape, self.dim, self.dim)
row_idcs, col_idcs = torch.triu_indices(self.dim, self.dim, offset=1)
value = value[..., row_idcs, col_idcs].to(self.raw_coeffs_2)
self.initialize(raw_coeffs_2=self.coeff_constraint.inverse_transform(value))
def _set_offset(self, value: Tensor) -> None:
value = torch.as_tensor(value).to(self.raw_offset)
self.initialize(raw_offset=self.coeff_constraint.inverse_transform(value))
@coeffs_1.setter
def coeffs_1(self, value) -> None:
self._set_coeffs_1(value)
@coeffs_2.setter
def coeffs_2(self, value) -> None:
self._set_coeffs_2(value)
@offset.setter
def offset(self, value) -> None:
self._set_offset(value)
def forward(
self,
x1: Tensor,
x2: Tensor,
diag: bool = False,
last_dim_is_batch: bool = False,
) -> Tensor:
"""Computes the kernel matrix k(x1, x2).
Args:
x1: `batch_shape x n1 x d`-dim Tensor in [0, 1]^dim.
x2: `batch_shape x n2 x d`-dim Tensor in [0, 1]^dim.
diag: If True, only returns the diagonal of the kernel matrix.
last_dim_is_batch: Not supported by this kernel.
Returns:
A `batch_shape x n1 x n2`-dim Tensor of kernel matrices.
"""
if last_dim_is_batch:
raise UnsupportedError(
"OrthogonalAdditiveKernel does not support `last_dim_is_batch`."
)
K_ortho = self._orthogonal_base_kernels(x1, x2) # batch_shape x d x n1 x n2
# contracting over d, leading to `batch_shape x n x n`-dim tensor, i.e.:
# K1 = torch.sum(self.coeffs_1[..., None, None] * K_ortho, dim=-3)
K1 = torch.einsum(self.coeffs_1, [..., 0], K_ortho, [..., 0, 1, 2], [..., 1, 2])
# adding the non-batch dimensions to offset
K = K1 + self.offset[..., None, None]
if self.coeffs_2 is not None:
# Computing the tensor of second order interactions K2.
# NOTE: K2 here is equivalent to:
# K2 = K_ortho.unsqueeze(-4) * K_ortho.unsqueeze(-3) # d x d x n x n
# K2 = (self.coeffs_2[..., None, None] * K2).sum(dim=(-4, -3))
# but avoids forming the `batch_shape x d x d x n x n`-dim tensor in memory.
# Reducing over the dimensions with the O(d^2) quadratic terms:
K2 = torch.einsum(
K_ortho,
[..., 0, 2, 3],
K_ortho,
[..., 1, 2, 3],
self.coeffs_2,
[..., 0, 1],
[..., 2, 3], # i.e. contracting over the first two non-batch dims
)
K = K + K2
return K if not diag else K.diag() # poor man's diag (TODO)
def _orthogonal_base_kernels(self, x1: Tensor, x2: Tensor) -> Tensor:
"""Evaluates the set of `d` orthogonalized base kernels on (x1, x2).
Note that even if the base kernel is positive, the orthogonalized versions
can - and usually do - take negative values.
Args:
x1: `batch_shape x n1 x d`-dim inputs to the kernel.
x2: `batch_shape x n2 x d`-dim inputs to the kernel.
Returns:
A `batch_shape x d x n1 x n2`-dim Tensor.
"""
_check_hypercube(x1, "x1")
if x1 is not x2:
_check_hypercube(x2, "x2")
Kx1x2 = self.k(x1, x2) # d x n x n
# Overwriting allocated quadrature tensors with fitting dtype and device
# self.z, self.w = self.z.to(x1), self.w.to(x1)
# include normalization constant in weights
w = self.w / self.normalizer().sqrt()
Skx1 = self.k(x1, self.z) @ w # batch_shape x d x n
Skx2 = Skx1 if (x1 is x2) else self.k(x2, self.z) @ w # d x n
# this is a tensor of kernel matrices of orthogonal 1d kernels
K_ortho = (Kx1x2 - Skx1 @ Skx2.transpose(-2, -1)).to_dense() # d x n x n
return K_ortho
def normalizer(self, eps: float = 1e-6) -> Tensor:
"""Integrates the `d` orthogonalized base kernels over `[0, 1] x [0, 1]`.
NOTE: If the module is in train mode, this needs to re-compute the normalizer
each time because the underlying parameters might have changed.
Args:
eps: Minimum value constraint on the normalizers. Avoids division by zero.
Returns:
A `d`-dim tensor of normalization constants.
"""
if self.train() or getattr(self, "_normalizer", None) is None:
self._normalizer = (self.w.T @ self.k(self.z, self.z) @ self.w).clamp(eps)
return self._normalizer
def leggauss(
deg: int,
a: float = -1.0,
b: float = 1.0,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
) -> tuple[Tensor, Tensor]:
"""Computes Gauss-Legendre quadrature nodes and weights. Wraps
`numpy.polynomial.legendre.leggauss` and returns Torch Tensors.
Args:
deg: Number of sample points and weights. Integrates poynomials of degree
`2 * deg + 1` exactly.
a, b: Lower and upper bound of integration domain.
dtype: Desired floating point type of the return Tensors.
device: Desired device type of the return Tensors.
Returns:
A tuple of Gauss-Legendre quadrature nodes and weights of length deg.
"""
dtype = dtype if dtype is not None else torch.get_default_dtype()
x, w = numpy.polynomial.legendre.leggauss(deg=deg)
x = torch.as_tensor(x, dtype=dtype, device=device)
w = torch.as_tensor(w, dtype=dtype, device=device)
if not (a == -1 and b == 1): # need to normalize for different domain
x = (b - a) * (x + 1) / 2 + a
w = w * ((b - a) / 2)
return x, w
def _check_hypercube(x: Tensor, name: str) -> None:
"""Raises a `ValueError` if an element `x` is not in [0, 1].
Args:
x: Tensor to be checked.
name: Name of the Tensor for the error message.
"""
if (x < 0).any() or (x > 1).any():
raise ValueError(name + " is not in hypercube [0, 1]^d.")
def _reverse_triu_indices(d: int) -> list[int]:
"""Computes a list of indices which, upon indexing a `d * (d - 1) / 2 + 1`-dim
Tensor whose last element is zero, will lead to a vectorized representation of
an upper-triangular matrix, whose diagonal is set to zero and whose super-diagonal
elements are set to the `d * (d - 1) / 2` values in the original tensor.
NOTE: This is a helper function for Orthogonal Additive Kernels, and allows the
implementation to only register `d * (d - 1) / 2` parameters to model the second
order interactions, instead of the full d^2 redundant terms.
Args:
d: Dimensionality that gives rise to the `d * (d - 1) / 2` quadratic terms.
Returns:
A list of integer indices in `[0, d * (d - 1) / 2]`. See above for details.
"""
indices = []
j = 0
d2 = int(d * (d - 1) / 2)
for i in range(d):
indices.extend(d2 for _ in range(i + 1)) # indexing zero (sub-diagonal)
indices.extend(range(j, j + d - i - 1)) # indexing coeffs (super-diagonal)
j += d - i - 1
return indices
def _coeffs_1_param(m: Module) -> Tensor:
return m.coeffs_1
def _coeffs_2_param(m: Module) -> Tensor:
return m.coeffs_2
def _offset_param(m: Module) -> Tensor:
return m.offset
def _coeffs_1_closure(m: Module, v: Tensor) -> Tensor:
return m._set_coeffs_1(v)
def _coeffs_2_closure(m: Module, v: Tensor) -> Tensor:
return m._set_coeffs_2(v)
def _offset_closure(m: Module, v: Tensor) -> Tensor:
return m._set_offset(v)