1"""
2Implementation of optimized einsum.
3
4"""
5import itertools
6import operator
7
8from numpy._core.multiarray import c_einsum
9from numpy._core.numeric import asanyarray, tensordot
10from numpy._core.overrides import array_function_dispatch
11
12__all__ = ['einsum', 'einsum_path']
13
14# importing string for string.ascii_letters would be too slow
15# the first import before caching has been measured to take 800 µs (#23777)
16einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
17einsum_symbols_set = set(einsum_symbols)
18
19
20def _flop_count(idx_contraction, inner, num_terms, size_dictionary):
21 """
22 Computes the number of FLOPS in the contraction.
23
24 Parameters
25 ----------
26 idx_contraction : iterable
27 The indices involved in the contraction
28 inner : bool
29 Does this contraction require an inner product?
30 num_terms : int
31 The number of terms in a contraction
32 size_dictionary : dict
33 The size of each of the indices in idx_contraction
34
35 Returns
36 -------
37 flop_count : int
38 The total number of FLOPS required for the contraction.
39
40 Examples
41 --------
42
43 >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
44 30
45
46 >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
47 60
48
49 """
50
51 overall_size = _compute_size_by_dict(idx_contraction, size_dictionary)
52 op_factor = max(1, num_terms - 1)
53 if inner:
54 op_factor += 1
55
56 return overall_size * op_factor
57
58def _compute_size_by_dict(indices, idx_dict):
59 """
60 Computes the product of the elements in indices based on the dictionary
61 idx_dict.
62
63 Parameters
64 ----------
65 indices : iterable
66 Indices to base the product on.
67 idx_dict : dictionary
68 Dictionary of index sizes
69
70 Returns
71 -------
72 ret : int
73 The resulting product.
74
75 Examples
76 --------
77 >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
78 90
79
80 """
81 ret = 1
82 for i in indices:
83 ret *= idx_dict[i]
84 return ret
85
86
87def _find_contraction(positions, input_sets, output_set):
88 """
89 Finds the contraction for a given set of input and output sets.
90
91 Parameters
92 ----------
93 positions : iterable
94 Integer positions of terms used in the contraction.
95 input_sets : list
96 List of sets that represent the lhs side of the einsum subscript
97 output_set : set
98 Set that represents the rhs side of the overall einsum subscript
99
100 Returns
101 -------
102 new_result : set
103 The indices of the resulting contraction
104 remaining : list
105 List of sets that have not been contracted, the new set is appended to
106 the end of this list
107 idx_removed : set
108 Indices removed from the entire contraction
109 idx_contraction : set
110 The indices used in the current contraction
111
112 Examples
113 --------
114
115 # A simple dot product test case
116 >>> pos = (0, 1)
117 >>> isets = [set('ab'), set('bc')]
118 >>> oset = set('ac')
119 >>> _find_contraction(pos, isets, oset)
120 ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
121
122 # A more complex case with additional terms in the contraction
123 >>> pos = (0, 2)
124 >>> isets = [set('abd'), set('ac'), set('bdc')]
125 >>> oset = set('ac')
126 >>> _find_contraction(pos, isets, oset)
127 ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
128 """
129
130 idx_contract = set()
131 idx_remain = output_set.copy()
132 remaining = []
133 for ind, value in enumerate(input_sets):
134 if ind in positions:
135 idx_contract |= value
136 else:
137 remaining.append(value)
138 idx_remain |= value
139
140 new_result = idx_remain & idx_contract
141 idx_removed = (idx_contract - new_result)
142 remaining.append(new_result)
143
144 return (new_result, remaining, idx_removed, idx_contract)
145
146
147def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
148 """
149 Computes all possible pair contractions, sieves the results based
150 on ``memory_limit`` and returns the lowest cost path. This algorithm
151 scales factorial with respect to the elements in the list ``input_sets``.
152
153 Parameters
154 ----------
155 input_sets : list
156 List of sets that represent the lhs side of the einsum subscript
157 output_set : set
158 Set that represents the rhs side of the overall einsum subscript
159 idx_dict : dictionary
160 Dictionary of index sizes
161 memory_limit : int
162 The maximum number of elements in a temporary array
163
164 Returns
165 -------
166 path : list
167 The optimal contraction order within the memory limit constraint.
168
169 Examples
170 --------
171 >>> isets = [set('abd'), set('ac'), set('bdc')]
172 >>> oset = set()
173 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
174 >>> _optimal_path(isets, oset, idx_sizes, 5000)
175 [(0, 2), (0, 1)]
176 """
177
178 full_results = [(0, [], input_sets)]
179 for iteration in range(len(input_sets) - 1):
180 iter_results = []
181
182 # Compute all unique pairs
183 for curr in full_results:
184 cost, positions, remaining = curr
185 for con in itertools.combinations(
186 range(len(input_sets) - iteration), 2
187 ):
188
189 # Find the contraction
190 cont = _find_contraction(con, remaining, output_set)
191 new_result, new_input_sets, idx_removed, idx_contract = cont
192
193 # Sieve the results based on memory_limit
194 new_size = _compute_size_by_dict(new_result, idx_dict)
195 if new_size > memory_limit:
196 continue
197
198 # Build (total_cost, positions, indices_remaining)
199 total_cost = cost + _flop_count(
200 idx_contract, idx_removed, len(con), idx_dict
201 )
202 new_pos = positions + [con]
203 iter_results.append((total_cost, new_pos, new_input_sets))
204
205 # Update combinatorial list, if we did not find anything return best
206 # path + remaining contractions
207 if iter_results:
208 full_results = iter_results
209 else:
210 path = min(full_results, key=lambda x: x[0])[1]
211 path += [tuple(range(len(input_sets) - iteration))]
212 return path
213
214 # If we have not found anything return single einsum contraction
215 if len(full_results) == 0:
216 return [tuple(range(len(input_sets)))]
217
218 path = min(full_results, key=lambda x: x[0])[1]
219 return path
220
221def _parse_possible_contraction(
222 positions, input_sets, output_set, idx_dict,
223 memory_limit, path_cost, naive_cost
224 ):
225 """Compute the cost (removed size + flops) and resultant indices for
226 performing the contraction specified by ``positions``.
227
228 Parameters
229 ----------
230 positions : tuple of int
231 The locations of the proposed tensors to contract.
232 input_sets : list of sets
233 The indices found on each tensors.
234 output_set : set
235 The output indices of the expression.
236 idx_dict : dict
237 Mapping of each index to its size.
238 memory_limit : int
239 The total allowed size for an intermediary tensor.
240 path_cost : int
241 The contraction cost so far.
242 naive_cost : int
243 The cost of the unoptimized expression.
244
245 Returns
246 -------
247 cost : (int, int)
248 A tuple containing the size of any indices removed, and the flop cost.
249 positions : tuple of int
250 The locations of the proposed tensors to contract.
251 new_input_sets : list of sets
252 The resulting new list of indices if this proposed contraction
253 is performed.
254
255 """
256
257 # Find the contraction
258 contract = _find_contraction(positions, input_sets, output_set)
259 idx_result, new_input_sets, idx_removed, idx_contract = contract
260
261 # Sieve the results based on memory_limit
262 new_size = _compute_size_by_dict(idx_result, idx_dict)
263 if new_size > memory_limit:
264 return None
265
266 # Build sort tuple
267 old_sizes = (
268 _compute_size_by_dict(input_sets[p], idx_dict) for p in positions
269 )
270 removed_size = sum(old_sizes) - new_size
271
272 # NB: removed_size used to be just the size of any removed indices i.e.:
273 # helpers.compute_size_by_dict(idx_removed, idx_dict)
274 cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict)
275 sort = (-removed_size, cost)
276
277 # Sieve based on total cost as well
278 if (path_cost + cost) > naive_cost:
279 return None
280
281 # Add contraction to possible choices
282 return [sort, positions, new_input_sets]
283
284
285def _update_other_results(results, best):
286 """Update the positions and provisional input_sets of ``results``
287 based on performing the contraction result ``best``. Remove any
288 involving the tensors contracted.
289
290 Parameters
291 ----------
292 results : list
293 List of contraction results produced by
294 ``_parse_possible_contraction``.
295 best : list
296 The best contraction of ``results`` i.e. the one that
297 will be performed.
298
299 Returns
300 -------
301 mod_results : list
302 The list of modified results, updated with outcome of
303 ``best`` contraction.
304 """
305
306 best_con = best[1]
307 bx, by = best_con
308 mod_results = []
309
310 for cost, (x, y), con_sets in results:
311
312 # Ignore results involving tensors just contracted
313 if x in best_con or y in best_con:
314 continue
315
316 # Update the input_sets
317 del con_sets[by - int(by > x) - int(by > y)]
318 del con_sets[bx - int(bx > x) - int(bx > y)]
319 con_sets.insert(-1, best[2][-1])
320
321 # Update the position indices
322 mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by)
323 mod_results.append((cost, mod_con, con_sets))
324
325 return mod_results
326
327def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
328 """
329 Finds the path by contracting the best pair until the input list is
330 exhausted. The best pair is found by minimizing the tuple
331 ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing
332 matrix multiplication or inner product operations, then Hadamard like
333 operations, and finally outer operations. Outer products are limited by
334 ``memory_limit``. This algorithm scales cubically with respect to the
335 number of elements in the list ``input_sets``.
336
337 Parameters
338 ----------
339 input_sets : list
340 List of sets that represent the lhs side of the einsum subscript
341 output_set : set
342 Set that represents the rhs side of the overall einsum subscript
343 idx_dict : dictionary
344 Dictionary of index sizes
345 memory_limit : int
346 The maximum number of elements in a temporary array
347
348 Returns
349 -------
350 path : list
351 The greedy contraction order within the memory limit constraint.
352
353 Examples
354 --------
355 >>> isets = [set('abd'), set('ac'), set('bdc')]
356 >>> oset = set()
357 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
358 >>> _greedy_path(isets, oset, idx_sizes, 5000)
359 [(0, 2), (0, 1)]
360 """
361
362 # Handle trivial cases that leaked through
363 if len(input_sets) == 1:
364 return [(0,)]
365 elif len(input_sets) == 2:
366 return [(0, 1)]
367
368 # Build up a naive cost
369 contract = _find_contraction(
370 range(len(input_sets)), input_sets, output_set
371 )
372 idx_result, new_input_sets, idx_removed, idx_contract = contract
373 naive_cost = _flop_count(
374 idx_contract, idx_removed, len(input_sets), idx_dict
375 )
376
377 # Initially iterate over all pairs
378 comb_iter = itertools.combinations(range(len(input_sets)), 2)
379 known_contractions = []
380
381 path_cost = 0
382 path = []
383
384 for iteration in range(len(input_sets) - 1):
385
386 # Iterate over all pairs on the first step, only previously
387 # found pairs on subsequent steps
388 for positions in comb_iter:
389
390 # Always initially ignore outer products
391 if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]):
392 continue
393
394 result = _parse_possible_contraction(
395 positions, input_sets, output_set, idx_dict,
396 memory_limit, path_cost, naive_cost
397 )
398 if result is not None:
399 known_contractions.append(result)
400
401 # If we do not have a inner contraction, rescan pairs
402 # including outer products
403 if len(known_contractions) == 0:
404
405 # Then check the outer products
406 for positions in itertools.combinations(
407 range(len(input_sets)), 2
408 ):
409 result = _parse_possible_contraction(
410 positions, input_sets, output_set, idx_dict,
411 memory_limit, path_cost, naive_cost
412 )
413 if result is not None:
414 known_contractions.append(result)
415
416 # If we still did not find any remaining contractions,
417 # default back to einsum like behavior
418 if len(known_contractions) == 0:
419 path.append(tuple(range(len(input_sets))))
420 break
421
422 # Sort based on first index
423 best = min(known_contractions, key=lambda x: x[0])
424
425 # Now propagate as many unused contractions as possible
426 # to the next iteration
427 known_contractions = _update_other_results(known_contractions, best)
428
429 # Next iteration only compute contractions with the new tensor
430 # All other contractions have been accounted for
431 input_sets = best[2]
432 new_tensor_pos = len(input_sets) - 1
433 comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos))
434
435 # Update path and total cost
436 path.append(best[1])
437 path_cost += best[0][1]
438
439 return path
440
441
442def _can_dot(inputs, result, idx_removed):
443 """
444 Checks if we can use BLAS (np.tensordot) call and its beneficial to do so.
445
446 Parameters
447 ----------
448 inputs : list of str
449 Specifies the subscripts for summation.
450 result : str
451 Resulting summation.
452 idx_removed : set
453 Indices that are removed in the summation
454
455
456 Returns
457 -------
458 type : bool
459 Returns true if BLAS should and can be used, else False
460
461 Notes
462 -----
463 If the operations is BLAS level 1 or 2 and is not already aligned
464 we default back to einsum as the memory movement to copy is more
465 costly than the operation itself.
466
467
468 Examples
469 --------
470
471 # Standard GEMM operation
472 >>> _can_dot(['ij', 'jk'], 'ik', set('j'))
473 True
474
475 # Can use the standard BLAS, but requires odd data movement
476 >>> _can_dot(['ijj', 'jk'], 'ik', set('j'))
477 False
478
479 # DDOT where the memory is not aligned
480 >>> _can_dot(['ijk', 'ikj'], '', set('ijk'))
481 False
482
483 """
484
485 # All `dot` calls remove indices
486 if len(idx_removed) == 0:
487 return False
488
489 # BLAS can only handle two operands
490 if len(inputs) != 2:
491 return False
492
493 input_left, input_right = inputs
494
495 for c in set(input_left + input_right):
496 # can't deal with repeated indices on same input or more than 2 total
497 nl, nr = input_left.count(c), input_right.count(c)
498 if (nl > 1) or (nr > 1) or (nl + nr > 2):
499 return False
500
501 # can't do implicit summation or dimension collapse e.g.
502 # "ab,bc->c" (implicitly sum over 'a')
503 # "ab,ca->ca" (take diagonal of 'a')
504 if nl + nr - 1 == int(c in result):
505 return False
506
507 # Build a few temporaries
508 set_left = set(input_left)
509 set_right = set(input_right)
510 keep_left = set_left - idx_removed
511 keep_right = set_right - idx_removed
512 rs = len(idx_removed)
513
514 # At this point we are a DOT, GEMV, or GEMM operation
515
516 # Handle inner products
517
518 # DDOT with aligned data
519 if input_left == input_right:
520 return True
521
522 # DDOT without aligned data (better to use einsum)
523 if set_left == set_right:
524 return False
525
526 # Handle the 4 possible (aligned) GEMV or GEMM cases
527
528 # GEMM or GEMV no transpose
529 if input_left[-rs:] == input_right[:rs]:
530 return True
531
532 # GEMM or GEMV transpose both
533 if input_left[:rs] == input_right[-rs:]:
534 return True
535
536 # GEMM or GEMV transpose right
537 if input_left[-rs:] == input_right[-rs:]:
538 return True
539
540 # GEMM or GEMV transpose left
541 if input_left[:rs] == input_right[:rs]:
542 return True
543
544 # Einsum is faster than GEMV if we have to copy data
545 if not keep_left or not keep_right:
546 return False
547
548 # We are a matrix-matrix product, but we need to copy data
549 return True
550
551
552def _parse_einsum_input(operands):
553 """
554 A reproduction of einsum c side einsum parsing in python.
555
556 Returns
557 -------
558 input_strings : str
559 Parsed input strings
560 output_string : str
561 Parsed output string
562 operands : list of array_like
563 The operands to use in the numpy contraction
564
565 Examples
566 --------
567 The operand list is simplified to reduce printing:
568
569 >>> np.random.seed(123)
570 >>> a = np.random.rand(4, 4)
571 >>> b = np.random.rand(4, 4, 4)
572 >>> _parse_einsum_input(('...a,...a->...', a, b))
573 ('za,xza', 'xz', [a, b]) # may vary
574
575 >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
576 ('za,xza', 'xz', [a, b]) # may vary
577 """
578
579 if len(operands) == 0:
580 raise ValueError("No input operands")
581
582 if isinstance(operands[0], str):
583 subscripts = operands[0].replace(" ", "")
584 operands = [asanyarray(v) for v in operands[1:]]
585
586 # Ensure all characters are valid
587 for s in subscripts:
588 if s in '.,->':
589 continue
590 if s not in einsum_symbols:
591 raise ValueError("Character %s is not a valid symbol." % s)
592
593 else:
594 tmp_operands = list(operands)
595 operand_list = []
596 subscript_list = []
597 for p in range(len(operands) // 2):
598 operand_list.append(tmp_operands.pop(0))
599 subscript_list.append(tmp_operands.pop(0))
600
601 output_list = tmp_operands[-1] if len(tmp_operands) else None
602 operands = [asanyarray(v) for v in operand_list]
603 subscripts = ""
604 last = len(subscript_list) - 1
605 for num, sub in enumerate(subscript_list):
606 for s in sub:
607 if s is Ellipsis:
608 subscripts += "..."
609 else:
610 try:
611 s = operator.index(s)
612 except TypeError as e:
613 raise TypeError(
614 "For this input type lists must contain "
615 "either int or Ellipsis"
616 ) from e
617 subscripts += einsum_symbols[s]
618 if num != last:
619 subscripts += ","
620
621 if output_list is not None:
622 subscripts += "->"
623 for s in output_list:
624 if s is Ellipsis:
625 subscripts += "..."
626 else:
627 try:
628 s = operator.index(s)
629 except TypeError as e:
630 raise TypeError(
631 "For this input type lists must contain "
632 "either int or Ellipsis"
633 ) from e
634 subscripts += einsum_symbols[s]
635 # Check for proper "->"
636 if ("-" in subscripts) or (">" in subscripts):
637 invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
638 if invalid or (subscripts.count("->") != 1):
639 raise ValueError("Subscripts can only contain one '->'.")
640
641 # Parse ellipses
642 if "." in subscripts:
643 used = subscripts.replace(".", "").replace(",", "").replace("->", "")
644 unused = list(einsum_symbols_set - set(used))
645 ellipse_inds = "".join(unused)
646 longest = 0
647
648 if "->" in subscripts:
649 input_tmp, output_sub = subscripts.split("->")
650 split_subscripts = input_tmp.split(",")
651 out_sub = True
652 else:
653 split_subscripts = subscripts.split(',')
654 out_sub = False
655
656 for num, sub in enumerate(split_subscripts):
657 if "." in sub:
658 if (sub.count(".") != 3) or (sub.count("...") != 1):
659 raise ValueError("Invalid Ellipses.")
660
661 # Take into account numerical values
662 if operands[num].shape == ():
663 ellipse_count = 0
664 else:
665 ellipse_count = max(operands[num].ndim, 1)
666 ellipse_count -= (len(sub) - 3)
667
668 if ellipse_count > longest:
669 longest = ellipse_count
670
671 if ellipse_count < 0:
672 raise ValueError("Ellipses lengths do not match.")
673 elif ellipse_count == 0:
674 split_subscripts[num] = sub.replace('...', '')
675 else:
676 rep_inds = ellipse_inds[-ellipse_count:]
677 split_subscripts[num] = sub.replace('...', rep_inds)
678
679 subscripts = ",".join(split_subscripts)
680 if longest == 0:
681 out_ellipse = ""
682 else:
683 out_ellipse = ellipse_inds[-longest:]
684
685 if out_sub:
686 subscripts += "->" + output_sub.replace("...", out_ellipse)
687 else:
688 # Special care for outputless ellipses
689 output_subscript = ""
690 tmp_subscripts = subscripts.replace(",", "")
691 for s in sorted(set(tmp_subscripts)):
692 if s not in (einsum_symbols):
693 raise ValueError("Character %s is not a valid symbol." % s)
694 if tmp_subscripts.count(s) == 1:
695 output_subscript += s
696 normal_inds = ''.join(sorted(set(output_subscript) -
697 set(out_ellipse)))
698
699 subscripts += "->" + out_ellipse + normal_inds
700
701 # Build output string if does not exist
702 if "->" in subscripts:
703 input_subscripts, output_subscript = subscripts.split("->")
704 else:
705 input_subscripts = subscripts
706 # Build output subscripts
707 tmp_subscripts = subscripts.replace(",", "")
708 output_subscript = ""
709 for s in sorted(set(tmp_subscripts)):
710 if s not in einsum_symbols:
711 raise ValueError("Character %s is not a valid symbol." % s)
712 if tmp_subscripts.count(s) == 1:
713 output_subscript += s
714
715 # Make sure output subscripts are in the input
716 for char in output_subscript:
717 if output_subscript.count(char) != 1:
718 raise ValueError("Output character %s appeared more than once in "
719 "the output." % char)
720 if char not in input_subscripts:
721 raise ValueError("Output character %s did not appear in the input"
722 % char)
723
724 # Make sure number operands is equivalent to the number of terms
725 if len(input_subscripts.split(',')) != len(operands):
726 raise ValueError("Number of einsum subscripts must be equal to the "
727 "number of operands.")
728
729 return (input_subscripts, output_subscript, operands)
730
731
732def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None):
733 # NOTE: technically, we should only dispatch on array-like arguments, not
734 # subscripts (given as strings). But separating operands into
735 # arrays/subscripts is a little tricky/slow (given einsum's two supported
736 # signatures), so as a practical shortcut we dispatch on everything.
737 # Strings will be ignored for dispatching since they don't define
738 # __array_function__.
739 return operands
740
741
742@array_function_dispatch(_einsum_path_dispatcher, module='numpy')
743def einsum_path(*operands, optimize='greedy', einsum_call=False):
744 """
745 einsum_path(subscripts, *operands, optimize='greedy')
746
747 Evaluates the lowest cost contraction order for an einsum expression by
748 considering the creation of intermediate arrays.
749
750 Parameters
751 ----------
752 subscripts : str
753 Specifies the subscripts for summation.
754 *operands : list of array_like
755 These are the arrays for the operation.
756 optimize : {bool, list, tuple, 'greedy', 'optimal'}
757 Choose the type of path. If a tuple is provided, the second argument is
758 assumed to be the maximum intermediate size created. If only a single
759 argument is provided the largest input or output array size is used
760 as a maximum intermediate size.
761
762 * if a list is given that starts with ``einsum_path``, uses this as the
763 contraction path
764 * if False no optimization is taken
765 * if True defaults to the 'greedy' algorithm
766 * 'optimal' An algorithm that combinatorially explores all possible
767 ways of contracting the listed tensors and chooses the least costly
768 path. Scales exponentially with the number of terms in the
769 contraction.
770 * 'greedy' An algorithm that chooses the best pair contraction
771 at each step. Effectively, this algorithm searches the largest inner,
772 Hadamard, and then outer products at each step. Scales cubically with
773 the number of terms in the contraction. Equivalent to the 'optimal'
774 path for most contractions.
775
776 Default is 'greedy'.
777
778 Returns
779 -------
780 path : list of tuples
781 A list representation of the einsum path.
782 string_repr : str
783 A printable representation of the einsum path.
784
785 Notes
786 -----
787 The resulting path indicates which terms of the input contraction should be
788 contracted first, the result of this contraction is then appended to the
789 end of the contraction list. This list can then be iterated over until all
790 intermediate contractions are complete.
791
792 See Also
793 --------
794 einsum, linalg.multi_dot
795
796 Examples
797 --------
798
799 We can begin with a chain dot example. In this case, it is optimal to
800 contract the ``b`` and ``c`` tensors first as represented by the first
801 element of the path ``(1, 2)``. The resulting tensor is added to the end
802 of the contraction and the remaining contraction ``(0, 1)`` is then
803 completed.
804
805 >>> np.random.seed(123)
806 >>> a = np.random.rand(2, 2)
807 >>> b = np.random.rand(2, 5)
808 >>> c = np.random.rand(5, 2)
809 >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
810 >>> print(path_info[0])
811 ['einsum_path', (1, 2), (0, 1)]
812 >>> print(path_info[1])
813 Complete contraction: ij,jk,kl->il # may vary
814 Naive scaling: 4
815 Optimized scaling: 3
816 Naive FLOP count: 1.600e+02
817 Optimized FLOP count: 5.600e+01
818 Theoretical speedup: 2.857
819 Largest intermediate: 4.000e+00 elements
820 -------------------------------------------------------------------------
821 scaling current remaining
822 -------------------------------------------------------------------------
823 3 kl,jk->jl ij,jl->il
824 3 jl,ij->il il->il
825
826
827 A more complex index transformation example.
828
829 >>> I = np.random.rand(10, 10, 10, 10)
830 >>> C = np.random.rand(10, 10)
831 >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C,
832 ... optimize='greedy')
833
834 >>> print(path_info[0])
835 ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
836 >>> print(path_info[1])
837 Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary
838 Naive scaling: 8
839 Optimized scaling: 5
840 Naive FLOP count: 8.000e+08
841 Optimized FLOP count: 8.000e+05
842 Theoretical speedup: 1000.000
843 Largest intermediate: 1.000e+04 elements
844 --------------------------------------------------------------------------
845 scaling current remaining
846 --------------------------------------------------------------------------
847 5 abcd,ea->bcde fb,gc,hd,bcde->efgh
848 5 bcde,fb->cdef gc,hd,cdef->efgh
849 5 cdef,gc->defg hd,defg->efgh
850 5 defg,hd->efgh efgh->efgh
851 """
852
853 # Figure out what the path really is
854 path_type = optimize
855 if path_type is True:
856 path_type = 'greedy'
857 if path_type is None:
858 path_type = False
859
860 explicit_einsum_path = False
861 memory_limit = None
862
863 # No optimization or a named path algorithm
864 if (path_type is False) or isinstance(path_type, str):
865 pass
866
867 # Given an explicit path
868 elif len(path_type) and (path_type[0] == 'einsum_path'):
869 explicit_einsum_path = True
870
871 # Path tuple with memory limit
872 elif ((len(path_type) == 2) and isinstance(path_type[0], str) and
873 isinstance(path_type[1], (int, float))):
874 memory_limit = int(path_type[1])
875 path_type = path_type[0]
876
877 else:
878 raise TypeError("Did not understand the path: %s" % str(path_type))
879
880 # Hidden option, only einsum should call this
881 einsum_call_arg = einsum_call
882
883 # Python side parsing
884 input_subscripts, output_subscript, operands = (
885 _parse_einsum_input(operands)
886 )
887
888 # Build a few useful list and sets
889 input_list = input_subscripts.split(',')
890 input_sets = [set(x) for x in input_list]
891 output_set = set(output_subscript)
892 indices = set(input_subscripts.replace(',', ''))
893
894 # Get length of each unique dimension and ensure all dimensions are correct
895 dimension_dict = {}
896 broadcast_indices = [[] for x in range(len(input_list))]
897 for tnum, term in enumerate(input_list):
898 sh = operands[tnum].shape
899 if len(sh) != len(term):
900 raise ValueError("Einstein sum subscript %s does not contain the "
901 "correct number of indices for operand %d."
902 % (input_subscripts[tnum], tnum))
903 for cnum, char in enumerate(term):
904 dim = sh[cnum]
905
906 # Build out broadcast indices
907 if dim == 1:
908 broadcast_indices[tnum].append(char)
909
910 if char in dimension_dict.keys():
911 # For broadcasting cases we always want the largest dim size
912 if dimension_dict[char] == 1:
913 dimension_dict[char] = dim
914 elif dim not in (1, dimension_dict[char]):
915 raise ValueError("Size of label '%s' for operand %d (%d) "
916 "does not match previous terms (%d)."
917 % (char, tnum, dimension_dict[char], dim))
918 else:
919 dimension_dict[char] = dim
920
921 # Convert broadcast inds to sets
922 broadcast_indices = [set(x) for x in broadcast_indices]
923
924 # Compute size of each input array plus the output array
925 size_list = [_compute_size_by_dict(term, dimension_dict)
926 for term in input_list + [output_subscript]]
927 max_size = max(size_list)
928
929 if memory_limit is None:
930 memory_arg = max_size
931 else:
932 memory_arg = memory_limit
933
934 # Compute naive cost
935 # This isn't quite right, need to look into exactly how einsum does this
936 inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
937 naive_cost = _flop_count(
938 indices, inner_product, len(input_list), dimension_dict
939 )
940
941 # Compute the path
942 if explicit_einsum_path:
943 path = path_type[1:]
944 elif (
945 (path_type is False)
946 or (len(input_list) in [1, 2])
947 or (indices == output_set)
948 ):
949 # Nothing to be optimized, leave it to einsum
950 path = [tuple(range(len(input_list)))]
951 elif path_type == "greedy":
952 path = _greedy_path(
953 input_sets, output_set, dimension_dict, memory_arg
954 )
955 elif path_type == "optimal":
956 path = _optimal_path(
957 input_sets, output_set, dimension_dict, memory_arg
958 )
959 else:
960 raise KeyError("Path name %s not found", path_type)
961
962 cost_list, scale_list, size_list, contraction_list = [], [], [], []
963
964 # Build contraction tuple (positions, gemm, einsum_str, remaining)
965 for cnum, contract_inds in enumerate(path):
966 # Make sure we remove inds from right to left
967 contract_inds = tuple(sorted(contract_inds, reverse=True))
968
969 contract = _find_contraction(contract_inds, input_sets, output_set)
970 out_inds, input_sets, idx_removed, idx_contract = contract
971
972 cost = _flop_count(
973 idx_contract, idx_removed, len(contract_inds), dimension_dict
974 )
975 cost_list.append(cost)
976 scale_list.append(len(idx_contract))
977 size_list.append(_compute_size_by_dict(out_inds, dimension_dict))
978
979 bcast = set()
980 tmp_inputs = []
981 for x in contract_inds:
982 tmp_inputs.append(input_list.pop(x))
983 bcast |= broadcast_indices.pop(x)
984
985 new_bcast_inds = bcast - idx_removed
986
987 # If we're broadcasting, nix blas
988 if not len(idx_removed & bcast):
989 do_blas = _can_dot(tmp_inputs, out_inds, idx_removed)
990 else:
991 do_blas = False
992
993 # Last contraction
994 if (cnum - len(path)) == -1:
995 idx_result = output_subscript
996 else:
997 sort_result = [(dimension_dict[ind], ind) for ind in out_inds]
998 idx_result = "".join([x[1] for x in sorted(sort_result)])
999
1000 input_list.append(idx_result)
1001 broadcast_indices.append(new_bcast_inds)
1002 einsum_str = ",".join(tmp_inputs) + "->" + idx_result
1003
1004 contraction = (
1005 contract_inds, idx_removed, einsum_str, input_list[:], do_blas
1006 )
1007 contraction_list.append(contraction)
1008
1009 opt_cost = sum(cost_list) + 1
1010
1011 if len(input_list) != 1:
1012 # Explicit "einsum_path" is usually trusted, but we detect this kind of
1013 # mistake in order to prevent from returning an intermediate value.
1014 raise RuntimeError(
1015 "Invalid einsum_path is specified: {} more operands has to be "
1016 "contracted.".format(len(input_list) - 1))
1017
1018 if einsum_call_arg:
1019 return (operands, contraction_list)
1020
1021 # Return the path along with a nice string representation
1022 overall_contraction = input_subscripts + "->" + output_subscript
1023 header = ("scaling", "current", "remaining")
1024
1025 speedup = naive_cost / opt_cost
1026 max_i = max(size_list)
1027
1028 path_print = " Complete contraction: %s\n" % overall_contraction
1029 path_print += " Naive scaling: %d\n" % len(indices)
1030 path_print += " Optimized scaling: %d\n" % max(scale_list)
1031 path_print += " Naive FLOP count: %.3e\n" % naive_cost
1032 path_print += " Optimized FLOP count: %.3e\n" % opt_cost
1033 path_print += " Theoretical speedup: %3.3f\n" % speedup
1034 path_print += " Largest intermediate: %.3e elements\n" % max_i
1035 path_print += "-" * 74 + "\n"
1036 path_print += "%6s %24s %40s\n" % header
1037 path_print += "-" * 74
1038
1039 for n, contraction in enumerate(contraction_list):
1040 inds, idx_rm, einsum_str, remaining, blas = contraction
1041 remaining_str = ",".join(remaining) + "->" + output_subscript
1042 path_run = (scale_list[n], einsum_str, remaining_str)
1043 path_print += "\n%4d %24s %40s" % path_run
1044
1045 path = ['einsum_path'] + path
1046 return (path, path_print)
1047
1048
1049def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs):
1050 # Arguably we dispatch on more arguments than we really should; see note in
1051 # _einsum_path_dispatcher for why.
1052 yield from operands
1053 yield out
1054
1055
1056# Rewrite einsum to handle different cases
1057@array_function_dispatch(_einsum_dispatcher, module='numpy')
1058def einsum(*operands, out=None, optimize=False, **kwargs):
1059 """
1060 einsum(subscripts, *operands, out=None, dtype=None, order='K',
1061 casting='safe', optimize=False)
1062
1063 Evaluates the Einstein summation convention on the operands.
1064
1065 Using the Einstein summation convention, many common multi-dimensional,
1066 linear algebraic array operations can be represented in a simple fashion.
1067 In *implicit* mode `einsum` computes these values.
1068
1069 In *explicit* mode, `einsum` provides further flexibility to compute
1070 other array operations that might not be considered classical Einstein
1071 summation operations, by disabling, or forcing summation over specified
1072 subscript labels.
1073
1074 See the notes and examples for clarification.
1075
1076 Parameters
1077 ----------
1078 subscripts : str
1079 Specifies the subscripts for summation as comma separated list of
1080 subscript labels. An implicit (classical Einstein summation)
1081 calculation is performed unless the explicit indicator '->' is
1082 included as well as subscript labels of the precise output form.
1083 operands : list of array_like
1084 These are the arrays for the operation.
1085 out : ndarray, optional
1086 If provided, the calculation is done into this array.
1087 dtype : {data-type, None}, optional
1088 If provided, forces the calculation to use the data type specified.
1089 Note that you may have to also give a more liberal `casting`
1090 parameter to allow the conversions. Default is None.
1091 order : {'C', 'F', 'A', 'K'}, optional
1092 Controls the memory layout of the output. 'C' means it should
1093 be C contiguous. 'F' means it should be Fortran contiguous,
1094 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise.
1095 'K' means it should be as close to the layout as the inputs as
1096 is possible, including arbitrarily permuted axes.
1097 Default is 'K'.
1098 casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
1099 Controls what kind of data casting may occur. Setting this to
1100 'unsafe' is not recommended, as it can adversely affect accumulations.
1101
1102 * 'no' means the data types should not be cast at all.
1103 * 'equiv' means only byte-order changes are allowed.
1104 * 'safe' means only casts which can preserve values are allowed.
1105 * 'same_kind' means only safe casts or casts within a kind,
1106 like float64 to float32, are allowed.
1107 * 'unsafe' means any data conversions may be done.
1108
1109 Default is 'safe'.
1110 optimize : {False, True, 'greedy', 'optimal'}, optional
1111 Controls if intermediate optimization should occur. No optimization
1112 will occur if False and True will default to the 'greedy' algorithm.
1113 Also accepts an explicit contraction list from the ``np.einsum_path``
1114 function. See ``np.einsum_path`` for more details. Defaults to False.
1115
1116 Returns
1117 -------
1118 output : ndarray
1119 The calculation based on the Einstein summation convention.
1120
1121 See Also
1122 --------
1123 einsum_path, dot, inner, outer, tensordot, linalg.multi_dot
1124 einsum:
1125 Similar verbose interface is provided by the
1126 `einops <https://github.com/arogozhnikov/einops>`_ package to cover
1127 additional operations: transpose, reshape/flatten, repeat/tile,
1128 squeeze/unsqueeze and reductions.
1129 The `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_
1130 optimizes contraction order for einsum-like expressions
1131 in backend-agnostic manner.
1132
1133 Notes
1134 -----
1135 The Einstein summation convention can be used to compute
1136 many multi-dimensional, linear algebraic array operations. `einsum`
1137 provides a succinct way of representing these.
1138
1139 A non-exhaustive list of these operations,
1140 which can be computed by `einsum`, is shown below along with examples:
1141
1142 * Trace of an array, :py:func:`numpy.trace`.
1143 * Return a diagonal, :py:func:`numpy.diag`.
1144 * Array axis summations, :py:func:`numpy.sum`.
1145 * Transpositions and permutations, :py:func:`numpy.transpose`.
1146 * Matrix multiplication and dot product, :py:func:`numpy.matmul`
1147 :py:func:`numpy.dot`.
1148 * Vector inner and outer products, :py:func:`numpy.inner`
1149 :py:func:`numpy.outer`.
1150 * Broadcasting, element-wise and scalar multiplication,
1151 :py:func:`numpy.multiply`.
1152 * Tensor contractions, :py:func:`numpy.tensordot`.
1153 * Chained array operations, in efficient calculation order,
1154 :py:func:`numpy.einsum_path`.
1155
1156 The subscripts string is a comma-separated list of subscript labels,
1157 where each label refers to a dimension of the corresponding operand.
1158 Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)``
1159 is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label
1160 appears only once, it is not summed, so ``np.einsum('i', a)``
1161 produces a view of ``a`` with no changes. A further example
1162 ``np.einsum('ij,jk', a, b)`` describes traditional matrix multiplication
1163 and is equivalent to :py:func:`np.matmul(a,b) <numpy.matmul>`.
1164 Repeated subscript labels in one operand take the diagonal.
1165 For example, ``np.einsum('ii', a)`` is equivalent to
1166 :py:func:`np.trace(a) <numpy.trace>`.
1167
1168 In *implicit mode*, the chosen subscripts are important
1169 since the axes of the output are reordered alphabetically. This
1170 means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while
1171 ``np.einsum('ji', a)`` takes its transpose. Additionally,
1172 ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while,
1173 ``np.einsum('ij,jh', a, b)`` returns the transpose of the
1174 multiplication since subscript 'h' precedes subscript 'i'.
1175
1176 In *explicit mode* the output can be directly controlled by
1177 specifying output subscript labels. This requires the
1178 identifier '->' as well as the list of output subscript labels.
1179 This feature increases the flexibility of the function since
1180 summing can be disabled or forced when required. The call
1181 ``np.einsum('i->', a)`` is like :py:func:`np.sum(a) <numpy.sum>`
1182 if ``a`` is a 1-D array, and ``np.einsum('ii->i', a)``
1183 is like :py:func:`np.diag(a) <numpy.diag>` if ``a`` is a square 2-D array.
1184 The difference is that `einsum` does not allow broadcasting by default.
1185 Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the
1186 order of the output subscript labels and therefore returns matrix
1187 multiplication, unlike the example above in implicit mode.
1188
1189 To enable and control broadcasting, use an ellipsis. Default
1190 NumPy-style broadcasting is done by adding an ellipsis
1191 to the left of each term, like ``np.einsum('...ii->...i', a)``.
1192 ``np.einsum('...i->...', a)`` is like
1193 :py:func:`np.sum(a, axis=-1) <numpy.sum>` for array ``a`` of any shape.
1194 To take the trace along the first and last axes,
1195 you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix
1196 product with the left-most indices instead of rightmost, one can do
1197 ``np.einsum('ij...,jk...->ik...', a, b)``.
1198
1199 When there is only one operand, no axes are summed, and no output
1200 parameter is provided, a view into the operand is returned instead
1201 of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)``
1202 produces a view (changed in version 1.10.0).
1203
1204 `einsum` also provides an alternative way to provide the subscripts and
1205 operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``.
1206 If the output shape is not provided in this format `einsum` will be
1207 calculated in implicit mode, otherwise it will be performed explicitly.
1208 The examples below have corresponding `einsum` calls with the two
1209 parameter methods.
1210
1211 Views returned from einsum are now writeable whenever the input array
1212 is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now
1213 have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>`
1214 and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal
1215 of a 2D array.
1216
1217 Added the ``optimize`` argument which will optimize the contraction order
1218 of an einsum expression. For a contraction with three or more operands
1219 this can greatly increase the computational efficiency at the cost of
1220 a larger memory footprint during computation.
1221
1222 Typically a 'greedy' algorithm is applied which empirical tests have shown
1223 returns the optimal path in the majority of cases. In some cases 'optimal'
1224 will return the superlative path through a more expensive, exhaustive
1225 search. For iterative calculations it may be advisable to calculate
1226 the optimal path once and reuse that path by supplying it as an argument.
1227 An example is given below.
1228
1229 See :py:func:`numpy.einsum_path` for more details.
1230
1231 Examples
1232 --------
1233 >>> a = np.arange(25).reshape(5,5)
1234 >>> b = np.arange(5)
1235 >>> c = np.arange(6).reshape(2,3)
1236
1237 Trace of a matrix:
1238
1239 >>> np.einsum('ii', a)
1240 60
1241 >>> np.einsum(a, [0,0])
1242 60
1243 >>> np.trace(a)
1244 60
1245
1246 Extract the diagonal (requires explicit form):
1247
1248 >>> np.einsum('ii->i', a)
1249 array([ 0, 6, 12, 18, 24])
1250 >>> np.einsum(a, [0,0], [0])
1251 array([ 0, 6, 12, 18, 24])
1252 >>> np.diag(a)
1253 array([ 0, 6, 12, 18, 24])
1254
1255 Sum over an axis (requires explicit form):
1256
1257 >>> np.einsum('ij->i', a)
1258 array([ 10, 35, 60, 85, 110])
1259 >>> np.einsum(a, [0,1], [0])
1260 array([ 10, 35, 60, 85, 110])
1261 >>> np.sum(a, axis=1)
1262 array([ 10, 35, 60, 85, 110])
1263
1264 For higher dimensional arrays summing a single axis can be done
1265 with ellipsis:
1266
1267 >>> np.einsum('...j->...', a)
1268 array([ 10, 35, 60, 85, 110])
1269 >>> np.einsum(a, [Ellipsis,1], [Ellipsis])
1270 array([ 10, 35, 60, 85, 110])
1271
1272 Compute a matrix transpose, or reorder any number of axes:
1273
1274 >>> np.einsum('ji', c)
1275 array([[0, 3],
1276 [1, 4],
1277 [2, 5]])
1278 >>> np.einsum('ij->ji', c)
1279 array([[0, 3],
1280 [1, 4],
1281 [2, 5]])
1282 >>> np.einsum(c, [1,0])
1283 array([[0, 3],
1284 [1, 4],
1285 [2, 5]])
1286 >>> np.transpose(c)
1287 array([[0, 3],
1288 [1, 4],
1289 [2, 5]])
1290
1291 Vector inner products:
1292
1293 >>> np.einsum('i,i', b, b)
1294 30
1295 >>> np.einsum(b, [0], b, [0])
1296 30
1297 >>> np.inner(b,b)
1298 30
1299
1300 Matrix vector multiplication:
1301
1302 >>> np.einsum('ij,j', a, b)
1303 array([ 30, 80, 130, 180, 230])
1304 >>> np.einsum(a, [0,1], b, [1])
1305 array([ 30, 80, 130, 180, 230])
1306 >>> np.dot(a, b)
1307 array([ 30, 80, 130, 180, 230])
1308 >>> np.einsum('...j,j', a, b)
1309 array([ 30, 80, 130, 180, 230])
1310
1311 Broadcasting and scalar multiplication:
1312
1313 >>> np.einsum('..., ...', 3, c)
1314 array([[ 0, 3, 6],
1315 [ 9, 12, 15]])
1316 >>> np.einsum(',ij', 3, c)
1317 array([[ 0, 3, 6],
1318 [ 9, 12, 15]])
1319 >>> np.einsum(3, [Ellipsis], c, [Ellipsis])
1320 array([[ 0, 3, 6],
1321 [ 9, 12, 15]])
1322 >>> np.multiply(3, c)
1323 array([[ 0, 3, 6],
1324 [ 9, 12, 15]])
1325
1326 Vector outer product:
1327
1328 >>> np.einsum('i,j', np.arange(2)+1, b)
1329 array([[0, 1, 2, 3, 4],
1330 [0, 2, 4, 6, 8]])
1331 >>> np.einsum(np.arange(2)+1, [0], b, [1])
1332 array([[0, 1, 2, 3, 4],
1333 [0, 2, 4, 6, 8]])
1334 >>> np.outer(np.arange(2)+1, b)
1335 array([[0, 1, 2, 3, 4],
1336 [0, 2, 4, 6, 8]])
1337
1338 Tensor contraction:
1339
1340 >>> a = np.arange(60.).reshape(3,4,5)
1341 >>> b = np.arange(24.).reshape(4,3,2)
1342 >>> np.einsum('ijk,jil->kl', a, b)
1343 array([[4400., 4730.],
1344 [4532., 4874.],
1345 [4664., 5018.],
1346 [4796., 5162.],
1347 [4928., 5306.]])
1348 >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
1349 array([[4400., 4730.],
1350 [4532., 4874.],
1351 [4664., 5018.],
1352 [4796., 5162.],
1353 [4928., 5306.]])
1354 >>> np.tensordot(a,b, axes=([1,0],[0,1]))
1355 array([[4400., 4730.],
1356 [4532., 4874.],
1357 [4664., 5018.],
1358 [4796., 5162.],
1359 [4928., 5306.]])
1360
1361 Writeable returned arrays (since version 1.10.0):
1362
1363 >>> a = np.zeros((3, 3))
1364 >>> np.einsum('ii->i', a)[:] = 1
1365 >>> a
1366 array([[1., 0., 0.],
1367 [0., 1., 0.],
1368 [0., 0., 1.]])
1369
1370 Example of ellipsis use:
1371
1372 >>> a = np.arange(6).reshape((3,2))
1373 >>> b = np.arange(12).reshape((4,3))
1374 >>> np.einsum('ki,jk->ij', a, b)
1375 array([[10, 28, 46, 64],
1376 [13, 40, 67, 94]])
1377 >>> np.einsum('ki,...k->i...', a, b)
1378 array([[10, 28, 46, 64],
1379 [13, 40, 67, 94]])
1380 >>> np.einsum('k...,jk', a, b)
1381 array([[10, 28, 46, 64],
1382 [13, 40, 67, 94]])
1383
1384 Chained array operations. For more complicated contractions, speed ups
1385 might be achieved by repeatedly computing a 'greedy' path or pre-computing
1386 the 'optimal' path and repeatedly applying it, using an `einsum_path`
1387 insertion (since version 1.12.0). Performance improvements can be
1388 particularly significant with larger arrays:
1389
1390 >>> a = np.ones(64).reshape(2,4,8)
1391
1392 Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.)
1393
1394 >>> for iteration in range(500):
1395 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
1396
1397 Sub-optimal `einsum` (due to repeated path calculation time): ~330ms
1398
1399 >>> for iteration in range(500):
1400 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a,
1401 ... optimize='optimal')
1402
1403 Greedy `einsum` (faster optimal path approximation): ~160ms
1404
1405 >>> for iteration in range(500):
1406 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
1407
1408 Optimal `einsum` (best usage pattern in some use cases): ~110ms
1409
1410 >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a,
1411 ... optimize='optimal')[0]
1412 >>> for iteration in range(500):
1413 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
1414
1415 """
1416 # Special handling if out is specified
1417 specified_out = out is not None
1418
1419 # If no optimization, run pure einsum
1420 if optimize is False:
1421 if specified_out:
1422 kwargs['out'] = out
1423 return c_einsum(*operands, **kwargs)
1424
1425 # Check the kwargs to avoid a more cryptic error later, without having to
1426 # repeat default values here
1427 valid_einsum_kwargs = ['dtype', 'order', 'casting']
1428 unknown_kwargs = [k for (k, v) in kwargs.items() if
1429 k not in valid_einsum_kwargs]
1430 if len(unknown_kwargs):
1431 raise TypeError("Did not understand the following kwargs: %s"
1432 % unknown_kwargs)
1433
1434 # Build the contraction list and operand
1435 operands, contraction_list = einsum_path(*operands, optimize=optimize,
1436 einsum_call=True)
1437
1438 # Handle order kwarg for output array, c_einsum allows mixed case
1439 output_order = kwargs.pop('order', 'K')
1440 if output_order.upper() == 'A':
1441 if all(arr.flags.f_contiguous for arr in operands):
1442 output_order = 'F'
1443 else:
1444 output_order = 'C'
1445
1446 # Start contraction loop
1447 for num, contraction in enumerate(contraction_list):
1448 inds, idx_rm, einsum_str, remaining, blas = contraction
1449 tmp_operands = [operands.pop(x) for x in inds]
1450
1451 # Do we need to deal with the output?
1452 handle_out = specified_out and ((num + 1) == len(contraction_list))
1453
1454 # Call tensordot if still possible
1455 if blas:
1456 # Checks have already been handled
1457 input_str, results_index = einsum_str.split('->')
1458 input_left, input_right = input_str.split(',')
1459
1460 tensor_result = input_left + input_right
1461 for s in idx_rm:
1462 tensor_result = tensor_result.replace(s, "")
1463
1464 # Find indices to contract over
1465 left_pos, right_pos = [], []
1466 for s in sorted(idx_rm):
1467 left_pos.append(input_left.find(s))
1468 right_pos.append(input_right.find(s))
1469
1470 # Contract!
1471 new_view = tensordot(
1472 *tmp_operands, axes=(tuple(left_pos), tuple(right_pos))
1473 )
1474
1475 # Build a new view if needed
1476 if (tensor_result != results_index) or handle_out:
1477 if handle_out:
1478 kwargs["out"] = out
1479 new_view = c_einsum(
1480 tensor_result + '->' + results_index, new_view, **kwargs
1481 )
1482
1483 # Call einsum
1484 else:
1485 # If out was specified
1486 if handle_out:
1487 kwargs["out"] = out
1488
1489 # Do the contraction
1490 new_view = c_einsum(einsum_str, *tmp_operands, **kwargs)
1491
1492 # Append new items and dereference what we can
1493 operands.append(new_view)
1494 del tmp_operands, new_view
1495
1496 if specified_out:
1497 return out
1498 else:
1499 return asanyarray(operands[0], order=output_order)