Source code for botorch.utils.constraints
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
r"""
Helpers for handling outcome constraints.
"""
from functools import partial
from typing import Callable, List, Optional, Tuple
import torch
from torch import Tensor
[docs]def get_outcome_constraint_transforms(
outcome_constraints: Optional[Tuple[Tensor, Tensor]],
) -> Optional[List[Callable[[Tensor], Tensor]]]:
r"""Create outcome constraint callables from outcome constraint tensors.
Args:
outcome_constraints: A tuple of `(A, b)`. For `k` outcome constraints
and `m` outputs at `f(x)``, `A` is `k x m` and `b` is `k x 1` such
that `A f(x) <= b`.
Returns:
A list of callables, each mapping a Tensor of size `b x q x m` to a
tensor of size `b x q`, where `m` is the number of outputs of the model.
Negative values imply feasibility. The callables support broadcasting
(e.g. for calling on a tensor of shape `mc_samples x b x q x m`).
Example:
>>> # constrain `f(x)[0] <= 0`
>>> A = torch.tensor([[1., 0.]])
>>> b = torch.tensor([[0.]])
>>> outcome_constraints = get_outcome_constraint_transforms((A, b))
"""
if outcome_constraints is None:
return None
A, b = outcome_constraints
def _oc(a: Tensor, rhs: Tensor, Y: Tensor) -> Tensor:
r"""Evaluate constraints.
Note: einsum multiples Y by a and sums over the `m`-dimension. Einsum
is ~2x faster than using `(Y * a.view(1, 1, -1)).sum(dim-1)`.
Args:
a: `m`-dim tensor of weights for the outcomes
rhs: Singleton tensor containing the outcome constraint value
Y: `... x b x q x m` tensor of function values
Returns:
A `... x b x q`-dim tensor where negative values imply feasibility
"""
lhs = torch.einsum("...m, m", [Y, a])
return lhs - rhs
return [partial(_oc, a, rhs) for a, rhs in zip(A, b)]