Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/dispatch.py: 40%

55 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1""" 

2Handles dispatching array operations to the correct backend library, as well 

3as converting arrays to backend formats and then potentially storing them as 

4constants. 

5""" 

6 

7import importlib 

8 

9import numpy 

10 

11from . import object_arrays 

12from . import cupy as _cupy 

13from . import jax as _jax 

14from . import tensorflow as _tensorflow 

15from . import theano as _theano 

16from . import torch as _torch 

17 

18__all__ = ["get_func", "has_einsum", "has_tensordot", "build_expression", "evaluate_constants", "has_backend"] 

19 

20# known non top-level imports 

21_aliases = { 

22 'dask': 'dask.array', 

23 'theano': 'theano.tensor', 

24 'torch': 'opt_einsum.backends.torch', 

25 'jax': 'jax.numpy', 

26 'autograd': 'autograd.numpy', 

27 'mars': 'mars.tensor', 

28} 

29 

30 

31def _import_func(func, backend, default=None): 

32 """Try and import ``{backend}.{func}``. 

33 If library is installed and func is found, return the func; 

34 otherwise if default is provided, return default; 

35 otherwise raise an error. 

36 """ 

37 try: 

38 lib = importlib.import_module(_aliases.get(backend, backend)) 

39 return getattr(lib, func) if default is None else getattr(lib, func, default) 

40 except AttributeError: 

41 error_msg = ("{} doesn't seem to provide the function {} - see " 

42 "https://optimized-einsum.readthedocs.io/en/latest/backends.html " 

43 "for details on which functions are required for which contractions.") 

44 raise AttributeError(error_msg.format(backend, func)) 

45 

46 

47# manually cache functions as python2 doesn't support functools.lru_cache 

48# other libs will be added to this if needed, but pre-populate with numpy 

49_cached_funcs = { 

50 ('tensordot', 'numpy'): numpy.tensordot, 

51 ('transpose', 'numpy'): numpy.transpose, 

52 ('einsum', 'numpy'): numpy.einsum, 

53 # also pre-populate with the arbitrary object backend 

54 ('tensordot', 'object'): numpy.tensordot, 

55 ('transpose', 'object'): numpy.transpose, 

56 ('einsum', 'object'): object_arrays.object_einsum, 

57} 

58 

59 

60def get_func(func, backend='numpy', default=None): 

61 """Return ``{backend}.{func}``, e.g. ``numpy.einsum``, 

62 or a default func if provided. Cache result. 

63 """ 

64 try: 

65 return _cached_funcs[func, backend] 

66 except KeyError: 

67 fn = _import_func(func, backend, default) 

68 _cached_funcs[func, backend] = fn 

69 return fn 

70 

71 

72# mark libs with einsum, else try to use tensordot/tranpose as much as possible 

73_has_einsum = {} 

74 

75 

76def has_einsum(backend): 

77 """Check if ``{backend}.einsum`` exists, cache result for performance. 

78 """ 

79 try: 

80 return _has_einsum[backend] 

81 except KeyError: 

82 try: 

83 get_func('einsum', backend) 

84 _has_einsum[backend] = True 

85 except AttributeError: 

86 _has_einsum[backend] = False 

87 

88 return _has_einsum[backend] 

89 

90 

91_has_tensordot = {} 

92 

93 

94def has_tensordot(backend): 

95 """Check if ``{backend}.tensordot`` exists, cache result for performance. 

96 """ 

97 try: 

98 return _has_tensordot[backend] 

99 except KeyError: 

100 try: 

101 get_func('tensordot', backend) 

102 _has_tensordot[backend] = True 

103 except AttributeError: 

104 _has_tensordot[backend] = False 

105 

106 return _has_tensordot[backend] 

107 

108 

109# Dispatch to correct expression backend 

110# these are the backends which support explicit to-and-from numpy conversion 

111CONVERT_BACKENDS = { 

112 'tensorflow': _tensorflow.build_expression, 

113 'theano': _theano.build_expression, 

114 'cupy': _cupy.build_expression, 

115 'torch': _torch.build_expression, 

116 'jax': _jax.build_expression, 

117} 

118 

119EVAL_CONSTS_BACKENDS = { 

120 'tensorflow': _tensorflow.evaluate_constants, 

121 'theano': _theano.evaluate_constants, 

122 'cupy': _cupy.evaluate_constants, 

123 'torch': _torch.evaluate_constants, 

124 'jax': _jax.evaluate_constants, 

125} 

126 

127 

128def build_expression(backend, arrays, expr): 

129 """Build an expression, based on ``expr`` and initial arrays ``arrays``, 

130 that evaluates using backend ``backend``. 

131 """ 

132 return CONVERT_BACKENDS[backend](arrays, expr) 

133 

134 

135def evaluate_constants(backend, arrays, expr): 

136 """Convert constant arrays to the correct backend, and perform as much of 

137 the contraction of ``expr`` with these as possible. 

138 """ 

139 return EVAL_CONSTS_BACKENDS[backend](arrays, expr) 

140 

141 

142def has_backend(backend): 

143 """Checks if the backend is known. 

144 """ 

145 return backend.lower() in CONVERT_BACKENDS