Source code for botorch.sampling.pathwise.paths

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from abc import ABC
from collections.abc import Iterable, Iterator, Mapping
from typing import Any, Callable, Optional, Union

from botorch.exceptions.errors import UnsupportedError
from botorch.sampling.pathwise.features import FeatureMap
from botorch.sampling.pathwise.utils import (
    TInputTransform,
    TOutputTransform,
    TransformedModuleMixin,
)
from torch import Tensor
from torch.nn import Module, ModuleDict, ModuleList, Parameter


[docs] class SamplePath(ABC, TransformedModuleMixin, Module): r"""Abstract base class for Botorch sample paths."""
[docs] class PathDict(SamplePath): r"""A dictionary of SamplePaths.""" def __init__( self, paths: Optional[Mapping[str, SamplePath]] = None, join: Optional[Callable[[list[Tensor]], Tensor]] = None, input_transform: Optional[TInputTransform] = None, output_transform: Optional[TOutputTransform] = None, ) -> None: r"""Initializes a PathDict instance. Args: paths: An optional mapping of strings to sample paths. join: An optional callable used to combine each path's outputs. input_transform: An optional input transform for the module. output_transform: An optional output transform for the module. """ if join is None and output_transform is not None: raise UnsupportedError("Output transforms must be preceded by a join rule.") super().__init__() self.join = join self.input_transform = input_transform self.output_transform = output_transform self.paths = ( paths if isinstance(paths, ModuleDict) else ModuleDict({} if paths is None else paths) )
[docs] def forward(self, x: Tensor, **kwargs: Any) -> Union[Tensor, dict[str, Tensor]]: out = [path(x, **kwargs) for path in self.paths.values()] return dict(zip(self.paths, out)) if self.join is None else self.join(out)
[docs] def items(self) -> Iterable[tuple[str, SamplePath]]: return self.paths.items()
[docs] def keys(self) -> Iterable[str]: return self.paths.keys()
[docs] def values(self) -> Iterable[SamplePath]: return self.paths.values()
def __len__(self) -> int: return len(self.paths) def __iter__(self) -> Iterator[SamplePath]: yield from self.paths def __delitem__(self, key: str) -> None: del self.paths[key] def __getitem__(self, key: str) -> SamplePath: return self.paths[key] def __setitem__(self, key: str, val: SamplePath) -> None: self.paths[key] = val
[docs] class PathList(SamplePath): r"""A list of SamplePaths.""" def __init__( self, paths: Optional[Iterable[SamplePath]] = None, join: Optional[Callable[[list[Tensor]], Tensor]] = None, input_transform: Optional[TInputTransform] = None, output_transform: Optional[TOutputTransform] = None, ) -> None: r"""Initializes a PathList instance. Args: paths: An optional iterable of sample paths. join: An optional callable used to combine each path's outputs. input_transform: An optional input transform for the module. output_transform: An optional output transform for the module. """ if join is None and output_transform is not None: raise UnsupportedError("Output transforms must be preceded by a join rule.") super().__init__() self.join = join self.input_transform = input_transform self.output_transform = output_transform self.paths = ( paths if isinstance(paths, ModuleList) else ModuleList({} if paths is None else paths) )
[docs] def forward(self, x: Tensor, **kwargs: Any) -> Union[Tensor, list[Tensor]]: out = [path(x, **kwargs) for path in self.paths] return out if self.join is None else self.join(out)
def __len__(self) -> int: return len(self.paths) def __iter__(self) -> Iterator[SamplePath]: yield from self.paths def __delitem__(self, key: int) -> None: del self.paths[key] def __getitem__(self, key: int) -> SamplePath: return self.paths[key] def __setitem__(self, key: int, val: SamplePath) -> None: self.paths[key] = val
[docs] class GeneralizedLinearPath(SamplePath): r"""A sample path in the form of a generalized linear model.""" def __init__( self, feature_map: FeatureMap, weight: Union[Parameter, Tensor], bias_module: Optional[Module] = None, input_transform: Optional[TInputTransform] = None, output_transform: Optional[TOutputTransform] = None, ): r"""Initializes a GeneralizedLinearPath instance. .. code-block:: text path(x) = output_transform(bias_module(z) + feature_map(z)^T weight), where z = input_transform(x). Args: feature_map: A map used to featurize the module's inputs. weight: A tensor of weights used to combine input features. bias_module: An optional module used to define additive offsets. input_transform: An optional input transform for the module. output_transform: An optional output transform for the module. """ super().__init__() self.feature_map = feature_map if not isinstance(weight, Parameter): self.register_buffer("weight", weight) self.weight = weight self.bias_module = bias_module self.input_transform = input_transform self.output_transform = output_transform
[docs] def forward(self, x: Tensor, **kwargs) -> Tensor: feat = self.feature_map(x, **kwargs) out = (feat @ self.weight.unsqueeze(-1)).squeeze(-1) return out if self.bias_module is None else out + self.bias_module(x)