Source code for botorch.sampling.pathwise.paths

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from abc import ABC
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    Iterator,
    List,
    Mapping,
    Optional,
    Tuple,
    Union,
)

from botorch.exceptions.errors import UnsupportedError
from botorch.sampling.pathwise.features import FeatureMap
from botorch.sampling.pathwise.utils import (
    TInputTransform,
    TOutputTransform,
    TransformedModuleMixin,
)
from torch import Tensor
from torch.nn import Module, ModuleDict, ModuleList, Parameter


[docs] class SamplePath(ABC, TransformedModuleMixin, Module): r"""Abstract base class for Botorch sample paths."""
[docs] class PathDict(SamplePath): r"""A dictionary of SamplePaths.""" def __init__( self, paths: Optional[Mapping[str, SamplePath]] = None, join: Optional[Callable[[List[Tensor]], Tensor]] = None, input_transform: Optional[TInputTransform] = None, output_transform: Optional[TOutputTransform] = None, ) -> None: r"""Initializes a PathDict instance. Args: paths: An optional mapping of strings to sample paths. join: An optional callable used to combine each path's outputs. input_transform: An optional input transform for the module. output_transform: An optional output transform for the module. """ if join is None and output_transform is not None: raise UnsupportedError("Output transforms must be preceded by a join rule.") super().__init__() self.join = join self.input_transform = input_transform self.output_transform = output_transform self.paths = ( paths if isinstance(paths, ModuleDict) else ModuleDict({} if paths is None else paths) )
[docs] def forward(self, x: Tensor, **kwargs: Any) -> Union[Tensor, Dict[str, Tensor]]: out = [path(x, **kwargs) for path in self.paths.values()] return dict(zip(self.paths, out)) if self.join is None else self.join(out)
[docs] def items(self) -> Iterable[Tuple[str, SamplePath]]: return self.paths.items()
[docs] def keys(self) -> Iterable[str]: return self.paths.keys()
[docs] def values(self) -> Iterable[SamplePath]: return self.paths.values()
def __len__(self) -> int: return len(self.paths) def __iter__(self) -> Iterator[SamplePath]: yield from self.paths def __delitem__(self, key: str) -> None: del self.paths[key] def __getitem__(self, key: str) -> SamplePath: return self.paths[key] def __setitem__(self, key: str, val: SamplePath) -> None: self.paths[key] = val
[docs] class PathList(SamplePath): r"""A list of SamplePaths.""" def __init__( self, paths: Optional[Iterable[SamplePath]] = None, join: Optional[Callable[[List[Tensor]], Tensor]] = None, input_transform: Optional[TInputTransform] = None, output_transform: Optional[TOutputTransform] = None, ) -> None: r"""Initializes a PathList instance. Args: paths: An optional iterable of sample paths. join: An optional callable used to combine each path's outputs. input_transform: An optional input transform for the module. output_transform: An optional output transform for the module. """ if join is None and output_transform is not None: raise UnsupportedError("Output transforms must be preceded by a join rule.") super().__init__() self.join = join self.input_transform = input_transform self.output_transform = output_transform self.paths = ( paths if isinstance(paths, ModuleList) else ModuleList({} if paths is None else paths) )
[docs] def forward(self, x: Tensor, **kwargs: Any) -> Union[Tensor, List[Tensor]]: out = [path(x, **kwargs) for path in self.paths] return out if self.join is None else self.join(out)
def __len__(self) -> int: return len(self.paths) def __iter__(self) -> Iterator[SamplePath]: yield from self.paths def __delitem__(self, key: int) -> None: del self.paths[key] def __getitem__(self, key: int) -> SamplePath: return self.paths[key] def __setitem__(self, key: int, val: SamplePath) -> None: self.paths[key] = val
[docs] class GeneralizedLinearPath(SamplePath): r"""A sample path in the form of a generalized linear model.""" def __init__( self, feature_map: FeatureMap, weight: Union[Parameter, Tensor], bias_module: Optional[Module] = None, input_transform: Optional[TInputTransform] = None, output_transform: Optional[TOutputTransform] = None, ): r"""Initializes a GeneralizedLinearPath instance. .. code-block:: text path(x) = output_transform(bias_module(z) + feature_map(z)^T weight), where z = input_transform(x). Args: feature_map: A map used to featurize the module's inputs. weight: A tensor of weights used to combine input features. bias_module: An optional module used to define additive offsets. input_transform: An optional input transform for the module. output_transform: An optional output transform for the module. """ super().__init__() self.feature_map = feature_map if not isinstance(weight, Parameter): self.register_buffer("weight", weight) self.weight = weight self.bias_module = bias_module self.input_transform = input_transform self.output_transform = output_transform
[docs] def forward(self, x: Tensor, **kwargs) -> Tensor: feat = self.feature_map(x, **kwargs) out = (feat @ self.weight.unsqueeze(-1)).squeeze(-1) return out if self.bias_module is None else out + self.bias_module(x)