#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
Utilities for converting between different models.
"""
from __future__ import annotations
from copy import deepcopy
import torch
from botorch.exceptions import UnsupportedError
from botorch.models.gp_regression import FixedNoiseGP, HeteroskedasticSingleTaskGP
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
from botorch.models.model_list_gp_regression import ModelListGP
from torch.nn import Module
def _get_module(module: Module, name: str) -> Module:
"""Recursively get a sub-module from a module.
Args:
module: A `torch.nn.Module`.
name: The name of the submodule to return, in the form of a period-delinated
string: `sub_module.subsub_module.[...].leaf_module`.
Returns:
The requested sub-module.
Example:
>>> gp = SingleTaskGP(train_X, train_Y)
>>> noise_prior = _get_module(gp, "likelihood.noise_covar.noise_prior")
"""
current = module
if name != "":
for a in name.split("."):
current = getattr(current, a)
return current
def _check_compatibility(models: ModelListGP) -> None:
"""Check if a ModelListGP can be converted."""
# check that all submodules are of the same type
for modn, mod in models[0].named_modules():
mcls = mod.__class__
if not all(isinstance(_get_module(m, modn), mcls) for m in models[1:]):
raise UnsupportedError(
"Sub-modules must be of the same type across models."
)
# check that each model is a BatchedMultiOutputGPyTorchModel
if not all(isinstance(m, BatchedMultiOutputGPyTorchModel) for m in models):
raise UnsupportedError(
"All models must be of type BatchedMultiOutputGPyTorchModel."
)
# TODO: Add support for HeteroskedasticSingleTaskGP
if any(isinstance(m, HeteroskedasticSingleTaskGP) for m in models):
raise NotImplementedError(
"Conversion of HeteroskedasticSingleTaskGP is currently unsupported."
)
# TODO: Add support for custom likelihoods
if any(getattr(m, "_is_custom_likelihood", False) for m in models):
raise NotImplementedError(
"Conversion of models with custom likelihoods is currently unsupported."
)
# check that each model is single-output
if not all(m._num_outputs == 1 for m in models):
raise UnsupportedError("All models must be single-output.")
# check that training inputs are the same
if not all(
torch.equal(ti, tj)
for m in models[1:]
for ti, tj in zip(models[0].train_inputs, m.train_inputs)
):
raise UnsupportedError("training inputs must agree for all sub-models.")
[docs]def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorchModel:
"""Convert a ModelListGP to a BatchedMultiOutputGPyTorchModel.
Args:
model_list: The `ModelListGP` to be converted to the appropriate
`BatchedMultiOutputGPyTorchModel`. All sub-models must be of the same
type and have the shape (batch shape and number of training inputs).
Returns:
The model converted into a `BatchedMultiOutputGPyTorchModel`.
Example:
>>> list_gp = ModelListGP(gp1, gp2)
>>> batch_gp = model_list_to_batched(list_gp)
"""
models = model_list.models
_check_compatibility(models)
# if the list has only one model, we can just return a copy of that
if len(models) == 1:
return deepcopy(models[0])
# construct inputs
train_X = deepcopy(models[0].train_inputs[0])
train_Y = torch.stack([m.train_targets.clone() for m in models], dim=-1)
kwargs = {"train_X": train_X, "train_Y": train_Y}
if isinstance(models[0], FixedNoiseGP):
kwargs["train_Yvar"] = torch.stack(
[m.likelihood.noise_covar.noise.clone() for m in models], dim=-1
)
if isinstance(models[0], SingleTaskMultiFidelityGP):
init_args = models[0]._init_args
if not all(
v == m._init_args[k] for m in models[1:] for k, v in init_args.items()
):
raise UnsupportedError("All models must have the same fidelity parameters.")
kwargs.update(init_args)
# construct the batched GP model
batch_gp = models[0].__class__(**kwargs)
tensors = {n for n, p in batch_gp.state_dict().items() if len(p.shape) > 0}
scalars = set(batch_gp.state_dict()) - tensors
input_batch_dims = len(models[0]._input_batch_shape)
# ensure scalars agree (TODO: Allow different priors for different outputs)
for n in scalars:
v0 = _get_module(models[0], n)
if not all(torch.equal(_get_module(m, n), v0) for m in models[1:]):
raise UnsupportedError("All scalars must have the same value.")
# ensure dimensions of all tensors agree
for n in tensors:
shape0 = _get_module(models[0], n).shape
if not all(_get_module(m, n).shape == shape0 for m in models[1:]):
raise UnsupportedError("All tensors must have the same shape.")
# now construct the batched state dict
scalar_state_dict = {
s: p.clone() for s, p in models[0].state_dict().items() if s in scalars
}
tensor_state_dict = {
t: (
torch.stack(
[m.state_dict()[t].clone() for m in models], dim=input_batch_dims
)
if "active_dims" not in t
else models[0].state_dict()[t].clone()
)
for t in tensors
}
batch_state_dict = {**scalar_state_dict, **tensor_state_dict}
# load the state dict into the new model
batch_gp.load_state_dict(batch_state_dict)
return batch_gp
[docs]def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> ModelListGP:
"""Convert a BatchedMultiOutputGPyTorchModel to a ModelListGP.
Args:
model_list: The `BatchedMultiOutputGPyTorchModel` to be converted to a
`ModelListGP`.
Returns:
The model converted into a `ModelListGP`.
Example:
>>> train_X = torch.rand(5, 2)
>>> train_Y = torch.rand(5, 2)
>>> batch_gp = SingleTaskGP(train_X, train_Y)
>>> list_gp = batched_to_model_list(batch_gp)
"""
# TODO: Add support for HeteroskedasticSingleTaskGP
if isinstance(batch_model, HeteroskedasticSingleTaskGP):
raise NotImplementedError(
"Conversion of HeteroskedasticSingleTaskGP currently not supported."
)
batch_sd = batch_model.state_dict()
tensors = {n for n, p in batch_sd.items() if len(p.shape) > 0}
scalars = set(batch_sd) - tensors
input_bdims = len(batch_model._input_batch_shape)
models = []
for i in range(batch_model._num_outputs):
scalar_sd = {s: batch_sd[s].clone() for s in scalars}
tensor_sd = {
t: (
batch_sd[t].select(input_bdims, i).clone()
if "active_dims" not in t
else batch_sd[t].clone()
)
for t in tensors
}
sd = {**scalar_sd, **tensor_sd}
kwargs = {
"train_X": batch_model.train_inputs[0].select(input_bdims, i).clone(),
"train_Y": batch_model.train_targets.select(input_bdims, i)
.clone()
.unsqueeze(-1),
}
if isinstance(batch_model, FixedNoiseGP):
noise_covar = batch_model.likelihood.noise_covar
kwargs["train_Yvar"] = (
noise_covar.noise.select(input_bdims, i).clone().unsqueeze(-1)
)
if isinstance(batch_model, SingleTaskMultiFidelityGP):
kwargs.update(batch_model._init_args)
model = batch_model.__class__(**kwargs)
model.load_state_dict(sd)
models.append(model)
return ModelListGP(*models)