Source code for botorch.models.converter

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Utilities for converting between different models.
"""

from __future__ import annotations

import warnings
from copy import deepcopy
from typing import Dict, Optional, Set, Tuple

import torch
from botorch.exceptions import UnsupportedError
from botorch.exceptions.warnings import BotorchWarning
from botorch.models.gp_regression import HeteroskedasticSingleTaskGP
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from torch import Tensor
from torch.nn import Module, ModuleList

DEPRECATION_MESSAGE = (
    "Model converter code is deprecated and will be removed in v0.13 release. "
    "Its correct behavior is dependent on some assumptions about model priors "
    "that do not always hold. Use it at your own risk! See "
    "https://github.com/cornellius-gp/gpytorch/issues/2550."
)


def _get_module(module: Module, name: str) -> Module:
    """Recursively get a sub-module from a module.

    Args:
        module: A `torch.nn.Module`.
        name: The name of the submodule to return, in the form of a period-delinated
            string: `sub_module.subsub_module.[...].leaf_module`.

    Returns:
        The requested sub-module.

    Example:
        >>> gp = SingleTaskGP(train_X, train_Y)
        >>> noise_prior = _get_module(gp, "likelihood.noise_covar.noise_prior")
    """
    current = module
    if name != "":
        for a in name.split("."):
            current = getattr(current, a)
    return current


def _check_compatibility(models: ModuleList) -> None:
    """Check if the submodels of a ModelListGP are compatible with the converter."""
    # Check that all submodules are of the same type.
    for modn, mod in models[0].named_modules():
        mcls = mod.__class__
        if not all(isinstance(_get_module(m, modn), mcls) for m in models[1:]):
            raise UnsupportedError(
                "Sub-modules must be of the same type across models."
            )
        if "prior" in modn and len(mod.state_dict()) == 0:
            warnings.warn(
                "Model converter cannot verify compatibility of GPyTorch priors "
                "that do not register their parameters as buffers. If the prior "
                "is different than the default prior set by the model constructor "
                "this may not work correctly. Use it at your own risk! See "
                "https://github.com/cornellius-gp/gpytorch/issues/2550.",
                BotorchWarning,
                stacklevel=3,
            )

    # Check that each model is a BatchedMultiOutputGPyTorchModel.
    if not all(isinstance(m, BatchedMultiOutputGPyTorchModel) for m in models):
        raise UnsupportedError(
            "All models must be of type BatchedMultiOutputGPyTorchModel."
        )

    # TODO: Add support for HeteroskedasticSingleTaskGP.
    if any(isinstance(m, HeteroskedasticSingleTaskGP) for m in models):
        raise NotImplementedError(
            "Conversion of HeteroskedasticSingleTaskGP is currently unsupported."
        )

    # TODO: Add support for custom likelihoods.
    if any(getattr(m, "_is_custom_likelihood", False) for m in models):
        raise NotImplementedError(
            "Conversion of models with custom likelihoods is currently unsupported."
        )

    # TODO: Add support for outcome transforms.
    if any(getattr(m, "outcome_transform", None) is not None for m in models):
        raise UnsupportedError(
            "Conversion of models with outcome transforms is currently unsupported."
        )

    # check that each model is single-output
    if not all(m._num_outputs == 1 for m in models):
        raise UnsupportedError("All models must be single-output.")

    # check that training inputs are the same
    if not all(
        torch.equal(ti, tj)
        for m in models[1:]
        for ti, tj in zip(models[0].train_inputs, m.train_inputs)
    ):
        raise UnsupportedError("training inputs must agree for all sub-models.")

    # check that there are no batched input transforms
    default_size = torch.Size([])
    for m in models:
        if hasattr(m, "input_transform"):
            if (
                m.input_transform is not None
                and len(getattr(m.input_transform, "batch_shape", default_size)) != 0
            ):
                raise UnsupportedError("Batched input_transforms are not supported.")

    # check that all models have the same input transforms
    if any(hasattr(m, "input_transform") for m in models):
        if not all(
            m.input_transform.equals(models[0].input_transform) for m in models[1:]
        ):
            raise UnsupportedError("All models must have the same input_transforms.")


[docs] def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorchModel: """Convert a ModelListGP to a BatchedMultiOutputGPyTorchModel. Args: model_list: The `ModelListGP` to be converted to the appropriate `BatchedMultiOutputGPyTorchModel`. All sub-models must be of the same type and have the shape (batch shape and number of training inputs). Returns: The model converted into a `BatchedMultiOutputGPyTorchModel`. Example: >>> list_gp = ModelListGP(gp1, gp2) >>> batch_gp = model_list_to_batched(list_gp) """ warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2) was_training = model_list.training model_list.train() models = model_list.models _check_compatibility(models) # if the list has only one model, we can just return a copy of that if len(models) == 1: return deepcopy(models[0]) # construct inputs train_X = deepcopy(models[0].train_inputs[0]) train_Y = torch.stack([m.train_targets.clone() for m in models], dim=-1) kwargs = {"train_X": train_X, "train_Y": train_Y} if isinstance(models[0].likelihood, FixedNoiseGaussianLikelihood): kwargs["train_Yvar"] = torch.stack( [m.likelihood.noise_covar.noise.clone() for m in models], dim=-1 ) if isinstance(models[0], SingleTaskMultiFidelityGP): init_args = models[0]._init_args if not all( v == m._init_args[k] for m in models[1:] for k, v in init_args.items() ): raise UnsupportedError("All models must have the same fidelity parameters.") kwargs.update(init_args) # add batched kernel, except if the model type is SingleTaskMultiFidelityGP, # which does not have a `covar_module` if not isinstance(models[0], SingleTaskMultiFidelityGP): batch_length = len(models) covar_module = _batched_kernel(models[0].covar_module, batch_length) kwargs["covar_module"] = covar_module # construct the batched GP model input_transform = getattr(models[0], "input_transform", None) batch_gp = models[0].__class__(input_transform=input_transform, **kwargs) adjusted_batch_keys, non_adjusted_batch_keys = _get_adjusted_batch_keys( batch_state_dict=batch_gp.state_dict(), input_transform=input_transform ) input_batch_dims = len(models[0]._input_batch_shape) # ensure scalars agree (TODO: Allow different priors for different outputs) for n in non_adjusted_batch_keys: v0 = _get_module(models[0], n) if not all(torch.equal(_get_module(m, n), v0) for m in models[1:]): raise UnsupportedError("All scalars must have the same value.") # ensure dimensions of all tensors agree for n in adjusted_batch_keys: shape0 = _get_module(models[0], n).shape if not all(_get_module(m, n).shape == shape0 for m in models[1:]): raise UnsupportedError("All tensors must have the same shape.") # now construct the batched state dict non_adjusted_batch_state_dict = { s: p.clone() for s, p in models[0].state_dict().items() if s in non_adjusted_batch_keys } adjusted_batch_state_dict = { t: ( torch.stack( [m.state_dict()[t].clone() for m in models], dim=input_batch_dims ) if "active_dims" not in t else models[0].state_dict()[t].clone() ) for t in adjusted_batch_keys } batch_state_dict = {**non_adjusted_batch_state_dict, **adjusted_batch_state_dict} # load the state dict into the new model batch_gp.load_state_dict(batch_state_dict) return batch_gp.train(mode=was_training)
def _batched_kernel(kernel, batch_length: int): """Adds a batch dimension of size `batch_length` to all non-scalar Tensor parameters that govern the kernel function `kernel`. NOTE: prior or constraint parameters are excluded from batching. """ # copy just in case there are non-tensor parameters that are passed by reference kernel = deepcopy(kernel) search_str = "raw_outputscale" for key, attr in kernel.state_dict().items(): if isinstance(attr, Tensor) and ( attr.ndim > 0 or (search_str == key.rpartition(".")[-1]) ): attr = attr.unsqueeze(0).expand(batch_length, *attr.shape).clone() set_attribute(kernel, key, torch.nn.Parameter(attr)) return kernel # two helper functions for `batched_kernel` # like `setattr` and `getattr` for object hierarchies
[docs] def set_attribute(obj, attr: str, val): """Like `setattr` but works with hierarchical attribute specification. E.g. if obj=Zoo(), and attr="tiger.age", set_attribute(obj, attr, 3), would set the Zoo's tiger's age to three. """ path_to_leaf, _, attr_name = attr.rpartition(".") leaf = get_attribute(obj, path_to_leaf) if path_to_leaf else obj setattr(leaf, attr_name, val)
[docs] def get_attribute(obj, attr: str): """Like `getattr` but works with hierarchical attribute specification. E.g. if obj=Zoo(), and attr="tiger.age", get_attribute(obj, attr), would return the Zoo's tiger's age. """ attr_names = attr.split(".") while attr_names: obj = getattr(obj, attr_names.pop(0)) return obj
[docs] def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> ModelListGP: """Convert a BatchedMultiOutputGPyTorchModel to a ModelListGP. Args: batch_model: The `BatchedMultiOutputGPyTorchModel` to be converted to a `ModelListGP`. Returns: The model converted into a `ModelListGP`. Example: >>> train_X = torch.rand(5, 2) >>> train_Y = torch.rand(5, 2) >>> batch_gp = SingleTaskGP(train_X, train_Y) >>> list_gp = batched_to_model_list(batch_gp) """ warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2) was_training = batch_model.training batch_model.train() # TODO: Add support for HeteroskedasticSingleTaskGP. if isinstance(batch_model, HeteroskedasticSingleTaskGP): raise NotImplementedError( "Conversion of HeteroskedasticSingleTaskGP is currently not supported." ) if isinstance(batch_model, MixedSingleTaskGP): raise NotImplementedError( "Conversion of MixedSingleTaskGP is currently not supported." ) input_transform = getattr(batch_model, "input_transform", None) outcome_transform = getattr(batch_model, "outcome_transform", None) batch_sd = batch_model.state_dict() adjusted_batch_keys, non_adjusted_batch_keys = _get_adjusted_batch_keys( batch_state_dict=batch_sd, input_transform=input_transform, outcome_transform=outcome_transform, ) input_bdims = len(batch_model._input_batch_shape) models = [] for i in range(batch_model._num_outputs): non_adjusted_batch_sd = { s: batch_sd[s].clone() for s in non_adjusted_batch_keys } adjusted_batch_sd = { t: ( batch_sd[t].select(input_bdims, i).clone() if "active_dims" not in t else batch_sd[t].clone() ) for t in adjusted_batch_keys } sd = {**non_adjusted_batch_sd, **adjusted_batch_sd} kwargs = { "train_X": batch_model.train_inputs[0].select(input_bdims, i).clone(), "train_Y": batch_model.train_targets.select(input_bdims, i) .clone() .unsqueeze(-1), } if isinstance(batch_model.likelihood, FixedNoiseGaussianLikelihood): noise_covar = batch_model.likelihood.noise_covar kwargs["train_Yvar"] = ( noise_covar.noise.select(input_bdims, i).clone().unsqueeze(-1) ) if isinstance(batch_model, SingleTaskMultiFidelityGP): kwargs.update(batch_model._init_args) # NOTE: Adding outcome transform to kwargs to avoid the multiple # values for same kwarg issue with SingleTaskMultiFidelityGP. if outcome_transform is not None: octf = outcome_transform.subset_output(idcs=[i]) kwargs["outcome_transform"] = octf # Update the outcome transform state dict entries. sd = { **sd, **{"outcome_transform." + k: v for k, v in octf.state_dict().items()}, } else: kwargs["outcome_transform"] = None model = batch_model.__class__(input_transform=input_transform, **kwargs) model.load_state_dict(sd) models.append(model) return ModelListGP(*models).train(mode=was_training)
[docs] def batched_multi_output_to_single_output( batch_mo_model: BatchedMultiOutputGPyTorchModel, ) -> BatchedMultiOutputGPyTorchModel: """Convert a model from batched multi-output to a batched single-output. Note: the underlying GPyTorch GP does not change. The GPyTorch GP's batch_shape (referred to as `_aug_batch_shape`) is still `_input_batch_shape x num_outputs`. The only things that change are the attributes of the BatchedMultiOutputGPyTorchModel that are responsible the internal accounting of the number of outputs: namely, num_outputs, _input_batch_shape, and _aug_batch_shape. Initially for the batched MO models these are: `num_outputs = m`, `_input_batch_shape = train_X.batch_shape`, and `_aug_batch_shape = train_X.batch_shape + torch.Size([num_outputs])`. In the new SO model, these are: `num_outputs = 1`, `_input_batch_shape = train_X.batch_shape + torch.Size([num_outputs])`, and `_aug_batch_shape = train_X.batch_shape + torch.Size([num_outputs])`. This is a (hopefully) temporary measure until multi-output MVNs with independent outputs have better support in GPyTorch (see https://github.com/cornellius-gp/gpytorch/pull/1083). Args: batched_mo_model: The BatchedMultiOutputGPyTorchModel Returns: The model converted into a batch single-output model. Example: >>> train_X = torch.rand(5, 2) >>> train_Y = torch.rand(5, 2) >>> batch_mo_gp = SingleTaskGP(train_X, train_Y) >>> batch_so_gp = batched_multioutput_to_single_output(batch_gp) """ warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2) was_training = batch_mo_model.training batch_mo_model.train() # TODO: Add support for HeteroskedasticSingleTaskGP. if isinstance(batch_mo_model, HeteroskedasticSingleTaskGP): raise NotImplementedError( "Conversion of HeteroskedasticSingleTaskGP currently not supported." ) elif not isinstance(batch_mo_model, BatchedMultiOutputGPyTorchModel): raise UnsupportedError("Only BatchedMultiOutputGPyTorchModels are supported.") # TODO: Add support for custom likelihoods. elif getattr(batch_mo_model, "_is_custom_likelihood", False): raise NotImplementedError( "Conversion of models with custom likelihoods is currently unsupported." ) input_transform = getattr(batch_mo_model, "input_transform", None) batch_sd = batch_mo_model.state_dict() # TODO: add support for outcome transforms. if hasattr(batch_mo_model, "outcome_transform"): raise NotImplementedError( "Converting batched multi-output models with outcome transforms " "is not currently supported." ) kwargs = { "train_X": batch_mo_model.train_inputs[0].clone(), "train_Y": batch_mo_model.train_targets.clone().unsqueeze(-1), } if isinstance(batch_mo_model.likelihood, FixedNoiseGaussianLikelihood): noise_covar = batch_mo_model.likelihood.noise_covar kwargs["train_Yvar"] = noise_covar.noise.clone().unsqueeze(-1) if isinstance(batch_mo_model, SingleTaskMultiFidelityGP): kwargs.update(batch_mo_model._init_args) single_outcome_model = batch_mo_model.__class__( input_transform=input_transform, **kwargs ) single_outcome_model.load_state_dict(batch_sd) return single_outcome_model.train(mode=was_training)
def _get_adjusted_batch_keys( batch_state_dict: Dict[str, Tensor], input_transform: Optional[InputTransform], outcome_transform: Optional[OutcomeTransform] = None, ) -> Tuple[Set[str], Set[str]]: r"""Group the keys based on whether the value requires batch shape changes. Args: batch_state_dict: The state dict of the batch model. input_transform: The input transform. outcome_transform: The outcome transform. Returns: A two-element tuple containing: - The keys of the parameters/buffers that require a batch shape adjustment. - The keys of the parameters/buffers that do not require a batch shape adjustment. """ # These are the names of the params/buffers that need their batch shape adjusted. adjusted_batch_keys = {n for n, p in batch_state_dict.items() if len(p.shape) > 0} # Don't modify transform buffers, so add them to non-adjusted set and remove # them from tensors. for transform, transform_type in [ (input_transform, "input_transform."), (outcome_transform, "outcome_transform."), ]: if transform is not None: transform_keys = { transform_type + n for n, p in transform.state_dict().items() } adjusted_batch_keys = adjusted_batch_keys - transform_keys # These are the names of the parameters/buffers that don't need their # batch shape adjusted. non_adjusted_batch_keys = set(batch_state_dict) - adjusted_batch_keys return adjusted_batch_keys, non_adjusted_batch_keys