Source code for botorch.models.transforms.input
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional
import torch
from torch import Tensor
from torch.nn import Module, ModuleDict
from ...exceptions.errors import BotorchTensorDimensionError
[docs]class InputTransform(Module, ABC):
r"""Abstract base class for input transforms."""
[docs] @abstractmethod
def forward(self, X: Tensor) -> Tensor:
r"""Transform the inputs to a model.
Args:
X: A `batch_shape x n x d`-dim tensor of inputs.
Returns:
A `batch_shape x n x d`-dim tensor of transformed inputs.
"""
pass # pragma: no cover
[docs] def untransform(self, X: Tensor) -> Tensor:
r"""Un-transform the inputs to a model.
Args:
X: A `batch_shape x n x d`-dim tensor of transformed inputs.
Returns:
A `batch_shape x n x d`-dim tensor of un-transformed inputs.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement the `untransform` method"
)
[docs]class ChainedInputTransform(InputTransform, ModuleDict):
r"""An input transform representing the chaining of individual transforms"""
def __init__(self, **transforms: InputTransform) -> None:
r"""Chaining of input transforms.
Args:
transforms: The transforms to chain. Internally, the names of the
kwargs are used as the keys for accessing the individual
transforms on the module.
"""
super().__init__(transforms)
[docs] def forward(self, X: Tensor) -> Tensor:
r"""Transform the inputs to a model.
Individual transforms are applied in sequence.
Args:
X: A `batch_shape x n x d`-dim tensor of inputs.
Returns:
A `batch_shape x n x d`-dim tensor of transformed inputs.
"""
for tf in self.values():
X = tf.forward(X)
return X
[docs] def untransform(self, X: Tensor) -> Tensor:
r"""Un-transform the inputs to a model.
Un-transforms of the individual transforms are applied in reverse sequence.
Args:
X: A `batch_shape x n x d`-dim tensor of transformed inputs.
Returns:
A `batch_shape x n x d`-dim tensor of un-transformed inputs.
"""
for tf in reversed(self.values()):
X = tf.untransform(X)
return X
[docs]class Normalize(InputTransform):
r"""Normalize the inputs to the unit cube.
If no explicit bounds are provided this module is stateful: If in train mode,
calling `forward` updates the module state (i.e. the normalizing bounds). If
in eval mode, calling `forward` simply applies the normalization using the
current module state.
"""
def __init__(
self,
d: int,
bounds: Optional[Tensor] = None,
batch_shape: torch.Size = torch.Size(), # noqa: B008
) -> None:
r"""Normalize the inputs to the unit cube.
Args:
d: The dimension of the input space.
bounds: If provided, use these bounds to normalize the inputs. If
omitted, learn the bounds in train mode.
batch_shape: The batch shape of the inputs (asssuming input tensors
of shape `batch_shape x n x d`). If provided, perform individual
normalization per batch, otherwise uses a single normalization.
"""
super().__init__()
if bounds is not None:
if bounds.size(-1) != d:
raise BotorchTensorDimensionError(
"Incompatible dimensions of provided bounds"
)
mins = bounds[..., [0], :]
ranges = bounds[..., [1], :] - mins
self.learn_bounds = False
else:
mins = torch.zeros(*batch_shape, 1, d)
ranges = torch.zeros(*batch_shape, 1, d)
self.learn_bounds = True
self.register_buffer("mins", mins)
self.register_buffer("ranges", ranges)
self._d = d
[docs] def forward(self, X: Tensor) -> Tensor:
r"""Normalize the inputs.
If no explicit bounds are provided, this is stateful: In train mode,
calling `forward` updates the module state (i.e. the normalizing bounds).
In eval mode, calling `forward` simply applies the normalization using
the current module state.
Args:
X: A `batch_shape x n x d`-dim tensor of inputs.
Returns:
A `batch_shape x n x d`-dim tensor of inputs normalized to the
module's bounds.
"""
if self.learn_bounds and self.training:
if X.size(-1) != self.mins.size(-1):
raise BotorchTensorDimensionError(
f"Wrong input. dimension. Received {X.size(-1)}, "
f"expected {self.mins.size(-1)}"
)
self.mins = X.min(dim=-2, keepdim=True)[0]
self.ranges = X.max(dim=-2, keepdim=True)[0] - self.mins
return (X - self.mins) / self.ranges
[docs] def untransform(self, X: Tensor) -> Tensor:
r"""Un-normalize the inputs.
Args:
X: A `batch_shape x n x d`-dim tensor of normalized inputs.
Returns:
A `batch_shape x n x d`-dim tensor of un-normalized inputs.
"""
return self.mins + X * self.ranges
@property
def bounds(self) -> Tensor:
r"""The bounds used for normalizing the inputs."""
return torch.cat([self.mins, self.mins + self.ranges], dim=-2)