#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""Utilities for fitting and manipulating models."""
from __future__ import annotations
from re import Pattern
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
NamedTuple,
Optional,
Tuple,
Union,
)
from warnings import warn
import torch
from botorch.exceptions.warnings import BotorchWarning
from botorch.models.gpytorch import GPyTorchModel
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader, TensorDataset
[docs]class TorchAttr(NamedTuple):
shape: torch.Size
dtype: torch.dtype
device: torch.device
def _get_extra_mll_args(
mll: MarginalLogLikelihood,
) -> Union[List[Tensor], List[List[Tensor]]]:
r"""Obtain extra arguments for MarginalLogLikelihood objects.
Get extra arguments (beyond the model output and training targets) required
for the particular type of MarginalLogLikelihood for a forward pass.
Args:
mll: The MarginalLogLikelihood module.
Returns:
Extra arguments for the MarginalLogLikelihood.
Returns an empty list if the mll type is unknown.
"""
warn("`_get_extra_mll_args` is marked for deprecation.", DeprecationWarning)
if isinstance(mll, ExactMarginalLogLikelihood):
return list(mll.model.train_inputs)
elif isinstance(mll, SumMarginalLogLikelihood):
return [list(x) for x in mll.model.train_inputs]
return []
[docs]def get_data_loader(
model: GPyTorchModel, batch_size: int = 1024, **kwargs: Any
) -> DataLoader:
dataset = TensorDataset(*model.train_inputs, model.train_targets)
return DataLoader(
dataset=dataset, batch_size=min(batch_size, len(model.train_targets)), **kwargs
)
[docs]def get_parameters(
module: Module,
requires_grad: Optional[bool] = None,
name_filter: Optional[Callable[[str], bool]] = None,
) -> Dict[str, Tensor]:
r"""Helper method for obtaining a module's parameters and their respective ranges.
Args:
module: The target module from which parameters are to be extracted.
requires_grad: Optional Boolean used to filter parameters based on whether
or not their require_grad attribute matches the user provided value.
name_filter: Optional Boolean function used to filter parameters by name.
Returns:
A dictionary of parameters.
"""
parameters = {}
for name, param in module.named_parameters():
if requires_grad is not None and param.requires_grad != requires_grad:
continue
if name_filter and not name_filter(name):
continue
parameters[name] = param
return parameters
[docs]def get_parameters_and_bounds(
module: Module,
requires_grad: Optional[bool] = None,
name_filter: Optional[Callable[[str], bool]] = None,
default_bounds: Tuple[float, float] = (-float("inf"), float("inf")),
) -> Tuple[Dict[str, Tensor], Dict[str, Tuple[Optional[float], Optional[float]]]]:
r"""Helper method for obtaining a module's parameters and their respective ranges.
Args:
module: The target module from which parameters are to be extracted.
name_filter: Optional Boolean function used to filter parameters by name.
requires_grad: Optional Boolean used to filter parameters based on whether
or not their require_grad attribute matches the user provided value.
default_bounds: Default lower and upper bounds for constrained parameters
with `None` typed bounds.
Returns:
A dictionary of parameters and a dictionary of parameter bounds.
"""
if hasattr(module, "named_parameters_and_constraints"):
bounds = {}
params = {}
for name, param, constraint in module.named_parameters_and_constraints():
if (requires_grad is None or (param.requires_grad == requires_grad)) and (
name_filter is None or name_filter(name)
):
params[name] = param
if constraint is None:
continue
bounds[name] = tuple(
default if bound is None else constraint.inverse_transform(bound)
for (bound, default) in zip(constraint, default_bounds)
)
return params, bounds
params = get_parameters(
module, requires_grad=requires_grad, name_filter=name_filter
)
return params, {}
[docs]def get_name_filter(
patterns: Iterator[Union[Pattern, str]]
) -> Callable[[Union[str, Tuple[str, Any, ...]]], bool]:
r"""Returns a binary function that filters strings (or iterables whose first
element is a string) according to a bank of excluded patterns. Typically, used
in conjunction with generators such as `module.named_parameters()`.
Args:
patterns: A collection of regular expressions or strings that
define the set of names to be excluded.
Returns:
A binary function indicating whether or not an item should be filtered.
"""
names = set()
_patterns = set()
for pattern in patterns:
if isinstance(pattern, str):
names.add(pattern)
elif isinstance(pattern, Pattern):
_patterns.add(pattern)
else:
raise TypeError(
"Expected `patterns` to contain `str` or `re.Pattern` typed elements, "
f"but found {type(pattern)}."
)
def name_filter(item: Union[str, Tuple[str, Any, ...]]) -> bool:
name = item if isinstance(item, str) else next(iter(item))
if name in names:
return False
for pattern in _patterns:
if pattern.search(name):
return False
return True
return name_filter
[docs]def sample_all_priors(model: GPyTorchModel, max_retries: int = 100) -> None:
r"""Sample from hyperparameter priors (in-place).
Args:
model: A GPyTorchModel.
"""
for _, module, prior, closure, setting_closure in model.named_priors():
if setting_closure is None:
raise RuntimeError(
"Must provide inverse transform to be able to sample from prior."
)
for i in range(max_retries):
try:
setting_closure(module, prior.sample(closure(module).shape))
break
except NotImplementedError:
warn(
f"`rsample` not implemented for {type(prior)}. Skipping.",
BotorchWarning,
)
break
except RuntimeError as e:
if "out of bounds of its current constraints" in str(e):
if i == max_retries - 1:
raise RuntimeError(
"Failed to sample a feasible parameter value "
f"from the prior after {max_retries} attempts."
)
else:
raise e
[docs]def allclose_mll(
a: MarginalLogLikelihood,
b: MarginalLogLikelihood,
transform_a: Optional[Callable[[Tensor], Tensor]] = None,
transform_b: Optional[Callable[[Tensor], Tensor]] = None,
rtol: float = 1e-05,
atol: float = 1e-08,
) -> bool:
r"""Convenience method for testing whether the log likelihoods produced by different
MarginalLogLikelihood instances, when evaluated on their respective models' training
sets, are allclose.
Args:
a: A MarginalLogLikelihood instance.
b: A second MarginalLogLikelihood instance.
transform_a: Optional callable used to post-transform log likelihoods under `a`.
transform_b: Optional callable used to post-transform log likelihoods under `b`.
rtol: Relative tolerance.
atol: Absolute tolerance.
Returns:
Boolean result of the allclose test.
"""
warn("`allclose_mll` is marked for deprecation.", DeprecationWarning)
values_a = a(
a.model(*a.model.train_inputs),
a.model.train_targets,
*_get_extra_mll_args(a),
)
if transform_a:
values_a = transform_a(values_a)
values_b = b(
b.model(*b.model.train_inputs),
b.model.train_targets,
*_get_extra_mll_args(b),
)
if transform_b:
values_b = transform_b(values_b)
return values_a.allclose(values_b, rtol=rtol, atol=atol)