Source code for botorch.utils.containers

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""Representations for different kinds of data."""

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any

from torch import device as Device, dtype as Dtype, LongTensor, Size, Tensor


class BotorchContainer(ABC):
    r"""Abstract base class for BoTorch's data containers.

    A BotorchContainer represents a tensor, which should be the sole object
    returned by its `__call__` method. Said tensor is expected to consist of
    one or more "events" (e.g. data points or feature vectors), whose shape is
    given by the required `event_shape` field.

    Notice: Once version 3.10 becomes standard, this class should
    be reworked to take advantage of dataclasses' `kw_only` flag.

    :meta private:
    """

    def __post_init__(self, validate_init: bool = True) -> None:
        if validate_init:
            self._validate()

    @abstractmethod
    def __call__(self) -> Tensor:
        raise NotImplementedError

    @abstractmethod
    def __eq__(self, other: Any) -> bool:
        raise NotImplementedError

    @property
    @abstractmethod
    def shape(self) -> Size:
        raise NotImplementedError

    @property
    @abstractmethod
    def device(self) -> Device:
        raise NotImplementedError

    @property
    @abstractmethod
    def dtype(self) -> Dtype:
        raise NotImplementedError

    def _validate(self) -> None:
        for field in fields(self):
            if field.name == "event_shape":
                return
        raise AttributeError("Missing required field `event_shape`.")


[docs]@dataclass(eq=False) class DenseContainer(BotorchContainer): r"""Basic representation of data stored as a dense Tensor.""" values: Tensor event_shape: Size def __call__(self) -> Tensor: """Returns a dense tensor representation of the container's contents.""" return self.values def __eq__(self, other: Any) -> bool: return ( type(other) is type(self) and self.shape == other.shape and self.values.equal(other.values) ) @property def shape(self) -> Size: return self.values.shape @property def device(self) -> Device: return self.values.device @property def dtype(self) -> Dtype: return self.values.dtype def _validate(self) -> None: super()._validate() for a, b in zip(reversed(self.event_shape), reversed(self.values.shape)): if a != b: raise ValueError( f"Shape of `values` {self.values.shape} incompatible with " f"`event shape` {self.event_shape}." )
[docs]@dataclass(eq=False) class SliceContainer(BotorchContainer): r"""Represent data points formed by concatenating (n-1)-dimensional slices taken from the leading dimension of an n-dimensional source tensor.""" values: Tensor indices: LongTensor event_shape: Size def __call__(self) -> Tensor: flat = self.values.index_select(dim=0, index=self.indices.view(-1)) return flat.view(*self.indices.shape[:-1], -1, *self.values.shape[2:]) def __eq__(self, other: Any) -> bool: return ( type(other) is type(self) and self.values.equal(other.values) and self.indices.equal(other.indices) ) @property def shape(self) -> Size: return self.indices.shape[:-1] + self.event_shape @property def device(self) -> Device: return self.values.device @property def dtype(self) -> Dtype: return self.values.dtype def _validate(self) -> None: super()._validate() values = self.values indices = self.indices assert indices.ndim > 1 assert (-1 < indices.min()) & (indices.max() < len(values)) event_shape = self.event_shape _event_shape = (indices.shape[-1] * values.shape[1],) + values.shape[2:] if event_shape != _event_shape: raise ValueError( f"Shapes of `values` {values.shape} and `indices` " f"{indices.shape} incompatible with `event_shape` {event_shape}." )