#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Probability Distributions.
This is modified from https://github.com/probtorch/pytorch/pull/143 and
https://github.com/tensorflow/probability/blob/v0.11.1/
tensorflow_probability/python/distributions/kumaraswamy.py.
TODO: replace with PyTorch version once the PR is up and landed.
"""
from __future__ import annotations
from typing import Tuple, Union
import torch
from torch import Tensor
from torch.distributions import constraints
from torch.distributions.gumbel import euler_constant
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import AffineTransform, PowerTransform
from torch.distributions.uniform import Uniform
from torch.distributions.utils import broadcast_all
def _weighted_logsumexp(
logx: Tensor, w: Tensor, dim: int, keepdim: bool = False
) -> Tuple[Tensor, Tensor]:
log_absw_x = logx + w.abs().log()
max_log_absw_x = torch.max(log_absw_x, dim=dim, keepdim=True).values
max_log_absw_x = torch.where(
torch.isinf(max_log_absw_x),
torch.zeros([], dtype=max_log_absw_x.dtype),
max_log_absw_x,
)
wx_over_max_absw_x = torch.sign(w) * torch.exp(log_absw_x - max_log_absw_x)
sum_wx_over_max_absw_x = wx_over_max_absw_x.sum(dim=dim, keepdim=keepdim)
if not keepdim:
max_log_absw_x = max_log_absw_x.squeeze(dim)
sgn = torch.sign(sum_wx_over_max_absw_x)
lswe = max_log_absw_x + torch.log(sgn * sum_wx_over_max_absw_x)
return lswe, sgn
def _log_moments(a: Tensor, b: Tensor, n: int) -> Tensor:
r"""Computes the logarithm of the n-th moment of the Kumaraswamy distribution.
Args:
a: 1st concentration parameter of the distribution
(often referred to as alpha)
b: 2nd concentration parameter of the distribution
(often referred to as beta)
n: The moment number
Returns:
The logarithm of the n-th moment of the Kumaraswamy distribution.
"""
arg1 = 1 + n / a
log_value = torch.lgamma(arg1) + torch.lgamma(b) - torch.lgamma(arg1 + b)
return b.log() + log_value
[docs]class Kumaraswamy(TransformedDistribution):
r"""A Kumaraswamy distribution.
Example::
>>> m = Kumaraswamy(torch.Tensor([1.0]), torch.Tensor([1.0]))
>>> m.sample() # sample from a Kumaraswamy distribution
tensor([ 0.1729])
Args:
concentration1: 1st concentration parameter of the distribution
(often referred to as alpha)
concentration0: 2nd concentration parameter of the distribution
(often referred to as beta)
"""
arg_constraints = {
"concentration1": constraints.positive,
"concentration0": constraints.positive,
}
support = constraints.unit_interval
has_rsample = True
def __init__(
self,
concentration1: Union[float, Tensor],
concentration0: Union[float, Tensor],
validate_args: bool = False,
):
self.concentration1, self.concentration0 = broadcast_all(
concentration1, concentration0
)
base_dist = Uniform(
torch.full_like(self.concentration0, 0.0),
torch.full_like(self.concentration0, 1.0),
)
transforms = [
AffineTransform(loc=1.0, scale=-1.0),
PowerTransform(exponent=self.concentration0.reciprocal()),
AffineTransform(loc=1.0, scale=-1.0),
PowerTransform(exponent=self.concentration1.reciprocal()),
]
super().__init__(base_dist, transforms, validate_args=validate_args)
[docs] def expand(
self, batch_shape: torch.Size, _instance: Kumaraswamy = None
) -> Kumaraswamy:
new = self._get_checked_instance(Kumaraswamy, _instance)
new.concentration1 = self.concentration1.expand(batch_shape)
new.concentration0 = self.concentration0.expand(batch_shape)
return super().expand(batch_shape, _instance=new)
@property
def mean(self) -> None:
return _log_moments(a=self.concentration1, b=self.concentration0, n=1).exp()
@property
def variance(self) -> None:
log_moment2 = _log_moments(a=self.concentration1, b=self.concentration0, n=2)
log_moment1 = _log_moments(a=self.concentration1, b=self.concentration0, n=1)
lswe, sgn = _weighted_logsumexp(
logx=torch.stack([log_moment2, 2 * log_moment1], dim=-1),
w=torch.tensor(
[1.0, -1.0], dtype=log_moment1.dtype, device=log_moment1.device
),
dim=-1,
)
return sgn * lswe.exp()
[docs] def entropy(self) -> None:
t1 = 1 - self.concentration1.reciprocal()
t0 = 1 - self.concentration0.reciprocal()
H0 = torch.digamma(self.concentration0 + 1) + euler_constant
return (
t0
+ t1 * H0
- torch.log(self.concentration1)
- torch.log(self.concentration0)
)