#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Risk Measures implemented as Monte-Carlo objectives, based on Bayesian
optimization of risk measures as introduced in [Cakmak2020risk]_. For a
broader discussion of Monte-Carlo methods for VaR and CVaR risk measures,
see also [Hong2014review]_.
.. [Cakmak2020risk]
S. Cakmak, R. Astudillo, P. Frazier, and E. Zhou. Bayesian Optimization of
Risk Measures. Advances in Neural Information Processing Systems 33, 2020.
.. [Hong2014review]
L. J. Hong, Z. Hu, and G. Liu. Monte carlo methods for value-at-risk and
conditional value-at-risk: a review. ACM Transactions on Modeling and
Computer Simulation, 2014.
"""
from abc import ABC, abstractmethod
from math import ceil
from typing import Optional
from botorch.acquisition.objective import MCAcquisitionObjective
from torch import Tensor
[docs]class RiskMeasureMCObjective(MCAcquisitionObjective, ABC):
r"""Objective transforming the posterior samples to samples of a risk measure.
The risk measure is calculated over joint q-batch samples from the posterior.
If the q-batch includes samples corresponding to multiple inputs, it is assumed
that first `n_w` samples correspond to first input, second `n_w` samples
correspond to second input etc.
The risk measures are commonly defined for minimization by considering the
upper tail of the distribution, i.e., treating larger values as being undesirable.
BoTorch by default assumes a maximization objective, so the default behavior here
is to calculate the risk measures w.r.t. the lower tail of the distribution.
This can be changed by passing `weights=torch.tensor([-1.0])`.
"""
def __init__(
self,
n_w: int,
weights: Optional[Tensor] = None,
) -> None:
r"""Transform the posterior samples to samples of a risk measure.
Args:
n_w: The size of the `w_set` to calculate the risk measure over.
weights: An optional `m`-dim tensor of weights for scalarizing
multi-output samples before calculating the risk measure.
"""
super().__init__()
self.n_w = n_w
self.weights = weights
def _prepare_samples(self, samples: Tensor) -> Tensor:
r"""Prepare samples for risk measure calculations by scalarizing and
separating out the q-batch dimension.
Args:
samples: A `sample_shape x batch_shape x (q * n_w) x m`-dim tensor of
posterior samples. The q-batches should be ordered so that each
`n_w` block of samples correspond to the same input.
Returns:
A `sample_shape x batch_shape x q x n_w`-dim tensor of prepared samples.
"""
if samples.shape[-1] > 1 and self.weights is None:
raise RuntimeError(
"Multi-output samples require `weights` for scalarization!"
)
if self.weights is not None:
samples = samples @ self.weights
else:
samples = samples.squeeze(-1)
return samples.view(*samples.shape[:-1], -1, self.n_w)
[docs] @abstractmethod
def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
r"""Calculate the risk measure corresponding to the given samples.
Args:
samples: A `sample_shape x batch_shape x (q * n_w) x m`-dim tensor of
posterior samples. The q-batches should be ordered so that each
`n_w` block of samples correspond to the same input.
X: A `batch_shape x q x d`-dim tensor of inputs. Ignored.
Returns:
A `sample_shape x batch_shape x q`-dim tensor of risk measure samples.
"""
pass # pragma: no cover
[docs]class CVaR(RiskMeasureMCObjective):
r"""The Conditional Value-at-Risk risk measure.
The Conditional Value-at-Risk measures the expectation of the worst outcomes
(small rewards or large losses) with a total probability of `1 - alpha`. It
is commonly defined as the conditional expectation of the reward function,
with the condition that the reward is smaller than the corresponding
Value-at-Risk (also defined below).
Note: Due to the use of a discrete `w_set` of samples, the VaR and CVaR
calculated here are (possibly biased) Monte-Carlo approximations of
the true risk measures.
"""
def __init__(
self,
alpha: float,
n_w: int,
weights: Optional[Tensor] = None,
) -> None:
r"""Transform the posterior samples to samples of a risk measure.
Args:
alpha: The risk level, float in `(0.0, 1.0]`.
n_w: The size of the `w_set` to calculate the risk measure over.
weights: An optional `m`-dim tensor of weights for scalarizing
multi-objective samples before calculating the risk measure.
"""
super().__init__(n_w=n_w, weights=weights)
if not 0 < alpha <= 1:
raise ValueError("alpha must be in (0.0, 1.0]")
self.alpha = alpha
self.alpha_idx = ceil(n_w * alpha) - 1
[docs] def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
r"""Calculate the CVaR corresponding to the given samples.
Args:
samples: A `sample_shape x batch_shape x (q * n_w) x m`-dim tensor of
posterior samples. The q-batches should be ordered so that each
`n_w` block of samples correspond to the same input.
X: A `batch_shape x q x d`-dim tensor of inputs. Ignored.
Returns:
A `sample_shape x batch_shape x q`-dim tensor of CVaR samples.
"""
prepared_samples = self._prepare_samples(samples)
sorted_samples = prepared_samples.sort(dim=-1, descending=True).values
return sorted_samples[..., self.alpha_idx :].mean(dim=-1)
[docs]class VaR(CVaR):
r"""The Value-at-Risk risk measure.
Value-at-Risk measures the smallest possible reward (or largest possible loss)
after excluding the worst outcomes with a total probability of `1 - alpha`. It
is commonly used in financial risk management, and it corresponds to the
`1 - alpha` quantile of a given random variable.
"""
[docs] def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
r"""Calculate the VaR corresponding to the given samples.
Args:
samples: A `sample_shape x batch_shape x (q * n_w) x m`-dim tensor of
posterior samples. The q-batches should be ordered so that each
`n_w` block of samples correspond to the same input.
X: A `batch_shape x q x d`-dim tensor of inputs. Ignored.
Returns:
A `sample_shape x batch_shape x q`-dim tensor of VaR samples.
"""
prepared_samples = self._prepare_samples(samples)
sorted_samples = prepared_samples.sort(dim=-1, descending=True).values
return sorted_samples[..., self.alpha_idx]
[docs]class WorstCase(RiskMeasureMCObjective):
r"""The worst-case risk measure."""
[docs] def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
r"""Calculate the worst-case measure corresponding to the given samples.
Args:
samples: A `sample_shape x batch_shape x (q * n_w) x m`-dim tensor of
posterior samples. The q-batches should be ordered so that each
`n_w` block of samples correspond to the same input.
X: A `batch_shape x q x d`-dim tensor of inputs. Ignored.
Returns:
A `sample_shape x batch_shape x q`-dim tensor of worst-case samples.
"""
prepared_samples = self._prepare_samples(samples)
return prepared_samples.min(dim=-1).values