Source code for botorch.models.kernels.downsampling

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import Optional

import torch
from gpytorch.constraints import Interval, Positive
from gpytorch.kernels import Kernel
from gpytorch.priors import Prior
from torch import Tensor


[docs]class DownsamplingKernel(Kernel): r"""GPyTorch Downsampling Kernel. Computes a covariance matrix based on the down sampling kernel between inputs `x_1` and `x_2` (we expect `d = 1`): K(\mathbf{x_1}, \mathbf{x_2}) = c + (1 - x_1)^(1 + delta) * (1 - x_2)^(1 + delta). where `c` is an offset parameter, and `delta` is a power parameter. Args: power_constraint: Constraint to place on power parameter. Default is `Positive`. power_prior: Prior over the power parameter. offset_constraint: Constraint to place on offset parameter. Default is `Positive`. active_dims: List of data dimensions to operate on. `len(active_dims)` should equal `num_dimensions`. """ def __init__( self, power_prior: Optional[Prior] = None, offset_prior: Optional[Prior] = None, power_constraint: Optional[Interval] = None, offset_constraint: Optional[Interval] = None, **kwargs, ): super().__init__(**kwargs) if power_constraint is None: power_constraint = Positive() if offset_constraint is None: offset_constraint = Positive() self.register_parameter( name="raw_power", parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)), ) self.register_parameter( name="raw_offset", parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)), ) if power_prior is not None: self.register_prior( "power_prior", power_prior, lambda m: m.power, lambda m, v: m._set_power(v), ) self.register_constraint("raw_power", power_constraint) if offset_prior is not None: self.register_prior( "offset_prior", offset_prior, lambda m: m.offset, lambda m, v: m._set_offset(v), ) self.register_constraint("raw_offset", offset_constraint) @property def power(self) -> Tensor: return self.raw_power_constraint.transform(self.raw_power) @power.setter def power(self, value: Tensor) -> None: self._set_power(value) def _set_power(self, value: Tensor) -> None: if not torch.is_tensor(value): value = torch.as_tensor(value).to(self.raw_power) self.initialize(raw_power=self.raw_power_constraint.inverse_transform(value)) @property def offset(self) -> Tensor: return self.raw_offset_constraint.transform(self.raw_offset) @offset.setter def offset(self, value: Tensor) -> None: self._set_offset(value) def _set_offset(self, value: Tensor) -> None: if not torch.is_tensor(value): value = torch.as_tensor(value).to(self.raw_offset) self.initialize(raw_offset=self.raw_offset_constraint.inverse_transform(value)) def forward( self, x1: Tensor, x2: Tensor, diag: Optional[bool] = False, last_dim_is_batch: Optional[bool] = False, **params, ) -> Tensor: offset = self.offset exponent = 1 + self.power if last_dim_is_batch: x1 = x1.transpose(-1, -2).unsqueeze(-1) x2 = x2.transpose(-1, -2).unsqueeze(-1) x1_ = 1 - x1 x2_ = 1 - x2 if diag: return offset + (x1_ * x2_).sum(dim=-1).pow(exponent) offset = offset.unsqueeze(-1) # unsqueeze enables batch evaluation exponent = exponent.unsqueeze(-1) # unsqueeze enables batch evaluation return offset + x1_.pow(exponent) @ x2_.transpose(-2, -1).pow(exponent)