Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/dispatch.py: 41%

56 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:41 +0000

1""" 

2Handles dispatching array operations to the correct backend library, as well 

3as converting arrays to backend formats and then potentially storing them as 

4constants. 

5""" 

6 

7import importlib 

8from typing import Any, Dict 

9 

10import numpy 

11 

12from . import cupy as _cupy 

13from . import jax as _jax 

14from . import object_arrays 

15from . import tensorflow as _tensorflow 

16from . import theano as _theano 

17from . import torch as _torch 

18 

19__all__ = [ 

20 "get_func", 

21 "has_einsum", 

22 "has_tensordot", 

23 "build_expression", 

24 "evaluate_constants", 

25 "has_backend", 

26] 

27 

28# known non top-level imports 

29_aliases = { 

30 "dask": "dask.array", 

31 "theano": "theano.tensor", 

32 "torch": "opt_einsum.backends.torch", 

33 "jax": "jax.numpy", 

34 "autograd": "autograd.numpy", 

35 "mars": "mars.tensor", 

36} 

37 

38 

39def _import_func(func: str, backend: str, default: Any = None) -> Any: 

40 """Try and import ``{backend}.{func}``. 

41 If library is installed and func is found, return the func; 

42 otherwise if default is provided, return default; 

43 otherwise raise an error. 

44 """ 

45 try: 

46 lib = importlib.import_module(_aliases.get(backend, backend)) 

47 return getattr(lib, func) if default is None else getattr(lib, func, default) 

48 except AttributeError: 

49 error_msg = ( 

50 "{} doesn't seem to provide the function {} - see " 

51 "https://optimized-einsum.readthedocs.io/en/latest/backends.html " 

52 "for details on which functions are required for which contractions." 

53 ) 

54 raise AttributeError(error_msg.format(backend, func)) 

55 

56 

57# manually cache functions as python2 doesn't support functools.lru_cache 

58# other libs will be added to this if needed, but pre-populate with numpy 

59_cached_funcs = { 

60 ("tensordot", "numpy"): numpy.tensordot, 

61 ("transpose", "numpy"): numpy.transpose, 

62 ("einsum", "numpy"): numpy.einsum, 

63 # also pre-populate with the arbitrary object backend 

64 ("tensordot", "object"): numpy.tensordot, 

65 ("transpose", "object"): numpy.transpose, 

66 ("einsum", "object"): object_arrays.object_einsum, 

67} 

68 

69 

70def get_func(func: str, backend: str = "numpy", default: Any = None) -> Any: 

71 """Return ``{backend}.{func}``, e.g. ``numpy.einsum``, 

72 or a default func if provided. Cache result. 

73 """ 

74 try: 

75 return _cached_funcs[func, backend] 

76 except KeyError: 

77 fn = _import_func(func, backend, default) 

78 _cached_funcs[func, backend] = fn 

79 return fn 

80 

81 

82# mark libs with einsum, else try to use tensordot/transpose as much as possible 

83_has_einsum: Dict[str, bool] = {} 

84 

85 

86def has_einsum(backend: str) -> bool: 

87 """Check if ``{backend}.einsum`` exists, cache result for performance.""" 

88 try: 

89 return _has_einsum[backend] 

90 except KeyError: 

91 try: 

92 get_func("einsum", backend) 

93 _has_einsum[backend] = True 

94 except AttributeError: 

95 _has_einsum[backend] = False 

96 

97 return _has_einsum[backend] 

98 

99 

100_has_tensordot: Dict[str, bool] = {} 

101 

102 

103def has_tensordot(backend: str) -> bool: 

104 """Check if ``{backend}.tensordot`` exists, cache result for performance.""" 

105 try: 

106 return _has_tensordot[backend] 

107 except KeyError: 

108 try: 

109 get_func("tensordot", backend) 

110 _has_tensordot[backend] = True 

111 except AttributeError: 

112 _has_tensordot[backend] = False 

113 

114 return _has_tensordot[backend] 

115 

116 

117# Dispatch to correct expression backend 

118# these are the backends which support explicit to-and-from numpy conversion 

119CONVERT_BACKENDS = { 

120 "tensorflow": _tensorflow.build_expression, 

121 "theano": _theano.build_expression, 

122 "cupy": _cupy.build_expression, 

123 "torch": _torch.build_expression, 

124 "jax": _jax.build_expression, 

125} 

126 

127EVAL_CONSTS_BACKENDS = { 

128 "tensorflow": _tensorflow.evaluate_constants, 

129 "theano": _theano.evaluate_constants, 

130 "cupy": _cupy.evaluate_constants, 

131 "torch": _torch.evaluate_constants, 

132 "jax": _jax.evaluate_constants, 

133} 

134 

135 

136def build_expression(backend, arrays, expr): 

137 """Build an expression, based on ``expr`` and initial arrays ``arrays``, 

138 that evaluates using backend ``backend``. 

139 """ 

140 return CONVERT_BACKENDS[backend](arrays, expr) 

141 

142 

143def evaluate_constants(backend, arrays, expr): 

144 """Convert constant arrays to the correct backend, and perform as much of 

145 the contraction of ``expr`` with these as possible. 

146 """ 

147 return EVAL_CONSTS_BACKENDS[backend](arrays, expr) 

148 

149 

150def has_backend(backend: str) -> bool: 

151 """Checks if the backend is known.""" 

152 return backend.lower() in CONVERT_BACKENDS