Source code for botorch.utils.dispatcher
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from collections.abc import Callable
from inspect import getsource, getsourcefile
from typing import Any
from multipledispatch.dispatcher import (
Dispatcher as MDDispatcher,
MDNotImplementedError, # trivial subclass of NotImplementedError
str_signature,
)
[docs]
def type_bypassing_encoder(arg: Any) -> type:
# Allow type variables to be passed as pre-encoded arguments
return arg if isinstance(arg, type) else type(arg)
[docs]
class Dispatcher(MDDispatcher):
r"""Clearing house for multiple dispatch functionality. This class extends
`<multipledispatch.Dispatcher>` by: (i) generalizing the argument encoding
convention during method lookup, (ii) implementing `__getitem__` as a dedicated
method lookup function.
"""
def __init__(
self,
name: str,
doc: str | None = None,
encoder: Callable[Any, type] = type,
) -> None:
"""
Args:
name: A string identifier for the `Dispatcher` instance.
doc: A docstring for the multiply dispatched method(s).
encoder: A callable that individually transforms the arguments passed
at runtime in order to construct the key used for method lookup as
`tuple(map(encoder, args))`. Defaults to `type`.
"""
super().__init__(name=name, doc=doc)
self._encoder = encoder
def __getitem__(
self,
args: Any | None = None,
types: tuple[type] | None = None,
) -> Callable:
r"""Method lookup.
Args:
args: A set of arguments that act as identifiers for a stored method.
types: A tuple of types that encodes `args`.
Returns:
A callable corresponding to the given `args` or `types`.
"""
if types is None:
if args is None:
raise RuntimeError("One of `args` or `types` must be provided.")
types = self.encode_args(args)
elif args is not None:
raise RuntimeError("Only one of `args` or `types` may be provided.")
try:
func = self._cache[types]
except KeyError:
func = self.dispatch(*types)
if not func:
msg = f"{self.name}: <{', '.join(cls.__name__ for cls in types)}"
raise NotImplementedError(f"Could not find signature for {msg}")
self._cache[types] = func
return func
def __call__(self, *args: Any, **kwargs: Any) -> Any:
r"""Multiply dispatches a call to a collection of methods.
Args:
args: A set of arguments that act as identifiers for a stored method.
kwargs: Optional keyword arguments passed to the retrieved method.
Returns:
The result of evaluating `func(*args, **kwargs)`, where `func` is
the function obtained via method lookup.
"""
types = self.encode_args(args)
func = self.__getitem__(types=types)
try:
return func(*args, **kwargs)
except MDNotImplementedError:
# Traverses registered methods in order, yields whenever a match is found
funcs = self.dispatch_iter(*types)
next(funcs) # burn first, same as self.__getitem__(types=types)
for func in funcs:
try:
return func(*args, **kwargs)
except MDNotImplementedError:
pass
raise NotImplementedError(
f"Matching functions for {self.name:s}: {str_signature(types):s} "
"found, but none completed successfully"
)
[docs]
def dispatch(self, *types: type) -> Callable:
r"""Method lookup strategy. Checks for an exact match before traversing
the set of registered methods according to the current ordering.
Args:
types: A tuple of types that gets compared with the signatures
of registered methods to determine compatibility.
Returns:
The first method encountered with a matching signature.
"""
if types in self.funcs:
return self.funcs[types]
try:
return next(self.dispatch_iter(*types))
except StopIteration:
return None
[docs]
def encode_args(self, args: Any) -> tuple[type]:
r"""Converts arguments into a tuple of types used during method lookup."""
return tuple(map(self.encoder, args if isinstance(args, tuple) else (args,)))
def _help(self, *args: Any) -> str:
r"""Returns the retrieved method's docstring."""
return self.dispatch(*self.encode_args(args)).__doc__
[docs]
def help(self, *args: Any, **kwargs: Any) -> None:
r"""Prints the retrieved method's docstring."""
print(self._help(*args))
def _source(self, *args: Any) -> str:
r"""Returns the retrieved method's source types as a string."""
func = self.dispatch(*self.encode_args(args))
if not func:
raise TypeError("No function found")
return f"File: {getsourcefile(func)}\n\n{getsource(func)}"
[docs]
def source(self, *args, **kwargs) -> None:
r"""Prints the retrieved method's source types."""
print(self._source(*args))
@property
def encoder(self) -> Callable[Any, type]:
return self._encoder