Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/numpy/_core/einsumfunc.py: 6%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

494 statements  

1""" 

2Implementation of optimized einsum. 

3 

4""" 

5import functools 

6import itertools 

7import operator 

8 

9from numpy._core.multiarray import c_einsum, matmul 

10from numpy._core.numeric import asanyarray, reshape 

11from numpy._core.overrides import array_function_dispatch 

12from numpy._core.umath import multiply 

13 

14__all__ = ['einsum', 'einsum_path'] 

15 

16# importing string for string.ascii_letters would be too slow 

17# the first import before caching has been measured to take 800 µs (#23777) 

18# imports begin with uppercase to mimic ASCII values to avoid sorting issues 

19einsum_symbols = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 

20einsum_symbols_set = set(einsum_symbols) 

21 

22 

23def _flop_count(idx_contraction, inner, num_terms, size_dictionary): 

24 """ 

25 Computes the number of FLOPS in the contraction. 

26 

27 Parameters 

28 ---------- 

29 idx_contraction : iterable 

30 The indices involved in the contraction 

31 inner : bool 

32 Does this contraction require an inner product? 

33 num_terms : int 

34 The number of terms in a contraction 

35 size_dictionary : dict 

36 The size of each of the indices in idx_contraction 

37 

38 Returns 

39 ------- 

40 flop_count : int 

41 The total number of FLOPS required for the contraction. 

42 

43 Examples 

44 -------- 

45 

46 >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5}) 

47 30 

48 

49 >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5}) 

50 60 

51 

52 """ 

53 

54 overall_size = _compute_size_by_dict(idx_contraction, size_dictionary) 

55 op_factor = max(1, num_terms - 1) 

56 if inner: 

57 op_factor += 1 

58 

59 return overall_size * op_factor 

60 

61def _compute_size_by_dict(indices, idx_dict): 

62 """ 

63 Computes the product of the elements in indices based on the dictionary 

64 idx_dict. 

65 

66 Parameters 

67 ---------- 

68 indices : iterable 

69 Indices to base the product on. 

70 idx_dict : dictionary 

71 Dictionary of index sizes 

72 

73 Returns 

74 ------- 

75 ret : int 

76 The resulting product. 

77 

78 Examples 

79 -------- 

80 >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5}) 

81 90 

82 

83 """ 

84 ret = 1 

85 for i in indices: 

86 ret *= idx_dict[i] 

87 return ret 

88 

89 

90def _find_contraction(positions, input_sets, output_set): 

91 """ 

92 Finds the contraction for a given set of input and output sets. 

93 

94 Parameters 

95 ---------- 

96 positions : iterable 

97 Integer positions of terms used in the contraction. 

98 input_sets : list 

99 List of sets that represent the lhs side of the einsum subscript 

100 output_set : set 

101 Set that represents the rhs side of the overall einsum subscript 

102 

103 Returns 

104 ------- 

105 new_result : set 

106 The indices of the resulting contraction 

107 remaining : list 

108 List of sets that have not been contracted, the new set is appended to 

109 the end of this list 

110 idx_removed : set 

111 Indices removed from the entire contraction 

112 idx_contraction : set 

113 The indices used in the current contraction 

114 

115 Examples 

116 -------- 

117 

118 # A simple dot product test case 

119 >>> pos = (0, 1) 

120 >>> isets = [set('ab'), set('bc')] 

121 >>> oset = set('ac') 

122 >>> _find_contraction(pos, isets, oset) 

123 ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'}) 

124 

125 # A more complex case with additional terms in the contraction 

126 >>> pos = (0, 2) 

127 >>> isets = [set('abd'), set('ac'), set('bdc')] 

128 >>> oset = set('ac') 

129 >>> _find_contraction(pos, isets, oset) 

130 ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'}) 

131 """ 

132 

133 idx_contract = set() 

134 idx_remain = output_set.copy() 

135 remaining = [] 

136 for ind, value in enumerate(input_sets): 

137 if ind in positions: 

138 idx_contract |= value 

139 else: 

140 remaining.append(value) 

141 idx_remain |= value 

142 

143 new_result = idx_remain & idx_contract 

144 idx_removed = (idx_contract - new_result) 

145 remaining.append(new_result) 

146 

147 return (new_result, remaining, idx_removed, idx_contract) 

148 

149 

150def _optimal_path(input_sets, output_set, idx_dict, memory_limit): 

151 """ 

152 Computes all possible pair contractions, sieves the results based 

153 on ``memory_limit`` and returns the lowest cost path. This algorithm 

154 scales factorial with respect to the elements in the list ``input_sets``. 

155 

156 Parameters 

157 ---------- 

158 input_sets : list 

159 List of sets that represent the lhs side of the einsum subscript 

160 output_set : set 

161 Set that represents the rhs side of the overall einsum subscript 

162 idx_dict : dictionary 

163 Dictionary of index sizes 

164 memory_limit : int 

165 The maximum number of elements in a temporary array 

166 

167 Returns 

168 ------- 

169 path : list 

170 The optimal contraction order within the memory limit constraint. 

171 

172 Examples 

173 -------- 

174 >>> isets = [set('abd'), set('ac'), set('bdc')] 

175 >>> oset = set() 

176 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 

177 >>> _optimal_path(isets, oset, idx_sizes, 5000) 

178 [(0, 2), (0, 1)] 

179 """ 

180 

181 full_results = [(0, [], input_sets)] 

182 for iteration in range(len(input_sets) - 1): 

183 iter_results = [] 

184 

185 # Compute all unique pairs 

186 for curr in full_results: 

187 cost, positions, remaining = curr 

188 for con in itertools.combinations( 

189 range(len(input_sets) - iteration), 2 

190 ): 

191 

192 # Find the contraction 

193 cont = _find_contraction(con, remaining, output_set) 

194 new_result, new_input_sets, idx_removed, idx_contract = cont 

195 

196 # Sieve the results based on memory_limit 

197 new_size = _compute_size_by_dict(new_result, idx_dict) 

198 if new_size > memory_limit: 

199 continue 

200 

201 # Build (total_cost, positions, indices_remaining) 

202 total_cost = cost + _flop_count( 

203 idx_contract, idx_removed, len(con), idx_dict 

204 ) 

205 new_pos = positions + [con] 

206 iter_results.append((total_cost, new_pos, new_input_sets)) 

207 

208 # Update combinatorial list, if we did not find anything return best 

209 # path + remaining contractions 

210 if iter_results: 

211 full_results = iter_results 

212 else: 

213 path = min(full_results, key=lambda x: x[0])[1] 

214 path += [tuple(range(len(input_sets) - iteration))] 

215 return path 

216 

217 # If we have not found anything return single einsum contraction 

218 if len(full_results) == 0: 

219 return [tuple(range(len(input_sets)))] 

220 

221 path = min(full_results, key=lambda x: x[0])[1] 

222 return path 

223 

224def _parse_possible_contraction( 

225 positions, input_sets, output_set, idx_dict, 

226 memory_limit, path_cost, naive_cost 

227 ): 

228 """Compute the cost (removed size + flops) and resultant indices for 

229 performing the contraction specified by ``positions``. 

230 

231 Parameters 

232 ---------- 

233 positions : tuple of int 

234 The locations of the proposed tensors to contract. 

235 input_sets : list of sets 

236 The indices found on each tensors. 

237 output_set : set 

238 The output indices of the expression. 

239 idx_dict : dict 

240 Mapping of each index to its size. 

241 memory_limit : int 

242 The total allowed size for an intermediary tensor. 

243 path_cost : int 

244 The contraction cost so far. 

245 naive_cost : int 

246 The cost of the unoptimized expression. 

247 

248 Returns 

249 ------- 

250 cost : (int, int) 

251 A tuple containing the size of any indices removed, and the flop cost. 

252 positions : tuple of int 

253 The locations of the proposed tensors to contract. 

254 new_input_sets : list of sets 

255 The resulting new list of indices if this proposed contraction 

256 is performed. 

257 

258 """ 

259 

260 # Find the contraction 

261 contract = _find_contraction(positions, input_sets, output_set) 

262 idx_result, new_input_sets, idx_removed, idx_contract = contract 

263 

264 # Sieve the results based on memory_limit 

265 new_size = _compute_size_by_dict(idx_result, idx_dict) 

266 if new_size > memory_limit: 

267 return None 

268 

269 # Build sort tuple 

270 old_sizes = ( 

271 _compute_size_by_dict(input_sets[p], idx_dict) for p in positions 

272 ) 

273 removed_size = sum(old_sizes) - new_size 

274 

275 # NB: removed_size used to be just the size of any removed indices i.e.: 

276 # helpers.compute_size_by_dict(idx_removed, idx_dict) 

277 cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict) 

278 sort = (-removed_size, cost) 

279 

280 # Sieve based on total cost as well 

281 if (path_cost + cost) > naive_cost: 

282 return None 

283 

284 # Add contraction to possible choices 

285 return [sort, positions, new_input_sets] 

286 

287 

288def _update_other_results(results, best): 

289 """Update the positions and provisional input_sets of ``results`` 

290 based on performing the contraction result ``best``. Remove any 

291 involving the tensors contracted. 

292 

293 Parameters 

294 ---------- 

295 results : list 

296 List of contraction results produced by 

297 ``_parse_possible_contraction``. 

298 best : list 

299 The best contraction of ``results`` i.e. the one that 

300 will be performed. 

301 

302 Returns 

303 ------- 

304 mod_results : list 

305 The list of modified results, updated with outcome of 

306 ``best`` contraction. 

307 """ 

308 

309 best_con = best[1] 

310 bx, by = best_con 

311 mod_results = [] 

312 

313 for cost, (x, y), con_sets in results: 

314 

315 # Ignore results involving tensors just contracted 

316 if x in best_con or y in best_con: 

317 continue 

318 

319 # Update the input_sets 

320 del con_sets[by - int(by > x) - int(by > y)] 

321 del con_sets[bx - int(bx > x) - int(bx > y)] 

322 con_sets.insert(-1, best[2][-1]) 

323 

324 # Update the position indices 

325 mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by) 

326 mod_results.append((cost, mod_con, con_sets)) 

327 

328 return mod_results 

329 

330def _greedy_path(input_sets, output_set, idx_dict, memory_limit): 

331 """ 

332 Finds the path by contracting the best pair until the input list is 

333 exhausted. The best pair is found by minimizing the tuple 

334 ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing 

335 matrix multiplication or inner product operations, then Hadamard like 

336 operations, and finally outer operations. Outer products are limited by 

337 ``memory_limit``. This algorithm scales cubically with respect to the 

338 number of elements in the list ``input_sets``. 

339 

340 Parameters 

341 ---------- 

342 input_sets : list 

343 List of sets that represent the lhs side of the einsum subscript 

344 output_set : set 

345 Set that represents the rhs side of the overall einsum subscript 

346 idx_dict : dictionary 

347 Dictionary of index sizes 

348 memory_limit : int 

349 The maximum number of elements in a temporary array 

350 

351 Returns 

352 ------- 

353 path : list 

354 The greedy contraction order within the memory limit constraint. 

355 

356 Examples 

357 -------- 

358 >>> isets = [set('abd'), set('ac'), set('bdc')] 

359 >>> oset = set() 

360 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 

361 >>> _greedy_path(isets, oset, idx_sizes, 5000) 

362 [(0, 2), (0, 1)] 

363 """ 

364 

365 # Handle trivial cases that leaked through 

366 if len(input_sets) == 1: 

367 return [(0,)] 

368 elif len(input_sets) == 2: 

369 return [(0, 1)] 

370 

371 # Build up a naive cost 

372 contract = _find_contraction( 

373 range(len(input_sets)), input_sets, output_set 

374 ) 

375 idx_result, new_input_sets, idx_removed, idx_contract = contract 

376 naive_cost = _flop_count( 

377 idx_contract, idx_removed, len(input_sets), idx_dict 

378 ) 

379 

380 # Initially iterate over all pairs 

381 comb_iter = itertools.combinations(range(len(input_sets)), 2) 

382 known_contractions = [] 

383 

384 path_cost = 0 

385 path = [] 

386 

387 for iteration in range(len(input_sets) - 1): 

388 

389 # Iterate over all pairs on the first step, only previously 

390 # found pairs on subsequent steps 

391 for positions in comb_iter: 

392 

393 # Always initially ignore outer products 

394 if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]): 

395 continue 

396 

397 result = _parse_possible_contraction( 

398 positions, input_sets, output_set, idx_dict, 

399 memory_limit, path_cost, naive_cost 

400 ) 

401 if result is not None: 

402 known_contractions.append(result) 

403 

404 # If we do not have a inner contraction, rescan pairs 

405 # including outer products 

406 if len(known_contractions) == 0: 

407 

408 # Then check the outer products 

409 for positions in itertools.combinations( 

410 range(len(input_sets)), 2 

411 ): 

412 result = _parse_possible_contraction( 

413 positions, input_sets, output_set, idx_dict, 

414 memory_limit, path_cost, naive_cost 

415 ) 

416 if result is not None: 

417 known_contractions.append(result) 

418 

419 # If we still did not find any remaining contractions, 

420 # default back to einsum like behavior 

421 if len(known_contractions) == 0: 

422 path.append(tuple(range(len(input_sets)))) 

423 break 

424 

425 # Sort based on first index 

426 best = min(known_contractions, key=lambda x: x[0]) 

427 

428 # Now propagate as many unused contractions as possible 

429 # to the next iteration 

430 known_contractions = _update_other_results(known_contractions, best) 

431 

432 # Next iteration only compute contractions with the new tensor 

433 # All other contractions have been accounted for 

434 input_sets = best[2] 

435 new_tensor_pos = len(input_sets) - 1 

436 comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos)) 

437 

438 # Update path and total cost 

439 path.append(best[1]) 

440 path_cost += best[0][1] 

441 

442 return path 

443 

444 

445def _parse_einsum_input(operands): 

446 """ 

447 A reproduction of einsum c side einsum parsing in python. 

448 

449 Returns 

450 ------- 

451 input_strings : str 

452 Parsed input strings 

453 output_string : str 

454 Parsed output string 

455 operands : list of array_like 

456 The operands to use in the numpy contraction 

457 

458 Examples 

459 -------- 

460 The operand list is simplified to reduce printing: 

461 

462 >>> np.random.seed(123) 

463 >>> a = np.random.rand(4, 4) 

464 >>> b = np.random.rand(4, 4, 4) 

465 >>> _parse_einsum_input(('...a,...a->...', a, b)) 

466 ('za,xza', 'xz', [a, b]) # may vary 

467 

468 >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) 

469 ('za,xza', 'xz', [a, b]) # may vary 

470 """ 

471 

472 if len(operands) == 0: 

473 raise ValueError("No input operands") 

474 

475 if isinstance(operands[0], str): 

476 subscripts = operands[0].replace(" ", "") 

477 operands = [asanyarray(v) for v in operands[1:]] 

478 

479 # Ensure all characters are valid 

480 for s in subscripts: 

481 if s in '.,->': 

482 continue 

483 if s not in einsum_symbols: 

484 raise ValueError(f"Character {s} is not a valid symbol.") 

485 

486 else: 

487 tmp_operands = list(operands) 

488 operand_list = [] 

489 subscript_list = [] 

490 for p in range(len(operands) // 2): 

491 operand_list.append(tmp_operands.pop(0)) 

492 subscript_list.append(tmp_operands.pop(0)) 

493 

494 output_list = tmp_operands[-1] if len(tmp_operands) else None 

495 operands = [asanyarray(v) for v in operand_list] 

496 subscripts = "" 

497 last = len(subscript_list) - 1 

498 for num, sub in enumerate(subscript_list): 

499 for s in sub: 

500 if s is Ellipsis: 

501 subscripts += "..." 

502 else: 

503 try: 

504 s = operator.index(s) 

505 except TypeError as e: 

506 raise TypeError( 

507 "For this input type lists must contain " 

508 "either int or Ellipsis" 

509 ) from e 

510 subscripts += einsum_symbols[s] 

511 if num != last: 

512 subscripts += "," 

513 

514 if output_list is not None: 

515 subscripts += "->" 

516 for s in output_list: 

517 if s is Ellipsis: 

518 subscripts += "..." 

519 else: 

520 try: 

521 s = operator.index(s) 

522 except TypeError as e: 

523 raise TypeError( 

524 "For this input type lists must contain " 

525 "either int or Ellipsis" 

526 ) from e 

527 subscripts += einsum_symbols[s] 

528 # Check for proper "->" 

529 if ("-" in subscripts) or (">" in subscripts): 

530 invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1) 

531 if invalid or (subscripts.count("->") != 1): 

532 raise ValueError("Subscripts can only contain one '->'.") 

533 

534 # Parse ellipses 

535 if "." in subscripts: 

536 used = subscripts.replace(".", "").replace(",", "").replace("->", "") 

537 unused = list(einsum_symbols_set - set(used)) 

538 ellipse_inds = "".join(unused) 

539 longest = 0 

540 

541 if "->" in subscripts: 

542 input_tmp, output_sub = subscripts.split("->") 

543 split_subscripts = input_tmp.split(",") 

544 out_sub = True 

545 else: 

546 split_subscripts = subscripts.split(',') 

547 out_sub = False 

548 

549 for num, sub in enumerate(split_subscripts): 

550 if "." in sub: 

551 if (sub.count(".") != 3) or (sub.count("...") != 1): 

552 raise ValueError("Invalid Ellipses.") 

553 

554 # Take into account numerical values 

555 if operands[num].shape == (): 

556 ellipse_count = 0 

557 else: 

558 ellipse_count = max(operands[num].ndim, 1) 

559 ellipse_count -= (len(sub) - 3) 

560 

561 if ellipse_count > longest: 

562 longest = ellipse_count 

563 

564 if ellipse_count < 0: 

565 raise ValueError("Ellipses lengths do not match.") 

566 elif ellipse_count == 0: 

567 split_subscripts[num] = sub.replace('...', '') 

568 else: 

569 rep_inds = ellipse_inds[-ellipse_count:] 

570 split_subscripts[num] = sub.replace('...', rep_inds) 

571 

572 subscripts = ",".join(split_subscripts) 

573 if longest == 0: 

574 out_ellipse = "" 

575 else: 

576 out_ellipse = ellipse_inds[-longest:] 

577 

578 if out_sub: 

579 subscripts += "->" + output_sub.replace("...", out_ellipse) 

580 else: 

581 # Special care for outputless ellipses 

582 output_subscript = "" 

583 tmp_subscripts = subscripts.replace(",", "") 

584 for s in sorted(set(tmp_subscripts)): 

585 if s not in (einsum_symbols): 

586 raise ValueError(f"Character {s} is not a valid symbol.") 

587 if tmp_subscripts.count(s) == 1: 

588 output_subscript += s 

589 normal_inds = ''.join(sorted(set(output_subscript) - 

590 set(out_ellipse))) 

591 

592 subscripts += "->" + out_ellipse + normal_inds 

593 

594 # Build output string if does not exist 

595 if "->" in subscripts: 

596 input_subscripts, output_subscript = subscripts.split("->") 

597 else: 

598 input_subscripts = subscripts 

599 # Build output subscripts 

600 tmp_subscripts = subscripts.replace(",", "") 

601 output_subscript = "" 

602 for s in sorted(set(tmp_subscripts)): 

603 if s not in einsum_symbols: 

604 raise ValueError(f"Character {s} is not a valid symbol.") 

605 if tmp_subscripts.count(s) == 1: 

606 output_subscript += s 

607 

608 # Make sure output subscripts are in the input 

609 for char in output_subscript: 

610 if output_subscript.count(char) != 1: 

611 raise ValueError("Output character %s appeared more than once in " 

612 "the output." % char) 

613 if char not in input_subscripts: 

614 raise ValueError(f"Output character {char} did not appear in the input") 

615 

616 # Make sure number operands is equivalent to the number of terms 

617 if len(input_subscripts.split(',')) != len(operands): 

618 raise ValueError("Number of einsum subscripts must be equal to the " 

619 "number of operands.") 

620 

621 return (input_subscripts, output_subscript, operands) 

622 

623 

624def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None): 

625 # NOTE: technically, we should only dispatch on array-like arguments, not 

626 # subscripts (given as strings). But separating operands into 

627 # arrays/subscripts is a little tricky/slow (given einsum's two supported 

628 # signatures), so as a practical shortcut we dispatch on everything. 

629 # Strings will be ignored for dispatching since they don't define 

630 # __array_function__. 

631 return operands 

632 

633 

634@array_function_dispatch(_einsum_path_dispatcher, module='numpy') 

635def einsum_path(*operands, optimize='greedy', einsum_call=False): 

636 """ 

637 einsum_path(subscripts, *operands, optimize='greedy') 

638 

639 Evaluates the lowest cost contraction order for an einsum expression by 

640 considering the creation of intermediate arrays. 

641 

642 Parameters 

643 ---------- 

644 subscripts : str 

645 Specifies the subscripts for summation. 

646 *operands : list of array_like 

647 These are the arrays for the operation. 

648 optimize : {bool, list, tuple, 'greedy', 'optimal'} 

649 Choose the type of path. If a tuple is provided, the second argument is 

650 assumed to be the maximum intermediate size created. If only a single 

651 argument is provided the largest input or output array size is used 

652 as a maximum intermediate size. 

653 

654 * if a list is given that starts with ``einsum_path``, uses this as the 

655 contraction path 

656 * if False no optimization is taken 

657 * if True defaults to the 'greedy' algorithm 

658 * 'optimal' An algorithm that combinatorially explores all possible 

659 ways of contracting the listed tensors and chooses the least costly 

660 path. Scales exponentially with the number of terms in the 

661 contraction. 

662 * 'greedy' An algorithm that chooses the best pair contraction 

663 at each step. Effectively, this algorithm searches the largest inner, 

664 Hadamard, and then outer products at each step. Scales cubically with 

665 the number of terms in the contraction. Equivalent to the 'optimal' 

666 path for most contractions. 

667 

668 Default is 'greedy'. 

669 

670 Returns 

671 ------- 

672 path : list of tuples 

673 A list representation of the einsum path. 

674 string_repr : str 

675 A printable representation of the einsum path. 

676 

677 Notes 

678 ----- 

679 The resulting path indicates which terms of the input contraction should be 

680 contracted first, the result of this contraction is then appended to the 

681 end of the contraction list. This list can then be iterated over until all 

682 intermediate contractions are complete. 

683 

684 See Also 

685 -------- 

686 einsum, linalg.multi_dot 

687 

688 Examples 

689 -------- 

690 

691 We can begin with a chain dot example. In this case, it is optimal to 

692 contract the ``b`` and ``c`` tensors first as represented by the first 

693 element of the path ``(1, 2)``. The resulting tensor is added to the end 

694 of the contraction and the remaining contraction ``(0, 1)`` is then 

695 completed. 

696 

697 >>> np.random.seed(123) 

698 >>> a = np.random.rand(2, 2) 

699 >>> b = np.random.rand(2, 5) 

700 >>> c = np.random.rand(5, 2) 

701 >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy') 

702 >>> print(path_info[0]) 

703 ['einsum_path', (1, 2), (0, 1)] 

704 >>> print(path_info[1]) 

705 Complete contraction: ij,jk,kl->il # may vary 

706 Naive scaling: 4 

707 Optimized scaling: 3 

708 Naive FLOP count: 1.600e+02 

709 Optimized FLOP count: 5.600e+01 

710 Theoretical speedup: 2.857 

711 Largest intermediate: 4.000e+00 elements 

712 ------------------------------------------------------------------------- 

713 scaling current remaining 

714 ------------------------------------------------------------------------- 

715 3 kl,jk->jl ij,jl->il 

716 3 jl,ij->il il->il 

717 

718 

719 A more complex index transformation example. 

720 

721 >>> I = np.random.rand(10, 10, 10, 10) 

722 >>> C = np.random.rand(10, 10) 

723 >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C, 

724 ... optimize='greedy') 

725 

726 >>> print(path_info[0]) 

727 ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)] 

728 >>> print(path_info[1]) 

729 Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary 

730 Naive scaling: 8 

731 Optimized scaling: 5 

732 Naive FLOP count: 8.000e+08 

733 Optimized FLOP count: 8.000e+05 

734 Theoretical speedup: 1000.000 

735 Largest intermediate: 1.000e+04 elements 

736 -------------------------------------------------------------------------- 

737 scaling current remaining 

738 -------------------------------------------------------------------------- 

739 5 abcd,ea->bcde fb,gc,hd,bcde->efgh 

740 5 bcde,fb->cdef gc,hd,cdef->efgh 

741 5 cdef,gc->defg hd,defg->efgh 

742 5 defg,hd->efgh efgh->efgh 

743 """ 

744 

745 # Figure out what the path really is 

746 path_type = optimize 

747 if path_type is True: 

748 path_type = 'greedy' 

749 if path_type is None: 

750 path_type = False 

751 

752 explicit_einsum_path = False 

753 memory_limit = None 

754 

755 # No optimization or a named path algorithm 

756 if (path_type is False) or isinstance(path_type, str): 

757 pass 

758 

759 # Given an explicit path 

760 elif len(path_type) and (path_type[0] == 'einsum_path'): 

761 explicit_einsum_path = True 

762 

763 # Path tuple with memory limit 

764 elif ((len(path_type) == 2) and isinstance(path_type[0], str) and 

765 isinstance(path_type[1], (int, float))): 

766 memory_limit = int(path_type[1]) 

767 path_type = path_type[0] 

768 

769 else: 

770 raise TypeError(f"Did not understand the path: {str(path_type)}") 

771 

772 # Hidden option, only einsum should call this 

773 einsum_call_arg = einsum_call 

774 

775 # Python side parsing 

776 input_subscripts, output_subscript, operands = ( 

777 _parse_einsum_input(operands) 

778 ) 

779 

780 # Build a few useful list and sets 

781 input_list = input_subscripts.split(',') 

782 num_inputs = len(input_list) 

783 input_sets = [set(x) for x in input_list] 

784 output_set = set(output_subscript) 

785 indices = set(input_subscripts.replace(',', '')) 

786 num_indices = len(indices) 

787 

788 # Get length of each unique dimension and ensure all dimensions are correct 

789 dimension_dict = {} 

790 for tnum, term in enumerate(input_list): 

791 sh = operands[tnum].shape 

792 if len(sh) != len(term): 

793 raise ValueError("Einstein sum subscript %s does not contain the " 

794 "correct number of indices for operand %d." 

795 % (input_subscripts[tnum], tnum)) 

796 for cnum, char in enumerate(term): 

797 dim = sh[cnum] 

798 

799 if char in dimension_dict.keys(): 

800 # For broadcasting cases we always want the largest dim size 

801 if dimension_dict[char] == 1: 

802 dimension_dict[char] = dim 

803 elif dim not in (1, dimension_dict[char]): 

804 raise ValueError("Size of label '%s' for operand %d (%d) " 

805 "does not match previous terms (%d)." 

806 % (char, tnum, dimension_dict[char], dim)) 

807 else: 

808 dimension_dict[char] = dim 

809 

810 # Compute size of each input array plus the output array 

811 size_list = [_compute_size_by_dict(term, dimension_dict) 

812 for term in input_list + [output_subscript]] 

813 max_size = max(size_list) 

814 

815 if memory_limit is None: 

816 memory_arg = max_size 

817 else: 

818 memory_arg = memory_limit 

819 

820 # Compute the path 

821 if explicit_einsum_path: 

822 path = path_type[1:] 

823 elif ( 

824 (path_type is False) 

825 or (num_inputs in [1, 2]) 

826 or (indices == output_set) 

827 ): 

828 # Nothing to be optimized, leave it to einsum 

829 path = [tuple(range(num_inputs))] 

830 elif path_type == "greedy": 

831 path = _greedy_path( 

832 input_sets, output_set, dimension_dict, memory_arg 

833 ) 

834 elif path_type == "optimal": 

835 path = _optimal_path( 

836 input_sets, output_set, dimension_dict, memory_arg 

837 ) 

838 else: 

839 raise KeyError("Path name %s not found", path_type) 

840 

841 cost_list, scale_list, size_list, contraction_list = [], [], [], [] 

842 

843 # Build contraction tuple (positions, gemm, einsum_str, remaining) 

844 for cnum, contract_inds in enumerate(path): 

845 # Make sure we remove inds from right to left 

846 contract_inds = tuple(sorted(contract_inds, reverse=True)) 

847 

848 contract = _find_contraction(contract_inds, input_sets, output_set) 

849 out_inds, input_sets, idx_removed, idx_contract = contract 

850 

851 if not einsum_call_arg: 

852 # these are only needed for printing info 

853 cost = _flop_count( 

854 idx_contract, idx_removed, len(contract_inds), dimension_dict 

855 ) 

856 cost_list.append(cost) 

857 scale_list.append(len(idx_contract)) 

858 size_list.append(_compute_size_by_dict(out_inds, dimension_dict)) 

859 

860 tmp_inputs = [] 

861 for x in contract_inds: 

862 tmp_inputs.append(input_list.pop(x)) 

863 

864 # Last contraction 

865 if (cnum - len(path)) == -1: 

866 idx_result = output_subscript 

867 else: 

868 sort_result = [(dimension_dict[ind], ind) for ind in out_inds] 

869 idx_result = "".join([x[1] for x in sorted(sort_result)]) 

870 

871 input_list.append(idx_result) 

872 einsum_str = ",".join(tmp_inputs) + "->" + idx_result 

873 

874 contraction = (contract_inds, einsum_str, input_list[:]) 

875 contraction_list.append(contraction) 

876 

877 if len(input_list) != 1: 

878 # Explicit "einsum_path" is usually trusted, but we detect this kind of 

879 # mistake in order to prevent from returning an intermediate value. 

880 raise RuntimeError( 

881 f"Invalid einsum_path is specified: {len(input_list) - 1} more " 

882 "operands has to be contracted.") 

883 

884 if einsum_call_arg: 

885 return (operands, contraction_list) 

886 

887 # Return the path along with a nice string representation 

888 overall_contraction = input_subscripts + "->" + output_subscript 

889 header = ("scaling", "current", "remaining") 

890 

891 # Compute naive cost 

892 # This isn't quite right, need to look into exactly how einsum does this 

893 inner_product = ( 

894 sum(len(set(x)) for x in input_subscripts.split(',')) - num_indices 

895 ) > 0 

896 naive_cost = _flop_count( 

897 indices, inner_product, num_inputs, dimension_dict 

898 ) 

899 

900 opt_cost = sum(cost_list) + 1 

901 speedup = naive_cost / opt_cost 

902 max_i = max(size_list) 

903 

904 path_print = f" Complete contraction: {overall_contraction}\n" 

905 path_print += f" Naive scaling: {num_indices}\n" 

906 path_print += " Optimized scaling: %d\n" % max(scale_list) 

907 path_print += f" Naive FLOP count: {naive_cost:.3e}\n" 

908 path_print += f" Optimized FLOP count: {opt_cost:.3e}\n" 

909 path_print += f" Theoretical speedup: {speedup:3.3f}\n" 

910 path_print += f" Largest intermediate: {max_i:.3e} elements\n" 

911 path_print += "-" * 74 + "\n" 

912 path_print += "%6s %24s %40s\n" % header 

913 path_print += "-" * 74 

914 

915 for n, contraction in enumerate(contraction_list): 

916 _, einsum_str, remaining = contraction 

917 remaining_str = ",".join(remaining) + "->" + output_subscript 

918 path_run = (scale_list[n], einsum_str, remaining_str) 

919 path_print += "\n%4d %24s %40s" % path_run 

920 

921 path = ['einsum_path'] + path 

922 return (path, path_print) 

923 

924 

925def _parse_eq_to_pure_multiplication(a_term, shape_a, b_term, shape_b, out): 

926 """If there are no contracted indices, then we can directly transpose and 

927 insert singleton dimensions into ``a`` and ``b`` such that (broadcast) 

928 elementwise multiplication performs the einsum. 

929 

930 No need to cache this as it is within the cached 

931 ``_parse_eq_to_batch_matmul``. 

932 

933 """ 

934 desired_a = "" 

935 desired_b = "" 

936 new_shape_a = [] 

937 new_shape_b = [] 

938 for ix in out: 

939 if ix in a_term: 

940 desired_a += ix 

941 new_shape_a.append(shape_a[a_term.index(ix)]) 

942 else: 

943 new_shape_a.append(1) 

944 if ix in b_term: 

945 desired_b += ix 

946 new_shape_b.append(shape_b[b_term.index(ix)]) 

947 else: 

948 new_shape_b.append(1) 

949 

950 if desired_a != a_term: 

951 eq_a = f"{a_term}->{desired_a}" 

952 else: 

953 eq_a = None 

954 if desired_b != b_term: 

955 eq_b = f"{b_term}->{desired_b}" 

956 else: 

957 eq_b = None 

958 

959 return ( 

960 eq_a, 

961 eq_b, 

962 new_shape_a, 

963 new_shape_b, 

964 None, # new_shape_ab, not needed since not fusing 

965 None, # perm_ab, not needed as we transpose a and b first 

966 True, # pure_multiplication=True 

967 ) 

968 

969 

970@functools.lru_cache(2**12) 

971def _parse_eq_to_batch_matmul(eq, shape_a, shape_b): 

972 """Cached parsing of a two term einsum equation into the necessary 

973 sequence of arguments for contracttion via batched matrix multiplication. 

974 The steps we need to specify are: 

975 

976 1. Remove repeated and trivial indices from the left and right terms, 

977 and transpose them, done as a single einsum. 

978 2. Fuse the remaining indices so we have two 3D tensors. 

979 3. Perform the batched matrix multiplication. 

980 4. Unfuse the output to get the desired final index order. 

981 

982 """ 

983 lhs, out = eq.split("->") 

984 a_term, b_term = lhs.split(",") 

985 

986 if len(a_term) != len(shape_a): 

987 raise ValueError(f"Term '{a_term}' does not match shape {shape_a}.") 

988 if len(b_term) != len(shape_b): 

989 raise ValueError(f"Term '{b_term}' does not match shape {shape_b}.") 

990 

991 sizes = {} 

992 singletons = set() 

993 

994 # parse left term to unique indices with size > 1 

995 left = {} 

996 for ix, d in zip(a_term, shape_a): 

997 if d == 1: 

998 # everything (including broadcasting) works nicely if simply ignore 

999 # such dimensions, but we do need to track if they appear in output 

1000 # and thus should be reintroduced later 

1001 singletons.add(ix) 

1002 continue 

1003 if sizes.setdefault(ix, d) != d: 

1004 # set and check size 

1005 raise ValueError( 

1006 f"Index {ix} has mismatched sizes {sizes[ix]} and {d}." 

1007 ) 

1008 left[ix] = True 

1009 

1010 # parse right term to unique indices with size > 1 

1011 right = {} 

1012 for ix, d in zip(b_term, shape_b): 

1013 # broadcast indices (size 1 on one input and size != 1 

1014 # on the other) should not be treated as singletons 

1015 if d == 1: 

1016 if ix not in left: 

1017 singletons.add(ix) 

1018 continue 

1019 singletons.discard(ix) 

1020 

1021 if sizes.setdefault(ix, d) != d: 

1022 # set and check size 

1023 raise ValueError( 

1024 f"Index {ix} has mismatched sizes {sizes[ix]} and {d}." 

1025 ) 

1026 right[ix] = True 

1027 

1028 # now we classify the unique size > 1 indices only 

1029 bat_inds = [] # appears on A, B, O 

1030 con_inds = [] # appears on A, B, . 

1031 a_keep = [] # appears on A, ., O 

1032 b_keep = [] # appears on ., B, O 

1033 # other indices (appearing on A or B only) will 

1034 # be summed or traced out prior to the matmul 

1035 for ix in left: 

1036 if right.pop(ix, False): 

1037 if ix in out: 

1038 bat_inds.append(ix) 

1039 else: 

1040 con_inds.append(ix) 

1041 elif ix in out: 

1042 a_keep.append(ix) 

1043 # now only indices unique to right remain 

1044 for ix in right: 

1045 if ix in out: 

1046 b_keep.append(ix) 

1047 

1048 if not con_inds: 

1049 # contraction is pure multiplication, prepare inputs differently 

1050 return _parse_eq_to_pure_multiplication( 

1051 a_term, shape_a, b_term, shape_b, out 

1052 ) 

1053 

1054 # only need the size one indices that appear in the output 

1055 singletons = [ix for ix in out if ix in singletons] 

1056 

1057 # take diagonal, remove any trivial axes and transpose left 

1058 desired_a = "".join((*bat_inds, *a_keep, *con_inds)) 

1059 if a_term != desired_a: 

1060 eq_a = f"{a_term}->{desired_a}" 

1061 else: 

1062 eq_a = None 

1063 

1064 # take diagonal, remove any trivial axes and transpose right 

1065 desired_b = "".join((*bat_inds, *con_inds, *b_keep)) 

1066 if b_term != desired_b: 

1067 eq_b = f"{b_term}->{desired_b}" 

1068 else: 

1069 eq_b = None 

1070 

1071 # then we want to reshape 

1072 if bat_inds: 

1073 lgroups = (bat_inds, a_keep, con_inds) 

1074 rgroups = (bat_inds, con_inds, b_keep) 

1075 ogroups = (bat_inds, a_keep, b_keep) 

1076 else: 

1077 # avoid size 1 batch dimension if no batch indices 

1078 lgroups = (a_keep, con_inds) 

1079 rgroups = (con_inds, b_keep) 

1080 ogroups = (a_keep, b_keep) 

1081 

1082 if any(len(group) != 1 for group in lgroups): 

1083 # need to fuse 'kept' and contracted indices 

1084 # (though could allow batch indices to be broadcast) 

1085 new_shape_a = tuple( 

1086 functools.reduce(operator.mul, (sizes[ix] for ix in ix_group), 1) 

1087 for ix_group in lgroups 

1088 ) 

1089 else: 

1090 new_shape_a = None 

1091 

1092 if any(len(group) != 1 for group in rgroups): 

1093 # need to fuse 'kept' and contracted indices 

1094 # (though could allow batch indices to be broadcast) 

1095 new_shape_b = tuple( 

1096 functools.reduce(operator.mul, (sizes[ix] for ix in ix_group), 1) 

1097 for ix_group in rgroups 

1098 ) 

1099 else: 

1100 new_shape_b = None 

1101 

1102 if any(len(group) != 1 for group in ogroups) or singletons: 

1103 new_shape_ab = (1,) * len(singletons) + tuple( 

1104 sizes[ix] for ix_group in ogroups for ix in ix_group 

1105 ) 

1106 else: 

1107 new_shape_ab = None 

1108 

1109 # then we might need to permute the matmul produced output: 

1110 out_produced = "".join((*singletons, *bat_inds, *a_keep, *b_keep)) 

1111 if out_produced != out: 

1112 perm_ab = tuple(out_produced.index(ix) for ix in out) 

1113 else: 

1114 perm_ab = None 

1115 

1116 return ( 

1117 eq_a, 

1118 eq_b, 

1119 new_shape_a, 

1120 new_shape_b, 

1121 new_shape_ab, 

1122 perm_ab, 

1123 False, # pure_multiplication=False 

1124 ) 

1125 

1126 

1127@functools.lru_cache(maxsize=64) 

1128def _parse_output_order(order, a_is_fcontig, b_is_fcontig): 

1129 order = order.upper() 

1130 if order == "K": 

1131 return None 

1132 elif order in "CF": 

1133 return order 

1134 elif order == "A": 

1135 if a_is_fcontig and b_is_fcontig: 

1136 return "F" 

1137 else: 

1138 return "C" 

1139 else: 

1140 raise ValueError( 

1141 "ValueError: order must be one of " 

1142 f"'C', 'F', 'A', or 'K' (got '{order}')" 

1143 ) 

1144 

1145 

1146def bmm_einsum(eq, a, b, out=None, **kwargs): 

1147 """Perform arbitrary pairwise einsums using only ``matmul``, or 

1148 ``multiply`` if no contracted indices are involved (plus maybe single term 

1149 ``einsum`` to prepare the terms individually). The logic for each is cached 

1150 based on the equation and array shape, and each step is only performed if 

1151 necessary. 

1152 

1153 Parameters 

1154 ---------- 

1155 eq : str 

1156 The einsum equation. 

1157 a : array_like 

1158 The first array to contract. 

1159 b : array_like 

1160 The second array to contract. 

1161 

1162 Returns 

1163 ------- 

1164 array_like 

1165 

1166 Notes 

1167 ----- 

1168 A fuller description of this algorithm, and original source for this 

1169 implementation, can be found at https://github.com/jcmgray/einsum_bmm. 

1170 """ 

1171 ( 

1172 eq_a, 

1173 eq_b, 

1174 new_shape_a, 

1175 new_shape_b, 

1176 new_shape_ab, 

1177 perm_ab, 

1178 pure_multiplication, 

1179 ) = _parse_eq_to_batch_matmul(eq, a.shape, b.shape) 

1180 

1181 # n.b. one could special case various cases to call c_einsum directly here 

1182 

1183 # need to handle `order` a little manually, since we do transpose 

1184 # operations before and potentially after the ufunc calls 

1185 output_order = _parse_output_order( 

1186 kwargs.pop("order", "K"), a.flags.f_contiguous, b.flags.f_contiguous 

1187 ) 

1188 

1189 # prepare left 

1190 if eq_a is not None: 

1191 # diagonals, sums, and tranpose 

1192 a = c_einsum(eq_a, a) 

1193 if new_shape_a is not None: 

1194 a = reshape(a, new_shape_a) 

1195 

1196 # prepare right 

1197 if eq_b is not None: 

1198 # diagonals, sums, and tranpose 

1199 b = c_einsum(eq_b, b) 

1200 if new_shape_b is not None: 

1201 b = reshape(b, new_shape_b) 

1202 

1203 if pure_multiplication: 

1204 # no contracted indices 

1205 if output_order is not None: 

1206 kwargs["order"] = output_order 

1207 

1208 # do the 'contraction' via multiplication! 

1209 return multiply(a, b, out=out, **kwargs) 

1210 

1211 # can only supply out here if no other reshaping / transposing 

1212 matmul_out_compatible = (new_shape_ab is None) and (perm_ab is None) 

1213 if matmul_out_compatible: 

1214 kwargs["out"] = out 

1215 

1216 # do the contraction! 

1217 ab = matmul(a, b, **kwargs) 

1218 

1219 # prepare the output 

1220 if new_shape_ab is not None: 

1221 ab = reshape(ab, new_shape_ab) 

1222 if perm_ab is not None: 

1223 ab = ab.transpose(perm_ab) 

1224 

1225 if (out is not None) and (not matmul_out_compatible): 

1226 # handle case where out is specified, but we also needed 

1227 # to reshape / transpose ``ab`` after the matmul 

1228 out[...] = ab 

1229 ab = out 

1230 elif output_order is not None: 

1231 ab = asanyarray(ab, order=output_order) 

1232 

1233 return ab 

1234 

1235 

1236def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs): 

1237 # Arguably we dispatch on more arguments than we really should; see note in 

1238 # _einsum_path_dispatcher for why. 

1239 yield from operands 

1240 yield out 

1241 

1242 

1243# Rewrite einsum to handle different cases 

1244@array_function_dispatch(_einsum_dispatcher, module='numpy') 

1245def einsum(*operands, out=None, optimize=False, **kwargs): 

1246 """ 

1247 einsum(subscripts, *operands, out=None, dtype=None, order='K', 

1248 casting='safe', optimize=False) 

1249 

1250 Evaluates the Einstein summation convention on the operands. 

1251 

1252 Using the Einstein summation convention, many common multi-dimensional, 

1253 linear algebraic array operations can be represented in a simple fashion. 

1254 In *implicit* mode `einsum` computes these values. 

1255 

1256 In *explicit* mode, `einsum` provides further flexibility to compute 

1257 other array operations that might not be considered classical Einstein 

1258 summation operations, by disabling, or forcing summation over specified 

1259 subscript labels. 

1260 

1261 See the notes and examples for clarification. 

1262 

1263 Parameters 

1264 ---------- 

1265 subscripts : str 

1266 Specifies the subscripts for summation as comma separated list of 

1267 subscript labels. An implicit (classical Einstein summation) 

1268 calculation is performed unless the explicit indicator '->' is 

1269 included as well as subscript labels of the precise output form. 

1270 operands : list of array_like 

1271 These are the arrays for the operation. 

1272 out : ndarray, optional 

1273 If provided, the calculation is done into this array. 

1274 dtype : {data-type, None}, optional 

1275 If provided, forces the calculation to use the data type specified. 

1276 Note that you may have to also give a more liberal `casting` 

1277 parameter to allow the conversions. Default is None. 

1278 order : {'C', 'F', 'A', 'K'}, optional 

1279 Controls the memory layout of the output. 'C' means it should 

1280 be C contiguous. 'F' means it should be Fortran contiguous, 

1281 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise. 

1282 'K' means it should be as close to the layout as the inputs as 

1283 is possible, including arbitrarily permuted axes. 

1284 Default is 'K'. 

1285 casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional 

1286 Controls what kind of data casting may occur. Setting this to 

1287 'unsafe' is not recommended, as it can adversely affect accumulations. 

1288 

1289 * 'no' means the data types should not be cast at all. 

1290 * 'equiv' means only byte-order changes are allowed. 

1291 * 'safe' means only casts which can preserve values are allowed. 

1292 * 'same_kind' means only safe casts or casts within a kind, 

1293 like float64 to float32, are allowed. 

1294 * 'unsafe' means any data conversions may be done. 

1295 

1296 Default is 'safe'. 

1297 optimize : {False, True, 'greedy', 'optimal'}, optional 

1298 Controls if intermediate optimization should occur. No optimization 

1299 will occur if False and True will default to the 'greedy' algorithm. 

1300 Also accepts an explicit contraction list from the ``np.einsum_path`` 

1301 function. See ``np.einsum_path`` for more details. Defaults to False. 

1302 

1303 Returns 

1304 ------- 

1305 output : ndarray 

1306 The calculation based on the Einstein summation convention. 

1307 

1308 See Also 

1309 -------- 

1310 einsum_path, dot, inner, outer, tensordot, linalg.multi_dot 

1311 einsum: 

1312 Similar verbose interface is provided by the 

1313 `einops <https://github.com/arogozhnikov/einops>`_ package to cover 

1314 additional operations: transpose, reshape/flatten, repeat/tile, 

1315 squeeze/unsqueeze and reductions. 

1316 The `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_ 

1317 optimizes contraction order for einsum-like expressions 

1318 in backend-agnostic manner. 

1319 

1320 Notes 

1321 ----- 

1322 The Einstein summation convention can be used to compute 

1323 many multi-dimensional, linear algebraic array operations. `einsum` 

1324 provides a succinct way of representing these. 

1325 

1326 A non-exhaustive list of these operations, 

1327 which can be computed by `einsum`, is shown below along with examples: 

1328 

1329 * Trace of an array, :py:func:`numpy.trace`. 

1330 * Return a diagonal, :py:func:`numpy.diag`. 

1331 * Array axis summations, :py:func:`numpy.sum`. 

1332 * Transpositions and permutations, :py:func:`numpy.transpose`. 

1333 * Matrix multiplication and dot product, :py:func:`numpy.matmul` 

1334 :py:func:`numpy.dot`. 

1335 * Vector inner and outer products, :py:func:`numpy.inner` 

1336 :py:func:`numpy.outer`. 

1337 * Broadcasting, element-wise and scalar multiplication, 

1338 :py:func:`numpy.multiply`. 

1339 * Tensor contractions, :py:func:`numpy.tensordot`. 

1340 * Chained array operations, in efficient calculation order, 

1341 :py:func:`numpy.einsum_path`. 

1342 

1343 The subscripts string is a comma-separated list of subscript labels, 

1344 where each label refers to a dimension of the corresponding operand. 

1345 Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)`` 

1346 is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label 

1347 appears only once, it is not summed, so ``np.einsum('i', a)`` 

1348 produces a view of ``a`` with no changes. A further example 

1349 ``np.einsum('ij,jk', a, b)`` describes traditional matrix multiplication 

1350 and is equivalent to :py:func:`np.matmul(a,b) <numpy.matmul>`. 

1351 Repeated subscript labels in one operand take the diagonal. 

1352 For example, ``np.einsum('ii', a)`` is equivalent to 

1353 :py:func:`np.trace(a) <numpy.trace>`. 

1354 

1355 In *implicit mode*, the chosen subscripts are important 

1356 since the axes of the output are reordered alphabetically. This 

1357 means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while 

1358 ``np.einsum('ji', a)`` takes its transpose. Additionally, 

1359 ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while, 

1360 ``np.einsum('ij,jh', a, b)`` returns the transpose of the 

1361 multiplication since subscript 'h' precedes subscript 'i'. 

1362 

1363 In *explicit mode* the output can be directly controlled by 

1364 specifying output subscript labels. This requires the 

1365 identifier '->' as well as the list of output subscript labels. 

1366 This feature increases the flexibility of the function since 

1367 summing can be disabled or forced when required. The call 

1368 ``np.einsum('i->', a)`` is like :py:func:`np.sum(a) <numpy.sum>` 

1369 if ``a`` is a 1-D array, and ``np.einsum('ii->i', a)`` 

1370 is like :py:func:`np.diag(a) <numpy.diag>` if ``a`` is a square 2-D array. 

1371 The difference is that `einsum` does not allow broadcasting by default. 

1372 Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the 

1373 order of the output subscript labels and therefore returns matrix 

1374 multiplication, unlike the example above in implicit mode. 

1375 

1376 To enable and control broadcasting, use an ellipsis. Default 

1377 NumPy-style broadcasting is done by adding an ellipsis 

1378 to the left of each term, like ``np.einsum('...ii->...i', a)``. 

1379 ``np.einsum('...i->...', a)`` is like 

1380 :py:func:`np.sum(a, axis=-1) <numpy.sum>` for array ``a`` of any shape. 

1381 To take the trace along the first and last axes, 

1382 you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix 

1383 product with the left-most indices instead of rightmost, one can do 

1384 ``np.einsum('ij...,jk...->ik...', a, b)``. 

1385 

1386 When there is only one operand, no axes are summed, and no output 

1387 parameter is provided, a view into the operand is returned instead 

1388 of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)`` 

1389 produces a view (changed in version 1.10.0). 

1390 

1391 `einsum` also provides an alternative way to provide the subscripts and 

1392 operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``. 

1393 If the output shape is not provided in this format `einsum` will be 

1394 calculated in implicit mode, otherwise it will be performed explicitly. 

1395 The examples below have corresponding `einsum` calls with the two 

1396 parameter methods. 

1397 

1398 Views returned from einsum are now writeable whenever the input array 

1399 is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now 

1400 have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>` 

1401 and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal 

1402 of a 2D array. 

1403 

1404 Added the ``optimize`` argument which will optimize the contraction order 

1405 of an einsum expression. For a contraction with three or more operands 

1406 this can greatly increase the computational efficiency at the cost of 

1407 a larger memory footprint during computation. 

1408 

1409 Typically a 'greedy' algorithm is applied which empirical tests have shown 

1410 returns the optimal path in the majority of cases. In some cases 'optimal' 

1411 will return the superlative path through a more expensive, exhaustive 

1412 search. For iterative calculations it may be advisable to calculate 

1413 the optimal path once and reuse that path by supplying it as an argument. 

1414 An example is given below. 

1415 

1416 See :py:func:`numpy.einsum_path` for more details. 

1417 

1418 Examples 

1419 -------- 

1420 >>> a = np.arange(25).reshape(5,5) 

1421 >>> b = np.arange(5) 

1422 >>> c = np.arange(6).reshape(2,3) 

1423 

1424 Trace of a matrix: 

1425 

1426 >>> np.einsum('ii', a) 

1427 60 

1428 >>> np.einsum(a, [0,0]) 

1429 60 

1430 >>> np.trace(a) 

1431 60 

1432 

1433 Extract the diagonal (requires explicit form): 

1434 

1435 >>> np.einsum('ii->i', a) 

1436 array([ 0, 6, 12, 18, 24]) 

1437 >>> np.einsum(a, [0,0], [0]) 

1438 array([ 0, 6, 12, 18, 24]) 

1439 >>> np.diag(a) 

1440 array([ 0, 6, 12, 18, 24]) 

1441 

1442 Sum over an axis (requires explicit form): 

1443 

1444 >>> np.einsum('ij->i', a) 

1445 array([ 10, 35, 60, 85, 110]) 

1446 >>> np.einsum(a, [0,1], [0]) 

1447 array([ 10, 35, 60, 85, 110]) 

1448 >>> np.sum(a, axis=1) 

1449 array([ 10, 35, 60, 85, 110]) 

1450 

1451 For higher dimensional arrays summing a single axis can be done 

1452 with ellipsis: 

1453 

1454 >>> np.einsum('...j->...', a) 

1455 array([ 10, 35, 60, 85, 110]) 

1456 >>> np.einsum(a, [Ellipsis,1], [Ellipsis]) 

1457 array([ 10, 35, 60, 85, 110]) 

1458 

1459 Compute a matrix transpose, or reorder any number of axes: 

1460 

1461 >>> np.einsum('ji', c) 

1462 array([[0, 3], 

1463 [1, 4], 

1464 [2, 5]]) 

1465 >>> np.einsum('ij->ji', c) 

1466 array([[0, 3], 

1467 [1, 4], 

1468 [2, 5]]) 

1469 >>> np.einsum(c, [1,0]) 

1470 array([[0, 3], 

1471 [1, 4], 

1472 [2, 5]]) 

1473 >>> np.transpose(c) 

1474 array([[0, 3], 

1475 [1, 4], 

1476 [2, 5]]) 

1477 

1478 Vector inner products: 

1479 

1480 >>> np.einsum('i,i', b, b) 

1481 30 

1482 >>> np.einsum(b, [0], b, [0]) 

1483 30 

1484 >>> np.inner(b,b) 

1485 30 

1486 

1487 Matrix vector multiplication: 

1488 

1489 >>> np.einsum('ij,j', a, b) 

1490 array([ 30, 80, 130, 180, 230]) 

1491 >>> np.einsum(a, [0,1], b, [1]) 

1492 array([ 30, 80, 130, 180, 230]) 

1493 >>> np.dot(a, b) 

1494 array([ 30, 80, 130, 180, 230]) 

1495 >>> np.einsum('...j,j', a, b) 

1496 array([ 30, 80, 130, 180, 230]) 

1497 

1498 Broadcasting and scalar multiplication: 

1499 

1500 >>> np.einsum('..., ...', 3, c) 

1501 array([[ 0, 3, 6], 

1502 [ 9, 12, 15]]) 

1503 >>> np.einsum(',ij', 3, c) 

1504 array([[ 0, 3, 6], 

1505 [ 9, 12, 15]]) 

1506 >>> np.einsum(3, [Ellipsis], c, [Ellipsis]) 

1507 array([[ 0, 3, 6], 

1508 [ 9, 12, 15]]) 

1509 >>> np.multiply(3, c) 

1510 array([[ 0, 3, 6], 

1511 [ 9, 12, 15]]) 

1512 

1513 Vector outer product: 

1514 

1515 >>> np.einsum('i,j', np.arange(2)+1, b) 

1516 array([[0, 1, 2, 3, 4], 

1517 [0, 2, 4, 6, 8]]) 

1518 >>> np.einsum(np.arange(2)+1, [0], b, [1]) 

1519 array([[0, 1, 2, 3, 4], 

1520 [0, 2, 4, 6, 8]]) 

1521 >>> np.outer(np.arange(2)+1, b) 

1522 array([[0, 1, 2, 3, 4], 

1523 [0, 2, 4, 6, 8]]) 

1524 

1525 Tensor contraction: 

1526 

1527 >>> a = np.arange(60.).reshape(3,4,5) 

1528 >>> b = np.arange(24.).reshape(4,3,2) 

1529 >>> np.einsum('ijk,jil->kl', a, b) 

1530 array([[4400., 4730.], 

1531 [4532., 4874.], 

1532 [4664., 5018.], 

1533 [4796., 5162.], 

1534 [4928., 5306.]]) 

1535 >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3]) 

1536 array([[4400., 4730.], 

1537 [4532., 4874.], 

1538 [4664., 5018.], 

1539 [4796., 5162.], 

1540 [4928., 5306.]]) 

1541 >>> np.tensordot(a,b, axes=([1,0],[0,1])) 

1542 array([[4400., 4730.], 

1543 [4532., 4874.], 

1544 [4664., 5018.], 

1545 [4796., 5162.], 

1546 [4928., 5306.]]) 

1547 

1548 Writeable returned arrays (since version 1.10.0): 

1549 

1550 >>> a = np.zeros((3, 3)) 

1551 >>> np.einsum('ii->i', a)[:] = 1 

1552 >>> a 

1553 array([[1., 0., 0.], 

1554 [0., 1., 0.], 

1555 [0., 0., 1.]]) 

1556 

1557 Example of ellipsis use: 

1558 

1559 >>> a = np.arange(6).reshape((3,2)) 

1560 >>> b = np.arange(12).reshape((4,3)) 

1561 >>> np.einsum('ki,jk->ij', a, b) 

1562 array([[10, 28, 46, 64], 

1563 [13, 40, 67, 94]]) 

1564 >>> np.einsum('ki,...k->i...', a, b) 

1565 array([[10, 28, 46, 64], 

1566 [13, 40, 67, 94]]) 

1567 >>> np.einsum('k...,jk', a, b) 

1568 array([[10, 28, 46, 64], 

1569 [13, 40, 67, 94]]) 

1570 

1571 Chained array operations. For more complicated contractions, speed ups 

1572 might be achieved by repeatedly computing a 'greedy' path or pre-computing 

1573 the 'optimal' path and repeatedly applying it, using an `einsum_path` 

1574 insertion (since version 1.12.0). Performance improvements can be 

1575 particularly significant with larger arrays: 

1576 

1577 >>> a = np.ones(64).reshape(2,4,8) 

1578 

1579 Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.) 

1580 

1581 >>> for iteration in range(500): 

1582 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a) 

1583 

1584 Sub-optimal `einsum` (due to repeated path calculation time): ~330ms 

1585 

1586 >>> for iteration in range(500): 

1587 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, 

1588 ... optimize='optimal') 

1589 

1590 Greedy `einsum` (faster optimal path approximation): ~160ms 

1591 

1592 >>> for iteration in range(500): 

1593 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy') 

1594 

1595 Optimal `einsum` (best usage pattern in some use cases): ~110ms 

1596 

1597 >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, 

1598 ... optimize='optimal')[0] 

1599 >>> for iteration in range(500): 

1600 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path) 

1601 

1602 """ 

1603 # Special handling if out is specified 

1604 specified_out = out is not None 

1605 

1606 # If no optimization, run pure einsum 

1607 if optimize is False: 

1608 if specified_out: 

1609 kwargs['out'] = out 

1610 return c_einsum(*operands, **kwargs) 

1611 

1612 # Check the kwargs to avoid a more cryptic error later, without having to 

1613 # repeat default values here 

1614 valid_einsum_kwargs = ['dtype', 'order', 'casting'] 

1615 unknown_kwargs = [k for (k, v) in kwargs.items() if 

1616 k not in valid_einsum_kwargs] 

1617 if len(unknown_kwargs): 

1618 raise TypeError(f"Did not understand the following kwargs: {unknown_kwargs}") 

1619 

1620 # Build the contraction list and operand 

1621 operands, contraction_list = einsum_path(*operands, optimize=optimize, 

1622 einsum_call=True) 

1623 

1624 # Start contraction loop 

1625 for num, contraction in enumerate(contraction_list): 

1626 inds, einsum_str, _ = contraction 

1627 tmp_operands = [operands.pop(x) for x in inds] 

1628 

1629 # Do we need to deal with the output? 

1630 handle_out = specified_out and ((num + 1) == len(contraction_list)) 

1631 

1632 # If out was specified 

1633 if handle_out: 

1634 kwargs["out"] = out 

1635 

1636 if len(tmp_operands) == 2: 

1637 # Call (batched) matrix multiplication if possible 

1638 new_view = bmm_einsum(einsum_str, *tmp_operands, **kwargs) 

1639 else: 

1640 # Call einsum 

1641 new_view = c_einsum(einsum_str, *tmp_operands, **kwargs) 

1642 

1643 # Append new items and dereference what we can 

1644 operands.append(new_view) 

1645 del tmp_operands, new_view 

1646 

1647 if specified_out: 

1648 return out 

1649 else: 

1650 return operands[0]