Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/sharing.py: 43%

102 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:41 +0000

1""" 

2A module for sharing intermediates between contractions. 

3 

4Copyright (c) 2018 Uber Technologies 

5""" 

6 

7import contextlib 

8import functools 

9import numbers 

10import threading 

11from collections import Counter, defaultdict 

12from typing import Any 

13from typing import Counter as CounterType 

14from typing import Dict, Generator, List, Optional, Tuple, Union 

15 

16from .parser import alpha_canonicalize, parse_einsum_input 

17from .typing import ArrayType 

18 

19CacheKeyType = Union[Tuple[str, str, int, Tuple[int, ...]], Tuple[str, int]] 

20CacheType = Dict[CacheKeyType, ArrayType] 

21 

22__all__ = [ 

23 "currently_sharing", 

24 "get_sharing_cache", 

25 "shared_intermediates", 

26 "count_cached_ops", 

27 "transpose_cache_wrap", 

28 "einsum_cache_wrap", 

29 "to_backend_cache_wrap", 

30] 

31 

32_SHARING_STACK: Dict[int, List[CacheType]] = defaultdict(list) 

33 

34 

35def currently_sharing() -> bool: 

36 """Check if we are currently sharing a cache -- thread specific.""" 

37 return threading.get_ident() in _SHARING_STACK 

38 

39 

40def get_sharing_cache() -> CacheType: 

41 """Return the most recent sharing cache -- thread specific.""" 

42 return _SHARING_STACK[threading.get_ident()][-1] 

43 

44 

45def _add_sharing_cache(cache: CacheType) -> Any: 

46 _SHARING_STACK[threading.get_ident()].append(cache) 

47 

48 

49def _remove_sharing_cache() -> None: 

50 tid = threading.get_ident() 

51 _SHARING_STACK[tid].pop() 

52 if not _SHARING_STACK[tid]: 

53 del _SHARING_STACK[tid] 

54 

55 

56@contextlib.contextmanager 

57def shared_intermediates( 

58 cache: Optional[CacheType] = None, 

59) -> Generator[CacheType, None, None]: 

60 """Context in which contract intermediate results are shared. 

61 

62 Note that intermediate computations will not be garbage collected until 

63 1. this context exits, and 

64 2. the yielded cache is garbage collected (if it was captured). 

65 

66 **Parameters:** 

67 

68 - **cache** - *(dict)* If specified, a user-stored dict in which intermediate results will be stored. This can be used to interleave sharing contexts. 

69 

70 **Returns:** 

71 

72 - **cache** - *(dict)* A dictionary in which sharing results are stored. If ignored, 

73 sharing results will be garbage collected when this context is 

74 exited. This dict can be passed to another context to resume 

75 sharing. 

76 """ 

77 if cache is None: 

78 cache = {} 

79 _add_sharing_cache(cache) 

80 try: 

81 yield cache 

82 finally: 

83 _remove_sharing_cache() 

84 

85 

86def count_cached_ops(cache: CacheType) -> CounterType[str]: 

87 """Returns a counter of the types of each op in the cache. 

88 This is useful for profiling to increase sharing. 

89 """ 

90 return Counter(key[0] for key in cache.keys()) 

91 

92 

93def _save_tensors(*tensors: ArrayType) -> None: 

94 """Save tensors in the cache to prevent their ids from being recycled. 

95 This is needed to prevent false cache lookups. 

96 """ 

97 cache = get_sharing_cache() 

98 for tensor in tensors: 

99 cache["tensor", id(tensor)] = tensor 

100 

101 

102def _memoize(key: CacheKeyType, fn: Any, *args: Any, **kwargs: Any) -> ArrayType: 

103 """Memoize ``fn(*args, **kwargs)`` using the given ``key``. 

104 Results will be stored in the innermost ``cache`` yielded by 

105 :func:`shared_intermediates`. 

106 """ 

107 cache = get_sharing_cache() 

108 if key in cache: 

109 return cache[key] 

110 result = fn(*args, **kwargs) 

111 cache[key] = result 

112 return result 

113 

114 

115def transpose_cache_wrap(transpose: Any) -> Any: 

116 """Decorates a ``transpose()`` implementation to be memoized inside a 

117 :func:`shared_intermediates` context. 

118 """ 

119 

120 @functools.wraps(transpose) 

121 def cached_transpose(a, axes, backend="numpy"): 

122 if not currently_sharing(): 

123 return transpose(a, axes, backend=backend) 

124 

125 # hash by axes 

126 _save_tensors(a) 

127 axes = tuple(axes) 

128 key = "transpose", backend, id(a), axes 

129 return _memoize(key, transpose, a, axes, backend=backend) 

130 

131 return cached_transpose 

132 

133 

134def tensordot_cache_wrap(tensordot: Any) -> Any: 

135 """Decorates a ``tensordot()`` implementation to be memoized inside a 

136 :func:`shared_intermediates` context. 

137 """ 

138 

139 @functools.wraps(tensordot) 

140 def cached_tensordot(x, y, axes=2, backend="numpy"): 

141 if not currently_sharing(): 

142 return tensordot(x, y, axes, backend=backend) 

143 

144 # hash based on the (axes_x,axes_y) form of axes 

145 _save_tensors(x, y) 

146 if isinstance(axes, numbers.Number): 

147 axes = ( 

148 list(range(len(x.shape)))[len(x.shape) - axes :], 

149 list(range(len(y.shape)))[:axes], 

150 ) 

151 axes = tuple(axes[0]), tuple(axes[1]) 

152 key = "tensordot", backend, id(x), id(y), axes 

153 return _memoize(key, tensordot, x, y, axes, backend=backend) 

154 

155 return cached_tensordot 

156 

157 

158def einsum_cache_wrap(einsum: Any) -> Any: 

159 """Decorates an ``einsum()`` implementation to be memoized inside a 

160 :func:`shared_intermediates` context. 

161 """ 

162 

163 @functools.wraps(einsum) 

164 def cached_einsum(*args, **kwargs): 

165 if not currently_sharing(): 

166 return einsum(*args, **kwargs) 

167 

168 # hash modulo commutativity by computing a canonical ordering and names 

169 backend = kwargs.pop("backend", "numpy") 

170 equation = args[0] 

171 inputs, output, operands = parse_einsum_input(args) 

172 inputs = inputs.split(",") 

173 

174 _save_tensors(*operands) 

175 

176 # Build canonical key 

177 canonical = sorted(zip(inputs, map(id, operands)), key=lambda x: x[1]) 

178 canonical_ids = tuple(id_ for _, id_ in canonical) 

179 canonical_inputs = ",".join(input_ for input_, _ in canonical) 

180 canonical_equation = alpha_canonicalize(canonical_inputs + "->" + output) 

181 

182 key = "einsum", backend, canonical_equation, canonical_ids 

183 return _memoize(key, einsum, equation, *operands, backend=backend) 

184 

185 return cached_einsum 

186 

187 

188def to_backend_cache_wrap(to_backend: Any = None, constants: Any = False) -> Any: 

189 """Decorates an ``to_backend()`` implementation to be memoized inside a 

190 :func:`shared_intermediates` context (e.g. ``to_cupy``, ``to_torch``). 

191 """ 

192 # manage the case that decorator is called with args 

193 if to_backend is None: 

194 return functools.partial(to_backend_cache_wrap, constants=constants) 

195 

196 if constants: 

197 

198 @functools.wraps(to_backend) 

199 def cached_to_backend(array, constant=False): 

200 if not currently_sharing(): 

201 return to_backend(array, constant=constant) 

202 

203 # hash by id 

204 key = to_backend.__name__, id(array), constant 

205 return _memoize(key, to_backend, array, constant=constant) 

206 

207 else: 

208 

209 @functools.wraps(to_backend) 

210 def cached_to_backend(array): 

211 if not currently_sharing(): 

212 return to_backend(array) 

213 

214 # hash by id 

215 key = to_backend.__name__, id(array) 

216 return _memoize(key, to_backend, array) 

217 

218 return cached_to_backend