#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""Utilities for fitting and manipulating models."""
from __future__ import annotations
from re import Pattern
from typing import Any, Callable, Dict, Iterator, NamedTuple, Optional, Tuple, Union
from warnings import warn
import torch
from botorch.exceptions.warnings import BotorchWarning
from botorch.models.gpytorch import GPyTorchModel
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader, TensorDataset
[docs]class TorchAttr(NamedTuple):
shape: torch.Size
dtype: torch.dtype
device: torch.device
[docs]def get_data_loader(
model: GPyTorchModel, batch_size: int = 1024, **kwargs: Any
) -> DataLoader:
dataset = TensorDataset(*model.train_inputs, model.train_targets)
return DataLoader(
dataset=dataset, batch_size=min(batch_size, len(model.train_targets)), **kwargs
)
[docs]def get_parameters(
module: Module,
requires_grad: Optional[bool] = None,
name_filter: Optional[Callable[[str], bool]] = None,
) -> Dict[str, Tensor]:
r"""Helper method for obtaining a module's parameters and their respective ranges.
Args:
module: The target module from which parameters are to be extracted.
requires_grad: Optional Boolean used to filter parameters based on whether
or not their require_grad attribute matches the user provided value.
name_filter: Optional Boolean function used to filter parameters by name.
Returns:
A dictionary of parameters.
"""
parameters = {}
for name, param in module.named_parameters():
if requires_grad is not None and param.requires_grad != requires_grad:
continue
if name_filter and not name_filter(name):
continue
parameters[name] = param
return parameters
[docs]def get_parameters_and_bounds(
module: Module,
requires_grad: Optional[bool] = None,
name_filter: Optional[Callable[[str], bool]] = None,
default_bounds: Tuple[float, float] = (-float("inf"), float("inf")),
) -> Tuple[Dict[str, Tensor], Dict[str, Tuple[Optional[float], Optional[float]]]]:
r"""Helper method for obtaining a module's parameters and their respective ranges.
Args:
module: The target module from which parameters are to be extracted.
name_filter: Optional Boolean function used to filter parameters by name.
requires_grad: Optional Boolean used to filter parameters based on whether
or not their require_grad attribute matches the user provided value.
default_bounds: Default lower and upper bounds for constrained parameters
with `None` typed bounds.
Returns:
A dictionary of parameters and a dictionary of parameter bounds.
"""
if hasattr(module, "named_parameters_and_constraints"):
bounds = {}
params = {}
for name, param, constraint in module.named_parameters_and_constraints():
if (requires_grad is None or (param.requires_grad == requires_grad)) and (
name_filter is None or name_filter(name)
):
params[name] = param
if constraint is None:
continue
bounds[name] = tuple(
default if bound is None else constraint.inverse_transform(bound)
for (bound, default) in zip(constraint, default_bounds)
)
return params, bounds
params = get_parameters(
module, requires_grad=requires_grad, name_filter=name_filter
)
return params, {}
[docs]def get_name_filter(
patterns: Iterator[Union[Pattern, str]]
) -> Callable[[Union[str, Tuple[str, Any, ...]]], bool]:
r"""Returns a binary function that filters strings (or iterables whose first
element is a string) according to a bank of excluded patterns. Typically, used
in conjunction with generators such as `module.named_parameters()`.
Args:
patterns: A collection of regular expressions or strings that
define the set of names to be excluded.
Returns:
A binary function indicating whether or not an item should be filtered.
"""
names = set()
_patterns = set()
for pattern in patterns:
if isinstance(pattern, str):
names.add(pattern)
elif isinstance(pattern, Pattern):
_patterns.add(pattern)
else:
raise TypeError(
"Expected `patterns` to contain `str` or `re.Pattern` typed elements, "
f"but found {type(pattern)}."
)
def name_filter(item: Union[str, Tuple[str, Any, ...]]) -> bool:
name = item if isinstance(item, str) else next(iter(item))
if name in names:
return False
for pattern in _patterns:
if pattern.search(name):
return False
return True
return name_filter
[docs]def sample_all_priors(model: GPyTorchModel, max_retries: int = 100) -> None:
r"""Sample from hyperparameter priors (in-place).
Args:
model: A GPyTorchModel.
"""
for _, module, prior, closure, setting_closure in model.named_priors():
if setting_closure is None:
raise RuntimeError(
"Must provide inverse transform to be able to sample from prior."
)
for i in range(max_retries):
try:
setting_closure(module, prior.sample(closure(module).shape))
break
except NotImplementedError:
warn(
f"`rsample` not implemented for {type(prior)}. Skipping.",
BotorchWarning,
)
break
except RuntimeError as e:
if "out of bounds of its current constraints" in str(e):
if i == max_retries - 1:
raise RuntimeError(
"Failed to sample a feasible parameter value "
f"from the prior after {max_retries} attempts."
)
else:
raise e