Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/sharing.py: 40%

96 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1""" 

2A module for sharing intermediates between contractions. 

3 

4Copyright (c) 2018 Uber Technologies 

5""" 

6 

7import contextlib 

8import functools 

9import numbers 

10import threading 

11from collections import Counter, defaultdict 

12 

13from .parser import alpha_canonicalize, parse_einsum_input 

14 

15__all__ = [ 

16 "currently_sharing", "get_sharing_cache", "shared_intermediates", "count_cached_ops", "transpose_cache_wrap", 

17 "einsum_cache_wrap", "to_backend_cache_wrap" 

18] 

19 

20_SHARING_STACK = defaultdict(list) 

21 

22 

23def currently_sharing(): 

24 """Check if we are currently sharing a cache -- thread specific. 

25 """ 

26 return threading.get_ident() in _SHARING_STACK 

27 

28 

29def get_sharing_cache(): 

30 """Return the most recent sharing cache -- thread specific. 

31 """ 

32 return _SHARING_STACK[threading.get_ident()][-1] 

33 

34 

35def _add_sharing_cache(cache): 

36 _SHARING_STACK[threading.get_ident()].append(cache) 

37 

38 

39def _remove_sharing_cache(): 

40 tid = threading.get_ident() 

41 _SHARING_STACK[tid].pop() 

42 if not _SHARING_STACK[tid]: 

43 del _SHARING_STACK[tid] 

44 

45 

46@contextlib.contextmanager 

47def shared_intermediates(cache=None): 

48 """Context in which contract intermediate results are shared. 

49 

50 Note that intermediate computations will not be garbage collected until 

51 1. this context exits, and 

52 2. the yielded cache is garbage collected (if it was captured). 

53 

54 Parameters 

55 ---------- 

56 cache : dict 

57 If specified, a user-stored dict in which intermediate results will 

58 be stored. This can be used to interleave sharing contexts. 

59 

60 Returns 

61 ------- 

62 cache : dict 

63 A dictionary in which sharing results are stored. If ignored, 

64 sharing results will be garbage collected when this context is 

65 exited. This dict can be passed to another context to resume 

66 sharing. 

67 """ 

68 if cache is None: 

69 cache = {} 

70 _add_sharing_cache(cache) 

71 try: 

72 yield cache 

73 finally: 

74 _remove_sharing_cache() 

75 

76 

77def count_cached_ops(cache): 

78 """Returns a counter of the types of each op in the cache. 

79 This is useful for profiling to increase sharing. 

80 """ 

81 return Counter(key[0] for key in cache.keys()) 

82 

83 

84def _save_tensors(*tensors): 

85 """Save tensors in the cache to prevent their ids from being recycled. 

86 This is needed to prevent false cache lookups. 

87 """ 

88 cache = get_sharing_cache() 

89 for tensor in tensors: 

90 cache['tensor', id(tensor)] = tensor 

91 

92 

93def _memoize(key, fn, *args, **kwargs): 

94 """Memoize ``fn(*args, **kwargs)`` using the given ``key``. 

95 Results will be stored in the innermost ``cache`` yielded by 

96 :func:`shared_intermediates`. 

97 """ 

98 cache = get_sharing_cache() 

99 if key in cache: 

100 return cache[key] 

101 result = fn(*args, **kwargs) 

102 cache[key] = result 

103 return result 

104 

105 

106def transpose_cache_wrap(transpose): 

107 """Decorates a ``transpose()`` implementation to be memoized inside a 

108 :func:`shared_intermediates` context. 

109 """ 

110 @functools.wraps(transpose) 

111 def cached_transpose(a, axes, backend='numpy'): 

112 if not currently_sharing(): 

113 return transpose(a, axes, backend=backend) 

114 

115 # hash by axes 

116 _save_tensors(a) 

117 axes = tuple(axes) 

118 key = 'transpose', backend, id(a), axes 

119 return _memoize(key, transpose, a, axes, backend=backend) 

120 

121 return cached_transpose 

122 

123 

124def tensordot_cache_wrap(tensordot): 

125 """Decorates a ``tensordot()`` implementation to be memoized inside a 

126 :func:`shared_intermediates` context. 

127 """ 

128 @functools.wraps(tensordot) 

129 def cached_tensordot(x, y, axes=2, backend='numpy'): 

130 if not currently_sharing(): 

131 return tensordot(x, y, axes, backend=backend) 

132 

133 # hash based on the (axes_x,axes_y) form of axes 

134 _save_tensors(x, y) 

135 if isinstance(axes, numbers.Number): 

136 axes = list(range(len(x.shape)))[len(x.shape) - axes:], list(range(len(y.shape)))[:axes] 

137 axes = tuple(axes[0]), tuple(axes[1]) 

138 key = 'tensordot', backend, id(x), id(y), axes 

139 return _memoize(key, tensordot, x, y, axes, backend=backend) 

140 

141 return cached_tensordot 

142 

143 

144def einsum_cache_wrap(einsum): 

145 """Decorates an ``einsum()`` implementation to be memoized inside a 

146 :func:`shared_intermediates` context. 

147 """ 

148 @functools.wraps(einsum) 

149 def cached_einsum(*args, **kwargs): 

150 if not currently_sharing(): 

151 return einsum(*args, **kwargs) 

152 

153 # hash modulo commutativity by computing a canonical ordering and names 

154 backend = kwargs.pop('backend', 'numpy') 

155 equation = args[0] 

156 inputs, output, operands = parse_einsum_input(args) 

157 inputs = inputs.split(',') 

158 

159 _save_tensors(*operands) 

160 

161 # Build canonical key 

162 canonical = sorted(zip(inputs, map(id, operands)), key=lambda x: x[1]) 

163 canonical_ids = tuple(id_ for _, id_ in canonical) 

164 canonical_inputs = ','.join(input_ for input_, _ in canonical) 

165 canonical_equation = alpha_canonicalize(canonical_inputs + "->" + output) 

166 

167 key = 'einsum', backend, canonical_equation, canonical_ids 

168 return _memoize(key, einsum, equation, *operands, backend=backend) 

169 

170 return cached_einsum 

171 

172 

173def to_backend_cache_wrap(to_backend=None, constants=False): 

174 """Decorates an ``to_backend()`` implementation to be memoized inside a 

175 :func:`shared_intermediates` context (e.g. ``to_cupy``, ``to_torch``). 

176 """ 

177 # manage the case that decorator is called with args 

178 if to_backend is None: 

179 return functools.partial(to_backend_cache_wrap, constants=constants) 

180 

181 if constants: 

182 

183 @functools.wraps(to_backend) 

184 def cached_to_backend(array, constant=False): 

185 if not currently_sharing(): 

186 return to_backend(array, constant=constant) 

187 

188 # hash by id 

189 key = to_backend.__name__, id(array), constant 

190 return _memoize(key, to_backend, array, constant=constant) 

191 

192 else: 

193 

194 @functools.wraps(to_backend) 

195 def cached_to_backend(array): 

196 if not currently_sharing(): 

197 return to_backend(array) 

198 

199 # hash by id 

200 key = to_backend.__name__, id(array) 

201 return _memoize(key, to_backend, array) 

202 

203 return cached_to_backend