Source code for botorch.utils.multi_objective.box_decompositions.utils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# LICENSE file in the root directory of this source tree.

r"""Utilities for box decomposition algorithms."""

from typing import Optional, Tuple

import torch
from botorch.exceptions.errors import BotorchTensorDimensionError, UnsupportedError
from botorch.utils.multi_objective.pareto import is_non_dominated
from torch import Size, Tensor

def _expand_ref_point(ref_point: Tensor, batch_shape: Size) -> Tensor:
r"""Expand reference point to the proper batch_shape.

Args:
ref_point: A (batch_shape) x m-dim tensor containing the reference
point.
batch_shape: The batch shape.

Returns:
A batch_shape x m-dim tensor containing the expanded reference point
"""
if ref_point.shape[:-1] != batch_shape:
if ref_point.ndim > 1:
raise BotorchTensorDimensionError(
"Expected ref_point to be a batch_shape x m or m-dim tensor, "
f"but got {ref_point.shape}."
)
ref_point = ref_point.view(
*(1 for _ in batch_shape), ref_point.shape[-1]
).expand(batch_shape + ref_point.shape[-1:])
return ref_point

Y: Tensor,
ref_point: Tensor,
is_pareto: bool = False,
) -> Tensor:
r"""Get a batch Pareto frontier by padding the pareto frontier with repeated points.

This assumes maximization.

Args:
Y: A (batch_shape) x n x m-dim tensor of points
ref_point: a (batch_shape) x m-dim tensor containing the reference point
is_pareto: a boolean indicating whether the points in Y are already
non-dominated.
feasibility_mask: A (batch_shape) x n-dim tensor of booleans indicating
whether each point is feasible.

Returns:
A (batch_shape) x max_num_pareto x m-dim tensor of padded Pareto
frontiers.
"""
tkwargs = {"dtype": Y.dtype, "device": Y.device}
ref_point = ref_point.unsqueeze(-2)
batch_shape = Y.shape[:-2]
if len(batch_shape) > 1:
raise UnsupportedError(
"_pad_batch_pareto_frontier only supports a single "
f"batch dimension, but got {len(batch_shape)} "
"batch dimensions."
)
# set infeasible points to be the reference point (corresponding to the batch)
if not is_pareto:
else:
better_than_ref = (Y > ref_point).all(dim=-1)
# is_non_dominated assumes maximization
# TODO: filter out points that are worse than the reference point first here
if len(batch_shape) == 0:
# Note: in the batch case, the Pareto frontier is padded by repeating
# a Pareto point. This ensures that the padded box-decomposition has
# the same number of points, which enables fast batch operations.
pareto_Y = torch.empty(*batch_shape, max_n_pareto, Y.shape[-1], **tkwargs)
n_pareto = pareto_i.shape[0]
if n_pareto > 0:
pareto_Y[i, :n_pareto] = pareto_i
# pad pareto_Y, so that all batches have the same size Pareto set
pareto_Y[i, n_pareto:] = pareto_i[-1]
else:
# if there are no pareto points in this batch, use the reference
# point
pareto_Y[i, :] = ref_point[i]
return pareto_Y

[docs]def compute_local_upper_bounds(
U: Tensor, Z: Tensor, z: Tensor
) -> Tuple[Tensor, Tensor]:
r"""Compute local upper bounds.

Note: this assumes minimization.

This uses the incremental algorithm (Alg. 1) from [Lacour17]_.

Args:
U: A n x m-dim tensor containing the local upper bounds.
Z: A n x m x m-dim tensor containing the defining points.
z: A m-dim tensor containing the new point.

Returns:
2-element tuple containing:

- A new n' x m-dim tensor local upper bounds.
- A n' x m x m-dim tensor containing the defining points.
"""
num_outcomes = U.shape[-1]
z_dominates_U = (U > z).all(dim=-1)
# Select upper bounds that are dominated by z.
# These are the search zones that contain z.
if not z_dominates_U.any():
return U, Z
A = U[z_dominates_U]
A_Z = Z[z_dominates_U]
P = []
P_Z = []
for j in range(num_outcomes):
u_j[:, j] = z[j]
P.append(u_j)
lt_zj = Z_ku[..., j] <= z[j]
P_uj = torch.zeros(
u_j.shape[0], num_outcomes, num_outcomes, dtype=U.dtype, device=U.device
)
P_uj[:, mask] = Z_ku[lt_zj].view(P_uj.shape[0], num_outcomes - 1, -1)
P_Z.append(P_uj)
# filter out elements of U that are in A
not_z_dominates_U = ~z_dominates_U
U = U[not_z_dominates_U]
# remaining indices
Z = Z[not_z_dominates_U]
if len(P) > 0:
Z = torch.cat([Z, *P_Z], dim=0)
# return elements in P or elements in (U that are not in A)
U = torch.cat([U, *P], dim=-2)
return U, Z

[docs]def get_partition_bounds(Z: Tensor, U: Tensor, ref_point: Tensor) -> Tensor:
r"""Get the cell bounds given the local upper bounds and the defining points.

This implements Equation 2 in [Lacour17]_.

Args:
Z: A n x m x m-dim tensor containing the defining points. The first
dimension corresponds to u_idx, the second dimension corresponds to j,
and Z[u_idx, j] is the set of definining points Z^j(u) where
u = U[u_idx].
U: A n x m-dim tensor containing the local upper bounds.
ref_point: A m-dim tensor containing the reference point.

Returns:
A 2 x num_cells x m-dim tensor containing the lower and upper vertices
bounding each hypercell.
"""
bounds = torch.empty(2, U.shape[0], U.shape[-1], dtype=U.dtype, device=U.device)
for u_idx in range(U.shape[0]):
# z_1^1(u)
bounds[0, u_idx, 0] = Z[u_idx, 0, 0]
# z_1^r(u)
bounds[1, u_idx, 0] = ref_point[0]
for j in range(1, U.shape[-1]):
bounds[0, u_idx, j] = Z[u_idx, :j, j].max()
bounds[1, u_idx, j] = U[u_idx, j]
# remove empty partitions
# Note: the equality will evaluate as True if the lower and upper bound
# are both (-inf), which could happen if the reference point is -inf.
empty = (bounds[1] <= bounds[0]).any(dim=-1)
return bounds[:, ~empty]

[docs]def update_local_upper_bounds_incremental(
new_pareto_Y: Tensor, U: Tensor, Z: Tensor
) -> Tuple[Tensor, Tensor]:
r"""Update the current local upper with the new pareto points.

This assumes minimization.

Args:
new_pareto_Y: A n x m-dim tensor containing the new
Pareto points.
U: A n' x m-dim tensor containing the local upper bounds.
Z: A n x m x m-dim tensor containing the defining points.

Returns:
2-element tuple containing:

- A new n' x m-dim tensor local upper bounds.
- A n' x m x m-dim tensor containing the defining points
"""
for i in range(new_pareto_Y.shape[-2]):
U, Z = compute_local_upper_bounds(U=U, Z=Z, z=new_pareto_Y[i])
return U, Z

[docs]def compute_non_dominated_hypercell_bounds_2d(
pareto_Y_sorted: Tensor, ref_point: Tensor
) -> Tensor:
r"""Compute an axis-aligned partitioning of the non-dominated space for 2
objectives.

Args:
pareto_Y_sorted: A (batch_shape) x n_pareto x 2-dim tensor of pareto outcomes
that are sorted by the 0th dimension in increasing order. All points must be
better than the reference point.
ref_point: A (batch_shape) x 2-dim reference point.

Returns:
A 2 x (batch_shape) x n_pareto + 1 x m-dim tensor of cell bounds.
"""
# add boundary point to each front
# the boundary point is the extreme value in each outcome
# (a single coordinate of reference point)
batch_shape = pareto_Y_sorted.shape[:-2]
if ref_point.ndim == pareto_Y_sorted.ndim - 1:
expanded_boundary_point = ref_point.unsqueeze(-2)
else:
view_shape = torch.Size([1] * len(batch_shape)) + torch.Size([1, 2])
expanded_shape = batch_shape + torch.Size([1, 2])
expanded_boundary_point = ref_point.view(view_shape).expand(expanded_shape)

# add the points (ref, y) and (x, ref) to the corresponding ends
pareto_Y_sorted0, pareto_Y_sorted1 = torch.split(pareto_Y_sorted, 1, dim=-1)
expanded_boundary_point0, expanded_boundary_point1 = torch.split(
expanded_boundary_point, 1, dim=-1
)
left_end = torch.cat(
[expanded_boundary_point0[..., :1, :], pareto_Y_sorted1[..., :1, :]], dim=-1
)
right_end = torch.cat(
[pareto_Y_sorted0[..., -1:, :], expanded_boundary_point1[..., :1, :]], dim=-1
)
front = torch.cat([left_end, pareto_Y_sorted, right_end], dim=-2)
# The top left corners of axis-aligned rectangles in dominated partitioning.
# These are the bottom left corners of the non-dominated partitioning
front0, front1 = torch.split(front, 1, dim=-1)
bottom_lefts = torch.cat([front0[..., :-1, :], front1[..., 1:, :]], dim=-1)
top_right_xs = torch.cat(
[
front0[..., 1:-1, :],
torch.full(
bottom_lefts.shape[:-2] + torch.Size([1, 1]),
float("inf"),
dtype=front.dtype,
device=front.device,
),
],
dim=-2,
)
top_rights = torch.cat(
[
top_right_xs,
torch.full(
bottom_lefts.shape[:-1] + torch.Size([1]),
float("inf"),
dtype=front.dtype,
device=front.device,
),
],
dim=-1,
)

[docs]def compute_dominated_hypercell_bounds_2d(
pareto_Y_sorted: Tensor, ref_point: Tensor
) -> Tensor:
r"""Compute an axis-aligned partitioning of the dominated space for 2-objectives.

Args:
pareto_Y_sorted: A (batch_shape) x n_pareto x 2-dim tensor of pareto outcomes
that are sorted by the 0th dimension in increasing order.
ref_point: A 2-dim reference point.

Returns:
A 2 x (batch_shape) x n_pareto x m-dim tensor of cell bounds.
"""
# add boundary point to each front
# the boundary point is the extreme value in each outcome
# (a single coordinate of reference point)
batch_shape = pareto_Y_sorted.shape[:-2]
if ref_point.ndim == pareto_Y_sorted.ndim - 1:
expanded_boundary_point = ref_point.unsqueeze(-2)
else:
view_shape = torch.Size([1] * len(batch_shape)) + torch.Size([1, 2])
expanded_shape = batch_shape + torch.Size([1, 2])
expanded_boundary_point = ref_point.view(view_shape).expand(expanded_shape)

# add the points (ref, y) and (x, ref) to the corresponding ends
pareto_Y_sorted0, pareto_Y_sorted1 = torch.split(pareto_Y_sorted, 1, dim=-1)
expanded_boundary_point0, expanded_boundary_point1 = torch.split(
expanded_boundary_point, 1, dim=-1
)
left_end = torch.cat(
[expanded_boundary_point0[..., :1, :], pareto_Y_sorted0[..., :1, :]], dim=-1
)
right_end = torch.cat(
[pareto_Y_sorted1[..., :1, :], expanded_boundary_point1[..., :1, :]], dim=-1
)
front = torch.cat([left_end, pareto_Y_sorted, right_end], dim=-2)
# compute hypervolume by summing rectangles from min_x -> max_x
top_rights = front[..., 1:-1, :]
bottom_lefts = torch.cat(
[
front[..., :-2, :1],
expanded_boundary_point1.expand(*top_rights.shape[:-1], 1),
],
dim=-1,
)