Source code for botorch.distributions.distributions

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Probability Distributions.

This is modified from https://github.com/probtorch/pytorch/pull/143 and
https://github.com/tensorflow/probability/blob/v0.11.1/
tensorflow_probability/python/distributions/kumaraswamy.py.

TODO: replace with PyTorch version once the PR is up and landed.
"""

from __future__ import annotations

from typing import Tuple, Union

import torch
from torch import Tensor
from torch.distributions import constraints
from torch.distributions.gumbel import euler_constant
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import AffineTransform, PowerTransform
from torch.distributions.uniform import Uniform
from torch.distributions.utils import broadcast_all


def _weighted_logsumexp(
    logx: Tensor, w: Tensor, dim: int, keepdim: bool = False
) -> Tuple[Tensor, Tensor]:
    log_absw_x = logx + w.abs().log()
    max_log_absw_x = torch.max(log_absw_x, dim=dim, keepdim=True).values
    max_log_absw_x = torch.where(
        torch.isinf(max_log_absw_x),
        torch.zeros([], dtype=max_log_absw_x.dtype),
        max_log_absw_x,
    )
    wx_over_max_absw_x = torch.sign(w) * torch.exp(log_absw_x - max_log_absw_x)
    sum_wx_over_max_absw_x = wx_over_max_absw_x.sum(dim=dim, keepdim=keepdim)
    if not keepdim:
        max_log_absw_x = max_log_absw_x.squeeze(dim)
    sgn = torch.sign(sum_wx_over_max_absw_x)
    lswe = max_log_absw_x + torch.log(sgn * sum_wx_over_max_absw_x)
    return lswe, sgn


def _log_moments(a: Tensor, b: Tensor, n: int) -> Tensor:
    r"""Computes the logarithm of the n-th moment of the Kumaraswamy distribution.
    Args:
        a: 1st concentration parameter of the distribution
                (often referred to as alpha)
        b: 2nd concentration parameter of the distribution
            (often referred to as beta)
        n: The moment number

    Returns:
        The logarithm of the n-th moment of the Kumaraswamy distribution.
    """
    arg1 = 1 + n / a
    log_value = torch.lgamma(arg1) + torch.lgamma(b) - torch.lgamma(arg1 + b)
    return b.log() + log_value


[docs]class Kumaraswamy(TransformedDistribution): r"""A Kumaraswamy distribution. Example:: >>> m = Kumaraswamy(torch.Tensor([1.0]), torch.Tensor([1.0])) >>> m.sample() # sample from a Kumaraswamy distribution tensor([ 0.1729]) Args: concentration1: 1st concentration parameter of the distribution (often referred to as alpha) concentration0: 2nd concentration parameter of the distribution (often referred to as beta) """ arg_constraints = { "concentration1": constraints.positive, "concentration0": constraints.positive, } support = constraints.unit_interval has_rsample = True def __init__( self, concentration1: Union[float, Tensor], concentration0: Union[float, Tensor], validate_args: bool = False, ): self.concentration1, self.concentration0 = broadcast_all( concentration1, concentration0 ) base_dist = Uniform( torch.full_like(self.concentration0, 0.0), torch.full_like(self.concentration0, 1.0), ) transforms = [ AffineTransform(loc=1.0, scale=-1.0), PowerTransform(exponent=self.concentration0.reciprocal()), AffineTransform(loc=1.0, scale=-1.0), PowerTransform(exponent=self.concentration1.reciprocal()), ] super().__init__(base_dist, transforms, validate_args=validate_args)
[docs] def expand( self, batch_shape: torch.Size, _instance: Kumaraswamy = None ) -> Kumaraswamy: new = self._get_checked_instance(Kumaraswamy, _instance) new.concentration1 = self.concentration1.expand(batch_shape) new.concentration0 = self.concentration0.expand(batch_shape) return super().expand(batch_shape, _instance=new)
@property def mean(self) -> None: return _log_moments(a=self.concentration1, b=self.concentration0, n=1).exp() @property def variance(self) -> None: log_moment2 = _log_moments(a=self.concentration1, b=self.concentration0, n=2) log_moment1 = _log_moments(a=self.concentration1, b=self.concentration0, n=1) lswe, sgn = _weighted_logsumexp( logx=torch.stack([log_moment2, 2 * log_moment1], dim=-1), w=torch.tensor( [1.0, -1.0], dtype=log_moment1.dtype, device=log_moment1.device ), dim=-1, ) return sgn * lswe.exp()
[docs] def entropy(self) -> None: t1 = 1 - self.concentration1.reciprocal() t0 = 1 - self.concentration0.reciprocal() H0 = torch.digamma(self.concentration0 + 1) + euler_constant return ( t0 + t1 * H0 - torch.log(self.concentration1) - torch.log(self.concentration0) )