Source code for botorch.models.converter

#! /usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Utilities for converting between different models.
"""

from copy import deepcopy

import torch
from torch.nn import Module

from ..exceptions import UnsupportedError
from .gp_regression import FixedNoiseGP, HeteroskedasticSingleTaskGP
from .gpytorch import BatchedMultiOutputGPyTorchModel
from .model_list_gp_regression import ModelListGP


def _get_module(module: Module, name: str) -> Module:
    """Recursively get a sub-module from a module.

    Args:
        module: A `torch.nn.Module`.
        name: The name of the submodule to return, in the form of a period-delinated
            string: `sub_module.subsub_module.[...].leaf_module`.

    Returns:
        The requested sub-module.

    Example:
        >>> gp = SingleTaskGP(train_X, train_Y)
        >>> noise_prior = _get_module(gp, "likelihood.noise_covar.noise_prior")
    """
    current = module
    if name != "":
        for a in name.split("."):
            current = getattr(current, a)
    return current


def _check_compatibility(models: ModelListGP) -> None:
    """Check if a ModelListGP can be converted."""
    # check that all submodules are of the same type
    for modn, mod in models[0].named_modules():
        mcls = mod.__class__
        if not all(isinstance(_get_module(m, modn), mcls) for m in models[1:]):
            raise UnsupportedError(
                "Sub-modules must be of the same type across models."
            )

    # check that each model is a BatchedMultiOutputGPyTorchModel
    if not all(isinstance(m, BatchedMultiOutputGPyTorchModel) for m in models):
        raise UnsupportedError(
            "All models must be of type BatchedMultiOutputGPyTorchModel."
        )

    # TODO: Add support for HeteroskedasticSingleTaskGP
    if any(isinstance(m, HeteroskedasticSingleTaskGP) for m in models):
        raise NotImplementedError(
            "Conversion of HeteroskedasticSingleTaskGP is currently unsupported."
        )

    # TODO: Add support for custom likelihoods
    if any(getattr(m, "_is_custom_likelihood", False) for m in models):
        raise NotImplementedError(
            "Conversion of models with custom likelihoods is currently unsupported."
        )

    # check that each model is single-output
    if not all(m._num_outputs == 1 for m in models):
        raise UnsupportedError("All models must be single-output.")

    # check that training inputs are the same
    if not all(
        torch.equal(ti, tj)
        for m in models[1:]
        for ti, tj in zip(models[0].train_inputs, m.train_inputs)
    ):
        raise UnsupportedError("training inputs must agree for all sub-models.")


[docs]def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorchModel: """Convert a ModelListGP to a BatchedMultiOutputGPyTorchModel. Args: model_list: The `ModelListGP` to be converted to the appropriate `BatchedMultiOutputGPyTorchModel`. All sub-models must be of the same type and have the shape (batch shape and number of training inputs). Returns: The model converted into a `BatchedMultiOutputGPyTorchModel`. Example: >>> list_gp = ModelListGP(gp1, gp2) >>> batch_gp = model_list_to_batched(list_gp) """ models = model_list.models _check_compatibility(models) # if the list has only one model, we can just return a copy of that if len(models) == 1: return deepcopy(models[0]) # construct inputs train_X = deepcopy(models[0].train_inputs[0]) train_Y = torch.stack([m.train_targets.clone() for m in models], dim=-1) kwargs = {"train_X": train_X, "train_Y": train_Y} if isinstance(models[0], FixedNoiseGP): kwargs["train_Yvar"] = torch.stack( [m.likelihood.noise_covar.noise.clone() for m in models], dim=-1 ) # construct the batched GP model batch_gp = models[0].__class__(**kwargs) tensors = {n for n, p in batch_gp.state_dict().items() if len(p.shape) > 0} scalars = set(batch_gp.state_dict()) - tensors input_batch_dims = len(models[0]._input_batch_shape) # ensure scalars agree (TODO: Allow different priors for different outputs) for n in scalars: v0 = _get_module(models[0], n) if not all(torch.equal(_get_module(m, n), v0) for m in models[1:]): raise UnsupportedError("All scalars must have the same value.") # ensure dimensions of all tensors agree for n in tensors: shape0 = _get_module(models[0], n).shape if not all(_get_module(m, n).shape == shape0 for m in models[1:]): raise UnsupportedError("All tensors must have the same shape.") # now construct the batched state dict scalar_state_dict = { s: p.clone() for s, p in models[0].state_dict().items() if s in scalars } tensor_state_dict = { t: torch.stack( [m.state_dict()[t].clone() for m in models], dim=input_batch_dims ) for t in tensors } batch_state_dict = {**scalar_state_dict, **tensor_state_dict} # load the state dict into the new model batch_gp.load_state_dict(batch_state_dict) return batch_gp
[docs]def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> ModelListGP: """Convert a BatchedMultiOutputGPyTorchModel to a ModelListGP. Args: model_list: The `BatchedMultiOutputGPyTorchModel` to be converted to a `ModelListGP`. Returns: The model converted into a `ModelListGP`. Example: >>> train_X = torch.rand(5, 2) >>> train_Y = torch.rand(5, 2) >>> batch_gp = SingleTaskGP(train_X, train_Y) >>> list_gp = batched_to_model_list(batch_gp) """ # TODO: Add support for HeteroskedasticSingleTaskGP if isinstance(batch_model, HeteroskedasticSingleTaskGP): raise NotImplementedError( "Conversion of HeteroskedasticSingleTaskGP currently not supported." ) batch_sd = batch_model.state_dict() tensors = {n for n, p in batch_sd.items() if len(p.shape) > 0} scalars = set(batch_sd) - tensors input_bdims = len(batch_model._input_batch_shape) models = [] for i in range(batch_model._num_outputs): scalar_sd = {s: batch_sd[s].clone() for s in scalars} tensor_sd = {t: batch_sd[t].select(input_bdims, i).clone() for t in tensors} sd = {**scalar_sd, **tensor_sd} kwargs = { "train_X": batch_model.train_inputs[0].select(input_bdims, i).clone(), "train_Y": batch_model.train_targets.select(input_bdims, i) .clone() .unsqueeze(-1), } if isinstance(batch_model, FixedNoiseGP): noise_covar = batch_model.likelihood.noise_covar kwargs["train_Yvar"] = ( noise_covar.noise.select(input_bdims, i).clone().unsqueeze(-1) ) model = batch_model.__class__(**kwargs) model.load_state_dict(sd) models.append(model) return ModelListGP(*models)