Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/sharing.py: 40%
96 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""
2A module for sharing intermediates between contractions.
4Copyright (c) 2018 Uber Technologies
5"""
7import contextlib
8import functools
9import numbers
10import threading
11from collections import Counter, defaultdict
13from .parser import alpha_canonicalize, parse_einsum_input
15__all__ = [
16 "currently_sharing", "get_sharing_cache", "shared_intermediates", "count_cached_ops", "transpose_cache_wrap",
17 "einsum_cache_wrap", "to_backend_cache_wrap"
18]
20_SHARING_STACK = defaultdict(list)
23def currently_sharing():
24 """Check if we are currently sharing a cache -- thread specific.
25 """
26 return threading.get_ident() in _SHARING_STACK
29def get_sharing_cache():
30 """Return the most recent sharing cache -- thread specific.
31 """
32 return _SHARING_STACK[threading.get_ident()][-1]
35def _add_sharing_cache(cache):
36 _SHARING_STACK[threading.get_ident()].append(cache)
39def _remove_sharing_cache():
40 tid = threading.get_ident()
41 _SHARING_STACK[tid].pop()
42 if not _SHARING_STACK[tid]:
43 del _SHARING_STACK[tid]
46@contextlib.contextmanager
47def shared_intermediates(cache=None):
48 """Context in which contract intermediate results are shared.
50 Note that intermediate computations will not be garbage collected until
51 1. this context exits, and
52 2. the yielded cache is garbage collected (if it was captured).
54 Parameters
55 ----------
56 cache : dict
57 If specified, a user-stored dict in which intermediate results will
58 be stored. This can be used to interleave sharing contexts.
60 Returns
61 -------
62 cache : dict
63 A dictionary in which sharing results are stored. If ignored,
64 sharing results will be garbage collected when this context is
65 exited. This dict can be passed to another context to resume
66 sharing.
67 """
68 if cache is None:
69 cache = {}
70 _add_sharing_cache(cache)
71 try:
72 yield cache
73 finally:
74 _remove_sharing_cache()
77def count_cached_ops(cache):
78 """Returns a counter of the types of each op in the cache.
79 This is useful for profiling to increase sharing.
80 """
81 return Counter(key[0] for key in cache.keys())
84def _save_tensors(*tensors):
85 """Save tensors in the cache to prevent their ids from being recycled.
86 This is needed to prevent false cache lookups.
87 """
88 cache = get_sharing_cache()
89 for tensor in tensors:
90 cache['tensor', id(tensor)] = tensor
93def _memoize(key, fn, *args, **kwargs):
94 """Memoize ``fn(*args, **kwargs)`` using the given ``key``.
95 Results will be stored in the innermost ``cache`` yielded by
96 :func:`shared_intermediates`.
97 """
98 cache = get_sharing_cache()
99 if key in cache:
100 return cache[key]
101 result = fn(*args, **kwargs)
102 cache[key] = result
103 return result
106def transpose_cache_wrap(transpose):
107 """Decorates a ``transpose()`` implementation to be memoized inside a
108 :func:`shared_intermediates` context.
109 """
110 @functools.wraps(transpose)
111 def cached_transpose(a, axes, backend='numpy'):
112 if not currently_sharing():
113 return transpose(a, axes, backend=backend)
115 # hash by axes
116 _save_tensors(a)
117 axes = tuple(axes)
118 key = 'transpose', backend, id(a), axes
119 return _memoize(key, transpose, a, axes, backend=backend)
121 return cached_transpose
124def tensordot_cache_wrap(tensordot):
125 """Decorates a ``tensordot()`` implementation to be memoized inside a
126 :func:`shared_intermediates` context.
127 """
128 @functools.wraps(tensordot)
129 def cached_tensordot(x, y, axes=2, backend='numpy'):
130 if not currently_sharing():
131 return tensordot(x, y, axes, backend=backend)
133 # hash based on the (axes_x,axes_y) form of axes
134 _save_tensors(x, y)
135 if isinstance(axes, numbers.Number):
136 axes = list(range(len(x.shape)))[len(x.shape) - axes:], list(range(len(y.shape)))[:axes]
137 axes = tuple(axes[0]), tuple(axes[1])
138 key = 'tensordot', backend, id(x), id(y), axes
139 return _memoize(key, tensordot, x, y, axes, backend=backend)
141 return cached_tensordot
144def einsum_cache_wrap(einsum):
145 """Decorates an ``einsum()`` implementation to be memoized inside a
146 :func:`shared_intermediates` context.
147 """
148 @functools.wraps(einsum)
149 def cached_einsum(*args, **kwargs):
150 if not currently_sharing():
151 return einsum(*args, **kwargs)
153 # hash modulo commutativity by computing a canonical ordering and names
154 backend = kwargs.pop('backend', 'numpy')
155 equation = args[0]
156 inputs, output, operands = parse_einsum_input(args)
157 inputs = inputs.split(',')
159 _save_tensors(*operands)
161 # Build canonical key
162 canonical = sorted(zip(inputs, map(id, operands)), key=lambda x: x[1])
163 canonical_ids = tuple(id_ for _, id_ in canonical)
164 canonical_inputs = ','.join(input_ for input_, _ in canonical)
165 canonical_equation = alpha_canonicalize(canonical_inputs + "->" + output)
167 key = 'einsum', backend, canonical_equation, canonical_ids
168 return _memoize(key, einsum, equation, *operands, backend=backend)
170 return cached_einsum
173def to_backend_cache_wrap(to_backend=None, constants=False):
174 """Decorates an ``to_backend()`` implementation to be memoized inside a
175 :func:`shared_intermediates` context (e.g. ``to_cupy``, ``to_torch``).
176 """
177 # manage the case that decorator is called with args
178 if to_backend is None:
179 return functools.partial(to_backend_cache_wrap, constants=constants)
181 if constants:
183 @functools.wraps(to_backend)
184 def cached_to_backend(array, constant=False):
185 if not currently_sharing():
186 return to_backend(array, constant=constant)
188 # hash by id
189 key = to_backend.__name__, id(array), constant
190 return _memoize(key, to_backend, array, constant=constant)
192 else:
194 @functools.wraps(to_backend)
195 def cached_to_backend(array):
196 if not currently_sharing():
197 return to_backend(array)
199 # hash by id
200 key = to_backend.__name__, id(array)
201 return _memoize(key, to_backend, array)
203 return cached_to_backend