Source code for botorch.utils.multi_objective.box_decompositions.utils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""Utilities for box decomposition algorithms."""

import torch
from botorch.exceptions.errors import BotorchTensorDimensionError, UnsupportedError
from botorch.utils.multi_objective.pareto import is_non_dominated
from torch import Size, Tensor


def _expand_ref_point(ref_point: Tensor, batch_shape: Size) -> Tensor:
    r"""Expand reference point to the proper batch_shape.

    Args:
        ref_point: A `(batch_shape) x m`-dim tensor containing the reference
            point.
        batch_shape: The batch shape.

    Returns:
        A `batch_shape x m`-dim tensor containing the expanded reference point
    """
    if ref_point.shape[:-1] != batch_shape:
        if ref_point.ndim > 1:
            raise BotorchTensorDimensionError(
                "Expected ref_point to be a `batch_shape x m` or `m`-dim tensor, "
                f"but got {ref_point.shape}."
            )
        ref_point = ref_point.view(
            *(1 for _ in batch_shape), ref_point.shape[-1]
        ).expand(batch_shape + ref_point.shape[-1:])
    return ref_point


def _pad_batch_pareto_frontier(
    Y: Tensor,
    ref_point: Tensor,
    is_pareto: bool = False,
    feasibility_mask: Tensor | None = None,
) -> Tensor:
    r"""Get a batch Pareto frontier by padding the pareto frontier with repeated points.

    This assumes maximization.

    Args:
        Y: A `(batch_shape) x n x m`-dim tensor of points
        ref_point: a `(batch_shape) x m`-dim tensor containing the reference point
        is_pareto: a boolean indicating whether the points in Y are already
            non-dominated.
        feasibility_mask: A `(batch_shape) x n`-dim tensor of booleans indicating
            whether each point is feasible.

    Returns:
        A `(batch_shape) x max_num_pareto x m`-dim tensor of padded Pareto
            frontiers.
    """
    tkwargs = {"dtype": Y.dtype, "device": Y.device}
    ref_point = ref_point.unsqueeze(-2)
    batch_shape = Y.shape[:-2]
    if len(batch_shape) > 1:
        raise UnsupportedError(
            "_pad_batch_pareto_frontier only supports a single "
            f"batch dimension, but got {len(batch_shape)} "
            "batch dimensions."
        )
    if feasibility_mask is not None:
        # set infeasible points to be the reference point (corresponding to the batch)
        Y = torch.where(feasibility_mask.unsqueeze(-1), Y, ref_point)
    if not is_pareto:
        pareto_mask = is_non_dominated(Y)
    else:
        pareto_mask = torch.ones(Y.shape[:-1], dtype=torch.bool, device=Y.device)
    better_than_ref = (Y > ref_point).all(dim=-1)
    # is_non_dominated assumes maximization
    # TODO: filter out points that are worse than the reference point first here
    pareto_mask = pareto_mask & better_than_ref
    if len(batch_shape) == 0:
        return Y[pareto_mask]
    # Note: in the batch case, the Pareto frontier is padded by repeating
    # a Pareto point. This ensures that the padded box-decomposition has
    # the same number of points, which enables fast batch operations.
    max_n_pareto = pareto_mask.sum(dim=-1).max().item()
    pareto_Y = torch.empty(*batch_shape, max_n_pareto, Y.shape[-1], **tkwargs)
    for i, pareto_i in enumerate(pareto_mask):
        pareto_i = Y[i, pareto_mask[i]]
        n_pareto = pareto_i.shape[0]
        if n_pareto > 0:
            pareto_Y[i, :n_pareto] = pareto_i
            # pad pareto_Y, so that all batches have the same size Pareto set
            pareto_Y[i, n_pareto:] = pareto_i[-1]
        else:
            # if there are no pareto points in this batch, use the reference
            # point
            pareto_Y[i, :] = ref_point[i]
    return pareto_Y


[docs] def compute_local_upper_bounds( U: Tensor, Z: Tensor, z: Tensor ) -> tuple[Tensor, Tensor]: r"""Compute local upper bounds. Note: this assumes minimization. This uses the incremental algorithm (Alg. 1) from [Lacour17]_. Args: U: A `n x m`-dim tensor containing the local upper bounds. Z: A `n x m x m`-dim tensor containing the defining points. z: A `m`-dim tensor containing the new point. Returns: 2-element tuple containing: - A new `n' x m`-dim tensor local upper bounds. - A `n' x m x m`-dim tensor containing the defining points. """ num_outcomes = U.shape[-1] z_dominates_U = (U > z).all(dim=-1) # Select upper bounds that are dominated by z. # These are the search zones that contain z. if not z_dominates_U.any(): return U, Z A = U[z_dominates_U] A_Z = Z[z_dominates_U] P = [] P_Z = [] mask = torch.ones(num_outcomes, dtype=torch.bool, device=U.device) for j in range(num_outcomes): mask[j] = 0 z_uj_max = A_Z[:, mask, j].max(dim=-1).values.view(-1) add_z = z[j] >= z_uj_max if add_z.any(): u_j = A[add_z].clone() u_j[:, j] = z[j] P.append(u_j) A_Z_filtered = A_Z[add_z] Z_ku = A_Z_filtered[:, mask] lt_zj = Z_ku[..., j] <= z[j] P_uj = torch.zeros( u_j.shape[0], num_outcomes, num_outcomes, dtype=U.dtype, device=U.device ) P_uj[:, mask] = Z_ku[lt_zj].view(P_uj.shape[0], num_outcomes - 1, -1) P_uj[:, ~mask] = z P_Z.append(P_uj) mask[j] = 1 # filter out elements of U that are in A not_z_dominates_U = ~z_dominates_U U = U[not_z_dominates_U] # remaining indices Z = Z[not_z_dominates_U] if len(P) > 0: # add points from P_Z Z = torch.cat([Z, *P_Z], dim=0) # return elements in P or elements in (U that are not in A) U = torch.cat([U, *P], dim=-2) return U, Z
[docs] def get_partition_bounds(Z: Tensor, U: Tensor, ref_point: Tensor) -> Tensor: r"""Get the cell bounds given the local upper bounds and the defining points. This implements Equation 2 in [Lacour17]_. Args: Z: A `n x m x m`-dim tensor containing the defining points. The first dimension corresponds to u_idx, the second dimension corresponds to j, and Z[u_idx, j] is the set of definining points Z^j(u) where u = U[u_idx]. U: A `n x m`-dim tensor containing the local upper bounds. ref_point: A `m`-dim tensor containing the reference point. Returns: A `2 x num_cells x m`-dim tensor containing the lower and upper vertices bounding each hypercell. """ bounds = torch.empty(2, U.shape[0], U.shape[-1], dtype=U.dtype, device=U.device) for u_idx in range(U.shape[0]): # z_1^1(u) bounds[0, u_idx, 0] = Z[u_idx, 0, 0] # z_1^r(u) bounds[1, u_idx, 0] = ref_point[0] for j in range(1, U.shape[-1]): bounds[0, u_idx, j] = Z[u_idx, :j, j].max() bounds[1, u_idx, j] = U[u_idx, j] # remove empty partitions # Note: the equality will evaluate as True if the lower and upper bound # are both (-inf), which could happen if the reference point is -inf. empty = (bounds[1] <= bounds[0]).any(dim=-1) return bounds[:, ~empty]
[docs] def update_local_upper_bounds_incremental( new_pareto_Y: Tensor, U: Tensor, Z: Tensor ) -> tuple[Tensor, Tensor]: r"""Update the current local upper with the new pareto points. This assumes minimization. Args: new_pareto_Y: A `n x m`-dim tensor containing the new Pareto points. U: A `n' x m`-dim tensor containing the local upper bounds. Z: A `n x m x m`-dim tensor containing the defining points. Returns: 2-element tuple containing: - A new `n' x m`-dim tensor local upper bounds. - A `n' x m x m`-dim tensor containing the defining points """ for i in range(new_pareto_Y.shape[-2]): U, Z = compute_local_upper_bounds(U=U, Z=Z, z=new_pareto_Y[i]) return U, Z
[docs] def compute_non_dominated_hypercell_bounds_2d( pareto_Y_sorted: Tensor, ref_point: Tensor ) -> Tensor: r"""Compute an axis-aligned partitioning of the non-dominated space for 2 objectives. Args: pareto_Y_sorted: A `(batch_shape) x n_pareto x 2`-dim tensor of pareto outcomes that are sorted by the 0th dimension in increasing order. All points must be better than the reference point. ref_point: A `(batch_shape) x 2`-dim reference point. Returns: A `2 x (batch_shape) x n_pareto + 1 x m`-dim tensor of cell bounds. """ # add boundary point to each front # the boundary point is the extreme value in each outcome # (a single coordinate of reference point) batch_shape = pareto_Y_sorted.shape[:-2] if ref_point.ndim == pareto_Y_sorted.ndim - 1: expanded_boundary_point = ref_point.unsqueeze(-2) else: view_shape = torch.Size([1] * len(batch_shape)) + torch.Size([1, 2]) expanded_shape = batch_shape + torch.Size([1, 2]) expanded_boundary_point = ref_point.view(view_shape).expand(expanded_shape) # add the points (ref, y) and (x, ref) to the corresponding ends pareto_Y_sorted0, pareto_Y_sorted1 = torch.split(pareto_Y_sorted, 1, dim=-1) expanded_boundary_point0, expanded_boundary_point1 = torch.split( expanded_boundary_point, 1, dim=-1 ) left_end = torch.cat( [expanded_boundary_point0[..., :1, :], pareto_Y_sorted1[..., :1, :]], dim=-1 ) right_end = torch.cat( [pareto_Y_sorted0[..., -1:, :], expanded_boundary_point1[..., :1, :]], dim=-1 ) front = torch.cat([left_end, pareto_Y_sorted, right_end], dim=-2) # The top left corners of axis-aligned rectangles in dominated partitioning. # These are the bottom left corners of the non-dominated partitioning front0, front1 = torch.split(front, 1, dim=-1) bottom_lefts = torch.cat([front0[..., :-1, :], front1[..., 1:, :]], dim=-1) top_right_xs = torch.cat( [ front0[..., 1:-1, :], torch.full( bottom_lefts.shape[:-2] + torch.Size([1, 1]), float("inf"), dtype=front.dtype, device=front.device, ), ], dim=-2, ) top_rights = torch.cat( [ top_right_xs, torch.full( bottom_lefts.shape[:-1] + torch.Size([1]), float("inf"), dtype=front.dtype, device=front.device, ), ], dim=-1, ) return torch.stack([bottom_lefts, top_rights], dim=0)
[docs] def compute_dominated_hypercell_bounds_2d( pareto_Y_sorted: Tensor, ref_point: Tensor ) -> Tensor: r"""Compute an axis-aligned partitioning of the dominated space for 2-objectives. Args: pareto_Y_sorted: A `(batch_shape) x n_pareto x 2`-dim tensor of pareto outcomes that are sorted by the 0th dimension in increasing order. ref_point: A `2`-dim reference point. Returns: A `2 x (batch_shape) x n_pareto x m`-dim tensor of cell bounds. """ # add boundary point to each front # the boundary point is the extreme value in each outcome # (a single coordinate of reference point) batch_shape = pareto_Y_sorted.shape[:-2] if ref_point.ndim == pareto_Y_sorted.ndim - 1: expanded_boundary_point = ref_point.unsqueeze(-2) else: view_shape = torch.Size([1] * len(batch_shape)) + torch.Size([1, 2]) expanded_shape = batch_shape + torch.Size([1, 2]) expanded_boundary_point = ref_point.view(view_shape).expand(expanded_shape) # add the points (ref, y) and (x, ref) to the corresponding ends pareto_Y_sorted0, pareto_Y_sorted1 = torch.split(pareto_Y_sorted, 1, dim=-1) expanded_boundary_point0, expanded_boundary_point1 = torch.split( expanded_boundary_point, 1, dim=-1 ) left_end = torch.cat( [expanded_boundary_point0[..., :1, :], pareto_Y_sorted0[..., :1, :]], dim=-1 ) right_end = torch.cat( [pareto_Y_sorted1[..., :1, :], expanded_boundary_point1[..., :1, :]], dim=-1 ) front = torch.cat([left_end, pareto_Y_sorted, right_end], dim=-2) # compute hypervolume by summing rectangles from min_x -> max_x top_rights = front[..., 1:-1, :] bottom_lefts = torch.cat( [ front[..., :-2, :1], expanded_boundary_point1.expand(*top_rights.shape[:-1], 1), ], dim=-1, ) return torch.stack([bottom_lefts, top_rights], dim=0)