Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/numpy/_core/einsumfunc.py: 6%

413 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-09 06:12 +0000

1""" 

2Implementation of optimized einsum. 

3 

4""" 

5import itertools 

6import operator 

7 

8from numpy._core.multiarray import c_einsum 

9from numpy._core.numeric import asanyarray, tensordot 

10from numpy._core.overrides import array_function_dispatch 

11 

12__all__ = ['einsum', 'einsum_path'] 

13 

14# importing string for string.ascii_letters would be too slow 

15# the first import before caching has been measured to take 800 µs (#23777) 

16einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' 

17einsum_symbols_set = set(einsum_symbols) 

18 

19 

20def _flop_count(idx_contraction, inner, num_terms, size_dictionary): 

21 """ 

22 Computes the number of FLOPS in the contraction. 

23 

24 Parameters 

25 ---------- 

26 idx_contraction : iterable 

27 The indices involved in the contraction 

28 inner : bool 

29 Does this contraction require an inner product? 

30 num_terms : int 

31 The number of terms in a contraction 

32 size_dictionary : dict 

33 The size of each of the indices in idx_contraction 

34 

35 Returns 

36 ------- 

37 flop_count : int 

38 The total number of FLOPS required for the contraction. 

39 

40 Examples 

41 -------- 

42 

43 >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5}) 

44 30 

45 

46 >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5}) 

47 60 

48 

49 """ 

50 

51 overall_size = _compute_size_by_dict(idx_contraction, size_dictionary) 

52 op_factor = max(1, num_terms - 1) 

53 if inner: 

54 op_factor += 1 

55 

56 return overall_size * op_factor 

57 

58def _compute_size_by_dict(indices, idx_dict): 

59 """ 

60 Computes the product of the elements in indices based on the dictionary 

61 idx_dict. 

62 

63 Parameters 

64 ---------- 

65 indices : iterable 

66 Indices to base the product on. 

67 idx_dict : dictionary 

68 Dictionary of index sizes 

69 

70 Returns 

71 ------- 

72 ret : int 

73 The resulting product. 

74 

75 Examples 

76 -------- 

77 >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5}) 

78 90 

79 

80 """ 

81 ret = 1 

82 for i in indices: 

83 ret *= idx_dict[i] 

84 return ret 

85 

86 

87def _find_contraction(positions, input_sets, output_set): 

88 """ 

89 Finds the contraction for a given set of input and output sets. 

90 

91 Parameters 

92 ---------- 

93 positions : iterable 

94 Integer positions of terms used in the contraction. 

95 input_sets : list 

96 List of sets that represent the lhs side of the einsum subscript 

97 output_set : set 

98 Set that represents the rhs side of the overall einsum subscript 

99 

100 Returns 

101 ------- 

102 new_result : set 

103 The indices of the resulting contraction 

104 remaining : list 

105 List of sets that have not been contracted, the new set is appended to 

106 the end of this list 

107 idx_removed : set 

108 Indices removed from the entire contraction 

109 idx_contraction : set 

110 The indices used in the current contraction 

111 

112 Examples 

113 -------- 

114 

115 # A simple dot product test case 

116 >>> pos = (0, 1) 

117 >>> isets = [set('ab'), set('bc')] 

118 >>> oset = set('ac') 

119 >>> _find_contraction(pos, isets, oset) 

120 ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'}) 

121 

122 # A more complex case with additional terms in the contraction 

123 >>> pos = (0, 2) 

124 >>> isets = [set('abd'), set('ac'), set('bdc')] 

125 >>> oset = set('ac') 

126 >>> _find_contraction(pos, isets, oset) 

127 ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'}) 

128 """ 

129 

130 idx_contract = set() 

131 idx_remain = output_set.copy() 

132 remaining = [] 

133 for ind, value in enumerate(input_sets): 

134 if ind in positions: 

135 idx_contract |= value 

136 else: 

137 remaining.append(value) 

138 idx_remain |= value 

139 

140 new_result = idx_remain & idx_contract 

141 idx_removed = (idx_contract - new_result) 

142 remaining.append(new_result) 

143 

144 return (new_result, remaining, idx_removed, idx_contract) 

145 

146 

147def _optimal_path(input_sets, output_set, idx_dict, memory_limit): 

148 """ 

149 Computes all possible pair contractions, sieves the results based 

150 on ``memory_limit`` and returns the lowest cost path. This algorithm 

151 scales factorial with respect to the elements in the list ``input_sets``. 

152 

153 Parameters 

154 ---------- 

155 input_sets : list 

156 List of sets that represent the lhs side of the einsum subscript 

157 output_set : set 

158 Set that represents the rhs side of the overall einsum subscript 

159 idx_dict : dictionary 

160 Dictionary of index sizes 

161 memory_limit : int 

162 The maximum number of elements in a temporary array 

163 

164 Returns 

165 ------- 

166 path : list 

167 The optimal contraction order within the memory limit constraint. 

168 

169 Examples 

170 -------- 

171 >>> isets = [set('abd'), set('ac'), set('bdc')] 

172 >>> oset = set() 

173 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 

174 >>> _optimal_path(isets, oset, idx_sizes, 5000) 

175 [(0, 2), (0, 1)] 

176 """ 

177 

178 full_results = [(0, [], input_sets)] 

179 for iteration in range(len(input_sets) - 1): 

180 iter_results = [] 

181 

182 # Compute all unique pairs 

183 for curr in full_results: 

184 cost, positions, remaining = curr 

185 for con in itertools.combinations( 

186 range(len(input_sets) - iteration), 2 

187 ): 

188 

189 # Find the contraction 

190 cont = _find_contraction(con, remaining, output_set) 

191 new_result, new_input_sets, idx_removed, idx_contract = cont 

192 

193 # Sieve the results based on memory_limit 

194 new_size = _compute_size_by_dict(new_result, idx_dict) 

195 if new_size > memory_limit: 

196 continue 

197 

198 # Build (total_cost, positions, indices_remaining) 

199 total_cost = cost + _flop_count( 

200 idx_contract, idx_removed, len(con), idx_dict 

201 ) 

202 new_pos = positions + [con] 

203 iter_results.append((total_cost, new_pos, new_input_sets)) 

204 

205 # Update combinatorial list, if we did not find anything return best 

206 # path + remaining contractions 

207 if iter_results: 

208 full_results = iter_results 

209 else: 

210 path = min(full_results, key=lambda x: x[0])[1] 

211 path += [tuple(range(len(input_sets) - iteration))] 

212 return path 

213 

214 # If we have not found anything return single einsum contraction 

215 if len(full_results) == 0: 

216 return [tuple(range(len(input_sets)))] 

217 

218 path = min(full_results, key=lambda x: x[0])[1] 

219 return path 

220 

221def _parse_possible_contraction( 

222 positions, input_sets, output_set, idx_dict, 

223 memory_limit, path_cost, naive_cost 

224 ): 

225 """Compute the cost (removed size + flops) and resultant indices for 

226 performing the contraction specified by ``positions``. 

227 

228 Parameters 

229 ---------- 

230 positions : tuple of int 

231 The locations of the proposed tensors to contract. 

232 input_sets : list of sets 

233 The indices found on each tensors. 

234 output_set : set 

235 The output indices of the expression. 

236 idx_dict : dict 

237 Mapping of each index to its size. 

238 memory_limit : int 

239 The total allowed size for an intermediary tensor. 

240 path_cost : int 

241 The contraction cost so far. 

242 naive_cost : int 

243 The cost of the unoptimized expression. 

244 

245 Returns 

246 ------- 

247 cost : (int, int) 

248 A tuple containing the size of any indices removed, and the flop cost. 

249 positions : tuple of int 

250 The locations of the proposed tensors to contract. 

251 new_input_sets : list of sets 

252 The resulting new list of indices if this proposed contraction 

253 is performed. 

254 

255 """ 

256 

257 # Find the contraction 

258 contract = _find_contraction(positions, input_sets, output_set) 

259 idx_result, new_input_sets, idx_removed, idx_contract = contract 

260 

261 # Sieve the results based on memory_limit 

262 new_size = _compute_size_by_dict(idx_result, idx_dict) 

263 if new_size > memory_limit: 

264 return None 

265 

266 # Build sort tuple 

267 old_sizes = ( 

268 _compute_size_by_dict(input_sets[p], idx_dict) for p in positions 

269 ) 

270 removed_size = sum(old_sizes) - new_size 

271 

272 # NB: removed_size used to be just the size of any removed indices i.e.: 

273 # helpers.compute_size_by_dict(idx_removed, idx_dict) 

274 cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict) 

275 sort = (-removed_size, cost) 

276 

277 # Sieve based on total cost as well 

278 if (path_cost + cost) > naive_cost: 

279 return None 

280 

281 # Add contraction to possible choices 

282 return [sort, positions, new_input_sets] 

283 

284 

285def _update_other_results(results, best): 

286 """Update the positions and provisional input_sets of ``results`` 

287 based on performing the contraction result ``best``. Remove any 

288 involving the tensors contracted. 

289 

290 Parameters 

291 ---------- 

292 results : list 

293 List of contraction results produced by  

294 ``_parse_possible_contraction``. 

295 best : list 

296 The best contraction of ``results`` i.e. the one that 

297 will be performed. 

298 

299 Returns 

300 ------- 

301 mod_results : list 

302 The list of modified results, updated with outcome of 

303 ``best`` contraction. 

304 """ 

305 

306 best_con = best[1] 

307 bx, by = best_con 

308 mod_results = [] 

309 

310 for cost, (x, y), con_sets in results: 

311 

312 # Ignore results involving tensors just contracted 

313 if x in best_con or y in best_con: 

314 continue 

315 

316 # Update the input_sets 

317 del con_sets[by - int(by > x) - int(by > y)] 

318 del con_sets[bx - int(bx > x) - int(bx > y)] 

319 con_sets.insert(-1, best[2][-1]) 

320 

321 # Update the position indices 

322 mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by) 

323 mod_results.append((cost, mod_con, con_sets)) 

324 

325 return mod_results 

326 

327def _greedy_path(input_sets, output_set, idx_dict, memory_limit): 

328 """ 

329 Finds the path by contracting the best pair until the input list is 

330 exhausted. The best pair is found by minimizing the tuple 

331 ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing 

332 matrix multiplication or inner product operations, then Hadamard like 

333 operations, and finally outer operations. Outer products are limited by 

334 ``memory_limit``. This algorithm scales cubically with respect to the 

335 number of elements in the list ``input_sets``. 

336 

337 Parameters 

338 ---------- 

339 input_sets : list 

340 List of sets that represent the lhs side of the einsum subscript 

341 output_set : set 

342 Set that represents the rhs side of the overall einsum subscript 

343 idx_dict : dictionary 

344 Dictionary of index sizes 

345 memory_limit : int 

346 The maximum number of elements in a temporary array 

347 

348 Returns 

349 ------- 

350 path : list 

351 The greedy contraction order within the memory limit constraint. 

352 

353 Examples 

354 -------- 

355 >>> isets = [set('abd'), set('ac'), set('bdc')] 

356 >>> oset = set() 

357 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 

358 >>> _greedy_path(isets, oset, idx_sizes, 5000) 

359 [(0, 2), (0, 1)] 

360 """ 

361 

362 # Handle trivial cases that leaked through 

363 if len(input_sets) == 1: 

364 return [(0,)] 

365 elif len(input_sets) == 2: 

366 return [(0, 1)] 

367 

368 # Build up a naive cost 

369 contract = _find_contraction( 

370 range(len(input_sets)), input_sets, output_set 

371 ) 

372 idx_result, new_input_sets, idx_removed, idx_contract = contract 

373 naive_cost = _flop_count( 

374 idx_contract, idx_removed, len(input_sets), idx_dict 

375 ) 

376 

377 # Initially iterate over all pairs 

378 comb_iter = itertools.combinations(range(len(input_sets)), 2) 

379 known_contractions = [] 

380 

381 path_cost = 0 

382 path = [] 

383 

384 for iteration in range(len(input_sets) - 1): 

385 

386 # Iterate over all pairs on the first step, only previously 

387 # found pairs on subsequent steps 

388 for positions in comb_iter: 

389 

390 # Always initially ignore outer products 

391 if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]): 

392 continue 

393 

394 result = _parse_possible_contraction( 

395 positions, input_sets, output_set, idx_dict, 

396 memory_limit, path_cost, naive_cost 

397 ) 

398 if result is not None: 

399 known_contractions.append(result) 

400 

401 # If we do not have a inner contraction, rescan pairs 

402 # including outer products 

403 if len(known_contractions) == 0: 

404 

405 # Then check the outer products 

406 for positions in itertools.combinations( 

407 range(len(input_sets)), 2 

408 ): 

409 result = _parse_possible_contraction( 

410 positions, input_sets, output_set, idx_dict, 

411 memory_limit, path_cost, naive_cost 

412 ) 

413 if result is not None: 

414 known_contractions.append(result) 

415 

416 # If we still did not find any remaining contractions, 

417 # default back to einsum like behavior 

418 if len(known_contractions) == 0: 

419 path.append(tuple(range(len(input_sets)))) 

420 break 

421 

422 # Sort based on first index 

423 best = min(known_contractions, key=lambda x: x[0]) 

424 

425 # Now propagate as many unused contractions as possible 

426 # to the next iteration 

427 known_contractions = _update_other_results(known_contractions, best) 

428 

429 # Next iteration only compute contractions with the new tensor 

430 # All other contractions have been accounted for 

431 input_sets = best[2] 

432 new_tensor_pos = len(input_sets) - 1 

433 comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos)) 

434 

435 # Update path and total cost 

436 path.append(best[1]) 

437 path_cost += best[0][1] 

438 

439 return path 

440 

441 

442def _can_dot(inputs, result, idx_removed): 

443 """ 

444 Checks if we can use BLAS (np.tensordot) call and its beneficial to do so. 

445 

446 Parameters 

447 ---------- 

448 inputs : list of str 

449 Specifies the subscripts for summation. 

450 result : str 

451 Resulting summation. 

452 idx_removed : set 

453 Indices that are removed in the summation 

454 

455 

456 Returns 

457 ------- 

458 type : bool 

459 Returns true if BLAS should and can be used, else False 

460 

461 Notes 

462 ----- 

463 If the operations is BLAS level 1 or 2 and is not already aligned 

464 we default back to einsum as the memory movement to copy is more 

465 costly than the operation itself. 

466 

467 

468 Examples 

469 -------- 

470 

471 # Standard GEMM operation 

472 >>> _can_dot(['ij', 'jk'], 'ik', set('j')) 

473 True 

474 

475 # Can use the standard BLAS, but requires odd data movement 

476 >>> _can_dot(['ijj', 'jk'], 'ik', set('j')) 

477 False 

478 

479 # DDOT where the memory is not aligned 

480 >>> _can_dot(['ijk', 'ikj'], '', set('ijk')) 

481 False 

482 

483 """ 

484 

485 # All `dot` calls remove indices 

486 if len(idx_removed) == 0: 

487 return False 

488 

489 # BLAS can only handle two operands 

490 if len(inputs) != 2: 

491 return False 

492 

493 input_left, input_right = inputs 

494 

495 for c in set(input_left + input_right): 

496 # can't deal with repeated indices on same input or more than 2 total 

497 nl, nr = input_left.count(c), input_right.count(c) 

498 if (nl > 1) or (nr > 1) or (nl + nr > 2): 

499 return False 

500 

501 # can't do implicit summation or dimension collapse e.g. 

502 # "ab,bc->c" (implicitly sum over 'a') 

503 # "ab,ca->ca" (take diagonal of 'a') 

504 if nl + nr - 1 == int(c in result): 

505 return False 

506 

507 # Build a few temporaries 

508 set_left = set(input_left) 

509 set_right = set(input_right) 

510 keep_left = set_left - idx_removed 

511 keep_right = set_right - idx_removed 

512 rs = len(idx_removed) 

513 

514 # At this point we are a DOT, GEMV, or GEMM operation 

515 

516 # Handle inner products 

517 

518 # DDOT with aligned data 

519 if input_left == input_right: 

520 return True 

521 

522 # DDOT without aligned data (better to use einsum) 

523 if set_left == set_right: 

524 return False 

525 

526 # Handle the 4 possible (aligned) GEMV or GEMM cases 

527 

528 # GEMM or GEMV no transpose 

529 if input_left[-rs:] == input_right[:rs]: 

530 return True 

531 

532 # GEMM or GEMV transpose both 

533 if input_left[:rs] == input_right[-rs:]: 

534 return True 

535 

536 # GEMM or GEMV transpose right 

537 if input_left[-rs:] == input_right[-rs:]: 

538 return True 

539 

540 # GEMM or GEMV transpose left 

541 if input_left[:rs] == input_right[:rs]: 

542 return True 

543 

544 # Einsum is faster than GEMV if we have to copy data 

545 if not keep_left or not keep_right: 

546 return False 

547 

548 # We are a matrix-matrix product, but we need to copy data 

549 return True 

550 

551 

552def _parse_einsum_input(operands): 

553 """ 

554 A reproduction of einsum c side einsum parsing in python. 

555 

556 Returns 

557 ------- 

558 input_strings : str 

559 Parsed input strings 

560 output_string : str 

561 Parsed output string 

562 operands : list of array_like 

563 The operands to use in the numpy contraction 

564 

565 Examples 

566 -------- 

567 The operand list is simplified to reduce printing: 

568 

569 >>> np.random.seed(123) 

570 >>> a = np.random.rand(4, 4) 

571 >>> b = np.random.rand(4, 4, 4) 

572 >>> _parse_einsum_input(('...a,...a->...', a, b)) 

573 ('za,xza', 'xz', [a, b]) # may vary 

574 

575 >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) 

576 ('za,xza', 'xz', [a, b]) # may vary 

577 """ 

578 

579 if len(operands) == 0: 

580 raise ValueError("No input operands") 

581 

582 if isinstance(operands[0], str): 

583 subscripts = operands[0].replace(" ", "") 

584 operands = [asanyarray(v) for v in operands[1:]] 

585 

586 # Ensure all characters are valid 

587 for s in subscripts: 

588 if s in '.,->': 

589 continue 

590 if s not in einsum_symbols: 

591 raise ValueError("Character %s is not a valid symbol." % s) 

592 

593 else: 

594 tmp_operands = list(operands) 

595 operand_list = [] 

596 subscript_list = [] 

597 for p in range(len(operands) // 2): 

598 operand_list.append(tmp_operands.pop(0)) 

599 subscript_list.append(tmp_operands.pop(0)) 

600 

601 output_list = tmp_operands[-1] if len(tmp_operands) else None 

602 operands = [asanyarray(v) for v in operand_list] 

603 subscripts = "" 

604 last = len(subscript_list) - 1 

605 for num, sub in enumerate(subscript_list): 

606 for s in sub: 

607 if s is Ellipsis: 

608 subscripts += "..." 

609 else: 

610 try: 

611 s = operator.index(s) 

612 except TypeError as e: 

613 raise TypeError( 

614 "For this input type lists must contain " 

615 "either int or Ellipsis" 

616 ) from e 

617 subscripts += einsum_symbols[s] 

618 if num != last: 

619 subscripts += "," 

620 

621 if output_list is not None: 

622 subscripts += "->" 

623 for s in output_list: 

624 if s is Ellipsis: 

625 subscripts += "..." 

626 else: 

627 try: 

628 s = operator.index(s) 

629 except TypeError as e: 

630 raise TypeError( 

631 "For this input type lists must contain " 

632 "either int or Ellipsis" 

633 ) from e 

634 subscripts += einsum_symbols[s] 

635 # Check for proper "->" 

636 if ("-" in subscripts) or (">" in subscripts): 

637 invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1) 

638 if invalid or (subscripts.count("->") != 1): 

639 raise ValueError("Subscripts can only contain one '->'.") 

640 

641 # Parse ellipses 

642 if "." in subscripts: 

643 used = subscripts.replace(".", "").replace(",", "").replace("->", "") 

644 unused = list(einsum_symbols_set - set(used)) 

645 ellipse_inds = "".join(unused) 

646 longest = 0 

647 

648 if "->" in subscripts: 

649 input_tmp, output_sub = subscripts.split("->") 

650 split_subscripts = input_tmp.split(",") 

651 out_sub = True 

652 else: 

653 split_subscripts = subscripts.split(',') 

654 out_sub = False 

655 

656 for num, sub in enumerate(split_subscripts): 

657 if "." in sub: 

658 if (sub.count(".") != 3) or (sub.count("...") != 1): 

659 raise ValueError("Invalid Ellipses.") 

660 

661 # Take into account numerical values 

662 if operands[num].shape == (): 

663 ellipse_count = 0 

664 else: 

665 ellipse_count = max(operands[num].ndim, 1) 

666 ellipse_count -= (len(sub) - 3) 

667 

668 if ellipse_count > longest: 

669 longest = ellipse_count 

670 

671 if ellipse_count < 0: 

672 raise ValueError("Ellipses lengths do not match.") 

673 elif ellipse_count == 0: 

674 split_subscripts[num] = sub.replace('...', '') 

675 else: 

676 rep_inds = ellipse_inds[-ellipse_count:] 

677 split_subscripts[num] = sub.replace('...', rep_inds) 

678 

679 subscripts = ",".join(split_subscripts) 

680 if longest == 0: 

681 out_ellipse = "" 

682 else: 

683 out_ellipse = ellipse_inds[-longest:] 

684 

685 if out_sub: 

686 subscripts += "->" + output_sub.replace("...", out_ellipse) 

687 else: 

688 # Special care for outputless ellipses 

689 output_subscript = "" 

690 tmp_subscripts = subscripts.replace(",", "") 

691 for s in sorted(set(tmp_subscripts)): 

692 if s not in (einsum_symbols): 

693 raise ValueError("Character %s is not a valid symbol." % s) 

694 if tmp_subscripts.count(s) == 1: 

695 output_subscript += s 

696 normal_inds = ''.join(sorted(set(output_subscript) - 

697 set(out_ellipse))) 

698 

699 subscripts += "->" + out_ellipse + normal_inds 

700 

701 # Build output string if does not exist 

702 if "->" in subscripts: 

703 input_subscripts, output_subscript = subscripts.split("->") 

704 else: 

705 input_subscripts = subscripts 

706 # Build output subscripts 

707 tmp_subscripts = subscripts.replace(",", "") 

708 output_subscript = "" 

709 for s in sorted(set(tmp_subscripts)): 

710 if s not in einsum_symbols: 

711 raise ValueError("Character %s is not a valid symbol." % s) 

712 if tmp_subscripts.count(s) == 1: 

713 output_subscript += s 

714 

715 # Make sure output subscripts are in the input 

716 for char in output_subscript: 

717 if output_subscript.count(char) != 1: 

718 raise ValueError("Output character %s appeared more than once in " 

719 "the output." % char) 

720 if char not in input_subscripts: 

721 raise ValueError("Output character %s did not appear in the input" 

722 % char) 

723 

724 # Make sure number operands is equivalent to the number of terms 

725 if len(input_subscripts.split(',')) != len(operands): 

726 raise ValueError("Number of einsum subscripts must be equal to the " 

727 "number of operands.") 

728 

729 return (input_subscripts, output_subscript, operands) 

730 

731 

732def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None): 

733 # NOTE: technically, we should only dispatch on array-like arguments, not 

734 # subscripts (given as strings). But separating operands into 

735 # arrays/subscripts is a little tricky/slow (given einsum's two supported 

736 # signatures), so as a practical shortcut we dispatch on everything. 

737 # Strings will be ignored for dispatching since they don't define 

738 # __array_function__. 

739 return operands 

740 

741 

742@array_function_dispatch(_einsum_path_dispatcher, module='numpy') 

743def einsum_path(*operands, optimize='greedy', einsum_call=False): 

744 """ 

745 einsum_path(subscripts, *operands, optimize='greedy') 

746 

747 Evaluates the lowest cost contraction order for an einsum expression by 

748 considering the creation of intermediate arrays. 

749 

750 Parameters 

751 ---------- 

752 subscripts : str 

753 Specifies the subscripts for summation. 

754 *operands : list of array_like 

755 These are the arrays for the operation. 

756 optimize : {bool, list, tuple, 'greedy', 'optimal'} 

757 Choose the type of path. If a tuple is provided, the second argument is 

758 assumed to be the maximum intermediate size created. If only a single 

759 argument is provided the largest input or output array size is used 

760 as a maximum intermediate size. 

761 

762 * if a list is given that starts with ``einsum_path``, uses this as the 

763 contraction path 

764 * if False no optimization is taken 

765 * if True defaults to the 'greedy' algorithm 

766 * 'optimal' An algorithm that combinatorially explores all possible 

767 ways of contracting the listed tensors and chooses the least costly 

768 path. Scales exponentially with the number of terms in the 

769 contraction. 

770 * 'greedy' An algorithm that chooses the best pair contraction 

771 at each step. Effectively, this algorithm searches the largest inner, 

772 Hadamard, and then outer products at each step. Scales cubically with 

773 the number of terms in the contraction. Equivalent to the 'optimal' 

774 path for most contractions. 

775 

776 Default is 'greedy'. 

777 

778 Returns 

779 ------- 

780 path : list of tuples 

781 A list representation of the einsum path. 

782 string_repr : str 

783 A printable representation of the einsum path. 

784 

785 Notes 

786 ----- 

787 The resulting path indicates which terms of the input contraction should be 

788 contracted first, the result of this contraction is then appended to the 

789 end of the contraction list. This list can then be iterated over until all 

790 intermediate contractions are complete. 

791 

792 See Also 

793 -------- 

794 einsum, linalg.multi_dot 

795 

796 Examples 

797 -------- 

798 

799 We can begin with a chain dot example. In this case, it is optimal to 

800 contract the ``b`` and ``c`` tensors first as represented by the first 

801 element of the path ``(1, 2)``. The resulting tensor is added to the end 

802 of the contraction and the remaining contraction ``(0, 1)`` is then 

803 completed. 

804 

805 >>> np.random.seed(123) 

806 >>> a = np.random.rand(2, 2) 

807 >>> b = np.random.rand(2, 5) 

808 >>> c = np.random.rand(5, 2) 

809 >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy') 

810 >>> print(path_info[0]) 

811 ['einsum_path', (1, 2), (0, 1)] 

812 >>> print(path_info[1]) 

813 Complete contraction: ij,jk,kl->il # may vary 

814 Naive scaling: 4 

815 Optimized scaling: 3 

816 Naive FLOP count: 1.600e+02 

817 Optimized FLOP count: 5.600e+01 

818 Theoretical speedup: 2.857 

819 Largest intermediate: 4.000e+00 elements 

820 ------------------------------------------------------------------------- 

821 scaling current remaining 

822 ------------------------------------------------------------------------- 

823 3 kl,jk->jl ij,jl->il 

824 3 jl,ij->il il->il 

825 

826 

827 A more complex index transformation example. 

828 

829 >>> I = np.random.rand(10, 10, 10, 10) 

830 >>> C = np.random.rand(10, 10) 

831 >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C, 

832 ... optimize='greedy') 

833 

834 >>> print(path_info[0]) 

835 ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)] 

836 >>> print(path_info[1])  

837 Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary 

838 Naive scaling: 8 

839 Optimized scaling: 5 

840 Naive FLOP count: 8.000e+08 

841 Optimized FLOP count: 8.000e+05 

842 Theoretical speedup: 1000.000 

843 Largest intermediate: 1.000e+04 elements 

844 -------------------------------------------------------------------------- 

845 scaling current remaining 

846 -------------------------------------------------------------------------- 

847 5 abcd,ea->bcde fb,gc,hd,bcde->efgh 

848 5 bcde,fb->cdef gc,hd,cdef->efgh 

849 5 cdef,gc->defg hd,defg->efgh 

850 5 defg,hd->efgh efgh->efgh 

851 """ 

852 

853 # Figure out what the path really is 

854 path_type = optimize 

855 if path_type is True: 

856 path_type = 'greedy' 

857 if path_type is None: 

858 path_type = False 

859 

860 explicit_einsum_path = False 

861 memory_limit = None 

862 

863 # No optimization or a named path algorithm 

864 if (path_type is False) or isinstance(path_type, str): 

865 pass 

866 

867 # Given an explicit path 

868 elif len(path_type) and (path_type[0] == 'einsum_path'): 

869 explicit_einsum_path = True 

870 

871 # Path tuple with memory limit 

872 elif ((len(path_type) == 2) and isinstance(path_type[0], str) and 

873 isinstance(path_type[1], (int, float))): 

874 memory_limit = int(path_type[1]) 

875 path_type = path_type[0] 

876 

877 else: 

878 raise TypeError("Did not understand the path: %s" % str(path_type)) 

879 

880 # Hidden option, only einsum should call this 

881 einsum_call_arg = einsum_call 

882 

883 # Python side parsing 

884 input_subscripts, output_subscript, operands = ( 

885 _parse_einsum_input(operands) 

886 ) 

887 

888 # Build a few useful list and sets 

889 input_list = input_subscripts.split(',') 

890 input_sets = [set(x) for x in input_list] 

891 output_set = set(output_subscript) 

892 indices = set(input_subscripts.replace(',', '')) 

893 

894 # Get length of each unique dimension and ensure all dimensions are correct 

895 dimension_dict = {} 

896 broadcast_indices = [[] for x in range(len(input_list))] 

897 for tnum, term in enumerate(input_list): 

898 sh = operands[tnum].shape 

899 if len(sh) != len(term): 

900 raise ValueError("Einstein sum subscript %s does not contain the " 

901 "correct number of indices for operand %d." 

902 % (input_subscripts[tnum], tnum)) 

903 for cnum, char in enumerate(term): 

904 dim = sh[cnum] 

905 

906 # Build out broadcast indices 

907 if dim == 1: 

908 broadcast_indices[tnum].append(char) 

909 

910 if char in dimension_dict.keys(): 

911 # For broadcasting cases we always want the largest dim size 

912 if dimension_dict[char] == 1: 

913 dimension_dict[char] = dim 

914 elif dim not in (1, dimension_dict[char]): 

915 raise ValueError("Size of label '%s' for operand %d (%d) " 

916 "does not match previous terms (%d)." 

917 % (char, tnum, dimension_dict[char], dim)) 

918 else: 

919 dimension_dict[char] = dim 

920 

921 # Convert broadcast inds to sets 

922 broadcast_indices = [set(x) for x in broadcast_indices] 

923 

924 # Compute size of each input array plus the output array 

925 size_list = [_compute_size_by_dict(term, dimension_dict) 

926 for term in input_list + [output_subscript]] 

927 max_size = max(size_list) 

928 

929 if memory_limit is None: 

930 memory_arg = max_size 

931 else: 

932 memory_arg = memory_limit 

933 

934 # Compute naive cost 

935 # This isn't quite right, need to look into exactly how einsum does this 

936 inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0 

937 naive_cost = _flop_count( 

938 indices, inner_product, len(input_list), dimension_dict 

939 ) 

940 

941 # Compute the path 

942 if explicit_einsum_path: 

943 path = path_type[1:] 

944 elif ( 

945 (path_type is False) 

946 or (len(input_list) in [1, 2]) 

947 or (indices == output_set) 

948 ): 

949 # Nothing to be optimized, leave it to einsum 

950 path = [tuple(range(len(input_list)))] 

951 elif path_type == "greedy": 

952 path = _greedy_path( 

953 input_sets, output_set, dimension_dict, memory_arg 

954 ) 

955 elif path_type == "optimal": 

956 path = _optimal_path( 

957 input_sets, output_set, dimension_dict, memory_arg 

958 ) 

959 else: 

960 raise KeyError("Path name %s not found", path_type) 

961 

962 cost_list, scale_list, size_list, contraction_list = [], [], [], [] 

963 

964 # Build contraction tuple (positions, gemm, einsum_str, remaining) 

965 for cnum, contract_inds in enumerate(path): 

966 # Make sure we remove inds from right to left 

967 contract_inds = tuple(sorted(list(contract_inds), reverse=True)) 

968 

969 contract = _find_contraction(contract_inds, input_sets, output_set) 

970 out_inds, input_sets, idx_removed, idx_contract = contract 

971 

972 cost = _flop_count( 

973 idx_contract, idx_removed, len(contract_inds), dimension_dict 

974 ) 

975 cost_list.append(cost) 

976 scale_list.append(len(idx_contract)) 

977 size_list.append(_compute_size_by_dict(out_inds, dimension_dict)) 

978 

979 bcast = set() 

980 tmp_inputs = [] 

981 for x in contract_inds: 

982 tmp_inputs.append(input_list.pop(x)) 

983 bcast |= broadcast_indices.pop(x) 

984 

985 new_bcast_inds = bcast - idx_removed 

986 

987 # If we're broadcasting, nix blas 

988 if not len(idx_removed & bcast): 

989 do_blas = _can_dot(tmp_inputs, out_inds, idx_removed) 

990 else: 

991 do_blas = False 

992 

993 # Last contraction 

994 if (cnum - len(path)) == -1: 

995 idx_result = output_subscript 

996 else: 

997 sort_result = [(dimension_dict[ind], ind) for ind in out_inds] 

998 idx_result = "".join([x[1] for x in sorted(sort_result)]) 

999 

1000 input_list.append(idx_result) 

1001 broadcast_indices.append(new_bcast_inds) 

1002 einsum_str = ",".join(tmp_inputs) + "->" + idx_result 

1003 

1004 contraction = ( 

1005 contract_inds, idx_removed, einsum_str, input_list[:], do_blas 

1006 ) 

1007 contraction_list.append(contraction) 

1008 

1009 opt_cost = sum(cost_list) + 1 

1010 

1011 if len(input_list) != 1: 

1012 # Explicit "einsum_path" is usually trusted, but we detect this kind of 

1013 # mistake in order to prevent from returning an intermediate value. 

1014 raise RuntimeError( 

1015 "Invalid einsum_path is specified: {} more operands has to be " 

1016 "contracted.".format(len(input_list) - 1)) 

1017 

1018 if einsum_call_arg: 

1019 return (operands, contraction_list) 

1020 

1021 # Return the path along with a nice string representation 

1022 overall_contraction = input_subscripts + "->" + output_subscript 

1023 header = ("scaling", "current", "remaining") 

1024 

1025 speedup = naive_cost / opt_cost 

1026 max_i = max(size_list) 

1027 

1028 path_print = " Complete contraction: %s\n" % overall_contraction 

1029 path_print += " Naive scaling: %d\n" % len(indices) 

1030 path_print += " Optimized scaling: %d\n" % max(scale_list) 

1031 path_print += " Naive FLOP count: %.3e\n" % naive_cost 

1032 path_print += " Optimized FLOP count: %.3e\n" % opt_cost 

1033 path_print += " Theoretical speedup: %3.3f\n" % speedup 

1034 path_print += " Largest intermediate: %.3e elements\n" % max_i 

1035 path_print += "-" * 74 + "\n" 

1036 path_print += "%6s %24s %40s\n" % header 

1037 path_print += "-" * 74 

1038 

1039 for n, contraction in enumerate(contraction_list): 

1040 inds, idx_rm, einsum_str, remaining, blas = contraction 

1041 remaining_str = ",".join(remaining) + "->" + output_subscript 

1042 path_run = (scale_list[n], einsum_str, remaining_str) 

1043 path_print += "\n%4d %24s %40s" % path_run 

1044 

1045 path = ['einsum_path'] + path 

1046 return (path, path_print) 

1047 

1048 

1049def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs): 

1050 # Arguably we dispatch on more arguments than we really should; see note in 

1051 # _einsum_path_dispatcher for why. 

1052 yield from operands 

1053 yield out 

1054 

1055 

1056# Rewrite einsum to handle different cases 

1057@array_function_dispatch(_einsum_dispatcher, module='numpy') 

1058def einsum(*operands, out=None, optimize=False, **kwargs): 

1059 """ 

1060 einsum(subscripts, *operands, out=None, dtype=None, order='K', 

1061 casting='safe', optimize=False) 

1062 

1063 Evaluates the Einstein summation convention on the operands. 

1064 

1065 Using the Einstein summation convention, many common multi-dimensional, 

1066 linear algebraic array operations can be represented in a simple fashion. 

1067 In *implicit* mode `einsum` computes these values. 

1068 

1069 In *explicit* mode, `einsum` provides further flexibility to compute 

1070 other array operations that might not be considered classical Einstein 

1071 summation operations, by disabling, or forcing summation over specified 

1072 subscript labels. 

1073 

1074 See the notes and examples for clarification. 

1075 

1076 Parameters 

1077 ---------- 

1078 subscripts : str 

1079 Specifies the subscripts for summation as comma separated list of 

1080 subscript labels. An implicit (classical Einstein summation) 

1081 calculation is performed unless the explicit indicator '->' is 

1082 included as well as subscript labels of the precise output form. 

1083 operands : list of array_like 

1084 These are the arrays for the operation. 

1085 out : ndarray, optional 

1086 If provided, the calculation is done into this array. 

1087 dtype : {data-type, None}, optional 

1088 If provided, forces the calculation to use the data type specified. 

1089 Note that you may have to also give a more liberal `casting` 

1090 parameter to allow the conversions. Default is None. 

1091 order : {'C', 'F', 'A', 'K'}, optional 

1092 Controls the memory layout of the output. 'C' means it should 

1093 be C contiguous. 'F' means it should be Fortran contiguous, 

1094 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise. 

1095 'K' means it should be as close to the layout as the inputs as 

1096 is possible, including arbitrarily permuted axes. 

1097 Default is 'K'. 

1098 casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional 

1099 Controls what kind of data casting may occur. Setting this to 

1100 'unsafe' is not recommended, as it can adversely affect accumulations. 

1101 

1102 * 'no' means the data types should not be cast at all. 

1103 * 'equiv' means only byte-order changes are allowed. 

1104 * 'safe' means only casts which can preserve values are allowed. 

1105 * 'same_kind' means only safe casts or casts within a kind, 

1106 like float64 to float32, are allowed. 

1107 * 'unsafe' means any data conversions may be done. 

1108 

1109 Default is 'safe'. 

1110 optimize : {False, True, 'greedy', 'optimal'}, optional 

1111 Controls if intermediate optimization should occur. No optimization 

1112 will occur if False and True will default to the 'greedy' algorithm. 

1113 Also accepts an explicit contraction list from the ``np.einsum_path`` 

1114 function. See ``np.einsum_path`` for more details. Defaults to False. 

1115 

1116 Returns 

1117 ------- 

1118 output : ndarray 

1119 The calculation based on the Einstein summation convention. 

1120 

1121 See Also 

1122 -------- 

1123 einsum_path, dot, inner, outer, tensordot, linalg.multi_dot 

1124 einsum: 

1125 Similar verbose interface is provided by the 

1126 `einops <https://github.com/arogozhnikov/einops>`_ package to cover 

1127 additional operations: transpose, reshape/flatten, repeat/tile, 

1128 squeeze/unsqueeze and reductions. 

1129 The `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_ 

1130 optimizes contraction order for einsum-like expressions 

1131 in backend-agnostic manner. 

1132 

1133 Notes 

1134 ----- 

1135 .. versionadded:: 1.6.0 

1136 

1137 The Einstein summation convention can be used to compute 

1138 many multi-dimensional, linear algebraic array operations. `einsum` 

1139 provides a succinct way of representing these. 

1140 

1141 A non-exhaustive list of these operations, 

1142 which can be computed by `einsum`, is shown below along with examples: 

1143 

1144 * Trace of an array, :py:func:`numpy.trace`. 

1145 * Return a diagonal, :py:func:`numpy.diag`. 

1146 * Array axis summations, :py:func:`numpy.sum`. 

1147 * Transpositions and permutations, :py:func:`numpy.transpose`. 

1148 * Matrix multiplication and dot product, :py:func:`numpy.matmul` 

1149 :py:func:`numpy.dot`. 

1150 * Vector inner and outer products, :py:func:`numpy.inner` 

1151 :py:func:`numpy.outer`. 

1152 * Broadcasting, element-wise and scalar multiplication, 

1153 :py:func:`numpy.multiply`. 

1154 * Tensor contractions, :py:func:`numpy.tensordot`. 

1155 * Chained array operations, in efficient calculation order, 

1156 :py:func:`numpy.einsum_path`. 

1157 

1158 The subscripts string is a comma-separated list of subscript labels, 

1159 where each label refers to a dimension of the corresponding operand. 

1160 Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)`` 

1161 is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label 

1162 appears only once, it is not summed, so ``np.einsum('i', a)`` 

1163 produces a view of ``a`` with no changes. A further example 

1164 ``np.einsum('ij,jk', a, b)`` describes traditional matrix multiplication 

1165 and is equivalent to :py:func:`np.matmul(a,b) <numpy.matmul>`. 

1166 Repeated subscript labels in one operand take the diagonal. 

1167 For example, ``np.einsum('ii', a)`` is equivalent to 

1168 :py:func:`np.trace(a) <numpy.trace>`. 

1169 

1170 In *implicit mode*, the chosen subscripts are important 

1171 since the axes of the output are reordered alphabetically. This 

1172 means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while 

1173 ``np.einsum('ji', a)`` takes its transpose. Additionally, 

1174 ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while, 

1175 ``np.einsum('ij,jh', a, b)`` returns the transpose of the 

1176 multiplication since subscript 'h' precedes subscript 'i'. 

1177 

1178 In *explicit mode* the output can be directly controlled by 

1179 specifying output subscript labels. This requires the 

1180 identifier '->' as well as the list of output subscript labels. 

1181 This feature increases the flexibility of the function since 

1182 summing can be disabled or forced when required. The call 

1183 ``np.einsum('i->', a)`` is like :py:func:`np.sum(a) <numpy.sum>` 

1184 if ``a`` is a 1-D array, and ``np.einsum('ii->i', a)`` 

1185 is like :py:func:`np.diag(a) <numpy.diag>` if ``a`` is a square 2-D array. 

1186 The difference is that `einsum` does not allow broadcasting by default.  

1187 Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the 

1188 order of the output subscript labels and therefore returns matrix 

1189 multiplication, unlike the example above in implicit mode. 

1190 

1191 To enable and control broadcasting, use an ellipsis. Default 

1192 NumPy-style broadcasting is done by adding an ellipsis 

1193 to the left of each term, like ``np.einsum('...ii->...i', a)``. 

1194 ``np.einsum('...i->...', a)`` is like  

1195 :py:func:`np.sum(a, axis=-1) <numpy.sum>` for array ``a`` of any shape. 

1196 To take the trace along the first and last axes, 

1197 you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix 

1198 product with the left-most indices instead of rightmost, one can do 

1199 ``np.einsum('ij...,jk...->ik...', a, b)``. 

1200 

1201 When there is only one operand, no axes are summed, and no output 

1202 parameter is provided, a view into the operand is returned instead 

1203 of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)`` 

1204 produces a view (changed in version 1.10.0). 

1205 

1206 `einsum` also provides an alternative way to provide the subscripts and 

1207 operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``. 

1208 If the output shape is not provided in this format `einsum` will be 

1209 calculated in implicit mode, otherwise it will be performed explicitly. 

1210 The examples below have corresponding `einsum` calls with the two 

1211 parameter methods. 

1212 

1213 .. versionadded:: 1.10.0 

1214 

1215 Views returned from einsum are now writeable whenever the input array 

1216 is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now 

1217 have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>` 

1218 and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal 

1219 of a 2D array. 

1220 

1221 .. versionadded:: 1.12.0 

1222 

1223 Added the ``optimize`` argument which will optimize the contraction order 

1224 of an einsum expression. For a contraction with three or more operands 

1225 this can greatly increase the computational efficiency at the cost of 

1226 a larger memory footprint during computation. 

1227 

1228 Typically a 'greedy' algorithm is applied which empirical tests have shown 

1229 returns the optimal path in the majority of cases. In some cases 'optimal' 

1230 will return the superlative path through a more expensive, exhaustive 

1231 search. For iterative calculations it may be advisable to calculate 

1232 the optimal path once and reuse that path by supplying it as an argument. 

1233 An example is given below. 

1234 

1235 See :py:func:`numpy.einsum_path` for more details. 

1236 

1237 Examples 

1238 -------- 

1239 >>> a = np.arange(25).reshape(5,5) 

1240 >>> b = np.arange(5) 

1241 >>> c = np.arange(6).reshape(2,3) 

1242 

1243 Trace of a matrix: 

1244 

1245 >>> np.einsum('ii', a) 

1246 60 

1247 >>> np.einsum(a, [0,0]) 

1248 60 

1249 >>> np.trace(a) 

1250 60 

1251 

1252 Extract the diagonal (requires explicit form): 

1253 

1254 >>> np.einsum('ii->i', a) 

1255 array([ 0, 6, 12, 18, 24]) 

1256 >>> np.einsum(a, [0,0], [0]) 

1257 array([ 0, 6, 12, 18, 24]) 

1258 >>> np.diag(a) 

1259 array([ 0, 6, 12, 18, 24]) 

1260 

1261 Sum over an axis (requires explicit form): 

1262 

1263 >>> np.einsum('ij->i', a) 

1264 array([ 10, 35, 60, 85, 110]) 

1265 >>> np.einsum(a, [0,1], [0]) 

1266 array([ 10, 35, 60, 85, 110]) 

1267 >>> np.sum(a, axis=1) 

1268 array([ 10, 35, 60, 85, 110]) 

1269 

1270 For higher dimensional arrays summing a single axis can be done 

1271 with ellipsis: 

1272 

1273 >>> np.einsum('...j->...', a) 

1274 array([ 10, 35, 60, 85, 110]) 

1275 >>> np.einsum(a, [Ellipsis,1], [Ellipsis]) 

1276 array([ 10, 35, 60, 85, 110]) 

1277 

1278 Compute a matrix transpose, or reorder any number of axes: 

1279 

1280 >>> np.einsum('ji', c) 

1281 array([[0, 3], 

1282 [1, 4], 

1283 [2, 5]]) 

1284 >>> np.einsum('ij->ji', c) 

1285 array([[0, 3], 

1286 [1, 4], 

1287 [2, 5]]) 

1288 >>> np.einsum(c, [1,0]) 

1289 array([[0, 3], 

1290 [1, 4], 

1291 [2, 5]]) 

1292 >>> np.transpose(c) 

1293 array([[0, 3], 

1294 [1, 4], 

1295 [2, 5]]) 

1296 

1297 Vector inner products: 

1298 

1299 >>> np.einsum('i,i', b, b) 

1300 30 

1301 >>> np.einsum(b, [0], b, [0]) 

1302 30 

1303 >>> np.inner(b,b) 

1304 30 

1305 

1306 Matrix vector multiplication: 

1307 

1308 >>> np.einsum('ij,j', a, b) 

1309 array([ 30, 80, 130, 180, 230]) 

1310 >>> np.einsum(a, [0,1], b, [1]) 

1311 array([ 30, 80, 130, 180, 230]) 

1312 >>> np.dot(a, b) 

1313 array([ 30, 80, 130, 180, 230]) 

1314 >>> np.einsum('...j,j', a, b) 

1315 array([ 30, 80, 130, 180, 230]) 

1316 

1317 Broadcasting and scalar multiplication: 

1318 

1319 >>> np.einsum('..., ...', 3, c) 

1320 array([[ 0, 3, 6], 

1321 [ 9, 12, 15]]) 

1322 >>> np.einsum(',ij', 3, c) 

1323 array([[ 0, 3, 6], 

1324 [ 9, 12, 15]]) 

1325 >>> np.einsum(3, [Ellipsis], c, [Ellipsis]) 

1326 array([[ 0, 3, 6], 

1327 [ 9, 12, 15]]) 

1328 >>> np.multiply(3, c) 

1329 array([[ 0, 3, 6], 

1330 [ 9, 12, 15]]) 

1331 

1332 Vector outer product: 

1333 

1334 >>> np.einsum('i,j', np.arange(2)+1, b) 

1335 array([[0, 1, 2, 3, 4], 

1336 [0, 2, 4, 6, 8]]) 

1337 >>> np.einsum(np.arange(2)+1, [0], b, [1]) 

1338 array([[0, 1, 2, 3, 4], 

1339 [0, 2, 4, 6, 8]]) 

1340 >>> np.outer(np.arange(2)+1, b) 

1341 array([[0, 1, 2, 3, 4], 

1342 [0, 2, 4, 6, 8]]) 

1343 

1344 Tensor contraction: 

1345 

1346 >>> a = np.arange(60.).reshape(3,4,5) 

1347 >>> b = np.arange(24.).reshape(4,3,2) 

1348 >>> np.einsum('ijk,jil->kl', a, b) 

1349 array([[4400., 4730.], 

1350 [4532., 4874.], 

1351 [4664., 5018.], 

1352 [4796., 5162.], 

1353 [4928., 5306.]]) 

1354 >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3]) 

1355 array([[4400., 4730.], 

1356 [4532., 4874.], 

1357 [4664., 5018.], 

1358 [4796., 5162.], 

1359 [4928., 5306.]]) 

1360 >>> np.tensordot(a,b, axes=([1,0],[0,1])) 

1361 array([[4400., 4730.], 

1362 [4532., 4874.], 

1363 [4664., 5018.], 

1364 [4796., 5162.], 

1365 [4928., 5306.]]) 

1366 

1367 Writeable returned arrays (since version 1.10.0): 

1368 

1369 >>> a = np.zeros((3, 3)) 

1370 >>> np.einsum('ii->i', a)[:] = 1 

1371 >>> a 

1372 array([[1., 0., 0.], 

1373 [0., 1., 0.], 

1374 [0., 0., 1.]]) 

1375 

1376 Example of ellipsis use: 

1377 

1378 >>> a = np.arange(6).reshape((3,2)) 

1379 >>> b = np.arange(12).reshape((4,3)) 

1380 >>> np.einsum('ki,jk->ij', a, b) 

1381 array([[10, 28, 46, 64], 

1382 [13, 40, 67, 94]]) 

1383 >>> np.einsum('ki,...k->i...', a, b) 

1384 array([[10, 28, 46, 64], 

1385 [13, 40, 67, 94]]) 

1386 >>> np.einsum('k...,jk', a, b) 

1387 array([[10, 28, 46, 64], 

1388 [13, 40, 67, 94]]) 

1389 

1390 Chained array operations. For more complicated contractions, speed ups 

1391 might be achieved by repeatedly computing a 'greedy' path or pre-computing 

1392 the 'optimal' path and repeatedly applying it, using an `einsum_path` 

1393 insertion (since version 1.12.0). Performance improvements can be 

1394 particularly significant with larger arrays: 

1395 

1396 >>> a = np.ones(64).reshape(2,4,8) 

1397 

1398 Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.) 

1399 

1400 >>> for iteration in range(500): 

1401 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a) 

1402 

1403 Sub-optimal `einsum` (due to repeated path calculation time): ~330ms 

1404 

1405 >>> for iteration in range(500): 

1406 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, 

1407 ... optimize='optimal') 

1408 

1409 Greedy `einsum` (faster optimal path approximation): ~160ms 

1410 

1411 >>> for iteration in range(500): 

1412 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy') 

1413 

1414 Optimal `einsum` (best usage pattern in some use cases): ~110ms 

1415 

1416 >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a,  

1417 ... optimize='optimal')[0] 

1418 >>> for iteration in range(500): 

1419 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path) 

1420 

1421 """ 

1422 # Special handling if out is specified 

1423 specified_out = out is not None 

1424 

1425 # If no optimization, run pure einsum 

1426 if optimize is False: 

1427 if specified_out: 

1428 kwargs['out'] = out 

1429 return c_einsum(*operands, **kwargs) 

1430 

1431 # Check the kwargs to avoid a more cryptic error later, without having to 

1432 # repeat default values here 

1433 valid_einsum_kwargs = ['dtype', 'order', 'casting'] 

1434 unknown_kwargs = [k for (k, v) in kwargs.items() if 

1435 k not in valid_einsum_kwargs] 

1436 if len(unknown_kwargs): 

1437 raise TypeError("Did not understand the following kwargs: %s" 

1438 % unknown_kwargs) 

1439 

1440 # Build the contraction list and operand 

1441 operands, contraction_list = einsum_path(*operands, optimize=optimize, 

1442 einsum_call=True) 

1443 

1444 # Handle order kwarg for output array, c_einsum allows mixed case 

1445 output_order = kwargs.pop('order', 'K') 

1446 if output_order.upper() == 'A': 

1447 if all(arr.flags.f_contiguous for arr in operands): 

1448 output_order = 'F' 

1449 else: 

1450 output_order = 'C' 

1451 

1452 # Start contraction loop 

1453 for num, contraction in enumerate(contraction_list): 

1454 inds, idx_rm, einsum_str, remaining, blas = contraction 

1455 tmp_operands = [operands.pop(x) for x in inds] 

1456 

1457 # Do we need to deal with the output? 

1458 handle_out = specified_out and ((num + 1) == len(contraction_list)) 

1459 

1460 # Call tensordot if still possible 

1461 if blas: 

1462 # Checks have already been handled 

1463 input_str, results_index = einsum_str.split('->') 

1464 input_left, input_right = input_str.split(',') 

1465 

1466 tensor_result = input_left + input_right 

1467 for s in idx_rm: 

1468 tensor_result = tensor_result.replace(s, "") 

1469 

1470 # Find indices to contract over 

1471 left_pos, right_pos = [], [] 

1472 for s in sorted(idx_rm): 

1473 left_pos.append(input_left.find(s)) 

1474 right_pos.append(input_right.find(s)) 

1475 

1476 # Contract! 

1477 new_view = tensordot( 

1478 *tmp_operands, axes=(tuple(left_pos), tuple(right_pos)) 

1479 ) 

1480 

1481 # Build a new view if needed 

1482 if (tensor_result != results_index) or handle_out: 

1483 if handle_out: 

1484 kwargs["out"] = out 

1485 new_view = c_einsum( 

1486 tensor_result + '->' + results_index, new_view, **kwargs 

1487 ) 

1488 

1489 # Call einsum 

1490 else: 

1491 # If out was specified 

1492 if handle_out: 

1493 kwargs["out"] = out 

1494 

1495 # Do the contraction 

1496 new_view = c_einsum(einsum_str, *tmp_operands, **kwargs) 

1497 

1498 # Append new items and dereference what we can 

1499 operands.append(new_view) 

1500 del tmp_operands, new_view 

1501 

1502 if specified_out: 

1503 return out 

1504 else: 

1505 return asanyarray(operands[0], order=output_order)