#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import torch
from gpytorch.constraints import Positive
from gpytorch.kernels import Kernel
from torch import Tensor
[docs]
class InfiniteWidthBNNKernel(Kernel):
r"""Infinite-width BNN kernel.
Defines the GP kernel which is equivalent to performing exact Bayesian
inference on a fully-connected deep neural network with ReLU activations
and i.i.d. priors in the infinite-width limit.
See [Cho2009kernel]_ and [Lee2018deep]_ for details.
.. [Cho2009kernel]
Y. Cho, and L. Saul. Kernel methods for deep learning.
Advances in Neural Information Processing Systems 22. 2009.
.. [Lee2018deep]
J. Lee, Y. Bahri, R. Novak, S. Schoenholz, J. Pennington, and J. Dickstein.
Deep Neural Networks as Gaussian Processes.
International Conference on Learning Representations. 2018.
"""
has_lengthscale = False
def __init__(
self,
depth: int = 3,
batch_shape: torch.Size | None = None,
active_dims: tuple[int, ...] | None = None,
acos_eps: float = 1e-7,
device: torch.device | None = None,
) -> None:
r"""
Args:
depth: Depth of neural network.
batch_shape: This will set a separate weight/bias var for each batch.
It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf` is
a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor.
param active_dims: Compute the covariance of only a few input dimensions.
The ints corresponds to the indices of the dimensions.
param acos_eps: A small positive value to restrict acos inputs to
:math`[-1 + \epsilon, 1 - \epsilon]`
param device: Device for parameters.
"""
super().__init__(batch_shape=batch_shape, active_dims=active_dims)
self.depth = depth
self.acos_eps = acos_eps
self.register_parameter(
"raw_weight_var",
torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, 1, device=device)),
)
self.register_constraint("raw_weight_var", Positive())
self.register_parameter(
"raw_bias_var",
torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, 1, device=device)),
)
self.register_constraint("raw_bias_var", Positive())
@property
def weight_var(self) -> Tensor:
return self.raw_weight_var_constraint.transform(self.raw_weight_var)
@weight_var.setter
def weight_var(self, value) -> None:
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_weight_var)
self.initialize(
raw_weight_var=self.raw_weight_var_constraint.inverse_transform(value)
)
@property
def bias_var(self) -> Tensor:
return self.raw_bias_var_constraint.transform(self.raw_bias_var)
@bias_var.setter
def bias_var(self, value) -> None:
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_bias_var)
self.initialize(
raw_bias_var=self.raw_bias_var_constraint.inverse_transform(value)
)
def _initialize_var(self, x: Tensor) -> Tensor:
"""Computes the initial variance of x for layer 0"""
return (
self.weight_var * torch.sum(x * x, dim=-1, keepdim=True) / x.shape[-1]
+ self.bias_var
)
def _update_var(self, K: Tensor, x: Tensor) -> Tensor:
"""Computes the updated variance of x for next layer"""
return self.weight_var * K / 2 + self.bias_var
def k(self, x1: Tensor, x2: Tensor) -> Tensor:
r"""
For single-layer infinite-width neural networks with i.i.d. priors,
the covariance between outputs can be computed by
:math:`K^0(x, x')=\sigma_b^2+\sigma_w^2\frac{x \cdot x'}{d_\text{input}}`.
For deeper networks, we can recursively define the covariance as
:math:`K^l(x, x')=\sigma_b^2+\sigma_w^2
F_\phi(K^{l-1}(x, x'), K^{l-1}(x, x), K^{l-1}(x', x'))`
where :math:`F_\phi` is a deterministic function based on the
activation function :math:`\phi`.
For ReLU activations, this yields the arc-cosine kernel, which can be computed
analytically.
Args:
x1: `batch_shape x n1 x d`-dim Tensor
x2: `batch_shape x n2 x d`-dim Tensor
"""
K_12 = (
self.weight_var * (x1.matmul(x2.transpose(-2, -1)) / x1.shape[-1])
+ self.bias_var
)
for layer in range(self.depth):
if layer == 0:
K_11 = self._initialize_var(x1)
K_22 = self._initialize_var(x2)
else:
K_11 = self._update_var(K_11, x1)
K_22 = self._update_var(K_22, x2)
sqrt_term = torch.sqrt(K_11.matmul(K_22.transpose(-2, -1)))
fraction = K_12 / sqrt_term
fraction = torch.clamp(
fraction, min=-1 + self.acos_eps, max=1 - self.acos_eps
)
theta = torch.acos(fraction)
theta_term = torch.sin(theta) + (torch.pi - theta) * fraction
K_12 = (
self.weight_var / (2 * torch.pi) * sqrt_term * theta_term
+ self.bias_var
)
return K_12
def forward(
self,
x1: Tensor,
x2: Tensor,
diag: bool | None = False,
last_dim_is_batch: bool | None = False,
**params,
) -> Tensor:
"""
Args:
x1: `batch_shape x n1 x d`-dim Tensor
x2: `batch_shape x n2 x d`-dim Tensor
diag: If True, only returns the diagonal of the kernel matrix.
last_dim_is_batch: Not supported by this kernel.
"""
if last_dim_is_batch:
raise RuntimeError("last_dim_is_batch not supported by this kernel.")
if diag:
K = self._initialize_var(x1)
for _ in range(self.depth):
K = self._update_var(K, x1)
return K.squeeze(-1)
else:
return self.k(x1, x2)