1"""
2Implementation of optimized einsum.
3
4"""
5import functools
6import itertools
7import operator
8
9from numpy._core.multiarray import c_einsum, matmul
10from numpy._core.numeric import asanyarray, reshape
11from numpy._core.overrides import array_function_dispatch
12from numpy._core.umath import multiply
13
14__all__ = ['einsum', 'einsum_path']
15
16# importing string for string.ascii_letters would be too slow
17# the first import before caching has been measured to take 800 µs (#23777)
18# imports begin with uppercase to mimic ASCII values to avoid sorting issues
19einsum_symbols = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
20einsum_symbols_set = set(einsum_symbols)
21
22
23def _flop_count(idx_contraction, inner, num_terms, size_dictionary):
24 """
25 Computes the number of FLOPS in the contraction.
26
27 Parameters
28 ----------
29 idx_contraction : iterable
30 The indices involved in the contraction
31 inner : bool
32 Does this contraction require an inner product?
33 num_terms : int
34 The number of terms in a contraction
35 size_dictionary : dict
36 The size of each of the indices in idx_contraction
37
38 Returns
39 -------
40 flop_count : int
41 The total number of FLOPS required for the contraction.
42
43 Examples
44 --------
45
46 >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
47 30
48
49 >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
50 60
51
52 """
53
54 overall_size = _compute_size_by_dict(idx_contraction, size_dictionary)
55 op_factor = max(1, num_terms - 1)
56 if inner:
57 op_factor += 1
58
59 return overall_size * op_factor
60
61def _compute_size_by_dict(indices, idx_dict):
62 """
63 Computes the product of the elements in indices based on the dictionary
64 idx_dict.
65
66 Parameters
67 ----------
68 indices : iterable
69 Indices to base the product on.
70 idx_dict : dictionary
71 Dictionary of index sizes
72
73 Returns
74 -------
75 ret : int
76 The resulting product.
77
78 Examples
79 --------
80 >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
81 90
82
83 """
84 ret = 1
85 for i in indices:
86 ret *= idx_dict[i]
87 return ret
88
89
90def _find_contraction(positions, input_sets, output_set):
91 """
92 Finds the contraction for a given set of input and output sets.
93
94 Parameters
95 ----------
96 positions : iterable
97 Integer positions of terms used in the contraction.
98 input_sets : list
99 List of sets that represent the lhs side of the einsum subscript
100 output_set : set
101 Set that represents the rhs side of the overall einsum subscript
102
103 Returns
104 -------
105 new_result : set
106 The indices of the resulting contraction
107 remaining : list
108 List of sets that have not been contracted, the new set is appended to
109 the end of this list
110 idx_removed : set
111 Indices removed from the entire contraction
112 idx_contraction : set
113 The indices used in the current contraction
114
115 Examples
116 --------
117
118 # A simple dot product test case
119 >>> pos = (0, 1)
120 >>> isets = [set('ab'), set('bc')]
121 >>> oset = set('ac')
122 >>> _find_contraction(pos, isets, oset)
123 ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
124
125 # A more complex case with additional terms in the contraction
126 >>> pos = (0, 2)
127 >>> isets = [set('abd'), set('ac'), set('bdc')]
128 >>> oset = set('ac')
129 >>> _find_contraction(pos, isets, oset)
130 ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
131 """
132
133 idx_contract = set()
134 idx_remain = output_set.copy()
135 remaining = []
136 for ind, value in enumerate(input_sets):
137 if ind in positions:
138 idx_contract |= value
139 else:
140 remaining.append(value)
141 idx_remain |= value
142
143 new_result = idx_remain & idx_contract
144 idx_removed = (idx_contract - new_result)
145 remaining.append(new_result)
146
147 return (new_result, remaining, idx_removed, idx_contract)
148
149
150def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
151 """
152 Computes all possible pair contractions, sieves the results based
153 on ``memory_limit`` and returns the lowest cost path. This algorithm
154 scales factorial with respect to the elements in the list ``input_sets``.
155
156 Parameters
157 ----------
158 input_sets : list
159 List of sets that represent the lhs side of the einsum subscript
160 output_set : set
161 Set that represents the rhs side of the overall einsum subscript
162 idx_dict : dictionary
163 Dictionary of index sizes
164 memory_limit : int
165 The maximum number of elements in a temporary array
166
167 Returns
168 -------
169 path : list
170 The optimal contraction order within the memory limit constraint.
171
172 Examples
173 --------
174 >>> isets = [set('abd'), set('ac'), set('bdc')]
175 >>> oset = set()
176 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
177 >>> _optimal_path(isets, oset, idx_sizes, 5000)
178 [(0, 2), (0, 1)]
179 """
180
181 full_results = [(0, [], input_sets)]
182 for iteration in range(len(input_sets) - 1):
183 iter_results = []
184
185 # Compute all unique pairs
186 for curr in full_results:
187 cost, positions, remaining = curr
188 for con in itertools.combinations(
189 range(len(input_sets) - iteration), 2
190 ):
191
192 # Find the contraction
193 cont = _find_contraction(con, remaining, output_set)
194 new_result, new_input_sets, idx_removed, idx_contract = cont
195
196 # Sieve the results based on memory_limit
197 new_size = _compute_size_by_dict(new_result, idx_dict)
198 if new_size > memory_limit:
199 continue
200
201 # Build (total_cost, positions, indices_remaining)
202 total_cost = cost + _flop_count(
203 idx_contract, idx_removed, len(con), idx_dict
204 )
205 new_pos = positions + [con]
206 iter_results.append((total_cost, new_pos, new_input_sets))
207
208 # Update combinatorial list, if we did not find anything return best
209 # path + remaining contractions
210 if iter_results:
211 full_results = iter_results
212 else:
213 path = min(full_results, key=lambda x: x[0])[1]
214 path += [tuple(range(len(input_sets) - iteration))]
215 return path
216
217 # If we have not found anything return single einsum contraction
218 if len(full_results) == 0:
219 return [tuple(range(len(input_sets)))]
220
221 path = min(full_results, key=lambda x: x[0])[1]
222 return path
223
224def _parse_possible_contraction(
225 positions, input_sets, output_set, idx_dict,
226 memory_limit, path_cost, naive_cost
227 ):
228 """Compute the cost (removed size + flops) and resultant indices for
229 performing the contraction specified by ``positions``.
230
231 Parameters
232 ----------
233 positions : tuple of int
234 The locations of the proposed tensors to contract.
235 input_sets : list of sets
236 The indices found on each tensors.
237 output_set : set
238 The output indices of the expression.
239 idx_dict : dict
240 Mapping of each index to its size.
241 memory_limit : int
242 The total allowed size for an intermediary tensor.
243 path_cost : int
244 The contraction cost so far.
245 naive_cost : int
246 The cost of the unoptimized expression.
247
248 Returns
249 -------
250 cost : (int, int)
251 A tuple containing the size of any indices removed, and the flop cost.
252 positions : tuple of int
253 The locations of the proposed tensors to contract.
254 new_input_sets : list of sets
255 The resulting new list of indices if this proposed contraction
256 is performed.
257
258 """
259
260 # Find the contraction
261 contract = _find_contraction(positions, input_sets, output_set)
262 idx_result, new_input_sets, idx_removed, idx_contract = contract
263
264 # Sieve the results based on memory_limit
265 new_size = _compute_size_by_dict(idx_result, idx_dict)
266 if new_size > memory_limit:
267 return None
268
269 # Build sort tuple
270 old_sizes = (
271 _compute_size_by_dict(input_sets[p], idx_dict) for p in positions
272 )
273 removed_size = sum(old_sizes) - new_size
274
275 # NB: removed_size used to be just the size of any removed indices i.e.:
276 # helpers.compute_size_by_dict(idx_removed, idx_dict)
277 cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict)
278 sort = (-removed_size, cost)
279
280 # Sieve based on total cost as well
281 if (path_cost + cost) > naive_cost:
282 return None
283
284 # Add contraction to possible choices
285 return [sort, positions, new_input_sets]
286
287
288def _update_other_results(results, best):
289 """Update the positions and provisional input_sets of ``results``
290 based on performing the contraction result ``best``. Remove any
291 involving the tensors contracted.
292
293 Parameters
294 ----------
295 results : list
296 List of contraction results produced by
297 ``_parse_possible_contraction``.
298 best : list
299 The best contraction of ``results`` i.e. the one that
300 will be performed.
301
302 Returns
303 -------
304 mod_results : list
305 The list of modified results, updated with outcome of
306 ``best`` contraction.
307 """
308
309 best_con = best[1]
310 bx, by = best_con
311 mod_results = []
312
313 for cost, (x, y), con_sets in results:
314
315 # Ignore results involving tensors just contracted
316 if x in best_con or y in best_con:
317 continue
318
319 # Update the input_sets
320 del con_sets[by - int(by > x) - int(by > y)]
321 del con_sets[bx - int(bx > x) - int(bx > y)]
322 con_sets.insert(-1, best[2][-1])
323
324 # Update the position indices
325 mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by)
326 mod_results.append((cost, mod_con, con_sets))
327
328 return mod_results
329
330def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
331 """
332 Finds the path by contracting the best pair until the input list is
333 exhausted. The best pair is found by minimizing the tuple
334 ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing
335 matrix multiplication or inner product operations, then Hadamard like
336 operations, and finally outer operations. Outer products are limited by
337 ``memory_limit``. This algorithm scales cubically with respect to the
338 number of elements in the list ``input_sets``.
339
340 Parameters
341 ----------
342 input_sets : list
343 List of sets that represent the lhs side of the einsum subscript
344 output_set : set
345 Set that represents the rhs side of the overall einsum subscript
346 idx_dict : dictionary
347 Dictionary of index sizes
348 memory_limit : int
349 The maximum number of elements in a temporary array
350
351 Returns
352 -------
353 path : list
354 The greedy contraction order within the memory limit constraint.
355
356 Examples
357 --------
358 >>> isets = [set('abd'), set('ac'), set('bdc')]
359 >>> oset = set()
360 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
361 >>> _greedy_path(isets, oset, idx_sizes, 5000)
362 [(0, 2), (0, 1)]
363 """
364
365 # Handle trivial cases that leaked through
366 if len(input_sets) == 1:
367 return [(0,)]
368 elif len(input_sets) == 2:
369 return [(0, 1)]
370
371 # Build up a naive cost
372 contract = _find_contraction(
373 range(len(input_sets)), input_sets, output_set
374 )
375 idx_result, new_input_sets, idx_removed, idx_contract = contract
376 naive_cost = _flop_count(
377 idx_contract, idx_removed, len(input_sets), idx_dict
378 )
379
380 # Initially iterate over all pairs
381 comb_iter = itertools.combinations(range(len(input_sets)), 2)
382 known_contractions = []
383
384 path_cost = 0
385 path = []
386
387 for iteration in range(len(input_sets) - 1):
388
389 # Iterate over all pairs on the first step, only previously
390 # found pairs on subsequent steps
391 for positions in comb_iter:
392
393 # Always initially ignore outer products
394 if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]):
395 continue
396
397 result = _parse_possible_contraction(
398 positions, input_sets, output_set, idx_dict,
399 memory_limit, path_cost, naive_cost
400 )
401 if result is not None:
402 known_contractions.append(result)
403
404 # If we do not have a inner contraction, rescan pairs
405 # including outer products
406 if len(known_contractions) == 0:
407
408 # Then check the outer products
409 for positions in itertools.combinations(
410 range(len(input_sets)), 2
411 ):
412 result = _parse_possible_contraction(
413 positions, input_sets, output_set, idx_dict,
414 memory_limit, path_cost, naive_cost
415 )
416 if result is not None:
417 known_contractions.append(result)
418
419 # If we still did not find any remaining contractions,
420 # default back to einsum like behavior
421 if len(known_contractions) == 0:
422 path.append(tuple(range(len(input_sets))))
423 break
424
425 # Sort based on first index
426 best = min(known_contractions, key=lambda x: x[0])
427
428 # Now propagate as many unused contractions as possible
429 # to the next iteration
430 known_contractions = _update_other_results(known_contractions, best)
431
432 # Next iteration only compute contractions with the new tensor
433 # All other contractions have been accounted for
434 input_sets = best[2]
435 new_tensor_pos = len(input_sets) - 1
436 comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos))
437
438 # Update path and total cost
439 path.append(best[1])
440 path_cost += best[0][1]
441
442 return path
443
444
445def _parse_einsum_input(operands):
446 """
447 A reproduction of einsum c side einsum parsing in python.
448
449 Returns
450 -------
451 input_strings : str
452 Parsed input strings
453 output_string : str
454 Parsed output string
455 operands : list of array_like
456 The operands to use in the numpy contraction
457
458 Examples
459 --------
460 The operand list is simplified to reduce printing:
461
462 >>> np.random.seed(123)
463 >>> a = np.random.rand(4, 4)
464 >>> b = np.random.rand(4, 4, 4)
465 >>> _parse_einsum_input(('...a,...a->...', a, b))
466 ('za,xza', 'xz', [a, b]) # may vary
467
468 >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
469 ('za,xza', 'xz', [a, b]) # may vary
470 """
471
472 if len(operands) == 0:
473 raise ValueError("No input operands")
474
475 if isinstance(operands[0], str):
476 subscripts = operands[0].replace(" ", "")
477 operands = [asanyarray(v) for v in operands[1:]]
478
479 # Ensure all characters are valid
480 for s in subscripts:
481 if s in '.,->':
482 continue
483 if s not in einsum_symbols:
484 raise ValueError(f"Character {s} is not a valid symbol.")
485
486 else:
487 tmp_operands = list(operands)
488 operand_list = []
489 subscript_list = []
490 for p in range(len(operands) // 2):
491 operand_list.append(tmp_operands.pop(0))
492 subscript_list.append(tmp_operands.pop(0))
493
494 output_list = tmp_operands[-1] if len(tmp_operands) else None
495 operands = [asanyarray(v) for v in operand_list]
496 subscripts = ""
497 last = len(subscript_list) - 1
498 for num, sub in enumerate(subscript_list):
499 for s in sub:
500 if s is Ellipsis:
501 subscripts += "..."
502 else:
503 try:
504 s = operator.index(s)
505 except TypeError as e:
506 raise TypeError(
507 "For this input type lists must contain "
508 "either int or Ellipsis"
509 ) from e
510 subscripts += einsum_symbols[s]
511 if num != last:
512 subscripts += ","
513
514 if output_list is not None:
515 subscripts += "->"
516 for s in output_list:
517 if s is Ellipsis:
518 subscripts += "..."
519 else:
520 try:
521 s = operator.index(s)
522 except TypeError as e:
523 raise TypeError(
524 "For this input type lists must contain "
525 "either int or Ellipsis"
526 ) from e
527 subscripts += einsum_symbols[s]
528 # Check for proper "->"
529 if ("-" in subscripts) or (">" in subscripts):
530 invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
531 if invalid or (subscripts.count("->") != 1):
532 raise ValueError("Subscripts can only contain one '->'.")
533
534 # Parse ellipses
535 if "." in subscripts:
536 used = subscripts.replace(".", "").replace(",", "").replace("->", "")
537 unused = list(einsum_symbols_set - set(used))
538 ellipse_inds = "".join(unused)
539 longest = 0
540
541 if "->" in subscripts:
542 input_tmp, output_sub = subscripts.split("->")
543 split_subscripts = input_tmp.split(",")
544 out_sub = True
545 else:
546 split_subscripts = subscripts.split(',')
547 out_sub = False
548
549 for num, sub in enumerate(split_subscripts):
550 if "." in sub:
551 if (sub.count(".") != 3) or (sub.count("...") != 1):
552 raise ValueError("Invalid Ellipses.")
553
554 # Take into account numerical values
555 if operands[num].shape == ():
556 ellipse_count = 0
557 else:
558 ellipse_count = max(operands[num].ndim, 1)
559 ellipse_count -= (len(sub) - 3)
560
561 if ellipse_count > longest:
562 longest = ellipse_count
563
564 if ellipse_count < 0:
565 raise ValueError("Ellipses lengths do not match.")
566 elif ellipse_count == 0:
567 split_subscripts[num] = sub.replace('...', '')
568 else:
569 rep_inds = ellipse_inds[-ellipse_count:]
570 split_subscripts[num] = sub.replace('...', rep_inds)
571
572 subscripts = ",".join(split_subscripts)
573 if longest == 0:
574 out_ellipse = ""
575 else:
576 out_ellipse = ellipse_inds[-longest:]
577
578 if out_sub:
579 subscripts += "->" + output_sub.replace("...", out_ellipse)
580 else:
581 # Special care for outputless ellipses
582 output_subscript = ""
583 tmp_subscripts = subscripts.replace(",", "")
584 for s in sorted(set(tmp_subscripts)):
585 if s not in (einsum_symbols):
586 raise ValueError(f"Character {s} is not a valid symbol.")
587 if tmp_subscripts.count(s) == 1:
588 output_subscript += s
589 normal_inds = ''.join(sorted(set(output_subscript) -
590 set(out_ellipse)))
591
592 subscripts += "->" + out_ellipse + normal_inds
593
594 # Build output string if does not exist
595 if "->" in subscripts:
596 input_subscripts, output_subscript = subscripts.split("->")
597 else:
598 input_subscripts = subscripts
599 # Build output subscripts
600 tmp_subscripts = subscripts.replace(",", "")
601 output_subscript = ""
602 for s in sorted(set(tmp_subscripts)):
603 if s not in einsum_symbols:
604 raise ValueError(f"Character {s} is not a valid symbol.")
605 if tmp_subscripts.count(s) == 1:
606 output_subscript += s
607
608 # Make sure output subscripts are in the input
609 for char in output_subscript:
610 if output_subscript.count(char) != 1:
611 raise ValueError("Output character %s appeared more than once in "
612 "the output." % char)
613 if char not in input_subscripts:
614 raise ValueError(f"Output character {char} did not appear in the input")
615
616 # Make sure number operands is equivalent to the number of terms
617 if len(input_subscripts.split(',')) != len(operands):
618 raise ValueError("Number of einsum subscripts must be equal to the "
619 "number of operands.")
620
621 return (input_subscripts, output_subscript, operands)
622
623
624def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None):
625 # NOTE: technically, we should only dispatch on array-like arguments, not
626 # subscripts (given as strings). But separating operands into
627 # arrays/subscripts is a little tricky/slow (given einsum's two supported
628 # signatures), so as a practical shortcut we dispatch on everything.
629 # Strings will be ignored for dispatching since they don't define
630 # __array_function__.
631 return operands
632
633
634@array_function_dispatch(_einsum_path_dispatcher, module='numpy')
635def einsum_path(*operands, optimize='greedy', einsum_call=False):
636 """
637 einsum_path(subscripts, *operands, optimize='greedy')
638
639 Evaluates the lowest cost contraction order for an einsum expression by
640 considering the creation of intermediate arrays.
641
642 Parameters
643 ----------
644 subscripts : str
645 Specifies the subscripts for summation.
646 *operands : list of array_like
647 These are the arrays for the operation.
648 optimize : {bool, list, tuple, 'greedy', 'optimal'}
649 Choose the type of path. If a tuple is provided, the second argument is
650 assumed to be the maximum intermediate size created. If only a single
651 argument is provided the largest input or output array size is used
652 as a maximum intermediate size.
653
654 * if a list is given that starts with ``einsum_path``, uses this as the
655 contraction path
656 * if False no optimization is taken
657 * if True defaults to the 'greedy' algorithm
658 * 'optimal' An algorithm that combinatorially explores all possible
659 ways of contracting the listed tensors and chooses the least costly
660 path. Scales exponentially with the number of terms in the
661 contraction.
662 * 'greedy' An algorithm that chooses the best pair contraction
663 at each step. Effectively, this algorithm searches the largest inner,
664 Hadamard, and then outer products at each step. Scales cubically with
665 the number of terms in the contraction. Equivalent to the 'optimal'
666 path for most contractions.
667
668 Default is 'greedy'.
669
670 Returns
671 -------
672 path : list of tuples
673 A list representation of the einsum path.
674 string_repr : str
675 A printable representation of the einsum path.
676
677 Notes
678 -----
679 The resulting path indicates which terms of the input contraction should be
680 contracted first, the result of this contraction is then appended to the
681 end of the contraction list. This list can then be iterated over until all
682 intermediate contractions are complete.
683
684 See Also
685 --------
686 einsum, linalg.multi_dot
687
688 Examples
689 --------
690
691 We can begin with a chain dot example. In this case, it is optimal to
692 contract the ``b`` and ``c`` tensors first as represented by the first
693 element of the path ``(1, 2)``. The resulting tensor is added to the end
694 of the contraction and the remaining contraction ``(0, 1)`` is then
695 completed.
696
697 >>> np.random.seed(123)
698 >>> a = np.random.rand(2, 2)
699 >>> b = np.random.rand(2, 5)
700 >>> c = np.random.rand(5, 2)
701 >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
702 >>> print(path_info[0])
703 ['einsum_path', (1, 2), (0, 1)]
704 >>> print(path_info[1])
705 Complete contraction: ij,jk,kl->il # may vary
706 Naive scaling: 4
707 Optimized scaling: 3
708 Naive FLOP count: 1.600e+02
709 Optimized FLOP count: 5.600e+01
710 Theoretical speedup: 2.857
711 Largest intermediate: 4.000e+00 elements
712 -------------------------------------------------------------------------
713 scaling current remaining
714 -------------------------------------------------------------------------
715 3 kl,jk->jl ij,jl->il
716 3 jl,ij->il il->il
717
718
719 A more complex index transformation example.
720
721 >>> I = np.random.rand(10, 10, 10, 10)
722 >>> C = np.random.rand(10, 10)
723 >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C,
724 ... optimize='greedy')
725
726 >>> print(path_info[0])
727 ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
728 >>> print(path_info[1])
729 Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary
730 Naive scaling: 8
731 Optimized scaling: 5
732 Naive FLOP count: 8.000e+08
733 Optimized FLOP count: 8.000e+05
734 Theoretical speedup: 1000.000
735 Largest intermediate: 1.000e+04 elements
736 --------------------------------------------------------------------------
737 scaling current remaining
738 --------------------------------------------------------------------------
739 5 abcd,ea->bcde fb,gc,hd,bcde->efgh
740 5 bcde,fb->cdef gc,hd,cdef->efgh
741 5 cdef,gc->defg hd,defg->efgh
742 5 defg,hd->efgh efgh->efgh
743 """
744
745 # Figure out what the path really is
746 path_type = optimize
747 if path_type is True:
748 path_type = 'greedy'
749 if path_type is None:
750 path_type = False
751
752 explicit_einsum_path = False
753 memory_limit = None
754
755 # No optimization or a named path algorithm
756 if (path_type is False) or isinstance(path_type, str):
757 pass
758
759 # Given an explicit path
760 elif len(path_type) and (path_type[0] == 'einsum_path'):
761 explicit_einsum_path = True
762
763 # Path tuple with memory limit
764 elif ((len(path_type) == 2) and isinstance(path_type[0], str) and
765 isinstance(path_type[1], (int, float))):
766 memory_limit = int(path_type[1])
767 path_type = path_type[0]
768
769 else:
770 raise TypeError(f"Did not understand the path: {str(path_type)}")
771
772 # Hidden option, only einsum should call this
773 einsum_call_arg = einsum_call
774
775 # Python side parsing
776 input_subscripts, output_subscript, operands = (
777 _parse_einsum_input(operands)
778 )
779
780 # Build a few useful list and sets
781 input_list = input_subscripts.split(',')
782 num_inputs = len(input_list)
783 input_sets = [set(x) for x in input_list]
784 output_set = set(output_subscript)
785 indices = set(input_subscripts.replace(',', ''))
786 num_indices = len(indices)
787
788 # Get length of each unique dimension and ensure all dimensions are correct
789 dimension_dict = {}
790 for tnum, term in enumerate(input_list):
791 sh = operands[tnum].shape
792 if len(sh) != len(term):
793 raise ValueError("Einstein sum subscript %s does not contain the "
794 "correct number of indices for operand %d."
795 % (input_subscripts[tnum], tnum))
796 for cnum, char in enumerate(term):
797 dim = sh[cnum]
798
799 if char in dimension_dict.keys():
800 # For broadcasting cases we always want the largest dim size
801 if dimension_dict[char] == 1:
802 dimension_dict[char] = dim
803 elif dim not in (1, dimension_dict[char]):
804 raise ValueError("Size of label '%s' for operand %d (%d) "
805 "does not match previous terms (%d)."
806 % (char, tnum, dimension_dict[char], dim))
807 else:
808 dimension_dict[char] = dim
809
810 # Compute size of each input array plus the output array
811 size_list = [_compute_size_by_dict(term, dimension_dict)
812 for term in input_list + [output_subscript]]
813 max_size = max(size_list)
814
815 if memory_limit is None:
816 memory_arg = max_size
817 else:
818 memory_arg = memory_limit
819
820 # Compute the path
821 if explicit_einsum_path:
822 path = path_type[1:]
823 elif (
824 (path_type is False)
825 or (num_inputs in [1, 2])
826 or (indices == output_set)
827 ):
828 # Nothing to be optimized, leave it to einsum
829 path = [tuple(range(num_inputs))]
830 elif path_type == "greedy":
831 path = _greedy_path(
832 input_sets, output_set, dimension_dict, memory_arg
833 )
834 elif path_type == "optimal":
835 path = _optimal_path(
836 input_sets, output_set, dimension_dict, memory_arg
837 )
838 else:
839 raise KeyError("Path name %s not found", path_type)
840
841 cost_list, scale_list, size_list, contraction_list = [], [], [], []
842
843 # Build contraction tuple (positions, gemm, einsum_str, remaining)
844 for cnum, contract_inds in enumerate(path):
845 # Make sure we remove inds from right to left
846 contract_inds = tuple(sorted(contract_inds, reverse=True))
847
848 contract = _find_contraction(contract_inds, input_sets, output_set)
849 out_inds, input_sets, idx_removed, idx_contract = contract
850
851 if not einsum_call_arg:
852 # these are only needed for printing info
853 cost = _flop_count(
854 idx_contract, idx_removed, len(contract_inds), dimension_dict
855 )
856 cost_list.append(cost)
857 scale_list.append(len(idx_contract))
858 size_list.append(_compute_size_by_dict(out_inds, dimension_dict))
859
860 tmp_inputs = []
861 for x in contract_inds:
862 tmp_inputs.append(input_list.pop(x))
863
864 # Last contraction
865 if (cnum - len(path)) == -1:
866 idx_result = output_subscript
867 else:
868 sort_result = [(dimension_dict[ind], ind) for ind in out_inds]
869 idx_result = "".join([x[1] for x in sorted(sort_result)])
870
871 input_list.append(idx_result)
872 einsum_str = ",".join(tmp_inputs) + "->" + idx_result
873
874 contraction = (contract_inds, einsum_str, input_list[:])
875 contraction_list.append(contraction)
876
877 if len(input_list) != 1:
878 # Explicit "einsum_path" is usually trusted, but we detect this kind of
879 # mistake in order to prevent from returning an intermediate value.
880 raise RuntimeError(
881 f"Invalid einsum_path is specified: {len(input_list) - 1} more "
882 "operands has to be contracted.")
883
884 if einsum_call_arg:
885 return (operands, contraction_list)
886
887 # Return the path along with a nice string representation
888 overall_contraction = input_subscripts + "->" + output_subscript
889 header = ("scaling", "current", "remaining")
890
891 # Compute naive cost
892 # This isn't quite right, need to look into exactly how einsum does this
893 inner_product = (
894 sum(len(set(x)) for x in input_subscripts.split(',')) - num_indices
895 ) > 0
896 naive_cost = _flop_count(
897 indices, inner_product, num_inputs, dimension_dict
898 )
899
900 opt_cost = sum(cost_list) + 1
901 speedup = naive_cost / opt_cost
902 max_i = max(size_list)
903
904 path_print = f" Complete contraction: {overall_contraction}\n"
905 path_print += f" Naive scaling: {num_indices}\n"
906 path_print += " Optimized scaling: %d\n" % max(scale_list)
907 path_print += f" Naive FLOP count: {naive_cost:.3e}\n"
908 path_print += f" Optimized FLOP count: {opt_cost:.3e}\n"
909 path_print += f" Theoretical speedup: {speedup:3.3f}\n"
910 path_print += f" Largest intermediate: {max_i:.3e} elements\n"
911 path_print += "-" * 74 + "\n"
912 path_print += "%6s %24s %40s\n" % header
913 path_print += "-" * 74
914
915 for n, contraction in enumerate(contraction_list):
916 _, einsum_str, remaining = contraction
917 remaining_str = ",".join(remaining) + "->" + output_subscript
918 path_run = (scale_list[n], einsum_str, remaining_str)
919 path_print += "\n%4d %24s %40s" % path_run
920
921 path = ['einsum_path'] + path
922 return (path, path_print)
923
924
925def _parse_eq_to_pure_multiplication(a_term, shape_a, b_term, shape_b, out):
926 """If there are no contracted indices, then we can directly transpose and
927 insert singleton dimensions into ``a`` and ``b`` such that (broadcast)
928 elementwise multiplication performs the einsum.
929
930 No need to cache this as it is within the cached
931 ``_parse_eq_to_batch_matmul``.
932
933 """
934 desired_a = ""
935 desired_b = ""
936 new_shape_a = []
937 new_shape_b = []
938 for ix in out:
939 if ix in a_term:
940 desired_a += ix
941 new_shape_a.append(shape_a[a_term.index(ix)])
942 else:
943 new_shape_a.append(1)
944 if ix in b_term:
945 desired_b += ix
946 new_shape_b.append(shape_b[b_term.index(ix)])
947 else:
948 new_shape_b.append(1)
949
950 if desired_a != a_term:
951 eq_a = f"{a_term}->{desired_a}"
952 else:
953 eq_a = None
954 if desired_b != b_term:
955 eq_b = f"{b_term}->{desired_b}"
956 else:
957 eq_b = None
958
959 return (
960 eq_a,
961 eq_b,
962 new_shape_a,
963 new_shape_b,
964 None, # new_shape_ab, not needed since not fusing
965 None, # perm_ab, not needed as we transpose a and b first
966 True, # pure_multiplication=True
967 )
968
969
970@functools.lru_cache(2**12)
971def _parse_eq_to_batch_matmul(eq, shape_a, shape_b):
972 """Cached parsing of a two term einsum equation into the necessary
973 sequence of arguments for contracttion via batched matrix multiplication.
974 The steps we need to specify are:
975
976 1. Remove repeated and trivial indices from the left and right terms,
977 and transpose them, done as a single einsum.
978 2. Fuse the remaining indices so we have two 3D tensors.
979 3. Perform the batched matrix multiplication.
980 4. Unfuse the output to get the desired final index order.
981
982 """
983 lhs, out = eq.split("->")
984 a_term, b_term = lhs.split(",")
985
986 if len(a_term) != len(shape_a):
987 raise ValueError(f"Term '{a_term}' does not match shape {shape_a}.")
988 if len(b_term) != len(shape_b):
989 raise ValueError(f"Term '{b_term}' does not match shape {shape_b}.")
990
991 sizes = {}
992 singletons = set()
993
994 # parse left term to unique indices with size > 1
995 left = {}
996 for ix, d in zip(a_term, shape_a):
997 if d == 1:
998 # everything (including broadcasting) works nicely if simply ignore
999 # such dimensions, but we do need to track if they appear in output
1000 # and thus should be reintroduced later
1001 singletons.add(ix)
1002 continue
1003 if sizes.setdefault(ix, d) != d:
1004 # set and check size
1005 raise ValueError(
1006 f"Index {ix} has mismatched sizes {sizes[ix]} and {d}."
1007 )
1008 left[ix] = True
1009
1010 # parse right term to unique indices with size > 1
1011 right = {}
1012 for ix, d in zip(b_term, shape_b):
1013 # broadcast indices (size 1 on one input and size != 1
1014 # on the other) should not be treated as singletons
1015 if d == 1:
1016 if ix not in left:
1017 singletons.add(ix)
1018 continue
1019 singletons.discard(ix)
1020
1021 if sizes.setdefault(ix, d) != d:
1022 # set and check size
1023 raise ValueError(
1024 f"Index {ix} has mismatched sizes {sizes[ix]} and {d}."
1025 )
1026 right[ix] = True
1027
1028 # now we classify the unique size > 1 indices only
1029 bat_inds = [] # appears on A, B, O
1030 con_inds = [] # appears on A, B, .
1031 a_keep = [] # appears on A, ., O
1032 b_keep = [] # appears on ., B, O
1033 # other indices (appearing on A or B only) will
1034 # be summed or traced out prior to the matmul
1035 for ix in left:
1036 if right.pop(ix, False):
1037 if ix in out:
1038 bat_inds.append(ix)
1039 else:
1040 con_inds.append(ix)
1041 elif ix in out:
1042 a_keep.append(ix)
1043 # now only indices unique to right remain
1044 for ix in right:
1045 if ix in out:
1046 b_keep.append(ix)
1047
1048 if not con_inds:
1049 # contraction is pure multiplication, prepare inputs differently
1050 return _parse_eq_to_pure_multiplication(
1051 a_term, shape_a, b_term, shape_b, out
1052 )
1053
1054 # only need the size one indices that appear in the output
1055 singletons = [ix for ix in out if ix in singletons]
1056
1057 # take diagonal, remove any trivial axes and transpose left
1058 desired_a = "".join((*bat_inds, *a_keep, *con_inds))
1059 if a_term != desired_a:
1060 eq_a = f"{a_term}->{desired_a}"
1061 else:
1062 eq_a = None
1063
1064 # take diagonal, remove any trivial axes and transpose right
1065 desired_b = "".join((*bat_inds, *con_inds, *b_keep))
1066 if b_term != desired_b:
1067 eq_b = f"{b_term}->{desired_b}"
1068 else:
1069 eq_b = None
1070
1071 # then we want to reshape
1072 if bat_inds:
1073 lgroups = (bat_inds, a_keep, con_inds)
1074 rgroups = (bat_inds, con_inds, b_keep)
1075 ogroups = (bat_inds, a_keep, b_keep)
1076 else:
1077 # avoid size 1 batch dimension if no batch indices
1078 lgroups = (a_keep, con_inds)
1079 rgroups = (con_inds, b_keep)
1080 ogroups = (a_keep, b_keep)
1081
1082 if any(len(group) != 1 for group in lgroups):
1083 # need to fuse 'kept' and contracted indices
1084 # (though could allow batch indices to be broadcast)
1085 new_shape_a = tuple(
1086 functools.reduce(operator.mul, (sizes[ix] for ix in ix_group), 1)
1087 for ix_group in lgroups
1088 )
1089 else:
1090 new_shape_a = None
1091
1092 if any(len(group) != 1 for group in rgroups):
1093 # need to fuse 'kept' and contracted indices
1094 # (though could allow batch indices to be broadcast)
1095 new_shape_b = tuple(
1096 functools.reduce(operator.mul, (sizes[ix] for ix in ix_group), 1)
1097 for ix_group in rgroups
1098 )
1099 else:
1100 new_shape_b = None
1101
1102 if any(len(group) != 1 for group in ogroups) or singletons:
1103 new_shape_ab = (1,) * len(singletons) + tuple(
1104 sizes[ix] for ix_group in ogroups for ix in ix_group
1105 )
1106 else:
1107 new_shape_ab = None
1108
1109 # then we might need to permute the matmul produced output:
1110 out_produced = "".join((*singletons, *bat_inds, *a_keep, *b_keep))
1111 if out_produced != out:
1112 perm_ab = tuple(out_produced.index(ix) for ix in out)
1113 else:
1114 perm_ab = None
1115
1116 return (
1117 eq_a,
1118 eq_b,
1119 new_shape_a,
1120 new_shape_b,
1121 new_shape_ab,
1122 perm_ab,
1123 False, # pure_multiplication=False
1124 )
1125
1126
1127@functools.lru_cache(maxsize=64)
1128def _parse_output_order(order, a_is_fcontig, b_is_fcontig):
1129 order = order.upper()
1130 if order == "K":
1131 return None
1132 elif order in "CF":
1133 return order
1134 elif order == "A":
1135 if a_is_fcontig and b_is_fcontig:
1136 return "F"
1137 else:
1138 return "C"
1139 else:
1140 raise ValueError(
1141 "ValueError: order must be one of "
1142 f"'C', 'F', 'A', or 'K' (got '{order}')"
1143 )
1144
1145
1146def bmm_einsum(eq, a, b, out=None, **kwargs):
1147 """Perform arbitrary pairwise einsums using only ``matmul``, or
1148 ``multiply`` if no contracted indices are involved (plus maybe single term
1149 ``einsum`` to prepare the terms individually). The logic for each is cached
1150 based on the equation and array shape, and each step is only performed if
1151 necessary.
1152
1153 Parameters
1154 ----------
1155 eq : str
1156 The einsum equation.
1157 a : array_like
1158 The first array to contract.
1159 b : array_like
1160 The second array to contract.
1161
1162 Returns
1163 -------
1164 array_like
1165
1166 Notes
1167 -----
1168 A fuller description of this algorithm, and original source for this
1169 implementation, can be found at https://github.com/jcmgray/einsum_bmm.
1170 """
1171 (
1172 eq_a,
1173 eq_b,
1174 new_shape_a,
1175 new_shape_b,
1176 new_shape_ab,
1177 perm_ab,
1178 pure_multiplication,
1179 ) = _parse_eq_to_batch_matmul(eq, a.shape, b.shape)
1180
1181 # n.b. one could special case various cases to call c_einsum directly here
1182
1183 # need to handle `order` a little manually, since we do transpose
1184 # operations before and potentially after the ufunc calls
1185 output_order = _parse_output_order(
1186 kwargs.pop("order", "K"), a.flags.f_contiguous, b.flags.f_contiguous
1187 )
1188
1189 # prepare left
1190 if eq_a is not None:
1191 # diagonals, sums, and tranpose
1192 a = c_einsum(eq_a, a)
1193 if new_shape_a is not None:
1194 a = reshape(a, new_shape_a)
1195
1196 # prepare right
1197 if eq_b is not None:
1198 # diagonals, sums, and tranpose
1199 b = c_einsum(eq_b, b)
1200 if new_shape_b is not None:
1201 b = reshape(b, new_shape_b)
1202
1203 if pure_multiplication:
1204 # no contracted indices
1205 if output_order is not None:
1206 kwargs["order"] = output_order
1207
1208 # do the 'contraction' via multiplication!
1209 return multiply(a, b, out=out, **kwargs)
1210
1211 # can only supply out here if no other reshaping / transposing
1212 matmul_out_compatible = (new_shape_ab is None) and (perm_ab is None)
1213 if matmul_out_compatible:
1214 kwargs["out"] = out
1215
1216 # do the contraction!
1217 ab = matmul(a, b, **kwargs)
1218
1219 # prepare the output
1220 if new_shape_ab is not None:
1221 ab = reshape(ab, new_shape_ab)
1222 if perm_ab is not None:
1223 ab = ab.transpose(perm_ab)
1224
1225 if (out is not None) and (not matmul_out_compatible):
1226 # handle case where out is specified, but we also needed
1227 # to reshape / transpose ``ab`` after the matmul
1228 out[...] = ab
1229 ab = out
1230 elif output_order is not None:
1231 ab = asanyarray(ab, order=output_order)
1232
1233 return ab
1234
1235
1236def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs):
1237 # Arguably we dispatch on more arguments than we really should; see note in
1238 # _einsum_path_dispatcher for why.
1239 yield from operands
1240 yield out
1241
1242
1243# Rewrite einsum to handle different cases
1244@array_function_dispatch(_einsum_dispatcher, module='numpy')
1245def einsum(*operands, out=None, optimize=False, **kwargs):
1246 """
1247 einsum(subscripts, *operands, out=None, dtype=None, order='K',
1248 casting='safe', optimize=False)
1249
1250 Evaluates the Einstein summation convention on the operands.
1251
1252 Using the Einstein summation convention, many common multi-dimensional,
1253 linear algebraic array operations can be represented in a simple fashion.
1254 In *implicit* mode `einsum` computes these values.
1255
1256 In *explicit* mode, `einsum` provides further flexibility to compute
1257 other array operations that might not be considered classical Einstein
1258 summation operations, by disabling, or forcing summation over specified
1259 subscript labels.
1260
1261 See the notes and examples for clarification.
1262
1263 Parameters
1264 ----------
1265 subscripts : str
1266 Specifies the subscripts for summation as comma separated list of
1267 subscript labels. An implicit (classical Einstein summation)
1268 calculation is performed unless the explicit indicator '->' is
1269 included as well as subscript labels of the precise output form.
1270 operands : list of array_like
1271 These are the arrays for the operation.
1272 out : ndarray, optional
1273 If provided, the calculation is done into this array.
1274 dtype : {data-type, None}, optional
1275 If provided, forces the calculation to use the data type specified.
1276 Note that you may have to also give a more liberal `casting`
1277 parameter to allow the conversions. Default is None.
1278 order : {'C', 'F', 'A', 'K'}, optional
1279 Controls the memory layout of the output. 'C' means it should
1280 be C contiguous. 'F' means it should be Fortran contiguous,
1281 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise.
1282 'K' means it should be as close to the layout as the inputs as
1283 is possible, including arbitrarily permuted axes.
1284 Default is 'K'.
1285 casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
1286 Controls what kind of data casting may occur. Setting this to
1287 'unsafe' is not recommended, as it can adversely affect accumulations.
1288
1289 * 'no' means the data types should not be cast at all.
1290 * 'equiv' means only byte-order changes are allowed.
1291 * 'safe' means only casts which can preserve values are allowed.
1292 * 'same_kind' means only safe casts or casts within a kind,
1293 like float64 to float32, are allowed.
1294 * 'unsafe' means any data conversions may be done.
1295
1296 Default is 'safe'.
1297 optimize : {False, True, 'greedy', 'optimal'}, optional
1298 Controls if intermediate optimization should occur. No optimization
1299 will occur if False and True will default to the 'greedy' algorithm.
1300 Also accepts an explicit contraction list from the ``np.einsum_path``
1301 function. See ``np.einsum_path`` for more details. Defaults to False.
1302
1303 Returns
1304 -------
1305 output : ndarray
1306 The calculation based on the Einstein summation convention.
1307
1308 See Also
1309 --------
1310 einsum_path, dot, inner, outer, tensordot, linalg.multi_dot
1311 einsum:
1312 Similar verbose interface is provided by the
1313 `einops <https://github.com/arogozhnikov/einops>`_ package to cover
1314 additional operations: transpose, reshape/flatten, repeat/tile,
1315 squeeze/unsqueeze and reductions.
1316 The `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_
1317 optimizes contraction order for einsum-like expressions
1318 in backend-agnostic manner.
1319
1320 Notes
1321 -----
1322 The Einstein summation convention can be used to compute
1323 many multi-dimensional, linear algebraic array operations. `einsum`
1324 provides a succinct way of representing these.
1325
1326 A non-exhaustive list of these operations,
1327 which can be computed by `einsum`, is shown below along with examples:
1328
1329 * Trace of an array, :py:func:`numpy.trace`.
1330 * Return a diagonal, :py:func:`numpy.diag`.
1331 * Array axis summations, :py:func:`numpy.sum`.
1332 * Transpositions and permutations, :py:func:`numpy.transpose`.
1333 * Matrix multiplication and dot product, :py:func:`numpy.matmul`
1334 :py:func:`numpy.dot`.
1335 * Vector inner and outer products, :py:func:`numpy.inner`
1336 :py:func:`numpy.outer`.
1337 * Broadcasting, element-wise and scalar multiplication,
1338 :py:func:`numpy.multiply`.
1339 * Tensor contractions, :py:func:`numpy.tensordot`.
1340 * Chained array operations, in efficient calculation order,
1341 :py:func:`numpy.einsum_path`.
1342
1343 The subscripts string is a comma-separated list of subscript labels,
1344 where each label refers to a dimension of the corresponding operand.
1345 Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)``
1346 is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label
1347 appears only once, it is not summed, so ``np.einsum('i', a)``
1348 produces a view of ``a`` with no changes. A further example
1349 ``np.einsum('ij,jk', a, b)`` describes traditional matrix multiplication
1350 and is equivalent to :py:func:`np.matmul(a,b) <numpy.matmul>`.
1351 Repeated subscript labels in one operand take the diagonal.
1352 For example, ``np.einsum('ii', a)`` is equivalent to
1353 :py:func:`np.trace(a) <numpy.trace>`.
1354
1355 In *implicit mode*, the chosen subscripts are important
1356 since the axes of the output are reordered alphabetically. This
1357 means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while
1358 ``np.einsum('ji', a)`` takes its transpose. Additionally,
1359 ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while,
1360 ``np.einsum('ij,jh', a, b)`` returns the transpose of the
1361 multiplication since subscript 'h' precedes subscript 'i'.
1362
1363 In *explicit mode* the output can be directly controlled by
1364 specifying output subscript labels. This requires the
1365 identifier '->' as well as the list of output subscript labels.
1366 This feature increases the flexibility of the function since
1367 summing can be disabled or forced when required. The call
1368 ``np.einsum('i->', a)`` is like :py:func:`np.sum(a) <numpy.sum>`
1369 if ``a`` is a 1-D array, and ``np.einsum('ii->i', a)``
1370 is like :py:func:`np.diag(a) <numpy.diag>` if ``a`` is a square 2-D array.
1371 The difference is that `einsum` does not allow broadcasting by default.
1372 Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the
1373 order of the output subscript labels and therefore returns matrix
1374 multiplication, unlike the example above in implicit mode.
1375
1376 To enable and control broadcasting, use an ellipsis. Default
1377 NumPy-style broadcasting is done by adding an ellipsis
1378 to the left of each term, like ``np.einsum('...ii->...i', a)``.
1379 ``np.einsum('...i->...', a)`` is like
1380 :py:func:`np.sum(a, axis=-1) <numpy.sum>` for array ``a`` of any shape.
1381 To take the trace along the first and last axes,
1382 you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix
1383 product with the left-most indices instead of rightmost, one can do
1384 ``np.einsum('ij...,jk...->ik...', a, b)``.
1385
1386 When there is only one operand, no axes are summed, and no output
1387 parameter is provided, a view into the operand is returned instead
1388 of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)``
1389 produces a view (changed in version 1.10.0).
1390
1391 `einsum` also provides an alternative way to provide the subscripts and
1392 operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``.
1393 If the output shape is not provided in this format `einsum` will be
1394 calculated in implicit mode, otherwise it will be performed explicitly.
1395 The examples below have corresponding `einsum` calls with the two
1396 parameter methods.
1397
1398 Views returned from einsum are now writeable whenever the input array
1399 is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now
1400 have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>`
1401 and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal
1402 of a 2D array.
1403
1404 Added the ``optimize`` argument which will optimize the contraction order
1405 of an einsum expression. For a contraction with three or more operands
1406 this can greatly increase the computational efficiency at the cost of
1407 a larger memory footprint during computation.
1408
1409 Typically a 'greedy' algorithm is applied which empirical tests have shown
1410 returns the optimal path in the majority of cases. In some cases 'optimal'
1411 will return the superlative path through a more expensive, exhaustive
1412 search. For iterative calculations it may be advisable to calculate
1413 the optimal path once and reuse that path by supplying it as an argument.
1414 An example is given below.
1415
1416 See :py:func:`numpy.einsum_path` for more details.
1417
1418 Examples
1419 --------
1420 >>> a = np.arange(25).reshape(5,5)
1421 >>> b = np.arange(5)
1422 >>> c = np.arange(6).reshape(2,3)
1423
1424 Trace of a matrix:
1425
1426 >>> np.einsum('ii', a)
1427 60
1428 >>> np.einsum(a, [0,0])
1429 60
1430 >>> np.trace(a)
1431 60
1432
1433 Extract the diagonal (requires explicit form):
1434
1435 >>> np.einsum('ii->i', a)
1436 array([ 0, 6, 12, 18, 24])
1437 >>> np.einsum(a, [0,0], [0])
1438 array([ 0, 6, 12, 18, 24])
1439 >>> np.diag(a)
1440 array([ 0, 6, 12, 18, 24])
1441
1442 Sum over an axis (requires explicit form):
1443
1444 >>> np.einsum('ij->i', a)
1445 array([ 10, 35, 60, 85, 110])
1446 >>> np.einsum(a, [0,1], [0])
1447 array([ 10, 35, 60, 85, 110])
1448 >>> np.sum(a, axis=1)
1449 array([ 10, 35, 60, 85, 110])
1450
1451 For higher dimensional arrays summing a single axis can be done
1452 with ellipsis:
1453
1454 >>> np.einsum('...j->...', a)
1455 array([ 10, 35, 60, 85, 110])
1456 >>> np.einsum(a, [Ellipsis,1], [Ellipsis])
1457 array([ 10, 35, 60, 85, 110])
1458
1459 Compute a matrix transpose, or reorder any number of axes:
1460
1461 >>> np.einsum('ji', c)
1462 array([[0, 3],
1463 [1, 4],
1464 [2, 5]])
1465 >>> np.einsum('ij->ji', c)
1466 array([[0, 3],
1467 [1, 4],
1468 [2, 5]])
1469 >>> np.einsum(c, [1,0])
1470 array([[0, 3],
1471 [1, 4],
1472 [2, 5]])
1473 >>> np.transpose(c)
1474 array([[0, 3],
1475 [1, 4],
1476 [2, 5]])
1477
1478 Vector inner products:
1479
1480 >>> np.einsum('i,i', b, b)
1481 30
1482 >>> np.einsum(b, [0], b, [0])
1483 30
1484 >>> np.inner(b,b)
1485 30
1486
1487 Matrix vector multiplication:
1488
1489 >>> np.einsum('ij,j', a, b)
1490 array([ 30, 80, 130, 180, 230])
1491 >>> np.einsum(a, [0,1], b, [1])
1492 array([ 30, 80, 130, 180, 230])
1493 >>> np.dot(a, b)
1494 array([ 30, 80, 130, 180, 230])
1495 >>> np.einsum('...j,j', a, b)
1496 array([ 30, 80, 130, 180, 230])
1497
1498 Broadcasting and scalar multiplication:
1499
1500 >>> np.einsum('..., ...', 3, c)
1501 array([[ 0, 3, 6],
1502 [ 9, 12, 15]])
1503 >>> np.einsum(',ij', 3, c)
1504 array([[ 0, 3, 6],
1505 [ 9, 12, 15]])
1506 >>> np.einsum(3, [Ellipsis], c, [Ellipsis])
1507 array([[ 0, 3, 6],
1508 [ 9, 12, 15]])
1509 >>> np.multiply(3, c)
1510 array([[ 0, 3, 6],
1511 [ 9, 12, 15]])
1512
1513 Vector outer product:
1514
1515 >>> np.einsum('i,j', np.arange(2)+1, b)
1516 array([[0, 1, 2, 3, 4],
1517 [0, 2, 4, 6, 8]])
1518 >>> np.einsum(np.arange(2)+1, [0], b, [1])
1519 array([[0, 1, 2, 3, 4],
1520 [0, 2, 4, 6, 8]])
1521 >>> np.outer(np.arange(2)+1, b)
1522 array([[0, 1, 2, 3, 4],
1523 [0, 2, 4, 6, 8]])
1524
1525 Tensor contraction:
1526
1527 >>> a = np.arange(60.).reshape(3,4,5)
1528 >>> b = np.arange(24.).reshape(4,3,2)
1529 >>> np.einsum('ijk,jil->kl', a, b)
1530 array([[4400., 4730.],
1531 [4532., 4874.],
1532 [4664., 5018.],
1533 [4796., 5162.],
1534 [4928., 5306.]])
1535 >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
1536 array([[4400., 4730.],
1537 [4532., 4874.],
1538 [4664., 5018.],
1539 [4796., 5162.],
1540 [4928., 5306.]])
1541 >>> np.tensordot(a,b, axes=([1,0],[0,1]))
1542 array([[4400., 4730.],
1543 [4532., 4874.],
1544 [4664., 5018.],
1545 [4796., 5162.],
1546 [4928., 5306.]])
1547
1548 Writeable returned arrays (since version 1.10.0):
1549
1550 >>> a = np.zeros((3, 3))
1551 >>> np.einsum('ii->i', a)[:] = 1
1552 >>> a
1553 array([[1., 0., 0.],
1554 [0., 1., 0.],
1555 [0., 0., 1.]])
1556
1557 Example of ellipsis use:
1558
1559 >>> a = np.arange(6).reshape((3,2))
1560 >>> b = np.arange(12).reshape((4,3))
1561 >>> np.einsum('ki,jk->ij', a, b)
1562 array([[10, 28, 46, 64],
1563 [13, 40, 67, 94]])
1564 >>> np.einsum('ki,...k->i...', a, b)
1565 array([[10, 28, 46, 64],
1566 [13, 40, 67, 94]])
1567 >>> np.einsum('k...,jk', a, b)
1568 array([[10, 28, 46, 64],
1569 [13, 40, 67, 94]])
1570
1571 Chained array operations. For more complicated contractions, speed ups
1572 might be achieved by repeatedly computing a 'greedy' path or pre-computing
1573 the 'optimal' path and repeatedly applying it, using an `einsum_path`
1574 insertion (since version 1.12.0). Performance improvements can be
1575 particularly significant with larger arrays:
1576
1577 >>> a = np.ones(64).reshape(2,4,8)
1578
1579 Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.)
1580
1581 >>> for iteration in range(500):
1582 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
1583
1584 Sub-optimal `einsum` (due to repeated path calculation time): ~330ms
1585
1586 >>> for iteration in range(500):
1587 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a,
1588 ... optimize='optimal')
1589
1590 Greedy `einsum` (faster optimal path approximation): ~160ms
1591
1592 >>> for iteration in range(500):
1593 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
1594
1595 Optimal `einsum` (best usage pattern in some use cases): ~110ms
1596
1597 >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a,
1598 ... optimize='optimal')[0]
1599 >>> for iteration in range(500):
1600 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
1601
1602 """
1603 # Special handling if out is specified
1604 specified_out = out is not None
1605
1606 # If no optimization, run pure einsum
1607 if optimize is False:
1608 if specified_out:
1609 kwargs['out'] = out
1610 return c_einsum(*operands, **kwargs)
1611
1612 # Check the kwargs to avoid a more cryptic error later, without having to
1613 # repeat default values here
1614 valid_einsum_kwargs = ['dtype', 'order', 'casting']
1615 unknown_kwargs = [k for (k, v) in kwargs.items() if
1616 k not in valid_einsum_kwargs]
1617 if len(unknown_kwargs):
1618 raise TypeError(f"Did not understand the following kwargs: {unknown_kwargs}")
1619
1620 # Build the contraction list and operand
1621 operands, contraction_list = einsum_path(*operands, optimize=optimize,
1622 einsum_call=True)
1623
1624 # Start contraction loop
1625 for num, contraction in enumerate(contraction_list):
1626 inds, einsum_str, _ = contraction
1627 tmp_operands = [operands.pop(x) for x in inds]
1628
1629 # Do we need to deal with the output?
1630 handle_out = specified_out and ((num + 1) == len(contraction_list))
1631
1632 # If out was specified
1633 if handle_out:
1634 kwargs["out"] = out
1635
1636 if len(tmp_operands) == 2:
1637 # Call (batched) matrix multiplication if possible
1638 new_view = bmm_einsum(einsum_str, *tmp_operands, **kwargs)
1639 else:
1640 # Call einsum
1641 new_view = c_einsum(einsum_str, *tmp_operands, **kwargs)
1642
1643 # Append new items and dereference what we can
1644 operands.append(new_view)
1645 del tmp_operands, new_view
1646
1647 if specified_out:
1648 return out
1649 else:
1650 return operands[0]