#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Special implementations of mathematical functions that
solve numerical issues of naive implementations.
.. [Maechler2012accurate]
M. Mächler. Accurately Computing log (1 - exp (-| a|))
Assessed by the Rmpfr package. Technical report, 2012.
"""
from __future__ import annotations
import math
from typing import Callable, Union
import torch
from botorch.exceptions import UnsupportedError
from botorch.utils.constants import get_constants_like
from torch import finfo, Tensor
from torch.nn.functional import softplus
_log2 = math.log(2)
_inv_sqrt_3 = math.sqrt(1 / 3)
TAU = 1.0 # default temperature parameter for smooth approximations to non-linearities
ALPHA = 2.0 # default alpha parameter for the asymptotic power decay of _pareto
# Unary ops
[docs]
def exp(x: Tensor, **kwargs) -> Tensor:
info = finfo(x.dtype)
maxexp = get_constants_like(math.log(info.max) - 1e-4, x)
return torch.exp(x.clip(max=maxexp), **kwargs)
[docs]
def log(x: Tensor, **kwargs) -> Tensor:
info = finfo(x.dtype)
return torch.log(x.clip(min=info.tiny), **kwargs)
# Binary ops
[docs]
def add(a: Tensor, b: Tensor, **kwargs) -> Tensor:
_0 = get_constants_like(0, a)
case = a.isinf() & b.isinf() & (a != b)
return torch.where(case, _0, a + b)
[docs]
def sub(a: Tensor, b: Tensor) -> Tensor:
_0 = get_constants_like(0, a)
case = (a.isinf() & b.isinf()) & (a == b)
return torch.where(case, _0, a - b)
[docs]
def div(a: Tensor, b: Tensor) -> Tensor:
_0, _1 = get_constants_like(values=(0, 1), ref=a)
case = ((a == _0) & (b == _0)) | (a.isinf() & a.isinf())
return torch.where(case, torch.where(a != b, -_1, _1), a / torch.where(case, _1, b))
[docs]
def mul(a: Tensor, b: Tensor) -> Tensor:
_0 = get_constants_like(values=0, ref=a)
case = (a.isinf() & (b == _0)) | (b.isinf() & (a == _0))
return torch.where(case, _0, a * torch.where(case, _0, b))
[docs]
def log1mexp(x: Tensor) -> Tensor:
"""Numerically accurate evaluation of log(1 - exp(x)) for x < 0.
See [Maechler2012accurate]_ for details.
"""
log2 = get_constants_like(values=_log2, ref=x)
is_small = -log2 < x # x < 0
return torch.where(
is_small,
(-x.expm1()).log(),
(-x.exp()).log1p(),
)
[docs]
def log1pexp(x: Tensor) -> Tensor:
"""Numerically accurate evaluation of log(1 + exp(x)).
See [Maechler2012accurate]_ for details.
"""
mask = x <= 18
return torch.where(
mask,
(lambda z: z.exp().log1p())(x.masked_fill(~mask, 0)),
(lambda z: z + (-z).exp())(x.masked_fill(mask, 0)),
)
[docs]
def logexpit(X: Tensor) -> Tensor:
"""Computes the logarithm of the expit (a.k.a. sigmoid) function."""
return -log1pexp(-X)
[docs]
def logplusexp(a: Tensor, b: Tensor) -> Tensor:
"""Computes log(exp(a) + exp(b)) similar to logsumexp."""
ab = torch.stack(torch.broadcast_tensors(a, b), dim=-1)
return logsumexp(ab, dim=-1)
[docs]
def logdiffexp(log_a: Tensor, log_b: Tensor) -> Tensor:
"""Computes log(b - a) accurately given log(a) and log(b).
Assumes, log_b > log_a, i.e. b > a > 0.
Args:
log_a (Tensor): The logarithm of a, assumed to be less than log_b.
log_b (Tensor): The logarithm of b, assumed to be larger than log_a.
Returns:
A Tensor of values corresponding to log(b - a).
"""
log_a, log_b = torch.broadcast_tensors(log_a, log_b)
is_inf = log_b == -torch.inf # implies log_a == -torch.inf by assumption
return log_b + log1mexp(log_a - log_b.masked_fill(is_inf, 0.0))
[docs]
def logsumexp(
x: Tensor, dim: Union[int, tuple[int, ...]], keepdim: bool = False
) -> Tensor:
"""Version of logsumexp that has a well-behaved backward pass when
x contains infinities.
In particular, the gradient of the standard torch version becomes NaN
1) for any element that is positive infinity, and 2) for any slice that
only contains negative infinities.
This version returns a gradient of 1 for any positive infinities in case 1, and
for all elements of the slice in case 2, in agreement with the asymptotic behavior
of the function.
Args:
x: The Tensor to which to apply `logsumexp`.
dim: An integer or a tuple of integers, representing the dimensions to reduce.
keepdim: Whether to keep the reduced dimensions. Defaults to False.
Returns:
A Tensor representing the log of the summed exponentials of `x`.
"""
return _inf_max_helper(torch.logsumexp, x=x, dim=dim, keepdim=keepdim)
def _inf_max_helper(
max_fun: Callable[[Tensor], Tensor],
x: Tensor,
dim: Union[int, tuple[int, ...]],
keepdim: bool,
) -> Tensor:
"""Helper function that generalizes the treatment of infinities for approximations
to the maximum operator, i.e., `max(X, dim, keepdim)`. At the point of writing of
this function, it is used to define `logsumexp` and `fatmax`.
Args:
max_fun: The function that is used to smoothly penalize the difference of an
element to the true maximum.
x: The Tensor on which to compute the smooth approximation to the maximum.
dim: The dimension(s) to reduce over.
keepdim: Whether to keep the reduced dimension. Defaults to False.
Returns:
The Tensor representing the smooth approximation to the maximum over the
specified dimensions.
"""
M = x.amax(dim=dim, keepdim=True)
is_inf_max = torch.logical_and(*torch.broadcast_tensors(M.isinf(), x == M))
has_inf_max = _any(is_inf_max, dim=dim, keepdim=True)
y_inf = x.masked_fill(~is_inf_max, 0.0)
M_no_inf = M.masked_fill(M.isinf(), 0.0)
y_no_inf = x.masked_fill(has_inf_max, 0.0) - M_no_inf
res = torch.where(
has_inf_max,
y_inf.sum(dim=dim, keepdim=True),
M_no_inf + max_fun(y_no_inf, dim=dim, keepdim=True),
)
# NOTE: Using `sum` instead of `squeeze` because PyTorch < 2.0 does not support
# tuple `dim` arguments. `sum` and `squeeze` are equivalent here because the
# `dim` dimensions have length one after the reductions in the previous lines.
# TODO: Replace `sum` with `squeeze` once PyTorch >= 2.0 is required.
return res if keepdim else res.sum(dim=dim)
def _any(x: Tensor, dim: Union[int, tuple[int, ...]], keepdim: bool = False) -> Tensor:
"""Extension of torch.any, which supports reducing over tuples of dimensions.
Args:
x: The Tensor to reduce over.
dim: An integer or a tuple of integers, representing the dimensions to reduce.
keepdim: Whether to keep the reduced dimensions. Defaults to False.
Returns:
The Tensor corresponding to `any` over the specified dimensions.
"""
if isinstance(dim, tuple):
for d in dim:
x = x.any(dim=d, keepdim=True)
else:
x = x.any(dim, keepdim=True)
return x if keepdim else x.squeeze(dim)
[docs]
def logmeanexp(
X: Tensor, dim: Union[int, tuple[int, ...]], keepdim: bool = False
) -> Tensor:
"""Computes `log(mean(exp(X), dim=dim, keepdim=keepdim))`.
Args:
X: Values of which to compute the logmeanexp.
dim: The dimension(s) over which to compute the mean.
keepdim: If True, keeps the reduced dimensions.
Returns:
A Tensor of values corresponding to `log(mean(exp(X), dim=dim))`.
"""
n = X.shape[dim] if isinstance(dim, int) else math.prod(X.shape[i] for i in dim)
return logsumexp(X, dim=dim, keepdim=keepdim) - math.log(n)
[docs]
def log_softplus(x: Tensor, tau: Union[float, Tensor] = TAU) -> Tensor:
"""Computes the logarithm of the softplus function with high numerical accuracy.
Args:
x: Input tensor, should have single or double precision floats.
tau: Decreasing tau increases the tightness of the
approximation to ReLU. Non-negative and defaults to 1.0.
Returns:
Tensor corresponding to `log(softplus(x))`.
"""
check_dtype_float32_or_float64(x)
tau = torch.as_tensor(tau, dtype=x.dtype, device=x.device)
# cutoff chosen to achieve accuracy to machine epsilon
upper = 16 if x.dtype == torch.float32 else 32
lower = -15 if x.dtype == torch.float32 else -35
mask = x / tau > lower
return torch.where(
mask,
softplus(x.masked_fill(~mask, lower), beta=(1 / tau), threshold=upper).log(),
x / tau + tau.log(),
)
[docs]
def smooth_amax(
X: Tensor,
dim: Union[int, tuple[int, ...]] = -1,
keepdim: bool = False,
tau: Union[float, Tensor] = 1.0,
) -> Tensor:
"""Computes a smooth approximation to `max(X, dim=dim)`, i.e the maximum value of
`X` over dimension `dim`, using the logarithm of the `l_(1/tau)` norm of `exp(X)`.
Note that when `X = log(U)` is the *logarithm* of an acquisition utility `U`,
`logsumexp(log(U) / tau) * tau = log(sum(U^(1/tau))^tau) = log(norm(U, ord=(1/tau))`
Args:
X: A Tensor from which to compute the smoothed amax.
dim: The dimensions to reduce over.
keepdim: If True, keeps the reduced dimensions.
tau: Temperature parameter controlling the smooth approximation
to max operator, becomes tighter as tau goes to 0. Needs to be positive.
Returns:
A Tensor of smooth approximations to `max(X, dim=dim)`.
"""
# consider normalizing by log_n = math.log(X.shape[dim]) to reduce error
return logsumexp(X / tau, dim=dim, keepdim=keepdim) * tau # ~ X.amax(dim=dim)
[docs]
def smooth_amin(
X: Tensor,
dim: Union[int, tuple[int, ...]] = -1,
keepdim: bool = False,
tau: Union[float, Tensor] = 1.0,
) -> Tensor:
"""A smooth approximation to `min(X, dim=dim)`, similar to `smooth_amax`."""
return -smooth_amax(X=-X, dim=dim, keepdim=keepdim, tau=tau)
[docs]
def check_dtype_float32_or_float64(X: Tensor) -> None:
if X.dtype != torch.float32 and X.dtype != torch.float64:
raise UnsupportedError(
f"Only dtypes float32 and float64 are supported, but received {X.dtype}."
)
[docs]
def log_fatplus(x: Tensor, tau: Union[float, Tensor] = TAU) -> Tensor:
"""Computes the logarithm of the fat-tailed softplus.
NOTE: Separated out in case the complexity of the `log` implementation increases
in the future.
"""
return fatplus(x, tau=tau).log()
[docs]
def fatplus(x: Tensor, tau: Union[float, Tensor] = TAU) -> Tensor:
"""Computes a fat-tailed approximation to `ReLU(x) = max(x, 0)` by linearly
combining a regular softplus function and the density function of a Cauchy
distribution. The coefficient `alpha` of the Cauchy density is chosen to guarantee
monotonicity and convexity.
Args:
x: A Tensor on whose values to compute the smoothed function.
tau: Temperature parameter controlling the smoothness of the approximation.
Returns:
A Tensor of values of the fat-tailed softplus.
"""
def _fatplus(x: Tensor) -> Tensor:
alpha = 1e-1 # guarantees monotonicity and convexity (TODO: ref + Lemma 4)
return softplus(x) + alpha * cauchy(x)
return tau * _fatplus(x / tau)
[docs]
def fatmax(
x: Tensor,
dim: Union[int, tuple[int, ...]],
keepdim: bool = False,
tau: Union[float, Tensor] = TAU,
alpha: float = ALPHA,
) -> Tensor:
"""Computes a smooth approximation to amax(X, dim=dim) with a fat tail.
Args:
X: A Tensor from which to compute the smoothed maximum.
dim: The dimensions to reduce over.
keepdim: If True, keeps the reduced dimensions.
tau: Temperature parameter controlling the smooth approximation
to max operator, becomes tighter as tau goes to 0. Needs to be positive.
alpha: The exponent of the asymptotic power decay of the approximation. The
default value is 2. Higher alpha parameters make the function behave more
similarly to the standard logsumexp approximation to the max, so it is
recommended to keep this value low or moderate, e.g. < 10.
Returns:
A Tensor of smooth approximations to `amax(X, dim=dim)` with a fat tail.
"""
def max_fun(
x: Tensor, dim: Union[int, tuple[int, ...]], keepdim: bool = False
) -> Tensor:
return tau * _pareto(-x / tau, alpha=alpha).sum(dim=dim, keepdim=keepdim).log()
return _inf_max_helper(max_fun=max_fun, x=x, dim=dim, keepdim=keepdim)
[docs]
def fatmin(
x: Tensor,
dim: Union[int, tuple[int, ...]],
keepdim: bool = False,
tau: Union[float, Tensor] = TAU,
alpha: float = ALPHA,
) -> Tensor:
"""Computes a smooth approximation to amin(X, dim=dim) with a fat tail.
Args:
X: A Tensor from which to compute the smoothed minimum.
dim: The dimensions to reduce over.
keepdim: If True, keeps the reduced dimensions.
tau: Temperature parameter controlling the smooth approximation
to min operator, becomes tighter as tau goes to 0. Needs to be positive.
alpha: The exponent of the asymptotic power decay of the approximation. The
default value is 2. Higher alpha parameters make the function behave more
similarly to the standard logsumexp approximation to the max, so it is
recommended to keep this value low or moderate, e.g. < 10.
Returns:
A Tensor of smooth approximations to `amin(X, dim=dim)` with a fat tail.
"""
return -fatmax(-x, dim=dim, keepdim=keepdim, tau=tau, alpha=alpha)
[docs]
def fatmaximum(
a: Tensor, b: Tensor, tau: Union[float, Tensor] = TAU, alpha: float = ALPHA
) -> Tensor:
"""Computes a smooth approximation to torch.maximum(a, b) with a fat tail.
Args:
a: The first Tensor from which to compute the smoothed component-wise maximum.
b: The second Tensor from which to compute the smoothed component-wise maximum.
tau: Temperature parameter controlling the smoothness of the approximation. A
smaller tau corresponds to a tighter approximation that leads to a sharper
objective landscape that might be more difficult to optimize.
Returns:
A smooth approximation of torch.maximum(a, b).
"""
return fatmax(
torch.stack(torch.broadcast_tensors(a, b), dim=-1),
dim=-1,
keepdim=False,
tau=tau,
)
[docs]
def fatminimum(
a: Tensor, b: Tensor, tau: Union[float, Tensor] = TAU, alpha: float = ALPHA
) -> Tensor:
"""Computes a smooth approximation to torch.minimum(a, b) with a fat tail.
Args:
a: The first Tensor from which to compute the smoothed component-wise minimum.
b: The second Tensor from which to compute the smoothed component-wise minimum.
tau: Temperature parameter controlling the smoothness of the approximation. A
smaller tau corresponds to a tighter approximation that leads to a sharper
objective landscape that might be more difficult to optimize.
Returns:
A smooth approximation of torch.minimum(a, b).
"""
return -fatmaximum(-a, -b, tau=tau, alpha=alpha)
[docs]
def log_fatmoid(X: Tensor, tau: Union[float, Tensor] = 1.0) -> Tensor:
"""Computes the logarithm of the fatmoid. Separated out in case the implementation
of the logarithm becomes more complex in the future to ensure numerical stability.
"""
return fatmoid(X, tau=tau).log()
[docs]
def fatmoid(X: Tensor, tau: Union[float, Tensor] = 1.0) -> Tensor:
"""Computes a twice continuously differentiable approximation to the Heaviside
step function with a fat tail, i.e. `O(1 / x^2)` as `x` goes to -inf.
Args:
X: A Tensor from which to compute the smoothed step function.
tau: Temperature parameter controlling the smoothness of the approximation.
Returns:
A tensor of fat-tailed approximations to the Heaviside step function.
"""
X = X / tau
m = _inv_sqrt_3 # this defines the inflection point
return torch.where(
X < 0,
2 / 3 * cauchy(X - m),
1 - 2 / 3 * cauchy(X + m),
)
[docs]
def cauchy(x: Tensor) -> Tensor:
"""Computes a Lorentzian, i.e. an un-normalized Cauchy density function."""
return 1 / (1 + x.square())
def _pareto(x: Tensor, alpha: float, check: bool = True) -> Tensor:
"""Computes a rational polynomial that is
1) monotonically decreasing for `x > 0`,
2) is equal to 1 at `x = 0`,
3) has a first and second derivative of 1 at `x = 0`, and
4) has an asymptotic decay of `O(1 / x^alpha)`.
These properties make it possible to use the function to define a smooth and
fat-tailed approximation to the maximum, which enables better gradient propagation,
see `fatmax` for details.
Args:
x: The input tensor.
alpha: The exponent of the asymptotic decay.
check: Whether to check if the input tensor only contains non-negative values.
Returns:
The tensor corresponding to the rational polynomial with the stated properties.
"""
if check and (x < 0).any():
raise ValueError("Argument `x` must be non-negative.")
alpha = alpha / 2 # so that alpha stands for the power decay
# choosing beta_0, beta_1 so that first and second derivatives at x = 0 are 1.
beta_1 = 2 * alpha
beta_0 = alpha * beta_1
return (beta_0 / (beta_0 + beta_1 * x + x.square())).pow(alpha)
[docs]
def sigmoid(X: Tensor, log: bool = False, fat: bool = False) -> Tensor:
"""A sigmoid function with an optional fat tail and evaluation in log space for
better numerical behavior. Notably, the fat-tailed sigmoid can be used to remedy
numerical underflow problems in the value and gradient of the canonical sigmoid.
Args:
X: The Tensor on which to evaluate the sigmoid.
log: Toggles the evaluation of the log sigmoid.
fat: Toggles the evaluation of the fat-tailed sigmoid.
Returns:
A Tensor of (log-)sigmoid values.
"""
Y = log_fatmoid(X) if fat else logexpit(X)
return Y if log else Y.exp()