Source code for botorch.models.transforms.input

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import Optional

import torch
from torch import Tensor
from torch.nn import Module, ModuleDict

from ...exceptions.errors import BotorchTensorDimensionError


[docs]class InputTransform(Module, ABC): r"""Abstract base class for input transforms."""
[docs] @abstractmethod def forward(self, X: Tensor) -> Tensor: r"""Transform the inputs to a model. Args: X: A `batch_shape x n x d`-dim tensor of inputs. Returns: A `batch_shape x n x d`-dim tensor of transformed inputs. """ pass # pragma: no cover
[docs] def untransform(self, X: Tensor) -> Tensor: r"""Un-transform the inputs to a model. Args: X: A `batch_shape x n x d`-dim tensor of transformed inputs. Returns: A `batch_shape x n x d`-dim tensor of un-transformed inputs. """ raise NotImplementedError( f"{self.__class__.__name__} does not implement the `untransform` method" )
[docs]class ChainedInputTransform(InputTransform, ModuleDict): r"""An input transform representing the chaining of individual transforms""" def __init__(self, **transforms: InputTransform) -> None: r"""Chaining of input transforms. Args: transforms: The transforms to chain. Internally, the names of the kwargs are used as the keys for accessing the individual transforms on the module. """ super().__init__(transforms)
[docs] def forward(self, X: Tensor) -> Tensor: r"""Transform the inputs to a model. Individual transforms are applied in sequence. Args: X: A `batch_shape x n x d`-dim tensor of inputs. Returns: A `batch_shape x n x d`-dim tensor of transformed inputs. """ for tf in self.values(): X = tf.forward(X) return X
[docs] def untransform(self, X: Tensor) -> Tensor: r"""Un-transform the inputs to a model. Un-transforms of the individual transforms are applied in reverse sequence. Args: X: A `batch_shape x n x d`-dim tensor of transformed inputs. Returns: A `batch_shape x n x d`-dim tensor of un-transformed inputs. """ for tf in reversed(self.values()): X = tf.untransform(X) return X
[docs]class Normalize(InputTransform): r"""Normalize the inputs to the unit cube. If no explicit bounds are provided this module is stateful: If in train mode, calling `forward` updates the module state (i.e. the normalizing bounds). If in eval mode, calling `forward` simply applies the normalization using the current module state. """ def __init__( self, d: int, bounds: Optional[Tensor] = None, batch_shape: torch.Size = torch.Size(), # noqa: B008 ) -> None: r"""Normalize the inputs to the unit cube. Args: d: The dimension of the input space. bounds: If provided, use these bounds to normalize the inputs. If omitted, learn the bounds in train mode. batch_shape: The batch shape of the inputs (asssuming input tensors of shape `batch_shape x n x d`). If provided, perform individual normalization per batch, otherwise uses a single normalization. """ super().__init__() if bounds is not None: if bounds.size(-1) != d: raise BotorchTensorDimensionError( "Incompatible dimensions of provided bounds" ) mins = bounds[..., [0], :] ranges = bounds[..., [1], :] - mins self.learn_bounds = False else: mins = torch.zeros(*batch_shape, 1, d) ranges = torch.zeros(*batch_shape, 1, d) self.learn_bounds = True self.register_buffer("mins", mins) self.register_buffer("ranges", ranges) self._d = d
[docs] def forward(self, X: Tensor) -> Tensor: r"""Normalize the inputs. If no explicit bounds are provided, this is stateful: In train mode, calling `forward` updates the module state (i.e. the normalizing bounds). In eval mode, calling `forward` simply applies the normalization using the current module state. Args: X: A `batch_shape x n x d`-dim tensor of inputs. Returns: A `batch_shape x n x d`-dim tensor of inputs normalized to the module's bounds. """ if self.learn_bounds and self.training: if X.size(-1) != self.mins.size(-1): raise BotorchTensorDimensionError( f"Wrong input. dimension. Received {X.size(-1)}, " f"expected {self.mins.size(-1)}" ) self.mins = X.min(dim=-2, keepdim=True)[0] self.ranges = X.max(dim=-2, keepdim=True)[0] - self.mins return (X - self.mins) / self.ranges
[docs] def untransform(self, X: Tensor) -> Tensor: r"""Un-normalize the inputs. Args: X: A `batch_shape x n x d`-dim tensor of normalized inputs. Returns: A `batch_shape x n x d`-dim tensor of un-normalized inputs. """ return self.mins + X * self.ranges
@property def bounds(self) -> Tensor: r"""The bounds used for normalizing the inputs.""" return torch.cat([self.mins, self.mins + self.ranges], dim=-2)