Source code for botorch.models.kernels.contextual_sac

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any

import torch
from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior
from gpytorch.kernels.kernel import Kernel
from gpytorch.kernels.scale_kernel import ScaleKernel
from gpytorch.priors.torch_priors import GammaPrior
from linear_operator.operators.sum_linear_operator import SumLinearOperator
from torch import Tensor
from torch.nn import ModuleDict  # pyre-ignore


[docs] class SACKernel(Kernel): r"""The structural additive contextual(SAC) kernel. The kernel is used for contextual BO without oberseving context breakdowns. There are d parameters and M contexts. In total, the dimension of parameter space is d*M and input x can be written as x=[x_11, ..., x_1d, x_21, ..., x_2d, ..., x_M1, ..., x_Md]. The kernel uses the parameter decomposition and assumes an additive structure across contexts. Each context compponent is assumed to be independent. .. math:: \begin{equation*} k(\mathbf{x}, \mathbf{x'}) = k_1(\mathbf{x_(1)}, \mathbf{x'_(1)}) + \cdots + k_M(\mathbf{x_(M)}, \mathbf{x'_(M)}) \end{equation*} where * :math: M is the number of partitions of parameter space. Each partition contains same number of parameters d. Each kernel `k_i` acts only on d parameters of ith partition i.e. `\mathbf{x}_(i)`. Each kernel `k_i` is a scaled RBF kernel with same lengthscales but different outputscales. """ def __init__( self, decomposition: dict[str, list[int]], batch_shape: torch.Size, device: torch.device | None = None, ) -> None: r""" Args: decomposition: Keys are context names. Values are the indexes of parameters belong to the context. The parameter indexes are in the same order across contexts. batch_shape: Batch shape as usual for gpytorch kernels. device: The torch device. """ super().__init__(batch_shape=batch_shape) self.decomposition = decomposition self._device = device num_param = len(next(iter(decomposition.values()))) for active_parameters in decomposition.values(): # check number of parameters are same in each decomp if len(active_parameters) != num_param: raise ValueError( "num of parameters needs to be same across all contexts" ) self._indexers = { context: torch.tensor(active_params, device=self.device) for context, active_params in self.decomposition.items() } self.base_kernel = get_covar_module_with_dim_scaled_prior( ard_num_dims=num_param, batch_shape=batch_shape, ) self.kernel_dict = {} # scaled kernel for each parameter space partition for context in list(decomposition.keys()): self.kernel_dict[context] = ScaleKernel( base_kernel=self.base_kernel, outputscale_prior=GammaPrior(2.0, 15.0) ) self.kernel_dict = ModuleDict(self.kernel_dict) @property def device(self) -> torch.device | None: return self._device def forward( self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch: bool = False, **params: Any, ) -> Tensor: """ iterate across each partition of parameter space and sum the covariance matrices together """ # same lengthscale for all the components covars = [ self.kernel_dict[context]( x1=x1.index_select(dim=-1, index=active_params), # pyre-ignore x2=x2.index_select(dim=-1, index=active_params), diag=diag, ) for context, active_params in self._indexers.items() ] if diag: res = sum(covars) else: res = SumLinearOperator(*covars) return res