Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.10/site-packages/numpy/_core/einsumfunc.py: 6%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

413 statements  

1""" 

2Implementation of optimized einsum. 

3 

4""" 

5import itertools 

6import operator 

7 

8from numpy._core.multiarray import c_einsum 

9from numpy._core.numeric import asanyarray, tensordot 

10from numpy._core.overrides import array_function_dispatch 

11 

12__all__ = ['einsum', 'einsum_path'] 

13 

14# importing string for string.ascii_letters would be too slow 

15# the first import before caching has been measured to take 800 µs (#23777) 

16einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' 

17einsum_symbols_set = set(einsum_symbols) 

18 

19 

20def _flop_count(idx_contraction, inner, num_terms, size_dictionary): 

21 """ 

22 Computes the number of FLOPS in the contraction. 

23 

24 Parameters 

25 ---------- 

26 idx_contraction : iterable 

27 The indices involved in the contraction 

28 inner : bool 

29 Does this contraction require an inner product? 

30 num_terms : int 

31 The number of terms in a contraction 

32 size_dictionary : dict 

33 The size of each of the indices in idx_contraction 

34 

35 Returns 

36 ------- 

37 flop_count : int 

38 The total number of FLOPS required for the contraction. 

39 

40 Examples 

41 -------- 

42 

43 >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5}) 

44 30 

45 

46 >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5}) 

47 60 

48 

49 """ 

50 

51 overall_size = _compute_size_by_dict(idx_contraction, size_dictionary) 

52 op_factor = max(1, num_terms - 1) 

53 if inner: 

54 op_factor += 1 

55 

56 return overall_size * op_factor 

57 

58def _compute_size_by_dict(indices, idx_dict): 

59 """ 

60 Computes the product of the elements in indices based on the dictionary 

61 idx_dict. 

62 

63 Parameters 

64 ---------- 

65 indices : iterable 

66 Indices to base the product on. 

67 idx_dict : dictionary 

68 Dictionary of index sizes 

69 

70 Returns 

71 ------- 

72 ret : int 

73 The resulting product. 

74 

75 Examples 

76 -------- 

77 >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5}) 

78 90 

79 

80 """ 

81 ret = 1 

82 for i in indices: 

83 ret *= idx_dict[i] 

84 return ret 

85 

86 

87def _find_contraction(positions, input_sets, output_set): 

88 """ 

89 Finds the contraction for a given set of input and output sets. 

90 

91 Parameters 

92 ---------- 

93 positions : iterable 

94 Integer positions of terms used in the contraction. 

95 input_sets : list 

96 List of sets that represent the lhs side of the einsum subscript 

97 output_set : set 

98 Set that represents the rhs side of the overall einsum subscript 

99 

100 Returns 

101 ------- 

102 new_result : set 

103 The indices of the resulting contraction 

104 remaining : list 

105 List of sets that have not been contracted, the new set is appended to 

106 the end of this list 

107 idx_removed : set 

108 Indices removed from the entire contraction 

109 idx_contraction : set 

110 The indices used in the current contraction 

111 

112 Examples 

113 -------- 

114 

115 # A simple dot product test case 

116 >>> pos = (0, 1) 

117 >>> isets = [set('ab'), set('bc')] 

118 >>> oset = set('ac') 

119 >>> _find_contraction(pos, isets, oset) 

120 ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'}) 

121 

122 # A more complex case with additional terms in the contraction 

123 >>> pos = (0, 2) 

124 >>> isets = [set('abd'), set('ac'), set('bdc')] 

125 >>> oset = set('ac') 

126 >>> _find_contraction(pos, isets, oset) 

127 ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'}) 

128 """ 

129 

130 idx_contract = set() 

131 idx_remain = output_set.copy() 

132 remaining = [] 

133 for ind, value in enumerate(input_sets): 

134 if ind in positions: 

135 idx_contract |= value 

136 else: 

137 remaining.append(value) 

138 idx_remain |= value 

139 

140 new_result = idx_remain & idx_contract 

141 idx_removed = (idx_contract - new_result) 

142 remaining.append(new_result) 

143 

144 return (new_result, remaining, idx_removed, idx_contract) 

145 

146 

147def _optimal_path(input_sets, output_set, idx_dict, memory_limit): 

148 """ 

149 Computes all possible pair contractions, sieves the results based 

150 on ``memory_limit`` and returns the lowest cost path. This algorithm 

151 scales factorial with respect to the elements in the list ``input_sets``. 

152 

153 Parameters 

154 ---------- 

155 input_sets : list 

156 List of sets that represent the lhs side of the einsum subscript 

157 output_set : set 

158 Set that represents the rhs side of the overall einsum subscript 

159 idx_dict : dictionary 

160 Dictionary of index sizes 

161 memory_limit : int 

162 The maximum number of elements in a temporary array 

163 

164 Returns 

165 ------- 

166 path : list 

167 The optimal contraction order within the memory limit constraint. 

168 

169 Examples 

170 -------- 

171 >>> isets = [set('abd'), set('ac'), set('bdc')] 

172 >>> oset = set() 

173 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 

174 >>> _optimal_path(isets, oset, idx_sizes, 5000) 

175 [(0, 2), (0, 1)] 

176 """ 

177 

178 full_results = [(0, [], input_sets)] 

179 for iteration in range(len(input_sets) - 1): 

180 iter_results = [] 

181 

182 # Compute all unique pairs 

183 for curr in full_results: 

184 cost, positions, remaining = curr 

185 for con in itertools.combinations( 

186 range(len(input_sets) - iteration), 2 

187 ): 

188 

189 # Find the contraction 

190 cont = _find_contraction(con, remaining, output_set) 

191 new_result, new_input_sets, idx_removed, idx_contract = cont 

192 

193 # Sieve the results based on memory_limit 

194 new_size = _compute_size_by_dict(new_result, idx_dict) 

195 if new_size > memory_limit: 

196 continue 

197 

198 # Build (total_cost, positions, indices_remaining) 

199 total_cost = cost + _flop_count( 

200 idx_contract, idx_removed, len(con), idx_dict 

201 ) 

202 new_pos = positions + [con] 

203 iter_results.append((total_cost, new_pos, new_input_sets)) 

204 

205 # Update combinatorial list, if we did not find anything return best 

206 # path + remaining contractions 

207 if iter_results: 

208 full_results = iter_results 

209 else: 

210 path = min(full_results, key=lambda x: x[0])[1] 

211 path += [tuple(range(len(input_sets) - iteration))] 

212 return path 

213 

214 # If we have not found anything return single einsum contraction 

215 if len(full_results) == 0: 

216 return [tuple(range(len(input_sets)))] 

217 

218 path = min(full_results, key=lambda x: x[0])[1] 

219 return path 

220 

221def _parse_possible_contraction( 

222 positions, input_sets, output_set, idx_dict, 

223 memory_limit, path_cost, naive_cost 

224 ): 

225 """Compute the cost (removed size + flops) and resultant indices for 

226 performing the contraction specified by ``positions``. 

227 

228 Parameters 

229 ---------- 

230 positions : tuple of int 

231 The locations of the proposed tensors to contract. 

232 input_sets : list of sets 

233 The indices found on each tensors. 

234 output_set : set 

235 The output indices of the expression. 

236 idx_dict : dict 

237 Mapping of each index to its size. 

238 memory_limit : int 

239 The total allowed size for an intermediary tensor. 

240 path_cost : int 

241 The contraction cost so far. 

242 naive_cost : int 

243 The cost of the unoptimized expression. 

244 

245 Returns 

246 ------- 

247 cost : (int, int) 

248 A tuple containing the size of any indices removed, and the flop cost. 

249 positions : tuple of int 

250 The locations of the proposed tensors to contract. 

251 new_input_sets : list of sets 

252 The resulting new list of indices if this proposed contraction 

253 is performed. 

254 

255 """ 

256 

257 # Find the contraction 

258 contract = _find_contraction(positions, input_sets, output_set) 

259 idx_result, new_input_sets, idx_removed, idx_contract = contract 

260 

261 # Sieve the results based on memory_limit 

262 new_size = _compute_size_by_dict(idx_result, idx_dict) 

263 if new_size > memory_limit: 

264 return None 

265 

266 # Build sort tuple 

267 old_sizes = ( 

268 _compute_size_by_dict(input_sets[p], idx_dict) for p in positions 

269 ) 

270 removed_size = sum(old_sizes) - new_size 

271 

272 # NB: removed_size used to be just the size of any removed indices i.e.: 

273 # helpers.compute_size_by_dict(idx_removed, idx_dict) 

274 cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict) 

275 sort = (-removed_size, cost) 

276 

277 # Sieve based on total cost as well 

278 if (path_cost + cost) > naive_cost: 

279 return None 

280 

281 # Add contraction to possible choices 

282 return [sort, positions, new_input_sets] 

283 

284 

285def _update_other_results(results, best): 

286 """Update the positions and provisional input_sets of ``results`` 

287 based on performing the contraction result ``best``. Remove any 

288 involving the tensors contracted. 

289 

290 Parameters 

291 ---------- 

292 results : list 

293 List of contraction results produced by 

294 ``_parse_possible_contraction``. 

295 best : list 

296 The best contraction of ``results`` i.e. the one that 

297 will be performed. 

298 

299 Returns 

300 ------- 

301 mod_results : list 

302 The list of modified results, updated with outcome of 

303 ``best`` contraction. 

304 """ 

305 

306 best_con = best[1] 

307 bx, by = best_con 

308 mod_results = [] 

309 

310 for cost, (x, y), con_sets in results: 

311 

312 # Ignore results involving tensors just contracted 

313 if x in best_con or y in best_con: 

314 continue 

315 

316 # Update the input_sets 

317 del con_sets[by - int(by > x) - int(by > y)] 

318 del con_sets[bx - int(bx > x) - int(bx > y)] 

319 con_sets.insert(-1, best[2][-1]) 

320 

321 # Update the position indices 

322 mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by) 

323 mod_results.append((cost, mod_con, con_sets)) 

324 

325 return mod_results 

326 

327def _greedy_path(input_sets, output_set, idx_dict, memory_limit): 

328 """ 

329 Finds the path by contracting the best pair until the input list is 

330 exhausted. The best pair is found by minimizing the tuple 

331 ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing 

332 matrix multiplication or inner product operations, then Hadamard like 

333 operations, and finally outer operations. Outer products are limited by 

334 ``memory_limit``. This algorithm scales cubically with respect to the 

335 number of elements in the list ``input_sets``. 

336 

337 Parameters 

338 ---------- 

339 input_sets : list 

340 List of sets that represent the lhs side of the einsum subscript 

341 output_set : set 

342 Set that represents the rhs side of the overall einsum subscript 

343 idx_dict : dictionary 

344 Dictionary of index sizes 

345 memory_limit : int 

346 The maximum number of elements in a temporary array 

347 

348 Returns 

349 ------- 

350 path : list 

351 The greedy contraction order within the memory limit constraint. 

352 

353 Examples 

354 -------- 

355 >>> isets = [set('abd'), set('ac'), set('bdc')] 

356 >>> oset = set() 

357 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 

358 >>> _greedy_path(isets, oset, idx_sizes, 5000) 

359 [(0, 2), (0, 1)] 

360 """ 

361 

362 # Handle trivial cases that leaked through 

363 if len(input_sets) == 1: 

364 return [(0,)] 

365 elif len(input_sets) == 2: 

366 return [(0, 1)] 

367 

368 # Build up a naive cost 

369 contract = _find_contraction( 

370 range(len(input_sets)), input_sets, output_set 

371 ) 

372 idx_result, new_input_sets, idx_removed, idx_contract = contract 

373 naive_cost = _flop_count( 

374 idx_contract, idx_removed, len(input_sets), idx_dict 

375 ) 

376 

377 # Initially iterate over all pairs 

378 comb_iter = itertools.combinations(range(len(input_sets)), 2) 

379 known_contractions = [] 

380 

381 path_cost = 0 

382 path = [] 

383 

384 for iteration in range(len(input_sets) - 1): 

385 

386 # Iterate over all pairs on the first step, only previously 

387 # found pairs on subsequent steps 

388 for positions in comb_iter: 

389 

390 # Always initially ignore outer products 

391 if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]): 

392 continue 

393 

394 result = _parse_possible_contraction( 

395 positions, input_sets, output_set, idx_dict, 

396 memory_limit, path_cost, naive_cost 

397 ) 

398 if result is not None: 

399 known_contractions.append(result) 

400 

401 # If we do not have a inner contraction, rescan pairs 

402 # including outer products 

403 if len(known_contractions) == 0: 

404 

405 # Then check the outer products 

406 for positions in itertools.combinations( 

407 range(len(input_sets)), 2 

408 ): 

409 result = _parse_possible_contraction( 

410 positions, input_sets, output_set, idx_dict, 

411 memory_limit, path_cost, naive_cost 

412 ) 

413 if result is not None: 

414 known_contractions.append(result) 

415 

416 # If we still did not find any remaining contractions, 

417 # default back to einsum like behavior 

418 if len(known_contractions) == 0: 

419 path.append(tuple(range(len(input_sets)))) 

420 break 

421 

422 # Sort based on first index 

423 best = min(known_contractions, key=lambda x: x[0]) 

424 

425 # Now propagate as many unused contractions as possible 

426 # to the next iteration 

427 known_contractions = _update_other_results(known_contractions, best) 

428 

429 # Next iteration only compute contractions with the new tensor 

430 # All other contractions have been accounted for 

431 input_sets = best[2] 

432 new_tensor_pos = len(input_sets) - 1 

433 comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos)) 

434 

435 # Update path and total cost 

436 path.append(best[1]) 

437 path_cost += best[0][1] 

438 

439 return path 

440 

441 

442def _can_dot(inputs, result, idx_removed): 

443 """ 

444 Checks if we can use BLAS (np.tensordot) call and its beneficial to do so. 

445 

446 Parameters 

447 ---------- 

448 inputs : list of str 

449 Specifies the subscripts for summation. 

450 result : str 

451 Resulting summation. 

452 idx_removed : set 

453 Indices that are removed in the summation 

454 

455 

456 Returns 

457 ------- 

458 type : bool 

459 Returns true if BLAS should and can be used, else False 

460 

461 Notes 

462 ----- 

463 If the operations is BLAS level 1 or 2 and is not already aligned 

464 we default back to einsum as the memory movement to copy is more 

465 costly than the operation itself. 

466 

467 

468 Examples 

469 -------- 

470 

471 # Standard GEMM operation 

472 >>> _can_dot(['ij', 'jk'], 'ik', set('j')) 

473 True 

474 

475 # Can use the standard BLAS, but requires odd data movement 

476 >>> _can_dot(['ijj', 'jk'], 'ik', set('j')) 

477 False 

478 

479 # DDOT where the memory is not aligned 

480 >>> _can_dot(['ijk', 'ikj'], '', set('ijk')) 

481 False 

482 

483 """ 

484 

485 # All `dot` calls remove indices 

486 if len(idx_removed) == 0: 

487 return False 

488 

489 # BLAS can only handle two operands 

490 if len(inputs) != 2: 

491 return False 

492 

493 input_left, input_right = inputs 

494 

495 for c in set(input_left + input_right): 

496 # can't deal with repeated indices on same input or more than 2 total 

497 nl, nr = input_left.count(c), input_right.count(c) 

498 if (nl > 1) or (nr > 1) or (nl + nr > 2): 

499 return False 

500 

501 # can't do implicit summation or dimension collapse e.g. 

502 # "ab,bc->c" (implicitly sum over 'a') 

503 # "ab,ca->ca" (take diagonal of 'a') 

504 if nl + nr - 1 == int(c in result): 

505 return False 

506 

507 # Build a few temporaries 

508 set_left = set(input_left) 

509 set_right = set(input_right) 

510 keep_left = set_left - idx_removed 

511 keep_right = set_right - idx_removed 

512 rs = len(idx_removed) 

513 

514 # At this point we are a DOT, GEMV, or GEMM operation 

515 

516 # Handle inner products 

517 

518 # DDOT with aligned data 

519 if input_left == input_right: 

520 return True 

521 

522 # DDOT without aligned data (better to use einsum) 

523 if set_left == set_right: 

524 return False 

525 

526 # Handle the 4 possible (aligned) GEMV or GEMM cases 

527 

528 # GEMM or GEMV no transpose 

529 if input_left[-rs:] == input_right[:rs]: 

530 return True 

531 

532 # GEMM or GEMV transpose both 

533 if input_left[:rs] == input_right[-rs:]: 

534 return True 

535 

536 # GEMM or GEMV transpose right 

537 if input_left[-rs:] == input_right[-rs:]: 

538 return True 

539 

540 # GEMM or GEMV transpose left 

541 if input_left[:rs] == input_right[:rs]: 

542 return True 

543 

544 # Einsum is faster than GEMV if we have to copy data 

545 if not keep_left or not keep_right: 

546 return False 

547 

548 # We are a matrix-matrix product, but we need to copy data 

549 return True 

550 

551 

552def _parse_einsum_input(operands): 

553 """ 

554 A reproduction of einsum c side einsum parsing in python. 

555 

556 Returns 

557 ------- 

558 input_strings : str 

559 Parsed input strings 

560 output_string : str 

561 Parsed output string 

562 operands : list of array_like 

563 The operands to use in the numpy contraction 

564 

565 Examples 

566 -------- 

567 The operand list is simplified to reduce printing: 

568 

569 >>> np.random.seed(123) 

570 >>> a = np.random.rand(4, 4) 

571 >>> b = np.random.rand(4, 4, 4) 

572 >>> _parse_einsum_input(('...a,...a->...', a, b)) 

573 ('za,xza', 'xz', [a, b]) # may vary 

574 

575 >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) 

576 ('za,xza', 'xz', [a, b]) # may vary 

577 """ 

578 

579 if len(operands) == 0: 

580 raise ValueError("No input operands") 

581 

582 if isinstance(operands[0], str): 

583 subscripts = operands[0].replace(" ", "") 

584 operands = [asanyarray(v) for v in operands[1:]] 

585 

586 # Ensure all characters are valid 

587 for s in subscripts: 

588 if s in '.,->': 

589 continue 

590 if s not in einsum_symbols: 

591 raise ValueError("Character %s is not a valid symbol." % s) 

592 

593 else: 

594 tmp_operands = list(operands) 

595 operand_list = [] 

596 subscript_list = [] 

597 for p in range(len(operands) // 2): 

598 operand_list.append(tmp_operands.pop(0)) 

599 subscript_list.append(tmp_operands.pop(0)) 

600 

601 output_list = tmp_operands[-1] if len(tmp_operands) else None 

602 operands = [asanyarray(v) for v in operand_list] 

603 subscripts = "" 

604 last = len(subscript_list) - 1 

605 for num, sub in enumerate(subscript_list): 

606 for s in sub: 

607 if s is Ellipsis: 

608 subscripts += "..." 

609 else: 

610 try: 

611 s = operator.index(s) 

612 except TypeError as e: 

613 raise TypeError( 

614 "For this input type lists must contain " 

615 "either int or Ellipsis" 

616 ) from e 

617 subscripts += einsum_symbols[s] 

618 if num != last: 

619 subscripts += "," 

620 

621 if output_list is not None: 

622 subscripts += "->" 

623 for s in output_list: 

624 if s is Ellipsis: 

625 subscripts += "..." 

626 else: 

627 try: 

628 s = operator.index(s) 

629 except TypeError as e: 

630 raise TypeError( 

631 "For this input type lists must contain " 

632 "either int or Ellipsis" 

633 ) from e 

634 subscripts += einsum_symbols[s] 

635 # Check for proper "->" 

636 if ("-" in subscripts) or (">" in subscripts): 

637 invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1) 

638 if invalid or (subscripts.count("->") != 1): 

639 raise ValueError("Subscripts can only contain one '->'.") 

640 

641 # Parse ellipses 

642 if "." in subscripts: 

643 used = subscripts.replace(".", "").replace(",", "").replace("->", "") 

644 unused = list(einsum_symbols_set - set(used)) 

645 ellipse_inds = "".join(unused) 

646 longest = 0 

647 

648 if "->" in subscripts: 

649 input_tmp, output_sub = subscripts.split("->") 

650 split_subscripts = input_tmp.split(",") 

651 out_sub = True 

652 else: 

653 split_subscripts = subscripts.split(',') 

654 out_sub = False 

655 

656 for num, sub in enumerate(split_subscripts): 

657 if "." in sub: 

658 if (sub.count(".") != 3) or (sub.count("...") != 1): 

659 raise ValueError("Invalid Ellipses.") 

660 

661 # Take into account numerical values 

662 if operands[num].shape == (): 

663 ellipse_count = 0 

664 else: 

665 ellipse_count = max(operands[num].ndim, 1) 

666 ellipse_count -= (len(sub) - 3) 

667 

668 if ellipse_count > longest: 

669 longest = ellipse_count 

670 

671 if ellipse_count < 0: 

672 raise ValueError("Ellipses lengths do not match.") 

673 elif ellipse_count == 0: 

674 split_subscripts[num] = sub.replace('...', '') 

675 else: 

676 rep_inds = ellipse_inds[-ellipse_count:] 

677 split_subscripts[num] = sub.replace('...', rep_inds) 

678 

679 subscripts = ",".join(split_subscripts) 

680 if longest == 0: 

681 out_ellipse = "" 

682 else: 

683 out_ellipse = ellipse_inds[-longest:] 

684 

685 if out_sub: 

686 subscripts += "->" + output_sub.replace("...", out_ellipse) 

687 else: 

688 # Special care for outputless ellipses 

689 output_subscript = "" 

690 tmp_subscripts = subscripts.replace(",", "") 

691 for s in sorted(set(tmp_subscripts)): 

692 if s not in (einsum_symbols): 

693 raise ValueError("Character %s is not a valid symbol." % s) 

694 if tmp_subscripts.count(s) == 1: 

695 output_subscript += s 

696 normal_inds = ''.join(sorted(set(output_subscript) - 

697 set(out_ellipse))) 

698 

699 subscripts += "->" + out_ellipse + normal_inds 

700 

701 # Build output string if does not exist 

702 if "->" in subscripts: 

703 input_subscripts, output_subscript = subscripts.split("->") 

704 else: 

705 input_subscripts = subscripts 

706 # Build output subscripts 

707 tmp_subscripts = subscripts.replace(",", "") 

708 output_subscript = "" 

709 for s in sorted(set(tmp_subscripts)): 

710 if s not in einsum_symbols: 

711 raise ValueError("Character %s is not a valid symbol." % s) 

712 if tmp_subscripts.count(s) == 1: 

713 output_subscript += s 

714 

715 # Make sure output subscripts are in the input 

716 for char in output_subscript: 

717 if output_subscript.count(char) != 1: 

718 raise ValueError("Output character %s appeared more than once in " 

719 "the output." % char) 

720 if char not in input_subscripts: 

721 raise ValueError("Output character %s did not appear in the input" 

722 % char) 

723 

724 # Make sure number operands is equivalent to the number of terms 

725 if len(input_subscripts.split(',')) != len(operands): 

726 raise ValueError("Number of einsum subscripts must be equal to the " 

727 "number of operands.") 

728 

729 return (input_subscripts, output_subscript, operands) 

730 

731 

732def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None): 

733 # NOTE: technically, we should only dispatch on array-like arguments, not 

734 # subscripts (given as strings). But separating operands into 

735 # arrays/subscripts is a little tricky/slow (given einsum's two supported 

736 # signatures), so as a practical shortcut we dispatch on everything. 

737 # Strings will be ignored for dispatching since they don't define 

738 # __array_function__. 

739 return operands 

740 

741 

742@array_function_dispatch(_einsum_path_dispatcher, module='numpy') 

743def einsum_path(*operands, optimize='greedy', einsum_call=False): 

744 """ 

745 einsum_path(subscripts, *operands, optimize='greedy') 

746 

747 Evaluates the lowest cost contraction order for an einsum expression by 

748 considering the creation of intermediate arrays. 

749 

750 Parameters 

751 ---------- 

752 subscripts : str 

753 Specifies the subscripts for summation. 

754 *operands : list of array_like 

755 These are the arrays for the operation. 

756 optimize : {bool, list, tuple, 'greedy', 'optimal'} 

757 Choose the type of path. If a tuple is provided, the second argument is 

758 assumed to be the maximum intermediate size created. If only a single 

759 argument is provided the largest input or output array size is used 

760 as a maximum intermediate size. 

761 

762 * if a list is given that starts with ``einsum_path``, uses this as the 

763 contraction path 

764 * if False no optimization is taken 

765 * if True defaults to the 'greedy' algorithm 

766 * 'optimal' An algorithm that combinatorially explores all possible 

767 ways of contracting the listed tensors and chooses the least costly 

768 path. Scales exponentially with the number of terms in the 

769 contraction. 

770 * 'greedy' An algorithm that chooses the best pair contraction 

771 at each step. Effectively, this algorithm searches the largest inner, 

772 Hadamard, and then outer products at each step. Scales cubically with 

773 the number of terms in the contraction. Equivalent to the 'optimal' 

774 path for most contractions. 

775 

776 Default is 'greedy'. 

777 

778 Returns 

779 ------- 

780 path : list of tuples 

781 A list representation of the einsum path. 

782 string_repr : str 

783 A printable representation of the einsum path. 

784 

785 Notes 

786 ----- 

787 The resulting path indicates which terms of the input contraction should be 

788 contracted first, the result of this contraction is then appended to the 

789 end of the contraction list. This list can then be iterated over until all 

790 intermediate contractions are complete. 

791 

792 See Also 

793 -------- 

794 einsum, linalg.multi_dot 

795 

796 Examples 

797 -------- 

798 

799 We can begin with a chain dot example. In this case, it is optimal to 

800 contract the ``b`` and ``c`` tensors first as represented by the first 

801 element of the path ``(1, 2)``. The resulting tensor is added to the end 

802 of the contraction and the remaining contraction ``(0, 1)`` is then 

803 completed. 

804 

805 >>> np.random.seed(123) 

806 >>> a = np.random.rand(2, 2) 

807 >>> b = np.random.rand(2, 5) 

808 >>> c = np.random.rand(5, 2) 

809 >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy') 

810 >>> print(path_info[0]) 

811 ['einsum_path', (1, 2), (0, 1)] 

812 >>> print(path_info[1]) 

813 Complete contraction: ij,jk,kl->il # may vary 

814 Naive scaling: 4 

815 Optimized scaling: 3 

816 Naive FLOP count: 1.600e+02 

817 Optimized FLOP count: 5.600e+01 

818 Theoretical speedup: 2.857 

819 Largest intermediate: 4.000e+00 elements 

820 ------------------------------------------------------------------------- 

821 scaling current remaining 

822 ------------------------------------------------------------------------- 

823 3 kl,jk->jl ij,jl->il 

824 3 jl,ij->il il->il 

825 

826 

827 A more complex index transformation example. 

828 

829 >>> I = np.random.rand(10, 10, 10, 10) 

830 >>> C = np.random.rand(10, 10) 

831 >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C, 

832 ... optimize='greedy') 

833 

834 >>> print(path_info[0]) 

835 ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)] 

836 >>> print(path_info[1]) 

837 Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary 

838 Naive scaling: 8 

839 Optimized scaling: 5 

840 Naive FLOP count: 8.000e+08 

841 Optimized FLOP count: 8.000e+05 

842 Theoretical speedup: 1000.000 

843 Largest intermediate: 1.000e+04 elements 

844 -------------------------------------------------------------------------- 

845 scaling current remaining 

846 -------------------------------------------------------------------------- 

847 5 abcd,ea->bcde fb,gc,hd,bcde->efgh 

848 5 bcde,fb->cdef gc,hd,cdef->efgh 

849 5 cdef,gc->defg hd,defg->efgh 

850 5 defg,hd->efgh efgh->efgh 

851 """ 

852 

853 # Figure out what the path really is 

854 path_type = optimize 

855 if path_type is True: 

856 path_type = 'greedy' 

857 if path_type is None: 

858 path_type = False 

859 

860 explicit_einsum_path = False 

861 memory_limit = None 

862 

863 # No optimization or a named path algorithm 

864 if (path_type is False) or isinstance(path_type, str): 

865 pass 

866 

867 # Given an explicit path 

868 elif len(path_type) and (path_type[0] == 'einsum_path'): 

869 explicit_einsum_path = True 

870 

871 # Path tuple with memory limit 

872 elif ((len(path_type) == 2) and isinstance(path_type[0], str) and 

873 isinstance(path_type[1], (int, float))): 

874 memory_limit = int(path_type[1]) 

875 path_type = path_type[0] 

876 

877 else: 

878 raise TypeError("Did not understand the path: %s" % str(path_type)) 

879 

880 # Hidden option, only einsum should call this 

881 einsum_call_arg = einsum_call 

882 

883 # Python side parsing 

884 input_subscripts, output_subscript, operands = ( 

885 _parse_einsum_input(operands) 

886 ) 

887 

888 # Build a few useful list and sets 

889 input_list = input_subscripts.split(',') 

890 input_sets = [set(x) for x in input_list] 

891 output_set = set(output_subscript) 

892 indices = set(input_subscripts.replace(',', '')) 

893 

894 # Get length of each unique dimension and ensure all dimensions are correct 

895 dimension_dict = {} 

896 broadcast_indices = [[] for x in range(len(input_list))] 

897 for tnum, term in enumerate(input_list): 

898 sh = operands[tnum].shape 

899 if len(sh) != len(term): 

900 raise ValueError("Einstein sum subscript %s does not contain the " 

901 "correct number of indices for operand %d." 

902 % (input_subscripts[tnum], tnum)) 

903 for cnum, char in enumerate(term): 

904 dim = sh[cnum] 

905 

906 # Build out broadcast indices 

907 if dim == 1: 

908 broadcast_indices[tnum].append(char) 

909 

910 if char in dimension_dict.keys(): 

911 # For broadcasting cases we always want the largest dim size 

912 if dimension_dict[char] == 1: 

913 dimension_dict[char] = dim 

914 elif dim not in (1, dimension_dict[char]): 

915 raise ValueError("Size of label '%s' for operand %d (%d) " 

916 "does not match previous terms (%d)." 

917 % (char, tnum, dimension_dict[char], dim)) 

918 else: 

919 dimension_dict[char] = dim 

920 

921 # Convert broadcast inds to sets 

922 broadcast_indices = [set(x) for x in broadcast_indices] 

923 

924 # Compute size of each input array plus the output array 

925 size_list = [_compute_size_by_dict(term, dimension_dict) 

926 for term in input_list + [output_subscript]] 

927 max_size = max(size_list) 

928 

929 if memory_limit is None: 

930 memory_arg = max_size 

931 else: 

932 memory_arg = memory_limit 

933 

934 # Compute naive cost 

935 # This isn't quite right, need to look into exactly how einsum does this 

936 inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0 

937 naive_cost = _flop_count( 

938 indices, inner_product, len(input_list), dimension_dict 

939 ) 

940 

941 # Compute the path 

942 if explicit_einsum_path: 

943 path = path_type[1:] 

944 elif ( 

945 (path_type is False) 

946 or (len(input_list) in [1, 2]) 

947 or (indices == output_set) 

948 ): 

949 # Nothing to be optimized, leave it to einsum 

950 path = [tuple(range(len(input_list)))] 

951 elif path_type == "greedy": 

952 path = _greedy_path( 

953 input_sets, output_set, dimension_dict, memory_arg 

954 ) 

955 elif path_type == "optimal": 

956 path = _optimal_path( 

957 input_sets, output_set, dimension_dict, memory_arg 

958 ) 

959 else: 

960 raise KeyError("Path name %s not found", path_type) 

961 

962 cost_list, scale_list, size_list, contraction_list = [], [], [], [] 

963 

964 # Build contraction tuple (positions, gemm, einsum_str, remaining) 

965 for cnum, contract_inds in enumerate(path): 

966 # Make sure we remove inds from right to left 

967 contract_inds = tuple(sorted(contract_inds, reverse=True)) 

968 

969 contract = _find_contraction(contract_inds, input_sets, output_set) 

970 out_inds, input_sets, idx_removed, idx_contract = contract 

971 

972 cost = _flop_count( 

973 idx_contract, idx_removed, len(contract_inds), dimension_dict 

974 ) 

975 cost_list.append(cost) 

976 scale_list.append(len(idx_contract)) 

977 size_list.append(_compute_size_by_dict(out_inds, dimension_dict)) 

978 

979 bcast = set() 

980 tmp_inputs = [] 

981 for x in contract_inds: 

982 tmp_inputs.append(input_list.pop(x)) 

983 bcast |= broadcast_indices.pop(x) 

984 

985 new_bcast_inds = bcast - idx_removed 

986 

987 # If we're broadcasting, nix blas 

988 if not len(idx_removed & bcast): 

989 do_blas = _can_dot(tmp_inputs, out_inds, idx_removed) 

990 else: 

991 do_blas = False 

992 

993 # Last contraction 

994 if (cnum - len(path)) == -1: 

995 idx_result = output_subscript 

996 else: 

997 sort_result = [(dimension_dict[ind], ind) for ind in out_inds] 

998 idx_result = "".join([x[1] for x in sorted(sort_result)]) 

999 

1000 input_list.append(idx_result) 

1001 broadcast_indices.append(new_bcast_inds) 

1002 einsum_str = ",".join(tmp_inputs) + "->" + idx_result 

1003 

1004 contraction = ( 

1005 contract_inds, idx_removed, einsum_str, input_list[:], do_blas 

1006 ) 

1007 contraction_list.append(contraction) 

1008 

1009 opt_cost = sum(cost_list) + 1 

1010 

1011 if len(input_list) != 1: 

1012 # Explicit "einsum_path" is usually trusted, but we detect this kind of 

1013 # mistake in order to prevent from returning an intermediate value. 

1014 raise RuntimeError( 

1015 "Invalid einsum_path is specified: {} more operands has to be " 

1016 "contracted.".format(len(input_list) - 1)) 

1017 

1018 if einsum_call_arg: 

1019 return (operands, contraction_list) 

1020 

1021 # Return the path along with a nice string representation 

1022 overall_contraction = input_subscripts + "->" + output_subscript 

1023 header = ("scaling", "current", "remaining") 

1024 

1025 speedup = naive_cost / opt_cost 

1026 max_i = max(size_list) 

1027 

1028 path_print = " Complete contraction: %s\n" % overall_contraction 

1029 path_print += " Naive scaling: %d\n" % len(indices) 

1030 path_print += " Optimized scaling: %d\n" % max(scale_list) 

1031 path_print += " Naive FLOP count: %.3e\n" % naive_cost 

1032 path_print += " Optimized FLOP count: %.3e\n" % opt_cost 

1033 path_print += " Theoretical speedup: %3.3f\n" % speedup 

1034 path_print += " Largest intermediate: %.3e elements\n" % max_i 

1035 path_print += "-" * 74 + "\n" 

1036 path_print += "%6s %24s %40s\n" % header 

1037 path_print += "-" * 74 

1038 

1039 for n, contraction in enumerate(contraction_list): 

1040 inds, idx_rm, einsum_str, remaining, blas = contraction 

1041 remaining_str = ",".join(remaining) + "->" + output_subscript 

1042 path_run = (scale_list[n], einsum_str, remaining_str) 

1043 path_print += "\n%4d %24s %40s" % path_run 

1044 

1045 path = ['einsum_path'] + path 

1046 return (path, path_print) 

1047 

1048 

1049def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs): 

1050 # Arguably we dispatch on more arguments than we really should; see note in 

1051 # _einsum_path_dispatcher for why. 

1052 yield from operands 

1053 yield out 

1054 

1055 

1056# Rewrite einsum to handle different cases 

1057@array_function_dispatch(_einsum_dispatcher, module='numpy') 

1058def einsum(*operands, out=None, optimize=False, **kwargs): 

1059 """ 

1060 einsum(subscripts, *operands, out=None, dtype=None, order='K', 

1061 casting='safe', optimize=False) 

1062 

1063 Evaluates the Einstein summation convention on the operands. 

1064 

1065 Using the Einstein summation convention, many common multi-dimensional, 

1066 linear algebraic array operations can be represented in a simple fashion. 

1067 In *implicit* mode `einsum` computes these values. 

1068 

1069 In *explicit* mode, `einsum` provides further flexibility to compute 

1070 other array operations that might not be considered classical Einstein 

1071 summation operations, by disabling, or forcing summation over specified 

1072 subscript labels. 

1073 

1074 See the notes and examples for clarification. 

1075 

1076 Parameters 

1077 ---------- 

1078 subscripts : str 

1079 Specifies the subscripts for summation as comma separated list of 

1080 subscript labels. An implicit (classical Einstein summation) 

1081 calculation is performed unless the explicit indicator '->' is 

1082 included as well as subscript labels of the precise output form. 

1083 operands : list of array_like 

1084 These are the arrays for the operation. 

1085 out : ndarray, optional 

1086 If provided, the calculation is done into this array. 

1087 dtype : {data-type, None}, optional 

1088 If provided, forces the calculation to use the data type specified. 

1089 Note that you may have to also give a more liberal `casting` 

1090 parameter to allow the conversions. Default is None. 

1091 order : {'C', 'F', 'A', 'K'}, optional 

1092 Controls the memory layout of the output. 'C' means it should 

1093 be C contiguous. 'F' means it should be Fortran contiguous, 

1094 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise. 

1095 'K' means it should be as close to the layout as the inputs as 

1096 is possible, including arbitrarily permuted axes. 

1097 Default is 'K'. 

1098 casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional 

1099 Controls what kind of data casting may occur. Setting this to 

1100 'unsafe' is not recommended, as it can adversely affect accumulations. 

1101 

1102 * 'no' means the data types should not be cast at all. 

1103 * 'equiv' means only byte-order changes are allowed. 

1104 * 'safe' means only casts which can preserve values are allowed. 

1105 * 'same_kind' means only safe casts or casts within a kind, 

1106 like float64 to float32, are allowed. 

1107 * 'unsafe' means any data conversions may be done. 

1108 

1109 Default is 'safe'. 

1110 optimize : {False, True, 'greedy', 'optimal'}, optional 

1111 Controls if intermediate optimization should occur. No optimization 

1112 will occur if False and True will default to the 'greedy' algorithm. 

1113 Also accepts an explicit contraction list from the ``np.einsum_path`` 

1114 function. See ``np.einsum_path`` for more details. Defaults to False. 

1115 

1116 Returns 

1117 ------- 

1118 output : ndarray 

1119 The calculation based on the Einstein summation convention. 

1120 

1121 See Also 

1122 -------- 

1123 einsum_path, dot, inner, outer, tensordot, linalg.multi_dot 

1124 einsum: 

1125 Similar verbose interface is provided by the 

1126 `einops <https://github.com/arogozhnikov/einops>`_ package to cover 

1127 additional operations: transpose, reshape/flatten, repeat/tile, 

1128 squeeze/unsqueeze and reductions. 

1129 The `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_ 

1130 optimizes contraction order for einsum-like expressions 

1131 in backend-agnostic manner. 

1132 

1133 Notes 

1134 ----- 

1135 The Einstein summation convention can be used to compute 

1136 many multi-dimensional, linear algebraic array operations. `einsum` 

1137 provides a succinct way of representing these. 

1138 

1139 A non-exhaustive list of these operations, 

1140 which can be computed by `einsum`, is shown below along with examples: 

1141 

1142 * Trace of an array, :py:func:`numpy.trace`. 

1143 * Return a diagonal, :py:func:`numpy.diag`. 

1144 * Array axis summations, :py:func:`numpy.sum`. 

1145 * Transpositions and permutations, :py:func:`numpy.transpose`. 

1146 * Matrix multiplication and dot product, :py:func:`numpy.matmul` 

1147 :py:func:`numpy.dot`. 

1148 * Vector inner and outer products, :py:func:`numpy.inner` 

1149 :py:func:`numpy.outer`. 

1150 * Broadcasting, element-wise and scalar multiplication, 

1151 :py:func:`numpy.multiply`. 

1152 * Tensor contractions, :py:func:`numpy.tensordot`. 

1153 * Chained array operations, in efficient calculation order, 

1154 :py:func:`numpy.einsum_path`. 

1155 

1156 The subscripts string is a comma-separated list of subscript labels, 

1157 where each label refers to a dimension of the corresponding operand. 

1158 Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)`` 

1159 is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label 

1160 appears only once, it is not summed, so ``np.einsum('i', a)`` 

1161 produces a view of ``a`` with no changes. A further example 

1162 ``np.einsum('ij,jk', a, b)`` describes traditional matrix multiplication 

1163 and is equivalent to :py:func:`np.matmul(a,b) <numpy.matmul>`. 

1164 Repeated subscript labels in one operand take the diagonal. 

1165 For example, ``np.einsum('ii', a)`` is equivalent to 

1166 :py:func:`np.trace(a) <numpy.trace>`. 

1167 

1168 In *implicit mode*, the chosen subscripts are important 

1169 since the axes of the output are reordered alphabetically. This 

1170 means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while 

1171 ``np.einsum('ji', a)`` takes its transpose. Additionally, 

1172 ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while, 

1173 ``np.einsum('ij,jh', a, b)`` returns the transpose of the 

1174 multiplication since subscript 'h' precedes subscript 'i'. 

1175 

1176 In *explicit mode* the output can be directly controlled by 

1177 specifying output subscript labels. This requires the 

1178 identifier '->' as well as the list of output subscript labels. 

1179 This feature increases the flexibility of the function since 

1180 summing can be disabled or forced when required. The call 

1181 ``np.einsum('i->', a)`` is like :py:func:`np.sum(a) <numpy.sum>` 

1182 if ``a`` is a 1-D array, and ``np.einsum('ii->i', a)`` 

1183 is like :py:func:`np.diag(a) <numpy.diag>` if ``a`` is a square 2-D array. 

1184 The difference is that `einsum` does not allow broadcasting by default. 

1185 Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the 

1186 order of the output subscript labels and therefore returns matrix 

1187 multiplication, unlike the example above in implicit mode. 

1188 

1189 To enable and control broadcasting, use an ellipsis. Default 

1190 NumPy-style broadcasting is done by adding an ellipsis 

1191 to the left of each term, like ``np.einsum('...ii->...i', a)``. 

1192 ``np.einsum('...i->...', a)`` is like 

1193 :py:func:`np.sum(a, axis=-1) <numpy.sum>` for array ``a`` of any shape. 

1194 To take the trace along the first and last axes, 

1195 you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix 

1196 product with the left-most indices instead of rightmost, one can do 

1197 ``np.einsum('ij...,jk...->ik...', a, b)``. 

1198 

1199 When there is only one operand, no axes are summed, and no output 

1200 parameter is provided, a view into the operand is returned instead 

1201 of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)`` 

1202 produces a view (changed in version 1.10.0). 

1203 

1204 `einsum` also provides an alternative way to provide the subscripts and 

1205 operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``. 

1206 If the output shape is not provided in this format `einsum` will be 

1207 calculated in implicit mode, otherwise it will be performed explicitly. 

1208 The examples below have corresponding `einsum` calls with the two 

1209 parameter methods. 

1210 

1211 Views returned from einsum are now writeable whenever the input array 

1212 is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now 

1213 have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>` 

1214 and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal 

1215 of a 2D array. 

1216 

1217 Added the ``optimize`` argument which will optimize the contraction order 

1218 of an einsum expression. For a contraction with three or more operands 

1219 this can greatly increase the computational efficiency at the cost of 

1220 a larger memory footprint during computation. 

1221 

1222 Typically a 'greedy' algorithm is applied which empirical tests have shown 

1223 returns the optimal path in the majority of cases. In some cases 'optimal' 

1224 will return the superlative path through a more expensive, exhaustive 

1225 search. For iterative calculations it may be advisable to calculate 

1226 the optimal path once and reuse that path by supplying it as an argument. 

1227 An example is given below. 

1228 

1229 See :py:func:`numpy.einsum_path` for more details. 

1230 

1231 Examples 

1232 -------- 

1233 >>> a = np.arange(25).reshape(5,5) 

1234 >>> b = np.arange(5) 

1235 >>> c = np.arange(6).reshape(2,3) 

1236 

1237 Trace of a matrix: 

1238 

1239 >>> np.einsum('ii', a) 

1240 60 

1241 >>> np.einsum(a, [0,0]) 

1242 60 

1243 >>> np.trace(a) 

1244 60 

1245 

1246 Extract the diagonal (requires explicit form): 

1247 

1248 >>> np.einsum('ii->i', a) 

1249 array([ 0, 6, 12, 18, 24]) 

1250 >>> np.einsum(a, [0,0], [0]) 

1251 array([ 0, 6, 12, 18, 24]) 

1252 >>> np.diag(a) 

1253 array([ 0, 6, 12, 18, 24]) 

1254 

1255 Sum over an axis (requires explicit form): 

1256 

1257 >>> np.einsum('ij->i', a) 

1258 array([ 10, 35, 60, 85, 110]) 

1259 >>> np.einsum(a, [0,1], [0]) 

1260 array([ 10, 35, 60, 85, 110]) 

1261 >>> np.sum(a, axis=1) 

1262 array([ 10, 35, 60, 85, 110]) 

1263 

1264 For higher dimensional arrays summing a single axis can be done 

1265 with ellipsis: 

1266 

1267 >>> np.einsum('...j->...', a) 

1268 array([ 10, 35, 60, 85, 110]) 

1269 >>> np.einsum(a, [Ellipsis,1], [Ellipsis]) 

1270 array([ 10, 35, 60, 85, 110]) 

1271 

1272 Compute a matrix transpose, or reorder any number of axes: 

1273 

1274 >>> np.einsum('ji', c) 

1275 array([[0, 3], 

1276 [1, 4], 

1277 [2, 5]]) 

1278 >>> np.einsum('ij->ji', c) 

1279 array([[0, 3], 

1280 [1, 4], 

1281 [2, 5]]) 

1282 >>> np.einsum(c, [1,0]) 

1283 array([[0, 3], 

1284 [1, 4], 

1285 [2, 5]]) 

1286 >>> np.transpose(c) 

1287 array([[0, 3], 

1288 [1, 4], 

1289 [2, 5]]) 

1290 

1291 Vector inner products: 

1292 

1293 >>> np.einsum('i,i', b, b) 

1294 30 

1295 >>> np.einsum(b, [0], b, [0]) 

1296 30 

1297 >>> np.inner(b,b) 

1298 30 

1299 

1300 Matrix vector multiplication: 

1301 

1302 >>> np.einsum('ij,j', a, b) 

1303 array([ 30, 80, 130, 180, 230]) 

1304 >>> np.einsum(a, [0,1], b, [1]) 

1305 array([ 30, 80, 130, 180, 230]) 

1306 >>> np.dot(a, b) 

1307 array([ 30, 80, 130, 180, 230]) 

1308 >>> np.einsum('...j,j', a, b) 

1309 array([ 30, 80, 130, 180, 230]) 

1310 

1311 Broadcasting and scalar multiplication: 

1312 

1313 >>> np.einsum('..., ...', 3, c) 

1314 array([[ 0, 3, 6], 

1315 [ 9, 12, 15]]) 

1316 >>> np.einsum(',ij', 3, c) 

1317 array([[ 0, 3, 6], 

1318 [ 9, 12, 15]]) 

1319 >>> np.einsum(3, [Ellipsis], c, [Ellipsis]) 

1320 array([[ 0, 3, 6], 

1321 [ 9, 12, 15]]) 

1322 >>> np.multiply(3, c) 

1323 array([[ 0, 3, 6], 

1324 [ 9, 12, 15]]) 

1325 

1326 Vector outer product: 

1327 

1328 >>> np.einsum('i,j', np.arange(2)+1, b) 

1329 array([[0, 1, 2, 3, 4], 

1330 [0, 2, 4, 6, 8]]) 

1331 >>> np.einsum(np.arange(2)+1, [0], b, [1]) 

1332 array([[0, 1, 2, 3, 4], 

1333 [0, 2, 4, 6, 8]]) 

1334 >>> np.outer(np.arange(2)+1, b) 

1335 array([[0, 1, 2, 3, 4], 

1336 [0, 2, 4, 6, 8]]) 

1337 

1338 Tensor contraction: 

1339 

1340 >>> a = np.arange(60.).reshape(3,4,5) 

1341 >>> b = np.arange(24.).reshape(4,3,2) 

1342 >>> np.einsum('ijk,jil->kl', a, b) 

1343 array([[4400., 4730.], 

1344 [4532., 4874.], 

1345 [4664., 5018.], 

1346 [4796., 5162.], 

1347 [4928., 5306.]]) 

1348 >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3]) 

1349 array([[4400., 4730.], 

1350 [4532., 4874.], 

1351 [4664., 5018.], 

1352 [4796., 5162.], 

1353 [4928., 5306.]]) 

1354 >>> np.tensordot(a,b, axes=([1,0],[0,1])) 

1355 array([[4400., 4730.], 

1356 [4532., 4874.], 

1357 [4664., 5018.], 

1358 [4796., 5162.], 

1359 [4928., 5306.]]) 

1360 

1361 Writeable returned arrays (since version 1.10.0): 

1362 

1363 >>> a = np.zeros((3, 3)) 

1364 >>> np.einsum('ii->i', a)[:] = 1 

1365 >>> a 

1366 array([[1., 0., 0.], 

1367 [0., 1., 0.], 

1368 [0., 0., 1.]]) 

1369 

1370 Example of ellipsis use: 

1371 

1372 >>> a = np.arange(6).reshape((3,2)) 

1373 >>> b = np.arange(12).reshape((4,3)) 

1374 >>> np.einsum('ki,jk->ij', a, b) 

1375 array([[10, 28, 46, 64], 

1376 [13, 40, 67, 94]]) 

1377 >>> np.einsum('ki,...k->i...', a, b) 

1378 array([[10, 28, 46, 64], 

1379 [13, 40, 67, 94]]) 

1380 >>> np.einsum('k...,jk', a, b) 

1381 array([[10, 28, 46, 64], 

1382 [13, 40, 67, 94]]) 

1383 

1384 Chained array operations. For more complicated contractions, speed ups 

1385 might be achieved by repeatedly computing a 'greedy' path or pre-computing 

1386 the 'optimal' path and repeatedly applying it, using an `einsum_path` 

1387 insertion (since version 1.12.0). Performance improvements can be 

1388 particularly significant with larger arrays: 

1389 

1390 >>> a = np.ones(64).reshape(2,4,8) 

1391 

1392 Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.) 

1393 

1394 >>> for iteration in range(500): 

1395 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a) 

1396 

1397 Sub-optimal `einsum` (due to repeated path calculation time): ~330ms 

1398 

1399 >>> for iteration in range(500): 

1400 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, 

1401 ... optimize='optimal') 

1402 

1403 Greedy `einsum` (faster optimal path approximation): ~160ms 

1404 

1405 >>> for iteration in range(500): 

1406 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy') 

1407 

1408 Optimal `einsum` (best usage pattern in some use cases): ~110ms 

1409 

1410 >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, 

1411 ... optimize='optimal')[0] 

1412 >>> for iteration in range(500): 

1413 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path) 

1414 

1415 """ 

1416 # Special handling if out is specified 

1417 specified_out = out is not None 

1418 

1419 # If no optimization, run pure einsum 

1420 if optimize is False: 

1421 if specified_out: 

1422 kwargs['out'] = out 

1423 return c_einsum(*operands, **kwargs) 

1424 

1425 # Check the kwargs to avoid a more cryptic error later, without having to 

1426 # repeat default values here 

1427 valid_einsum_kwargs = ['dtype', 'order', 'casting'] 

1428 unknown_kwargs = [k for (k, v) in kwargs.items() if 

1429 k not in valid_einsum_kwargs] 

1430 if len(unknown_kwargs): 

1431 raise TypeError("Did not understand the following kwargs: %s" 

1432 % unknown_kwargs) 

1433 

1434 # Build the contraction list and operand 

1435 operands, contraction_list = einsum_path(*operands, optimize=optimize, 

1436 einsum_call=True) 

1437 

1438 # Handle order kwarg for output array, c_einsum allows mixed case 

1439 output_order = kwargs.pop('order', 'K') 

1440 if output_order.upper() == 'A': 

1441 if all(arr.flags.f_contiguous for arr in operands): 

1442 output_order = 'F' 

1443 else: 

1444 output_order = 'C' 

1445 

1446 # Start contraction loop 

1447 for num, contraction in enumerate(contraction_list): 

1448 inds, idx_rm, einsum_str, remaining, blas = contraction 

1449 tmp_operands = [operands.pop(x) for x in inds] 

1450 

1451 # Do we need to deal with the output? 

1452 handle_out = specified_out and ((num + 1) == len(contraction_list)) 

1453 

1454 # Call tensordot if still possible 

1455 if blas: 

1456 # Checks have already been handled 

1457 input_str, results_index = einsum_str.split('->') 

1458 input_left, input_right = input_str.split(',') 

1459 

1460 tensor_result = input_left + input_right 

1461 for s in idx_rm: 

1462 tensor_result = tensor_result.replace(s, "") 

1463 

1464 # Find indices to contract over 

1465 left_pos, right_pos = [], [] 

1466 for s in sorted(idx_rm): 

1467 left_pos.append(input_left.find(s)) 

1468 right_pos.append(input_right.find(s)) 

1469 

1470 # Contract! 

1471 new_view = tensordot( 

1472 *tmp_operands, axes=(tuple(left_pos), tuple(right_pos)) 

1473 ) 

1474 

1475 # Build a new view if needed 

1476 if (tensor_result != results_index) or handle_out: 

1477 if handle_out: 

1478 kwargs["out"] = out 

1479 new_view = c_einsum( 

1480 tensor_result + '->' + results_index, new_view, **kwargs 

1481 ) 

1482 

1483 # Call einsum 

1484 else: 

1485 # If out was specified 

1486 if handle_out: 

1487 kwargs["out"] = out 

1488 

1489 # Do the contraction 

1490 new_view = c_einsum(einsum_str, *tmp_operands, **kwargs) 

1491 

1492 # Append new items and dereference what we can 

1493 operands.append(new_view) 

1494 del tmp_operands, new_view 

1495 

1496 if specified_out: 

1497 return out 

1498 else: 

1499 return asanyarray(operands[0], order=output_order)