#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from typing import Any, Optional
import torch
from botorch.exceptions import UnsupportedError
from gpytorch.constraints import Interval, Positive
from gpytorch.kernels import Kernel
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.priors import Prior
from gpytorch.priors.torch_priors import GammaPrior
from torch import Tensor
[docs]
class LinearTruncatedFidelityKernel(Kernel):
r"""GPyTorch Linear Truncated Fidelity Kernel.
Computes a covariance matrix based on the Linear truncated kernel between
inputs `x_1` and `x_2` for up to two fidelity parmeters:
K(x_1, x_2) = k_0 + c_1(x_1, x_2)k_1 + c_2(x_1,x_2)k_2 + c_3(x_1,x_2)k_3
where
- `k_i(i=0,1,2,3)` are Matern kernels calculated between non-fidelity
parameters of `x_1` and `x_2` with different priors.
- `c_1=(1 - x_1[f_1])(1 - x_2[f_1]))(1 + x_1[f_1] x_2[f_1])^p` is the kernel
of the the bias term, which can be decomposed into a determistic part
and a polynomial kernel. Here `f_1` is the first fidelity dimension and
`p` is the order of the polynomial kernel.
- `c_3` is the same as `c_1` but is calculated for the second fidelity
dimension `f_2`.
- `c_2` is the interaction term with four deterministic terms and the
polynomial kernel between `x_1[..., [f_1, f_2]]` and
`x_2[..., [f_1, f_2]]`.
Example:
>>> x = torch.randn(10, 5)
>>> # Non-batch: Simple option
>>> covar_module = LinearTruncatedFidelityKernel()
>>> covar = covar_module(x) # Output: LinearOperator of size (10 x 10)
>>>
>>> batch_x = torch.randn(2, 10, 5)
>>> # Batch: Simple option
>>> covar_module = LinearTruncatedFidelityKernel(batch_shape = torch.Size([2]))
>>> covar = covar_module(x) # Output: LinearOperator of size (2 x 10 x 10)
"""
def __init__( # noqa C901
self,
fidelity_dims: list[int],
dimension: Optional[int] = None,
power_prior: Optional[Prior] = None,
power_constraint: Optional[Interval] = None,
nu: float = 2.5,
lengthscale_prior_unbiased: Optional[Prior] = None,
lengthscale_prior_biased: Optional[Prior] = None,
lengthscale_constraint_unbiased: Optional[Interval] = None,
lengthscale_constraint_biased: Optional[Interval] = None,
covar_module_unbiased: Optional[Kernel] = None,
covar_module_biased: Optional[Kernel] = None,
**kwargs: Any,
) -> None:
"""
Args:
fidelity_dims: A list containing either one or two indices specifying
the fidelity parameters of the input.
dimension: The dimension of `x`. Unused if `active_dims` is specified.
power_prior: Prior for the power parameter of the polynomial kernel.
Default is `None`.
power_constraint: Constraint on the power parameter of the polynomial
kernel. Default is `Positive`.
nu: The smoothness parameter for the Matern kernel: either 1/2, 3/2,
or 5/2. Unused if both `covar_module_unbiased` and
`covar_module_biased` are specified.
lengthscale_prior_unbiased: Prior on the lengthscale parameter of Matern
kernel `k_0`. Default is `Gamma(1.1, 1/20)`.
lengthscale_constraint_unbiased: Constraint on the lengthscale parameter
of the Matern kernel `k_0`. Default is `Positive`.
lengthscale_prior_biased: Prior on the lengthscale parameter of Matern
kernels `k_i(i>0)`. Default is `Gamma(5, 1/20)`.
lengthscale_constraint_biased: Constraint on the lengthscale parameter
of the Matern kernels `k_i(i>0)`. Default is `Positive`.
covar_module_unbiased: Specify a custom kernel for `k_0`. If omitted,
use a `MaternKernel`.
covar_module_biased: Specify a custom kernel for the biased parts
`k_i(i>0)`. If omitted, use a `MaternKernel`.
batch_shape: If specified, use a separate lengthscale for each batch of
input data. If `x1` is a `batch_shape x n x d` tensor, this should
be `batch_shape`.
active_dims: Compute the covariance of a subset of input dimensions. The
numbers correspond to the indices of the dimensions.
"""
if dimension is None and kwargs.get("active_dims") is None:
raise UnsupportedError(
"Must specify dimension when not specifying active_dims."
)
n_fidelity = len(fidelity_dims)
if len(set(fidelity_dims)) != n_fidelity:
raise ValueError("fidelity_dims must not have repeated elements")
if n_fidelity not in {1, 2}:
raise UnsupportedError(
"LinearTruncatedFidelityKernel accepts either one or two"
"fidelity parameters."
)
if nu not in {0.5, 1.5, 2.5}:
raise ValueError("nu must be one of 0.5, 1.5, or 2.5")
super().__init__(**kwargs)
self.fidelity_dims = fidelity_dims
if power_constraint is None:
power_constraint = Positive()
if lengthscale_prior_unbiased is None:
lengthscale_prior_unbiased = GammaPrior(3, 6)
if lengthscale_prior_biased is None:
lengthscale_prior_biased = GammaPrior(6, 2)
if lengthscale_constraint_unbiased is None:
lengthscale_constraint_unbiased = Positive()
if lengthscale_constraint_biased is None:
lengthscale_constraint_biased = Positive()
self.register_parameter(
name="raw_power",
parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)),
)
self.register_constraint("raw_power", power_constraint)
if power_prior is not None:
self.register_prior(
"power_prior",
power_prior,
lambda m: m.power,
lambda m, v: m._set_power(v),
)
if self.active_dims is not None:
dimension = len(self.active_dims)
if covar_module_unbiased is None:
covar_module_unbiased = MaternKernel(
nu=nu,
batch_shape=self.batch_shape,
lengthscale_prior=lengthscale_prior_unbiased,
ard_num_dims=dimension - n_fidelity,
lengthscale_constraint=lengthscale_constraint_unbiased,
)
if covar_module_biased is None:
covar_module_biased = MaternKernel(
nu=nu,
batch_shape=self.batch_shape,
lengthscale_prior=lengthscale_prior_biased,
ard_num_dims=dimension - n_fidelity,
lengthscale_constraint=lengthscale_constraint_biased,
)
self.covar_module_unbiased = covar_module_unbiased
self.covar_module_biased = covar_module_biased
@property
def power(self) -> Tensor:
return self.raw_power_constraint.transform(self.raw_power)
@power.setter
def power(self, value: Tensor) -> None:
self._set_power(value)
def _set_power(self, value: Tensor) -> None:
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_power)
self.initialize(raw_power=self.raw_power_constraint.inverse_transform(value))
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Tensor:
if params.get("last_dim_is_batch", False):
raise NotImplementedError(
"last_dim_is_batch not yet supported by LinearTruncatedFidelityKernel"
)
power = self.power.view(*self.batch_shape, 1, 1)
active_dimsM = torch.tensor(
[i for i in range(x1.size(-1)) if i not in self.fidelity_dims],
device=x1.device,
)
if len(active_dimsM) == 0:
raise RuntimeError(
"Input to LinearTruncatedFidelityKernel must have at least one "
"non-fidelity dimension."
)
x1_ = x1.index_select(dim=-1, index=active_dimsM)
x2_ = x2.index_select(dim=-1, index=active_dimsM)
covar_unbiased = self.covar_module_unbiased(x1_, x2_, diag=diag)
covar_biased = self.covar_module_biased(x1_, x2_, diag=diag)
# clamp to avoid numerical issues
fd_idxr0 = torch.full(
(1,), self.fidelity_dims[0], dtype=torch.long, device=x1.device
)
x11_ = x1.index_select(dim=-1, index=fd_idxr0).clamp(0, 1)
x21t_ = x2.index_select(dim=-1, index=fd_idxr0).clamp(0, 1)
if not diag:
x21t_ = x21t_.transpose(-1, -2)
cross_term_1 = (1 - x11_) * (1 - x21t_)
bias_factor = cross_term_1 * (1 + x11_ * x21t_).pow(power)
if len(self.fidelity_dims) > 1:
# clamp to avoid numerical issues
fd_idxr1 = torch.full(
(1,), self.fidelity_dims[1], dtype=torch.long, device=x1.device
)
x12_ = x1.index_select(dim=-1, index=fd_idxr1).clamp(0, 1)
x22t_ = x2.index_select(dim=-1, index=fd_idxr1).clamp(0, 1)
x1b_ = torch.cat([x11_, x12_], dim=-1)
if diag:
x2bt_ = torch.cat([x21t_, x22t_], dim=-1)
k = (1 + (x1b_ * x2bt_).sum(dim=-1, keepdim=True)).pow(power)
else:
x22t_ = x22t_.transpose(-1, -2)
x2bt_ = torch.cat([x21t_, x22t_], dim=-2)
k = (1 + x1b_ @ x2bt_).pow(power)
cross_term_2 = (1 - x12_) * (1 - x22t_)
bias_factor += cross_term_2 * (1 + x12_ * x22t_).pow(power)
bias_factor += cross_term_2 * cross_term_1 * k
if diag:
bias_factor = bias_factor.view(covar_biased.shape)
return covar_unbiased + bias_factor * covar_biased