Source code for botorch.models.kernels.contextual_lcea

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any

import torch
from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior
from gpytorch.constraints import Positive
from gpytorch.kernels.kernel import Kernel
from gpytorch.priors.torch_priors import GammaPrior
from linear_operator.operators import DiagLinearOperator
from linear_operator.operators.dense_linear_operator import DenseLinearOperator
from torch import Tensor
from torch.nn import ModuleList


def get_order(indices: list[int]) -> list[int]:
    r"""Get the order indices as integers ranging from 0 to the number of indices.

    Args:
        indices: A list of parameter indices.

    Returns:
        A list of integers ranging from 0 to the number of indices.
    """
    return [i % len(indices) for i in indices]


def is_contiguous(indices: list[int]) -> bool:
    r"""Check if the list of integers is contiguous.

    Args:
        indices: A list of parameter indices.
    Returns:
        A boolean indicating whether the indices are contiguous.
    """
    min_idx = min(indices)
    return set(indices) == set(range(min_idx, min_idx + len(indices)))


def get_permutation(decomposition: dict[str, list[int]]) -> list[int] | None:
    """Construct permutation to reorder the parameters such that:

    1) the parameters for each context are contiguous.
    2) The parameters for each context are in the same order

    Args:
        decomposition: A dictionary mapping context names to a list of
            parameters.
    Returns:
        A permutation to reorder the parameters for (1) and (2).
        Returning `None` means that ordering specified in `decomposition`
        satisfies (1) and (2).
    """
    permutation = None
    if not all(
        is_contiguous(indices=active_parameters)
        for active_parameters in decomposition.values()
    ):
        permutation = _create_new_permutation(decomposition=decomposition)
    else:
        same_order = True
        expected_order = get_order(indices=next(iter(decomposition.values())))
        for active_parameters in decomposition.values():
            order = get_order(indices=active_parameters)
            if order != expected_order:
                same_order = False
                break
        if not same_order:
            permutation = _create_new_permutation(decomposition=decomposition)
    return permutation


def _create_new_permutation(decomposition: dict[str, list[int]]) -> list[int]:
    # make contiguous and ordered
    permutation = []
    for active_parameters in decomposition.values():
        sorted_indices = sorted(active_parameters)
        permutation.extend(sorted_indices)
    return permutation


[docs] class LCEAKernel(Kernel): r"""The Latent Context Embedding Additive (LCE-A) Kernel. This kernel is similar to the SACKernel, and is used when context breakdowns are unbserverable. It assumes the same additive structure and a spatial kernel shared across contexts. Rather than assuming independence, LCEAKernel models the correlation in the latent functions for each context through learning context embeddings. """ def __init__( self, decomposition: dict[str, list[int]], batch_shape: torch.Size, train_embedding: bool = True, cat_feature_dict: dict | None = None, embs_feature_dict: dict | None = None, embs_dim_list: list[int] | None = None, context_weight_dict: dict | None = None, device: torch.device | None = None, ) -> None: r""" Args: decomposition: Keys index context names. Values are the indexes of parameters belong to the context. batch_shape: Batch shape as usual for gpytorch kernels. Model does not support batch training. When batch_shape is non-empty, it is used for loading hyper-parameter values generated from MCMC sampling. train_embedding: A boolean indictor of whether to learn context embeddings. cat_feature_dict: Keys are context names and values are list of categorical features i.e. {"context_name" : [cat_0, ..., cat_k]}. k equals the number of categorical variables. If None, uses context names in the decomposition as the only categorical feature, i.e., k = 1. embs_feature_dict: Pre-trained continuous embedding features of each context. embs_dim_list: Embedding dimension for each categorical variable. The length equals to num of categorical features k. If None, the embedding dimension is set to 1 for each categorical variable. context_weight_dict: Known population weights of each context. """ super().__init__(batch_shape=batch_shape) self.batch_shape = batch_shape self.train_embedding = train_embedding self._device = device self.num_param = len(next(iter(decomposition.values()))) self.context_list = list(decomposition.keys()) self.num_contexts = len(self.context_list) # get parameter space decomposition for active_parameters in decomposition.values(): # check number of parameters are same in each decomp if len(active_parameters) != self.num_param: raise ValueError( "The number of parameters needs to be same across all contexts." ) # reorder the parameter list based on decomposition such that # parameters for each context are contiguous and in the same order for each # context self.permutation = get_permutation(decomposition=decomposition) # get context features and set emb dim self.context_cat_feature = None self.context_emb_feature = None self.n_embs = 0 self.emb_weight_matrix_list = None self.emb_dims = None self._set_context_features( cat_feature_dict=cat_feature_dict, embs_feature_dict=embs_feature_dict, embs_dim_list=embs_dim_list, ) # contruct embedding layer if train_embedding: self._set_emb_layers() # task covariance matrix self.task_covar_module = get_covar_module_with_dim_scaled_prior( ard_num_dims=self.n_embs, batch_shape=batch_shape, ) # base kernel self.base_kernel = get_covar_module_with_dim_scaled_prior( ard_num_dims=self.num_param, batch_shape=batch_shape, ) # outputscales for each context (note this is like sqrt of outputscale) self.context_weight = None if context_weight_dict is None: outputscale_list = torch.zeros( *batch_shape, self.num_contexts, device=self.device ) else: outputscale_list = torch.zeros(*batch_shape, 1, device=self.device) self.context_weight = torch.tensor( [context_weight_dict[c] for c in self.context_list], device=self.device ) self.register_parameter( name="raw_outputscale_list", parameter=torch.nn.Parameter(outputscale_list) ) self.register_prior( "outputscale_list_prior", GammaPrior(2.0, 15.0), lambda m: m.outputscale_list, lambda m, v: m._set_outputscale_list(v), ) self.register_constraint("raw_outputscale_list", Positive()) @property def device(self) -> torch.device | None: return self._device @property def outputscale_list(self) -> Tensor: return self.raw_outputscale_list_constraint.transform(self.raw_outputscale_list) @outputscale_list.setter def outputscale_list(self, value: Tensor) -> None: self._set_outputscale_list(value) def _set_outputscale_list(self, value: Tensor) -> None: if not torch.is_tensor(value): value = torch.as_tensor(value).to(self.raw_outputscale_list) self.initialize( raw_outputscale_list=self.raw_outputscale_list_constraint.inverse_transform( value ) ) def _set_context_features( self, cat_feature_dict: dict | None = None, embs_feature_dict: dict | None = None, embs_dim_list: list[int] | None = None, ) -> None: """Set context categorical features and continuous embedding features. If cat_feature_dict is None, context indices will be used; If embs_dim_list is None, we use 1-d embedding for each categorical features. """ # get context categorical features if cat_feature_dict is None: self.context_cat_feature = torch.arange( self.num_contexts, device=self.device ).unsqueeze(-1) else: self.context_cat_feature = torch.tensor( [cat_feature_dict[c] for c in self.context_list] ) # construct emb_dims based on categorical features if embs_dim_list is None: # set embedding_dim = 1 for each categorical variable embs_dim_list = [1 for _i in range(self.context_cat_feature.size(1))] self.emb_dims = [ (len(self.context_cat_feature[:, i].unique()), embs_dim_list[i]) for i in range(self.context_cat_feature.size(1)) ] if self.train_embedding: self.n_embs = sum(embs_dim_list) # total num of emb features # get context embedding features if embs_feature_dict is not None: self.context_emb_feature = torch.tensor( [embs_feature_dict[c] for c in self.context_list], device=self.device ) self.n_embs += self.context_emb_feature.size(1) def _set_emb_layers(self) -> None: """Construct embedding layers. If model is non-batch, we use nn.Embedding to learn emb weights. If model is batched (sef.batch_shape is non-empty), we load emb weights posterior samples and construct a parameter list that each parameter is the emb weight of each layer. The shape of weight matrices are ns x num_contexts x emb_dim. """ self.emb_layers = ModuleList( [ torch.nn.Embedding(num_embeddings=x, embedding_dim=y, max_norm=1.0) for x, y in self.emb_dims ] ) # use posterior of emb weights if len(self.batch_shape) > 0: self.emb_weight_matrix_list = torch.nn.ParameterList( [ torch.nn.Parameter( torch.zeros( self.batch_shape + emb_layer.weight.shape, device=self.device, ) ) for emb_layer in self.emb_layers ] ) def _eval_context_covar(self) -> Tensor: """Compute context covariance matrix. Returns: A (ns) x num_contexts x num_contexts tensor. """ if len(self.batch_shape) > 0: # broadcast - (ns x num_contexts x k) all_embs = self._task_embeddings_batch() else: all_embs = self._task_embeddings() # no broadcast - (num_contexts x k) context_covar = self.task_covar_module(all_embs).to_dense() if self.context_weight is None: context_outputscales = self.outputscale_list else: context_outputscales = self.outputscale_list * self.context_weight context_covar = ( (context_outputscales.unsqueeze(-2)) # (ns) x 1 x num_contexts .mul(context_covar) .mul(context_outputscales.unsqueeze(-1)) # (ns) x num_contexts x 1 ) return context_covar def _task_embeddings(self) -> Tensor: """Generate embedding features of contexts when model is non-batch. Returns: a (num_contexts x n_embs) tensor. n_embs is the sum of embedding dimensions i.e. sum(embs_dim_list) """ if self.train_embedding is False: return self.context_emb_feature # use pre-trained embedding only context_features = torch.stack( [self.context_cat_feature[i, :] for i in range(self.num_contexts)], dim=0 ) embeddings = [ emb_layer(context_features[:, i].to(device=self.device, dtype=torch.long)) for i, emb_layer in enumerate(self.emb_layers) ] embeddings = torch.cat(embeddings, dim=1) # add given embeddings if any if self.context_emb_feature is not None: embeddings = torch.cat([embeddings, self.context_emb_feature], dim=1) return embeddings def _task_embeddings_batch(self) -> Tensor: """Generate embedding features of contexts when model has multiple batches. Returns: a (ns) x num_contexts x n_embs tensor. ns is the batch size i.e num of posterior samples; n_embs is the sum of embedding dimensions i.e. sum(embs_dim_list). """ context_features = torch.cat( [ self.context_cat_feature[i, :].unsqueeze(0) for i in range(self.num_contexts) ] ) embeddings = [] for b in range(self.batch_shape.numel()): # pyre-ignore for i in range(len(self.emb_weight_matrix_list)): # loop over emb layer and concat embs from each layer embeddings.append( torch.cat( [ torch.nn.functional.embedding( context_features[:, 0].to( dtype=torch.long, device=self.device ), self.emb_weight_matrix_list[i][b, :], ).unsqueeze(0) ], dim=1, ) ) embeddings = torch.cat(embeddings, dim=0) # add given embeddings if any if self.context_emb_feature is not None: embeddings = torch.cat( [ embeddings, self.context_emb_feature.expand( *self.batch_shape + self.context_emb_feature.shape ), ], dim=-1, ) return embeddings def train(self, mode: bool = True) -> None: super().train(mode=mode) if not mode: self.register_buffer("_context_covar", self._eval_context_covar()) def forward( self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch: bool = False, **params: Any, ) -> Tensor: """Iterate across each partition of parameter space and sum the covariance matrices together """ # context covar matrix context_covar = ( self._eval_context_covar() if self.training else self._context_covar ) base_covar_perm = self._eval_base_covar_perm(x1, x2) # expand context_covar to match base_covar_perm if base_covar_perm.dim() > context_covar.dim(): context_covar = context_covar.expand(base_covar_perm.shape) # then weight by the context kernel # compute the base kernel on the d parameters einsum_str = "...nnki, ...nnki -> ...n" if diag else "...ki, ...ki -> ..." covar_dense = torch.einsum(einsum_str, context_covar, base_covar_perm) if diag: return DiagLinearOperator(covar_dense) return DenseLinearOperator(covar_dense) def _eval_base_covar_perm(self, x1: Tensor, x2: Tensor) -> Tensor: """Computes the base covariance matrix on x1, x2, applying permutations and reshaping the kernel matrix as required by `forward`. NOTE: Using the notation n = num_observations, k = num_contexts, d = input_dim, the input tensors have to have the following shapes. Args: x1: `batch_shape x n x (k*d)`-dim Tensor of kernel inputs. x2: `batch_shape x n x (k*d)`-dim Tensor of kernel inputs. Returns: `batch_shape x n x n x k x k`-dim Tensor of base covariance values. """ if self.permutation is not None: x1 = x1[..., self.permutation] x2 = x2[..., self.permutation] # turn last two dimensions of n x (k*d) into (n*k) x d. x1_exp = x1.reshape(*x1.shape[:-2], -1, self.num_param) x2_exp = x2.reshape(*x2.shape[:-2], -1, self.num_param) # batch shape x n*k x n*k base_covar = self.base_kernel(x1_exp, x2_exp) # batch shape x n x n x k x k view_shape = x1.shape[:-2] + torch.Size( [ x1.shape[-2], self.num_contexts, x2.shape[-2], self.num_contexts, ] ) base_covar_perm = ( base_covar.to_dense() .view(view_shape) .permute(*list(range(x1.ndim - 2)), -4, -2, -3, -1) ) return base_covar_perm