Source code for botorch.sampling.pathwise.utils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, List, Optional, overload, Tuple, Union

import torch
from botorch.models.approximate_gp import SingleTaskVariationalGP
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.model import Model, ModelList
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.utils.dispatcher import Dispatcher
from gpytorch.kernels import ScaleKernel
from gpytorch.kernels.kernel import Kernel
from torch import LongTensor, Tensor
from torch.nn import Module, ModuleList

TInputTransform = Union[InputTransform, Callable[[Tensor], Tensor]]
TOutputTransform = Union[OutcomeTransform, Callable[[Tensor], Tensor]]
GetTrainInputs = Dispatcher("get_train_inputs")
GetTrainTargets = Dispatcher("get_train_targets")


[docs]class TransformedModuleMixin: r"""Mixin that wraps a module's __call__ method with optional transforms.""" input_transform: Optional[TInputTransform] output_transform: Optional[TOutputTransform] def __call__(self, values: Tensor, *args: Any, **kwargs: Any) -> Tensor: input_transform = getattr(self, "input_transform", None) if input_transform is not None: values = ( input_transform.forward(values) if isinstance(input_transform, InputTransform) else input_transform(values) ) output = super().__call__(values, *args, **kwargs) output_transform = getattr(self, "output_transform", None) if output_transform is None: return output return ( output_transform.untransform(output)[0] if isinstance(output_transform, OutcomeTransform) else output_transform(output) )
[docs]class TensorTransform(ABC, Module): r"""Abstract base class for transforms that map tensor to tensor."""
[docs] @abstractmethod def forward(self, values: Tensor, **kwargs: Any) -> Tensor: pass # pragma: no cover
[docs]class ChainedTransform(TensorTransform): r"""A composition of TensorTransforms.""" def __init__(self, *transforms: TensorTransform): r"""Initializes a ChainedTransform instance. Args: transforms: A set of transforms to be applied from right to left. """ super().__init__() self.transforms = ModuleList(transforms)
[docs] def forward(self, values: Tensor) -> Tensor: for transform in reversed(self.transforms): values = transform(values) return values
[docs]class SineCosineTransform(TensorTransform): r"""A transform that returns concatenated sine and cosine features.""" def __init__(self, scale: Optional[Tensor] = None): r"""Initializes a SineCosineTransform instance. Args: scale: An optional tensor used to rescale the module's outputs. """ super().__init__() self.scale = scale
[docs] def forward(self, values: Tensor) -> Tensor: sincos = torch.concat([values.sin(), values.cos()], dim=-1) return sincos if self.scale is None else self.scale * sincos
[docs]class InverseLengthscaleTransform(TensorTransform): r"""A transform that divides its inputs by a kernels lengthscales.""" def __init__(self, kernel: Kernel): r"""Initializes an InverseLengthscaleTransform instance. Args: kernel: The kernel whose lengthscales are to be used. """ if not kernel.has_lengthscale: raise RuntimeError(f"{type(kernel)} does not implement `lengthscale`.") super().__init__() self.kernel = kernel
[docs] def forward(self, values: Tensor) -> Tensor: return self.kernel.lengthscale.reciprocal() * values
[docs]class OutputscaleTransform(TensorTransform): r"""A transform that multiplies its inputs by the square root of a kernel's outputscale.""" def __init__(self, kernel: ScaleKernel): r"""Initializes an OutputscaleTransform instance. Args: kernel: A ScaleKernel whose `outputscale` is to be used. """ super().__init__() self.kernel = kernel
[docs] def forward(self, values: Tensor) -> Tensor: outputscale = ( self.kernel.outputscale[..., None, None] if self.kernel.batch_shape else self.kernel.outputscale ) return outputscale.sqrt() * values
[docs]class FeatureSelector(TensorTransform): r"""A transform that returns a subset of its input's features. along a given tensor dimension.""" def __init__(self, indices: Iterable[int], dim: Union[int, LongTensor] = -1): r"""Initializes a FeatureSelector instance. Args: indices: A LongTensor of feature indices. dim: The dimensional along which to index features. """ super().__init__() self.register_buffer("dim", dim if torch.is_tensor(dim) else torch.tensor(dim)) self.register_buffer( "indices", indices if torch.is_tensor(indices) else torch.tensor(indices) )
[docs] def forward(self, values: Tensor) -> Tensor: return values.index_select(dim=self.dim, index=self.indices)
[docs]class OutcomeUntransformer(TensorTransform): r"""Module acting as a bridge for `OutcomeTransform.untransform`.""" def __init__( self, transform: OutcomeTransform, num_outputs: Union[int, LongTensor], ): r"""Initializes an OutcomeUntransformer instance. Args: transform: The wrapped OutcomeTransform instance. num_outputs: The number of outcome features that the OutcomeTransform transforms. """ super().__init__() self.transform = transform self.register_buffer( "num_outputs", num_outputs if torch.is_tensor(num_outputs) else torch.tensor(num_outputs), )
[docs] def forward(self, values: Tensor) -> Tensor: # OutcomeTransforms expect an explicit output dimension in the final position. if self.num_outputs == 1: # BoTorch has suppressed the output dimension output_values, _ = self.transform.untransform(values.unsqueeze(-1)) return output_values.squeeze(-1) # BoTorch has moved the output dimension inside as the final batch dimension. output_values, _ = self.transform.untransform(values.transpose(-2, -1)) return output_values.transpose(-2, -1)
[docs]def get_input_transform(model: GPyTorchModel) -> Optional[InputTransform]: r"""Returns a model's input_transform or None.""" return getattr(model, "input_transform", None)
[docs]def get_output_transform(model: GPyTorchModel) -> Optional[OutcomeUntransformer]: r"""Returns a wrapped version of a model's outcome_transform or None.""" transform = getattr(model, "outcome_transform", None) if transform is None: return None return OutcomeUntransformer(transform=transform, num_outputs=model.num_outputs)
@overload def get_train_inputs(model: Model, transformed: bool = False) -> Tuple[Tensor, ...]: pass # pragma: no cover @overload def get_train_inputs(model: ModelList, transformed: bool = False) -> List[...]: pass # pragma: no cover
[docs]def get_train_inputs(model: Model, transformed: bool = False): return GetTrainInputs(model, transformed=transformed)
@GetTrainInputs.register(Model) def _get_train_inputs_Model(model: Model, transformed: bool = False) -> Tuple[Tensor]: if not transformed: original_train_input = getattr(model, "_original_train_inputs", None) if torch.is_tensor(original_train_input): return (original_train_input,) (X,) = model.train_inputs transform = get_input_transform(model) if transform is None: return (X,) if model.training: return (transform.forward(X) if transformed else X,) return (X if transformed else transform.untransform(X),) @GetTrainInputs.register(SingleTaskVariationalGP) def _get_train_inputs_SingleTaskVariationalGP( model: SingleTaskVariationalGP, transformed: bool = False ) -> Tuple[Tensor]: (X,) = model.model.train_inputs if model.training != transformed: return (X,) transform = get_input_transform(model) if transform is None: return (X,) return (transform.forward(X) if model.training else transform.untransform(X),) @GetTrainInputs.register(ModelList) def _get_train_inputs_ModelList( model: ModelList, transformed: bool = False ) -> List[...]: return [get_train_inputs(m, transformed=transformed) for m in model.models] @overload def get_train_targets(model: Model, transformed: bool = False) -> Tensor: pass # pragma: no cover @overload def get_train_targets(model: ModelList, transformed: bool = False) -> List[...]: pass # pragma: no cover
[docs]def get_train_targets(model: Model, transformed: bool = False): return GetTrainTargets(model, transformed=transformed)
@GetTrainTargets.register(Model) def _get_train_targets_Model(model: Model, transformed: bool = False) -> Tensor: Y = model.train_targets # Note: Avoid using `get_output_transform` here since it creates a Module transform = getattr(model, "outcome_transform", None) if transformed or transform is None: return Y if model.num_outputs == 1: return transform.untransform(Y.unsqueeze(-1))[0].squeeze(-1) return transform.untransform(Y.transpose(-2, -1))[0].transpose(-2, -1) @GetTrainTargets.register(SingleTaskVariationalGP) def _get_train_targets_SingleTaskVariationalGP( model: Model, transformed: bool = False ) -> Tensor: Y = model.model.train_targets transform = getattr(model, "outcome_transform", None) if transformed or transform is None: return Y if model.num_outputs == 1: return transform.untransform(Y.unsqueeze(-1))[0].squeeze(-1) # SingleTaskVariationalGP.__init__ doesn't bring the multitoutpout dimension inside return transform.untransform(Y)[0] @GetTrainTargets.register(ModelList) def _get_train_targets_ModelList( model: ModelList, transformed: bool = False ) -> List[...]: return [get_train_targets(m, transformed=transformed) for m in model.models]