#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

Special implementations of mathematical functions that
solve numerical issues of naive implementations.

from __future__ import annotations

import math

from typing import Tuple, Union

import torch
from botorch.exceptions import UnsupportedError
from botorch.utils.constants import get_constants_like
from torch import finfo, Tensor
from torch.nn.functional import softplus

_log2 = math.log(2)
_inv_sqrt_3 = math.sqrt(1 / 3)

# Unary ops
[docs]def exp(x: Tensor, **kwargs) -> Tensor: info = finfo(x.dtype) maxexp = get_constants_like(math.log(info.max) - 1e-4, x) return torch.exp(x.clip(max=maxexp), **kwargs)
[docs]def log(x: Tensor, **kwargs) -> Tensor: info = finfo(x.dtype) return torch.log(x.clip(min=info.tiny), **kwargs)
# Binary ops
[docs]def add(a: Tensor, b: Tensor, **kwargs) -> Tensor: _0 = get_constants_like(0, a) case = a.isinf() & b.isinf() & (a != b) return torch.where(case, _0, a + b)
[docs]def sub(a: Tensor, b: Tensor) -> Tensor: _0 = get_constants_like(0, a) case = (a.isinf() & b.isinf()) & (a == b) return torch.where(case, _0, a - b)
[docs]def div(a: Tensor, b: Tensor) -> Tensor: _0, _1 = get_constants_like(values=(0, 1), ref=a) case = ((a == _0) & (b == _0)) | (a.isinf() & a.isinf()) return torch.where(case, torch.where(a != b, -_1, _1), a / torch.where(case, _1, b))
[docs]def mul(a: Tensor, b: Tensor) -> Tensor: _0 = get_constants_like(values=0, ref=a) case = (a.isinf() & (b == _0)) | (b.isinf() & (a == _0)) return torch.where(case, _0, a * torch.where(case, _0, b))
[docs]def log1mexp(x: Tensor) -> Tensor: """Numerically accurate evaluation of log(1 - exp(x)) for x < 0. See [Maechler2012accurate]_ for details. """ log2 = get_constants_like(values=_log2, ref=x) is_small = -log2 < x # x < 0 return torch.where( is_small, (-x.expm1()).log(), (-x.exp()).log1p(), )
[docs]def log1pexp(x: Tensor) -> Tensor: """Numerically accurate evaluation of log(1 + exp(x)). See [Maechler2012accurate]_ for details. """ mask = x <= 18 return torch.where( mask, (lambda z: z.exp().log1p())(x.masked_fill(~mask, 0)), (lambda z: z + (-z).exp())(x.masked_fill(mask, 0)), )
[docs]def logexpit(X: Tensor) -> Tensor: """Computes the logarithm of the expit (a.k.a. sigmoid) function.""" return -log1pexp(-X)
[docs]def logdiffexp(log_a: Tensor, log_b: Tensor) -> Tensor: """Computes log(b - a) accurately given log(a) and log(b). Assumes, log_b > log_a, i.e. b > a > 0. Args: log_a (Tensor): The logarithm of a, assumed to be less than log_b. log_b (Tensor): The logarithm of b, assumed to be larger than log_a. Returns: A Tensor of values corresponding to log(b - a). """ return log_b + log1mexp(log_a - log_b)
[docs]def logmeanexp( X: Tensor, dim: Union[int, Tuple[int, ...]], keepdim: bool = False ) -> Tensor: """Computes `log(mean(exp(X), dim=dim, keepdim=keepdim))`. Args: X: Values of which to compute the logmeanexp. dim: The dimension(s) over which to compute the mean. keepdim: If True, keeps the reduced dimensions. Returns: A Tensor of values corresponding to `log(mean(exp(X), dim=dim))`. """ n = X.shape[dim] if isinstance(dim, int) else[i] for i in dim) return torch.logsumexp(X, dim=dim, keepdim=keepdim) - math.log(n)
[docs]def log_softplus(x: Tensor, tau: Union[float, Tensor] = 1.0) -> Tensor: """Computes the logarithm of the softplus function with high numerical accuracy. Args: x: Input tensor, should have single or double precision floats. tau: Decreasing tau increases the tightness of the approximation to ReLU. Non-negative and defaults to 1.0. Returns: Tensor corresponding to `log(softplus(x))`. """ check_dtype_float32_or_float64(x) tau = torch.as_tensor(tau, dtype=x.dtype, device=x.device) # cutoff chosen to achieve accuracy to machine epsilon upper = 16 if x.dtype == torch.float32 else 32 lower = -15 if x.dtype == torch.float32 else -35 mask = x / tau > lower return torch.where( mask, softplus(x.masked_fill(~mask, lower), beta=(1 / tau), threshold=upper).log(), x / tau + tau.log(), )
[docs]def smooth_amax(X: Tensor, tau: Union[float, Tensor] = 1e-3, dim: int = -1) -> Tensor: """Computes a smooth approximation to `max(X, dim=dim)`, i.e the maximum value of `X` over dimension `dim`, using the logarithm of the `l_(1/tau)` norm of `exp(X)`. Note that when `X = log(U)` is the *logarithm* of an acquisition utility `U`, `logsumexp(log(U) / tau) * tau = log(sum(U^(1/tau))^tau) = log(norm(U, ord=(1/tau))` Args: X: A Tensor from which to compute the smoothed amax. tau: Temperature parameter controlling the smooth approximation to max operator, becomes tighter as tau goes to 0. Needs to be positive. Returns: A Tensor of smooth approximations to `max(X, dim=dim)`. """ # consider normalizing by log_n = math.log(X.shape[dim]) to reduce error return torch.logsumexp(X / tau, dim=dim) * tau # ~ X.amax(dim=dim)
[docs]def check_dtype_float32_or_float64(X: Tensor) -> None: if X.dtype != torch.float32 and X.dtype != torch.float64: raise UnsupportedError( f"Only dtypes float32 and float64 are supported, but received {X.dtype}." )
[docs]def log_fatplus(x: Tensor, tau: Union[float, Tensor] = 1.0) -> Tensor: """Computes the logarithm of the fat-tailed softplus. NOTE: Separated out in case the complexity of the `log` implementation increases in the future. """ return fatplus(x, tau=tau).log()
[docs]def fatplus(x: Tensor, tau: Union[float, Tensor] = 1.0) -> Tensor: """Computes a fat-tailed approximation to `ReLU(x) = max(x, 0)` by linearly combining a regular softplus function and the density function of a Cauchy distribution. The coefficient `alpha` of the Cauchy density is chosen to guarantee monotonicity and convexity. Args: x: A Tensor on whose values to compute the smoothed function. Returns: A Tensor of values of the fat-tailed softplus. """ def _fatplus(x: Tensor) -> Tensor: alpha = 1e-1 # guarantees monotonicity and convexity (TODO: ref + Lemma 4) return softplus(x) + alpha * cauchy(x) return tau * _fatplus(x / tau)
[docs]def fatmax(X: Tensor, dim: int, tau: Union[float, Tensor] = 1.0) -> Tensor: """Computes a smooth approximation to amax(X, dim=dim) with a fat tail. Args: X: A Tensor from which to compute the smoothed amax. tau: Temperature parameter controlling the smooth approximation to max operator, becomes tighter as tau goes to 0. Needs to be positive. standardize: Toggles the temperature standardization of the smoothed function. Returns: A Tensor of smooth approximations to `max(X, dim=dim)` with a fat tail. """ if X.shape[dim] == 1: return X.squeeze(dim) M = X.amax(dim=dim, keepdim=True) Y = (X - M) / tau # NOTE: this would cause NaNs when X has Infs. M = M.squeeze(dim) return M + tau * cauchy(Y).sum(dim=dim).log() # could change to mean
[docs]def log_fatmoid(X: Tensor, tau: Union[float, Tensor] = 1.0) -> Tensor: """Computes the logarithm of the fatmoid. Separated out in case the implementation of the logarithm becomes more complex in the future to ensure numerical stability. """ return fatmoid(X, tau=tau).log()
[docs]def fatmoid(X: Tensor, tau: Union[float, Tensor] = 1.0) -> Tensor: """Computes a twice continuously differentiable approximation to the Heaviside step function with a fat tail, i.e. `O(1 / x^2)` as `x` goes to -inf. Args: X: A Tensor from which to compute the smoothed step function. tau: Temperature parameter controlling the smoothness of the approximation. Returns: A tensor of fat-tailed approximations to the Heaviside step function. """ X = X / tau m = _inv_sqrt_3 # this defines the inflection point return torch.where( X < 0, 2 / 3 * cauchy(X - m), 1 - 2 / 3 * cauchy(X + m), )
[docs]def cauchy(x: Tensor) -> Tensor: """Computes a Lorentzian, i.e. an un-normalized Cauchy density function.""" return 1 / (1 + x.square())
[docs]def sigmoid(X: Tensor, log: bool = False, fat: bool = False) -> Tensor: """A sigmoid function with an optional fat tail and evaluation in log space for better numerical behavior. Notably, the fat-tailed sigmoid can be used to remedy numerical underflow problems in the value and gradient of the canonical sigmoid. Args: X: The Tensor on which to evaluate the sigmoid. log: Toggles the evaluation of the log sigmoid. fat: Toggles the evaluation of the fat-tailed sigmoid. Returns: A Tensor of (log-)sigmoid values. """ Y = log_fatmoid(X) if fat else logexpit(X) return Y if log else Y.exp()