Source code for botorch.utils.rounding

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import torch
from torch import Tensor


[docs]def approximate_round(X: Tensor, tau: float = 1e-3) -> Tensor: r"""Diffentiable approximate rounding function. This method is a piecewise approximation of a rounding function where each piece is a hyperbolic tangent function. Args: X: The tensor to round to the nearest integer (element-wise). tau: A temperature hyperparameter. Returns: The approximately rounded input tensor. """ offset = X.floor() scaled_remainder = (X - offset - 0.5) / tau rounding_component = (torch.tanh(scaled_remainder) + 1) / 2 return offset + rounding_component