Source code for botorch.utils.multi_objective.box_decompositions.box_decomposition
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""Box decomposition algorithms."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional
import torch
from botorch.exceptions.errors import BotorchError, BotorchTensorDimensionError
from botorch.utils.multi_objective.box_decompositions.utils import (
_expand_ref_point,
_pad_batch_pareto_frontier,
)
from torch import Tensor
from torch.nn import Module
[docs]class BoxDecomposition(Module, ABC):
r"""An abstract class for box decompositions.
Note: Internally, we store the negative reference point (minimization).
"""
def __init__(
self, ref_point: Tensor, sort: bool, Y: Optional[Tensor] = None
) -> None:
"""Initialize BoxDecomposition.
Args:
ref_point: A `m`-dim tensor containing the reference point.
sort: A boolean indicating whether to sort the Pareto frontier.
Y: A `(batch_shape) x n x m`-dim tensor of outcomes.
"""
super().__init__()
self.register_buffer("_neg_ref_point", -ref_point)
self.register_buffer("sort", torch.tensor(sort, dtype=torch.bool))
self.num_outcomes = ref_point.shape[-1]
if Y is not None:
self.update(Y=Y)
@property
def pareto_Y(self) -> Tensor:
r"""This returns the non-dominated set.
Returns:
A `n_pareto x m`-dim tensor of outcomes.
"""
try:
return -self._neg_pareto_Y
except AttributeError:
raise BotorchError("pareto_Y has not been initialized")
@property
def ref_point(self) -> Tensor:
r"""Get the reference point.
Returns:
A `m`-dim tensor of outcomes.
"""
return -self._neg_ref_point
@property
def Y(self) -> Tensor:
r"""Get the raw outcomes.
Returns:
A `n x m`-dim tensor of outcomes.
"""
return -self._neg_Y
def _update_pareto_Y(self) -> bool:
r"""Update the non-dominated front.
Returns:
A boolean indicating whether the Pareto frontier has changed.
"""
# is_non_dominated assumes maximization
if self._neg_Y.shape[-2] == 0:
pareto_Y = self._neg_Y
else:
# assumes maximization
pareto_Y = -_pad_batch_pareto_frontier(
Y=self.Y,
ref_point=_expand_ref_point(
ref_point=self.ref_point, batch_shape=self.batch_shape
),
)
if self.sort:
# sort by first objective
if len(self.batch_shape) > 0:
pareto_Y = pareto_Y.gather(
index=torch.argsort(pareto_Y[..., :1], dim=-2).expand(
pareto_Y.shape
),
dim=-2,
)
else:
pareto_Y = pareto_Y[torch.argsort(pareto_Y[:, 0])]
if not hasattr(self, "_neg_pareto_Y") or not torch.equal(
pareto_Y, self._neg_pareto_Y
):
self.register_buffer("_neg_pareto_Y", pareto_Y)
return True
return False
[docs] def partition_space(self) -> None:
r"""Compute box decomposition."""
try:
self.partition_space_2d()
except BotorchTensorDimensionError:
self._partition_space()
[docs] @abstractmethod
def partition_space_2d(self) -> None:
r"""Compute box decomposition for 2 objectives."""
pass # pragma: no cover
[docs] @abstractmethod
def get_hypercell_bounds(self) -> Tensor:
r"""Get the bounds of each hypercell in the decomposition.
Returns:
A `2 x num_cells x num_outcomes`-dim tensor containing the
lower and upper vertices bounding each hypercell.
"""
pass # pragma: no cover
[docs] def update(self, Y: Tensor) -> None:
r"""Update non-dominated front and decomposition.
Args:
Y: A `(batch_shape) x n x m`-dim tensor of outcomes.
"""
self.batch_shape = Y.shape[:-2]
if len(self.batch_shape) > 1:
raise NotImplementedError(
f"{type(self).__name__} only supports a single "
f"batch dimension, but got {len(self.batch_shape)} "
"batch dimensions."
)
elif len(self.batch_shape) > 0 and self.num_outcomes > 2:
raise NotImplementedError(
f"{type(self).__name__} only supports a batched box "
f"decompositions in the 2-objective setting."
)
# multiply by -1, since internally we minimize.
self._neg_Y = -Y
is_new_pareto = self._update_pareto_Y()
# Update decomposition if the Pareto front changed
if is_new_pareto:
self.partition_space()