Source code for botorch.models.kernels.categorical
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from gpytorch.kernels.kernel import Kernel
from torch import Tensor
[docs]class CategoricalKernel(Kernel):
r"""A Kernel for categorical features.
Computes `exp(-(dist(x1, x2) / lengthscale)**2)`, where
`dist(x1, x2)` is zero if `x1 == x2` and one if `x1 != x2`.
Note: This kernel is NOT differentiable w.r.t. the inputs.
"""
has_lengthscale = True
def forward(
self,
x1: Tensor,
x2: Tensor,
diag: bool = False,
last_dim_is_batch: bool = False,
**kwargs
) -> Tensor:
delta = x1.unsqueeze(-2) != x2.unsqueeze(-3)
dists = (delta / self.lengthscale.unsqueeze(-2)).pow(2)
if last_dim_is_batch:
dists = dists.transpose(-3, -1)
else:
dists = dists.mean(-1)
res = torch.exp(-dists)
if diag:
res = torch.diagonal(res, dim1=-1, dim2=-2)
return res