#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Utiltiy functions for models.
"""
from __future__ import annotations
import warnings
from contextlib import ExitStack, contextmanager
from typing import List, Optional, Tuple
import torch
from botorch import settings
from botorch.exceptions import InputDataError, InputDataWarning
from gpytorch import settings as gpt_settings
from gpytorch.module import Module
from gpytorch.utils.broadcasting import _mul_broadcast_shape
from torch import Tensor
def _make_X_full(X: Tensor, output_indices: List[int], tf: int) -> Tensor:
r"""Helper to construct input tensor with task indices.
Args:
X: The raw input tensor (without task information).
output_indices: The output indices to generate (passed in via `posterior`).
tf: The task feature index.
Returns:
Tensor: The full input tensor for the multi-task model, including task
indices.
"""
index_shape = X.shape[:-1] + torch.Size([1])
indexers = (
torch.full(index_shape, fill_value=i, device=X.device, dtype=X.dtype)
for i in output_indices
)
X_l, X_r = X[..., :tf], X[..., tf:]
return torch.cat(
[torch.cat([X_l, indexer, X_r], dim=-1) for indexer in indexers], dim=-2
)
[docs]def add_output_dim(X: Tensor, original_batch_shape: torch.Size) -> Tuple[Tensor, int]:
r"""Insert the output dimension at the correct location.
The trailing batch dimensions of X must match the original batch dimensions
of the training inputs, but can also include extra batch dimensions.
Args:
X: A `(new_batch_shape) x (original_batch_shape) x n x d` tensor of
features.
original_batch_shape: the batch shape of the model's training inputs.
Returns:
2-element tuple containing
- A `(new_batch_shape) x (original_batch_shape) x m x n x d` tensor of
features.
- The index corresponding to the output dimension.
"""
X_batch_shape = X.shape[:-2]
if len(X_batch_shape) > 0 and len(original_batch_shape) > 0:
# check that X_batch_shape supports broadcasting or augments
# original_batch_shape with extra batch dims
error_msg = (
"The trailing batch dimensions of X must match the trailing "
"batch dimensions of the training inputs."
)
_mul_broadcast_shape(X_batch_shape, original_batch_shape, error_msg=error_msg)
# insert `m` dimension
X = X.unsqueeze(-3)
output_dim_idx = max(len(original_batch_shape), len(X_batch_shape))
return X, output_dim_idx
[docs]def check_no_nans(Z: Tensor) -> None:
r"""Check that tensor does not contain NaN values.
Raises an InputDataError if `Z` contains NaN values.
Args:
Z: The input tensor.
"""
if torch.any(torch.isnan(Z)).item():
raise InputDataError("Input data contains NaN values.")
[docs]def check_min_max_scaling(
X: Tensor, strict: bool = False, atol: float = 1e-2, raise_on_fail: bool = False
) -> None:
r"""Check that tensor is normalized to the unit cube.
Args:
X: A `batch_shape x n x d` input tensor. Typically the training inputs
of a model.
strict: If True, require `X` to be scaled to the unit cube (rather than
just to be contained within the unit cube).
atol: The tolerance for the boundary check. Only used if `strict=True`.
raise_on_fail: If True, raise an exception instead of a warning.
"""
with torch.no_grad():
Xmin, Xmax = torch.min(X, dim=-1)[0], torch.max(X, dim=-1)[0]
msg = None
if strict and max(torch.abs(Xmin).max(), torch.abs(Xmax - 1).max()) > atol:
msg = "scaled"
if torch.any(Xmin < -atol) or torch.any(Xmax > 1 + atol):
msg = "contained"
if msg is not None:
msg = (
f"Input data is not {msg} to the unit cube. "
"Please consider min-max scaling the input data."
)
if raise_on_fail:
raise InputDataError(msg)
warnings.warn(msg, InputDataWarning)
[docs]def check_standardization(
Y: Tensor,
atol_mean: float = 1e-2,
atol_std: float = 1e-2,
raise_on_fail: bool = False,
) -> None:
r"""Check that tensor is standardized (zero mean, unit variance).
Args:
Y: The input tensor of shape `batch_shape x n x m`. Typically the
train targets of a model. Standardization is checked across the
`n`-dimension.
atol_mean: The tolerance for the mean check.
atol_std: The tolerance for the std check.
raise_on_fail: If True, raise an exception instead of a warning.
"""
with torch.no_grad():
Ymean, Ystd = torch.mean(Y, dim=-2), torch.std(Y, dim=-2)
if torch.abs(Ymean).max() > atol_mean or torch.abs(Ystd - 1).max() > atol_std:
msg = (
"Input data is not standardized. Please consider scaling the "
"input to zero mean and unit variance."
)
if raise_on_fail:
raise InputDataError(msg)
warnings.warn(msg, InputDataWarning)
[docs]def mod_batch_shape(module: Module, names: List[str], b: int) -> None:
r"""Recursive helper to modify gpytorch modules' batch shape attribute.
Modifies the module in-place.
Args:
module: The module to be modified.
names: The list of names to access the attribute. If the full name of
the module is `"module.sub_module.leaf_module"`, this will be
`["sub_module", "leaf_module"]`.
b: The new size of the last element of the module's `batch_shape`
attribute.
"""
if len(names) == 0:
return
m = getattr(module, names[0])
if len(names) == 1 and hasattr(m, "batch_shape") and len(m.batch_shape) > 0:
m.batch_shape = m.batch_shape[:-1] + torch.Size([b] if b > 0 else [])
else:
mod_batch_shape(module=m, names=names[1:], b=b)
[docs]@contextmanager
def gpt_posterior_settings():
r"""Context manager for settings used for computing model posteriors."""
with ExitStack() as es:
es.enter_context(gpt_settings.debug(False))
es.enter_context(gpt_settings.fast_pred_var())
es.enter_context(
gpt_settings.detach_test_caches(settings.propagate_grads.off())
)
yield