#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
r"""
Some basic data transformation helpers.
"""
from functools import wraps
from typing import Any, Callable, Optional
import torch
from torch import Tensor
[docs]def squeeze_last_dim(Y: Tensor) -> Tensor:
r"""Squeeze the last dimension of a Tensor.
Args:
Y: A `... x d`-dim Tensor.
Returns:
The input tensor with last dimension squeezed.
Example:
>>> Y = torch.rand(4, 3)
>>> Y_squeezed = squeeze_last_dim(Y)
"""
return Y.squeeze(-1)
[docs]def standardize(X: Tensor) -> Tensor:
r"""Standardize a tensor by dim=0.
Args:
X: A `n` or `n x d`-dim tensor
Returns:
The standardized `X`.
Example:
>>> X = torch.rand(4, 3)
>>> X_standardized = standardize(X)
"""
X_std = X.std(dim=0)
X_std = X_std.where(X_std >= 1e-9, torch.full_like(X_std, 1.0))
return (X - X.mean(dim=0)) / X_std
[docs]def normalize(X: Tensor, bounds: Tensor) -> Tensor:
r"""Min-max normalize X to [0, 1] using the provided bounds.
Args:
X: `... x d` tensor of data
bounds: `2 x d` tensor of lower and upper bounds for each of the X's d
columns.
Returns:
A `... x d`-dim tensor of normalized data.
Example:
>>> X = torch.rand(4, 3)
>>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)])
>>> X_normalized = unnormalize(X, bounds)
"""
return (X - bounds[0]) / (bounds[1] - bounds[0])
[docs]def unnormalize(X: Tensor, bounds: Tensor) -> Tensor:
r"""Unscale X from [0, 1] to the original scale.
Args:
X: `... x d` tensor of data
bounds: `2 x d` tensor of lower and upper bounds for each of the X's d
columns.
Returns:
A `... x d`-dim tensor of unnormalized data.
Example:
>>> X_normalized = torch.rand(4, 3)
>>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)])
>>> X = unnormalize(X_normalized, bounds)
"""
return X * (bounds[1] - bounds[0]) + bounds[0]
[docs]def match_batch_shape(X: Tensor, Y: Tensor) -> Tensor:
r"""Matches the batch dimension of a tensor to that of another tensor.
Args:
X: A `batch_shape_X x q x d` tensor, whose batch dimensions that
correspond to batch dimensions of `Y` are to be matched to those
(if compatible).
Y: A `batch_shape_Y x q' x d` tensor.
Returns:
A `batch_shape_Y x q x d` tensor containing the data of `X` expanded to
the batch dimensions of `Y` (if compatible). For instance, if `X` is
`b'' x b' x q x d` and `Y` is `b x q x d`, then the returned tensor is
`b'' x b x q x d`.
Example:
>>> X = torch.rand(2, 1, 5, 3)
>>> Y = torch.rand(2, 6, 4, 3)
>>> X_matched = match_batch_shape(X, Y)
>>> X_matched.shape
torch.Size([2, 6, 5, 3])
"""
return X.expand(X.shape[: -Y.dim()] + Y.shape[:-2] + X.shape[-2:])
[docs]def convert_to_target_pre_hook(module, *args):
r"""Pre-hook for automatically calling `.to(X)` on module prior to `forward`"""
module.to(args[0][0])