#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
r"""
A converter that simplifies using numpy-based optimizers with generic torch
`nn.Module` classes. This enables using a `scipy.optim.minimize` optimizer
for optimizing module parameters.
"""
from collections import OrderedDict
from math import inf
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
import numpy as np
import torch
from torch.nn import Module
ParameterBounds = Dict[str, Tuple[Optional[float], Optional[float]]]
[docs]class TorchAttr(NamedTuple):
shape: torch.Size
dtype: torch.dtype
device: torch.device
[docs]def module_to_array(
module: Module,
bounds: Optional[ParameterBounds] = None,
exclude: Optional[Set[str]] = None,
) -> Tuple[np.ndarray, Dict[str, TorchAttr], Optional[np.ndarray]]:
r"""Extract named parameters from a module into a numpy array.
Only extracts parameters with requires_grad, since it is meant for optimizing.
Args:
module: A module with parameters. May specify parameter constraints in
a `named_parameters_and_constraints` method.
bounds: A ParameterBounds dictionary mapping parameter names to tuples
of lower and upper bounds. Bounds specified here take precedence
over bounds on the same parameters specified in the constraints
registered with the module.
exclude: A list of parameter names that are to be excluded from extraction.
Returns:
3-element tuple containing
- The parameter values as a numpy array.
- An ordered dictionary with the name and tensor attributes of each
parameter.
- A `2 x n_params` numpy array with lower and upper bounds if at least
one constraint is finite, and None otherwise.
Example:
>>> mll = ExactMarginalLogLikelihood(model.likelihood, model)
>>> parameter_array, property_dict, bounds_out = module_to_array(mll)
"""
x: List[np.ndarray] = []
lower: List[np.ndarray] = []
upper: List[np.ndarray] = []
property_dict = OrderedDict()
exclude = set() if exclude is None else exclude
# get bounds specified in model (if any)
bounds_: ParameterBounds = {}
if hasattr(module, "named_parameters_and_constraints"):
for param_name, _, constraint in module.named_parameters_and_constraints():
if constraint is not None and not constraint.enforced:
bounds_[param_name] = constraint.lower_bound, constraint.upper_bound
# update with user-supplied bounds (overwrites if already exists)
if bounds is not None:
bounds_.update(bounds)
for p_name, t in module.named_parameters():
if p_name not in exclude and t.requires_grad:
property_dict[p_name] = TorchAttr(
shape=t.shape, dtype=t.dtype, device=t.device
)
x.append(t.detach().view(-1).cpu().double().clone().numpy())
# construct bounds
if bounds_:
l_, u_ = bounds_.get(p_name, (-inf, inf))
if torch.is_tensor(l_):
l_ = l_.cpu().detach()
if torch.is_tensor(u_):
u_ = u_.cpu().detach()
# check for Nones here b/c it may be passed in manually in bounds
lower.append(np.full(t.nelement(), l_ if l_ is not None else -inf))
upper.append(np.full(t.nelement(), u_ if u_ is not None else inf))
x_out = np.concatenate(x)
bounds_out = None
if bounds_:
if not all(np.isinf(b).all() for lu in (lower, upper) for b in lu):
bounds_out = np.stack([np.concatenate(lower), np.concatenate(upper)])
return x_out, property_dict, bounds_out
[docs]def set_params_with_array(
module: Module, x: np.ndarray, property_dict: Dict[str, TorchAttr]
) -> Module:
r"""Set module parameters with values from numpy array.
Args:
module: Module with parameters to be set
x: Numpy array with parameter values
property_dict: Dictionary of parameter names and torch attributes as
returned by module_to_array.
Returns:
Module: module with parameters updated in-place.
Example:
>>> mll = ExactMarginalLogLikelihood(model.likelihood, model)
>>> parameter_array, property_dict, bounds_out = module_to_array(mll)
>>> parameter_array += 0.1 # perturb parameters (for example only)
>>> mll = set_params_with_array(mll, parameter_array, property_dict)
"""
param_dict = OrderedDict(module.named_parameters())
start_idx = 0
for p_name, attrs in property_dict.items():
# Construct the new tensor
if len(attrs.shape) == 0: # deal with scalar tensors
end_idx = start_idx + 1
new_data = torch.tensor(
x[start_idx], dtype=attrs.dtype, device=attrs.device
)
else:
end_idx = start_idx + np.prod(attrs.shape)
new_data = torch.tensor(
x[start_idx:end_idx], dtype=attrs.dtype, device=attrs.device
).view(*attrs.shape)
start_idx = end_idx
# Update corresponding parameter in-place. Disable autograd to update.
param_dict[p_name].requires_grad_(False)
param_dict[p_name].copy_(new_data)
param_dict[p_name].requires_grad_(True)
return module