Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/parser.py: 15%
124 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1#!/usr/bin/env python
2# coding: utf-8
3"""
4A functionally equivalent parser of the numpy.einsum input parser
5"""
7import itertools
8from collections import OrderedDict
10import numpy as np
12__all__ = [
13 "is_valid_einsum_char", "has_valid_einsum_chars_only", "get_symbol", "gen_unused_symbols",
14 "convert_to_valid_einsum_chars", "alpha_canonicalize", "find_output_str", "find_output_shape",
15 "possibly_convert_to_numpy", "parse_einsum_input"
16]
18_einsum_symbols_base = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
21def is_valid_einsum_char(x):
22 """Check if the character ``x`` is valid for numpy einsum.
24 Examples
25 --------
26 >>> is_valid_einsum_char("a")
27 True
29 >>> is_valid_einsum_char("Ǵ")
30 False
31 """
32 return (x in _einsum_symbols_base) or (x in ',->.')
35def has_valid_einsum_chars_only(einsum_str):
36 """Check if ``einsum_str`` contains only valid characters for numpy einsum.
38 Examples
39 --------
40 >>> has_valid_einsum_chars_only("abAZ")
41 True
43 >>> has_valid_einsum_chars_only("Över")
44 False
45 """
46 return all(map(is_valid_einsum_char, einsum_str))
49def get_symbol(i):
50 """Get the symbol corresponding to int ``i`` - runs through the usual 52
51 letters before resorting to unicode characters, starting at ``chr(192)``.
53 Examples
54 --------
55 >>> get_symbol(2)
56 'c'
58 >>> get_symbol(200)
59 'Ŕ'
61 >>> get_symbol(20000)
62 '京'
63 """
64 if i < 52:
65 return _einsum_symbols_base[i]
66 return chr(i + 140)
69def gen_unused_symbols(used, n):
70 """Generate ``n`` symbols that are not already in ``used``.
72 Examples
73 --------
74 >>> list(oe.parser.gen_unused_symbols("abd", 2))
75 ['c', 'e']
76 """
77 i = cnt = 0
78 while cnt < n:
79 s = get_symbol(i)
80 i += 1
81 if s in used:
82 continue
83 yield s
84 cnt += 1
87def convert_to_valid_einsum_chars(einsum_str):
88 """Convert the str ``einsum_str`` to contain only the alphabetic characters
89 valid for numpy einsum. If there are too many symbols, let the backend
90 throw an error.
92 Examples
93 --------
94 >>> oe.parser.convert_to_valid_einsum_chars("Ĥěļļö")
95 'cbdda'
96 """
97 symbols = sorted(set(einsum_str) - set(',->'))
98 replacer = {x: get_symbol(i) for i, x in enumerate(symbols)}
99 return "".join(replacer.get(x, x) for x in einsum_str)
102def alpha_canonicalize(equation):
103 """Alpha convert an equation in an order-independent canonical way.
105 Examples
106 --------
107 >>> oe.parser.alpha_canonicalize("dcba")
108 'abcd'
110 >>> oe.parser.alpha_canonicalize("Ĥěļļö")
111 'abccd'
112 """
113 rename = OrderedDict()
114 for name in equation:
115 if name in '.,->':
116 continue
117 if name not in rename:
118 rename[name] = get_symbol(len(rename))
119 return ''.join(rename.get(x, x) for x in equation)
122def find_output_str(subscripts):
123 """
124 Find the output string for the inputs ``subscripts`` under canonical einstein summation rules. That is, repeated indices are summed over by default.
126 Examples
127 --------
128 >>> oe.parser.find_output_str("ab,bc")
129 'ac'
131 >>> oe.parser.find_output_str("a,b")
132 'ab'
134 >>> oe.parser.find_output_str("a,a,b,b")
135 ''
136 """
137 tmp_subscripts = subscripts.replace(",", "")
138 return "".join(s for s in sorted(set(tmp_subscripts)) if tmp_subscripts.count(s) == 1)
141def find_output_shape(inputs, shapes, output):
142 """Find the output shape for given inputs, shapes and output string, taking
143 into account broadcasting.
145 Examples
146 --------
147 >>> oe.parser.find_output_shape(["ab", "bc"], [(2, 3), (3, 4)], "ac")
148 (2, 4)
150 # Broadcasting is accounted for
151 >>> oe.parser.find_output_shape(["a", "a"], [(4, ), (1, )], "a")
152 (4,)
153 """
154 return tuple(
155 max(shape[loc] for shape, loc in zip(shapes, [x.find(c) for x in inputs]) if loc >= 0) for c in output)
158def possibly_convert_to_numpy(x):
159 """Convert things without a 'shape' to ndarrays, but leave everything else.
161 Examples
162 --------
163 >>> oe.parser.possibly_convert_to_numpy(5)
164 array(5)
166 >>> oe.parser.possibly_convert_to_numpy([5, 3])
167 array([5, 3])
169 >>> oe.parser.possibly_convert_to_numpy(np.array([5, 3]))
170 array([5, 3])
172 # Any class with a shape is passed through
173 >>> class Shape:
174 ... def __init__(self, shape):
175 ... self.shape = shape
176 ...
178 >>> myshape = Shape((5, 5))
179 >>> oe.parser.possibly_convert_to_numpy(myshape)
180 <__main__.Shape object at 0x10f850710>
181 """
183 if not hasattr(x, 'shape'):
184 return np.asanyarray(x)
185 else:
186 return x
189def convert_subscripts(old_sub, symbol_map):
190 """Convert user custom subscripts list to subscript string according to `symbol_map`.
192 Examples
193 --------
194 >>> oe.parser.convert_subscripts(['abc', 'def'], {'abc':'a', 'def':'b'})
195 'ab'
196 >>> oe.parser.convert_subscripts([Ellipsis, object], {object:'a'})
197 '...a'
198 """
199 new_sub = ""
200 for s in old_sub:
201 if s is Ellipsis:
202 new_sub += "..."
203 else:
204 # no need to try/except here because symbol_map has already been checked
205 new_sub += symbol_map[s]
206 return new_sub
209def convert_interleaved_input(operands):
210 """Convert 'interleaved' input to standard einsum input.
211 """
212 tmp_operands = list(operands)
213 operand_list = []
214 subscript_list = []
215 for p in range(len(operands) // 2):
216 operand_list.append(tmp_operands.pop(0))
217 subscript_list.append(tmp_operands.pop(0))
219 output_list = tmp_operands[-1] if len(tmp_operands) else None
220 operands = [possibly_convert_to_numpy(x) for x in operand_list]
222 # build a map from user symbols to single-character symbols based on `get_symbol`
223 # The map retains the intrinsic order of user symbols
224 try:
225 # collect all user symbols
226 symbol_set = set(itertools.chain.from_iterable(subscript_list))
228 # remove Ellipsis because it can not be compared with other objects
229 symbol_set.discard(Ellipsis)
231 # build the map based on sorted user symbols, retaining the order we lost in the `set`
232 symbol_map = {symbol: get_symbol(idx) for idx, symbol in enumerate(sorted(symbol_set))}
234 except TypeError: # unhashable or uncomparable object
235 raise TypeError("For this input type lists must contain either Ellipsis "
236 "or hashable and comparable object (e.g. int, str).")
238 subscripts = ','.join(convert_subscripts(sub, symbol_map) for sub in subscript_list)
239 if output_list is not None:
240 subscripts += "->"
241 subscripts += convert_subscripts(output_list, symbol_map)
243 return subscripts, operands
246def parse_einsum_input(operands):
247 """
248 A reproduction of einsum c side einsum parsing in python.
250 Returns
251 -------
252 input_strings : str
253 Parsed input strings
254 output_string : str
255 Parsed output string
256 operands : list of array_like
257 The operands to use in the numpy contraction
259 Examples
260 --------
261 The operand list is simplified to reduce printing:
263 >>> a = np.random.rand(4, 4)
264 >>> b = np.random.rand(4, 4, 4)
265 >>> parse_einsum_input(('...a,...a->...', a, b))
266 ('za,xza', 'xz', [a, b])
268 >>> parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
269 ('za,xza', 'xz', [a, b])
270 """
272 if len(operands) == 0:
273 raise ValueError("No input operands")
275 if isinstance(operands[0], str):
276 subscripts = operands[0].replace(" ", "")
277 operands = [possibly_convert_to_numpy(x) for x in operands[1:]]
279 else:
280 subscripts, operands = convert_interleaved_input(operands)
282 # Check for proper "->"
283 if ("-" in subscripts) or (">" in subscripts):
284 invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
285 if invalid or (subscripts.count("->") != 1):
286 raise ValueError("Subscripts can only contain one '->'.")
288 # Parse ellipses
289 if "." in subscripts:
290 used = subscripts.replace(".", "").replace(",", "").replace("->", "")
291 ellipse_inds = "".join(gen_unused_symbols(used, max(len(x.shape) for x in operands)))
292 longest = 0
294 # Do we have an output to account for?
295 if "->" in subscripts:
296 input_tmp, output_sub = subscripts.split("->")
297 split_subscripts = input_tmp.split(",")
298 out_sub = True
299 else:
300 split_subscripts = subscripts.split(',')
301 out_sub = False
303 for num, sub in enumerate(split_subscripts):
304 if "." in sub:
305 if (sub.count(".") != 3) or (sub.count("...") != 1):
306 raise ValueError("Invalid Ellipses.")
308 # Take into account numerical values
309 if operands[num].shape == ():
310 ellipse_count = 0
311 else:
312 ellipse_count = max(len(operands[num].shape), 1) - (len(sub) - 3)
314 if ellipse_count > longest:
315 longest = ellipse_count
317 if ellipse_count < 0:
318 raise ValueError("Ellipses lengths do not match.")
319 elif ellipse_count == 0:
320 split_subscripts[num] = sub.replace('...', '')
321 else:
322 split_subscripts[num] = sub.replace('...', ellipse_inds[-ellipse_count:])
324 subscripts = ",".join(split_subscripts)
326 # Figure out output ellipses
327 if longest == 0:
328 out_ellipse = ""
329 else:
330 out_ellipse = ellipse_inds[-longest:]
332 if out_sub:
333 subscripts += "->" + output_sub.replace("...", out_ellipse)
334 else:
335 # Special care for outputless ellipses
336 output_subscript = find_output_str(subscripts)
337 normal_inds = ''.join(sorted(set(output_subscript) - set(out_ellipse)))
339 subscripts += "->" + out_ellipse + normal_inds
341 # Build output string if does not exist
342 if "->" in subscripts:
343 input_subscripts, output_subscript = subscripts.split("->")
344 else:
345 input_subscripts, output_subscript = subscripts, find_output_str(subscripts)
347 # Make sure output subscripts are in the input
348 for char in output_subscript:
349 if char not in input_subscripts:
350 raise ValueError("Output character '{}' did not appear in the input".format(char))
352 # Make sure number operands is equivalent to the number of terms
353 if len(input_subscripts.split(',')) != len(operands):
354 raise ValueError("Number of einsum subscripts must be equal to the " "number of operands.")
356 return input_subscripts, output_subscript, operands