Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/parser.py: 55%
132 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
1"""
2A functionally equivalent parser of the numpy.einsum input parser
3"""
5import itertools
6from typing import Any, Dict, Iterator, List, Tuple, Union
8import numpy as np
10from .typing import ArrayType, TensorShapeType
12__all__ = [
13 "is_valid_einsum_char",
14 "has_valid_einsum_chars_only",
15 "get_symbol",
16 "gen_unused_symbols",
17 "convert_to_valid_einsum_chars",
18 "alpha_canonicalize",
19 "find_output_str",
20 "find_output_shape",
21 "possibly_convert_to_numpy",
22 "parse_einsum_input",
23]
25_einsum_symbols_base = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
28def is_valid_einsum_char(x: str) -> bool:
29 """Check if the character ``x`` is valid for numpy einsum.
31 **Examples:**
33 ```python
34 is_valid_einsum_char("a")
35 #> True
37 is_valid_einsum_char("Ǵ")
38 #> False
39 ```
40 """
41 return (x in _einsum_symbols_base) or (x in ",->.")
44def has_valid_einsum_chars_only(einsum_str: str) -> bool:
45 """Check if ``einsum_str`` contains only valid characters for numpy einsum.
47 **Examples:**
49 ```python
50 has_valid_einsum_chars_only("abAZ")
51 #> True
53 has_valid_einsum_chars_only("Över")
54 #> False
55 ```
56 """
57 return all(map(is_valid_einsum_char, einsum_str))
60def get_symbol(i: int) -> str:
61 """Get the symbol corresponding to int ``i`` - runs through the usual 52
62 letters before resorting to unicode characters, starting at ``chr(192)`` and skipping surrogates.
64 **Examples:**
66 ```python
67 get_symbol(2)
68 #> 'c'
70 get_symbol(200)
71 #> 'Ŕ'
73 get_symbol(20000)
74 #> '京'
75 ```
76 """
77 if i < 52:
78 return _einsum_symbols_base[i]
79 elif i >= 55296:
80 # Skip chr(57343) - chr(55296) as surrogates
81 return chr(i + 2048)
82 else:
83 return chr(i + 140)
86def gen_unused_symbols(used: str, n: int) -> Iterator[str]:
87 """Generate ``n`` symbols that are not already in ``used``.
89 **Examples:**
90 ```python
91 list(oe.parser.gen_unused_symbols("abd", 2))
92 #> ['c', 'e']
93 ```
94 """
95 i = cnt = 0
96 while cnt < n:
97 s = get_symbol(i)
98 i += 1
99 if s in used:
100 continue
101 yield s
102 cnt += 1
105def convert_to_valid_einsum_chars(einsum_str: str) -> str:
106 """Convert the str ``einsum_str`` to contain only the alphabetic characters
107 valid for numpy einsum. If there are too many symbols, let the backend
108 throw an error.
110 Examples
111 --------
112 >>> oe.parser.convert_to_valid_einsum_chars("Ĥěļļö")
113 'cbdda'
114 """
115 symbols = sorted(set(einsum_str) - set(",->"))
116 replacer = {x: get_symbol(i) for i, x in enumerate(symbols)}
117 return "".join(replacer.get(x, x) for x in einsum_str)
120def alpha_canonicalize(equation: str) -> str:
121 """Alpha convert an equation in an order-independent canonical way.
123 Examples
124 --------
125 >>> oe.parser.alpha_canonicalize("dcba")
126 'abcd'
128 >>> oe.parser.alpha_canonicalize("Ĥěļļö")
129 'abccd'
130 """
131 rename: Dict[str, str] = {}
132 for name in equation:
133 if name in ".,->":
134 continue
135 if name not in rename:
136 rename[name] = get_symbol(len(rename))
137 return "".join(rename.get(x, x) for x in equation)
140def find_output_str(subscripts: str) -> str:
141 """
142 Find the output string for the inputs ``subscripts`` under canonical einstein summation rules.
143 That is, repeated indices are summed over by default.
145 Examples
146 --------
147 >>> oe.parser.find_output_str("ab,bc")
148 'ac'
150 >>> oe.parser.find_output_str("a,b")
151 'ab'
153 >>> oe.parser.find_output_str("a,a,b,b")
154 ''
155 """
156 tmp_subscripts = subscripts.replace(",", "")
157 return "".join(s for s in sorted(set(tmp_subscripts)) if tmp_subscripts.count(s) == 1)
160def find_output_shape(inputs: List[str], shapes: List[TensorShapeType], output: str) -> TensorShapeType:
161 """Find the output shape for given inputs, shapes and output string, taking
162 into account broadcasting.
164 Examples
165 --------
166 >>> oe.parser.find_output_shape(["ab", "bc"], [(2, 3), (3, 4)], "ac")
167 (2, 4)
169 # Broadcasting is accounted for
170 >>> oe.parser.find_output_shape(["a", "a"], [(4, ), (1, )], "a")
171 (4,)
172 """
173 return tuple(max(shape[loc] for shape, loc in zip(shapes, [x.find(c) for x in inputs]) if loc >= 0) for c in output)
176def possibly_convert_to_numpy(x: Any) -> Any:
177 """Convert things without a 'shape' to ndarrays, but leave everything else.
179 Examples
180 --------
181 >>> oe.parser.possibly_convert_to_numpy(5)
182 array(5)
184 >>> oe.parser.possibly_convert_to_numpy([5, 3])
185 array([5, 3])
187 >>> oe.parser.possibly_convert_to_numpy(np.array([5, 3]))
188 array([5, 3])
190 # Any class with a shape is passed through
191 >>> class Shape:
192 ... def __init__(self, shape):
193 ... self.shape = shape
194 ...
196 >>> myshape = Shape((5, 5))
197 >>> oe.parser.possibly_convert_to_numpy(myshape)
198 <__main__.Shape object at 0x10f850710>
199 """
201 if not hasattr(x, "shape"):
202 return np.asanyarray(x)
203 else:
204 return x
207def convert_subscripts(old_sub: List[Any], symbol_map: Dict[Any, Any]) -> str:
208 """Convert user custom subscripts list to subscript string according to `symbol_map`.
210 Examples
211 --------
212 >>> oe.parser.convert_subscripts(['abc', 'def'], {'abc':'a', 'def':'b'})
213 'ab'
214 >>> oe.parser.convert_subscripts([Ellipsis, object], {object:'a'})
215 '...a'
216 """
217 new_sub = ""
218 for s in old_sub:
219 if s is Ellipsis:
220 new_sub += "..."
221 else:
222 # no need to try/except here because symbol_map has already been checked
223 new_sub += symbol_map[s]
224 return new_sub
227def convert_interleaved_input(operands: Union[List[Any], Tuple[Any]]) -> Tuple[str, List[Any]]:
228 """Convert 'interleaved' input to standard einsum input."""
229 tmp_operands = list(operands)
230 operand_list = []
231 subscript_list = []
232 for p in range(len(operands) // 2):
233 operand_list.append(tmp_operands.pop(0))
234 subscript_list.append(tmp_operands.pop(0))
236 output_list = tmp_operands[-1] if len(tmp_operands) else None
237 operands = [possibly_convert_to_numpy(x) for x in operand_list]
239 # build a map from user symbols to single-character symbols based on `get_symbol`
240 # The map retains the intrinsic order of user symbols
241 try:
242 # collect all user symbols
243 symbol_set = set(itertools.chain.from_iterable(subscript_list))
245 # remove Ellipsis because it can not be compared with other objects
246 symbol_set.discard(Ellipsis)
248 # build the map based on sorted user symbols, retaining the order we lost in the `set`
249 symbol_map = {symbol: get_symbol(idx) for idx, symbol in enumerate(sorted(symbol_set))}
251 except TypeError: # unhashable or uncomparable object
252 raise TypeError(
253 "For this input type lists must contain either Ellipsis "
254 "or hashable and comparable object (e.g. int, str)."
255 )
257 subscripts = ",".join(convert_subscripts(sub, symbol_map) for sub in subscript_list)
258 if output_list is not None:
259 subscripts += "->"
260 subscripts += convert_subscripts(output_list, symbol_map)
262 return subscripts, operands
265def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, List[ArrayType]]:
266 """
267 A reproduction of einsum c side einsum parsing in python.
269 **Parameters:**
270 Intakes the same inputs as `contract_path`, but NOT the keyword args. The only
271 supported keyword argument is:
272 - **shapes** - *(bool, optional)* Whether ``parse_einsum_input`` should assume arrays (the default) or
273 array shapes have been supplied.
275 Returns
276 -------
277 input_strings : str
278 Parsed input strings
279 output_string : str
280 Parsed output string
281 operands : list of array_like
282 The operands to use in the numpy contraction
284 Examples
285 --------
286 The operand list is simplified to reduce printing:
288 >>> a = np.random.rand(4, 4)
289 >>> b = np.random.rand(4, 4, 4)
290 >>> parse_einsum_input(('...a,...a->...', a, b))
291 ('za,xza', 'xz', [a, b])
293 >>> parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
294 ('za,xza', 'xz', [a, b])
295 """
297 if len(operands) == 0:
298 raise ValueError("No input operands")
300 if isinstance(operands[0], str):
301 subscripts = operands[0].replace(" ", "")
302 if shapes:
303 if any([hasattr(o, "shape") for o in operands[1:]]):
304 raise ValueError(
305 "shapes is set to True but given at least one operand looks like an array"
306 " (at least one operand has a shape attribute). "
307 )
308 operands = [possibly_convert_to_numpy(x) for x in operands[1:]]
309 else:
310 subscripts, operands = convert_interleaved_input(operands)
312 if shapes:
313 operand_shapes = operands
314 else:
315 operand_shapes = [o.shape for o in operands]
317 # Check for proper "->"
318 if ("-" in subscripts) or (">" in subscripts):
319 invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
320 if invalid or (subscripts.count("->") != 1):
321 raise ValueError("Subscripts can only contain one '->'.")
323 # Parse ellipses
324 if "." in subscripts:
325 used = subscripts.replace(".", "").replace(",", "").replace("->", "")
326 ellipse_inds = "".join(gen_unused_symbols(used, max(len(x) for x in operand_shapes)))
327 longest = 0
329 # Do we have an output to account for?
330 if "->" in subscripts:
331 input_tmp, output_sub = subscripts.split("->")
332 split_subscripts = input_tmp.split(",")
333 out_sub = True
334 else:
335 split_subscripts = subscripts.split(",")
336 out_sub = False
338 for num, sub in enumerate(split_subscripts):
339 if "." in sub:
340 if (sub.count(".") != 3) or (sub.count("...") != 1):
341 raise ValueError("Invalid Ellipses.")
343 # Take into account numerical values
344 if operand_shapes[num] == ():
345 ellipse_count = 0
346 else:
347 ellipse_count = max(len(operand_shapes[num]), 1) - (len(sub) - 3)
349 if ellipse_count > longest:
350 longest = ellipse_count
352 if ellipse_count < 0:
353 raise ValueError("Ellipses lengths do not match.")
354 elif ellipse_count == 0:
355 split_subscripts[num] = sub.replace("...", "")
356 else:
357 split_subscripts[num] = sub.replace("...", ellipse_inds[-ellipse_count:])
359 subscripts = ",".join(split_subscripts)
361 # Figure out output ellipses
362 if longest == 0:
363 out_ellipse = ""
364 else:
365 out_ellipse = ellipse_inds[-longest:]
367 if out_sub:
368 subscripts += "->" + output_sub.replace("...", out_ellipse)
369 else:
370 # Special care for outputless ellipses
371 output_subscript = find_output_str(subscripts)
372 normal_inds = "".join(sorted(set(output_subscript) - set(out_ellipse)))
374 subscripts += "->" + out_ellipse + normal_inds
376 # Build output string if does not exist
377 if "->" in subscripts:
378 input_subscripts, output_subscript = subscripts.split("->")
379 else:
380 input_subscripts, output_subscript = subscripts, find_output_str(subscripts)
382 # Make sure output subscripts are in the input
383 for char in output_subscript:
384 if char not in input_subscripts:
385 raise ValueError("Output character '{}' did not appear in the input".format(char))
387 # Make sure number operands is equivalent to the number of terms
388 if len(input_subscripts.split(",")) != len(operands):
389 raise ValueError(
390 f"Number of einsum subscripts, {len(input_subscripts.split(','))}, must be equal to the "
391 f"number of operands, {len(operands)}."
392 )
394 return input_subscripts, output_subscript, operands