Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/helpers.py: 17%
70 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""
2Contains helper functions for opt_einsum testing scripts
3"""
5from collections import OrderedDict
7import numpy as np
9from .parser import get_symbol
11__all__ = ["build_views", "compute_size_by_dict", "find_contraction", "flop_count"]
13_valid_chars = "abcdefghijklmopqABC"
14_sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4])
15_default_dim_dict = {c: s for c, s in zip(_valid_chars, _sizes)}
18def build_views(string, dimension_dict=None):
19 """
20 Builds random numpy arrays for testing.
22 Parameters
23 ----------
24 string : list of str
25 List of tensor strings to build
26 dimension_dict : dictionary
27 Dictionary of index _sizes
29 Returns
30 -------
31 ret : list of np.ndarry's
32 The resulting views.
34 Examples
35 --------
36 >>> view = build_views(['abbc'], {'a': 2, 'b':3, 'c':5})
37 >>> view[0].shape
38 (2, 3, 3, 5)
40 """
42 if dimension_dict is None:
43 dimension_dict = _default_dim_dict
45 views = []
46 terms = string.split('->')[0].split(',')
47 for term in terms:
48 dims = [dimension_dict[x] for x in term]
49 views.append(np.random.rand(*dims))
50 return views
53def compute_size_by_dict(indices, idx_dict):
54 """
55 Computes the product of the elements in indices based on the dictionary
56 idx_dict.
58 Parameters
59 ----------
60 indices : iterable
61 Indices to base the product on.
62 idx_dict : dictionary
63 Dictionary of index _sizes
65 Returns
66 -------
67 ret : int
68 The resulting product.
70 Examples
71 --------
72 >>> compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
73 90
75 """
76 ret = 1
77 for i in indices: # lgtm [py/iteration-string-and-sequence]
78 ret *= idx_dict[i]
79 return ret
82def find_contraction(positions, input_sets, output_set):
83 """
84 Finds the contraction for a given set of input and output sets.
86 Parameters
87 ----------
88 positions : iterable
89 Integer positions of terms used in the contraction.
90 input_sets : list
91 List of sets that represent the lhs side of the einsum subscript
92 output_set : set
93 Set that represents the rhs side of the overall einsum subscript
95 Returns
96 -------
97 new_result : set
98 The indices of the resulting contraction
99 remaining : list
100 List of sets that have not been contracted, the new set is appended to
101 the end of this list
102 idx_removed : set
103 Indices removed from the entire contraction
104 idx_contraction : set
105 The indices used in the current contraction
107 Examples
108 --------
110 # A simple dot product test case
111 >>> pos = (0, 1)
112 >>> isets = [set('ab'), set('bc')]
113 >>> oset = set('ac')
114 >>> find_contraction(pos, isets, oset)
115 ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
117 # A more complex case with additional terms in the contraction
118 >>> pos = (0, 2)
119 >>> isets = [set('abd'), set('ac'), set('bdc')]
120 >>> oset = set('ac')
121 >>> find_contraction(pos, isets, oset)
122 ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
123 """
125 remaining = list(input_sets)
126 inputs = (remaining.pop(i) for i in sorted(positions, reverse=True))
127 idx_contract = set.union(*inputs)
128 idx_remain = output_set.union(*remaining)
130 new_result = idx_remain & idx_contract
131 idx_removed = (idx_contract - new_result)
132 remaining.append(new_result)
134 return new_result, remaining, idx_removed, idx_contract
137def flop_count(idx_contraction, inner, num_terms, size_dictionary):
138 """
139 Computes the number of FLOPS in the contraction.
141 Parameters
142 ----------
143 idx_contraction : iterable
144 The indices involved in the contraction
145 inner : bool
146 Does this contraction require an inner product?
147 num_terms : int
148 The number of terms in a contraction
149 size_dictionary : dict
150 The size of each of the indices in idx_contraction
152 Returns
153 -------
154 flop_count : int
155 The total number of FLOPS required for the contraction.
157 Examples
158 --------
160 >>> flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
161 90
163 >>> flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
164 270
166 """
168 overall_size = compute_size_by_dict(idx_contraction, size_dictionary)
169 op_factor = max(1, num_terms - 1)
170 if inner:
171 op_factor += 1
173 return overall_size * op_factor
176def rand_equation(n, reg, n_out=0, d_min=2, d_max=9, seed=None, global_dim=False, return_size_dict=False):
177 """Generate a random contraction and shapes.
179 Parameters
180 ----------
181 n : int
182 Number of array arguments.
183 reg : int
184 'Regularity' of the contraction graph. This essentially determines how
185 many indices each tensor shares with others on average.
186 n_out : int, optional
187 Number of output indices (i.e. the number of non-contracted indices).
188 Defaults to 0, i.e., a contraction resulting in a scalar.
189 d_min : int, optional
190 Minimum dimension size.
191 d_max : int, optional
192 Maximum dimension size.
193 seed: int, optional
194 If not None, seed numpy's random generator with this.
195 global_dim : bool, optional
196 Add a global, 'broadcast', dimension to every operand.
197 return_size_dict : bool, optional
198 Return the mapping of indices to sizes.
200 Returns
201 -------
202 eq : str
203 The equation string.
204 shapes : list[tuple[int]]
205 The array shapes.
206 size_dict : dict[str, int]
207 The dict of index sizes, only returned if ``return_size_dict=True``.
209 Examples
210 --------
211 >>> eq, shapes = rand_equation(n=10, reg=4, n_out=5, seed=42)
212 >>> eq
213 'oyeqn,tmaq,skpo,vg,hxui,n,fwxmr,hitplcj,kudlgfv,rywjsb->cebda'
215 >>> shapes
216 [(9, 5, 4, 5, 4),
217 (4, 4, 8, 5),
218 (9, 4, 6, 9),
219 (6, 6),
220 (6, 9, 7, 8),
221 (4,),
222 (9, 3, 9, 4, 9),
223 (6, 8, 4, 6, 8, 6, 3),
224 (4, 7, 8, 8, 6, 9, 6),
225 (9, 5, 3, 3, 9, 5)]
226 """
228 if seed is not None:
229 np.random.seed(seed)
231 # total number of indices
232 num_inds = n * reg // 2 + n_out
233 inputs = ["" for _ in range(n)]
234 output = []
236 size_dict = OrderedDict((get_symbol(i), np.random.randint(d_min, d_max + 1)) for i in range(num_inds))
238 # generate a list of indices to place either once or twice
239 def gen():
240 for i, ix in enumerate(size_dict):
241 # generate an outer index
242 if i < n_out:
243 output.append(ix)
244 yield ix
245 # generate a bond
246 else:
247 yield ix
248 yield ix
250 # add the indices randomly to the inputs
251 for i, ix in enumerate(np.random.permutation(list(gen()))):
252 # make sure all inputs have at least one index
253 if i < n:
254 inputs[i] += ix
255 else:
256 # don't add any traces on same op
257 where = np.random.randint(0, n)
258 while ix in inputs[where]:
259 where = np.random.randint(0, n)
261 inputs[where] += ix
263 # possibly add the same global dim to every arg
264 if global_dim:
265 gdim = get_symbol(num_inds)
266 size_dict[gdim] = np.random.randint(d_min, d_max + 1)
267 for i in range(n):
268 inputs[i] += gdim
269 output += gdim
271 # randomly transpose the output indices and form equation
272 output = "".join(np.random.permutation(output))
273 eq = "{}->{}".format(",".join(inputs), output)
275 # make the shapes
276 shapes = [tuple(size_dict[ix] for ix in op) for op in inputs]
278 ret = (eq, shapes)
280 if return_size_dict:
281 ret += (size_dict, )
283 return ret