Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/paths.py: 16%

440 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1""" 

2Contains the path technology behind opt_einsum in addition to several path helpers 

3""" 

4 

5import functools 

6import heapq 

7import itertools 

8import random 

9from collections import Counter, OrderedDict, defaultdict 

10 

11import numpy as np 

12 

13from . import helpers 

14 

15__all__ = [ 

16 "optimal", "BranchBound", "branch", "greedy", "auto", "auto_hq", "get_path_fn", "DynamicProgramming", 

17 "dynamic_programming" 

18] 

19 

20_UNLIMITED_MEM = {-1, None, float('inf')} 

21 

22 

23class PathOptimizer(object): 

24 """Base class for different path optimizers to inherit from. 

25 

26 Subclassed optimizers should define a call method with signature:: 

27 

28 def __call__(self, inputs, output, size_dict, memory_limit=None): 

29 \"\"\" 

30 Parameters 

31 ---------- 

32 inputs : list[set[str]] 

33 The indices of each input array. 

34 outputs : set[str] 

35 The output indices 

36 size_dict : dict[str, int] 

37 The size of each index 

38 memory_limit : int, optional 

39 If given, the maximum allowed memory. 

40 \"\"\" 

41 # ... compute path here ... 

42 return path 

43 

44 where ``path`` is a list of int-tuples specifiying a contraction order. 

45 """ 

46 

47 def _check_args_against_first_call(self, inputs, output, size_dict): 

48 """Utility that stateful optimizers can use to ensure they are not 

49 called with different contractions across separate runs. 

50 """ 

51 args = (inputs, output, size_dict) 

52 if not hasattr(self, '_first_call_args'): 

53 # simply set the attribute as currently there is no global PathOptimizer init 

54 self._first_call_args = args 

55 elif args != self._first_call_args: 

56 raise ValueError("The arguments specifiying the contraction that this path optimizer " 

57 "instance was called with have changed - try creating a new instance.") 

58 

59 def __call__(self, inputs, output, size_dict, memory_limit=None): 

60 raise NotImplementedError 

61 

62 

63def ssa_to_linear(ssa_path): 

64 """ 

65 Convert a path with static single assignment ids to a path with recycled 

66 linear ids. For example:: 

67 

68 >>> ssa_to_linear([(0, 3), (2, 4), (1, 5)]) 

69 [(0, 3), (1, 2), (0, 1)] 

70 """ 

71 ids = np.arange(1 + max(map(max, ssa_path)), dtype=np.int32) 

72 path = [] 

73 for ssa_ids in ssa_path: 

74 path.append(tuple(int(ids[ssa_id]) for ssa_id in ssa_ids)) 

75 for ssa_id in ssa_ids: 

76 ids[ssa_id:] -= 1 

77 return path 

78 

79 

80def linear_to_ssa(path): 

81 """ 

82 Convert a path with recycled linear ids to a path with static single 

83 assignment ids. For example:: 

84 

85 >>> linear_to_ssa([(0, 3), (1, 2), (0, 1)]) 

86 [(0, 3), (2, 4), (1, 5)] 

87 """ 

88 num_inputs = sum(map(len, path)) - len(path) + 1 

89 linear_to_ssa = list(range(num_inputs)) 

90 new_ids = itertools.count(num_inputs) 

91 ssa_path = [] 

92 for ids in path: 

93 ssa_path.append(tuple(linear_to_ssa[id_] for id_ in ids)) 

94 for id_ in sorted(ids, reverse=True): 

95 del linear_to_ssa[id_] 

96 linear_to_ssa.append(next(new_ids)) 

97 return ssa_path 

98 

99 

100def calc_k12_flops(inputs, output, remaining, i, j, size_dict): 

101 """ 

102 Calculate the resulting indices and flops for a potential pairwise 

103 contraction - used in the recursive (optimal/branch) algorithms. 

104 

105 Parameters 

106 ---------- 

107 inputs : tuple[frozenset[str]] 

108 The indices of each tensor in this contraction, note this includes 

109 tensors unavaiable to contract as static single assignment is used -> 

110 contracted tensors are not removed from the list. 

111 output : frozenset[str] 

112 The set of output indices for the whole contraction. 

113 remaining : frozenset[int] 

114 The set of indices (corresponding to ``inputs``) of tensors still 

115 available to contract. 

116 i : int 

117 Index of potential tensor to contract. 

118 j : int 

119 Index of potential tensor to contract. 

120 size_dict dict[str, int] 

121 Size mapping of all the indices. 

122 

123 Returns 

124 ------- 

125 k12 : frozenset 

126 The resulting indices of the potential tensor. 

127 cost : int 

128 Estimated flop count of operation. 

129 """ 

130 k1, k2 = inputs[i], inputs[j] 

131 either = k1 | k2 

132 shared = k1 & k2 

133 keep = frozenset.union(output, *map(inputs.__getitem__, remaining - {i, j})) 

134 

135 k12 = either & keep 

136 cost = helpers.flop_count(either, shared - keep, 2, size_dict) 

137 

138 return k12, cost 

139 

140 

141def _compute_oversize_flops(inputs, remaining, output, size_dict): 

142 """ 

143 Compute the flop count for a contraction of all remaining arguments. This 

144 is used when a memory limit means that no pairwise contractions can be made. 

145 """ 

146 idx_contraction = frozenset.union(*map(inputs.__getitem__, remaining)) 

147 inner = idx_contraction - output 

148 num_terms = len(remaining) 

149 return helpers.flop_count(idx_contraction, inner, num_terms, size_dict) 

150 

151 

152def optimal(inputs, output, size_dict, memory_limit=None): 

153 """ 

154 Computes all possible pair contractions in a depth-first recursive manner, 

155 sieving results based on ``memory_limit`` and the best path found so far. 

156 Returns the lowest cost path. This algorithm scales factoriallly with 

157 respect to the elements in the list ``input_sets``. 

158 

159 Parameters 

160 ---------- 

161 inputs : list 

162 List of sets that represent the lhs side of the einsum subscript. 

163 output : set 

164 Set that represents the rhs side of the overall einsum subscript. 

165 size_dict : dictionary 

166 Dictionary of index sizes. 

167 memory_limit : int 

168 The maximum number of elements in a temporary array. 

169 

170 Returns 

171 ------- 

172 path : list 

173 The optimal contraction order within the memory limit constraint. 

174 

175 Examples 

176 -------- 

177 >>> isets = [set('abd'), set('ac'), set('bdc')] 

178 >>> oset = set('') 

179 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 

180 >>> optimal(isets, oset, idx_sizes, 5000) 

181 [(0, 2), (0, 1)] 

182 """ 

183 inputs = tuple(map(frozenset, inputs)) 

184 output = frozenset(output) 

185 

186 best = {'flops': float('inf'), 'ssa_path': (tuple(range(len(inputs))), )} 

187 size_cache = {} 

188 result_cache = {} 

189 

190 def _optimal_iterate(path, remaining, inputs, flops): 

191 

192 # reached end of path (only ever get here if flops is best found so far) 

193 if len(remaining) == 1: 

194 best['flops'] = flops 

195 best['ssa_path'] = path 

196 return 

197 

198 # check all possible remaining paths 

199 for i, j in itertools.combinations(remaining, 2): 

200 if i > j: 

201 i, j = j, i 

202 key = (inputs[i], inputs[j]) 

203 try: 

204 k12, flops12 = result_cache[key] 

205 except KeyError: 

206 k12, flops12 = result_cache[key] = calc_k12_flops(inputs, output, remaining, i, j, size_dict) 

207 

208 # sieve based on current best flops 

209 new_flops = flops + flops12 

210 if new_flops >= best['flops']: 

211 continue 

212 

213 # sieve based on memory limit 

214 if memory_limit not in _UNLIMITED_MEM: 

215 try: 

216 size12 = size_cache[k12] 

217 except KeyError: 

218 size12 = size_cache[k12] = helpers.compute_size_by_dict(k12, size_dict) 

219 

220 # possibly terminate this path with an all-terms einsum 

221 if size12 > memory_limit: 

222 new_flops = flops + _compute_oversize_flops(inputs, remaining, output, size_dict) 

223 if new_flops < best['flops']: 

224 best['flops'] = new_flops 

225 best['ssa_path'] = path + (tuple(remaining), ) 

226 continue 

227 

228 # add contraction and recurse into all remaining 

229 _optimal_iterate(path=path + ((i, j), ), 

230 inputs=inputs + (k12, ), 

231 remaining=remaining - {i, j} | {len(inputs)}, 

232 flops=new_flops) 

233 

234 _optimal_iterate(path=(), inputs=inputs, remaining=set(range(len(inputs))), flops=0) 

235 

236 return ssa_to_linear(best['ssa_path']) 

237 

238 

239# functions for comparing which of two paths is 'better' 

240 

241 

242def better_flops_first(flops, size, best_flops, best_size): 

243 return (flops, size) < (best_flops, best_size) 

244 

245 

246def better_size_first(flops, size, best_flops, best_size): 

247 return (size, flops) < (best_size, best_flops) 

248 

249 

250_BETTER_FNS = { 

251 'flops': better_flops_first, 

252 'size': better_size_first, 

253} 

254 

255 

256def get_better_fn(key): 

257 return _BETTER_FNS[key] 

258 

259 

260# functions for assigning a heuristic 'cost' to a potential contraction 

261 

262 

263def cost_memory_removed(size12, size1, size2, k12, k1, k2): 

264 """The default heuristic cost, corresponding to the total reduction in 

265 memory of performing a contraction. 

266 """ 

267 return size12 - size1 - size2 

268 

269 

270def cost_memory_removed_jitter(size12, size1, size2, k12, k1, k2): 

271 """Like memory-removed, but with a slight amount of noise that breaks ties 

272 and thus jumbles the contractions a bit. 

273 """ 

274 return random.gauss(1.0, 0.01) * (size12 - size1 - size2) 

275 

276 

277_COST_FNS = { 

278 'memory-removed': cost_memory_removed, 

279 'memory-removed-jitter': cost_memory_removed_jitter, 

280} 

281 

282 

283class BranchBound(PathOptimizer): 

284 """ 

285 Explores possible pair contractions in a depth-first recursive manner like 

286 the ``optimal`` approach, but with extra heuristic early pruning of branches 

287 as well sieving by ``memory_limit`` and the best path found so far. Returns 

288 the lowest cost path. This algorithm still scales factorially with respect 

289 to the elements in the list ``input_sets`` if ``nbranch`` is not set, but it 

290 scales exponentially like ``nbranch**len(input_sets)`` otherwise. 

291 

292 Parameters 

293 ---------- 

294 nbranch : None or int, optional 

295 How many branches to explore at each contraction step. If None, explore 

296 all possible branches. If an integer, branch into this many paths at 

297 each step. Defaults to None. 

298 cutoff_flops_factor : float, optional 

299 If at any point, a path is doing this much worse than the best path 

300 found so far was, terminate it. The larger this is made, the more paths 

301 will be fully explored and the slower the algorithm. Defaults to 4. 

302 minimize : {'flops', 'size'}, optional 

303 Whether to optimize the path with regard primarily to the total 

304 estimated flop-count, or the size of the largest intermediate. The 

305 option not chosen will still be used as a secondary criterion. 

306 cost_fn : callable, optional 

307 A function that returns a heuristic 'cost' of a potential contraction 

308 with which to sort candidates. Should have signature 

309 ``cost_fn(size12, size1, size2, k12, k1, k2)``. 

310 """ 

311 def __init__(self, nbranch=None, cutoff_flops_factor=4, minimize='flops', cost_fn='memory-removed'): 

312 self.nbranch = nbranch 

313 self.cutoff_flops_factor = cutoff_flops_factor 

314 self.minimize = minimize 

315 self.cost_fn = _COST_FNS.get(cost_fn, cost_fn) 

316 

317 self.better = get_better_fn(minimize) 

318 self.best = {'flops': float('inf'), 'size': float('inf')} 

319 self.best_progress = defaultdict(lambda: float('inf')) 

320 

321 @property 

322 def path(self): 

323 return ssa_to_linear(self.best['ssa_path']) 

324 

325 def __call__(self, inputs, output, size_dict, memory_limit=None): 

326 """ 

327 

328 Parameters 

329 ---------- 

330 input_sets : list 

331 List of sets that represent the lhs side of the einsum subscript 

332 output_set : set 

333 Set that represents the rhs side of the overall einsum subscript 

334 idx_dict : dictionary 

335 Dictionary of index sizes 

336 memory_limit : int 

337 The maximum number of elements in a temporary array 

338 

339 Returns 

340 ------- 

341 path : list 

342 The contraction order within the memory limit constraint. 

343 

344 Examples 

345 -------- 

346 >>> isets = [set('abd'), set('ac'), set('bdc')] 

347 >>> oset = set('') 

348 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 

349 >>> optimal(isets, oset, idx_sizes, 5000) 

350 [(0, 2), (0, 1)] 

351 """ 

352 self._check_args_against_first_call(inputs, output, size_dict) 

353 

354 inputs = tuple(map(frozenset, inputs)) 

355 output = frozenset(output) 

356 

357 size_cache = {k: helpers.compute_size_by_dict(k, size_dict) for k in inputs} 

358 result_cache = {} 

359 

360 def _branch_iterate(path, inputs, remaining, flops, size): 

361 

362 # reached end of path (only ever get here if flops is best found so far) 

363 if len(remaining) == 1: 

364 self.best['size'] = size 

365 self.best['flops'] = flops 

366 self.best['ssa_path'] = path 

367 return 

368 

369 def _assess_candidate(k1, k2, i, j): 

370 # find resulting indices and flops 

371 try: 

372 k12, flops12 = result_cache[k1, k2] 

373 except KeyError: 

374 k12, flops12 = result_cache[k1, k2] = calc_k12_flops(inputs, output, remaining, i, j, size_dict) 

375 

376 try: 

377 size12 = size_cache[k12] 

378 except KeyError: 

379 size12 = size_cache[k12] = helpers.compute_size_by_dict(k12, size_dict) 

380 

381 new_flops = flops + flops12 

382 new_size = max(size, size12) 

383 

384 # sieve based on current best i.e. check flops and size still better 

385 if not self.better(new_flops, new_size, self.best['flops'], self.best['size']): 

386 return None 

387 

388 # compare to how the best method was doing as this point 

389 if new_flops < self.best_progress[len(inputs)]: 

390 self.best_progress[len(inputs)] = new_flops 

391 # sieve based on current progress relative to best 

392 elif new_flops > self.cutoff_flops_factor * self.best_progress[len(inputs)]: 

393 return None 

394 

395 # sieve based on memory limit 

396 if (memory_limit not in _UNLIMITED_MEM) and (size12 > memory_limit): 

397 # terminate path here, but check all-terms contract first 

398 new_flops = flops + _compute_oversize_flops(inputs, remaining, output, size_dict) 

399 if new_flops < self.best['flops']: 

400 self.best['flops'] = new_flops 

401 self.best['ssa_path'] = path + (tuple(remaining), ) 

402 return None 

403 

404 # set cost heuristic in order to locally sort possible contractions 

405 size1, size2 = size_cache[inputs[i]], size_cache[inputs[j]] 

406 cost = self.cost_fn(size12, size1, size2, k12, k1, k2) 

407 

408 return cost, flops12, new_flops, new_size, (i, j), k12 

409 

410 # check all possible remaining paths 

411 candidates = [] 

412 for i, j in itertools.combinations(remaining, 2): 

413 if i > j: 

414 i, j = j, i 

415 k1, k2 = inputs[i], inputs[j] 

416 

417 # initially ignore outer products 

418 if k1.isdisjoint(k2): 

419 continue 

420 

421 candidate = _assess_candidate(k1, k2, i, j) 

422 if candidate: 

423 heapq.heappush(candidates, candidate) 

424 

425 # assess outer products if nothing left 

426 if not candidates: 

427 for i, j in itertools.combinations(remaining, 2): 

428 if i > j: 

429 i, j = j, i 

430 k1, k2 = inputs[i], inputs[j] 

431 candidate = _assess_candidate(k1, k2, i, j) 

432 if candidate: 

433 heapq.heappush(candidates, candidate) 

434 

435 # recurse into all or some of the best candidate contractions 

436 bi = 0 

437 while (self.nbranch is None or bi < self.nbranch) and candidates: 

438 _, _, new_flops, new_size, (i, j), k12 = heapq.heappop(candidates) 

439 _branch_iterate(path=path + ((i, j), ), 

440 inputs=inputs + (k12, ), 

441 remaining=(remaining - {i, j}) | {len(inputs)}, 

442 flops=new_flops, 

443 size=new_size) 

444 bi += 1 

445 

446 _branch_iterate(path=(), inputs=inputs, remaining=set(range(len(inputs))), flops=0, size=0) 

447 

448 return self.path 

449 

450 

451def branch(inputs, output, size_dict, memory_limit=None, **optimizer_kwargs): 

452 optimizer = BranchBound(**optimizer_kwargs) 

453 return optimizer(inputs, output, size_dict, memory_limit) 

454 

455 

456branch_all = functools.partial(branch, nbranch=None) 

457branch_2 = functools.partial(branch, nbranch=2) 

458branch_1 = functools.partial(branch, nbranch=1) 

459 

460 

461def _get_candidate(output, sizes, remaining, footprints, dim_ref_counts, k1, k2, cost_fn): 

462 either = k1 | k2 

463 two = k1 & k2 

464 one = either - two 

465 k12 = (either & output) | (two & dim_ref_counts[3]) | (one & dim_ref_counts[2]) 

466 cost = cost_fn(helpers.compute_size_by_dict(k12, sizes), footprints[k1], footprints[k2], k12, k1, k2) 

467 id1 = remaining[k1] 

468 id2 = remaining[k2] 

469 if id1 > id2: 

470 k1, id1, k2, id2 = k2, id2, k1, id1 

471 cost = cost, id2, id1 # break ties to ensure determinism 

472 return cost, k1, k2, k12 

473 

474 

475def _push_candidate(output, sizes, remaining, footprints, dim_ref_counts, k1, k2s, queue, push_all, cost_fn): 

476 candidates = (_get_candidate(output, sizes, remaining, footprints, dim_ref_counts, k1, k2, cost_fn) for k2 in k2s) 

477 if push_all: 

478 # want to do this if we e.g. are using a custom 'choose_fn' 

479 for candidate in candidates: 

480 heapq.heappush(queue, candidate) 

481 else: 

482 heapq.heappush(queue, min(candidates)) 

483 

484 

485def _update_ref_counts(dim_to_keys, dim_ref_counts, dims): 

486 for dim in dims: 

487 count = len(dim_to_keys[dim]) 

488 if count <= 1: 

489 dim_ref_counts[2].discard(dim) 

490 dim_ref_counts[3].discard(dim) 

491 elif count == 2: 

492 dim_ref_counts[2].add(dim) 

493 dim_ref_counts[3].discard(dim) 

494 else: 

495 dim_ref_counts[2].add(dim) 

496 dim_ref_counts[3].add(dim) 

497 

498 

499def _simple_chooser(queue, remaining): 

500 """Default contraction chooser that simply takes the minimum cost option. 

501 """ 

502 cost, k1, k2, k12 = heapq.heappop(queue) 

503 if k1 not in remaining or k2 not in remaining: 

504 return None # candidate is obsolete 

505 return cost, k1, k2, k12 

506 

507 

508def ssa_greedy_optimize(inputs, output, sizes, choose_fn=None, cost_fn='memory-removed'): 

509 """ 

510 This is the core function for :func:`greedy` but produces a path with 

511 static single assignment ids rather than recycled linear ids. 

512 SSA ids are cheaper to work with and easier to reason about. 

513 """ 

514 if len(inputs) == 1: 

515 # Perform a single contraction to match output shape. 

516 return [(0, )] 

517 

518 # set the function that assigns a heuristic cost to a possible contraction 

519 cost_fn = _COST_FNS.get(cost_fn, cost_fn) 

520 

521 # set the function that chooses which contraction to take 

522 if choose_fn is None: 

523 choose_fn = _simple_chooser 

524 push_all = False 

525 else: 

526 # assume chooser wants access to all possible contractions 

527 push_all = True 

528 

529 # A dim that is common to all tensors might as well be an output dim, since it 

530 # cannot be contracted until the final step. This avoids an expensive all-pairs 

531 # comparison to search for possible contractions at each step, leading to speedup 

532 # in many practical problems where all tensors share a common batch dimension. 

533 inputs = list(map(frozenset, inputs)) 

534 output = frozenset(output) | frozenset.intersection(*inputs) 

535 

536 # Deduplicate shapes by eagerly computing Hadamard products. 

537 remaining = {} # key -> ssa_id 

538 ssa_ids = itertools.count(len(inputs)) 

539 ssa_path = [] 

540 for ssa_id, key in enumerate(inputs): 

541 if key in remaining: 

542 ssa_path.append((remaining[key], ssa_id)) 

543 remaining[key] = next(ssa_ids) 

544 else: 

545 remaining[key] = ssa_id 

546 

547 # Keep track of possible contraction dims. 

548 dim_to_keys = defaultdict(set) 

549 for key in remaining: 

550 for dim in key - output: 

551 dim_to_keys[dim].add(key) 

552 

553 # Keep track of the number of tensors using each dim; when the dim is no longer 

554 # used it can be contracted. Since we specialize to binary ops, we only care about 

555 # ref counts of >=2 or >=3. 

556 dim_ref_counts = { 

557 count: set(dim for dim, keys in dim_to_keys.items() if len(keys) >= count) - output 

558 for count in [2, 3] 

559 } 

560 

561 # Compute separable part of the objective function for contractions. 

562 footprints = {key: helpers.compute_size_by_dict(key, sizes) for key in remaining} 

563 

564 # Find initial candidate contractions. 

565 queue = [] 

566 for dim, keys in dim_to_keys.items(): 

567 keys = sorted(keys, key=remaining.__getitem__) 

568 for i, k1 in enumerate(keys[:-1]): 

569 k2s = keys[1 + i:] 

570 _push_candidate(output, sizes, remaining, footprints, dim_ref_counts, k1, k2s, queue, push_all, cost_fn) 

571 

572 # Greedily contract pairs of tensors. 

573 while queue: 

574 

575 con = choose_fn(queue, remaining) 

576 if con is None: 

577 continue # allow choose_fn to flag all candidates obsolete 

578 cost, k1, k2, k12 = con 

579 

580 ssa_id1 = remaining.pop(k1) 

581 ssa_id2 = remaining.pop(k2) 

582 for dim in k1 - output: 

583 dim_to_keys[dim].remove(k1) 

584 for dim in k2 - output: 

585 dim_to_keys[dim].remove(k2) 

586 ssa_path.append((ssa_id1, ssa_id2)) 

587 if k12 in remaining: 

588 ssa_path.append((remaining[k12], next(ssa_ids))) 

589 else: 

590 for dim in k12 - output: 

591 dim_to_keys[dim].add(k12) 

592 remaining[k12] = next(ssa_ids) 

593 _update_ref_counts(dim_to_keys, dim_ref_counts, k1 | k2 - output) 

594 footprints[k12] = helpers.compute_size_by_dict(k12, sizes) 

595 

596 # Find new candidate contractions. 

597 k1 = k12 

598 k2s = set(k2 for dim in k1 for k2 in dim_to_keys[dim]) 

599 k2s.discard(k1) 

600 if k2s: 

601 _push_candidate(output, sizes, remaining, footprints, dim_ref_counts, k1, k2s, queue, push_all, cost_fn) 

602 

603 # Greedily compute pairwise outer products. 

604 queue = [(helpers.compute_size_by_dict(key & output, sizes), ssa_id, key) for key, ssa_id in remaining.items()] 

605 heapq.heapify(queue) 

606 _, ssa_id1, k1 = heapq.heappop(queue) 

607 while queue: 

608 _, ssa_id2, k2 = heapq.heappop(queue) 

609 ssa_path.append((min(ssa_id1, ssa_id2), max(ssa_id1, ssa_id2))) 

610 k12 = (k1 | k2) & output 

611 cost = helpers.compute_size_by_dict(k12, sizes) 

612 ssa_id12 = next(ssa_ids) 

613 _, ssa_id1, k1 = heapq.heappushpop(queue, (cost, ssa_id12, k12)) 

614 

615 return ssa_path 

616 

617 

618def greedy(inputs, output, size_dict, memory_limit=None, choose_fn=None, cost_fn='memory-removed'): 

619 """ 

620 Finds the path by a three stage algorithm: 

621 

622 1. Eagerly compute Hadamard products. 

623 2. Greedily compute contractions to maximize ``removed_size`` 

624 3. Greedily compute outer products. 

625 

626 This algorithm scales quadratically with respect to the 

627 maximum number of elements sharing a common dim. 

628 

629 Parameters 

630 ---------- 

631 inputs : list 

632 List of sets that represent the lhs side of the einsum subscript 

633 output : set 

634 Set that represents the rhs side of the overall einsum subscript 

635 size_dict : dictionary 

636 Dictionary of index sizes 

637 memory_limit : int 

638 The maximum number of elements in a temporary array 

639 choose_fn : callable, optional 

640 A function that chooses which contraction to perform from the queu 

641 cost_fn : callable, optional 

642 A function that assigns a potential contraction a cost. 

643 

644 Returns 

645 ------- 

646 path : list 

647 The contraction order (a list of tuples of ints). 

648 

649 Examples 

650 -------- 

651 >>> isets = [set('abd'), set('ac'), set('bdc')] 

652 >>> oset = set('') 

653 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 

654 >>> greedy(isets, oset, idx_sizes) 

655 [(0, 2), (0, 1)] 

656 """ 

657 if memory_limit not in _UNLIMITED_MEM: 

658 return branch(inputs, output, size_dict, memory_limit, nbranch=1, cost_fn=cost_fn) 

659 

660 ssa_path = ssa_greedy_optimize(inputs, output, size_dict, cost_fn=cost_fn, choose_fn=choose_fn) 

661 return ssa_to_linear(ssa_path) 

662 

663 

664def _tree_to_sequence(c): 

665 """ 

666 Converts a contraction tree to a contraction path as it has to be 

667 returned by path optimizers. A contraction tree can either be an int 

668 (=no contraction) or a tuple containing the terms to be contracted. An 

669 arbitrary number (>= 1) of terms can be contracted at once. Note that 

670 contractions are commutative, e.g. (j, k, l) = (k, l, j). Note that in 

671 general, solutions are not unique. 

672 

673 Parameters 

674 ---------- 

675 c : tuple or int 

676 Contraction tree 

677 

678 Returns 

679 ------- 

680 path : list[set[int]] 

681 Contraction path 

682 

683 Examples 

684 -------- 

685 >>> _tree_to_sequence(((1,2),(0,(4,5,3)))) 

686 [(1, 2), (1, 2, 3), (0, 2), (0, 1)] 

687 """ 

688 

689 # ((1,2),(0,(4,5,3))) --> [(1, 2), (1, 2, 3), (0, 2), (0, 1)] 

690 # 

691 # 0 0 0 (1,2) --> ((1,2),(0,(3,4,5))) 

692 # 1 3 (1,2) --> (0,(3,4,5)) 

693 # 2 --> 4 --> (3,4,5) 

694 # 3 5 

695 # 4 (1,2) 

696 # 5 

697 # 

698 # this function iterates through the table shown above from right to left; 

699 

700 if type(c) == int: 

701 return [] 

702 

703 c = [c] # list of remaining contractions (lower part of columns shown above) 

704 t = [] # list of elementary tensors (upper part of colums) 

705 s = [] # resulting contraction sequence 

706 

707 while len(c) > 0: 

708 j = c.pop(-1) 

709 s.insert(0, tuple()) 

710 

711 for i in sorted([i for i in j if type(i) == int]): 

712 s[0] += (sum(1 for q in t if q < i), ) 

713 t.insert(s[0][-1], i) 

714 

715 for i in [i for i in j if type(i) != int]: 

716 s[0] += (len(t) + len(c), ) 

717 c.append(i) 

718 

719 return s 

720 

721 

722def _find_disconnected_subgraphs(inputs, output): 

723 """ 

724 Finds disconnected subgraphs in the given list of inputs. Inputs are 

725 connected if they share summation indices. Note: Disconnected subgraphs 

726 can be contracted independently before forming outer products. 

727 

728 Parameters 

729 ---------- 

730 inputs : list[set] 

731 List of sets that represent the lhs side of the einsum subscript 

732 output : set 

733 Set that represents the rhs side of the overall einsum subscript 

734 

735 Returns 

736 ------- 

737 subgraphs : list[set[int]] 

738 List containing sets of indices for each subgraph 

739 

740 Examples 

741 -------- 

742 >>> _find_disconnected_subgraphs([set("ab"), set("c"), set("ad")], set("bd")) 

743 [{0, 2}, {1}] 

744 

745 >>> _find_disconnected_subgraphs([set("ab"), set("c"), set("ad")], set("abd")) 

746 [{0}, {1}, {2}] 

747 """ 

748 

749 subgraphs = [] 

750 unused_inputs = set(range(len(inputs))) 

751 

752 i_sum = set.union(*inputs) - output # all summation indices 

753 

754 while len(unused_inputs) > 0: 

755 g = set() 

756 q = [unused_inputs.pop()] 

757 while len(q) > 0: 

758 j = q.pop() 

759 g.add(j) 

760 i_tmp = i_sum & inputs[j] 

761 n = {k for k in unused_inputs if len(i_tmp & inputs[k]) > 0} 

762 q.extend(n) 

763 unused_inputs.difference_update(n) 

764 

765 subgraphs.append(g) 

766 

767 return subgraphs 

768 

769 

770def _bitmap_select(s, seq): 

771 """Select elements of ``seq`` which are marked by the bitmap set ``s``. 

772 

773 E.g.: 

774 

775 >>> list(_bitmap_select(0b11010, ['A', 'B', 'C', 'D', 'E'])) 

776 ['B', 'D', 'E'] 

777 """ 

778 return (x for x, b in zip(seq, bin(s)[:1:-1]) if b == '1') 

779 

780 

781def _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2): 

782 """Calculates the effective outer indices of the intermediate tensor 

783 corresponding to the subgraph ``s``. 

784 """ 

785 # set of remaining tensors (=g-s) 

786 r = g & (all_tensors ^ s) 

787 # indices of remaining indices: 

788 if r: 

789 i_r = set.union(*_bitmap_select(r, inputs)) 

790 else: 

791 i_r = set() 

792 # contraction indices: 

793 i_contract = i1_cut_i2_wo_output - i_r 

794 return i1_union_i2 - i_contract 

795 

796 

797def _dp_compare_flops(cost1, cost2, i1_union_i2, size_dict, cost_cap, s1, s2, xn, g, all_tensors, inputs, 

798 i1_cut_i2_wo_output, memory_limit, cntrct1, cntrct2): 

799 """Performs the inner comparison of whether the two subgraphs (the bitmaps 

800 ``s1`` and ``s2``) should be merged and added to the dynamic programming 

801 search. Will skip for a number of reasons: 

802 

803 1. If the number of operations to form ``s = s1 | s2`` including previous 

804 contractions is above the cost-cap. 

805 2. If we've already found a better way of making ``s``. 

806 3. If the intermediate tensor corresponding to ``s`` is going to break the 

807 memory limit. 

808 """ 

809 cost = cost1 + cost2 + helpers.compute_size_by_dict(i1_union_i2, size_dict) 

810 if cost <= cost_cap: 

811 s = s1 | s2 

812 if s not in xn or cost < xn[s][1]: 

813 i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2) 

814 mem = helpers.compute_size_by_dict(i, size_dict) 

815 if memory_limit is None or mem <= memory_limit: 

816 xn[s] = (i, cost, (cntrct1, cntrct2)) 

817 

818 

819def _dp_compare_size(cost1, cost2, i1_union_i2, size_dict, cost_cap, s1, s2, xn, g, all_tensors, inputs, 

820 i1_cut_i2_wo_output, memory_limit, cntrct1, cntrct2): 

821 """Like ``_dp_compare_flops`` but sieves the potential contraction based 

822 on the size of the intermediate tensor created, rather than the number of 

823 operations, and so calculates that first. 

824 """ 

825 s = s1 | s2 

826 i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2) 

827 mem = helpers.compute_size_by_dict(i, size_dict) 

828 cost = max(cost1, cost2, mem) 

829 if cost <= cost_cap: 

830 if s not in xn or cost < xn[s][1]: 

831 if memory_limit is None or mem <= memory_limit: 

832 xn[s] = (i, cost, (cntrct1, cntrct2)) 

833 

834 

835def simple_tree_tuple(seq): 

836 """Make a simple left to right binary tree out of iterable ``seq``. 

837 

838 >>> tuple_nest([1, 2, 3, 4]) 

839 (((1, 2), 3), 4) 

840 

841 """ 

842 return functools.reduce(lambda x, y: (x, y), seq) 

843 

844 

845def _dp_parse_out_single_term_ops(inputs, all_inds, ind_counts): 

846 """Take ``inputs`` and parse for single term index operations, i.e. where 

847 an index appears on one tensor and nowhere else. 

848 

849 If a term is completely reduced to a scalar in this way it can be removed 

850 to ``inputs_done``. If only some indices can be summed then add a 'single 

851 term contraction' that will perform this summation. 

852 """ 

853 i_single = {i for i, c in enumerate(all_inds) if ind_counts[c] == 1} 

854 inputs_parsed, inputs_done, inputs_contractions = [], [], [] 

855 for j, i in enumerate(inputs): 

856 i_reduced = i - i_single 

857 if not i_reduced: 

858 # input reduced to scalar already - remove 

859 inputs_done.append((j, )) 

860 else: 

861 # if the input has any index reductions, add single contraction 

862 inputs_parsed.append(i_reduced) 

863 inputs_contractions.append((j, ) if i_reduced != i else j) 

864 

865 return inputs_parsed, inputs_done, inputs_contractions 

866 

867 

868class DynamicProgramming(PathOptimizer): 

869 """ 

870 Finds the optimal path of pairwise contractions without intermediate outer 

871 products based a dynamic programming approach presented in 

872 Phys. Rev. E 90, 033315 (2014) (the corresponding preprint is publically 

873 available at https://arxiv.org/abs/1304.6112). This method is especially 

874 well-suited in the area of tensor network states, where it usually 

875 outperforms all the other optimization strategies. 

876 

877 This algorithm shows exponential scaling with the number of inputs 

878 in the worst case scenario (see example below). If the graph to be 

879 contracted consists of disconnected subgraphs, the algorithm scales 

880 linearly in the number of disconnected subgraphs and only exponentially 

881 with the number of inputs per subgraph. 

882 

883 Parameters 

884 ---------- 

885 minimize : {'flops', 'size'}, optional 

886 Whether to find the contraction that minimizes the number of 

887 operations or the size of the largest intermediate tensor. 

888 cost_cap : {True, False, int}, optional 

889 How to implement cost-capping: 

890 

891 * True - iteratively increase the cost-cap 

892 * False - implement no cost-cap at all 

893 * int - use explicit cost cap 

894 

895 search_outer : bool, optional 

896 In rare circumstances the optimal contraction may involve an outer 

897 product, this option allows searching such contractions but may well 

898 slow down the path finding considerably on all but very small graphs. 

899 """ 

900 def __init__(self, minimize='flops', cost_cap=True, search_outer=False): 

901 

902 # set whether inner function minimizes against flops or size 

903 self.minimize = minimize 

904 self._check_contraction = { 

905 'flops': _dp_compare_flops, 

906 'size': _dp_compare_size, 

907 }[self.minimize] 

908 

909 # set whether inner function considers outer products 

910 self.search_outer = search_outer 

911 self._check_outer = { 

912 False: lambda x: x, 

913 True: lambda x: True, 

914 }[self.search_outer] 

915 

916 self.cost_cap = cost_cap 

917 

918 def __call__(self, inputs, output, size_dict, memory_limit=None): 

919 """ 

920 Parameters 

921 ---------- 

922 inputs : list 

923 List of sets that represent the lhs side of the einsum subscript 

924 output : set 

925 Set that represents the rhs side of the overall einsum subscript 

926 size_dict : dictionary 

927 Dictionary of index sizes 

928 memory_limit : int 

929 The maximum number of elements in a temporary array 

930 

931 Returns 

932 ------- 

933 path : list 

934 The contraction order (a list of tuples of ints). 

935 

936 Examples 

937 -------- 

938 >>> n_in = 3 # exponential scaling 

939 >>> n_out = 2 # linear scaling 

940 >>> s = dict() 

941 >>> i_all = [] 

942 >>> for _ in range(n_out): 

943 >>> i = [set() for _ in range(n_in)] 

944 >>> for j in range(n_in): 

945 >>> for k in range(j+1, n_in): 

946 >>> c = oe.get_symbol(len(s)) 

947 >>> i[j].add(c) 

948 >>> i[k].add(c) 

949 >>> s[c] = 2 

950 >>> i_all.extend(i) 

951 >>> o = DynamicProgramming() 

952 >>> o(i_all, set(), s) 

953 [(1, 2), (0, 4), (1, 2), (0, 2), (0, 1)] 

954 """ 

955 ind_counts = Counter(itertools.chain(*inputs, output)) 

956 all_inds = tuple(ind_counts) 

957 

958 # convert all indices to integers (makes set operations ~10 % faster) 

959 symbol2int = {c: j for j, c in enumerate(all_inds)} 

960 inputs = [set(symbol2int[c] for c in i) for i in inputs] 

961 output = set(symbol2int[c] for c in output) 

962 size_dict = {symbol2int[c]: v for c, v in size_dict.items() if c in symbol2int} 

963 size_dict = [size_dict[j] for j in range(len(size_dict))] 

964 

965 inputs, inputs_done, inputs_contractions = _dp_parse_out_single_term_ops(inputs, all_inds, ind_counts) 

966 

967 if not inputs: 

968 # nothing left to do after single axis reductions! 

969 return _tree_to_sequence(simple_tree_tuple(inputs_done)) 

970 

971 # a list of all neccessary contraction expressions for each of the 

972 # disconnected subgraphs and their size 

973 subgraph_contractions = inputs_done 

974 subgraph_contractions_size = [1] * len(inputs_done) 

975 

976 if self.search_outer: 

977 # optimize everything together if we are considering outer products 

978 subgraphs = [set(range(len(inputs)))] 

979 else: 

980 subgraphs = _find_disconnected_subgraphs(inputs, output) 

981 

982 # the bitmap set of all tensors is computed as it is needed to 

983 # compute set differences: s1 - s2 transforms into 

984 # s1 & (all_tensors ^ s2) 

985 all_tensors = (1 << len(inputs)) - 1 

986 

987 for g in subgraphs: 

988 

989 # dynamic programming approach to compute x[n] for subgraph g; 

990 # x[n][set of n tensors] = (indices, cost, contraction) 

991 # the set of n tensors is represented by a bitmap: if bit j is 1, 

992 # tensor j is in the set, e.g. 0b100101 = {0,2,5}; set unions 

993 # (intersections) can then be computed by bitwise or (and); 

994 x = [None] * 2 + [dict() for j in range(len(g) - 1)] 

995 x[1] = OrderedDict((1 << j, (inputs[j], 0, inputs_contractions[j])) for j in g) 

996 

997 # convert set of tensors g to a bitmap set: 

998 g = functools.reduce(lambda x, y: x | y, (1 << j for j in g)) 

999 

1000 # try to find contraction with cost <= cost_cap and increase 

1001 # cost_cap successively if no such contraction is found; 

1002 # this is a major performance improvement; start with product of 

1003 # output index dimensions as initial cost_cap 

1004 subgraph_inds = set.union(*_bitmap_select(g, inputs)) 

1005 if self.cost_cap is True: 

1006 cost_cap = helpers.compute_size_by_dict(subgraph_inds & output, size_dict) 

1007 elif self.cost_cap is False: 

1008 cost_cap = float('inf') 

1009 else: 

1010 cost_cap = self.cost_cap 

1011 # set the factor to increase the cost by each iteration (ensure > 1) 

1012 cost_increment = max(min(map(size_dict.__getitem__, subgraph_inds)), 2) 

1013 

1014 while len(x[-1]) == 0: 

1015 for n in range(2, len(x[1]) + 1): 

1016 xn = x[n] 

1017 

1018 # try to combine solutions from x[m] and x[n-m] 

1019 for m in range(1, n // 2 + 1): 

1020 for s1, (i1, cost1, cntrct1) in x[m].items(): 

1021 for s2, (i2, cost2, cntrct2) in x[n - m].items(): 

1022 

1023 # can only merge if s1 and s2 are disjoint 

1024 # and avoid e.g. s1={0}, s2={1} and s1={1}, s2={0} 

1025 if (not s1 & s2) and (m != n - m or s1 < s2): 

1026 i1_cut_i2_wo_output = (i1 & i2) - output 

1027 

1028 # maybe ignore outer products: 

1029 if self._check_outer(i1_cut_i2_wo_output): 

1030 

1031 i1_union_i2 = i1 | i2 

1032 self._check_contraction(cost1, cost2, i1_union_i2, size_dict, cost_cap, s1, s2, 

1033 xn, g, all_tensors, inputs, i1_cut_i2_wo_output, 

1034 memory_limit, cntrct1, cntrct2) 

1035 

1036 # increase cost cap for next iteration: 

1037 cost_cap = cost_increment * cost_cap 

1038 

1039 i, cost, contraction = list(x[-1].values())[0] 

1040 subgraph_contractions.append(contraction) 

1041 subgraph_contractions_size.append(helpers.compute_size_by_dict(i, size_dict)) 

1042 

1043 # sort the subgraph contractions by the size of the subgraphs in 

1044 # ascending order (will give the cheapest contractions); note that 

1045 # outer products should be performed pairwise (to use BLAS functions) 

1046 subgraph_contractions = [ 

1047 subgraph_contractions[j] 

1048 for j in sorted(range(len(subgraph_contractions_size)), key=subgraph_contractions_size.__getitem__) 

1049 ] 

1050 

1051 # build the final contraction tree 

1052 tree = simple_tree_tuple(subgraph_contractions) 

1053 return _tree_to_sequence(tree) 

1054 

1055 

1056def dynamic_programming(inputs, output, size_dict, memory_limit=None, **kwargs): 

1057 optimizer = DynamicProgramming(**kwargs) 

1058 return optimizer(inputs, output, size_dict, memory_limit) 

1059 

1060 

1061_AUTO_CHOICES = {} 

1062for i in range(1, 5): 

1063 _AUTO_CHOICES[i] = optimal 

1064for i in range(5, 7): 

1065 _AUTO_CHOICES[i] = branch_all 

1066for i in range(7, 9): 

1067 _AUTO_CHOICES[i] = branch_2 

1068for i in range(9, 15): 

1069 _AUTO_CHOICES[i] = branch_1 

1070 

1071 

1072def auto(inputs, output, size_dict, memory_limit=None): 

1073 """Finds the contraction path by automatically choosing the method based on 

1074 how many input arguments there are. 

1075 """ 

1076 N = len(inputs) 

1077 return _AUTO_CHOICES.get(N, greedy)(inputs, output, size_dict, memory_limit) 

1078 

1079 

1080_AUTO_HQ_CHOICES = {} 

1081for i in range(1, 6): 

1082 _AUTO_HQ_CHOICES[i] = optimal 

1083for i in range(6, 17): 

1084 _AUTO_HQ_CHOICES[i] = dynamic_programming 

1085 

1086 

1087def auto_hq(inputs, output, size_dict, memory_limit=None): 

1088 """Finds the contraction path by automatically choosing the method based on 

1089 how many input arguments there are, but targeting a more generous 

1090 amount of search time than ``'auto'``. 

1091 """ 

1092 from .path_random import random_greedy_128 

1093 

1094 N = len(inputs) 

1095 return _AUTO_HQ_CHOICES.get(N, random_greedy_128)(inputs, output, size_dict, memory_limit) 

1096 

1097 

1098_PATH_OPTIONS = { 

1099 'auto': auto, 

1100 'auto-hq': auto_hq, 

1101 'optimal': optimal, 

1102 'branch-all': branch_all, 

1103 'branch-2': branch_2, 

1104 'branch-1': branch_1, 

1105 'greedy': greedy, 

1106 'eager': greedy, 

1107 'opportunistic': greedy, 

1108 'dp': dynamic_programming, 

1109 'dynamic-programming': dynamic_programming 

1110} 

1111 

1112 

1113def register_path_fn(name, fn): 

1114 """Add path finding function ``fn`` as an option with ``name``. 

1115 """ 

1116 if name in _PATH_OPTIONS: 

1117 raise KeyError("Path optimizer '{}' already exists.".format(name)) 

1118 

1119 _PATH_OPTIONS[name.lower()] = fn 

1120 

1121 

1122def get_path_fn(path_type): 

1123 """Get the correct path finding function from str ``path_type``. 

1124 """ 

1125 if path_type not in _PATH_OPTIONS: 

1126 raise KeyError("Path optimizer '{}' not found, valid options are {}.".format( 

1127 path_type, set(_PATH_OPTIONS.keys()))) 

1128 

1129 return _PATH_OPTIONS[path_type]