Source code for botorch.models.kernels.linear_truncated_fidelity

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import Any

import torch
from botorch.exceptions import UnsupportedError
from gpytorch.constraints import Interval, Positive
from gpytorch.kernels import Kernel
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.priors import Prior
from gpytorch.priors.torch_priors import GammaPrior
from torch import Tensor


[docs] class LinearTruncatedFidelityKernel(Kernel): r"""GPyTorch Linear Truncated Fidelity Kernel. Computes a covariance matrix based on the Linear truncated kernel between inputs `x_1` and `x_2` for up to two fidelity parmeters: K(x_1, x_2) = k_0 + c_1(x_1, x_2)k_1 + c_2(x_1,x_2)k_2 + c_3(x_1,x_2)k_3 where - `k_i(i=0,1,2,3)` are Matern kernels calculated between non-fidelity parameters of `x_1` and `x_2` with different priors. - `c_1=(1 - x_1[f_1])(1 - x_2[f_1]))(1 + x_1[f_1] x_2[f_1])^p` is the kernel of the the bias term, which can be decomposed into a determistic part and a polynomial kernel. Here `f_1` is the first fidelity dimension and `p` is the order of the polynomial kernel. - `c_3` is the same as `c_1` but is calculated for the second fidelity dimension `f_2`. - `c_2` is the interaction term with four deterministic terms and the polynomial kernel between `x_1[..., [f_1, f_2]]` and `x_2[..., [f_1, f_2]]`. Example: >>> x = torch.randn(10, 5) >>> # Non-batch: Simple option >>> covar_module = LinearTruncatedFidelityKernel() >>> covar = covar_module(x) # Output: LinearOperator of size (10 x 10) >>> >>> batch_x = torch.randn(2, 10, 5) >>> # Batch: Simple option >>> covar_module = LinearTruncatedFidelityKernel(batch_shape = torch.Size([2])) >>> covar = covar_module(x) # Output: LinearOperator of size (2 x 10 x 10) """ def __init__( # noqa C901 self, fidelity_dims: list[int], dimension: int | None = None, power_prior: Prior | None = None, power_constraint: Interval | None = None, nu: float = 2.5, lengthscale_prior_unbiased: Prior | None = None, lengthscale_prior_biased: Prior | None = None, lengthscale_constraint_unbiased: Interval | None = None, lengthscale_constraint_biased: Interval | None = None, covar_module_unbiased: Kernel | None = None, covar_module_biased: Kernel | None = None, **kwargs: Any, ) -> None: """ Args: fidelity_dims: A list containing either one or two indices specifying the fidelity parameters of the input. dimension: The dimension of `x`. Unused if `active_dims` is specified. power_prior: Prior for the power parameter of the polynomial kernel. Default is `None`. power_constraint: Constraint on the power parameter of the polynomial kernel. Default is `Positive`. nu: The smoothness parameter for the Matern kernel: either 1/2, 3/2, or 5/2. Unused if both `covar_module_unbiased` and `covar_module_biased` are specified. lengthscale_prior_unbiased: Prior on the lengthscale parameter of Matern kernel `k_0`. Default is `Gamma(1.1, 1/20)`. lengthscale_constraint_unbiased: Constraint on the lengthscale parameter of the Matern kernel `k_0`. Default is `Positive`. lengthscale_prior_biased: Prior on the lengthscale parameter of Matern kernels `k_i(i>0)`. Default is `Gamma(5, 1/20)`. lengthscale_constraint_biased: Constraint on the lengthscale parameter of the Matern kernels `k_i(i>0)`. Default is `Positive`. covar_module_unbiased: Specify a custom kernel for `k_0`. If omitted, use a `MaternKernel`. covar_module_biased: Specify a custom kernel for the biased parts `k_i(i>0)`. If omitted, use a `MaternKernel`. batch_shape: If specified, use a separate lengthscale for each batch of input data. If `x1` is a `batch_shape x n x d` tensor, this should be `batch_shape`. active_dims: Compute the covariance of a subset of input dimensions. The numbers correspond to the indices of the dimensions. """ if dimension is None and kwargs.get("active_dims") is None: raise UnsupportedError( "Must specify dimension when not specifying active_dims." ) n_fidelity = len(fidelity_dims) if len(set(fidelity_dims)) != n_fidelity: raise ValueError("fidelity_dims must not have repeated elements") if n_fidelity not in {1, 2}: raise UnsupportedError( "LinearTruncatedFidelityKernel accepts either one or two" "fidelity parameters." ) if nu not in {0.5, 1.5, 2.5}: raise ValueError("nu must be one of 0.5, 1.5, or 2.5") super().__init__(**kwargs) self.fidelity_dims = fidelity_dims if power_constraint is None: power_constraint = Positive() if lengthscale_prior_unbiased is None: lengthscale_prior_unbiased = GammaPrior(3, 6) if lengthscale_prior_biased is None: lengthscale_prior_biased = GammaPrior(6, 2) if lengthscale_constraint_unbiased is None: lengthscale_constraint_unbiased = Positive() if lengthscale_constraint_biased is None: lengthscale_constraint_biased = Positive() self.register_parameter( name="raw_power", parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)), ) self.register_constraint("raw_power", power_constraint) if power_prior is not None: self.register_prior( "power_prior", power_prior, lambda m: m.power, lambda m, v: m._set_power(v), ) if self.active_dims is not None: dimension = len(self.active_dims) if covar_module_unbiased is None: covar_module_unbiased = MaternKernel( nu=nu, batch_shape=self.batch_shape, lengthscale_prior=lengthscale_prior_unbiased, ard_num_dims=dimension - n_fidelity, lengthscale_constraint=lengthscale_constraint_unbiased, ) if covar_module_biased is None: covar_module_biased = MaternKernel( nu=nu, batch_shape=self.batch_shape, lengthscale_prior=lengthscale_prior_biased, ard_num_dims=dimension - n_fidelity, lengthscale_constraint=lengthscale_constraint_biased, ) self.covar_module_unbiased = covar_module_unbiased self.covar_module_biased = covar_module_biased @property def power(self) -> Tensor: return self.raw_power_constraint.transform(self.raw_power) @power.setter def power(self, value: Tensor) -> None: self._set_power(value) def _set_power(self, value: Tensor) -> None: if not torch.is_tensor(value): value = torch.as_tensor(value).to(self.raw_power) self.initialize(raw_power=self.raw_power_constraint.inverse_transform(value)) def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Tensor: if params.get("last_dim_is_batch", False): raise NotImplementedError( "last_dim_is_batch not yet supported by LinearTruncatedFidelityKernel" ) power = self.power.view(*self.batch_shape, 1, 1) active_dimsM = torch.tensor( [i for i in range(x1.size(-1)) if i not in self.fidelity_dims], device=x1.device, ) if len(active_dimsM) == 0: raise RuntimeError( "Input to LinearTruncatedFidelityKernel must have at least one " "non-fidelity dimension." ) x1_ = x1.index_select(dim=-1, index=active_dimsM) x2_ = x2.index_select(dim=-1, index=active_dimsM) covar_unbiased = self.covar_module_unbiased(x1_, x2_, diag=diag) covar_biased = self.covar_module_biased(x1_, x2_, diag=diag) # clamp to avoid numerical issues fd_idxr0 = torch.full( (1,), self.fidelity_dims[0], dtype=torch.long, device=x1.device ) x11_ = x1.index_select(dim=-1, index=fd_idxr0).clamp(0, 1) x21t_ = x2.index_select(dim=-1, index=fd_idxr0).clamp(0, 1) if not diag: x21t_ = x21t_.transpose(-1, -2) cross_term_1 = (1 - x11_) * (1 - x21t_) bias_factor = cross_term_1 * (1 + x11_ * x21t_).pow(power) if len(self.fidelity_dims) > 1: # clamp to avoid numerical issues fd_idxr1 = torch.full( (1,), self.fidelity_dims[1], dtype=torch.long, device=x1.device ) x12_ = x1.index_select(dim=-1, index=fd_idxr1).clamp(0, 1) x22t_ = x2.index_select(dim=-1, index=fd_idxr1).clamp(0, 1) x1b_ = torch.cat([x11_, x12_], dim=-1) if diag: x2bt_ = torch.cat([x21t_, x22t_], dim=-1) k = (1 + (x1b_ * x2bt_).sum(dim=-1, keepdim=True)).pow(power) else: x22t_ = x22t_.transpose(-1, -2) x2bt_ = torch.cat([x21t_, x22t_], dim=-2) k = (1 + x1b_ @ x2bt_).pow(power) cross_term_2 = (1 - x12_) * (1 - x22t_) bias_factor += cross_term_2 * (1 + x12_ * x22t_).pow(power) bias_factor += cross_term_2 * cross_term_1 * k if diag: bias_factor = bias_factor.view(covar_biased.shape) return covar_unbiased + bias_factor * covar_biased