1"""
2Implementation of optimized einsum.
3
4"""
5import itertools
6import operator
7
8from numpy.core.multiarray import c_einsum
9from numpy.core.numeric import asanyarray, tensordot
10from numpy.core.overrides import array_function_dispatch
11
12__all__ = ['einsum', 'einsum_path']
13
14einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
15einsum_symbols_set = set(einsum_symbols)
16
17
18def _flop_count(idx_contraction, inner, num_terms, size_dictionary):
19 """
20 Computes the number of FLOPS in the contraction.
21
22 Parameters
23 ----------
24 idx_contraction : iterable
25 The indices involved in the contraction
26 inner : bool
27 Does this contraction require an inner product?
28 num_terms : int
29 The number of terms in a contraction
30 size_dictionary : dict
31 The size of each of the indices in idx_contraction
32
33 Returns
34 -------
35 flop_count : int
36 The total number of FLOPS required for the contraction.
37
38 Examples
39 --------
40
41 >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
42 30
43
44 >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
45 60
46
47 """
48
49 overall_size = _compute_size_by_dict(idx_contraction, size_dictionary)
50 op_factor = max(1, num_terms - 1)
51 if inner:
52 op_factor += 1
53
54 return overall_size * op_factor
55
56def _compute_size_by_dict(indices, idx_dict):
57 """
58 Computes the product of the elements in indices based on the dictionary
59 idx_dict.
60
61 Parameters
62 ----------
63 indices : iterable
64 Indices to base the product on.
65 idx_dict : dictionary
66 Dictionary of index sizes
67
68 Returns
69 -------
70 ret : int
71 The resulting product.
72
73 Examples
74 --------
75 >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
76 90
77
78 """
79 ret = 1
80 for i in indices:
81 ret *= idx_dict[i]
82 return ret
83
84
85def _find_contraction(positions, input_sets, output_set):
86 """
87 Finds the contraction for a given set of input and output sets.
88
89 Parameters
90 ----------
91 positions : iterable
92 Integer positions of terms used in the contraction.
93 input_sets : list
94 List of sets that represent the lhs side of the einsum subscript
95 output_set : set
96 Set that represents the rhs side of the overall einsum subscript
97
98 Returns
99 -------
100 new_result : set
101 The indices of the resulting contraction
102 remaining : list
103 List of sets that have not been contracted, the new set is appended to
104 the end of this list
105 idx_removed : set
106 Indices removed from the entire contraction
107 idx_contraction : set
108 The indices used in the current contraction
109
110 Examples
111 --------
112
113 # A simple dot product test case
114 >>> pos = (0, 1)
115 >>> isets = [set('ab'), set('bc')]
116 >>> oset = set('ac')
117 >>> _find_contraction(pos, isets, oset)
118 ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
119
120 # A more complex case with additional terms in the contraction
121 >>> pos = (0, 2)
122 >>> isets = [set('abd'), set('ac'), set('bdc')]
123 >>> oset = set('ac')
124 >>> _find_contraction(pos, isets, oset)
125 ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
126 """
127
128 idx_contract = set()
129 idx_remain = output_set.copy()
130 remaining = []
131 for ind, value in enumerate(input_sets):
132 if ind in positions:
133 idx_contract |= value
134 else:
135 remaining.append(value)
136 idx_remain |= value
137
138 new_result = idx_remain & idx_contract
139 idx_removed = (idx_contract - new_result)
140 remaining.append(new_result)
141
142 return (new_result, remaining, idx_removed, idx_contract)
143
144
145def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
146 """
147 Computes all possible pair contractions, sieves the results based
148 on ``memory_limit`` and returns the lowest cost path. This algorithm
149 scales factorial with respect to the elements in the list ``input_sets``.
150
151 Parameters
152 ----------
153 input_sets : list
154 List of sets that represent the lhs side of the einsum subscript
155 output_set : set
156 Set that represents the rhs side of the overall einsum subscript
157 idx_dict : dictionary
158 Dictionary of index sizes
159 memory_limit : int
160 The maximum number of elements in a temporary array
161
162 Returns
163 -------
164 path : list
165 The optimal contraction order within the memory limit constraint.
166
167 Examples
168 --------
169 >>> isets = [set('abd'), set('ac'), set('bdc')]
170 >>> oset = set()
171 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
172 >>> _optimal_path(isets, oset, idx_sizes, 5000)
173 [(0, 2), (0, 1)]
174 """
175
176 full_results = [(0, [], input_sets)]
177 for iteration in range(len(input_sets) - 1):
178 iter_results = []
179
180 # Compute all unique pairs
181 for curr in full_results:
182 cost, positions, remaining = curr
183 for con in itertools.combinations(range(len(input_sets) - iteration), 2):
184
185 # Find the contraction
186 cont = _find_contraction(con, remaining, output_set)
187 new_result, new_input_sets, idx_removed, idx_contract = cont
188
189 # Sieve the results based on memory_limit
190 new_size = _compute_size_by_dict(new_result, idx_dict)
191 if new_size > memory_limit:
192 continue
193
194 # Build (total_cost, positions, indices_remaining)
195 total_cost = cost + _flop_count(idx_contract, idx_removed, len(con), idx_dict)
196 new_pos = positions + [con]
197 iter_results.append((total_cost, new_pos, new_input_sets))
198
199 # Update combinatorial list, if we did not find anything return best
200 # path + remaining contractions
201 if iter_results:
202 full_results = iter_results
203 else:
204 path = min(full_results, key=lambda x: x[0])[1]
205 path += [tuple(range(len(input_sets) - iteration))]
206 return path
207
208 # If we have not found anything return single einsum contraction
209 if len(full_results) == 0:
210 return [tuple(range(len(input_sets)))]
211
212 path = min(full_results, key=lambda x: x[0])[1]
213 return path
214
215def _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost):
216 """Compute the cost (removed size + flops) and resultant indices for
217 performing the contraction specified by ``positions``.
218
219 Parameters
220 ----------
221 positions : tuple of int
222 The locations of the proposed tensors to contract.
223 input_sets : list of sets
224 The indices found on each tensors.
225 output_set : set
226 The output indices of the expression.
227 idx_dict : dict
228 Mapping of each index to its size.
229 memory_limit : int
230 The total allowed size for an intermediary tensor.
231 path_cost : int
232 The contraction cost so far.
233 naive_cost : int
234 The cost of the unoptimized expression.
235
236 Returns
237 -------
238 cost : (int, int)
239 A tuple containing the size of any indices removed, and the flop cost.
240 positions : tuple of int
241 The locations of the proposed tensors to contract.
242 new_input_sets : list of sets
243 The resulting new list of indices if this proposed contraction is performed.
244
245 """
246
247 # Find the contraction
248 contract = _find_contraction(positions, input_sets, output_set)
249 idx_result, new_input_sets, idx_removed, idx_contract = contract
250
251 # Sieve the results based on memory_limit
252 new_size = _compute_size_by_dict(idx_result, idx_dict)
253 if new_size > memory_limit:
254 return None
255
256 # Build sort tuple
257 old_sizes = (_compute_size_by_dict(input_sets[p], idx_dict) for p in positions)
258 removed_size = sum(old_sizes) - new_size
259
260 # NB: removed_size used to be just the size of any removed indices i.e.:
261 # helpers.compute_size_by_dict(idx_removed, idx_dict)
262 cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict)
263 sort = (-removed_size, cost)
264
265 # Sieve based on total cost as well
266 if (path_cost + cost) > naive_cost:
267 return None
268
269 # Add contraction to possible choices
270 return [sort, positions, new_input_sets]
271
272
273def _update_other_results(results, best):
274 """Update the positions and provisional input_sets of ``results`` based on
275 performing the contraction result ``best``. Remove any involving the tensors
276 contracted.
277
278 Parameters
279 ----------
280 results : list
281 List of contraction results produced by ``_parse_possible_contraction``.
282 best : list
283 The best contraction of ``results`` i.e. the one that will be performed.
284
285 Returns
286 -------
287 mod_results : list
288 The list of modified results, updated with outcome of ``best`` contraction.
289 """
290
291 best_con = best[1]
292 bx, by = best_con
293 mod_results = []
294
295 for cost, (x, y), con_sets in results:
296
297 # Ignore results involving tensors just contracted
298 if x in best_con or y in best_con:
299 continue
300
301 # Update the input_sets
302 del con_sets[by - int(by > x) - int(by > y)]
303 del con_sets[bx - int(bx > x) - int(bx > y)]
304 con_sets.insert(-1, best[2][-1])
305
306 # Update the position indices
307 mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by)
308 mod_results.append((cost, mod_con, con_sets))
309
310 return mod_results
311
312def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
313 """
314 Finds the path by contracting the best pair until the input list is
315 exhausted. The best pair is found by minimizing the tuple
316 ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing
317 matrix multiplication or inner product operations, then Hadamard like
318 operations, and finally outer operations. Outer products are limited by
319 ``memory_limit``. This algorithm scales cubically with respect to the
320 number of elements in the list ``input_sets``.
321
322 Parameters
323 ----------
324 input_sets : list
325 List of sets that represent the lhs side of the einsum subscript
326 output_set : set
327 Set that represents the rhs side of the overall einsum subscript
328 idx_dict : dictionary
329 Dictionary of index sizes
330 memory_limit : int
331 The maximum number of elements in a temporary array
332
333 Returns
334 -------
335 path : list
336 The greedy contraction order within the memory limit constraint.
337
338 Examples
339 --------
340 >>> isets = [set('abd'), set('ac'), set('bdc')]
341 >>> oset = set()
342 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
343 >>> _greedy_path(isets, oset, idx_sizes, 5000)
344 [(0, 2), (0, 1)]
345 """
346
347 # Handle trivial cases that leaked through
348 if len(input_sets) == 1:
349 return [(0,)]
350 elif len(input_sets) == 2:
351 return [(0, 1)]
352
353 # Build up a naive cost
354 contract = _find_contraction(range(len(input_sets)), input_sets, output_set)
355 idx_result, new_input_sets, idx_removed, idx_contract = contract
356 naive_cost = _flop_count(idx_contract, idx_removed, len(input_sets), idx_dict)
357
358 # Initially iterate over all pairs
359 comb_iter = itertools.combinations(range(len(input_sets)), 2)
360 known_contractions = []
361
362 path_cost = 0
363 path = []
364
365 for iteration in range(len(input_sets) - 1):
366
367 # Iterate over all pairs on first step, only previously found pairs on subsequent steps
368 for positions in comb_iter:
369
370 # Always initially ignore outer products
371 if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]):
372 continue
373
374 result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost,
375 naive_cost)
376 if result is not None:
377 known_contractions.append(result)
378
379 # If we do not have a inner contraction, rescan pairs including outer products
380 if len(known_contractions) == 0:
381
382 # Then check the outer products
383 for positions in itertools.combinations(range(len(input_sets)), 2):
384 result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit,
385 path_cost, naive_cost)
386 if result is not None:
387 known_contractions.append(result)
388
389 # If we still did not find any remaining contractions, default back to einsum like behavior
390 if len(known_contractions) == 0:
391 path.append(tuple(range(len(input_sets))))
392 break
393
394 # Sort based on first index
395 best = min(known_contractions, key=lambda x: x[0])
396
397 # Now propagate as many unused contractions as possible to next iteration
398 known_contractions = _update_other_results(known_contractions, best)
399
400 # Next iteration only compute contractions with the new tensor
401 # All other contractions have been accounted for
402 input_sets = best[2]
403 new_tensor_pos = len(input_sets) - 1
404 comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos))
405
406 # Update path and total cost
407 path.append(best[1])
408 path_cost += best[0][1]
409
410 return path
411
412
413def _can_dot(inputs, result, idx_removed):
414 """
415 Checks if we can use BLAS (np.tensordot) call and its beneficial to do so.
416
417 Parameters
418 ----------
419 inputs : list of str
420 Specifies the subscripts for summation.
421 result : str
422 Resulting summation.
423 idx_removed : set
424 Indices that are removed in the summation
425
426
427 Returns
428 -------
429 type : bool
430 Returns true if BLAS should and can be used, else False
431
432 Notes
433 -----
434 If the operations is BLAS level 1 or 2 and is not already aligned
435 we default back to einsum as the memory movement to copy is more
436 costly than the operation itself.
437
438
439 Examples
440 --------
441
442 # Standard GEMM operation
443 >>> _can_dot(['ij', 'jk'], 'ik', set('j'))
444 True
445
446 # Can use the standard BLAS, but requires odd data movement
447 >>> _can_dot(['ijj', 'jk'], 'ik', set('j'))
448 False
449
450 # DDOT where the memory is not aligned
451 >>> _can_dot(['ijk', 'ikj'], '', set('ijk'))
452 False
453
454 """
455
456 # All `dot` calls remove indices
457 if len(idx_removed) == 0:
458 return False
459
460 # BLAS can only handle two operands
461 if len(inputs) != 2:
462 return False
463
464 input_left, input_right = inputs
465
466 for c in set(input_left + input_right):
467 # can't deal with repeated indices on same input or more than 2 total
468 nl, nr = input_left.count(c), input_right.count(c)
469 if (nl > 1) or (nr > 1) or (nl + nr > 2):
470 return False
471
472 # can't do implicit summation or dimension collapse e.g.
473 # "ab,bc->c" (implicitly sum over 'a')
474 # "ab,ca->ca" (take diagonal of 'a')
475 if nl + nr - 1 == int(c in result):
476 return False
477
478 # Build a few temporaries
479 set_left = set(input_left)
480 set_right = set(input_right)
481 keep_left = set_left - idx_removed
482 keep_right = set_right - idx_removed
483 rs = len(idx_removed)
484
485 # At this point we are a DOT, GEMV, or GEMM operation
486
487 # Handle inner products
488
489 # DDOT with aligned data
490 if input_left == input_right:
491 return True
492
493 # DDOT without aligned data (better to use einsum)
494 if set_left == set_right:
495 return False
496
497 # Handle the 4 possible (aligned) GEMV or GEMM cases
498
499 # GEMM or GEMV no transpose
500 if input_left[-rs:] == input_right[:rs]:
501 return True
502
503 # GEMM or GEMV transpose both
504 if input_left[:rs] == input_right[-rs:]:
505 return True
506
507 # GEMM or GEMV transpose right
508 if input_left[-rs:] == input_right[-rs:]:
509 return True
510
511 # GEMM or GEMV transpose left
512 if input_left[:rs] == input_right[:rs]:
513 return True
514
515 # Einsum is faster than GEMV if we have to copy data
516 if not keep_left or not keep_right:
517 return False
518
519 # We are a matrix-matrix product, but we need to copy data
520 return True
521
522
523def _parse_einsum_input(operands):
524 """
525 A reproduction of einsum c side einsum parsing in python.
526
527 Returns
528 -------
529 input_strings : str
530 Parsed input strings
531 output_string : str
532 Parsed output string
533 operands : list of array_like
534 The operands to use in the numpy contraction
535
536 Examples
537 --------
538 The operand list is simplified to reduce printing:
539
540 >>> np.random.seed(123)
541 >>> a = np.random.rand(4, 4)
542 >>> b = np.random.rand(4, 4, 4)
543 >>> _parse_einsum_input(('...a,...a->...', a, b))
544 ('za,xza', 'xz', [a, b]) # may vary
545
546 >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
547 ('za,xza', 'xz', [a, b]) # may vary
548 """
549
550 if len(operands) == 0:
551 raise ValueError("No input operands")
552
553 if isinstance(operands[0], str):
554 subscripts = operands[0].replace(" ", "")
555 operands = [asanyarray(v) for v in operands[1:]]
556
557 # Ensure all characters are valid
558 for s in subscripts:
559 if s in '.,->':
560 continue
561 if s not in einsum_symbols:
562 raise ValueError("Character %s is not a valid symbol." % s)
563
564 else:
565 tmp_operands = list(operands)
566 operand_list = []
567 subscript_list = []
568 for p in range(len(operands) // 2):
569 operand_list.append(tmp_operands.pop(0))
570 subscript_list.append(tmp_operands.pop(0))
571
572 output_list = tmp_operands[-1] if len(tmp_operands) else None
573 operands = [asanyarray(v) for v in operand_list]
574 subscripts = ""
575 last = len(subscript_list) - 1
576 for num, sub in enumerate(subscript_list):
577 for s in sub:
578 if s is Ellipsis:
579 subscripts += "..."
580 else:
581 try:
582 s = operator.index(s)
583 except TypeError as e:
584 raise TypeError("For this input type lists must contain "
585 "either int or Ellipsis") from e
586 subscripts += einsum_symbols[s]
587 if num != last:
588 subscripts += ","
589
590 if output_list is not None:
591 subscripts += "->"
592 for s in output_list:
593 if s is Ellipsis:
594 subscripts += "..."
595 else:
596 try:
597 s = operator.index(s)
598 except TypeError as e:
599 raise TypeError("For this input type lists must contain "
600 "either int or Ellipsis") from e
601 subscripts += einsum_symbols[s]
602 # Check for proper "->"
603 if ("-" in subscripts) or (">" in subscripts):
604 invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
605 if invalid or (subscripts.count("->") != 1):
606 raise ValueError("Subscripts can only contain one '->'.")
607
608 # Parse ellipses
609 if "." in subscripts:
610 used = subscripts.replace(".", "").replace(",", "").replace("->", "")
611 unused = list(einsum_symbols_set - set(used))
612 ellipse_inds = "".join(unused)
613 longest = 0
614
615 if "->" in subscripts:
616 input_tmp, output_sub = subscripts.split("->")
617 split_subscripts = input_tmp.split(",")
618 out_sub = True
619 else:
620 split_subscripts = subscripts.split(',')
621 out_sub = False
622
623 for num, sub in enumerate(split_subscripts):
624 if "." in sub:
625 if (sub.count(".") != 3) or (sub.count("...") != 1):
626 raise ValueError("Invalid Ellipses.")
627
628 # Take into account numerical values
629 if operands[num].shape == ():
630 ellipse_count = 0
631 else:
632 ellipse_count = max(operands[num].ndim, 1)
633 ellipse_count -= (len(sub) - 3)
634
635 if ellipse_count > longest:
636 longest = ellipse_count
637
638 if ellipse_count < 0:
639 raise ValueError("Ellipses lengths do not match.")
640 elif ellipse_count == 0:
641 split_subscripts[num] = sub.replace('...', '')
642 else:
643 rep_inds = ellipse_inds[-ellipse_count:]
644 split_subscripts[num] = sub.replace('...', rep_inds)
645
646 subscripts = ",".join(split_subscripts)
647 if longest == 0:
648 out_ellipse = ""
649 else:
650 out_ellipse = ellipse_inds[-longest:]
651
652 if out_sub:
653 subscripts += "->" + output_sub.replace("...", out_ellipse)
654 else:
655 # Special care for outputless ellipses
656 output_subscript = ""
657 tmp_subscripts = subscripts.replace(",", "")
658 for s in sorted(set(tmp_subscripts)):
659 if s not in (einsum_symbols):
660 raise ValueError("Character %s is not a valid symbol." % s)
661 if tmp_subscripts.count(s) == 1:
662 output_subscript += s
663 normal_inds = ''.join(sorted(set(output_subscript) -
664 set(out_ellipse)))
665
666 subscripts += "->" + out_ellipse + normal_inds
667
668 # Build output string if does not exist
669 if "->" in subscripts:
670 input_subscripts, output_subscript = subscripts.split("->")
671 else:
672 input_subscripts = subscripts
673 # Build output subscripts
674 tmp_subscripts = subscripts.replace(",", "")
675 output_subscript = ""
676 for s in sorted(set(tmp_subscripts)):
677 if s not in einsum_symbols:
678 raise ValueError("Character %s is not a valid symbol." % s)
679 if tmp_subscripts.count(s) == 1:
680 output_subscript += s
681
682 # Make sure output subscripts are in the input
683 for char in output_subscript:
684 if char not in input_subscripts:
685 raise ValueError("Output character %s did not appear in the input"
686 % char)
687
688 # Make sure number operands is equivalent to the number of terms
689 if len(input_subscripts.split(',')) != len(operands):
690 raise ValueError("Number of einsum subscripts must be equal to the "
691 "number of operands.")
692
693 return (input_subscripts, output_subscript, operands)
694
695
696def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None):
697 # NOTE: technically, we should only dispatch on array-like arguments, not
698 # subscripts (given as strings). But separating operands into
699 # arrays/subscripts is a little tricky/slow (given einsum's two supported
700 # signatures), so as a practical shortcut we dispatch on everything.
701 # Strings will be ignored for dispatching since they don't define
702 # __array_function__.
703 return operands
704
705
706@array_function_dispatch(_einsum_path_dispatcher, module='numpy')
707def einsum_path(*operands, optimize='greedy', einsum_call=False):
708 """
709 einsum_path(subscripts, *operands, optimize='greedy')
710
711 Evaluates the lowest cost contraction order for an einsum expression by
712 considering the creation of intermediate arrays.
713
714 Parameters
715 ----------
716 subscripts : str
717 Specifies the subscripts for summation.
718 *operands : list of array_like
719 These are the arrays for the operation.
720 optimize : {bool, list, tuple, 'greedy', 'optimal'}
721 Choose the type of path. If a tuple is provided, the second argument is
722 assumed to be the maximum intermediate size created. If only a single
723 argument is provided the largest input or output array size is used
724 as a maximum intermediate size.
725
726 * if a list is given that starts with ``einsum_path``, uses this as the
727 contraction path
728 * if False no optimization is taken
729 * if True defaults to the 'greedy' algorithm
730 * 'optimal' An algorithm that combinatorially explores all possible
731 ways of contracting the listed tensors and choosest the least costly
732 path. Scales exponentially with the number of terms in the
733 contraction.
734 * 'greedy' An algorithm that chooses the best pair contraction
735 at each step. Effectively, this algorithm searches the largest inner,
736 Hadamard, and then outer products at each step. Scales cubically with
737 the number of terms in the contraction. Equivalent to the 'optimal'
738 path for most contractions.
739
740 Default is 'greedy'.
741
742 Returns
743 -------
744 path : list of tuples
745 A list representation of the einsum path.
746 string_repr : str
747 A printable representation of the einsum path.
748
749 Notes
750 -----
751 The resulting path indicates which terms of the input contraction should be
752 contracted first, the result of this contraction is then appended to the
753 end of the contraction list. This list can then be iterated over until all
754 intermediate contractions are complete.
755
756 See Also
757 --------
758 einsum, linalg.multi_dot
759
760 Examples
761 --------
762
763 We can begin with a chain dot example. In this case, it is optimal to
764 contract the ``b`` and ``c`` tensors first as represented by the first
765 element of the path ``(1, 2)``. The resulting tensor is added to the end
766 of the contraction and the remaining contraction ``(0, 1)`` is then
767 completed.
768
769 >>> np.random.seed(123)
770 >>> a = np.random.rand(2, 2)
771 >>> b = np.random.rand(2, 5)
772 >>> c = np.random.rand(5, 2)
773 >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
774 >>> print(path_info[0])
775 ['einsum_path', (1, 2), (0, 1)]
776 >>> print(path_info[1])
777 Complete contraction: ij,jk,kl->il # may vary
778 Naive scaling: 4
779 Optimized scaling: 3
780 Naive FLOP count: 1.600e+02
781 Optimized FLOP count: 5.600e+01
782 Theoretical speedup: 2.857
783 Largest intermediate: 4.000e+00 elements
784 -------------------------------------------------------------------------
785 scaling current remaining
786 -------------------------------------------------------------------------
787 3 kl,jk->jl ij,jl->il
788 3 jl,ij->il il->il
789
790
791 A more complex index transformation example.
792
793 >>> I = np.random.rand(10, 10, 10, 10)
794 >>> C = np.random.rand(10, 10)
795 >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C,
796 ... optimize='greedy')
797
798 >>> print(path_info[0])
799 ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
800 >>> print(path_info[1])
801 Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary
802 Naive scaling: 8
803 Optimized scaling: 5
804 Naive FLOP count: 8.000e+08
805 Optimized FLOP count: 8.000e+05
806 Theoretical speedup: 1000.000
807 Largest intermediate: 1.000e+04 elements
808 --------------------------------------------------------------------------
809 scaling current remaining
810 --------------------------------------------------------------------------
811 5 abcd,ea->bcde fb,gc,hd,bcde->efgh
812 5 bcde,fb->cdef gc,hd,cdef->efgh
813 5 cdef,gc->defg hd,defg->efgh
814 5 defg,hd->efgh efgh->efgh
815 """
816
817 # Figure out what the path really is
818 path_type = optimize
819 if path_type is True:
820 path_type = 'greedy'
821 if path_type is None:
822 path_type = False
823
824 explicit_einsum_path = False
825 memory_limit = None
826
827 # No optimization or a named path algorithm
828 if (path_type is False) or isinstance(path_type, str):
829 pass
830
831 # Given an explicit path
832 elif len(path_type) and (path_type[0] == 'einsum_path'):
833 explicit_einsum_path = True
834
835 # Path tuple with memory limit
836 elif ((len(path_type) == 2) and isinstance(path_type[0], str) and
837 isinstance(path_type[1], (int, float))):
838 memory_limit = int(path_type[1])
839 path_type = path_type[0]
840
841 else:
842 raise TypeError("Did not understand the path: %s" % str(path_type))
843
844 # Hidden option, only einsum should call this
845 einsum_call_arg = einsum_call
846
847 # Python side parsing
848 input_subscripts, output_subscript, operands = _parse_einsum_input(operands)
849
850 # Build a few useful list and sets
851 input_list = input_subscripts.split(',')
852 input_sets = [set(x) for x in input_list]
853 output_set = set(output_subscript)
854 indices = set(input_subscripts.replace(',', ''))
855
856 # Get length of each unique dimension and ensure all dimensions are correct
857 dimension_dict = {}
858 broadcast_indices = [[] for x in range(len(input_list))]
859 for tnum, term in enumerate(input_list):
860 sh = operands[tnum].shape
861 if len(sh) != len(term):
862 raise ValueError("Einstein sum subscript %s does not contain the "
863 "correct number of indices for operand %d."
864 % (input_subscripts[tnum], tnum))
865 for cnum, char in enumerate(term):
866 dim = sh[cnum]
867
868 # Build out broadcast indices
869 if dim == 1:
870 broadcast_indices[tnum].append(char)
871
872 if char in dimension_dict.keys():
873 # For broadcasting cases we always want the largest dim size
874 if dimension_dict[char] == 1:
875 dimension_dict[char] = dim
876 elif dim not in (1, dimension_dict[char]):
877 raise ValueError("Size of label '%s' for operand %d (%d) "
878 "does not match previous terms (%d)."
879 % (char, tnum, dimension_dict[char], dim))
880 else:
881 dimension_dict[char] = dim
882
883 # Convert broadcast inds to sets
884 broadcast_indices = [set(x) for x in broadcast_indices]
885
886 # Compute size of each input array plus the output array
887 size_list = [_compute_size_by_dict(term, dimension_dict)
888 for term in input_list + [output_subscript]]
889 max_size = max(size_list)
890
891 if memory_limit is None:
892 memory_arg = max_size
893 else:
894 memory_arg = memory_limit
895
896 # Compute naive cost
897 # This isn't quite right, need to look into exactly how einsum does this
898 inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
899 naive_cost = _flop_count(indices, inner_product, len(input_list), dimension_dict)
900
901 # Compute the path
902 if explicit_einsum_path:
903 path = path_type[1:]
904 elif (
905 (path_type is False)
906 or (len(input_list) in [1, 2])
907 or (indices == output_set)
908 ):
909 # Nothing to be optimized, leave it to einsum
910 path = [tuple(range(len(input_list)))]
911 elif path_type == "greedy":
912 path = _greedy_path(input_sets, output_set, dimension_dict, memory_arg)
913 elif path_type == "optimal":
914 path = _optimal_path(input_sets, output_set, dimension_dict, memory_arg)
915 else:
916 raise KeyError("Path name %s not found", path_type)
917
918 cost_list, scale_list, size_list, contraction_list = [], [], [], []
919
920 # Build contraction tuple (positions, gemm, einsum_str, remaining)
921 for cnum, contract_inds in enumerate(path):
922 # Make sure we remove inds from right to left
923 contract_inds = tuple(sorted(list(contract_inds), reverse=True))
924
925 contract = _find_contraction(contract_inds, input_sets, output_set)
926 out_inds, input_sets, idx_removed, idx_contract = contract
927
928 cost = _flop_count(idx_contract, idx_removed, len(contract_inds), dimension_dict)
929 cost_list.append(cost)
930 scale_list.append(len(idx_contract))
931 size_list.append(_compute_size_by_dict(out_inds, dimension_dict))
932
933 bcast = set()
934 tmp_inputs = []
935 for x in contract_inds:
936 tmp_inputs.append(input_list.pop(x))
937 bcast |= broadcast_indices.pop(x)
938
939 new_bcast_inds = bcast - idx_removed
940
941 # If we're broadcasting, nix blas
942 if not len(idx_removed & bcast):
943 do_blas = _can_dot(tmp_inputs, out_inds, idx_removed)
944 else:
945 do_blas = False
946
947 # Last contraction
948 if (cnum - len(path)) == -1:
949 idx_result = output_subscript
950 else:
951 sort_result = [(dimension_dict[ind], ind) for ind in out_inds]
952 idx_result = "".join([x[1] for x in sorted(sort_result)])
953
954 input_list.append(idx_result)
955 broadcast_indices.append(new_bcast_inds)
956 einsum_str = ",".join(tmp_inputs) + "->" + idx_result
957
958 contraction = (contract_inds, idx_removed, einsum_str, input_list[:], do_blas)
959 contraction_list.append(contraction)
960
961 opt_cost = sum(cost_list) + 1
962
963 if len(input_list) != 1:
964 # Explicit "einsum_path" is usually trusted, but we detect this kind of
965 # mistake in order to prevent from returning an intermediate value.
966 raise RuntimeError(
967 "Invalid einsum_path is specified: {} more operands has to be "
968 "contracted.".format(len(input_list) - 1))
969
970 if einsum_call_arg:
971 return (operands, contraction_list)
972
973 # Return the path along with a nice string representation
974 overall_contraction = input_subscripts + "->" + output_subscript
975 header = ("scaling", "current", "remaining")
976
977 speedup = naive_cost / opt_cost
978 max_i = max(size_list)
979
980 path_print = " Complete contraction: %s\n" % overall_contraction
981 path_print += " Naive scaling: %d\n" % len(indices)
982 path_print += " Optimized scaling: %d\n" % max(scale_list)
983 path_print += " Naive FLOP count: %.3e\n" % naive_cost
984 path_print += " Optimized FLOP count: %.3e\n" % opt_cost
985 path_print += " Theoretical speedup: %3.3f\n" % speedup
986 path_print += " Largest intermediate: %.3e elements\n" % max_i
987 path_print += "-" * 74 + "\n"
988 path_print += "%6s %24s %40s\n" % header
989 path_print += "-" * 74
990
991 for n, contraction in enumerate(contraction_list):
992 inds, idx_rm, einsum_str, remaining, blas = contraction
993 remaining_str = ",".join(remaining) + "->" + output_subscript
994 path_run = (scale_list[n], einsum_str, remaining_str)
995 path_print += "\n%4d %24s %40s" % path_run
996
997 path = ['einsum_path'] + path
998 return (path, path_print)
999
1000
1001def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs):
1002 # Arguably we dispatch on more arguments than we really should; see note in
1003 # _einsum_path_dispatcher for why.
1004 yield from operands
1005 yield out
1006
1007
1008# Rewrite einsum to handle different cases
1009@array_function_dispatch(_einsum_dispatcher, module='numpy')
1010def einsum(*operands, out=None, optimize=False, **kwargs):
1011 """
1012 einsum(subscripts, *operands, out=None, dtype=None, order='K',
1013 casting='safe', optimize=False)
1014
1015 Evaluates the Einstein summation convention on the operands.
1016
1017 Using the Einstein summation convention, many common multi-dimensional,
1018 linear algebraic array operations can be represented in a simple fashion.
1019 In *implicit* mode `einsum` computes these values.
1020
1021 In *explicit* mode, `einsum` provides further flexibility to compute
1022 other array operations that might not be considered classical Einstein
1023 summation operations, by disabling, or forcing summation over specified
1024 subscript labels.
1025
1026 See the notes and examples for clarification.
1027
1028 Parameters
1029 ----------
1030 subscripts : str
1031 Specifies the subscripts for summation as comma separated list of
1032 subscript labels. An implicit (classical Einstein summation)
1033 calculation is performed unless the explicit indicator '->' is
1034 included as well as subscript labels of the precise output form.
1035 operands : list of array_like
1036 These are the arrays for the operation.
1037 out : ndarray, optional
1038 If provided, the calculation is done into this array.
1039 dtype : {data-type, None}, optional
1040 If provided, forces the calculation to use the data type specified.
1041 Note that you may have to also give a more liberal `casting`
1042 parameter to allow the conversions. Default is None.
1043 order : {'C', 'F', 'A', 'K'}, optional
1044 Controls the memory layout of the output. 'C' means it should
1045 be C contiguous. 'F' means it should be Fortran contiguous,
1046 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise.
1047 'K' means it should be as close to the layout as the inputs as
1048 is possible, including arbitrarily permuted axes.
1049 Default is 'K'.
1050 casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
1051 Controls what kind of data casting may occur. Setting this to
1052 'unsafe' is not recommended, as it can adversely affect accumulations.
1053
1054 * 'no' means the data types should not be cast at all.
1055 * 'equiv' means only byte-order changes are allowed.
1056 * 'safe' means only casts which can preserve values are allowed.
1057 * 'same_kind' means only safe casts or casts within a kind,
1058 like float64 to float32, are allowed.
1059 * 'unsafe' means any data conversions may be done.
1060
1061 Default is 'safe'.
1062 optimize : {False, True, 'greedy', 'optimal'}, optional
1063 Controls if intermediate optimization should occur. No optimization
1064 will occur if False and True will default to the 'greedy' algorithm.
1065 Also accepts an explicit contraction list from the ``np.einsum_path``
1066 function. See ``np.einsum_path`` for more details. Defaults to False.
1067
1068 Returns
1069 -------
1070 output : ndarray
1071 The calculation based on the Einstein summation convention.
1072
1073 See Also
1074 --------
1075 einsum_path, dot, inner, outer, tensordot, linalg.multi_dot
1076 einops :
1077 similar verbose interface is provided by
1078 `einops <https://github.com/arogozhnikov/einops>`_ package to cover
1079 additional operations: transpose, reshape/flatten, repeat/tile,
1080 squeeze/unsqueeze and reductions.
1081 opt_einsum :
1082 `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_
1083 optimizes contraction order for einsum-like expressions
1084 in backend-agnostic manner.
1085
1086 Notes
1087 -----
1088 .. versionadded:: 1.6.0
1089
1090 The Einstein summation convention can be used to compute
1091 many multi-dimensional, linear algebraic array operations. `einsum`
1092 provides a succinct way of representing these.
1093
1094 A non-exhaustive list of these operations,
1095 which can be computed by `einsum`, is shown below along with examples:
1096
1097 * Trace of an array, :py:func:`numpy.trace`.
1098 * Return a diagonal, :py:func:`numpy.diag`.
1099 * Array axis summations, :py:func:`numpy.sum`.
1100 * Transpositions and permutations, :py:func:`numpy.transpose`.
1101 * Matrix multiplication and dot product, :py:func:`numpy.matmul` :py:func:`numpy.dot`.
1102 * Vector inner and outer products, :py:func:`numpy.inner` :py:func:`numpy.outer`.
1103 * Broadcasting, element-wise and scalar multiplication, :py:func:`numpy.multiply`.
1104 * Tensor contractions, :py:func:`numpy.tensordot`.
1105 * Chained array operations, in efficient calculation order, :py:func:`numpy.einsum_path`.
1106
1107 The subscripts string is a comma-separated list of subscript labels,
1108 where each label refers to a dimension of the corresponding operand.
1109 Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)``
1110 is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label
1111 appears only once, it is not summed, so ``np.einsum('i', a)`` produces a
1112 view of ``a`` with no changes. A further example ``np.einsum('ij,jk', a, b)``
1113 describes traditional matrix multiplication and is equivalent to
1114 :py:func:`np.matmul(a,b) <numpy.matmul>`. Repeated subscript labels in one
1115 operand take the diagonal. For example, ``np.einsum('ii', a)`` is equivalent
1116 to :py:func:`np.trace(a) <numpy.trace>`.
1117
1118 In *implicit mode*, the chosen subscripts are important
1119 since the axes of the output are reordered alphabetically. This
1120 means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while
1121 ``np.einsum('ji', a)`` takes its transpose. Additionally,
1122 ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while,
1123 ``np.einsum('ij,jh', a, b)`` returns the transpose of the
1124 multiplication since subscript 'h' precedes subscript 'i'.
1125
1126 In *explicit mode* the output can be directly controlled by
1127 specifying output subscript labels. This requires the
1128 identifier '->' as well as the list of output subscript labels.
1129 This feature increases the flexibility of the function since
1130 summing can be disabled or forced when required. The call
1131 ``np.einsum('i->', a)`` is like :py:func:`np.sum(a, axis=-1) <numpy.sum>`,
1132 and ``np.einsum('ii->i', a)`` is like :py:func:`np.diag(a) <numpy.diag>`.
1133 The difference is that `einsum` does not allow broadcasting by default.
1134 Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the
1135 order of the output subscript labels and therefore returns matrix
1136 multiplication, unlike the example above in implicit mode.
1137
1138 To enable and control broadcasting, use an ellipsis. Default
1139 NumPy-style broadcasting is done by adding an ellipsis
1140 to the left of each term, like ``np.einsum('...ii->...i', a)``.
1141 To take the trace along the first and last axes,
1142 you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix
1143 product with the left-most indices instead of rightmost, one can do
1144 ``np.einsum('ij...,jk...->ik...', a, b)``.
1145
1146 When there is only one operand, no axes are summed, and no output
1147 parameter is provided, a view into the operand is returned instead
1148 of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)``
1149 produces a view (changed in version 1.10.0).
1150
1151 `einsum` also provides an alternative way to provide the subscripts
1152 and operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``.
1153 If the output shape is not provided in this format `einsum` will be
1154 calculated in implicit mode, otherwise it will be performed explicitly.
1155 The examples below have corresponding `einsum` calls with the two
1156 parameter methods.
1157
1158 .. versionadded:: 1.10.0
1159
1160 Views returned from einsum are now writeable whenever the input array
1161 is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now
1162 have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>`
1163 and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal
1164 of a 2D array.
1165
1166 .. versionadded:: 1.12.0
1167
1168 Added the ``optimize`` argument which will optimize the contraction order
1169 of an einsum expression. For a contraction with three or more operands this
1170 can greatly increase the computational efficiency at the cost of a larger
1171 memory footprint during computation.
1172
1173 Typically a 'greedy' algorithm is applied which empirical tests have shown
1174 returns the optimal path in the majority of cases. In some cases 'optimal'
1175 will return the superlative path through a more expensive, exhaustive search.
1176 For iterative calculations it may be advisable to calculate the optimal path
1177 once and reuse that path by supplying it as an argument. An example is given
1178 below.
1179
1180 See :py:func:`numpy.einsum_path` for more details.
1181
1182 Examples
1183 --------
1184 >>> a = np.arange(25).reshape(5,5)
1185 >>> b = np.arange(5)
1186 >>> c = np.arange(6).reshape(2,3)
1187
1188 Trace of a matrix:
1189
1190 >>> np.einsum('ii', a)
1191 60
1192 >>> np.einsum(a, [0,0])
1193 60
1194 >>> np.trace(a)
1195 60
1196
1197 Extract the diagonal (requires explicit form):
1198
1199 >>> np.einsum('ii->i', a)
1200 array([ 0, 6, 12, 18, 24])
1201 >>> np.einsum(a, [0,0], [0])
1202 array([ 0, 6, 12, 18, 24])
1203 >>> np.diag(a)
1204 array([ 0, 6, 12, 18, 24])
1205
1206 Sum over an axis (requires explicit form):
1207
1208 >>> np.einsum('ij->i', a)
1209 array([ 10, 35, 60, 85, 110])
1210 >>> np.einsum(a, [0,1], [0])
1211 array([ 10, 35, 60, 85, 110])
1212 >>> np.sum(a, axis=1)
1213 array([ 10, 35, 60, 85, 110])
1214
1215 For higher dimensional arrays summing a single axis can be done with ellipsis:
1216
1217 >>> np.einsum('...j->...', a)
1218 array([ 10, 35, 60, 85, 110])
1219 >>> np.einsum(a, [Ellipsis,1], [Ellipsis])
1220 array([ 10, 35, 60, 85, 110])
1221
1222 Compute a matrix transpose, or reorder any number of axes:
1223
1224 >>> np.einsum('ji', c)
1225 array([[0, 3],
1226 [1, 4],
1227 [2, 5]])
1228 >>> np.einsum('ij->ji', c)
1229 array([[0, 3],
1230 [1, 4],
1231 [2, 5]])
1232 >>> np.einsum(c, [1,0])
1233 array([[0, 3],
1234 [1, 4],
1235 [2, 5]])
1236 >>> np.transpose(c)
1237 array([[0, 3],
1238 [1, 4],
1239 [2, 5]])
1240
1241 Vector inner products:
1242
1243 >>> np.einsum('i,i', b, b)
1244 30
1245 >>> np.einsum(b, [0], b, [0])
1246 30
1247 >>> np.inner(b,b)
1248 30
1249
1250 Matrix vector multiplication:
1251
1252 >>> np.einsum('ij,j', a, b)
1253 array([ 30, 80, 130, 180, 230])
1254 >>> np.einsum(a, [0,1], b, [1])
1255 array([ 30, 80, 130, 180, 230])
1256 >>> np.dot(a, b)
1257 array([ 30, 80, 130, 180, 230])
1258 >>> np.einsum('...j,j', a, b)
1259 array([ 30, 80, 130, 180, 230])
1260
1261 Broadcasting and scalar multiplication:
1262
1263 >>> np.einsum('..., ...', 3, c)
1264 array([[ 0, 3, 6],
1265 [ 9, 12, 15]])
1266 >>> np.einsum(',ij', 3, c)
1267 array([[ 0, 3, 6],
1268 [ 9, 12, 15]])
1269 >>> np.einsum(3, [Ellipsis], c, [Ellipsis])
1270 array([[ 0, 3, 6],
1271 [ 9, 12, 15]])
1272 >>> np.multiply(3, c)
1273 array([[ 0, 3, 6],
1274 [ 9, 12, 15]])
1275
1276 Vector outer product:
1277
1278 >>> np.einsum('i,j', np.arange(2)+1, b)
1279 array([[0, 1, 2, 3, 4],
1280 [0, 2, 4, 6, 8]])
1281 >>> np.einsum(np.arange(2)+1, [0], b, [1])
1282 array([[0, 1, 2, 3, 4],
1283 [0, 2, 4, 6, 8]])
1284 >>> np.outer(np.arange(2)+1, b)
1285 array([[0, 1, 2, 3, 4],
1286 [0, 2, 4, 6, 8]])
1287
1288 Tensor contraction:
1289
1290 >>> a = np.arange(60.).reshape(3,4,5)
1291 >>> b = np.arange(24.).reshape(4,3,2)
1292 >>> np.einsum('ijk,jil->kl', a, b)
1293 array([[4400., 4730.],
1294 [4532., 4874.],
1295 [4664., 5018.],
1296 [4796., 5162.],
1297 [4928., 5306.]])
1298 >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
1299 array([[4400., 4730.],
1300 [4532., 4874.],
1301 [4664., 5018.],
1302 [4796., 5162.],
1303 [4928., 5306.]])
1304 >>> np.tensordot(a,b, axes=([1,0],[0,1]))
1305 array([[4400., 4730.],
1306 [4532., 4874.],
1307 [4664., 5018.],
1308 [4796., 5162.],
1309 [4928., 5306.]])
1310
1311 Writeable returned arrays (since version 1.10.0):
1312
1313 >>> a = np.zeros((3, 3))
1314 >>> np.einsum('ii->i', a)[:] = 1
1315 >>> a
1316 array([[1., 0., 0.],
1317 [0., 1., 0.],
1318 [0., 0., 1.]])
1319
1320 Example of ellipsis use:
1321
1322 >>> a = np.arange(6).reshape((3,2))
1323 >>> b = np.arange(12).reshape((4,3))
1324 >>> np.einsum('ki,jk->ij', a, b)
1325 array([[10, 28, 46, 64],
1326 [13, 40, 67, 94]])
1327 >>> np.einsum('ki,...k->i...', a, b)
1328 array([[10, 28, 46, 64],
1329 [13, 40, 67, 94]])
1330 >>> np.einsum('k...,jk', a, b)
1331 array([[10, 28, 46, 64],
1332 [13, 40, 67, 94]])
1333
1334 Chained array operations. For more complicated contractions, speed ups
1335 might be achieved by repeatedly computing a 'greedy' path or pre-computing the
1336 'optimal' path and repeatedly applying it, using an
1337 `einsum_path` insertion (since version 1.12.0). Performance improvements can be
1338 particularly significant with larger arrays:
1339
1340 >>> a = np.ones(64).reshape(2,4,8)
1341
1342 Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.)
1343
1344 >>> for iteration in range(500):
1345 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
1346
1347 Sub-optimal `einsum` (due to repeated path calculation time): ~330ms
1348
1349 >>> for iteration in range(500):
1350 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')
1351
1352 Greedy `einsum` (faster optimal path approximation): ~160ms
1353
1354 >>> for iteration in range(500):
1355 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
1356
1357 Optimal `einsum` (best usage pattern in some use cases): ~110ms
1358
1359 >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')[0]
1360 >>> for iteration in range(500):
1361 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
1362
1363 """
1364 # Special handling if out is specified
1365 specified_out = out is not None
1366
1367 # If no optimization, run pure einsum
1368 if optimize is False:
1369 if specified_out:
1370 kwargs['out'] = out
1371 return c_einsum(*operands, **kwargs)
1372
1373 # Check the kwargs to avoid a more cryptic error later, without having to
1374 # repeat default values here
1375 valid_einsum_kwargs = ['dtype', 'order', 'casting']
1376 unknown_kwargs = [k for (k, v) in kwargs.items() if
1377 k not in valid_einsum_kwargs]
1378 if len(unknown_kwargs):
1379 raise TypeError("Did not understand the following kwargs: %s"
1380 % unknown_kwargs)
1381
1382 # Build the contraction list and operand
1383 operands, contraction_list = einsum_path(*operands, optimize=optimize,
1384 einsum_call=True)
1385
1386 # Handle order kwarg for output array, c_einsum allows mixed case
1387 output_order = kwargs.pop('order', 'K')
1388 if output_order.upper() == 'A':
1389 if all(arr.flags.f_contiguous for arr in operands):
1390 output_order = 'F'
1391 else:
1392 output_order = 'C'
1393
1394 # Start contraction loop
1395 for num, contraction in enumerate(contraction_list):
1396 inds, idx_rm, einsum_str, remaining, blas = contraction
1397 tmp_operands = [operands.pop(x) for x in inds]
1398
1399 # Do we need to deal with the output?
1400 handle_out = specified_out and ((num + 1) == len(contraction_list))
1401
1402 # Call tensordot if still possible
1403 if blas:
1404 # Checks have already been handled
1405 input_str, results_index = einsum_str.split('->')
1406 input_left, input_right = input_str.split(',')
1407
1408 tensor_result = input_left + input_right
1409 for s in idx_rm:
1410 tensor_result = tensor_result.replace(s, "")
1411
1412 # Find indices to contract over
1413 left_pos, right_pos = [], []
1414 for s in sorted(idx_rm):
1415 left_pos.append(input_left.find(s))
1416 right_pos.append(input_right.find(s))
1417
1418 # Contract!
1419 new_view = tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos)))
1420
1421 # Build a new view if needed
1422 if (tensor_result != results_index) or handle_out:
1423 if handle_out:
1424 kwargs["out"] = out
1425 new_view = c_einsum(tensor_result + '->' + results_index, new_view, **kwargs)
1426
1427 # Call einsum
1428 else:
1429 # If out was specified
1430 if handle_out:
1431 kwargs["out"] = out
1432
1433 # Do the contraction
1434 new_view = c_einsum(einsum_str, *tmp_operands, **kwargs)
1435
1436 # Append new items and dereference what we can
1437 operands.append(new_view)
1438 del tmp_operands, new_view
1439
1440 if specified_out:
1441 return out
1442 else:
1443 return asanyarray(operands[0], order=output_order)