Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/helpers.py: 22%
77 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
1"""
2Contains helper functions for opt_einsum testing scripts
3"""
5from typing import Any, Collection, Dict, FrozenSet, Iterable, List, Optional, Tuple, Union, overload
7import numpy as np
9from .parser import get_symbol
10from .typing import ArrayIndexType, PathType
12__all__ = ["build_views", "compute_size_by_dict", "find_contraction", "flop_count"]
14_valid_chars = "abcdefghijklmopqABC"
15_sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4])
16_default_dim_dict = {c: s for c, s in zip(_valid_chars, _sizes)}
19def build_views(string: str, dimension_dict: Optional[Dict[str, int]] = None) -> List[np.ndarray]:
20 """
21 Builds random numpy arrays for testing.
23 Parameters
24 ----------
25 string : str
26 List of tensor strings to build
27 dimension_dict : dictionary
28 Dictionary of index _sizes
30 Returns
31 -------
32 ret : list of np.ndarry's
33 The resulting views.
35 Examples
36 --------
37 >>> view = build_views('abbc', {'a': 2, 'b':3, 'c':5})
38 >>> view[0].shape
39 (2, 3, 3, 5)
41 """
43 if dimension_dict is None:
44 dimension_dict = _default_dim_dict
46 views = []
47 terms = string.split("->")[0].split(",")
48 for term in terms:
49 dims = [dimension_dict[x] for x in term]
50 views.append(np.random.rand(*dims))
51 return views
54@overload
55def compute_size_by_dict(indices: Iterable[int], idx_dict: List[int]) -> int:
56 ...
59@overload
60def compute_size_by_dict(indices: Collection[str], idx_dict: Dict[str, int]) -> int:
61 ...
64def compute_size_by_dict(indices: Any, idx_dict: Any) -> int:
65 """
66 Computes the product of the elements in indices based on the dictionary
67 idx_dict.
69 Parameters
70 ----------
71 indices : iterable
72 Indices to base the product on.
73 idx_dict : dictionary
74 Dictionary of index _sizes
76 Returns
77 -------
78 ret : int
79 The resulting product.
81 Examples
82 --------
83 >>> compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
84 90
86 """
87 ret = 1
88 for i in indices: # lgtm [py/iteration-string-and-sequence]
89 ret *= idx_dict[i]
90 return ret
93def find_contraction(
94 positions: Collection[int],
95 input_sets: List[ArrayIndexType],
96 output_set: ArrayIndexType,
97) -> Tuple[FrozenSet[str], List[ArrayIndexType], ArrayIndexType, ArrayIndexType]:
98 """
99 Finds the contraction for a given set of input and output sets.
101 Parameters
102 ----------
103 positions : iterable
104 Integer positions of terms used in the contraction.
105 input_sets : list
106 List of sets that represent the lhs side of the einsum subscript
107 output_set : set
108 Set that represents the rhs side of the overall einsum subscript
110 Returns
111 -------
112 new_result : set
113 The indices of the resulting contraction
114 remaining : list
115 List of sets that have not been contracted, the new set is appended to
116 the end of this list
117 idx_removed : set
118 Indices removed from the entire contraction
119 idx_contraction : set
120 The indices used in the current contraction
122 Examples
123 --------
125 # A simple dot product test case
126 >>> pos = (0, 1)
127 >>> isets = [set('ab'), set('bc')]
128 >>> oset = set('ac')
129 >>> find_contraction(pos, isets, oset)
130 ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
132 # A more complex case with additional terms in the contraction
133 >>> pos = (0, 2)
134 >>> isets = [set('abd'), set('ac'), set('bdc')]
135 >>> oset = set('ac')
136 >>> find_contraction(pos, isets, oset)
137 ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
138 """
140 remaining = list(input_sets)
141 inputs = (remaining.pop(i) for i in sorted(positions, reverse=True))
142 idx_contract = frozenset.union(*inputs)
143 idx_remain = output_set.union(*remaining)
145 new_result = idx_remain & idx_contract
146 idx_removed = idx_contract - new_result
147 remaining.append(new_result)
149 return new_result, remaining, idx_removed, idx_contract
152def flop_count(
153 idx_contraction: Collection[str],
154 inner: bool,
155 num_terms: int,
156 size_dictionary: Dict[str, int],
157) -> int:
158 """
159 Computes the number of FLOPS in the contraction.
161 Parameters
162 ----------
163 idx_contraction : iterable
164 The indices involved in the contraction
165 inner : bool
166 Does this contraction require an inner product?
167 num_terms : int
168 The number of terms in a contraction
169 size_dictionary : dict
170 The size of each of the indices in idx_contraction
172 Returns
173 -------
174 flop_count : int
175 The total number of FLOPS required for the contraction.
177 Examples
178 --------
180 >>> flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
181 30
183 >>> flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
184 60
186 """
188 overall_size = compute_size_by_dict(idx_contraction, size_dictionary)
189 op_factor = max(1, num_terms - 1)
190 if inner:
191 op_factor += 1
193 return overall_size * op_factor
196def rand_equation(
197 n: int,
198 reg: int,
199 n_out: int = 0,
200 d_min: int = 2,
201 d_max: int = 9,
202 seed: Optional[int] = None,
203 global_dim: bool = False,
204 return_size_dict: bool = False,
205) -> Union[Tuple[str, PathType, Dict[str, int]], Tuple[str, PathType]]:
206 """Generate a random contraction and shapes.
208 Parameters
209 ----------
210 n : int
211 Number of array arguments.
212 reg : int
213 'Regularity' of the contraction graph. This essentially determines how
214 many indices each tensor shares with others on average.
215 n_out : int, optional
216 Number of output indices (i.e. the number of non-contracted indices).
217 Defaults to 0, i.e., a contraction resulting in a scalar.
218 d_min : int, optional
219 Minimum dimension size.
220 d_max : int, optional
221 Maximum dimension size.
222 seed: int, optional
223 If not None, seed numpy's random generator with this.
224 global_dim : bool, optional
225 Add a global, 'broadcast', dimension to every operand.
226 return_size_dict : bool, optional
227 Return the mapping of indices to sizes.
229 Returns
230 -------
231 eq : str
232 The equation string.
233 shapes : list[tuple[int]]
234 The array shapes.
235 size_dict : dict[str, int]
236 The dict of index sizes, only returned if ``return_size_dict=True``.
238 Examples
239 --------
240 >>> eq, shapes = rand_equation(n=10, reg=4, n_out=5, seed=42)
241 >>> eq
242 'oyeqn,tmaq,skpo,vg,hxui,n,fwxmr,hitplcj,kudlgfv,rywjsb->cebda'
244 >>> shapes
245 [(9, 5, 4, 5, 4),
246 (4, 4, 8, 5),
247 (9, 4, 6, 9),
248 (6, 6),
249 (6, 9, 7, 8),
250 (4,),
251 (9, 3, 9, 4, 9),
252 (6, 8, 4, 6, 8, 6, 3),
253 (4, 7, 8, 8, 6, 9, 6),
254 (9, 5, 3, 3, 9, 5)]
255 """
257 if seed is not None:
258 np.random.seed(seed)
260 # total number of indices
261 num_inds = n * reg // 2 + n_out
262 inputs = ["" for _ in range(n)]
263 output = []
265 size_dict = {get_symbol(i): np.random.randint(d_min, d_max + 1) for i in range(num_inds)}
267 # generate a list of indices to place either once or twice
268 def gen():
269 for i, ix in enumerate(size_dict):
270 # generate an outer index
271 if i < n_out:
272 output.append(ix)
273 yield ix
274 # generate a bond
275 else:
276 yield ix
277 yield ix
279 # add the indices randomly to the inputs
280 for i, ix in enumerate(np.random.permutation(list(gen()))):
281 # make sure all inputs have at least one index
282 if i < n:
283 inputs[i] += ix
284 else:
285 # don't add any traces on same op
286 where = np.random.randint(0, n)
287 while ix in inputs[where]:
288 where = np.random.randint(0, n)
290 inputs[where] += ix
292 # possibly add the same global dim to every arg
293 if global_dim:
294 gdim = get_symbol(num_inds)
295 size_dict[gdim] = np.random.randint(d_min, d_max + 1)
296 for i in range(n):
297 inputs[i] += gdim
298 output += gdim
300 # randomly transpose the output indices and form equation
301 output = "".join(np.random.permutation(output)) # type: ignore
302 eq = "{}->{}".format(",".join(inputs), output)
304 # make the shapes
305 shapes = [tuple(size_dict[ix] for ix in op) for op in inputs]
307 ret = (eq, shapes)
309 if return_size_dict:
310 return ret + (size_dict,)
311 else:
312 return ret