# Source code for botorch.utils.multi_objective.box_decompositions.box_decomposition_list

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# LICENSE file in the root directory of this source tree.

r"""Box decomposition container."""

from __future__ import annotations

from typing import List, Union

import torch
from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.utils.multi_objective.box_decompositions.box_decomposition import (
BoxDecomposition,
)
from torch import Tensor
from torch.nn import Module, ModuleList

[docs]class BoxDecompositionList(Module):
r"""A list of box decompositions."""

def __init__(self, *box_decompositions: BoxDecomposition) -> None:
r"""Initialize the box decomposition list.

Args:
*box_decompositions: An variable number of box decompositions

Example:
>>> bd1 = FastNondominatedPartitioning(ref_point, Y=Y1)
>>> bd2 = FastNondominatedPartitioning(ref_point, Y=Y2)
>>> bd = BoxDecompositionList(bd1, bd2)
"""
super().__init__()
self.box_decompositions = ModuleList(box_decompositions)

@property
def pareto_Y(self) -> List[Tensor]:
r"""This returns the non-dominated set.

Note: Internally, we store the negative pareto set (minimization).

Returns:
A list where the ith element is the n_pareto_i x m-dim tensor
of pareto optimal outcomes for each box_decomposition i.
"""
return [p.pareto_Y for p in self.box_decompositions]

@property
def ref_point(self) -> Tensor:
r"""Get the reference point.

Note: Internally, we store the negative reference point (minimization).

Returns:
A n_box_decompositions x m-dim tensor of outcomes.
"""
return torch.stack([p.ref_point for p in self.box_decompositions], dim=0)

[docs]    def get_hypercell_bounds(self) -> Tensor:
r"""Get the bounds of each hypercell in the decomposition.

Returns:
A 2 x n_box_decompositions x num_cells x num_outcomes-dim tensor
containing the lower and upper vertices bounding each hypercell.
"""
bounds_list = []
max_num_cells = 0
for p in self.box_decompositions:
bounds = p.get_hypercell_bounds()
max_num_cells = max(max_num_cells, bounds.shape[-2])
bounds_list.append(bounds)
# pad the decomposition with empty cells so that all
# decompositions have the same number of cells
for i, bounds in enumerate(bounds_list):
num_missing = max_num_cells - bounds.shape[-2]
if num_missing > 0:
2,
num_missing,
bounds.shape[-1],
dtype=bounds.dtype,
device=bounds.device,
)
bounds_list[i] = torch.cat(
[
bounds,
],
dim=-2,
)

[docs]    def update(self, Y: Union[List[Tensor], Tensor]) -> None:
r"""Update the partitioning.

Args:
Y: A n_box_decompositions x n x num_outcomes-dim tensor or a list
where the ith  element contains the new points for
box_decomposition i.
"""
if (
torch.is_tensor(Y)
and Y.ndim != 3
and Y.shape != len(self.box_decompositions)
) or (isinstance(Y, List) and len(Y) != len(self.box_decompositions)):
raise BotorchTensorDimensionError(
"BoxDecompositionList.update requires either a batched tensor Y, "
"with one batch per box decomposition or a list of tensors with "
"one element per box decomposition."
)
for i, p in enumerate(self.box_decompositions):
p.update(Y[i])

[docs]    def compute_hypervolume(self) -> Tensor:
r"""Compute hypervolume that is dominated by the Pareto Froniter.

Returns:
A (batch_shape)-dim tensor containing the hypervolume dominated by
each Pareto frontier.
"""