Source code for botorch.utils.containers
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""Representations for different kinds of data."""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any
from torch import device as Device, dtype as Dtype, LongTensor, Size, Tensor
[docs]
class BotorchContainer(ABC):
r"""Abstract base class for BoTorch's data containers.
A BotorchContainer represents a tensor, which should be the sole object
returned by its `__call__` method. Said tensor is expected to consist of
one or more "events" (e.g. data points or feature vectors), whose shape is
given by the required `event_shape` field.
Notice: Once version 3.10 becomes standard, this class should
be reworked to take advantage of dataclasses' `kw_only` flag.
"""
event_shape: Size
def __post_init__(self, validate_init: bool = True) -> None:
if validate_init:
self._validate()
@abstractmethod
def __call__(self) -> Tensor:
raise NotImplementedError
@abstractmethod
def __eq__(self, other: Any) -> bool:
raise NotImplementedError
@property
@abstractmethod
def shape(self) -> Size:
raise NotImplementedError
@property
@abstractmethod
def device(self) -> Device:
raise NotImplementedError
@property
@abstractmethod
def dtype(self) -> Dtype:
raise NotImplementedError
def _validate(self) -> None:
for field in fields(self):
if field.name == "event_shape":
return
raise AttributeError("Missing required field `event_shape`.")
[docs]
@dataclass(eq=False)
class DenseContainer(BotorchContainer):
r"""Basic representation of data stored as a dense Tensor."""
values: Tensor
event_shape: Size
def __call__(self) -> Tensor:
"""Returns a dense tensor representation of the container's contents."""
return self.values
def __eq__(self, other: Any) -> bool:
return (
type(other) is type(self)
and self.shape == other.shape
and self.values.equal(other.values)
)
@property
def shape(self) -> Size:
return self.values.shape
@property
def device(self) -> Device:
return self.values.device
@property
def dtype(self) -> Dtype:
return self.values.dtype
def _validate(self) -> None:
super()._validate()
for a, b in zip(reversed(self.event_shape), reversed(self.values.shape)):
if a != b:
raise ValueError(
f"Shape of `values` {self.values.shape} incompatible with "
f"`event shape` {self.event_shape}."
)
[docs]
@dataclass(eq=False)
class SliceContainer(BotorchContainer):
r"""Represent data points formed by concatenating (n-1)-dimensional slices
taken from the leading dimension of an n-dimensional source tensor."""
values: Tensor
indices: LongTensor
event_shape: Size
def __call__(self) -> Tensor:
flat = self.values.index_select(dim=0, index=self.indices.view(-1))
return flat.view(*self.indices.shape[:-1], -1, *self.values.shape[2:])
def __eq__(self, other: Any) -> bool:
return (
type(other) is type(self)
and self.values.equal(other.values)
and self.indices.equal(other.indices)
)
@property
def shape(self) -> Size:
return self.indices.shape[:-1] + self.event_shape
@property
def device(self) -> Device:
return self.values.device
@property
def dtype(self) -> Dtype:
return self.values.dtype
def _validate(self) -> None:
super()._validate()
values = self.values
indices = self.indices
assert indices.ndim > 1
assert (-1 < indices.min()) & (indices.max() < len(values))
event_shape = self.event_shape
_event_shape = (indices.shape[-1] * values.shape[1],) + values.shape[2:]
if event_shape != _event_shape:
raise ValueError(
f"Shapes of `values` {values.shape} and `indices` "
f"{indices.shape} incompatible with `event_shape` {event_shape}."
)