Source code for botorch.models.kernels.downsampling
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
from gpytorch.constraints import Interval, Positive
from gpytorch.kernels import Kernel
from gpytorch.priors import Prior
from torch import Tensor
[docs]class DownsamplingKernel(Kernel):
r"""GPyTorch Downsampling Kernel.
Computes a covariance matrix based on the down sampling kernel between
inputs `x_1` and `x_2` (we expect `d = 1`):
K(\mathbf{x_1}, \mathbf{x_2}) = c + (1 - x_1)^(1 + delta) *
(1 - x_2)^(1 + delta).
where `c` is an offset parameter, and `delta` is a power parameter.
Args:
power_constraint: Constraint to place on power parameter. Default is
`Positive`.
power_prior: Prior over the power parameter.
offset_constraint: Constraint to place on offset parameter. Default is
`Positive`.
active_dims: List of data dimensions to operate on. `len(active_dims)`
should equal `num_dimensions`.
"""
def __init__(
self,
power_prior: Optional[Prior] = None,
offset_prior: Optional[Prior] = None,
power_constraint: Optional[Interval] = None,
offset_constraint: Optional[Interval] = None,
**kwargs,
):
super().__init__(**kwargs)
if power_constraint is None:
power_constraint = Positive()
if offset_constraint is None:
offset_constraint = Positive()
self.register_parameter(
name="raw_power",
parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)),
)
self.register_parameter(
name="raw_offset",
parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)),
)
if power_prior is not None:
self.register_prior(
"power_prior",
power_prior,
lambda: self.power,
lambda v: self._set_power(v),
)
self.register_constraint("raw_power", power_constraint)
if offset_prior is not None:
self.register_prior(
"offset_prior",
offset_prior,
lambda: self.offset,
lambda v: self._set_offset(v),
)
self.register_constraint("raw_offset", offset_constraint)
@property
def power(self) -> Tensor:
return self.raw_power_constraint.transform(self.raw_power)
@power.setter
def power(self, value: Tensor) -> None:
self._set_power(value)
def _set_power(self, value: Tensor) -> None:
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_power)
self.initialize(raw_power=self.raw_power_constraint.inverse_transform(value))
@property
def offset(self) -> Tensor:
return self.raw_offset_constraint.transform(self.raw_offset)
@offset.setter
def offset(self, value: Tensor) -> None:
self._set_offset(value)
def _set_offset(self, value: Tensor) -> None:
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_offset)
self.initialize(raw_offset=self.raw_offset_constraint.inverse_transform(value))
def forward(
self,
x1: Tensor,
x2: Tensor,
diag: Optional[bool] = False,
last_dim_is_batch: Optional[bool] = False,
**params,
) -> Tensor:
offset = self.offset.view(*self.batch_shape, 1, 1)
exponent = 1 + self.power.view(*self.batch_shape, 1, 1)
if last_dim_is_batch:
x1 = x1.transpose(-1, -2).unsqueeze(-1)
x2 = x2.transpose(-1, -2).unsqueeze(-1)
x1_ = 1 - x1
x2_ = 1 - x2
if diag:
return offset + (x1_ * x2_).sum(dim=-1).pow(exponent)
return offset + x1_.pow(exponent) @ x2_.transpose(-2, -1).pow(exponent)