#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""Assorted helper methods and objects for working with BoTorch models."""
from __future__ import annotations
import warnings
from contextlib import contextmanager, ExitStack
from typing import List, Optional, Tuple
import torch
from botorch import settings
from botorch.exceptions import InputDataError, InputDataWarning
from botorch.settings import _Flag
from gpytorch import settings as gpt_settings
from gpytorch.module import Module
from gpytorch.utils.broadcasting import _mul_broadcast_shape
from torch import Tensor
def _make_X_full(X: Tensor, output_indices: List[int], tf: int) -> Tensor:
r"""Helper to construct input tensor with task indices.
Args:
X: The raw input tensor (without task information).
output_indices: The output indices to generate (passed in via `posterior`).
tf: The task feature index.
Returns:
Tensor: The full input tensor for the multi-task model, including task
indices.
"""
index_shape = X.shape[:-1] + torch.Size([1])
indexers = (
torch.full(index_shape, fill_value=i, device=X.device, dtype=X.dtype)
for i in output_indices
)
X_l, X_r = X[..., :tf], X[..., tf:]
return torch.cat(
[torch.cat([X_l, indexer, X_r], dim=-1) for indexer in indexers], dim=-2
)
[docs]def add_output_dim(X: Tensor, original_batch_shape: torch.Size) -> Tuple[Tensor, int]:
r"""Insert the output dimension at the correct location.
The trailing batch dimensions of X must match the original batch dimensions
of the training inputs, but can also include extra batch dimensions.
Args:
X: A `(new_batch_shape) x (original_batch_shape) x n x d` tensor of
features.
original_batch_shape: the batch shape of the model's training inputs.
Returns:
2-element tuple containing
- A `(new_batch_shape) x (original_batch_shape) x m x n x d` tensor of
features.
- The index corresponding to the output dimension.
"""
X_batch_shape = X.shape[:-2]
if len(X_batch_shape) > 0 and len(original_batch_shape) > 0:
# check that X_batch_shape supports broadcasting or augments
# original_batch_shape with extra batch dims
error_msg = (
"The trailing batch dimensions of X must match the trailing "
"batch dimensions of the training inputs."
)
_mul_broadcast_shape(X_batch_shape, original_batch_shape, error_msg=error_msg)
# insert `m` dimension
X = X.unsqueeze(-3)
output_dim_idx = max(len(original_batch_shape), len(X_batch_shape))
return X, output_dim_idx
[docs]def check_no_nans(Z: Tensor) -> None:
r"""Check that tensor does not contain NaN values.
Raises an InputDataError if `Z` contains NaN values.
Args:
Z: The input tensor.
"""
if torch.any(torch.isnan(Z)).item():
raise InputDataError("Input data contains NaN values.")
[docs]def check_min_max_scaling(
X: Tensor,
strict: bool = False,
atol: float = 1e-2,
raise_on_fail: bool = False,
ignore_dims: Optional[List[int]] = None,
) -> None:
r"""Check that tensor is normalized to the unit cube.
Args:
X: A `batch_shape x n x d` input tensor. Typically the training inputs
of a model.
strict: If True, require `X` to be scaled to the unit cube (rather than
just to be contained within the unit cube).
atol: The tolerance for the boundary check. Only used if `strict=True`.
raise_on_fail: If True, raise an exception instead of a warning.
ignore_dims: Subset of dimensions where the min-max scaling check is omitted.
"""
ignore_dims = ignore_dims or []
check_dims = list(set(range(X.shape[-1])) - set(ignore_dims))
if len(check_dims) == 0:
return None
with torch.no_grad():
X_check = X[..., check_dims]
Xmin = torch.min(X_check, dim=-1).values
Xmax = torch.max(X_check, dim=-1).values
msg = None
if strict and max(torch.abs(Xmin).max(), torch.abs(Xmax - 1).max()) > atol:
msg = "scaled"
if torch.any(Xmin < -atol) or torch.any(Xmax > 1 + atol):
msg = "contained"
if msg is not None:
msg = (
f"Input data is not {msg} to the unit cube. "
"Please consider min-max scaling the input data."
)
if raise_on_fail:
raise InputDataError(msg)
warnings.warn(msg, InputDataWarning)
[docs]def check_standardization(
Y: Tensor,
atol_mean: float = 1e-2,
atol_std: float = 1e-2,
raise_on_fail: bool = False,
) -> None:
r"""Check that tensor is standardized (zero mean, unit variance).
Args:
Y: The input tensor of shape `batch_shape x n x m`. Typically the
train targets of a model. Standardization is checked across the
`n`-dimension.
atol_mean: The tolerance for the mean check.
atol_std: The tolerance for the std check.
raise_on_fail: If True, raise an exception instead of a warning.
"""
with torch.no_grad():
Ymean, Ystd = torch.mean(Y, dim=-2), torch.std(Y, dim=-2)
if torch.abs(Ymean).max() > atol_mean or torch.abs(Ystd - 1).max() > atol_std:
msg = (
"Input data is not standardized. Please consider scaling the "
"input to zero mean and unit variance."
)
if raise_on_fail:
raise InputDataError(msg)
warnings.warn(msg, InputDataWarning)
[docs]def mod_batch_shape(module: Module, names: List[str], b: int) -> None:
r"""Recursive helper to modify gpytorch modules' batch shape attribute.
Modifies the module in-place.
Args:
module: The module to be modified.
names: The list of names to access the attribute. If the full name of
the module is `"module.sub_module.leaf_module"`, this will be
`["sub_module", "leaf_module"]`.
b: The new size of the last element of the module's `batch_shape`
attribute.
"""
if len(names) == 0:
return
m = getattr(module, names[0])
if len(names) == 1 and hasattr(m, "batch_shape") and len(m.batch_shape) > 0:
m.batch_shape = m.batch_shape[:-1] + torch.Size([b] if b > 0 else [])
else:
mod_batch_shape(module=m, names=names[1:], b=b)
[docs]@contextmanager
def gpt_posterior_settings():
r"""Context manager for settings used for computing model posteriors."""
with ExitStack() as es:
if gpt_settings.debug.is_default():
es.enter_context(gpt_settings.debug(False))
if gpt_settings.fast_pred_var.is_default():
es.enter_context(gpt_settings.fast_pred_var())
es.enter_context(
gpt_settings.detach_test_caches(settings.propagate_grads.off())
)
yield
[docs]class fantasize(_Flag):
r"""A flag denoting whether we are currently in a `fantasize` context."""
_state: bool = False