Source code for botorch.optim.utils.model_utils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""Utilities for fitting and manipulating models."""

from __future__ import annotations

from re import Pattern
from typing import (
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    NamedTuple,
    Optional,
    Tuple,
    Union,
)
from warnings import warn

import torch
from botorch.exceptions.warnings import BotorchWarning
from botorch.models.gpytorch import GPyTorchModel
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader, TensorDataset


[docs]class TorchAttr(NamedTuple): shape: torch.Size dtype: torch.dtype device: torch.device
def _get_extra_mll_args( mll: MarginalLogLikelihood, ) -> Union[List[Tensor], List[List[Tensor]]]: r"""Obtain extra arguments for MarginalLogLikelihood objects. Get extra arguments (beyond the model output and training targets) required for the particular type of MarginalLogLikelihood for a forward pass. Args: mll: The MarginalLogLikelihood module. Returns: Extra arguments for the MarginalLogLikelihood. Returns an empty list if the mll type is unknown. """ warn("`_get_extra_mll_args` is marked for deprecation.", DeprecationWarning) if isinstance(mll, ExactMarginalLogLikelihood): return list(mll.model.train_inputs) elif isinstance(mll, SumMarginalLogLikelihood): return [list(x) for x in mll.model.train_inputs] return []
[docs]def get_data_loader( model: GPyTorchModel, batch_size: int = 1024, **kwargs: Any ) -> DataLoader: dataset = TensorDataset(*model.train_inputs, model.train_targets) return DataLoader( dataset=dataset, batch_size=min(batch_size, len(model.train_targets)), **kwargs )
[docs]def get_parameters( module: Module, requires_grad: Optional[bool] = None, name_filter: Optional[Callable[[str], bool]] = None, ) -> Dict[str, Tensor]: r"""Helper method for obtaining a module's parameters and their respective ranges. Args: module: The target module from which parameters are to be extracted. requires_grad: Optional Boolean used to filter parameters based on whether or not their require_grad attribute matches the user provided value. name_filter: Optional Boolean function used to filter parameters by name. Returns: A dictionary of parameters. """ parameters = {} for name, param in module.named_parameters(): if requires_grad is not None and param.requires_grad != requires_grad: continue if name_filter and not name_filter(name): continue parameters[name] = param return parameters
[docs]def get_parameters_and_bounds( module: Module, requires_grad: Optional[bool] = None, name_filter: Optional[Callable[[str], bool]] = None, default_bounds: Tuple[float, float] = (-float("inf"), float("inf")), ) -> Tuple[Dict[str, Tensor], Dict[str, Tuple[Optional[float], Optional[float]]]]: r"""Helper method for obtaining a module's parameters and their respective ranges. Args: module: The target module from which parameters are to be extracted. name_filter: Optional Boolean function used to filter parameters by name. requires_grad: Optional Boolean used to filter parameters based on whether or not their require_grad attribute matches the user provided value. default_bounds: Default lower and upper bounds for constrained parameters with `None` typed bounds. Returns: A dictionary of parameters and a dictionary of parameter bounds. """ if hasattr(module, "named_parameters_and_constraints"): bounds = {} params = {} for name, param, constraint in module.named_parameters_and_constraints(): if (requires_grad is None or (param.requires_grad == requires_grad)) and ( name_filter is None or name_filter(name) ): params[name] = param if constraint is None: continue bounds[name] = tuple( default if bound is None else constraint.inverse_transform(bound) for (bound, default) in zip(constraint, default_bounds) ) return params, bounds params = get_parameters( module, requires_grad=requires_grad, name_filter=name_filter ) return params, {}
[docs]def get_name_filter( patterns: Iterator[Union[Pattern, str]] ) -> Callable[[Union[str, Tuple[str, Any, ...]]], bool]: r"""Returns a binary function that filters strings (or iterables whose first element is a string) according to a bank of excluded patterns. Typically, used in conjunction with generators such as `module.named_parameters()`. Args: patterns: A collection of regular expressions or strings that define the set of names to be excluded. Returns: A binary function indicating whether or not an item should be filtered. """ names = set() _patterns = set() for pattern in patterns: if isinstance(pattern, str): names.add(pattern) elif isinstance(pattern, Pattern): _patterns.add(pattern) else: raise TypeError( "Expected `patterns` to contain `str` or `re.Pattern` typed elements, " f"but found {type(pattern)}." ) def name_filter(item: Union[str, Tuple[str, Any, ...]]) -> bool: name = item if isinstance(item, str) else next(iter(item)) if name in names: return False for pattern in _patterns: if pattern.search(name): return False return True return name_filter
[docs]def sample_all_priors(model: GPyTorchModel, max_retries: int = 100) -> None: r"""Sample from hyperparameter priors (in-place). Args: model: A GPyTorchModel. """ for _, module, prior, closure, setting_closure in model.named_priors(): if setting_closure is None: raise RuntimeError( "Must provide inverse transform to be able to sample from prior." ) for i in range(max_retries): try: setting_closure(module, prior.sample(closure(module).shape)) break except NotImplementedError: warn( f"`rsample` not implemented for {type(prior)}. Skipping.", BotorchWarning, ) break except RuntimeError as e: if "out of bounds of its current constraints" in str(e): if i == max_retries - 1: raise RuntimeError( "Failed to sample a feasible parameter value " f"from the prior after {max_retries} attempts." ) else: raise e