# Source code for botorch.utils.rounding

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# LICENSE file in the root directory of this source tree.

r"""
Discretization (rounding) functions for acquisition optimization.

References

.. [Daulton2022bopr]
S. Daulton, X. Wan, D. Eriksson, M. Balandat, M. A. Osborne, E. Bakshy.
Bayesian Optimization over Discrete and Mixed Spaces via Probabilistic
Reparameterization. Advances in Neural Information Processing Systems
35, 2022.
"""

from __future__ import annotations

import torch
from torch import Tensor
from torch.nn.functional import one_hot

[docs]def approximate_round(X: Tensor, tau: float = 1e-3) -> Tensor:
r"""Diffentiable approximate rounding function.

This method is a piecewise approximation of a rounding function where
each piece is a hyperbolic tangent function.

Args:
X: The tensor to round to the nearest integer (element-wise).
tau: A temperature hyperparameter.

Returns:
The approximately rounded input tensor.
"""
offset = X.floor()
scaled_remainder = (X - offset - 0.5) / tau
rounding_component = (torch.tanh(scaled_remainder) + 1) / 2
return offset + rounding_component

[docs]class IdentitySTEFunction(Function):
"""Base class for functions using straight through gradient estimators.

This class approximates the gradient with the identity function.
"""

[docs]    @staticmethod
def backward(ctx, grad_output: Tensor) -> Tensor:
r"""Use a straight-through estimator the gradient.

This uses the identity function.

Args:

Returns:
The provided tensor.
"""

[docs]class RoundSTE(IdentitySTEFunction):
r"""Round the input tensor and use a straight-through gradient estimator.

[Daulton2022bopr]_ proposes using this in acquisition optimization.
"""

[docs]    @staticmethod
def forward(ctx, X: Tensor) -> Tensor:
r"""Round the input tensor element-wise.

Args:
X: The tensor to be rounded.

Returns:
A tensor where each element is rounded to the nearest integer.
"""
return X.round()

[docs]class OneHotArgmaxSTE(IdentitySTEFunction):
r"""Discretize a continuous relaxation of a one-hot encoded categorical.

This returns a one-hot encoded categorical and use a straight-through
gradient estimator via an identity function.

[Daulton2022bopr]_ proposes using this in acquisition optimization.
"""

[docs]    @staticmethod
def forward(ctx, X: Tensor) -> Tensor:
r"""Discretize the input tensor.

This applies a argmax along the last dimensions of the input tensor
and one-hot encodes the result.

Args:
X: The tensor to be rounded.

Returns:
A tensor where each element is rounded to the nearest integer.
"""
return one_hot(X.argmax(dim=-1), num_classes=X.shape[-1]).to(X)