# Source code for botorch.utils.probability.linalg

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass, InitVar
from itertools import chain
from typing import Any, Optional, Sequence

import torch
from botorch.utils.probability.utils import swap_along_dim_
from linear_operator.utils.errors import NotPSDError
from torch import LongTensor, Tensor
[docs]def augment_cholesky( Laa: Tensor, Kbb: Tensor, Kba: Optional[Tensor] = None, Lba: Optional[Tensor] = None, jitter: Optional[float] = None, ) -> Tensor: r"""Computes the Cholesky factor of a block matrix K = [[Kaa, Kab], [Kba, Kbb]] based on a precomputed Cholesky factor Kaa = Laa Laa^T. Args: Laa: Cholesky factor of K's upper left block. Kbb: Lower-right block of K. Kba: Lower-left block of K. Lba: Precomputed solve Kba Laa^{-T}. jitter: Optional nugget to be added to the diagonal of Kbb. """ if not (Kba is None) ^ (Lba is None): raise ValueError("One and only one of Kba or Lba must be provided.") if jitter is not None: Kbb = Kbb.clone() Kbb.diagonal(dim1=-2, dim2=-1).add_(jitter) if Lba is None: Lba = torch.linalg.solve_triangular( Laa.transpose(-2, -1), Kba, left=False, upper=True ) Lbb, info = torch.linalg.cholesky_ex(Kbb - Lba @ Lba.transpose(-2, -1)) if info.any(): raise NotPSDError( "Schur complement of K with respect to Kaa not PSD for the given " "Cholesky factor Laa" f"{'.' if jitter is None else f' and nugget jitter={jitter}.'}" ) n = Lbb.shape[-1] return block_matrix_concat(blocks=([pad(Laa, (0, n))], [Lba, Lbb]))
[docs]@dataclass class PivotedCholesky: step: int tril: Tensor perm: LongTensor diag: Optional[Tensor] = None validate_init: InitVar[bool] = True def __post_init__(self, validate_init: bool = True): if not validate_init: return if self.tril.shape[-2] != self.tril.shape[-1]: raise ValueError( f"Expected square matrices but matrix has shape {self.tril.shape}." ) if self.perm.shape != self.tril.shape[:-1]: raise ValueError( f"perm of shape {self.perm.shape} incompatible with " f"matrix of shape {self.tril.shape}." ) if self.diag is not None and self.diag.shape != self.tril.shape[:-1]: raise ValueError( f"diag of shape {self.diag.shape} incompatible with " f"matrix of shape {self.tril.shape}." ) def __getitem__(self, key: Any) -> PivotedCholesky: return PivotedCholesky( step=self.step, tril=self.tril[key], perm=self.perm[key], diag=None if self.diag is None else self.diag[key], )
[docs] def update_(self, eps: float = 1e-10) -> None: r"""Performs a single matrix decomposition step.""" i = self.step L = self.tril Lii = self.tril[..., i, i].clone().clip(min=0).sqrt() # Finalize i-th row and column of Cholesky factor L[..., i, i] = Lii L[..., i, i + 1 :] = 0 L[..., i + 1 :, i] = L[..., i + 1 :, i].clone() / Lii.unsqueeze(-1) # Update tril(L[i + 1:, i + 1:]) to be the lower triangular part # of the Schur complement of cov with respect to cov[:i, :i]. rank1 = L[..., i + 1 :, i : i + 1].clone() rank1 = (rank1 * rank1.transpose(-1, -2)).tril() L[..., i + 1 :, i + 1 :] = L[..., i + 1 :, i + 1 :].clone() - rank1 L[Lii <= i * eps, i:, i] = 0 # numerical stability clause self.step += 1
[docs] def pivot_(self, pivot: LongTensor) -> None: *batch_shape, _, size = self.tril.shape if pivot.shape != tuple(batch_shape): raise ValueError("Argument pivot does to match with batch shape`.") # Perform basic swaps for key in ("perm", "diag"): tnsr = getattr(self, key, None) if tnsr is not None: swap_along_dim_(tnsr, i=self.step, j=pivot, dim=tnsr.ndim - 1) # Perform matrix swaps; prealloacte buffers for row/column linear indices size2 = size**2 min_pivot = pivot.min() tkwargs = {"device": pivot.device, "dtype": pivot.dtype} buffer_col = torch.arange(size * (1 + min_pivot), size2, size, **tkwargs) buffer_row = torch.arange(0, max(self.step, pivot.max()), **tkwargs) head = buffer_row[: self.step] indices_v1 = [] indices_v2 = [] for i, piv in enumerate(pivot.view(-1, 1)): v1 = pad(piv, (1, 0), value=self.step).unsqueeze(-1) v2 = pad(piv, (0, 1), value=self.step).unsqueeze(-1) start = i * size2 indices_v1.extend((start + v1 + size * v1).ravel()) indices_v2.extend((start + v2 + size * v2).ravel()) indices_v1.extend((start + size * v1 + head).ravel()) indices_v2.extend((start + size * v2 + head).ravel()) tail = buffer_col[piv - min_pivot :] indices_v1.extend((start + v1 + tail).ravel()) indices_v2.extend((start + v2 + tail).ravel()) interior = buffer_row[min(piv, self.step + 1) : piv] indices_v1.extend(start + size * interior + self.step) indices_v2.extend(start + size * piv + interior) swap_along_dim_( self.tril.view(-1), i=torch.as_tensor(indices_v1, **tkwargs), j=torch.as_tensor(indices_v2, **tkwargs), dim=0, )