Source code for botorch.utils.probability.utils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from functools import lru_cache
from math import pi
from numbers import Number
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, Union

import torch
from numpy.polynomial.legendre import leggauss as numpy_leggauss
from torch import BoolTensor, LongTensor, Tensor

CaseNd = Tuple[Callable[[], BoolTensor], Callable[[BoolTensor], Tensor]]

_inv_sqrt_2pi = (2 * pi) ** -0.5
_neg_inv_sqrt2 = -(2**-0.5)
STANDARDIZED_RANGE: Tuple[float, float] = (-1e6, 1e6)


[docs]def case_dispatcher( out: Tensor, cases: Iterable[CaseNd] = (), default: Callable[[BoolTensor], Tensor] = None, ) -> Tensor: r"""Basic implementation of a tensorized switching case statement. Args: out: Tensor to which case outcomes are written. cases: Iterable of function pairs (pred, func), where `mask=pred()` specifies whether `func` is applicable for each entry in `out`. Note that cases are resolved first-come, first-serve. default: Optional `func` to which all unclaimed entries of `out` are dispatched. """ active = None for closure, func in cases: pred = closure() if not pred.any(): continue mask = pred if (active is None) else pred & active if not mask.any(): continue if mask.all(): # where possible, use Ellipsis to avoid indexing out[...] = func(...) return out out[mask] = func(mask) if active is None: active = ~mask else: active[mask] = False if not active.any(): break if default is not None: if active is None: out[...] = default(...) elif active.any(): out[active] = default(active) return out
[docs]@lru_cache(maxsize=None) def get_constants( values: Union[Number, Iterator[Number]], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> Union[Tensor, Tuple[Tensor, ...]]: r"""Returns scalar-valued Tensors containing each of the given constants. Used to expedite tensor operations involving scalar arithmetic. Note that the returned Tensors should not be modified in-place.""" if isinstance(values, Number): return torch.full((), values, dtype=dtype, device=device) return tuple(torch.full((), val, dtype=dtype, device=device) for val in values)
[docs]def get_constants_like( values: Union[Number, Iterator[Number]], ref: Tensor, ) -> Union[Tensor, Iterator[Tensor]]: return get_constants(values, device=ref.device, dtype=ref.dtype)
[docs]def gen_positional_indices( shape: torch.Size, dim: int, device: Optional[torch.device] = None, ) -> Iterator[torch.LongTensor]: ndim = len(shape) _dim = ndim + dim if dim < 0 else dim if _dim >= ndim or _dim < 0: raise ValueError(f"dim={dim} invalid for shape {shape}.") cumsize = shape[_dim + 1 :].numel() for i, s in enumerate(reversed(shape[: _dim + 1])): yield torch.arange(0, s * cumsize, cumsize, device=device)[(...,) + i * (None,)] cumsize *= s
[docs]def build_positional_indices( shape: torch.Size, dim: int, device: Optional[torch.device] = None, ) -> LongTensor: return sum(gen_positional_indices(shape=shape, dim=dim, device=device))
[docs]@lru_cache(maxsize=None) def leggauss(deg: int, **tkwargs: Any) -> Tuple[Tensor, Tensor]: x, w = numpy_leggauss(deg) return torch.as_tensor(x, **tkwargs), torch.as_tensor(w, **tkwargs)
[docs]def ndtr(x: Tensor) -> Tensor: r"""Standard normal CDF.""" half, ninv_sqrt2 = get_constants_like((0.5, _neg_inv_sqrt2), x) return half * torch.erfc(ninv_sqrt2 * x)
[docs]def phi(x: Tensor) -> Tensor: r"""Standard normal PDF.""" inv_sqrt_2pi, neg_half = get_constants_like((_inv_sqrt_2pi, -0.5), x) return inv_sqrt_2pi * (neg_half * x.square()).exp()
[docs]def swap_along_dim_( values: Tensor, i: Union[int, LongTensor], j: Union[int, LongTensor], dim: int, buffer: Optional[Tensor] = None, ) -> Tensor: r"""Swaps Tensor slices in-place along dimension `dim`. When passed as Tensors, `i` (and `j`) should be `dim`-dimensional tensors with the same shape as `values.shape[:dim]`. The xception to this rule occurs when `dim=0`, in which case `i` (and `j`) should be (at most) one-dimensional when passed as a Tensor. Args: values: Tensor whose values are to be swapped. i: Indices for slices along dimension `dim`. j: Indices for slices along dimension `dim`. dim: The dimension of `values` along which to swap slices. buffer: Optional buffer used internally to store copied values. Returns: The original `values` tensor. """ dim = values.ndim + dim if dim < 0 else dim if dim and ( (isinstance(i, Tensor) and i.ndim) or (isinstance(j, Tensor) and j.ndim) ): # Handle n-dimensional batches of heterogeneous swaps via linear indexing if isinstance(i, Tensor) and i.shape != values.shape[:dim]: raise ValueError("Batch shapes of `i` and `values` do not match.") if isinstance(j, Tensor) and j.shape != values.shape[:dim]: raise ValueError("Batch shapes of `j` and `values` do not match.") pidx = build_positional_indices( shape=values.shape[: dim + 1], dim=-2, device=values.device ) swap_along_dim_( values.view(-1, *values.shape[dim + 1 :]), i=(pidx + i).view(-1), j=(pidx + j).view(-1), dim=0, buffer=buffer, ) else: # Base cases: homogeneous swaps and 1-dimenensional heterogeneous swaps if isinstance(i, Tensor) and i.ndim > 1: raise ValueError("Tensor `i` must be at most 1-dimensional when `dim=0`.") if isinstance(j, Tensor) and j.ndim > 1: raise ValueError("Tensor `j` must be at most 1-dimensional when `dim=0`.") if dim: ctx = tuple(slice(None) for _ in range(dim)) i = ctx + (i,) j = ctx + (j,) if buffer is None: buffer = values[i].clone() else: buffer.copy_(values[i]) values[i] = values[j] values[j] = buffer return values