# Source code for botorch.models.kernels.contextual_sac

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Optional

import torch
from gpytorch.kernels.kernel import Kernel
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.kernels.scale_kernel import ScaleKernel
from gpytorch.priors.torch_priors import GammaPrior
from linear_operator.operators.sum_linear_operator import SumLinearOperator
from torch import Tensor
from torch.nn import ModuleDict  # pyre-ignore

[docs]class SACKernel(Kernel):
r"""The structural additive contextual(SAC) kernel.

The kernel is used for contextual BO without oberseving context breakdowns.
There are d parameters and M contexts. In total, the dimension of parameter space
is d*M and input x can be written as
x=[x_11, ..., x_1d, x_21, ..., x_2d, ...,  x_M1, ..., x_Md].

The kernel uses the parameter decomposition and assumes an additive structure
across contexts. Each context compponent is assumed to be independent.

.. math::
\begin{equation*}
k(\mathbf{x}, \mathbf{x'}) = k_1(\mathbf{x_(1)}, \mathbf{x'_(1)}) + \cdots
+ k_M(\mathbf{x_(M)}, \mathbf{x'_(M)})
\end{equation*}

where
* :math: M is the number of partitions of parameter space. Each partition contains
same number of parameters d. Each kernel k_i acts only on d parameters of ith
partition i.e. \mathbf{x}_(i). Each kernel k_i is a scaled Matern kernel
with same lengthscales but different outputscales.
"""

def __init__(
self,
decomposition: Dict[str, List[int]],
batch_shape: torch.Size,
device: Optional[torch.device] = None,
) -> None:
r"""
Args:
decomposition: Keys are context names. Values are the indexes of parameters
belong to the context. The parameter indexes are in the same order
across contexts.
batch_shape: Batch shape as usual for gpytorch kernels.
device: The torch device.
"""

super().__init__(batch_shape=batch_shape)
self.decomposition = decomposition
self._device = device

num_param = len(next(iter(decomposition.values())))
for active_parameters in decomposition.values():
# check number of parameters are same in each decomp
if len(active_parameters) != num_param:
raise ValueError(
"num of parameters needs to be same across all contexts"
)

self._indexers = {
context: torch.tensor(active_params, device=self.device)
for context, active_params in self.decomposition.items()
}

self.base_kernel = MaternKernel(
nu=2.5,
ard_num_dims=num_param,
batch_shape=batch_shape,
lengthscale_prior=GammaPrior(3.0, 6.0),
)

self.kernel_dict = {}  # scaled kernel for each parameter space partition
for context in list(decomposition.keys()):
self.kernel_dict[context] = ScaleKernel(
base_kernel=self.base_kernel, outputscale_prior=GammaPrior(2.0, 15.0)
)
self.kernel_dict = ModuleDict(self.kernel_dict)

@property
def device(self) -> Optional[torch.device]:
return self._device

def forward(
self,
x1: Tensor,
x2: Tensor,
diag: bool = False,
last_dim_is_batch: bool = False,
**params: Any,
) -> Tensor:
"""
iterate across each partition of parameter space and sum the
covariance matrices together
"""
# same lengthscale for all the components
covars = [
self.kernel_dict[context](
x1=x1.index_select(dim=-1, index=active_params),  # pyre-ignore
x2=x2.index_select(dim=-1, index=active_params),
diag=diag,
)
for context, active_params in self._indexers.items()
]

if diag:
res = sum(covars)
else:
res = SumLinearOperator(*covars)
return res