Source code for botorch.utils.containers

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Containers to standardize inputs into models and acquisition functions.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional

import torch
from botorch.exceptions.errors import UnsupportedError
from torch import Tensor


[docs]@dataclass class TrainingData: r"""Standardized container of model training data for models. Properties: Xs: A list of tensors, each of shape `batch_shape x n_i x d`, where `n_i` is the number of training inputs for the i-th model. Ys: A list of tensors, each of shape `batch_shape x n_i x 1`, where `n_i` is the number of training observations for the i-th (single-output) model. Yvars: A list of tensors, each of shape `batch_shape x n_i x 1`, where `n_i` is the number of training observations of the observation noise for the i-th (single-output) model. If `None`, the observation noise level is unobserved. """ Xs: List[Tensor] # `batch_shape x n_i x 1` Ys: List[Tensor] # `batch_shape x n_i x 1` Yvars: Optional[List[Tensor]] = None # `batch_shape x n_i x 1` def __post_init__(self): self._is_block_design = all(torch.equal(X, self.Xs[0]) for X in self.Xs[1:])
[docs] @classmethod def from_block_design(cls, X: Tensor, Y: Tensor, Yvar: Optional[Tensor] = None): r"""Construct a TrainingData object from a block design description. Args: X: A `batch_shape x n x d` tensor of training points (shared across all outcomes). Y: A `batch_shape x n x m` tensor of training observations. Yvar: A `batch_shape x n x m` tensor of training noise variance observations, or `None`. Returns: The `TrainingData` object (with `is_block_design=True`). """ return cls( Xs=[X for _ in range(Y.shape[-1])], Ys=list(torch.split(Y, 1, dim=-1)), Yvars=None if Yvar is None else list(torch.split(Yvar, 1, dim=-1)), )
@property def is_block_design(self) -> bool: r"""Indicates whether training data is a "block design". Block designs are designs in which all outcomes are observed at the same training inputs. """ return self._is_block_design @property def X(self) -> Tensor: r"""The training inputs (block-design only). This raises an `UnsupportedError` in the non-block-design case. """ if not self.is_block_design: raise UnsupportedError return self.Xs[0] @property def Y(self) -> Tensor: r"""The training observations (block-design only). This raises an `UnsupportedError` in the non-block-design case. """ if not self.is_block_design: raise UnsupportedError return torch.cat(self.Ys, dim=-1) @property def Yvar(self) -> Optional[List[Tensor]]: r"""The training observations's noise variance (block-design only). This raises an `UnsupportedError` in the non-block-design case. """ if self.Yvars is not None: if not self.is_block_design: raise UnsupportedError return torch.cat(self.Yvars, dim=-1) def __eq__(self, other: TrainingData) -> bool: # Check for `None` Yvars and unequal attribute lengths upfront. if self.Yvars is None or other.Yvars is None: if not (self.Yvars is other.Yvars is None): return False else: if len(self.Yvars) != len(other.Yvars): return False if len(self.Xs) != len(other.Xs) or len(self.Ys) != len(other.Ys): return False return ( # Deep-check equality of attributes. all( torch.equal(self_X, other_X) for self_X, other_X in zip(self.Xs, other.Xs) ) and all( torch.equal(self_Y, other_Y) for self_Y, other_Y in zip(self.Ys, other.Ys) ) and ( self.Yvars is other.Yvars is None or all( torch.equal(self_Yvar, other_Yvar) for self_Yvar, other_Yvar in zip(self.Yvars, other.Yvars) ) ) )