Source code for botorch.utils.multi_objective.box_decompositions.box_decomposition_list
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""Box decomposition container."""
from __future__ import annotations
from typing import List, Union
import torch
from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.utils.multi_objective.box_decompositions.box_decomposition import (
BoxDecomposition,
)
from torch import Tensor
from torch.nn import Module, ModuleList
[docs]class BoxDecompositionList(Module):
r"""A list of box decompositions."""
def __init__(self, *box_decompositions: BoxDecomposition) -> None:
r"""Initialize the box decomposition list.
Args:
*box_decompositions: An variable number of box decompositions
Example:
>>> bd1 = FastNondominatedPartitioning(ref_point, Y=Y1)
>>> bd2 = FastNondominatedPartitioning(ref_point, Y=Y2)
>>> bd = BoxDecompositionList(bd1, bd2)
"""
super().__init__()
self.box_decompositions = ModuleList(box_decompositions)
@property
def pareto_Y(self) -> List[Tensor]:
r"""This returns the non-dominated set.
Note: Internally, we store the negative pareto set (minimization).
Returns:
A list where the ith element is the `n_pareto_i x m`-dim tensor
of pareto optimal outcomes for each box_decomposition `i`.
"""
return [p.pareto_Y for p in self.box_decompositions]
@property
def ref_point(self) -> Tensor:
r"""Get the reference point.
Note: Internally, we store the negative reference point (minimization).
Returns:
A `n_box_decompositions x m`-dim tensor of outcomes.
"""
return torch.stack([p.ref_point for p in self.box_decompositions], dim=0)
[docs] def get_hypercell_bounds(self) -> Tensor:
r"""Get the bounds of each hypercell in the decomposition.
Returns:
A `2 x n_box_decompositions x num_cells x num_outcomes`-dim tensor
containing the lower and upper vertices bounding each hypercell.
"""
bounds_list = []
max_num_cells = 0
for p in self.box_decompositions:
bounds = p.get_hypercell_bounds()
max_num_cells = max(max_num_cells, bounds.shape[-2])
bounds_list.append(bounds)
# pad the decomposition with empty cells so that all
# decompositions have the same number of cells
for i, bounds in enumerate(bounds_list):
num_missing = max_num_cells - bounds.shape[-2]
if num_missing > 0:
padding = torch.zeros(
2,
num_missing,
bounds.shape[-1],
dtype=bounds.dtype,
device=bounds.device,
)
bounds_list[i] = torch.cat(
[
bounds,
padding,
],
dim=-2,
)
return torch.stack(bounds_list, dim=-3)
[docs] def update(self, Y: Union[List[Tensor], Tensor]) -> None:
r"""Update the partitioning.
Args:
Y: A `n_box_decompositions x n x num_outcomes`-dim tensor or a list
where the ith element contains the new points for
box_decomposition `i`.
"""
if (
torch.is_tensor(Y)
and Y.ndim != 3
and Y.shape[0] != len(self.box_decompositions)
) or (isinstance(Y, List) and len(Y) != len(self.box_decompositions)):
raise BotorchTensorDimensionError(
"BoxDecompositionList.update requires either a batched tensor Y, "
"with one batch per box decomposition or a list of tensors with "
"one element per box decomposition."
)
for i, p in enumerate(self.box_decompositions):
p.update(Y[i])
[docs] def compute_hypervolume(self) -> Tensor:
r"""Compute hypervolume that is dominated by the Pareto Froniter.
Returns:
A `(batch_shape)`-dim tensor containing the hypervolume dominated by
each Pareto frontier.
"""
return torch.stack(
[p.compute_hypervolume() for p in self.box_decompositions], dim=0
)