Source code for botorch.utils.containers

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""Representations for different kinds of data."""

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any

from torch import device as Device, dtype as Dtype, LongTensor, Size, Tensor


[docs] class BotorchContainer(ABC): r"""Abstract base class for BoTorch's data containers. A BotorchContainer represents a tensor, which should be the sole object returned by its `__call__` method. Said tensor is expected to consist of one or more "events" (e.g. data points or feature vectors), whose shape is given by the required `event_shape` field. Notice: Once version 3.10 becomes standard, this class should be reworked to take advantage of dataclasses' `kw_only` flag. """ event_shape: Size def __post_init__(self, validate_init: bool = True) -> None: if validate_init: self._validate() @abstractmethod def __call__(self) -> Tensor: raise NotImplementedError @abstractmethod def __eq__(self, other: Any) -> bool: raise NotImplementedError @property @abstractmethod def shape(self) -> Size: raise NotImplementedError @property @abstractmethod def device(self) -> Device: raise NotImplementedError @property @abstractmethod def dtype(self) -> Dtype: raise NotImplementedError def _validate(self) -> None: for field in fields(self): if field.name == "event_shape": return raise AttributeError("Missing required field `event_shape`.")
[docs] @dataclass(eq=False) class DenseContainer(BotorchContainer): r"""Basic representation of data stored as a dense Tensor.""" values: Tensor event_shape: Size def __call__(self) -> Tensor: """Returns a dense tensor representation of the container's contents.""" return self.values def __eq__(self, other: Any) -> bool: return ( type(other) is type(self) and self.shape == other.shape and self.values.equal(other.values) ) @property def shape(self) -> Size: return self.values.shape @property def device(self) -> Device: return self.values.device @property def dtype(self) -> Dtype: return self.values.dtype def _validate(self) -> None: super()._validate() for a, b in zip(reversed(self.event_shape), reversed(self.values.shape)): if a != b: raise ValueError( f"Shape of `values` {self.values.shape} incompatible with " f"`event shape` {self.event_shape}." )
[docs] @dataclass(eq=False) class SliceContainer(BotorchContainer): r"""Represent data points formed by concatenating (n-1)-dimensional slices taken from the leading dimension of an n-dimensional source tensor.""" values: Tensor indices: LongTensor event_shape: Size def __call__(self) -> Tensor: flat = self.values.index_select(dim=0, index=self.indices.view(-1)) return flat.view(*self.indices.shape[:-1], -1, *self.values.shape[2:]) def __eq__(self, other: Any) -> bool: return ( type(other) is type(self) and self.values.equal(other.values) and self.indices.equal(other.indices) ) @property def shape(self) -> Size: return self.indices.shape[:-1] + self.event_shape @property def device(self) -> Device: return self.values.device @property def dtype(self) -> Dtype: return self.values.dtype def _validate(self) -> None: super()._validate() values = self.values indices = self.indices assert indices.ndim > 1 assert (-1 < indices.min()) & (indices.max() < len(values)) event_shape = self.event_shape _event_shape = (indices.shape[-1] * values.shape[1],) + values.shape[2:] if event_shape != _event_shape: raise ValueError( f"Shapes of `values` {values.shape} and `indices` " f"{indices.shape} incompatible with `event_shape` {event_shape}." )