#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Optional
import torch
from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior
from gpytorch.constraints import Positive
from gpytorch.kernels.kernel import Kernel
from gpytorch.priors.torch_priors import GammaPrior
from linear_operator.operators import DiagLinearOperator
from linear_operator.operators.dense_linear_operator import DenseLinearOperator
from torch import Tensor
from torch.nn import ModuleList
def get_order(indices: list[int]) -> list[int]:
r"""Get the order indices as integers ranging from 0 to the number of indices.
Args:
indices: A list of parameter indices.
Returns:
A list of integers ranging from 0 to the number of indices.
"""
return [i % len(indices) for i in indices]
def is_contiguous(indices: list[int]) -> bool:
r"""Check if the list of integers is contiguous.
Args:
indices: A list of parameter indices.
Returns:
A boolean indicating whether the indices are contiguous.
"""
min_idx = min(indices)
return set(indices) == set(range(min_idx, min_idx + len(indices)))
def get_permutation(decomposition: dict[str, list[int]]) -> Optional[list[int]]:
"""Construct permutation to reorder the parameters such that:
1) the parameters for each context are contiguous.
2) The parameters for each context are in the same order
Args:
decomposition: A dictionary mapping context names to a list of
parameters.
Returns:
A permutation to reorder the parameters for (1) and (2).
Returning `None` means that ordering specified in `decomposition`
satisfies (1) and (2).
"""
permutation = None
if not all(
is_contiguous(indices=active_parameters)
for active_parameters in decomposition.values()
):
permutation = _create_new_permutation(decomposition=decomposition)
else:
same_order = True
expected_order = get_order(indices=next(iter(decomposition.values())))
for active_parameters in decomposition.values():
order = get_order(indices=active_parameters)
if order != expected_order:
same_order = False
break
if not same_order:
permutation = _create_new_permutation(decomposition=decomposition)
return permutation
def _create_new_permutation(decomposition: dict[str, list[int]]) -> list[int]:
# make contiguous and ordered
permutation = []
for active_parameters in decomposition.values():
sorted_indices = sorted(active_parameters)
permutation.extend(sorted_indices)
return permutation
[docs]
class LCEAKernel(Kernel):
r"""The Latent Context Embedding Additive (LCE-A) Kernel.
This kernel is similar to the SACKernel, and is used when context breakdowns are
unbserverable. It assumes the same additive structure and a spatial kernel shared
across contexts. Rather than assuming independence, LCEAKernel models the
correlation in the latent functions for each context through learning context
embeddings.
"""
def __init__(
self,
decomposition: dict[str, list[int]],
batch_shape: torch.Size,
train_embedding: bool = True,
cat_feature_dict: Optional[dict] = None,
embs_feature_dict: Optional[dict] = None,
embs_dim_list: Optional[list[int]] = None,
context_weight_dict: Optional[dict] = None,
device: Optional[torch.device] = None,
) -> None:
r"""
Args:
decomposition: Keys index context names. Values are the indexes of
parameters belong to the context.
batch_shape: Batch shape as usual for gpytorch kernels. Model does not
support batch training. When batch_shape is non-empty, it is used for
loading hyper-parameter values generated from MCMC sampling.
train_embedding: A boolean indictor of whether to learn context embeddings.
cat_feature_dict: Keys are context names and values are list of categorical
features i.e. {"context_name" : [cat_0, ..., cat_k]}. k equals the
number of categorical variables. If None, uses context names in the
decomposition as the only categorical feature, i.e., k = 1.
embs_feature_dict: Pre-trained continuous embedding features of each
context.
embs_dim_list: Embedding dimension for each categorical variable. The length
equals to num of categorical features k. If None, the embedding
dimension is set to 1 for each categorical variable.
context_weight_dict: Known population weights of each context.
"""
super().__init__(batch_shape=batch_shape)
self.batch_shape = batch_shape
self.train_embedding = train_embedding
self._device = device
self.num_param = len(next(iter(decomposition.values())))
self.context_list = list(decomposition.keys())
self.num_contexts = len(self.context_list)
# get parameter space decomposition
for active_parameters in decomposition.values():
# check number of parameters are same in each decomp
if len(active_parameters) != self.num_param:
raise ValueError(
"The number of parameters needs to be same across all contexts."
)
# reorder the parameter list based on decomposition such that
# parameters for each context are contiguous and in the same order for each
# context
self.permutation = get_permutation(decomposition=decomposition)
# get context features and set emb dim
self.context_cat_feature = None
self.context_emb_feature = None
self.n_embs = 0
self.emb_weight_matrix_list = None
self.emb_dims = None
self._set_context_features(
cat_feature_dict=cat_feature_dict,
embs_feature_dict=embs_feature_dict,
embs_dim_list=embs_dim_list,
)
# contruct embedding layer
if train_embedding:
self._set_emb_layers()
# task covariance matrix
self.task_covar_module = get_covar_module_with_dim_scaled_prior(
ard_num_dims=self.n_embs,
batch_shape=batch_shape,
)
# base kernel
self.base_kernel = get_covar_module_with_dim_scaled_prior(
ard_num_dims=self.num_param,
batch_shape=batch_shape,
)
# outputscales for each context (note this is like sqrt of outputscale)
self.context_weight = None
if context_weight_dict is None:
outputscale_list = torch.zeros(
*batch_shape, self.num_contexts, device=self.device
)
else:
outputscale_list = torch.zeros(*batch_shape, 1, device=self.device)
self.context_weight = torch.tensor(
[context_weight_dict[c] for c in self.context_list], device=self.device
)
self.register_parameter(
name="raw_outputscale_list", parameter=torch.nn.Parameter(outputscale_list)
)
self.register_prior(
"outputscale_list_prior",
GammaPrior(2.0, 15.0),
lambda m: m.outputscale_list,
lambda m, v: m._set_outputscale_list(v),
)
self.register_constraint("raw_outputscale_list", Positive())
@property
def device(self) -> Optional[torch.device]:
return self._device
@property
def outputscale_list(self) -> Tensor:
return self.raw_outputscale_list_constraint.transform(self.raw_outputscale_list)
@outputscale_list.setter
def outputscale_list(self, value: Tensor) -> None:
self._set_outputscale_list(value)
def _set_outputscale_list(self, value: Tensor) -> None:
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_outputscale_list)
self.initialize(
raw_outputscale_list=self.raw_outputscale_list_constraint.inverse_transform(
value
)
)
def _set_context_features(
self,
cat_feature_dict: Optional[dict] = None,
embs_feature_dict: Optional[dict] = None,
embs_dim_list: Optional[list[int]] = None,
) -> None:
"""Set context categorical features and continuous embedding features.
If cat_feature_dict is None, context indices will be used; If embs_dim_list
is None, we use 1-d embedding for each categorical features.
"""
# get context categorical features
if cat_feature_dict is None:
self.context_cat_feature = torch.arange(
self.num_contexts, device=self.device
).unsqueeze(-1)
else:
self.context_cat_feature = torch.tensor(
[cat_feature_dict[c] for c in self.context_list]
)
# construct emb_dims based on categorical features
if embs_dim_list is None:
# set embedding_dim = 1 for each categorical variable
embs_dim_list = [1 for _i in range(self.context_cat_feature.size(1))]
self.emb_dims = [
(len(self.context_cat_feature[:, i].unique()), embs_dim_list[i])
for i in range(self.context_cat_feature.size(1))
]
if self.train_embedding:
self.n_embs = sum(embs_dim_list) # total num of emb features
# get context embedding features
if embs_feature_dict is not None:
self.context_emb_feature = torch.tensor(
[embs_feature_dict[c] for c in self.context_list], device=self.device
)
self.n_embs += self.context_emb_feature.size(1)
def _set_emb_layers(self) -> None:
"""Construct embedding layers.
If model is non-batch, we use nn.Embedding to learn emb weights. If model is
batched (sef.batch_shape is non-empty), we load emb weights posterior samples
and construct a parameter list that each parameter is the emb weight of each
layer. The shape of weight matrices are ns x num_contexts x emb_dim.
"""
self.emb_layers = ModuleList(
[
torch.nn.Embedding(num_embeddings=x, embedding_dim=y, max_norm=1.0)
for x, y in self.emb_dims
]
)
# use posterior of emb weights
if len(self.batch_shape) > 0:
self.emb_weight_matrix_list = torch.nn.ParameterList(
[
torch.nn.Parameter(
torch.zeros(
self.batch_shape + emb_layer.weight.shape,
device=self.device,
)
)
for emb_layer in self.emb_layers
]
)
def _eval_context_covar(self) -> Tensor:
"""Compute context covariance matrix.
Returns:
A (ns) x num_contexts x num_contexts tensor.
"""
if len(self.batch_shape) > 0:
# broadcast - (ns x num_contexts x k)
all_embs = self._task_embeddings_batch()
else:
all_embs = self._task_embeddings() # no broadcast - (num_contexts x k)
context_covar = self.task_covar_module(all_embs).to_dense()
if self.context_weight is None:
context_outputscales = self.outputscale_list
else:
context_outputscales = self.outputscale_list * self.context_weight
context_covar = (
(context_outputscales.unsqueeze(-2)) # (ns) x 1 x num_contexts
.mul(context_covar)
.mul(context_outputscales.unsqueeze(-1)) # (ns) x num_contexts x 1
)
return context_covar
def _task_embeddings(self) -> Tensor:
"""Generate embedding features of contexts when model is non-batch.
Returns:
a (num_contexts x n_embs) tensor. n_embs is the sum of embedding
dimensions i.e. sum(embs_dim_list)
"""
if self.train_embedding is False:
return self.context_emb_feature # use pre-trained embedding only
context_features = torch.stack(
[self.context_cat_feature[i, :] for i in range(self.num_contexts)], dim=0
)
embeddings = [
emb_layer(context_features[:, i].to(device=self.device, dtype=torch.long))
for i, emb_layer in enumerate(self.emb_layers)
]
embeddings = torch.cat(embeddings, dim=1)
# add given embeddings if any
if self.context_emb_feature is not None:
embeddings = torch.cat([embeddings, self.context_emb_feature], dim=1)
return embeddings
def _task_embeddings_batch(self) -> Tensor:
"""Generate embedding features of contexts when model has multiple batches.
Returns:
a (ns) x num_contexts x n_embs tensor. ns is the batch size i.e num of
posterior samples; n_embs is the sum of embedding dimensions i.e.
sum(embs_dim_list).
"""
context_features = torch.cat(
[
self.context_cat_feature[i, :].unsqueeze(0)
for i in range(self.num_contexts)
]
)
embeddings = []
for b in range(self.batch_shape.numel()): # pyre-ignore
for i in range(len(self.emb_weight_matrix_list)):
# loop over emb layer and concat embs from each layer
embeddings.append(
torch.cat(
[
torch.nn.functional.embedding(
context_features[:, 0].to(
dtype=torch.long, device=self.device
),
self.emb_weight_matrix_list[i][b, :],
).unsqueeze(0)
],
dim=1,
)
)
embeddings = torch.cat(embeddings, dim=0)
# add given embeddings if any
if self.context_emb_feature is not None:
embeddings = torch.cat(
[
embeddings,
self.context_emb_feature.expand(
*self.batch_shape + self.context_emb_feature.shape
),
],
dim=-1,
)
return embeddings
def train(self, mode: bool = True) -> None:
super().train(mode=mode)
if not mode:
self.register_buffer("_context_covar", self._eval_context_covar())
def forward(
self,
x1: Tensor,
x2: Tensor,
diag: bool = False,
last_dim_is_batch: bool = False,
**params: Any,
) -> Tensor:
"""Iterate across each partition of parameter space and sum the
covariance matrices together
"""
# context covar matrix
context_covar = (
self._eval_context_covar() if self.training else self._context_covar
)
base_covar_perm = self._eval_base_covar_perm(x1, x2)
# expand context_covar to match base_covar_perm
if base_covar_perm.dim() > context_covar.dim():
context_covar = context_covar.expand(base_covar_perm.shape)
# then weight by the context kernel
# compute the base kernel on the d parameters
einsum_str = "...nnki, ...nnki -> ...n" if diag else "...ki, ...ki -> ..."
covar_dense = torch.einsum(einsum_str, context_covar, base_covar_perm)
if diag:
return DiagLinearOperator(covar_dense)
return DenseLinearOperator(covar_dense)
def _eval_base_covar_perm(self, x1: Tensor, x2: Tensor) -> Tensor:
"""Computes the base covariance matrix on x1, x2, applying permutations and
reshaping the kernel matrix as required by `forward`.
NOTE: Using the notation n = num_observations, k = num_contexts, d = input_dim,
the input tensors have to have the following shapes.
Args:
x1: `batch_shape x n x (k*d)`-dim Tensor of kernel inputs.
x2: `batch_shape x n x (k*d)`-dim Tensor of kernel inputs.
Returns:
`batch_shape x n x n x k x k`-dim Tensor of base covariance values.
"""
if self.permutation is not None:
x1 = x1[..., self.permutation]
x2 = x2[..., self.permutation]
# turn last two dimensions of n x (k*d) into (n*k) x d.
x1_exp = x1.reshape(*x1.shape[:-2], -1, self.num_param)
x2_exp = x2.reshape(*x2.shape[:-2], -1, self.num_param)
# batch shape x n*k x n*k
base_covar = self.base_kernel(x1_exp, x2_exp)
# batch shape x n x n x k x k
view_shape = x1.shape[:-2] + torch.Size(
[
x1.shape[-2],
self.num_contexts,
x2.shape[-2],
self.num_contexts,
]
)
base_covar_perm = (
base_covar.to_dense()
.view(view_shape)
.permute(*list(range(x1.ndim - 2)), -4, -2, -3, -1)
)
return base_covar_perm