Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/paths.py: 17%

501 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:41 +0000

1""" 

2Contains the path technology behind opt_einsum in addition to several path helpers 

3""" 

4 

5import functools 

6import heapq 

7import itertools 

8import operator 

9import random 

10import re 

11from collections import Counter, OrderedDict, defaultdict 

12from typing import Any, Callable 

13from typing import Counter as CounterType 

14from typing import Dict, FrozenSet, Generator, List, Optional, Sequence, Set, Tuple, Union 

15 

16import numpy as np 

17 

18from .helpers import compute_size_by_dict, flop_count 

19from .typing import ArrayIndexType, PathType 

20 

21__all__ = [ 

22 "optimal", 

23 "BranchBound", 

24 "branch", 

25 "greedy", 

26 "auto", 

27 "auto_hq", 

28 "get_path_fn", 

29 "DynamicProgramming", 

30 "dynamic_programming", 

31] 

32 

33_UNLIMITED_MEM = {-1, None, float("inf")} 

34 

35 

36class PathOptimizer: 

37 """Base class for different path optimizers to inherit from. 

38 

39 Subclassed optimizers should define a call method with signature: 

40 

41 ```python 

42 def __call__(self, inputs, output, size_dict, memory_limit=None): 

43 \"\"\" 

44 **Parameters:** 

45 ---------- 

46 inputs : list[set[str]] 

47 The indices of each input array. 

48 outputs : set[str] 

49 The output indices 

50 size_dict : dict[str, int] 

51 The size of each index 

52 memory_limit : int, optional 

53 If given, the maximum allowed memory. 

54 \"\"\" 

55 # ... compute path here ... 

56 return path 

57 ``` 

58 

59 where `path` is a list of int-tuples specifying a contraction order. 

60 """ 

61 

62 def _check_args_against_first_call( 

63 self, 

64 inputs: List[ArrayIndexType], 

65 output: ArrayIndexType, 

66 size_dict: Dict[str, int], 

67 ) -> None: 

68 """Utility that stateful optimizers can use to ensure they are not 

69 called with different contractions across separate runs. 

70 """ 

71 args = (inputs, output, size_dict) 

72 if not hasattr(self, "_first_call_args"): 

73 # simply set the attribute as currently there is no global PathOptimizer init 

74 self._first_call_args = args 

75 elif args != self._first_call_args: 

76 raise ValueError( 

77 "The arguments specifying the contraction that this path optimizer " 

78 "instance was called with have changed - try creating a new instance." 

79 ) 

80 

81 def __call__( 

82 self, 

83 inputs: List[ArrayIndexType], 

84 output: ArrayIndexType, 

85 size_dict: Dict[str, int], 

86 memory_limit: Optional[int] = None, 

87 ) -> PathType: 

88 raise NotImplementedError 

89 

90 

91def ssa_to_linear(ssa_path: PathType) -> PathType: 

92 """ 

93 Convert a path with static single assignment ids to a path with recycled 

94 linear ids. For example: 

95 

96 ```python 

97 ssa_to_linear([(0, 3), (2, 4), (1, 5)]) 

98 #> [(0, 3), (1, 2), (0, 1)] 

99 ``` 

100 """ 

101 ids = np.arange(1 + max(map(max, ssa_path)), dtype=np.int32) # type: ignore 

102 path = [] 

103 for ssa_ids in ssa_path: 

104 path.append(tuple(int(ids[ssa_id]) for ssa_id in ssa_ids)) 

105 for ssa_id in ssa_ids: 

106 ids[ssa_id:] -= 1 

107 return path 

108 

109 

110def linear_to_ssa(path: PathType) -> PathType: 

111 """ 

112 Convert a path with recycled linear ids to a path with static single 

113 assignment ids. For example:: 

114 

115 ```python 

116 linear_to_ssa([(0, 3), (1, 2), (0, 1)]) 

117 #> [(0, 3), (2, 4), (1, 5)] 

118 ``` 

119 """ 

120 num_inputs = sum(map(len, path)) - len(path) + 1 

121 linear_to_ssa = list(range(num_inputs)) 

122 new_ids = itertools.count(num_inputs) 

123 ssa_path = [] 

124 for ids in path: 

125 ssa_path.append(tuple(linear_to_ssa[id_] for id_ in ids)) 

126 for id_ in sorted(ids, reverse=True): 

127 del linear_to_ssa[id_] 

128 linear_to_ssa.append(next(new_ids)) 

129 return ssa_path 

130 

131 

132def calc_k12_flops( 

133 inputs: Tuple[FrozenSet[str]], 

134 output: FrozenSet[str], 

135 remaining: FrozenSet[int], 

136 i: int, 

137 j: int, 

138 size_dict: Dict[str, int], 

139) -> Tuple[FrozenSet[str], int]: 

140 """ 

141 Calculate the resulting indices and flops for a potential pairwise 

142 contraction - used in the recursive (optimal/branch) algorithms. 

143 

144 **Parameters:** 

145 

146 - **inputs** - *(tuple[frozenset[str]])* The indices of each tensor in this contraction, note this includes 

147 tensors unavailable to contract as static single assignment is used -> 

148 contracted tensors are not removed from the list. 

149 - **output** - *(frozenset[str])* The set of output indices for the whole contraction. 

150 - **remaining** - *(frozenset[int])* The set of indices (corresponding to ``inputs``) of tensors still available to contract. 

151 - **i** - *(int)* Index of potential tensor to contract. 

152 - **j** - *(int)* Index of potential tensor to contract. 

153 - **size_dict : dict[str, int] )* Size mapping of all the indices. 

154 

155 **Returns:** 

156 

157 - **k12** - *(frozenset)* The resulting indices of the potential tensor. 

158 - **cost** - *(int)* Estimated flop count of operation. 

159 """ 

160 k1, k2 = inputs[i], inputs[j] 

161 either = k1 | k2 

162 shared = k1 & k2 

163 keep = frozenset.union(output, *map(inputs.__getitem__, remaining - {i, j})) 

164 

165 k12 = either & keep 

166 cost = flop_count(either, bool(shared - keep), 2, size_dict) 

167 

168 return k12, cost 

169 

170 

171def _compute_oversize_flops( 

172 inputs: Tuple[FrozenSet[str]], 

173 remaining: List[ArrayIndexType], 

174 output: ArrayIndexType, 

175 size_dict: Dict[str, int], 

176) -> int: 

177 """ 

178 Compute the flop count for a contraction of all remaining arguments. This 

179 is used when a memory limit means that no pairwise contractions can be made. 

180 """ 

181 idx_contraction = frozenset.union(*map(inputs.__getitem__, remaining)) # type: ignore 

182 inner = idx_contraction - output 

183 num_terms = len(remaining) 

184 return flop_count(idx_contraction, bool(inner), num_terms, size_dict) 

185 

186 

187def optimal( 

188 inputs: List[ArrayIndexType], 

189 output: ArrayIndexType, 

190 size_dict: Dict[str, int], 

191 memory_limit: Optional[int] = None, 

192) -> PathType: 

193 """ 

194 Computes all possible pair contractions in a depth-first recursive manner, 

195 sieving results based on `memory_limit` and the best path found so far. 

196 **Returns:** the lowest cost path. This algorithm scales factoriallly with 

197 respect to the elements in the list `input_sets`. 

198 

199 **Parameters:** 

200 

201 - **inputs** - *(list)* List of sets that represent the lhs side of the einsum subscript. 

202 - **output** - *(set)* Set that represents the rhs side of the overall einsum subscript. 

203 - **size_dict** - *(dictionary)* Dictionary of index sizes. 

204 - **memory_limit** - *(int)* The maximum number of elements in a temporary array. 

205 

206 **Returns:** 

207 

208 - **path** - *(list)* The optimal contraction order within the memory limit constraint. 

209 

210 **Examples:** 

211 

212 ```python 

213 isets = [set('abd'), set('ac'), set('bdc')] 

214 oset = set('') 

215 idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 

216 optimal(isets, oset, idx_sizes, 5000) 

217 #> [(0, 2), (0, 1)] 

218 ``` 

219 """ 

220 inputs_set = tuple(map(frozenset, inputs)) # type: ignore 

221 output_set = frozenset(output) 

222 

223 best_flops = {"flops": float("inf")} 

224 best_ssa_path = {"ssa_path": (tuple(range(len(inputs))),)} 

225 size_cache: Dict[FrozenSet[str], int] = {} 

226 result_cache: Dict[Tuple[ArrayIndexType, ArrayIndexType], Tuple[FrozenSet[str], int]] = {} 

227 

228 def _optimal_iterate(path, remaining, inputs, flops): 

229 

230 # reached end of path (only ever get here if flops is best found so far) 

231 if len(remaining) == 1: 

232 best_flops["flops"] = flops 

233 best_ssa_path["ssa_path"] = path 

234 return 

235 

236 # check all possible remaining paths 

237 for i, j in itertools.combinations(remaining, 2): 

238 if i > j: 

239 i, j = j, i 

240 key = (inputs[i], inputs[j]) 

241 try: 

242 k12, flops12 = result_cache[key] 

243 except KeyError: 

244 k12, flops12 = result_cache[key] = calc_k12_flops(inputs, output_set, remaining, i, j, size_dict) 

245 

246 # sieve based on current best flops 

247 new_flops = flops + flops12 

248 if new_flops >= best_flops["flops"]: 

249 continue 

250 

251 # sieve based on memory limit 

252 if memory_limit not in _UNLIMITED_MEM: 

253 try: 

254 size12 = size_cache[k12] 

255 except KeyError: 

256 size12 = size_cache[k12] = compute_size_by_dict(k12, size_dict) 

257 

258 # possibly terminate this path with an all-terms einsum 

259 if size12 > memory_limit: 

260 new_flops = flops + _compute_oversize_flops(inputs, remaining, output_set, size_dict) 

261 if new_flops < best_flops["flops"]: 

262 best_flops["flops"] = new_flops 

263 best_ssa_path["ssa_path"] = path + (tuple(remaining),) 

264 continue 

265 

266 # add contraction and recurse into all remaining 

267 _optimal_iterate( 

268 path=path + ((i, j),), 

269 inputs=inputs + (k12,), 

270 remaining=remaining - {i, j} | {len(inputs)}, 

271 flops=new_flops, 

272 ) 

273 

274 _optimal_iterate(path=(), inputs=inputs_set, remaining=set(range(len(inputs))), flops=0) 

275 

276 return ssa_to_linear(best_ssa_path["ssa_path"]) 

277 

278 

279# functions for comparing which of two paths is 'better' 

280 

281 

282def better_flops_first(flops: int, size: int, best_flops: int, best_size: int) -> bool: 

283 return (flops, size) < (best_flops, best_size) 

284 

285 

286def better_size_first(flops: int, size: int, best_flops: int, best_size: int) -> bool: 

287 return (size, flops) < (best_size, best_flops) 

288 

289 

290_BETTER_FNS = { 

291 "flops": better_flops_first, 

292 "size": better_size_first, 

293} 

294 

295 

296def get_better_fn(key: str) -> Callable[[int, int, int, int], bool]: 

297 return _BETTER_FNS[key] 

298 

299 

300# functions for assigning a heuristic 'cost' to a potential contraction 

301 

302 

303def cost_memory_removed(size12: int, size1: int, size2: int, k12: int, k1: int, k2: int) -> float: 

304 """The default heuristic cost, corresponding to the total reduction in 

305 memory of performing a contraction. 

306 """ 

307 return size12 - size1 - size2 

308 

309 

310def cost_memory_removed_jitter(size12: int, size1: int, size2: int, k12: int, k1: int, k2: int) -> float: 

311 """Like memory-removed, but with a slight amount of noise that breaks ties 

312 and thus jumbles the contractions a bit. 

313 """ 

314 return random.gauss(1.0, 0.01) * (size12 - size1 - size2) 

315 

316 

317_COST_FNS = { 

318 "memory-removed": cost_memory_removed, 

319 "memory-removed-jitter": cost_memory_removed_jitter, 

320} 

321 

322 

323class BranchBound(PathOptimizer): 

324 """ 

325 Explores possible pair contractions in a depth-first recursive manner like 

326 the `optimal` approach, but with extra heuristic early pruning of branches 

327 as well sieving by `memory_limit` and the best path found so far. **Returns:** 

328 the lowest cost path. This algorithm still scales factorially with respect 

329 to the elements in the list `input_sets` if `nbranch` is not set, but it 

330 scales exponentially like `nbranch**len(input_sets)` otherwise. 

331 

332 **Parameters:** 

333 

334 - **nbranch** - *(None or int, optional)* How many branches to explore at each contraction step. If None, explore 

335 all possible branches. If an integer, branch into this many paths at 

336 each step. Defaults to None. 

337 - **cutoff_flops_factor** - *(float, optional)* If at any point, a path is doing this much worse than the best path 

338 found so far was, terminate it. The larger this is made, the more paths 

339 will be fully explored and the slower the algorithm. Defaults to 4. 

340 - **minimize** - *({'flops', 'size'}, optional)* Whether to optimize the path with regard primarily to the total 

341 estimated flop-count, or the size of the largest intermediate. The 

342 option not chosen will still be used as a secondary criterion. 

343 - **cost_fn** - *(callable, optional)* A function that returns a heuristic 'cost' of a potential contraction 

344 with which to sort candidates. Should have signature 

345 `cost_fn(size12, size1, size2, k12, k1, k2)`. 

346 """ 

347 

348 def __init__( 

349 self, 

350 nbranch=None, 

351 cutoff_flops_factor=4, 

352 minimize="flops", 

353 cost_fn="memory-removed", 

354 ): 

355 if (nbranch is not None) and nbranch < 1: 

356 raise ValueError(f"The number of branches must be at least one, `nbranch={nbranch}`.") 

357 

358 self.nbranch = nbranch 

359 self.cutoff_flops_factor = cutoff_flops_factor 

360 self.minimize = minimize 

361 self.cost_fn = _COST_FNS.get(cost_fn, cost_fn) 

362 

363 self.better = get_better_fn(minimize) 

364 self.best = {"flops": float("inf"), "size": float("inf")} 

365 self.best_progress = defaultdict(lambda: float("inf")) 

366 

367 @property 

368 def path(self) -> PathType: 

369 return ssa_to_linear(self.best["ssa_path"]) 

370 

371 def __call__( 

372 self, 

373 inputs_: List[ArrayIndexType], 

374 output_: ArrayIndexType, 

375 size_dict: Dict[str, int], 

376 memory_limit: Optional[int] = None, 

377 ) -> PathType: 

378 """ 

379 

380 **Parameters:** 

381 

382 - **input_sets** - *(list)* List of sets that represent the lhs side of the einsum subscript 

383 - **output_set** - *(set)* Set that represents the rhs side of the overall einsum subscript 

384 - **idx_dict** - *(dictionary)* Dictionary of index sizes 

385 - **memory_limit** - *(int)* The maximum number of elements in a temporary array 

386 

387 **Returns:** 

388 

389 - **path** - *(list)* The contraction order within the memory limit constraint. 

390 

391 **Examples:** 

392 

393 ```python 

394 isets = [set('abd'), set('ac'), set('bdc')] 

395 oset = set('') 

396 idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 

397 optimal(isets, oset, idx_sizes, 5000) 

398 #> [(0, 2), (0, 1)] 

399 """ 

400 self._check_args_against_first_call(inputs_, output_, size_dict) 

401 

402 inputs: Tuple[FrozenSet[str]] = tuple(map(frozenset, inputs_)) # type: ignore 

403 output: FrozenSet[str] = frozenset(output_) 

404 

405 size_cache = {k: compute_size_by_dict(k, size_dict) for k in inputs} 

406 result_cache: Dict[Tuple[FrozenSet[str], FrozenSet[str]], Tuple[FrozenSet[str], int]] = {} 

407 

408 def _branch_iterate(path, inputs, remaining, flops, size): 

409 

410 # reached end of path (only ever get here if flops is best found so far) 

411 if len(remaining) == 1: 

412 self.best["size"] = size 

413 self.best["flops"] = flops 

414 self.best["ssa_path"] = path 

415 return 

416 

417 def _assess_candidate(k1: FrozenSet[str], k2: FrozenSet[str], i: int, j: int) -> Any: 

418 # find resulting indices and flops 

419 try: 

420 k12, flops12 = result_cache[k1, k2] 

421 except KeyError: 

422 k12, flops12 = result_cache[k1, k2] = calc_k12_flops(inputs, output, remaining, i, j, size_dict) 

423 

424 try: 

425 size12 = size_cache[k12] 

426 except KeyError: 

427 size12 = size_cache[k12] = compute_size_by_dict(k12, size_dict) 

428 

429 new_flops = flops + flops12 

430 new_size = max(size, size12) 

431 

432 # sieve based on current best i.e. check flops and size still better 

433 if not self.better(new_flops, new_size, self.best["flops"], self.best["size"]): 

434 return None 

435 

436 # compare to how the best method was doing as this point 

437 if new_flops < self.best_progress[len(inputs)]: 

438 self.best_progress[len(inputs)] = new_flops 

439 # sieve based on current progress relative to best 

440 elif new_flops > self.cutoff_flops_factor * self.best_progress[len(inputs)]: 

441 return None 

442 

443 # sieve based on memory limit 

444 if (memory_limit not in _UNLIMITED_MEM) and (size12 > memory_limit): # type: ignore 

445 # terminate path here, but check all-terms contract first 

446 new_flops = flops + _compute_oversize_flops(inputs, remaining, output_, size_dict) 

447 if new_flops < self.best["flops"]: 

448 self.best["flops"] = new_flops 

449 self.best["ssa_path"] = path + (tuple(remaining),) 

450 return None 

451 

452 # set cost heuristic in order to locally sort possible contractions 

453 size1, size2 = size_cache[inputs[i]], size_cache[inputs[j]] 

454 cost = self.cost_fn(size12, size1, size2, k12, k1, k2) 

455 

456 return cost, flops12, new_flops, new_size, (i, j), k12 

457 

458 # check all possible remaining paths 

459 candidates = [] 

460 for i, j in itertools.combinations(remaining, 2): 

461 if i > j: 

462 i, j = j, i 

463 k1, k2 = inputs[i], inputs[j] 

464 

465 # initially ignore outer products 

466 if k1.isdisjoint(k2): 

467 continue 

468 

469 candidate = _assess_candidate(k1, k2, i, j) 

470 if candidate: 

471 heapq.heappush(candidates, candidate) 

472 

473 # assess outer products if nothing left 

474 if not candidates: 

475 for i, j in itertools.combinations(remaining, 2): 

476 if i > j: 

477 i, j = j, i 

478 k1, k2 = inputs[i], inputs[j] 

479 candidate = _assess_candidate(k1, k2, i, j) 

480 if candidate: 

481 heapq.heappush(candidates, candidate) 

482 

483 # recurse into all or some of the best candidate contractions 

484 bi = 0 

485 while (self.nbranch is None or bi < self.nbranch) and candidates: 

486 _, _, new_flops, new_size, (i, j), k12 = heapq.heappop(candidates) 

487 _branch_iterate( 

488 path=path + ((i, j),), 

489 inputs=inputs + (k12,), 

490 remaining=(remaining - {i, j}) | {len(inputs)}, 

491 flops=new_flops, 

492 size=new_size, 

493 ) 

494 bi += 1 

495 

496 _branch_iterate(path=(), inputs=inputs, remaining=set(range(len(inputs))), flops=0, size=0) 

497 

498 return self.path 

499 

500 

501def branch( 

502 inputs: List[ArrayIndexType], 

503 output: ArrayIndexType, 

504 size_dict: Dict[str, int], 

505 memory_limit: Optional[int] = None, 

506 **optimizer_kwargs: Dict[str, Any], 

507) -> PathType: 

508 optimizer = BranchBound(**optimizer_kwargs) 

509 return optimizer(inputs, output, size_dict, memory_limit) 

510 

511 

512branch_all = functools.partial(branch, nbranch=None) 

513branch_2 = functools.partial(branch, nbranch=2) 

514branch_1 = functools.partial(branch, nbranch=1) 

515 

516GreedyCostType = Tuple[int, int, int] 

517GreedyContractionType = Tuple[GreedyCostType, ArrayIndexType, ArrayIndexType, ArrayIndexType] # Cost, t1,t2->t3 

518 

519 

520def _get_candidate( 

521 output: ArrayIndexType, 

522 sizes: Dict[str, int], 

523 remaining: Dict[ArrayIndexType, int], 

524 footprints: Dict[ArrayIndexType, int], 

525 dim_ref_counts: Dict[int, Set[str]], 

526 k1: ArrayIndexType, 

527 k2: ArrayIndexType, 

528 cost_fn: Any, 

529) -> GreedyContractionType: 

530 either = k1 | k2 

531 two = k1 & k2 

532 one = either - two 

533 k12 = (either & output) | (two & dim_ref_counts[3]) | (one & dim_ref_counts[2]) 

534 cost = cost_fn( 

535 compute_size_by_dict(k12, sizes), 

536 footprints[k1], 

537 footprints[k2], 

538 k12, 

539 k1, 

540 k2, 

541 ) 

542 id1 = remaining[k1] 

543 id2 = remaining[k2] 

544 if id1 > id2: 

545 k1, id1, k2, id2 = k2, id2, k1, id1 

546 cost = cost, id2, id1 # break ties to ensure determinism 

547 return cost, k1, k2, k12 

548 

549 

550def _push_candidate( 

551 output: ArrayIndexType, 

552 sizes: Dict[str, Any], 

553 remaining: Dict[ArrayIndexType, int], 

554 footprints: Dict[ArrayIndexType, int], 

555 dim_ref_counts: Dict[int, Set[str]], 

556 k1: ArrayIndexType, 

557 k2s: List[ArrayIndexType], 

558 queue: List[GreedyContractionType], 

559 push_all: bool, 

560 cost_fn: Any, 

561) -> None: 

562 candidates = (_get_candidate(output, sizes, remaining, footprints, dim_ref_counts, k1, k2, cost_fn) for k2 in k2s) 

563 if push_all: 

564 # want to do this if we e.g. are using a custom 'choose_fn' 

565 for candidate in candidates: 

566 heapq.heappush(queue, candidate) 

567 else: 

568 heapq.heappush(queue, min(candidates)) 

569 

570 

571def _update_ref_counts( 

572 dim_to_keys: Dict[str, Set[ArrayIndexType]], 

573 dim_ref_counts: Dict[int, Set[str]], 

574 dims: ArrayIndexType, 

575) -> None: 

576 for dim in dims: 

577 count = len(dim_to_keys[dim]) 

578 if count <= 1: 

579 dim_ref_counts[2].discard(dim) 

580 dim_ref_counts[3].discard(dim) 

581 elif count == 2: 

582 dim_ref_counts[2].add(dim) 

583 dim_ref_counts[3].discard(dim) 

584 else: 

585 dim_ref_counts[2].add(dim) 

586 dim_ref_counts[3].add(dim) 

587 

588 

589def _simple_chooser(queue, remaining): 

590 """Default contraction chooser that simply takes the minimum cost option.""" 

591 cost, k1, k2, k12 = heapq.heappop(queue) 

592 if k1 not in remaining or k2 not in remaining: 

593 return None # candidate is obsolete 

594 return cost, k1, k2, k12 

595 

596 

597def ssa_greedy_optimize( 

598 inputs: List[ArrayIndexType], 

599 output: ArrayIndexType, 

600 sizes: Dict[str, int], 

601 choose_fn: Any = None, 

602 cost_fn: Any = "memory-removed", 

603) -> PathType: 

604 """ 

605 This is the core function for :func:`greedy` but produces a path with 

606 static single assignment ids rather than recycled linear ids. 

607 SSA ids are cheaper to work with and easier to reason about. 

608 """ 

609 if len(inputs) == 1: 

610 # Perform a single contraction to match output shape. 

611 return [(0,)] 

612 

613 # set the function that assigns a heuristic cost to a possible contraction 

614 cost_fn = _COST_FNS.get(cost_fn, cost_fn) 

615 

616 # set the function that chooses which contraction to take 

617 if choose_fn is None: 

618 choose_fn = _simple_chooser 

619 push_all = False 

620 else: 

621 # assume chooser wants access to all possible contractions 

622 push_all = True 

623 

624 # A dim that is common to all tensors might as well be an output dim, since it 

625 # cannot be contracted until the final step. This avoids an expensive all-pairs 

626 # comparison to search for possible contractions at each step, leading to speedup 

627 # in many practical problems where all tensors share a common batch dimension. 

628 fs_inputs = [frozenset(x) for x in inputs] 

629 output = frozenset(output) | frozenset.intersection(*fs_inputs) 

630 

631 # Deduplicate shapes by eagerly computing Hadamard products. 

632 remaining: Dict[ArrayIndexType, int] = {} # key -> ssa_id 

633 ssa_ids = itertools.count(len(fs_inputs)) 

634 ssa_path = [] 

635 for ssa_id, key in enumerate(fs_inputs): 

636 if key in remaining: 

637 ssa_path.append((remaining[key], ssa_id)) 

638 remaining[key] = next(ssa_ids) 

639 else: 

640 remaining[key] = ssa_id 

641 

642 # Keep track of possible contraction dims. 

643 dim_to_keys = defaultdict(set) 

644 for key in remaining: 

645 for dim in key - output: 

646 dim_to_keys[dim].add(key) 

647 

648 # Keep track of the number of tensors using each dim; when the dim is no longer 

649 # used it can be contracted. Since we specialize to binary ops, we only care about 

650 # ref counts of >=2 or >=3. 

651 dim_ref_counts = { 

652 count: set(dim for dim, keys in dim_to_keys.items() if len(keys) >= count) - output for count in [2, 3] 

653 } 

654 

655 # Compute separable part of the objective function for contractions. 

656 footprints = {key: compute_size_by_dict(key, sizes) for key in remaining} 

657 

658 # Find initial candidate contractions. 

659 queue: List[GreedyContractionType] = [] 

660 for dim, dim_keys in dim_to_keys.items(): 

661 dim_keys_list = sorted(dim_keys, key=remaining.__getitem__) 

662 for i, k1 in enumerate(dim_keys_list[:-1]): 

663 k2s_guess = dim_keys_list[1 + i :] 

664 _push_candidate( 

665 output, 

666 sizes, 

667 remaining, 

668 footprints, 

669 dim_ref_counts, 

670 k1, 

671 k2s_guess, 

672 queue, 

673 push_all, 

674 cost_fn, 

675 ) 

676 

677 # Greedily contract pairs of tensors. 

678 while queue: 

679 

680 con = choose_fn(queue, remaining) 

681 if con is None: 

682 continue # allow choose_fn to flag all candidates obsolete 

683 cost, k1, k2, k12 = con 

684 

685 ssa_id1 = remaining.pop(k1) 

686 ssa_id2 = remaining.pop(k2) 

687 for dim in k1 - output: 

688 dim_to_keys[dim].remove(k1) 

689 for dim in k2 - output: 

690 dim_to_keys[dim].remove(k2) 

691 ssa_path.append((ssa_id1, ssa_id2)) 

692 if k12 in remaining: 

693 ssa_path.append((remaining[k12], next(ssa_ids))) 

694 else: 

695 for dim in k12 - output: 

696 dim_to_keys[dim].add(k12) 

697 remaining[k12] = next(ssa_ids) 

698 _update_ref_counts(dim_to_keys, dim_ref_counts, k1 | k2 - output) 

699 footprints[k12] = compute_size_by_dict(k12, sizes) 

700 

701 # Find new candidate contractions. 

702 k1 = k12 

703 k2s = set(k2 for dim in k1 for k2 in dim_to_keys[dim]) 

704 k2s.discard(k1) 

705 if k2s: 

706 _push_candidate( 

707 output, 

708 sizes, 

709 remaining, 

710 footprints, 

711 dim_ref_counts, 

712 k1, 

713 list(k2s), 

714 queue, 

715 push_all, 

716 cost_fn, 

717 ) 

718 

719 # Greedily compute pairwise outer products. 

720 final_queue = [(compute_size_by_dict(key & output, sizes), ssa_id, key) for key, ssa_id in remaining.items()] 

721 heapq.heapify(final_queue) 

722 _, ssa_id1, k1 = heapq.heappop(final_queue) 

723 while final_queue: 

724 _, ssa_id2, k2 = heapq.heappop(final_queue) 

725 ssa_path.append((min(ssa_id1, ssa_id2), max(ssa_id1, ssa_id2))) 

726 k12 = (k1 | k2) & output 

727 cost = compute_size_by_dict(k12, sizes) 

728 ssa_id12 = next(ssa_ids) 

729 _, ssa_id1, k1 = heapq.heappushpop(final_queue, (cost, ssa_id12, k12)) 

730 

731 return ssa_path 

732 

733 

734def greedy( 

735 inputs: List[ArrayIndexType], 

736 output: ArrayIndexType, 

737 size_dict: Dict[str, int], 

738 memory_limit: Optional[int] = None, 

739 choose_fn: Any = None, 

740 cost_fn: str = "memory-removed", 

741) -> PathType: 

742 """ 

743 Finds the path by a three stage algorithm: 

744 

745 1. Eagerly compute Hadamard products. 

746 2. Greedily compute contractions to maximize `removed_size` 

747 3. Greedily compute outer products. 

748 

749 This algorithm scales quadratically with respect to the 

750 maximum number of elements sharing a common dim. 

751 

752 **Parameters:** 

753 

754 - **inputs** - *(list)* List of sets that represent the lhs side of the einsum subscript 

755 - **output** - *(set)* Set that represents the rhs side of the overall einsum subscript 

756 - **size_dict** - *(dictionary)* Dictionary of index sizes 

757 - **memory_limit** - *(int)* The maximum number of elements in a temporary array 

758 - **choose_fn** - *(callable, optional)* A function that chooses which contraction to perform from the queue 

759 - **cost_fn** - *(callable, optional)* A function that assigns a potential contraction a cost. 

760 

761 **Returns:** 

762 

763 - **path** - *(list)* The contraction order (a list of tuples of ints). 

764 

765 **Examples:** 

766 

767 ```python 

768 isets = [set('abd'), set('ac'), set('bdc')] 

769 oset = set('') 

770 idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 

771 greedy(isets, oset, idx_sizes) 

772 #> [(0, 2), (0, 1)] 

773 ``` 

774 """ 

775 if memory_limit not in _UNLIMITED_MEM: 

776 return branch(inputs, output, size_dict, memory_limit, nbranch=1, cost_fn=cost_fn) # type: ignore 

777 

778 ssa_path = ssa_greedy_optimize(inputs, output, size_dict, cost_fn=cost_fn, choose_fn=choose_fn) 

779 return ssa_to_linear(ssa_path) 

780 

781 

782def _tree_to_sequence(tree: Tuple[Any, ...]) -> PathType: 

783 """ 

784 Converts a contraction tree to a contraction path as it has to be 

785 returned by path optimizers. A contraction tree can either be an int 

786 (=no contraction) or a tuple containing the terms to be contracted. An 

787 arbitrary number (>= 1) of terms can be contracted at once. Note that 

788 contractions are commutative, e.g. (j, k, l) = (k, l, j). Note that in 

789 general, solutions are not unique. 

790 

791 **Parameters:** 

792 

793 - **c** - *(tuple or int)* Contraction tree 

794 

795 **Returns:** 

796 

797 - **path** - *(list[set[int]])* Contraction path 

798 

799 **Examples:** 

800 

801 ```python 

802 _tree_to_sequence(((1,2),(0,(4,5,3)))) 

803 #> [(1, 2), (1, 2, 3), (0, 2), (0, 1)] 

804 ``` 

805 """ 

806 

807 # ((1,2),(0,(4,5,3))) --> [(1, 2), (1, 2, 3), (0, 2), (0, 1)] 

808 # 

809 # 0 0 0 (1,2) --> ((1,2),(0,(3,4,5))) 

810 # 1 3 (1,2) --> (0,(3,4,5)) 

811 # 2 --> 4 --> (3,4,5) 

812 # 3 5 

813 # 4 (1,2) 

814 # 5 

815 # 

816 # this function iterates through the table shown above from right to left; 

817 

818 if type(tree) == int: 

819 return [] 

820 

821 c: List[Tuple[Any, ...]] = [tree] # list of remaining contractions (lower part of columns shown above) 

822 t: List[int] = [] # list of elementary tensors (upper part of columns) 

823 s: List[Tuple[int, ...]] = [] # resulting contraction sequence 

824 

825 while len(c) > 0: 

826 j = c.pop(-1) 

827 s.insert(0, tuple()) 

828 

829 for i in sorted([i for i in j if type(i) == int]): 

830 s[0] += (sum(1 for q in t if q < i),) 

831 t.insert(s[0][-1], i) 

832 

833 for i_tup in [i_tup for i_tup in j if type(i_tup) != int]: 

834 s[0] += (len(t) + len(c),) 

835 c.append(i_tup) 

836 

837 return s 

838 

839 

840def _find_disconnected_subgraphs(inputs: List[FrozenSet[int]], output: FrozenSet[int]) -> List[FrozenSet[int]]: 

841 """ 

842 Finds disconnected subgraphs in the given list of inputs. Inputs are 

843 connected if they share summation indices. Note: Disconnected subgraphs 

844 can be contracted independently before forming outer products. 

845 

846 **Parameters:** 

847 - **inputs** - *(list[set])* List of sets that represent the lhs side of the einsum subscript 

848 - **output** - *(set)* Set that represents the rhs side of the overall einsum subscript 

849 

850 **Returns:** 

851 

852 - **subgraphs** - *(list[set[int]])* List containing sets of indices for each subgraph 

853 

854 **Examples:** 

855 

856 ```python 

857 _find_disconnected_subgraphs([set("ab"), set("c"), set("ad")], set("bd")) 

858 #> [{0, 2}, {1}] 

859 

860 _find_disconnected_subgraphs([set("ab"), set("c"), set("ad")], set("abd")) 

861 #> [{0}, {1}, {2}] 

862 ``` 

863 """ 

864 

865 subgraphs = [] 

866 unused_inputs = set(range(len(inputs))) 

867 

868 i_sum = frozenset.union(*inputs) - output # all summation indices 

869 

870 while len(unused_inputs) > 0: 

871 g = set() 

872 q = [unused_inputs.pop()] 

873 while len(q) > 0: 

874 j = q.pop() 

875 g.add(j) 

876 i_tmp = i_sum & inputs[j] 

877 n = {k for k in unused_inputs if len(i_tmp & inputs[k]) > 0} 

878 q.extend(n) 

879 unused_inputs.difference_update(n) 

880 

881 subgraphs.append(g) 

882 

883 return [frozenset(x) for x in subgraphs] 

884 

885 

886def _bitmap_select(s: int, seq: List[FrozenSet[int]]) -> Generator[FrozenSet[int], None, None]: 

887 """Select elements of ``seq`` which are marked by the bitmap set ``s``. 

888 

889 E.g.: 

890 

891 >>> list(_bitmap_select(0b11010, ['A', 'B', 'C', 'D', 'E'])) 

892 ['B', 'D', 'E'] 

893 """ 

894 return (x for x, b in zip(seq, bin(s)[:1:-1]) if b == "1") 

895 

896 

897def _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2): 

898 """Calculates the effective outer indices of the intermediate tensor 

899 corresponding to the subgraph ``s``. 

900 """ 

901 # set of remaining tensors (=g-s) 

902 r = g & (all_tensors ^ s) 

903 # indices of remaining indices: 

904 if r: 

905 i_r = frozenset.union(*_bitmap_select(r, inputs)) 

906 else: 

907 i_r = frozenset() 

908 # contraction indices: 

909 i_contract = i1_cut_i2_wo_output - i_r 

910 return i1_union_i2 - i_contract 

911 

912 

913def _dp_compare_flops( 

914 cost1: int, 

915 cost2: int, 

916 i1_union_i2: Set[int], 

917 size_dict: List[int], 

918 cost_cap: int, 

919 s1: int, 

920 s2: int, 

921 xn: Dict[int, Any], 

922 g: int, 

923 all_tensors: int, 

924 inputs: List[FrozenSet[int]], 

925 i1_cut_i2_wo_output: Set[int], 

926 memory_limit: Optional[int], 

927 contract1: Union[int, Tuple[int]], 

928 contract2: Union[int, Tuple[int]], 

929) -> None: 

930 """Performs the inner comparison of whether the two subgraphs (the bitmaps 

931 `s1` and `s2`) should be merged and added to the dynamic programming 

932 search. Will skip for a number of reasons: 

933 

934 1. If the number of operations to form `s = s1 | s2` including previous 

935 contractions is above the cost-cap. 

936 2. If we've already found a better way of making `s`. 

937 3. If the intermediate tensor corresponding to `s` is going to break the 

938 memory limit. 

939 """ 

940 

941 # TODO: Odd usage with an Iterable[int] to map a dict of type List[int] 

942 cost = cost1 + cost2 + compute_size_by_dict(i1_union_i2, size_dict) 

943 if cost <= cost_cap: 

944 s = s1 | s2 

945 if s not in xn or cost < xn[s][1]: 

946 i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2) 

947 mem = compute_size_by_dict(i, size_dict) 

948 if memory_limit is None or mem <= memory_limit: 

949 xn[s] = (i, cost, (contract1, contract2)) 

950 

951 

952def _dp_compare_size( 

953 cost1: int, 

954 cost2: int, 

955 i1_union_i2: Set[int], 

956 size_dict: List[int], 

957 cost_cap: int, 

958 s1: int, 

959 s2: int, 

960 xn: Dict[int, Any], 

961 g: int, 

962 all_tensors: int, 

963 inputs: List[FrozenSet[int]], 

964 i1_cut_i2_wo_output: Set[int], 

965 memory_limit: Optional[int], 

966 contract1: Union[int, Tuple[int]], 

967 contract2: Union[int, Tuple[int]], 

968) -> None: 

969 """Like `_dp_compare_flops` but sieves the potential contraction based 

970 on the size of the intermediate tensor created, rather than the number of 

971 operations, and so calculates that first. 

972 """ 

973 

974 s = s1 | s2 

975 i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2) 

976 mem = compute_size_by_dict(i, size_dict) 

977 cost = max(cost1, cost2, mem) 

978 if cost <= cost_cap: 

979 if s not in xn or cost < xn[s][1]: 

980 if memory_limit is None or mem <= memory_limit: 

981 xn[s] = (i, cost, (contract1, contract2)) 

982 

983 

984def _dp_compare_write( 

985 cost1: int, 

986 cost2: int, 

987 i1_union_i2: Set[int], 

988 size_dict: List[int], 

989 cost_cap: int, 

990 s1: int, 

991 s2: int, 

992 xn: Dict[int, Any], 

993 g: int, 

994 all_tensors: int, 

995 inputs: List[FrozenSet[int]], 

996 i1_cut_i2_wo_output: Set[int], 

997 memory_limit: Optional[int], 

998 contract1: Union[int, Tuple[int]], 

999 contract2: Union[int, Tuple[int]], 

1000) -> None: 

1001 """Like ``_dp_compare_flops`` but sieves the potential contraction based 

1002 on the total size of memory created, rather than the number of 

1003 operations, and so calculates that first. 

1004 """ 

1005 s = s1 | s2 

1006 i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2) 

1007 mem = compute_size_by_dict(i, size_dict) 

1008 cost = cost1 + cost2 + mem 

1009 if cost <= cost_cap: 

1010 if s not in xn or cost < xn[s][1]: 

1011 if memory_limit is None or mem <= memory_limit: 

1012 xn[s] = (i, cost, (contract1, contract2)) 

1013 

1014 

1015DEFAULT_COMBO_FACTOR = 64 

1016 

1017 

1018def _dp_compare_combo( 

1019 cost1: int, 

1020 cost2: int, 

1021 i1_union_i2: Set[int], 

1022 size_dict: List[int], 

1023 cost_cap: int, 

1024 s1: int, 

1025 s2: int, 

1026 xn: Dict[int, Any], 

1027 g: int, 

1028 all_tensors: int, 

1029 inputs: List[FrozenSet[int]], 

1030 i1_cut_i2_wo_output: Set[int], 

1031 memory_limit: Optional[int], 

1032 contract1: Union[int, Tuple[int]], 

1033 contract2: Union[int, Tuple[int]], 

1034 factor: Union[int, float] = DEFAULT_COMBO_FACTOR, 

1035 combine: Callable = sum, 

1036) -> None: 

1037 """Like ``_dp_compare_flops`` but sieves the potential contraction based 

1038 on some combination of both the flops and size, 

1039 """ 

1040 s = s1 | s2 

1041 i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2) 

1042 mem = compute_size_by_dict(i, size_dict) 

1043 f = compute_size_by_dict(i1_union_i2, size_dict) 

1044 cost = cost1 + cost2 + combine((f, factor * mem)) 

1045 if cost <= cost_cap: 

1046 if s not in xn or cost < xn[s][1]: 

1047 if memory_limit is None or mem <= memory_limit: 

1048 xn[s] = (i, cost, (contract1, contract2)) 

1049 

1050 

1051minimize_finder = re.compile(r"(flops|size|write|combo|limit)-*(\d*)") 

1052 

1053 

1054@functools.lru_cache(128) 

1055def _parse_minimize(minimize: Union[str, Callable]) -> Tuple[Callable, Union[int, float]]: 

1056 """This works out what local scoring function to use for the dp algorithm 

1057 as well as a `naive_scale` to account for the memory_limit checks. 

1058 """ 

1059 if minimize == "flops": 

1060 return _dp_compare_flops, 1 

1061 elif minimize == "size": 

1062 return _dp_compare_size, 1 

1063 elif minimize == "write": 

1064 return _dp_compare_write, 1 

1065 elif callable(minimize): 

1066 # default to naive_scale=inf for this and remaining options 

1067 # as otherwise memory_limit check can cause problems 

1068 return minimize, float("inf") 

1069 

1070 # parse out a customized value for the combination factor 

1071 match = minimize_finder.fullmatch(minimize) 

1072 if match is None: 

1073 raise ValueError(f"Couldn't parse `minimize` value: {minimize}.") 

1074 

1075 minimize, custom_factor = match.groups() 

1076 factor = float(custom_factor) if custom_factor else DEFAULT_COMBO_FACTOR 

1077 if minimize == "combo": 

1078 return functools.partial(_dp_compare_combo, factor=factor, combine=sum), float("inf") 

1079 elif minimize == "limit": 

1080 return functools.partial(_dp_compare_combo, factor=factor, combine=max), float("inf") 

1081 else: 

1082 raise ValueError(f"Couldn't parse `minimize` value: {minimize}.") 

1083 

1084 

1085def simple_tree_tuple(seq: Sequence[Tuple[int, ...]]) -> Tuple[Any, ...]: 

1086 """Make a simple left to right binary tree out of iterable `seq`. 

1087 

1088 ```python 

1089 tuple_nest([1, 2, 3, 4]) 

1090 #> (((1, 2), 3), 4) 

1091 ``` 

1092 

1093 """ 

1094 return functools.reduce(lambda x, y: (x, y), seq) 

1095 

1096 

1097def _dp_parse_out_single_term_ops( 

1098 inputs: List[FrozenSet[int]], all_inds: Tuple[str, ...], ind_counts: CounterType[str] 

1099) -> Tuple[List[FrozenSet[int]], List[Tuple[int]], List[Union[int, Tuple[int]]]]: 

1100 """Take `inputs` and parse for single term index operations, i.e. where 

1101 an index appears on one tensor and nowhere else. 

1102 

1103 If a term is completely reduced to a scalar in this way it can be removed 

1104 to `inputs_done`. If only some indices can be summed then add a 'single 

1105 term contraction' that will perform this summation. 

1106 """ 

1107 i_single = frozenset(i for i, c in enumerate(all_inds) if ind_counts[c] == 1) 

1108 inputs_parsed: List[FrozenSet[int]] = [] 

1109 inputs_done: List[Tuple[int]] = [] 

1110 inputs_contractions: List[Union[int, Tuple[int]]] = [] 

1111 for j, i in enumerate(inputs): 

1112 i_reduced = i - i_single 

1113 if (not i_reduced) and (len(i) > 0): 

1114 # input reduced to scalar already - remove 

1115 inputs_done.append((j,)) 

1116 else: 

1117 # if the input has any index reductions, add single contraction 

1118 inputs_parsed.append(i_reduced) 

1119 inputs_contractions.append((j,) if i_reduced != i else j) 

1120 

1121 return inputs_parsed, inputs_done, inputs_contractions 

1122 

1123 

1124class DynamicProgramming(PathOptimizer): 

1125 """ 

1126 Finds the optimal path of pairwise contractions without intermediate outer 

1127 products based a dynamic programming approach presented in 

1128 Phys. Rev. E 90, 033315 (2014) (the corresponding preprint is publicly 

1129 available at https://arxiv.org/abs/1304.6112). This method is especially 

1130 well-suited in the area of tensor network states, where it usually 

1131 outperforms all the other optimization strategies. 

1132 

1133 This algorithm shows exponential scaling with the number of inputs 

1134 in the worst case scenario (see example below). If the graph to be 

1135 contracted consists of disconnected subgraphs, the algorithm scales 

1136 linearly in the number of disconnected subgraphs and only exponentially 

1137 with the number of inputs per subgraph. 

1138 

1139 **Parameters:** 

1140 

1141 - **minimize** - *({'flops', 'size', 'write', 'combo', 'limit', callable}, optional)* What to minimize: 

1142 

1143 - 'flops' - minimize the number of flops 

1144 - 'size' - minimize the size of the largest intermediate 

1145 - 'write' - minimize the size of all intermediate tensors 

1146 - 'combo' - minimize `flops + alpha * write` summed over intermediates, a default ratio of alpha=64 

1147 is used, or it can be customized with `f'combo-{alpha}'` 

1148 - 'limit' - minimize `max(flops, alpha * write)` summed over intermediates, a default ratio of alpha=64 

1149 is used, or it can be customized with `f'limit-{alpha}'` 

1150 - callable - a custom local cost function 

1151 

1152 - **cost_cap** - *({True, False, int}, optional)* How to implement cost-capping: 

1153 

1154 - True - iteratively increase the cost-cap 

1155 - False - implement no cost-cap at all 

1156 - int - use explicit cost cap 

1157 

1158 - **search_outer** - *(bool, optional)* In rare circumstances the optimal contraction may involve an outer 

1159 product, this option allows searching such contractions but may well 

1160 slow down the path finding considerably on all but very small graphs. 

1161 """ 

1162 

1163 def __init__(self, minimize: str = "flops", cost_cap: bool = True, search_outer: bool = False) -> None: 

1164 self.minimize = minimize 

1165 self.search_outer = search_outer 

1166 self.cost_cap = cost_cap 

1167 

1168 def __call__( 

1169 self, 

1170 inputs_: List[ArrayIndexType], 

1171 output_: ArrayIndexType, 

1172 size_dict_: Dict[str, int], 

1173 memory_limit: Optional[int] = None, 

1174 ) -> PathType: 

1175 """ 

1176 **Parameters:** 

1177 

1178 - **inputs** - *(list)* List of sets that represent the lhs side of the einsum subscript 

1179 - **output** - *(set)* Set that represents the rhs side of the overall einsum subscript 

1180 - **size_dict** - *(dictionary)* Dictionary of index sizes 

1181 - **memory_limit** - *(int)* The maximum number of elements in a temporary array 

1182 

1183 **Returns:** 

1184 

1185 - **path** - *(list)* The contraction order (a list of tuples of ints). 

1186 

1187 **Examples:** 

1188 

1189 ```python 

1190 n_in = 3 # exponential scaling 

1191 n_out = 2 # linear scaling 

1192 s = dict() 

1193 i_all = [] 

1194 for _ in range(n_out): 

1195 i = [set() for _ in range(n_in)] 

1196 for j in range(n_in): 

1197 for k in range(j+1, n_in): 

1198 c = oe.get_symbol(len(s)) 

1199 i[j].add(c) 

1200 i[k].add(c) 

1201 s[c] = 2 

1202 i_all.extend(i) 

1203 o = DynamicProgramming() 

1204 o(i_all, set(), s) 

1205 #> [(1, 2), (0, 4), (1, 2), (0, 2), (0, 1)] 

1206 ``` 

1207 """ 

1208 _check_contraction, naive_scale = _parse_minimize(self.minimize) 

1209 _check_outer = (lambda x: True) if self.search_outer else (lambda x: x) 

1210 

1211 ind_counts = Counter(itertools.chain(*inputs_, output_)) 

1212 all_inds = tuple(ind_counts) 

1213 

1214 # convert all indices to integers (makes set operations ~10 % faster) 

1215 symbol2int = {c: j for j, c in enumerate(all_inds)} 

1216 inputs = [frozenset(symbol2int[c] for c in i) for i in inputs_] 

1217 output = frozenset(symbol2int[c] for c in output_) 

1218 size_dict_canonical = {symbol2int[c]: v for c, v in size_dict_.items() if c in symbol2int} 

1219 size_dict = [size_dict_canonical[j] for j in range(len(size_dict_canonical))] 

1220 naive_cost = naive_scale * len(inputs) * functools.reduce(operator.mul, size_dict, 1) 

1221 

1222 inputs, inputs_done, inputs_contractions = _dp_parse_out_single_term_ops(inputs, all_inds, ind_counts) 

1223 

1224 if not inputs: 

1225 # nothing left to do after single axis reductions! 

1226 return _tree_to_sequence(simple_tree_tuple(inputs_done)) 

1227 

1228 # a list of all necessary contraction expressions for each of the 

1229 # disconnected subgraphs and their size 

1230 subgraph_contractions = inputs_done 

1231 subgraph_contractions_size = [1] * len(inputs_done) 

1232 

1233 if self.search_outer: 

1234 # optimize everything together if we are considering outer products 

1235 subgraphs = [frozenset(range(len(inputs)))] 

1236 else: 

1237 subgraphs = _find_disconnected_subgraphs(inputs, output) 

1238 

1239 # the bitmap set of all tensors is computed as it is needed to 

1240 # compute set differences: s1 - s2 transforms into 

1241 # s1 & (all_tensors ^ s2) 

1242 all_tensors = (1 << len(inputs)) - 1 

1243 

1244 for g in subgraphs: 

1245 

1246 # dynamic programming approach to compute x[n] for subgraph g; 

1247 # x[n][set of n tensors] = (indices, cost, contraction) 

1248 # the set of n tensors is represented by a bitmap: if bit j is 1, 

1249 # tensor j is in the set, e.g. 0b100101 = {0,2,5}; set unions 

1250 # (intersections) can then be computed by bitwise or (and); 

1251 x: List[Any] = [None] * 2 + [dict() for j in range(len(g) - 1)] 

1252 x[1] = OrderedDict((1 << j, (inputs[j], 0, inputs_contractions[j])) for j in g) 

1253 

1254 # convert set of tensors g to a bitmap set: 

1255 bitmap_g = functools.reduce(lambda x, y: x | y, (1 << j for j in g)) 

1256 

1257 # try to find contraction with cost <= cost_cap and increase 

1258 # cost_cap successively if no such contraction is found; 

1259 # this is a major performance improvement; start with product of 

1260 # output index dimensions as initial cost_cap 

1261 subgraph_inds = frozenset.union(*_bitmap_select(bitmap_g, inputs)) 

1262 if self.cost_cap is True: 

1263 cost_cap = compute_size_by_dict(subgraph_inds & output, size_dict) 

1264 elif self.cost_cap is False: 

1265 cost_cap = float("inf") # type: ignore 

1266 else: 

1267 cost_cap = self.cost_cap 

1268 # set the factor to increase the cost by each iteration (ensure > 1) 

1269 if len(subgraph_inds) == 0: 

1270 cost_increment = 2 

1271 else: 

1272 cost_increment = max(min(map(size_dict.__getitem__, subgraph_inds)), 2) 

1273 

1274 while len(x[-1]) == 0: 

1275 for n in range(2, len(x[1]) + 1): 

1276 xn = x[n] 

1277 

1278 # try to combine solutions from x[m] and x[n-m] 

1279 for m in range(1, n // 2 + 1): 

1280 for s1, (i1, cost1, contract1) in x[m].items(): 

1281 for s2, (i2, cost2, contract2) in x[n - m].items(): 

1282 

1283 # can only merge if s1 and s2 are disjoint 

1284 # and avoid e.g. s1={0}, s2={1} and s1={1}, s2={0} 

1285 if (not s1 & s2) and (m != n - m or s1 < s2): 

1286 i1_cut_i2_wo_output = (i1 & i2) - output 

1287 

1288 # maybe ignore outer products: 

1289 if _check_outer(i1_cut_i2_wo_output): 

1290 

1291 i1_union_i2 = i1 | i2 

1292 _check_contraction( 

1293 cost1, 

1294 cost2, 

1295 i1_union_i2, 

1296 size_dict, 

1297 cost_cap, 

1298 s1, 

1299 s2, 

1300 xn, 

1301 bitmap_g, 

1302 all_tensors, 

1303 inputs, 

1304 i1_cut_i2_wo_output, 

1305 memory_limit, 

1306 contract1, 

1307 contract2, 

1308 ) 

1309 

1310 if (cost_cap > naive_cost) and (len(x[-1]) == 0): 

1311 raise RuntimeError("No contraction found for given `memory_limit`.") 

1312 

1313 # increase cost cap for next iteration: 

1314 cost_cap = cost_increment * cost_cap 

1315 

1316 i, cost, contraction = list(x[-1].values())[0] 

1317 subgraph_contractions.append(contraction) 

1318 subgraph_contractions_size.append(compute_size_by_dict(i, size_dict)) 

1319 

1320 # sort the subgraph contractions by the size of the subgraphs in 

1321 # ascending order (will give the cheapest contractions); note that 

1322 # outer products should be performed pairwise (to use BLAS functions) 

1323 subgraph_contractions = [ 

1324 subgraph_contractions[j] 

1325 for j in sorted( 

1326 range(len(subgraph_contractions_size)), 

1327 key=subgraph_contractions_size.__getitem__, 

1328 ) 

1329 ] 

1330 

1331 # build the final contraction tree 

1332 tree = simple_tree_tuple(subgraph_contractions) 

1333 return _tree_to_sequence(tree) 

1334 

1335 

1336def dynamic_programming( 

1337 inputs: List[ArrayIndexType], 

1338 output: ArrayIndexType, 

1339 size_dict: Dict[str, int], 

1340 memory_limit: Optional[int] = None, 

1341 **kwargs: Any, 

1342) -> PathType: 

1343 optimizer = DynamicProgramming(**kwargs) 

1344 return optimizer(inputs, output, size_dict, memory_limit) 

1345 

1346 

1347_AUTO_CHOICES = {} 

1348for i in range(1, 5): 

1349 _AUTO_CHOICES[i] = optimal 

1350for i in range(5, 7): 

1351 _AUTO_CHOICES[i] = branch_all 

1352for i in range(7, 9): 

1353 _AUTO_CHOICES[i] = branch_2 

1354for i in range(9, 15): 

1355 _AUTO_CHOICES[i] = branch_1 

1356 

1357 

1358def auto( 

1359 inputs: List[ArrayIndexType], 

1360 output: ArrayIndexType, 

1361 size_dict: Dict[str, int], 

1362 memory_limit: Optional[int] = None, 

1363) -> PathType: 

1364 """Finds the contraction path by automatically choosing the method based on 

1365 how many input arguments there are. 

1366 """ 

1367 N = len(inputs) 

1368 return _AUTO_CHOICES.get(N, greedy)(inputs, output, size_dict, memory_limit) 

1369 

1370 

1371_AUTO_HQ_CHOICES = {} 

1372for i in range(1, 6): 

1373 _AUTO_HQ_CHOICES[i] = optimal 

1374for i in range(6, 17): 

1375 _AUTO_HQ_CHOICES[i] = dynamic_programming 

1376 

1377 

1378def auto_hq( 

1379 inputs: List[ArrayIndexType], 

1380 output: ArrayIndexType, 

1381 size_dict: Dict[str, int], 

1382 memory_limit: Optional[int] = None, 

1383) -> PathType: 

1384 """Finds the contraction path by automatically choosing the method based on 

1385 how many input arguments there are, but targeting a more generous 

1386 amount of search time than ``'auto'``. 

1387 """ 

1388 from .path_random import random_greedy_128 

1389 

1390 N = len(inputs) 

1391 return _AUTO_HQ_CHOICES.get(N, random_greedy_128)(inputs, output, size_dict, memory_limit) 

1392 

1393 

1394PathSearchFunctionType = Callable[[List[ArrayIndexType], ArrayIndexType, Dict[str, int], Optional[int]], PathType] 

1395_PATH_OPTIONS: Dict[str, PathSearchFunctionType] = { 

1396 "auto": auto, 

1397 "auto-hq": auto_hq, 

1398 "optimal": optimal, 

1399 "branch-all": branch_all, 

1400 "branch-2": branch_2, 

1401 "branch-1": branch_1, 

1402 "greedy": greedy, 

1403 "eager": greedy, 

1404 "opportunistic": greedy, 

1405 "dp": dynamic_programming, 

1406 "dynamic-programming": dynamic_programming, 

1407} 

1408 

1409 

1410def register_path_fn(name: str, fn: PathSearchFunctionType) -> None: 

1411 """Add path finding function ``fn`` as an option with ``name``.""" 

1412 if name in _PATH_OPTIONS: 

1413 raise KeyError("Path optimizer '{}' already exists.".format(name)) 

1414 

1415 _PATH_OPTIONS[name.lower()] = fn 

1416 

1417 

1418def get_path_fn(path_type: str) -> PathSearchFunctionType: 

1419 """Get the correct path finding function from str ``path_type``.""" 

1420 path_type = path_type.lower() 

1421 if path_type not in _PATH_OPTIONS: 

1422 raise KeyError( 

1423 "Path optimizer '{}' not found, valid options are {}.".format(path_type, set(_PATH_OPTIONS.keys())) 

1424 ) 

1425 

1426 return _PATH_OPTIONS[path_type]