Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/dispatch.py: 40%
55 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""
2Handles dispatching array operations to the correct backend library, as well
3as converting arrays to backend formats and then potentially storing them as
4constants.
5"""
7import importlib
9import numpy
11from . import object_arrays
12from . import cupy as _cupy
13from . import jax as _jax
14from . import tensorflow as _tensorflow
15from . import theano as _theano
16from . import torch as _torch
18__all__ = ["get_func", "has_einsum", "has_tensordot", "build_expression", "evaluate_constants", "has_backend"]
20# known non top-level imports
21_aliases = {
22 'dask': 'dask.array',
23 'theano': 'theano.tensor',
24 'torch': 'opt_einsum.backends.torch',
25 'jax': 'jax.numpy',
26 'autograd': 'autograd.numpy',
27 'mars': 'mars.tensor',
28}
31def _import_func(func, backend, default=None):
32 """Try and import ``{backend}.{func}``.
33 If library is installed and func is found, return the func;
34 otherwise if default is provided, return default;
35 otherwise raise an error.
36 """
37 try:
38 lib = importlib.import_module(_aliases.get(backend, backend))
39 return getattr(lib, func) if default is None else getattr(lib, func, default)
40 except AttributeError:
41 error_msg = ("{} doesn't seem to provide the function {} - see "
42 "https://optimized-einsum.readthedocs.io/en/latest/backends.html "
43 "for details on which functions are required for which contractions.")
44 raise AttributeError(error_msg.format(backend, func))
47# manually cache functions as python2 doesn't support functools.lru_cache
48# other libs will be added to this if needed, but pre-populate with numpy
49_cached_funcs = {
50 ('tensordot', 'numpy'): numpy.tensordot,
51 ('transpose', 'numpy'): numpy.transpose,
52 ('einsum', 'numpy'): numpy.einsum,
53 # also pre-populate with the arbitrary object backend
54 ('tensordot', 'object'): numpy.tensordot,
55 ('transpose', 'object'): numpy.transpose,
56 ('einsum', 'object'): object_arrays.object_einsum,
57}
60def get_func(func, backend='numpy', default=None):
61 """Return ``{backend}.{func}``, e.g. ``numpy.einsum``,
62 or a default func if provided. Cache result.
63 """
64 try:
65 return _cached_funcs[func, backend]
66 except KeyError:
67 fn = _import_func(func, backend, default)
68 _cached_funcs[func, backend] = fn
69 return fn
72# mark libs with einsum, else try to use tensordot/tranpose as much as possible
73_has_einsum = {}
76def has_einsum(backend):
77 """Check if ``{backend}.einsum`` exists, cache result for performance.
78 """
79 try:
80 return _has_einsum[backend]
81 except KeyError:
82 try:
83 get_func('einsum', backend)
84 _has_einsum[backend] = True
85 except AttributeError:
86 _has_einsum[backend] = False
88 return _has_einsum[backend]
91_has_tensordot = {}
94def has_tensordot(backend):
95 """Check if ``{backend}.tensordot`` exists, cache result for performance.
96 """
97 try:
98 return _has_tensordot[backend]
99 except KeyError:
100 try:
101 get_func('tensordot', backend)
102 _has_tensordot[backend] = True
103 except AttributeError:
104 _has_tensordot[backend] = False
106 return _has_tensordot[backend]
109# Dispatch to correct expression backend
110# these are the backends which support explicit to-and-from numpy conversion
111CONVERT_BACKENDS = {
112 'tensorflow': _tensorflow.build_expression,
113 'theano': _theano.build_expression,
114 'cupy': _cupy.build_expression,
115 'torch': _torch.build_expression,
116 'jax': _jax.build_expression,
117}
119EVAL_CONSTS_BACKENDS = {
120 'tensorflow': _tensorflow.evaluate_constants,
121 'theano': _theano.evaluate_constants,
122 'cupy': _cupy.evaluate_constants,
123 'torch': _torch.evaluate_constants,
124 'jax': _jax.evaluate_constants,
125}
128def build_expression(backend, arrays, expr):
129 """Build an expression, based on ``expr`` and initial arrays ``arrays``,
130 that evaluates using backend ``backend``.
131 """
132 return CONVERT_BACKENDS[backend](arrays, expr)
135def evaluate_constants(backend, arrays, expr):
136 """Convert constant arrays to the correct backend, and perform as much of
137 the contraction of ``expr`` with these as possible.
138 """
139 return EVAL_CONSTS_BACKENDS[backend](arrays, expr)
142def has_backend(backend):
143 """Checks if the backend is known.
144 """
145 return backend.lower() in CONVERT_BACKENDS