Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/dispatch.py: 41%
56 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
1"""
2Handles dispatching array operations to the correct backend library, as well
3as converting arrays to backend formats and then potentially storing them as
4constants.
5"""
7import importlib
8from typing import Any, Dict
10import numpy
12from . import cupy as _cupy
13from . import jax as _jax
14from . import object_arrays
15from . import tensorflow as _tensorflow
16from . import theano as _theano
17from . import torch as _torch
19__all__ = [
20 "get_func",
21 "has_einsum",
22 "has_tensordot",
23 "build_expression",
24 "evaluate_constants",
25 "has_backend",
26]
28# known non top-level imports
29_aliases = {
30 "dask": "dask.array",
31 "theano": "theano.tensor",
32 "torch": "opt_einsum.backends.torch",
33 "jax": "jax.numpy",
34 "autograd": "autograd.numpy",
35 "mars": "mars.tensor",
36}
39def _import_func(func: str, backend: str, default: Any = None) -> Any:
40 """Try and import ``{backend}.{func}``.
41 If library is installed and func is found, return the func;
42 otherwise if default is provided, return default;
43 otherwise raise an error.
44 """
45 try:
46 lib = importlib.import_module(_aliases.get(backend, backend))
47 return getattr(lib, func) if default is None else getattr(lib, func, default)
48 except AttributeError:
49 error_msg = (
50 "{} doesn't seem to provide the function {} - see "
51 "https://optimized-einsum.readthedocs.io/en/latest/backends.html "
52 "for details on which functions are required for which contractions."
53 )
54 raise AttributeError(error_msg.format(backend, func))
57# manually cache functions as python2 doesn't support functools.lru_cache
58# other libs will be added to this if needed, but pre-populate with numpy
59_cached_funcs = {
60 ("tensordot", "numpy"): numpy.tensordot,
61 ("transpose", "numpy"): numpy.transpose,
62 ("einsum", "numpy"): numpy.einsum,
63 # also pre-populate with the arbitrary object backend
64 ("tensordot", "object"): numpy.tensordot,
65 ("transpose", "object"): numpy.transpose,
66 ("einsum", "object"): object_arrays.object_einsum,
67}
70def get_func(func: str, backend: str = "numpy", default: Any = None) -> Any:
71 """Return ``{backend}.{func}``, e.g. ``numpy.einsum``,
72 or a default func if provided. Cache result.
73 """
74 try:
75 return _cached_funcs[func, backend]
76 except KeyError:
77 fn = _import_func(func, backend, default)
78 _cached_funcs[func, backend] = fn
79 return fn
82# mark libs with einsum, else try to use tensordot/transpose as much as possible
83_has_einsum: Dict[str, bool] = {}
86def has_einsum(backend: str) -> bool:
87 """Check if ``{backend}.einsum`` exists, cache result for performance."""
88 try:
89 return _has_einsum[backend]
90 except KeyError:
91 try:
92 get_func("einsum", backend)
93 _has_einsum[backend] = True
94 except AttributeError:
95 _has_einsum[backend] = False
97 return _has_einsum[backend]
100_has_tensordot: Dict[str, bool] = {}
103def has_tensordot(backend: str) -> bool:
104 """Check if ``{backend}.tensordot`` exists, cache result for performance."""
105 try:
106 return _has_tensordot[backend]
107 except KeyError:
108 try:
109 get_func("tensordot", backend)
110 _has_tensordot[backend] = True
111 except AttributeError:
112 _has_tensordot[backend] = False
114 return _has_tensordot[backend]
117# Dispatch to correct expression backend
118# these are the backends which support explicit to-and-from numpy conversion
119CONVERT_BACKENDS = {
120 "tensorflow": _tensorflow.build_expression,
121 "theano": _theano.build_expression,
122 "cupy": _cupy.build_expression,
123 "torch": _torch.build_expression,
124 "jax": _jax.build_expression,
125}
127EVAL_CONSTS_BACKENDS = {
128 "tensorflow": _tensorflow.evaluate_constants,
129 "theano": _theano.evaluate_constants,
130 "cupy": _cupy.evaluate_constants,
131 "torch": _torch.evaluate_constants,
132 "jax": _jax.evaluate_constants,
133}
136def build_expression(backend, arrays, expr):
137 """Build an expression, based on ``expr`` and initial arrays ``arrays``,
138 that evaluates using backend ``backend``.
139 """
140 return CONVERT_BACKENDS[backend](arrays, expr)
143def evaluate_constants(backend, arrays, expr):
144 """Convert constant arrays to the correct backend, and perform as much of
145 the contraction of ``expr`` with these as possible.
146 """
147 return EVAL_CONSTS_BACKENDS[backend](arrays, expr)
150def has_backend(backend: str) -> bool:
151 """Checks if the backend is known."""
152 return backend.lower() in CONVERT_BACKENDS