from __future__ import annotations

from functools import lru_cache
from numbers import Number
from typing import Iterator, Optional, Tuple, Union

import torch
from torch import Tensor

[docs]@lru_cache(maxsize=None) def get_constants( values: Union[Number, Iterator[Number]], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> Union[Tensor, Tuple[Tensor, ...]]: r"""Returns scalar-valued Tensors containing each of the given constants. Used to expedite tensor operations involving scalar arithmetic. Note that the returned Tensors should not be modified in-place.""" if isinstance(values, Number): return torch.full((), values, dtype=dtype, device=device) return tuple(torch.full((), val, dtype=dtype, device=device) for val in values)
[docs]def get_constants_like( values: Union[Number, Iterator[Number]], ref: Tensor, ) -> Union[Tensor, Iterator[Tensor]]: return get_constants(values, device=ref.device, dtype=ref.dtype)