Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/numpy/_core/einsumfunc.py: 6%
413 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-09 06:12 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-09 06:12 +0000
1"""
2Implementation of optimized einsum.
4"""
5import itertools
6import operator
8from numpy._core.multiarray import c_einsum
9from numpy._core.numeric import asanyarray, tensordot
10from numpy._core.overrides import array_function_dispatch
12__all__ = ['einsum', 'einsum_path']
14# importing string for string.ascii_letters would be too slow
15# the first import before caching has been measured to take 800 µs (#23777)
16einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
17einsum_symbols_set = set(einsum_symbols)
20def _flop_count(idx_contraction, inner, num_terms, size_dictionary):
21 """
22 Computes the number of FLOPS in the contraction.
24 Parameters
25 ----------
26 idx_contraction : iterable
27 The indices involved in the contraction
28 inner : bool
29 Does this contraction require an inner product?
30 num_terms : int
31 The number of terms in a contraction
32 size_dictionary : dict
33 The size of each of the indices in idx_contraction
35 Returns
36 -------
37 flop_count : int
38 The total number of FLOPS required for the contraction.
40 Examples
41 --------
43 >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
44 30
46 >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
47 60
49 """
51 overall_size = _compute_size_by_dict(idx_contraction, size_dictionary)
52 op_factor = max(1, num_terms - 1)
53 if inner:
54 op_factor += 1
56 return overall_size * op_factor
58def _compute_size_by_dict(indices, idx_dict):
59 """
60 Computes the product of the elements in indices based on the dictionary
61 idx_dict.
63 Parameters
64 ----------
65 indices : iterable
66 Indices to base the product on.
67 idx_dict : dictionary
68 Dictionary of index sizes
70 Returns
71 -------
72 ret : int
73 The resulting product.
75 Examples
76 --------
77 >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
78 90
80 """
81 ret = 1
82 for i in indices:
83 ret *= idx_dict[i]
84 return ret
87def _find_contraction(positions, input_sets, output_set):
88 """
89 Finds the contraction for a given set of input and output sets.
91 Parameters
92 ----------
93 positions : iterable
94 Integer positions of terms used in the contraction.
95 input_sets : list
96 List of sets that represent the lhs side of the einsum subscript
97 output_set : set
98 Set that represents the rhs side of the overall einsum subscript
100 Returns
101 -------
102 new_result : set
103 The indices of the resulting contraction
104 remaining : list
105 List of sets that have not been contracted, the new set is appended to
106 the end of this list
107 idx_removed : set
108 Indices removed from the entire contraction
109 idx_contraction : set
110 The indices used in the current contraction
112 Examples
113 --------
115 # A simple dot product test case
116 >>> pos = (0, 1)
117 >>> isets = [set('ab'), set('bc')]
118 >>> oset = set('ac')
119 >>> _find_contraction(pos, isets, oset)
120 ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
122 # A more complex case with additional terms in the contraction
123 >>> pos = (0, 2)
124 >>> isets = [set('abd'), set('ac'), set('bdc')]
125 >>> oset = set('ac')
126 >>> _find_contraction(pos, isets, oset)
127 ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
128 """
130 idx_contract = set()
131 idx_remain = output_set.copy()
132 remaining = []
133 for ind, value in enumerate(input_sets):
134 if ind in positions:
135 idx_contract |= value
136 else:
137 remaining.append(value)
138 idx_remain |= value
140 new_result = idx_remain & idx_contract
141 idx_removed = (idx_contract - new_result)
142 remaining.append(new_result)
144 return (new_result, remaining, idx_removed, idx_contract)
147def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
148 """
149 Computes all possible pair contractions, sieves the results based
150 on ``memory_limit`` and returns the lowest cost path. This algorithm
151 scales factorial with respect to the elements in the list ``input_sets``.
153 Parameters
154 ----------
155 input_sets : list
156 List of sets that represent the lhs side of the einsum subscript
157 output_set : set
158 Set that represents the rhs side of the overall einsum subscript
159 idx_dict : dictionary
160 Dictionary of index sizes
161 memory_limit : int
162 The maximum number of elements in a temporary array
164 Returns
165 -------
166 path : list
167 The optimal contraction order within the memory limit constraint.
169 Examples
170 --------
171 >>> isets = [set('abd'), set('ac'), set('bdc')]
172 >>> oset = set()
173 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
174 >>> _optimal_path(isets, oset, idx_sizes, 5000)
175 [(0, 2), (0, 1)]
176 """
178 full_results = [(0, [], input_sets)]
179 for iteration in range(len(input_sets) - 1):
180 iter_results = []
182 # Compute all unique pairs
183 for curr in full_results:
184 cost, positions, remaining = curr
185 for con in itertools.combinations(
186 range(len(input_sets) - iteration), 2
187 ):
189 # Find the contraction
190 cont = _find_contraction(con, remaining, output_set)
191 new_result, new_input_sets, idx_removed, idx_contract = cont
193 # Sieve the results based on memory_limit
194 new_size = _compute_size_by_dict(new_result, idx_dict)
195 if new_size > memory_limit:
196 continue
198 # Build (total_cost, positions, indices_remaining)
199 total_cost = cost + _flop_count(
200 idx_contract, idx_removed, len(con), idx_dict
201 )
202 new_pos = positions + [con]
203 iter_results.append((total_cost, new_pos, new_input_sets))
205 # Update combinatorial list, if we did not find anything return best
206 # path + remaining contractions
207 if iter_results:
208 full_results = iter_results
209 else:
210 path = min(full_results, key=lambda x: x[0])[1]
211 path += [tuple(range(len(input_sets) - iteration))]
212 return path
214 # If we have not found anything return single einsum contraction
215 if len(full_results) == 0:
216 return [tuple(range(len(input_sets)))]
218 path = min(full_results, key=lambda x: x[0])[1]
219 return path
221def _parse_possible_contraction(
222 positions, input_sets, output_set, idx_dict,
223 memory_limit, path_cost, naive_cost
224 ):
225 """Compute the cost (removed size + flops) and resultant indices for
226 performing the contraction specified by ``positions``.
228 Parameters
229 ----------
230 positions : tuple of int
231 The locations of the proposed tensors to contract.
232 input_sets : list of sets
233 The indices found on each tensors.
234 output_set : set
235 The output indices of the expression.
236 idx_dict : dict
237 Mapping of each index to its size.
238 memory_limit : int
239 The total allowed size for an intermediary tensor.
240 path_cost : int
241 The contraction cost so far.
242 naive_cost : int
243 The cost of the unoptimized expression.
245 Returns
246 -------
247 cost : (int, int)
248 A tuple containing the size of any indices removed, and the flop cost.
249 positions : tuple of int
250 The locations of the proposed tensors to contract.
251 new_input_sets : list of sets
252 The resulting new list of indices if this proposed contraction
253 is performed.
255 """
257 # Find the contraction
258 contract = _find_contraction(positions, input_sets, output_set)
259 idx_result, new_input_sets, idx_removed, idx_contract = contract
261 # Sieve the results based on memory_limit
262 new_size = _compute_size_by_dict(idx_result, idx_dict)
263 if new_size > memory_limit:
264 return None
266 # Build sort tuple
267 old_sizes = (
268 _compute_size_by_dict(input_sets[p], idx_dict) for p in positions
269 )
270 removed_size = sum(old_sizes) - new_size
272 # NB: removed_size used to be just the size of any removed indices i.e.:
273 # helpers.compute_size_by_dict(idx_removed, idx_dict)
274 cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict)
275 sort = (-removed_size, cost)
277 # Sieve based on total cost as well
278 if (path_cost + cost) > naive_cost:
279 return None
281 # Add contraction to possible choices
282 return [sort, positions, new_input_sets]
285def _update_other_results(results, best):
286 """Update the positions and provisional input_sets of ``results``
287 based on performing the contraction result ``best``. Remove any
288 involving the tensors contracted.
290 Parameters
291 ----------
292 results : list
293 List of contraction results produced by
294 ``_parse_possible_contraction``.
295 best : list
296 The best contraction of ``results`` i.e. the one that
297 will be performed.
299 Returns
300 -------
301 mod_results : list
302 The list of modified results, updated with outcome of
303 ``best`` contraction.
304 """
306 best_con = best[1]
307 bx, by = best_con
308 mod_results = []
310 for cost, (x, y), con_sets in results:
312 # Ignore results involving tensors just contracted
313 if x in best_con or y in best_con:
314 continue
316 # Update the input_sets
317 del con_sets[by - int(by > x) - int(by > y)]
318 del con_sets[bx - int(bx > x) - int(bx > y)]
319 con_sets.insert(-1, best[2][-1])
321 # Update the position indices
322 mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by)
323 mod_results.append((cost, mod_con, con_sets))
325 return mod_results
327def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
328 """
329 Finds the path by contracting the best pair until the input list is
330 exhausted. The best pair is found by minimizing the tuple
331 ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing
332 matrix multiplication or inner product operations, then Hadamard like
333 operations, and finally outer operations. Outer products are limited by
334 ``memory_limit``. This algorithm scales cubically with respect to the
335 number of elements in the list ``input_sets``.
337 Parameters
338 ----------
339 input_sets : list
340 List of sets that represent the lhs side of the einsum subscript
341 output_set : set
342 Set that represents the rhs side of the overall einsum subscript
343 idx_dict : dictionary
344 Dictionary of index sizes
345 memory_limit : int
346 The maximum number of elements in a temporary array
348 Returns
349 -------
350 path : list
351 The greedy contraction order within the memory limit constraint.
353 Examples
354 --------
355 >>> isets = [set('abd'), set('ac'), set('bdc')]
356 >>> oset = set()
357 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
358 >>> _greedy_path(isets, oset, idx_sizes, 5000)
359 [(0, 2), (0, 1)]
360 """
362 # Handle trivial cases that leaked through
363 if len(input_sets) == 1:
364 return [(0,)]
365 elif len(input_sets) == 2:
366 return [(0, 1)]
368 # Build up a naive cost
369 contract = _find_contraction(
370 range(len(input_sets)), input_sets, output_set
371 )
372 idx_result, new_input_sets, idx_removed, idx_contract = contract
373 naive_cost = _flop_count(
374 idx_contract, idx_removed, len(input_sets), idx_dict
375 )
377 # Initially iterate over all pairs
378 comb_iter = itertools.combinations(range(len(input_sets)), 2)
379 known_contractions = []
381 path_cost = 0
382 path = []
384 for iteration in range(len(input_sets) - 1):
386 # Iterate over all pairs on the first step, only previously
387 # found pairs on subsequent steps
388 for positions in comb_iter:
390 # Always initially ignore outer products
391 if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]):
392 continue
394 result = _parse_possible_contraction(
395 positions, input_sets, output_set, idx_dict,
396 memory_limit, path_cost, naive_cost
397 )
398 if result is not None:
399 known_contractions.append(result)
401 # If we do not have a inner contraction, rescan pairs
402 # including outer products
403 if len(known_contractions) == 0:
405 # Then check the outer products
406 for positions in itertools.combinations(
407 range(len(input_sets)), 2
408 ):
409 result = _parse_possible_contraction(
410 positions, input_sets, output_set, idx_dict,
411 memory_limit, path_cost, naive_cost
412 )
413 if result is not None:
414 known_contractions.append(result)
416 # If we still did not find any remaining contractions,
417 # default back to einsum like behavior
418 if len(known_contractions) == 0:
419 path.append(tuple(range(len(input_sets))))
420 break
422 # Sort based on first index
423 best = min(known_contractions, key=lambda x: x[0])
425 # Now propagate as many unused contractions as possible
426 # to the next iteration
427 known_contractions = _update_other_results(known_contractions, best)
429 # Next iteration only compute contractions with the new tensor
430 # All other contractions have been accounted for
431 input_sets = best[2]
432 new_tensor_pos = len(input_sets) - 1
433 comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos))
435 # Update path and total cost
436 path.append(best[1])
437 path_cost += best[0][1]
439 return path
442def _can_dot(inputs, result, idx_removed):
443 """
444 Checks if we can use BLAS (np.tensordot) call and its beneficial to do so.
446 Parameters
447 ----------
448 inputs : list of str
449 Specifies the subscripts for summation.
450 result : str
451 Resulting summation.
452 idx_removed : set
453 Indices that are removed in the summation
456 Returns
457 -------
458 type : bool
459 Returns true if BLAS should and can be used, else False
461 Notes
462 -----
463 If the operations is BLAS level 1 or 2 and is not already aligned
464 we default back to einsum as the memory movement to copy is more
465 costly than the operation itself.
468 Examples
469 --------
471 # Standard GEMM operation
472 >>> _can_dot(['ij', 'jk'], 'ik', set('j'))
473 True
475 # Can use the standard BLAS, but requires odd data movement
476 >>> _can_dot(['ijj', 'jk'], 'ik', set('j'))
477 False
479 # DDOT where the memory is not aligned
480 >>> _can_dot(['ijk', 'ikj'], '', set('ijk'))
481 False
483 """
485 # All `dot` calls remove indices
486 if len(idx_removed) == 0:
487 return False
489 # BLAS can only handle two operands
490 if len(inputs) != 2:
491 return False
493 input_left, input_right = inputs
495 for c in set(input_left + input_right):
496 # can't deal with repeated indices on same input or more than 2 total
497 nl, nr = input_left.count(c), input_right.count(c)
498 if (nl > 1) or (nr > 1) or (nl + nr > 2):
499 return False
501 # can't do implicit summation or dimension collapse e.g.
502 # "ab,bc->c" (implicitly sum over 'a')
503 # "ab,ca->ca" (take diagonal of 'a')
504 if nl + nr - 1 == int(c in result):
505 return False
507 # Build a few temporaries
508 set_left = set(input_left)
509 set_right = set(input_right)
510 keep_left = set_left - idx_removed
511 keep_right = set_right - idx_removed
512 rs = len(idx_removed)
514 # At this point we are a DOT, GEMV, or GEMM operation
516 # Handle inner products
518 # DDOT with aligned data
519 if input_left == input_right:
520 return True
522 # DDOT without aligned data (better to use einsum)
523 if set_left == set_right:
524 return False
526 # Handle the 4 possible (aligned) GEMV or GEMM cases
528 # GEMM or GEMV no transpose
529 if input_left[-rs:] == input_right[:rs]:
530 return True
532 # GEMM or GEMV transpose both
533 if input_left[:rs] == input_right[-rs:]:
534 return True
536 # GEMM or GEMV transpose right
537 if input_left[-rs:] == input_right[-rs:]:
538 return True
540 # GEMM or GEMV transpose left
541 if input_left[:rs] == input_right[:rs]:
542 return True
544 # Einsum is faster than GEMV if we have to copy data
545 if not keep_left or not keep_right:
546 return False
548 # We are a matrix-matrix product, but we need to copy data
549 return True
552def _parse_einsum_input(operands):
553 """
554 A reproduction of einsum c side einsum parsing in python.
556 Returns
557 -------
558 input_strings : str
559 Parsed input strings
560 output_string : str
561 Parsed output string
562 operands : list of array_like
563 The operands to use in the numpy contraction
565 Examples
566 --------
567 The operand list is simplified to reduce printing:
569 >>> np.random.seed(123)
570 >>> a = np.random.rand(4, 4)
571 >>> b = np.random.rand(4, 4, 4)
572 >>> _parse_einsum_input(('...a,...a->...', a, b))
573 ('za,xza', 'xz', [a, b]) # may vary
575 >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
576 ('za,xza', 'xz', [a, b]) # may vary
577 """
579 if len(operands) == 0:
580 raise ValueError("No input operands")
582 if isinstance(operands[0], str):
583 subscripts = operands[0].replace(" ", "")
584 operands = [asanyarray(v) for v in operands[1:]]
586 # Ensure all characters are valid
587 for s in subscripts:
588 if s in '.,->':
589 continue
590 if s not in einsum_symbols:
591 raise ValueError("Character %s is not a valid symbol." % s)
593 else:
594 tmp_operands = list(operands)
595 operand_list = []
596 subscript_list = []
597 for p in range(len(operands) // 2):
598 operand_list.append(tmp_operands.pop(0))
599 subscript_list.append(tmp_operands.pop(0))
601 output_list = tmp_operands[-1] if len(tmp_operands) else None
602 operands = [asanyarray(v) for v in operand_list]
603 subscripts = ""
604 last = len(subscript_list) - 1
605 for num, sub in enumerate(subscript_list):
606 for s in sub:
607 if s is Ellipsis:
608 subscripts += "..."
609 else:
610 try:
611 s = operator.index(s)
612 except TypeError as e:
613 raise TypeError(
614 "For this input type lists must contain "
615 "either int or Ellipsis"
616 ) from e
617 subscripts += einsum_symbols[s]
618 if num != last:
619 subscripts += ","
621 if output_list is not None:
622 subscripts += "->"
623 for s in output_list:
624 if s is Ellipsis:
625 subscripts += "..."
626 else:
627 try:
628 s = operator.index(s)
629 except TypeError as e:
630 raise TypeError(
631 "For this input type lists must contain "
632 "either int or Ellipsis"
633 ) from e
634 subscripts += einsum_symbols[s]
635 # Check for proper "->"
636 if ("-" in subscripts) or (">" in subscripts):
637 invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
638 if invalid or (subscripts.count("->") != 1):
639 raise ValueError("Subscripts can only contain one '->'.")
641 # Parse ellipses
642 if "." in subscripts:
643 used = subscripts.replace(".", "").replace(",", "").replace("->", "")
644 unused = list(einsum_symbols_set - set(used))
645 ellipse_inds = "".join(unused)
646 longest = 0
648 if "->" in subscripts:
649 input_tmp, output_sub = subscripts.split("->")
650 split_subscripts = input_tmp.split(",")
651 out_sub = True
652 else:
653 split_subscripts = subscripts.split(',')
654 out_sub = False
656 for num, sub in enumerate(split_subscripts):
657 if "." in sub:
658 if (sub.count(".") != 3) or (sub.count("...") != 1):
659 raise ValueError("Invalid Ellipses.")
661 # Take into account numerical values
662 if operands[num].shape == ():
663 ellipse_count = 0
664 else:
665 ellipse_count = max(operands[num].ndim, 1)
666 ellipse_count -= (len(sub) - 3)
668 if ellipse_count > longest:
669 longest = ellipse_count
671 if ellipse_count < 0:
672 raise ValueError("Ellipses lengths do not match.")
673 elif ellipse_count == 0:
674 split_subscripts[num] = sub.replace('...', '')
675 else:
676 rep_inds = ellipse_inds[-ellipse_count:]
677 split_subscripts[num] = sub.replace('...', rep_inds)
679 subscripts = ",".join(split_subscripts)
680 if longest == 0:
681 out_ellipse = ""
682 else:
683 out_ellipse = ellipse_inds[-longest:]
685 if out_sub:
686 subscripts += "->" + output_sub.replace("...", out_ellipse)
687 else:
688 # Special care for outputless ellipses
689 output_subscript = ""
690 tmp_subscripts = subscripts.replace(",", "")
691 for s in sorted(set(tmp_subscripts)):
692 if s not in (einsum_symbols):
693 raise ValueError("Character %s is not a valid symbol." % s)
694 if tmp_subscripts.count(s) == 1:
695 output_subscript += s
696 normal_inds = ''.join(sorted(set(output_subscript) -
697 set(out_ellipse)))
699 subscripts += "->" + out_ellipse + normal_inds
701 # Build output string if does not exist
702 if "->" in subscripts:
703 input_subscripts, output_subscript = subscripts.split("->")
704 else:
705 input_subscripts = subscripts
706 # Build output subscripts
707 tmp_subscripts = subscripts.replace(",", "")
708 output_subscript = ""
709 for s in sorted(set(tmp_subscripts)):
710 if s not in einsum_symbols:
711 raise ValueError("Character %s is not a valid symbol." % s)
712 if tmp_subscripts.count(s) == 1:
713 output_subscript += s
715 # Make sure output subscripts are in the input
716 for char in output_subscript:
717 if output_subscript.count(char) != 1:
718 raise ValueError("Output character %s appeared more than once in "
719 "the output." % char)
720 if char not in input_subscripts:
721 raise ValueError("Output character %s did not appear in the input"
722 % char)
724 # Make sure number operands is equivalent to the number of terms
725 if len(input_subscripts.split(',')) != len(operands):
726 raise ValueError("Number of einsum subscripts must be equal to the "
727 "number of operands.")
729 return (input_subscripts, output_subscript, operands)
732def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None):
733 # NOTE: technically, we should only dispatch on array-like arguments, not
734 # subscripts (given as strings). But separating operands into
735 # arrays/subscripts is a little tricky/slow (given einsum's two supported
736 # signatures), so as a practical shortcut we dispatch on everything.
737 # Strings will be ignored for dispatching since they don't define
738 # __array_function__.
739 return operands
742@array_function_dispatch(_einsum_path_dispatcher, module='numpy')
743def einsum_path(*operands, optimize='greedy', einsum_call=False):
744 """
745 einsum_path(subscripts, *operands, optimize='greedy')
747 Evaluates the lowest cost contraction order for an einsum expression by
748 considering the creation of intermediate arrays.
750 Parameters
751 ----------
752 subscripts : str
753 Specifies the subscripts for summation.
754 *operands : list of array_like
755 These are the arrays for the operation.
756 optimize : {bool, list, tuple, 'greedy', 'optimal'}
757 Choose the type of path. If a tuple is provided, the second argument is
758 assumed to be the maximum intermediate size created. If only a single
759 argument is provided the largest input or output array size is used
760 as a maximum intermediate size.
762 * if a list is given that starts with ``einsum_path``, uses this as the
763 contraction path
764 * if False no optimization is taken
765 * if True defaults to the 'greedy' algorithm
766 * 'optimal' An algorithm that combinatorially explores all possible
767 ways of contracting the listed tensors and chooses the least costly
768 path. Scales exponentially with the number of terms in the
769 contraction.
770 * 'greedy' An algorithm that chooses the best pair contraction
771 at each step. Effectively, this algorithm searches the largest inner,
772 Hadamard, and then outer products at each step. Scales cubically with
773 the number of terms in the contraction. Equivalent to the 'optimal'
774 path for most contractions.
776 Default is 'greedy'.
778 Returns
779 -------
780 path : list of tuples
781 A list representation of the einsum path.
782 string_repr : str
783 A printable representation of the einsum path.
785 Notes
786 -----
787 The resulting path indicates which terms of the input contraction should be
788 contracted first, the result of this contraction is then appended to the
789 end of the contraction list. This list can then be iterated over until all
790 intermediate contractions are complete.
792 See Also
793 --------
794 einsum, linalg.multi_dot
796 Examples
797 --------
799 We can begin with a chain dot example. In this case, it is optimal to
800 contract the ``b`` and ``c`` tensors first as represented by the first
801 element of the path ``(1, 2)``. The resulting tensor is added to the end
802 of the contraction and the remaining contraction ``(0, 1)`` is then
803 completed.
805 >>> np.random.seed(123)
806 >>> a = np.random.rand(2, 2)
807 >>> b = np.random.rand(2, 5)
808 >>> c = np.random.rand(5, 2)
809 >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
810 >>> print(path_info[0])
811 ['einsum_path', (1, 2), (0, 1)]
812 >>> print(path_info[1])
813 Complete contraction: ij,jk,kl->il # may vary
814 Naive scaling: 4
815 Optimized scaling: 3
816 Naive FLOP count: 1.600e+02
817 Optimized FLOP count: 5.600e+01
818 Theoretical speedup: 2.857
819 Largest intermediate: 4.000e+00 elements
820 -------------------------------------------------------------------------
821 scaling current remaining
822 -------------------------------------------------------------------------
823 3 kl,jk->jl ij,jl->il
824 3 jl,ij->il il->il
827 A more complex index transformation example.
829 >>> I = np.random.rand(10, 10, 10, 10)
830 >>> C = np.random.rand(10, 10)
831 >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C,
832 ... optimize='greedy')
834 >>> print(path_info[0])
835 ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
836 >>> print(path_info[1])
837 Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary
838 Naive scaling: 8
839 Optimized scaling: 5
840 Naive FLOP count: 8.000e+08
841 Optimized FLOP count: 8.000e+05
842 Theoretical speedup: 1000.000
843 Largest intermediate: 1.000e+04 elements
844 --------------------------------------------------------------------------
845 scaling current remaining
846 --------------------------------------------------------------------------
847 5 abcd,ea->bcde fb,gc,hd,bcde->efgh
848 5 bcde,fb->cdef gc,hd,cdef->efgh
849 5 cdef,gc->defg hd,defg->efgh
850 5 defg,hd->efgh efgh->efgh
851 """
853 # Figure out what the path really is
854 path_type = optimize
855 if path_type is True:
856 path_type = 'greedy'
857 if path_type is None:
858 path_type = False
860 explicit_einsum_path = False
861 memory_limit = None
863 # No optimization or a named path algorithm
864 if (path_type is False) or isinstance(path_type, str):
865 pass
867 # Given an explicit path
868 elif len(path_type) and (path_type[0] == 'einsum_path'):
869 explicit_einsum_path = True
871 # Path tuple with memory limit
872 elif ((len(path_type) == 2) and isinstance(path_type[0], str) and
873 isinstance(path_type[1], (int, float))):
874 memory_limit = int(path_type[1])
875 path_type = path_type[0]
877 else:
878 raise TypeError("Did not understand the path: %s" % str(path_type))
880 # Hidden option, only einsum should call this
881 einsum_call_arg = einsum_call
883 # Python side parsing
884 input_subscripts, output_subscript, operands = (
885 _parse_einsum_input(operands)
886 )
888 # Build a few useful list and sets
889 input_list = input_subscripts.split(',')
890 input_sets = [set(x) for x in input_list]
891 output_set = set(output_subscript)
892 indices = set(input_subscripts.replace(',', ''))
894 # Get length of each unique dimension and ensure all dimensions are correct
895 dimension_dict = {}
896 broadcast_indices = [[] for x in range(len(input_list))]
897 for tnum, term in enumerate(input_list):
898 sh = operands[tnum].shape
899 if len(sh) != len(term):
900 raise ValueError("Einstein sum subscript %s does not contain the "
901 "correct number of indices for operand %d."
902 % (input_subscripts[tnum], tnum))
903 for cnum, char in enumerate(term):
904 dim = sh[cnum]
906 # Build out broadcast indices
907 if dim == 1:
908 broadcast_indices[tnum].append(char)
910 if char in dimension_dict.keys():
911 # For broadcasting cases we always want the largest dim size
912 if dimension_dict[char] == 1:
913 dimension_dict[char] = dim
914 elif dim not in (1, dimension_dict[char]):
915 raise ValueError("Size of label '%s' for operand %d (%d) "
916 "does not match previous terms (%d)."
917 % (char, tnum, dimension_dict[char], dim))
918 else:
919 dimension_dict[char] = dim
921 # Convert broadcast inds to sets
922 broadcast_indices = [set(x) for x in broadcast_indices]
924 # Compute size of each input array plus the output array
925 size_list = [_compute_size_by_dict(term, dimension_dict)
926 for term in input_list + [output_subscript]]
927 max_size = max(size_list)
929 if memory_limit is None:
930 memory_arg = max_size
931 else:
932 memory_arg = memory_limit
934 # Compute naive cost
935 # This isn't quite right, need to look into exactly how einsum does this
936 inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
937 naive_cost = _flop_count(
938 indices, inner_product, len(input_list), dimension_dict
939 )
941 # Compute the path
942 if explicit_einsum_path:
943 path = path_type[1:]
944 elif (
945 (path_type is False)
946 or (len(input_list) in [1, 2])
947 or (indices == output_set)
948 ):
949 # Nothing to be optimized, leave it to einsum
950 path = [tuple(range(len(input_list)))]
951 elif path_type == "greedy":
952 path = _greedy_path(
953 input_sets, output_set, dimension_dict, memory_arg
954 )
955 elif path_type == "optimal":
956 path = _optimal_path(
957 input_sets, output_set, dimension_dict, memory_arg
958 )
959 else:
960 raise KeyError("Path name %s not found", path_type)
962 cost_list, scale_list, size_list, contraction_list = [], [], [], []
964 # Build contraction tuple (positions, gemm, einsum_str, remaining)
965 for cnum, contract_inds in enumerate(path):
966 # Make sure we remove inds from right to left
967 contract_inds = tuple(sorted(list(contract_inds), reverse=True))
969 contract = _find_contraction(contract_inds, input_sets, output_set)
970 out_inds, input_sets, idx_removed, idx_contract = contract
972 cost = _flop_count(
973 idx_contract, idx_removed, len(contract_inds), dimension_dict
974 )
975 cost_list.append(cost)
976 scale_list.append(len(idx_contract))
977 size_list.append(_compute_size_by_dict(out_inds, dimension_dict))
979 bcast = set()
980 tmp_inputs = []
981 for x in contract_inds:
982 tmp_inputs.append(input_list.pop(x))
983 bcast |= broadcast_indices.pop(x)
985 new_bcast_inds = bcast - idx_removed
987 # If we're broadcasting, nix blas
988 if not len(idx_removed & bcast):
989 do_blas = _can_dot(tmp_inputs, out_inds, idx_removed)
990 else:
991 do_blas = False
993 # Last contraction
994 if (cnum - len(path)) == -1:
995 idx_result = output_subscript
996 else:
997 sort_result = [(dimension_dict[ind], ind) for ind in out_inds]
998 idx_result = "".join([x[1] for x in sorted(sort_result)])
1000 input_list.append(idx_result)
1001 broadcast_indices.append(new_bcast_inds)
1002 einsum_str = ",".join(tmp_inputs) + "->" + idx_result
1004 contraction = (
1005 contract_inds, idx_removed, einsum_str, input_list[:], do_blas
1006 )
1007 contraction_list.append(contraction)
1009 opt_cost = sum(cost_list) + 1
1011 if len(input_list) != 1:
1012 # Explicit "einsum_path" is usually trusted, but we detect this kind of
1013 # mistake in order to prevent from returning an intermediate value.
1014 raise RuntimeError(
1015 "Invalid einsum_path is specified: {} more operands has to be "
1016 "contracted.".format(len(input_list) - 1))
1018 if einsum_call_arg:
1019 return (operands, contraction_list)
1021 # Return the path along with a nice string representation
1022 overall_contraction = input_subscripts + "->" + output_subscript
1023 header = ("scaling", "current", "remaining")
1025 speedup = naive_cost / opt_cost
1026 max_i = max(size_list)
1028 path_print = " Complete contraction: %s\n" % overall_contraction
1029 path_print += " Naive scaling: %d\n" % len(indices)
1030 path_print += " Optimized scaling: %d\n" % max(scale_list)
1031 path_print += " Naive FLOP count: %.3e\n" % naive_cost
1032 path_print += " Optimized FLOP count: %.3e\n" % opt_cost
1033 path_print += " Theoretical speedup: %3.3f\n" % speedup
1034 path_print += " Largest intermediate: %.3e elements\n" % max_i
1035 path_print += "-" * 74 + "\n"
1036 path_print += "%6s %24s %40s\n" % header
1037 path_print += "-" * 74
1039 for n, contraction in enumerate(contraction_list):
1040 inds, idx_rm, einsum_str, remaining, blas = contraction
1041 remaining_str = ",".join(remaining) + "->" + output_subscript
1042 path_run = (scale_list[n], einsum_str, remaining_str)
1043 path_print += "\n%4d %24s %40s" % path_run
1045 path = ['einsum_path'] + path
1046 return (path, path_print)
1049def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs):
1050 # Arguably we dispatch on more arguments than we really should; see note in
1051 # _einsum_path_dispatcher for why.
1052 yield from operands
1053 yield out
1056# Rewrite einsum to handle different cases
1057@array_function_dispatch(_einsum_dispatcher, module='numpy')
1058def einsum(*operands, out=None, optimize=False, **kwargs):
1059 """
1060 einsum(subscripts, *operands, out=None, dtype=None, order='K',
1061 casting='safe', optimize=False)
1063 Evaluates the Einstein summation convention on the operands.
1065 Using the Einstein summation convention, many common multi-dimensional,
1066 linear algebraic array operations can be represented in a simple fashion.
1067 In *implicit* mode `einsum` computes these values.
1069 In *explicit* mode, `einsum` provides further flexibility to compute
1070 other array operations that might not be considered classical Einstein
1071 summation operations, by disabling, or forcing summation over specified
1072 subscript labels.
1074 See the notes and examples for clarification.
1076 Parameters
1077 ----------
1078 subscripts : str
1079 Specifies the subscripts for summation as comma separated list of
1080 subscript labels. An implicit (classical Einstein summation)
1081 calculation is performed unless the explicit indicator '->' is
1082 included as well as subscript labels of the precise output form.
1083 operands : list of array_like
1084 These are the arrays for the operation.
1085 out : ndarray, optional
1086 If provided, the calculation is done into this array.
1087 dtype : {data-type, None}, optional
1088 If provided, forces the calculation to use the data type specified.
1089 Note that you may have to also give a more liberal `casting`
1090 parameter to allow the conversions. Default is None.
1091 order : {'C', 'F', 'A', 'K'}, optional
1092 Controls the memory layout of the output. 'C' means it should
1093 be C contiguous. 'F' means it should be Fortran contiguous,
1094 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise.
1095 'K' means it should be as close to the layout as the inputs as
1096 is possible, including arbitrarily permuted axes.
1097 Default is 'K'.
1098 casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
1099 Controls what kind of data casting may occur. Setting this to
1100 'unsafe' is not recommended, as it can adversely affect accumulations.
1102 * 'no' means the data types should not be cast at all.
1103 * 'equiv' means only byte-order changes are allowed.
1104 * 'safe' means only casts which can preserve values are allowed.
1105 * 'same_kind' means only safe casts or casts within a kind,
1106 like float64 to float32, are allowed.
1107 * 'unsafe' means any data conversions may be done.
1109 Default is 'safe'.
1110 optimize : {False, True, 'greedy', 'optimal'}, optional
1111 Controls if intermediate optimization should occur. No optimization
1112 will occur if False and True will default to the 'greedy' algorithm.
1113 Also accepts an explicit contraction list from the ``np.einsum_path``
1114 function. See ``np.einsum_path`` for more details. Defaults to False.
1116 Returns
1117 -------
1118 output : ndarray
1119 The calculation based on the Einstein summation convention.
1121 See Also
1122 --------
1123 einsum_path, dot, inner, outer, tensordot, linalg.multi_dot
1124 einsum:
1125 Similar verbose interface is provided by the
1126 `einops <https://github.com/arogozhnikov/einops>`_ package to cover
1127 additional operations: transpose, reshape/flatten, repeat/tile,
1128 squeeze/unsqueeze and reductions.
1129 The `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_
1130 optimizes contraction order for einsum-like expressions
1131 in backend-agnostic manner.
1133 Notes
1134 -----
1135 .. versionadded:: 1.6.0
1137 The Einstein summation convention can be used to compute
1138 many multi-dimensional, linear algebraic array operations. `einsum`
1139 provides a succinct way of representing these.
1141 A non-exhaustive list of these operations,
1142 which can be computed by `einsum`, is shown below along with examples:
1144 * Trace of an array, :py:func:`numpy.trace`.
1145 * Return a diagonal, :py:func:`numpy.diag`.
1146 * Array axis summations, :py:func:`numpy.sum`.
1147 * Transpositions and permutations, :py:func:`numpy.transpose`.
1148 * Matrix multiplication and dot product, :py:func:`numpy.matmul`
1149 :py:func:`numpy.dot`.
1150 * Vector inner and outer products, :py:func:`numpy.inner`
1151 :py:func:`numpy.outer`.
1152 * Broadcasting, element-wise and scalar multiplication,
1153 :py:func:`numpy.multiply`.
1154 * Tensor contractions, :py:func:`numpy.tensordot`.
1155 * Chained array operations, in efficient calculation order,
1156 :py:func:`numpy.einsum_path`.
1158 The subscripts string is a comma-separated list of subscript labels,
1159 where each label refers to a dimension of the corresponding operand.
1160 Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)``
1161 is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label
1162 appears only once, it is not summed, so ``np.einsum('i', a)``
1163 produces a view of ``a`` with no changes. A further example
1164 ``np.einsum('ij,jk', a, b)`` describes traditional matrix multiplication
1165 and is equivalent to :py:func:`np.matmul(a,b) <numpy.matmul>`.
1166 Repeated subscript labels in one operand take the diagonal.
1167 For example, ``np.einsum('ii', a)`` is equivalent to
1168 :py:func:`np.trace(a) <numpy.trace>`.
1170 In *implicit mode*, the chosen subscripts are important
1171 since the axes of the output are reordered alphabetically. This
1172 means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while
1173 ``np.einsum('ji', a)`` takes its transpose. Additionally,
1174 ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while,
1175 ``np.einsum('ij,jh', a, b)`` returns the transpose of the
1176 multiplication since subscript 'h' precedes subscript 'i'.
1178 In *explicit mode* the output can be directly controlled by
1179 specifying output subscript labels. This requires the
1180 identifier '->' as well as the list of output subscript labels.
1181 This feature increases the flexibility of the function since
1182 summing can be disabled or forced when required. The call
1183 ``np.einsum('i->', a)`` is like :py:func:`np.sum(a) <numpy.sum>`
1184 if ``a`` is a 1-D array, and ``np.einsum('ii->i', a)``
1185 is like :py:func:`np.diag(a) <numpy.diag>` if ``a`` is a square 2-D array.
1186 The difference is that `einsum` does not allow broadcasting by default.
1187 Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the
1188 order of the output subscript labels and therefore returns matrix
1189 multiplication, unlike the example above in implicit mode.
1191 To enable and control broadcasting, use an ellipsis. Default
1192 NumPy-style broadcasting is done by adding an ellipsis
1193 to the left of each term, like ``np.einsum('...ii->...i', a)``.
1194 ``np.einsum('...i->...', a)`` is like
1195 :py:func:`np.sum(a, axis=-1) <numpy.sum>` for array ``a`` of any shape.
1196 To take the trace along the first and last axes,
1197 you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix
1198 product with the left-most indices instead of rightmost, one can do
1199 ``np.einsum('ij...,jk...->ik...', a, b)``.
1201 When there is only one operand, no axes are summed, and no output
1202 parameter is provided, a view into the operand is returned instead
1203 of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)``
1204 produces a view (changed in version 1.10.0).
1206 `einsum` also provides an alternative way to provide the subscripts and
1207 operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``.
1208 If the output shape is not provided in this format `einsum` will be
1209 calculated in implicit mode, otherwise it will be performed explicitly.
1210 The examples below have corresponding `einsum` calls with the two
1211 parameter methods.
1213 .. versionadded:: 1.10.0
1215 Views returned from einsum are now writeable whenever the input array
1216 is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now
1217 have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>`
1218 and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal
1219 of a 2D array.
1221 .. versionadded:: 1.12.0
1223 Added the ``optimize`` argument which will optimize the contraction order
1224 of an einsum expression. For a contraction with three or more operands
1225 this can greatly increase the computational efficiency at the cost of
1226 a larger memory footprint during computation.
1228 Typically a 'greedy' algorithm is applied which empirical tests have shown
1229 returns the optimal path in the majority of cases. In some cases 'optimal'
1230 will return the superlative path through a more expensive, exhaustive
1231 search. For iterative calculations it may be advisable to calculate
1232 the optimal path once and reuse that path by supplying it as an argument.
1233 An example is given below.
1235 See :py:func:`numpy.einsum_path` for more details.
1237 Examples
1238 --------
1239 >>> a = np.arange(25).reshape(5,5)
1240 >>> b = np.arange(5)
1241 >>> c = np.arange(6).reshape(2,3)
1243 Trace of a matrix:
1245 >>> np.einsum('ii', a)
1246 60
1247 >>> np.einsum(a, [0,0])
1248 60
1249 >>> np.trace(a)
1250 60
1252 Extract the diagonal (requires explicit form):
1254 >>> np.einsum('ii->i', a)
1255 array([ 0, 6, 12, 18, 24])
1256 >>> np.einsum(a, [0,0], [0])
1257 array([ 0, 6, 12, 18, 24])
1258 >>> np.diag(a)
1259 array([ 0, 6, 12, 18, 24])
1261 Sum over an axis (requires explicit form):
1263 >>> np.einsum('ij->i', a)
1264 array([ 10, 35, 60, 85, 110])
1265 >>> np.einsum(a, [0,1], [0])
1266 array([ 10, 35, 60, 85, 110])
1267 >>> np.sum(a, axis=1)
1268 array([ 10, 35, 60, 85, 110])
1270 For higher dimensional arrays summing a single axis can be done
1271 with ellipsis:
1273 >>> np.einsum('...j->...', a)
1274 array([ 10, 35, 60, 85, 110])
1275 >>> np.einsum(a, [Ellipsis,1], [Ellipsis])
1276 array([ 10, 35, 60, 85, 110])
1278 Compute a matrix transpose, or reorder any number of axes:
1280 >>> np.einsum('ji', c)
1281 array([[0, 3],
1282 [1, 4],
1283 [2, 5]])
1284 >>> np.einsum('ij->ji', c)
1285 array([[0, 3],
1286 [1, 4],
1287 [2, 5]])
1288 >>> np.einsum(c, [1,0])
1289 array([[0, 3],
1290 [1, 4],
1291 [2, 5]])
1292 >>> np.transpose(c)
1293 array([[0, 3],
1294 [1, 4],
1295 [2, 5]])
1297 Vector inner products:
1299 >>> np.einsum('i,i', b, b)
1300 30
1301 >>> np.einsum(b, [0], b, [0])
1302 30
1303 >>> np.inner(b,b)
1304 30
1306 Matrix vector multiplication:
1308 >>> np.einsum('ij,j', a, b)
1309 array([ 30, 80, 130, 180, 230])
1310 >>> np.einsum(a, [0,1], b, [1])
1311 array([ 30, 80, 130, 180, 230])
1312 >>> np.dot(a, b)
1313 array([ 30, 80, 130, 180, 230])
1314 >>> np.einsum('...j,j', a, b)
1315 array([ 30, 80, 130, 180, 230])
1317 Broadcasting and scalar multiplication:
1319 >>> np.einsum('..., ...', 3, c)
1320 array([[ 0, 3, 6],
1321 [ 9, 12, 15]])
1322 >>> np.einsum(',ij', 3, c)
1323 array([[ 0, 3, 6],
1324 [ 9, 12, 15]])
1325 >>> np.einsum(3, [Ellipsis], c, [Ellipsis])
1326 array([[ 0, 3, 6],
1327 [ 9, 12, 15]])
1328 >>> np.multiply(3, c)
1329 array([[ 0, 3, 6],
1330 [ 9, 12, 15]])
1332 Vector outer product:
1334 >>> np.einsum('i,j', np.arange(2)+1, b)
1335 array([[0, 1, 2, 3, 4],
1336 [0, 2, 4, 6, 8]])
1337 >>> np.einsum(np.arange(2)+1, [0], b, [1])
1338 array([[0, 1, 2, 3, 4],
1339 [0, 2, 4, 6, 8]])
1340 >>> np.outer(np.arange(2)+1, b)
1341 array([[0, 1, 2, 3, 4],
1342 [0, 2, 4, 6, 8]])
1344 Tensor contraction:
1346 >>> a = np.arange(60.).reshape(3,4,5)
1347 >>> b = np.arange(24.).reshape(4,3,2)
1348 >>> np.einsum('ijk,jil->kl', a, b)
1349 array([[4400., 4730.],
1350 [4532., 4874.],
1351 [4664., 5018.],
1352 [4796., 5162.],
1353 [4928., 5306.]])
1354 >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
1355 array([[4400., 4730.],
1356 [4532., 4874.],
1357 [4664., 5018.],
1358 [4796., 5162.],
1359 [4928., 5306.]])
1360 >>> np.tensordot(a,b, axes=([1,0],[0,1]))
1361 array([[4400., 4730.],
1362 [4532., 4874.],
1363 [4664., 5018.],
1364 [4796., 5162.],
1365 [4928., 5306.]])
1367 Writeable returned arrays (since version 1.10.0):
1369 >>> a = np.zeros((3, 3))
1370 >>> np.einsum('ii->i', a)[:] = 1
1371 >>> a
1372 array([[1., 0., 0.],
1373 [0., 1., 0.],
1374 [0., 0., 1.]])
1376 Example of ellipsis use:
1378 >>> a = np.arange(6).reshape((3,2))
1379 >>> b = np.arange(12).reshape((4,3))
1380 >>> np.einsum('ki,jk->ij', a, b)
1381 array([[10, 28, 46, 64],
1382 [13, 40, 67, 94]])
1383 >>> np.einsum('ki,...k->i...', a, b)
1384 array([[10, 28, 46, 64],
1385 [13, 40, 67, 94]])
1386 >>> np.einsum('k...,jk', a, b)
1387 array([[10, 28, 46, 64],
1388 [13, 40, 67, 94]])
1390 Chained array operations. For more complicated contractions, speed ups
1391 might be achieved by repeatedly computing a 'greedy' path or pre-computing
1392 the 'optimal' path and repeatedly applying it, using an `einsum_path`
1393 insertion (since version 1.12.0). Performance improvements can be
1394 particularly significant with larger arrays:
1396 >>> a = np.ones(64).reshape(2,4,8)
1398 Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.)
1400 >>> for iteration in range(500):
1401 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
1403 Sub-optimal `einsum` (due to repeated path calculation time): ~330ms
1405 >>> for iteration in range(500):
1406 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a,
1407 ... optimize='optimal')
1409 Greedy `einsum` (faster optimal path approximation): ~160ms
1411 >>> for iteration in range(500):
1412 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
1414 Optimal `einsum` (best usage pattern in some use cases): ~110ms
1416 >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a,
1417 ... optimize='optimal')[0]
1418 >>> for iteration in range(500):
1419 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
1421 """
1422 # Special handling if out is specified
1423 specified_out = out is not None
1425 # If no optimization, run pure einsum
1426 if optimize is False:
1427 if specified_out:
1428 kwargs['out'] = out
1429 return c_einsum(*operands, **kwargs)
1431 # Check the kwargs to avoid a more cryptic error later, without having to
1432 # repeat default values here
1433 valid_einsum_kwargs = ['dtype', 'order', 'casting']
1434 unknown_kwargs = [k for (k, v) in kwargs.items() if
1435 k not in valid_einsum_kwargs]
1436 if len(unknown_kwargs):
1437 raise TypeError("Did not understand the following kwargs: %s"
1438 % unknown_kwargs)
1440 # Build the contraction list and operand
1441 operands, contraction_list = einsum_path(*operands, optimize=optimize,
1442 einsum_call=True)
1444 # Handle order kwarg for output array, c_einsum allows mixed case
1445 output_order = kwargs.pop('order', 'K')
1446 if output_order.upper() == 'A':
1447 if all(arr.flags.f_contiguous for arr in operands):
1448 output_order = 'F'
1449 else:
1450 output_order = 'C'
1452 # Start contraction loop
1453 for num, contraction in enumerate(contraction_list):
1454 inds, idx_rm, einsum_str, remaining, blas = contraction
1455 tmp_operands = [operands.pop(x) for x in inds]
1457 # Do we need to deal with the output?
1458 handle_out = specified_out and ((num + 1) == len(contraction_list))
1460 # Call tensordot if still possible
1461 if blas:
1462 # Checks have already been handled
1463 input_str, results_index = einsum_str.split('->')
1464 input_left, input_right = input_str.split(',')
1466 tensor_result = input_left + input_right
1467 for s in idx_rm:
1468 tensor_result = tensor_result.replace(s, "")
1470 # Find indices to contract over
1471 left_pos, right_pos = [], []
1472 for s in sorted(idx_rm):
1473 left_pos.append(input_left.find(s))
1474 right_pos.append(input_right.find(s))
1476 # Contract!
1477 new_view = tensordot(
1478 *tmp_operands, axes=(tuple(left_pos), tuple(right_pos))
1479 )
1481 # Build a new view if needed
1482 if (tensor_result != results_index) or handle_out:
1483 if handle_out:
1484 kwargs["out"] = out
1485 new_view = c_einsum(
1486 tensor_result + '->' + results_index, new_view, **kwargs
1487 )
1489 # Call einsum
1490 else:
1491 # If out was specified
1492 if handle_out:
1493 kwargs["out"] = out
1495 # Do the contraction
1496 new_view = c_einsum(einsum_str, *tmp_operands, **kwargs)
1498 # Append new items and dereference what we can
1499 operands.append(new_view)
1500 del tmp_operands, new_view
1502 if specified_out:
1503 return out
1504 else:
1505 return asanyarray(operands[0], order=output_order)