# Source code for botorch.models.kernels.linear_truncated_fidelity

#!/usr/bin/env python3
#
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from copy import deepcopy
from typing import Any, List, Optional

import torch
from botorch.exceptions import UnsupportedError
from gpytorch.constraints import Interval, Positive
from gpytorch.kernels import Kernel
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.priors import Prior
from gpytorch.priors.torch_priors import GammaPrior
from torch import Tensor

[docs]class LinearTruncatedFidelityKernel(Kernel):
r"""GPyTorch Linear Truncated Fidelity Kernel.

Computes a covariance matrix based on the Linear truncated kernel between
inputs x_1 and x_2 for up to two fidelity parmeters:

K(x_1, x_2) = k_0 + c_1(x_1, x_2)k_1 + c_2(x_1,x_2)k_2 + c_3(x_1,x_2)k_3

where

- k_i(i=0,1,2,3) are Matern kernels calculated between non-fidelity
parameters of x_1 and x_2 with different priors.
- c_1=(1 - x_1[f_1])(1 - x_2[f_1]))(1 + x_1[f_1] x_2[f_1])^p is the kernel
of the the bias term, which can be decomposed into a determistic part
and a polynomial kernel. Here f_1 is the first fidelity dimension and
p is the order of the polynomial kernel.
- c_3 is the same as c_1 but is calculated for the second fidelity
dimension f_2.
- c_2 is the interaction term with four deterministic terms and the
polynomial kernel between x_1[..., [f_1, f_2]] and
x_2[..., [f_1, f_2]].

Args:
fidelity_dims: A list containing either one or two indices specifying
the fidelity parameters of the input.
dimension: The dimension of x. Unused if active_dims is specified.
power_prior: Prior for the power parameter of the polynomial kernel.
Default is None.
power_constraint: Constraint on the power parameter of the polynomial
kernel. Default is Positive.
nu: The smoothness parameter for the Matern kernel: either 1/2, 3/2,
or 5/2. Unused if both covar_module_unbiased and
covar_module_biased are specified.
lengthscale_prior_unbiased: Prior on the lengthscale parameter of Matern
kernel k_0. Default is Gamma(1.1, 1/20).
lengthscale_constraint_unbiased: Constraint on the lengthscale parameter
of the Matern kernel k_0. Default is Positive.
lengthscale_prior_biased: Prior on the lengthscale parameter of Matern
kernels k_i(i>0). Default is Gamma(5, 1/20).
lengthscale_constraint_biased: Constraint on the lengthscale parameter
of the Matern kernels k_i(i>0). Default is Positive.
covar_module_unbiased: Specify a custom kernel for k_0. If omitted,
use a MaternKernel.
covar_module_biased: Specify a custom kernel for the biased parts
k_i(i>0). If omitted, use a MaternKernel.
batch_shape: If specified, use a separate lengthscale for each batch of
input data. If x1 is a batch_shape x n x d tensor, this should
be batch_shape.
active_dims: Compute the covariance of a subset of input dimensions. The
numbers correspond to the indices of the dimensions.

Example:
>>> x = torch.randn(10, 5)
>>> # Non-batch: Simple option
>>> covar_module = LinearTruncatedFidelityKernel()
>>> covar = covar_module(x)  # Output: LazyVariable of size (10 x 10)
>>>
>>> batch_x = torch.randn(2, 10, 5)
>>> # Batch: Simple option
>>> covar_module = LinearTruncatedFidelityKernel(batch_shape = torch.Size([2]))
>>> covar = covar_module(x)  # Output: LazyVariable of size (2 x 10 x 10)
"""

def __init__(  # noqa C901
self,
fidelity_dims: List[int],
dimension: Optional[int] = None,
power_prior: Optional[Prior] = None,
power_constraint: Optional[Interval] = None,
nu: float = 2.5,
lengthscale_prior_unbiased: Optional[Prior] = None,
lengthscale_prior_biased: Optional[Prior] = None,
lengthscale_constraint_unbiased: Optional[Interval] = None,
lengthscale_constraint_biased: Optional[Interval] = None,
covar_module_unbiased: Optional[Kernel] = None,
covar_module_biased: Optional[Kernel] = None,
**kwargs: Any,
) -> None:
if dimension is None and kwargs.get("active_dims") is None:
raise UnsupportedError(
"Must specify dimension when not specifying active_dims."
)
n_fidelity = len(fidelity_dims)
if len(set(fidelity_dims)) != n_fidelity:
raise ValueError("fidelity_dims must not have repeated elements")
if n_fidelity not in {1, 2}:
raise UnsupportedError(
"LinearTruncatedFidelityKernel accepts either one or two"
"fidelity parameters."
)
if nu not in {0.5, 1.5, 2.5}:
raise ValueError("nu must be one of 0.5, 1.5, or 2.5")

super().__init__(**kwargs)
self.fidelity_dims = fidelity_dims
if power_constraint is None:
power_constraint = Positive()

if lengthscale_prior_unbiased is None:
lengthscale_prior_unbiased = GammaPrior(3, 6)

if lengthscale_prior_biased is None:
lengthscale_prior_biased = GammaPrior(6, 2)

if lengthscale_constraint_unbiased is None:
lengthscale_constraint_unbiased = Positive()

if lengthscale_constraint_biased is None:
lengthscale_constraint_biased = Positive()

self.register_parameter(
name="raw_power",
parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)),
)
self.register_constraint("raw_power", power_constraint)

if power_prior is not None:
self.register_prior(
"power_prior",
power_prior,
lambda m: m.power,
lambda m, v: m._set_power(v),
)

if self.active_dims is not None:
dimension = len(self.active_dims)

if covar_module_unbiased is None:
covar_module_unbiased = MaternKernel(
nu=nu,
batch_shape=self.batch_shape,
lengthscale_prior=lengthscale_prior_unbiased,
ard_num_dims=dimension - n_fidelity,
lengthscale_constraint=lengthscale_constraint_unbiased,
)

if covar_module_biased is None:
covar_module_biased = MaternKernel(
nu=nu,
batch_shape=self.batch_shape,
lengthscale_prior=lengthscale_prior_biased,
ard_num_dims=dimension - n_fidelity,
lengthscale_constraint=lengthscale_constraint_biased,
)

self.covar_module_unbiased = covar_module_unbiased
self.covar_module_biased = covar_module_biased

@property
def power(self) -> Tensor:
return self.raw_power_constraint.transform(self.raw_power)

@power.setter
def power(self, value: Tensor) -> None:
self._set_power(value)

def _set_power(self, value: Tensor) -> None:
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_power)
self.initialize(raw_power=self.raw_power_constraint.inverse_transform(value))

def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Tensor:
if params.get("last_dim_is_batch", False):
raise NotImplementedError(
"last_dim_is_batch not yet supported by LinearTruncatedFidelityKernel"
)

power = self.power.view(*self.batch_shape, 1, 1)
active_dimsM = torch.tensor(
[i for i in range(x1.size(-1)) if i not in self.fidelity_dims],
device=x1.device,
)
if len(active_dimsM) == 0:
raise RuntimeError(
"Input to LinearTruncatedFidelityKernel must have at least one "
" non-fidelity dimension"
)
x1_ = x1.index_select(dim=-1, index=active_dimsM)
x2_ = x2.index_select(dim=-1, index=active_dimsM)
covar_unbiased = self.covar_module_unbiased(x1_, x2_, diag=diag)
covar_biased = self.covar_module_biased(x1_, x2_, diag=diag)

# clamp to avoid numerical issues
fd_idxr0 = torch.full(
(1,), self.fidelity_dims[0], dtype=torch.long, device=x1.device
)
x11_ = x1.index_select(dim=-1, index=fd_idxr0).clamp(0, 1)
x21t_ = x2.index_select(dim=-1, index=fd_idxr0).clamp(0, 1)
if not diag:
x21t_ = x21t_.transpose(-1, -2)
cross_term_1 = (1 - x11_) * (1 - x21t_)
bias_factor = cross_term_1 * (1 + x11_ * x21t_).pow(power)

if len(self.fidelity_dims) > 1:
# clamp to avoid numerical issues
fd_idxr1 = torch.full(
(1,), self.fidelity_dims[1], dtype=torch.long, device=x1.device
)
x12_ = x1.index_select(dim=-1, index=fd_idxr1).clamp(0, 1)
x22t_ = x2.index_select(dim=-1, index=fd_idxr1).clamp(0, 1)
x1b_ = torch.cat([x11_, x12_], dim=-1)
if diag:
x2bt_ = torch.cat([x21t_, x22t_], dim=-1)
k = (1 + (x1b_ * x2bt_).sum(dim=-1, keepdim=True)).pow(power)
else:
x22t_ = x22t_.transpose(-1, -2)
x2bt_ = torch.cat([x21t_, x22t_], dim=-2)
k = (1 + x1b_ @ x2bt_).pow(power)

cross_term_2 = (1 - x12_) * (1 - x22t_)
bias_factor += cross_term_2 * (1 + x12_ * x22t_).pow(power)
bias_factor += cross_term_2 * cross_term_1 * k

if diag:
bias_factor = bias_factor.view(covar_biased.shape)

return covar_unbiased + bias_factor * covar_biased

def __getitem__(self, index) -> LinearTruncatedFidelityKernel:
new_kernel = deepcopy(self)
new_kernel.covar_module_unbiased = new_kernel.covar_module_unbiased[index]
new_kernel.covar_module_biased = new_kernel.covar_module_biased[index]
new_kernel.raw_power = torch.nn.Parameter(new_kernel.raw_power[index])
new_kernel.batch_shape = new_kernel.batch_shape[
1 if isinstance(index, int) else len(index) :
]
return new_kernel