Source code for botorch.utils.rounding

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Discretization (rounding) functions for acquisition optimization.

References

.. [Daulton2022bopr]
    S. Daulton, X. Wan, D. Eriksson, M. Balandat, M. A. Osborne, E. Bakshy.
    Bayesian Optimization over Discrete and Mixed Spaces via Probabilistic
    Reparameterization. Advances in Neural Information Processing Systems
    35, 2022.
"""

from __future__ import annotations

import torch
from torch import Tensor
from torch.autograd import Function
from torch.nn.functional import one_hot


[docs] def approximate_round(X: Tensor, tau: float = 1e-3) -> Tensor: r"""Diffentiable approximate rounding function. This method is a piecewise approximation of a rounding function where each piece is a hyperbolic tangent function. Args: X: The tensor to round to the nearest integer (element-wise). tau: A temperature hyperparameter. Returns: The approximately rounded input tensor. """ offset = X.floor() scaled_remainder = (X - offset - 0.5) / tau rounding_component = (torch.tanh(scaled_remainder) + 1) / 2 return offset + rounding_component
[docs] class IdentitySTEFunction(Function): """Base class for functions using straight through gradient estimators. This class approximates the gradient with the identity function. """
[docs] @staticmethod def backward(ctx, grad_output: Tensor) -> Tensor: r"""Use a straight-through estimator the gradient. This uses the identity function. Args: grad_output: A tensor of gradients. Returns: The provided tensor. """ return grad_output
[docs] class RoundSTE(IdentitySTEFunction): r"""Round the input tensor and use a straight-through gradient estimator. [Daulton2022bopr]_ proposes using this in acquisition optimization. """
[docs] @staticmethod def forward(ctx, X: Tensor) -> Tensor: r"""Round the input tensor element-wise. Args: X: The tensor to be rounded. Returns: A tensor where each element is rounded to the nearest integer. """ return X.round()
[docs] class OneHotArgmaxSTE(IdentitySTEFunction): r"""Discretize a continuous relaxation of a one-hot encoded categorical. This returns a one-hot encoded categorical and use a straight-through gradient estimator via an identity function. [Daulton2022bopr]_ proposes using this in acquisition optimization. """
[docs] @staticmethod def forward(ctx, X: Tensor) -> Tensor: r"""Discretize the input tensor. This applies a argmax along the last dimensions of the input tensor and one-hot encodes the result. Args: X: The tensor to be rounded. Returns: A tensor where each element is rounded to the nearest integer. """ return one_hot(X.argmax(dim=-1), num_classes=X.shape[-1]).to(X)