Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/sharing.py: 43%
102 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
1"""
2A module for sharing intermediates between contractions.
4Copyright (c) 2018 Uber Technologies
5"""
7import contextlib
8import functools
9import numbers
10import threading
11from collections import Counter, defaultdict
12from typing import Any
13from typing import Counter as CounterType
14from typing import Dict, Generator, List, Optional, Tuple, Union
16from .parser import alpha_canonicalize, parse_einsum_input
17from .typing import ArrayType
19CacheKeyType = Union[Tuple[str, str, int, Tuple[int, ...]], Tuple[str, int]]
20CacheType = Dict[CacheKeyType, ArrayType]
22__all__ = [
23 "currently_sharing",
24 "get_sharing_cache",
25 "shared_intermediates",
26 "count_cached_ops",
27 "transpose_cache_wrap",
28 "einsum_cache_wrap",
29 "to_backend_cache_wrap",
30]
32_SHARING_STACK: Dict[int, List[CacheType]] = defaultdict(list)
35def currently_sharing() -> bool:
36 """Check if we are currently sharing a cache -- thread specific."""
37 return threading.get_ident() in _SHARING_STACK
40def get_sharing_cache() -> CacheType:
41 """Return the most recent sharing cache -- thread specific."""
42 return _SHARING_STACK[threading.get_ident()][-1]
45def _add_sharing_cache(cache: CacheType) -> Any:
46 _SHARING_STACK[threading.get_ident()].append(cache)
49def _remove_sharing_cache() -> None:
50 tid = threading.get_ident()
51 _SHARING_STACK[tid].pop()
52 if not _SHARING_STACK[tid]:
53 del _SHARING_STACK[tid]
56@contextlib.contextmanager
57def shared_intermediates(
58 cache: Optional[CacheType] = None,
59) -> Generator[CacheType, None, None]:
60 """Context in which contract intermediate results are shared.
62 Note that intermediate computations will not be garbage collected until
63 1. this context exits, and
64 2. the yielded cache is garbage collected (if it was captured).
66 **Parameters:**
68 - **cache** - *(dict)* If specified, a user-stored dict in which intermediate results will be stored. This can be used to interleave sharing contexts.
70 **Returns:**
72 - **cache** - *(dict)* A dictionary in which sharing results are stored. If ignored,
73 sharing results will be garbage collected when this context is
74 exited. This dict can be passed to another context to resume
75 sharing.
76 """
77 if cache is None:
78 cache = {}
79 _add_sharing_cache(cache)
80 try:
81 yield cache
82 finally:
83 _remove_sharing_cache()
86def count_cached_ops(cache: CacheType) -> CounterType[str]:
87 """Returns a counter of the types of each op in the cache.
88 This is useful for profiling to increase sharing.
89 """
90 return Counter(key[0] for key in cache.keys())
93def _save_tensors(*tensors: ArrayType) -> None:
94 """Save tensors in the cache to prevent their ids from being recycled.
95 This is needed to prevent false cache lookups.
96 """
97 cache = get_sharing_cache()
98 for tensor in tensors:
99 cache["tensor", id(tensor)] = tensor
102def _memoize(key: CacheKeyType, fn: Any, *args: Any, **kwargs: Any) -> ArrayType:
103 """Memoize ``fn(*args, **kwargs)`` using the given ``key``.
104 Results will be stored in the innermost ``cache`` yielded by
105 :func:`shared_intermediates`.
106 """
107 cache = get_sharing_cache()
108 if key in cache:
109 return cache[key]
110 result = fn(*args, **kwargs)
111 cache[key] = result
112 return result
115def transpose_cache_wrap(transpose: Any) -> Any:
116 """Decorates a ``transpose()`` implementation to be memoized inside a
117 :func:`shared_intermediates` context.
118 """
120 @functools.wraps(transpose)
121 def cached_transpose(a, axes, backend="numpy"):
122 if not currently_sharing():
123 return transpose(a, axes, backend=backend)
125 # hash by axes
126 _save_tensors(a)
127 axes = tuple(axes)
128 key = "transpose", backend, id(a), axes
129 return _memoize(key, transpose, a, axes, backend=backend)
131 return cached_transpose
134def tensordot_cache_wrap(tensordot: Any) -> Any:
135 """Decorates a ``tensordot()`` implementation to be memoized inside a
136 :func:`shared_intermediates` context.
137 """
139 @functools.wraps(tensordot)
140 def cached_tensordot(x, y, axes=2, backend="numpy"):
141 if not currently_sharing():
142 return tensordot(x, y, axes, backend=backend)
144 # hash based on the (axes_x,axes_y) form of axes
145 _save_tensors(x, y)
146 if isinstance(axes, numbers.Number):
147 axes = (
148 list(range(len(x.shape)))[len(x.shape) - axes :],
149 list(range(len(y.shape)))[:axes],
150 )
151 axes = tuple(axes[0]), tuple(axes[1])
152 key = "tensordot", backend, id(x), id(y), axes
153 return _memoize(key, tensordot, x, y, axes, backend=backend)
155 return cached_tensordot
158def einsum_cache_wrap(einsum: Any) -> Any:
159 """Decorates an ``einsum()`` implementation to be memoized inside a
160 :func:`shared_intermediates` context.
161 """
163 @functools.wraps(einsum)
164 def cached_einsum(*args, **kwargs):
165 if not currently_sharing():
166 return einsum(*args, **kwargs)
168 # hash modulo commutativity by computing a canonical ordering and names
169 backend = kwargs.pop("backend", "numpy")
170 equation = args[0]
171 inputs, output, operands = parse_einsum_input(args)
172 inputs = inputs.split(",")
174 _save_tensors(*operands)
176 # Build canonical key
177 canonical = sorted(zip(inputs, map(id, operands)), key=lambda x: x[1])
178 canonical_ids = tuple(id_ for _, id_ in canonical)
179 canonical_inputs = ",".join(input_ for input_, _ in canonical)
180 canonical_equation = alpha_canonicalize(canonical_inputs + "->" + output)
182 key = "einsum", backend, canonical_equation, canonical_ids
183 return _memoize(key, einsum, equation, *operands, backend=backend)
185 return cached_einsum
188def to_backend_cache_wrap(to_backend: Any = None, constants: Any = False) -> Any:
189 """Decorates an ``to_backend()`` implementation to be memoized inside a
190 :func:`shared_intermediates` context (e.g. ``to_cupy``, ``to_torch``).
191 """
192 # manage the case that decorator is called with args
193 if to_backend is None:
194 return functools.partial(to_backend_cache_wrap, constants=constants)
196 if constants:
198 @functools.wraps(to_backend)
199 def cached_to_backend(array, constant=False):
200 if not currently_sharing():
201 return to_backend(array, constant=constant)
203 # hash by id
204 key = to_backend.__name__, id(array), constant
205 return _memoize(key, to_backend, array, constant=constant)
207 else:
209 @functools.wraps(to_backend)
210 def cached_to_backend(array):
211 if not currently_sharing():
212 return to_backend(array)
214 # hash by id
215 key = to_backend.__name__, id(array)
216 return _memoize(key, to_backend, array)
218 return cached_to_backend