#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass, InitVar
from itertools import chain
from typing import Any, Optional, Sequence

import torch
from botorch.utils.probability.utils import swap_along_dim_
from linear_operator.utils.errors import NotPSDError
from torch import LongTensor, Tensor
from torch.nn.functional import pad

[docs]def block_matrix_concat(blocks: Sequence[Sequence[Tensor]]) -> Tensor: rows = [] shape = torch.broadcast_shapes(*(x.shape[:-2] for x in chain.from_iterable(blocks))) for tensors in blocks: parts = [x.expand(*shape, *x.shape[-2:]) for x in tensors] if len(parts) > 1: rows.append(, dim=-1)) else: rows.extend(parts) return torch.concat(rows, dim=-2)
[docs]def augment_cholesky( Laa: Tensor, Kbb: Tensor, Kba: Optional[Tensor] = None, Lba: Optional[Tensor] = None, jitter: Optional[float] = None, ) -> Tensor: r"""Computes the Cholesky factor of a block matrix `K = [[Kaa, Kab], [Kba, Kbb]]` based on a precomputed Cholesky factor `Kaa = Laa Laa^T`. Args: Laa: Cholesky factor of K's upper left block. Kbb: Lower-right block of K. Kba: Lower-left block of K. Lba: Precomputed solve `Kba Laa^{-T}`. jitter: Optional nugget to be added to the diagonal of Kbb. """ if not (Kba is None) ^ (Lba is None): raise ValueError("One and only one of `Kba` or `Lba` must be provided.") if jitter is not None: Kbb = Kbb.clone() Kbb.diagonal(dim1=-2, dim2=-1).add_(jitter) if Lba is None: Lba = torch.linalg.solve_triangular( Laa.transpose(-2, -1), Kba, left=False, upper=True ) Lbb, info = torch.linalg.cholesky_ex(Kbb - Lba @ Lba.transpose(-2, -1)) if info.any(): raise NotPSDError( "Schur complement of `K` with respect to `Kaa` not PSD for the given " "Cholesky factor `Laa`" f"{'.' if jitter is None else f' and nugget jitter={jitter}.'}" ) n = Lbb.shape[-1] return block_matrix_concat(blocks=([pad(Laa, (0, n))], [Lba, Lbb]))
[docs]@dataclass class PivotedCholesky: step: int tril: Tensor perm: LongTensor diag: Optional[Tensor] = None validate_init: InitVar[bool] = True def __post_init__(self, validate_init: bool = True): if not validate_init: return if self.tril.shape[-2] != self.tril.shape[-1]: raise ValueError( f"Expected square matrices but `matrix` has shape `{self.tril.shape}`." ) if self.perm.shape != self.tril.shape[:-1]: raise ValueError( f"`perm` of shape `{self.perm.shape}` incompatible with " f"`matrix` of shape `{self.tril.shape}`." ) if self.diag is not None and self.diag.shape != self.tril.shape[:-1]: raise ValueError( f"`diag` of shape `{self.diag.shape}` incompatible with " f"`matrix` of shape `{self.tril.shape}`." ) def __getitem__(self, key: Any) -> PivotedCholesky: return PivotedCholesky( step=self.step, tril=self.tril[key], perm=self.perm[key], diag=None if self.diag is None else self.diag[key], )
[docs] def update_(self, eps: float = 1e-10) -> None: r"""Performs a single matrix decomposition step.""" i = self.step L = self.tril Lii = self.tril[..., i, i].clone().clip(min=0).sqrt() # Finalize `i-th` row and column of Cholesky factor L[..., i, i] = Lii L[..., i, i + 1 :] = 0 L[..., i + 1 :, i] = L[..., i + 1 :, i].clone() / Lii.unsqueeze(-1) # Update `tril(L[i + 1:, i + 1:])` to be the lower triangular part # of the Schur complement of `cov` with respect to `cov[:i, :i]`. rank1 = L[..., i + 1 :, i : i + 1].clone() rank1 = (rank1 * rank1.transpose(-1, -2)).tril() L[..., i + 1 :, i + 1 :] = L[..., i + 1 :, i + 1 :].clone() - rank1 L[Lii <= i * eps, i:, i] = 0 # numerical stability clause self.step += 1
[docs] def pivot_(self, pivot: LongTensor) -> None: *batch_shape, _, size = self.tril.shape if pivot.shape != tuple(batch_shape): raise ValueError("Argument `pivot` does to match with batch shape`.") # Perform basic swaps for key in ("perm", "diag"): tnsr = getattr(self, key, None) if tnsr is not None: swap_along_dim_(tnsr, i=self.step, j=pivot, dim=tnsr.ndim - 1) # Perform matrix swaps; prealloacte buffers for row/column linear indices size2 = size**2 min_pivot = pivot.min() tkwargs = {"device": pivot.device, "dtype": pivot.dtype} buffer_col = torch.arange(size * (1 + min_pivot), size2, size, **tkwargs) buffer_row = torch.arange(0, max(self.step, pivot.max()), **tkwargs) head = buffer_row[: self.step] indices_v1 = [] indices_v2 = [] for i, piv in enumerate(pivot.view(-1, 1)): v1 = pad(piv, (1, 0), value=self.step).unsqueeze(-1) v2 = pad(piv, (0, 1), value=self.step).unsqueeze(-1) start = i * size2 indices_v1.extend((start + v1 + size * v1).ravel()) indices_v2.extend((start + v2 + size * v2).ravel()) indices_v1.extend((start + size * v1 + head).ravel()) indices_v2.extend((start + size * v2 + head).ravel()) tail = buffer_col[piv - min_pivot :] indices_v1.extend((start + v1 + tail).ravel()) indices_v2.extend((start + v2 + tail).ravel()) interior = buffer_row[min(piv, self.step + 1) : piv] indices_v1.extend(start + size * interior + self.step) indices_v2.extend(start + size * piv + interior) swap_along_dim_( self.tril.view(-1), i=torch.as_tensor(indices_v1, **tkwargs), j=torch.as_tensor(indices_v2, **tkwargs), dim=0, )
[docs] def expand(self, *sizes: int) -> PivotedCholesky: fields = {} for name, ndim in {"perm": 1, "diag": 1, "tril": 2}.items(): src = getattr(self, name) if src is not None: fields[name] = src.expand(sizes + src.shape[-ndim:]) return type(self)(step=self.step, **fields)
[docs] def concat(self, other: PivotedCholesky, dim: int = 0) -> PivotedCholesky: if self.step != other.step: raise ValueError("Cannot conncatenate decompositions at different steps.") fields = {} for name in ("tril", "perm", "diag"): a = getattr(self, name) b = getattr(other, name) if type(a) is not type(b): raise NotImplementedError(f"Types of field {name} do not match.") if a is not None: fields[name] = torch.concat((a, b), dim=dim) return type(self)(step=self.step, **fields)
[docs] def detach(self) -> PivotedCholesky: fields = {} for name in ("tril", "perm", "diag"): obj = getattr(self, name) if obj is not None: fields[name] = obj.detach() return type(self)(step=self.step, **fields)
[docs] def clone(self) -> PivotedCholesky: fields = {} for name in ("tril", "perm", "diag"): obj = getattr(self, name) if obj is not None: fields[name] = obj.clone() return type(self)(step=self.step, **fields)