Source code for botorch.test_functions.utils
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from typing import Optional, Tuple
import torch
from torch import Tensor
[docs]def round_nearest(
X: Tensor, increment: float, bounds: Optional[Tuple[float, float]]
) -> Tensor:
r"""Rounds the input tensor to the nearest multiple of `increment`.
Args:
X: The input to be rounded.
increment: The increment to round to.
bounds: An optional tuple of two floats representing the lower and upper
bounds on `X`. If provided, this will round to the nearest multiple
of `increment` that lies within the bounds.
Returns:
The rounded input.
"""
X_round = torch.round(X / increment) * increment
if bounds is not None:
X_round = torch.where(X_round < bounds[0], X_round + increment, X_round)
X_round = torch.where(X_round > bounds[1], X_round - increment, X_round)
return X_round