Source code for botorch.utils.multitask

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Helpers for multitask modeling.
"""

from __future__ import annotations

from typing import List

import torch
from gpytorch.distributions import MultitaskMultivariateNormal
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from linear_operator import to_linear_operator


[docs] def separate_mtmvn(mvn: MultitaskMultivariateNormal) -> List[MultivariateNormal]: """ Separate a MTMVN into a list of MVNs, where covariance across data within each task are preserved, while covariance across task are dropped. """ # T150340766 Upstream as a class method on gpytorch MultitaskMultivariateNormal. full_covar = mvn.lazy_covariance_matrix num_data, num_tasks = mvn.mean.shape[-2:] if mvn._interleaved: data_indices = torch.arange( 0, num_data * num_tasks, num_tasks, device=full_covar.device ).view(-1, 1, 1) task_indices = torch.arange(num_tasks, device=full_covar.device) else: data_indices = torch.arange(num_data, device=full_covar.device).view(-1, 1, 1) task_indices = torch.arange( 0, num_data * num_tasks, num_data, device=full_covar.device ) slice_ = (data_indices + task_indices).transpose(-1, -3) data_covars = full_covar[..., slice_, slice_.transpose(-1, -2)] mvns = [] for c in range(num_tasks): mvns.append( MultivariateNormal( mvn.mean[..., c], to_linear_operator(data_covars[..., c, :, :]) ) ) return mvns