Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/paths.py: 17%
501 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
1"""
2Contains the path technology behind opt_einsum in addition to several path helpers
3"""
5import functools
6import heapq
7import itertools
8import operator
9import random
10import re
11from collections import Counter, OrderedDict, defaultdict
12from typing import Any, Callable
13from typing import Counter as CounterType
14from typing import Dict, FrozenSet, Generator, List, Optional, Sequence, Set, Tuple, Union
16import numpy as np
18from .helpers import compute_size_by_dict, flop_count
19from .typing import ArrayIndexType, PathType
21__all__ = [
22 "optimal",
23 "BranchBound",
24 "branch",
25 "greedy",
26 "auto",
27 "auto_hq",
28 "get_path_fn",
29 "DynamicProgramming",
30 "dynamic_programming",
31]
33_UNLIMITED_MEM = {-1, None, float("inf")}
36class PathOptimizer:
37 """Base class for different path optimizers to inherit from.
39 Subclassed optimizers should define a call method with signature:
41 ```python
42 def __call__(self, inputs, output, size_dict, memory_limit=None):
43 \"\"\"
44 **Parameters:**
45 ----------
46 inputs : list[set[str]]
47 The indices of each input array.
48 outputs : set[str]
49 The output indices
50 size_dict : dict[str, int]
51 The size of each index
52 memory_limit : int, optional
53 If given, the maximum allowed memory.
54 \"\"\"
55 # ... compute path here ...
56 return path
57 ```
59 where `path` is a list of int-tuples specifying a contraction order.
60 """
62 def _check_args_against_first_call(
63 self,
64 inputs: List[ArrayIndexType],
65 output: ArrayIndexType,
66 size_dict: Dict[str, int],
67 ) -> None:
68 """Utility that stateful optimizers can use to ensure they are not
69 called with different contractions across separate runs.
70 """
71 args = (inputs, output, size_dict)
72 if not hasattr(self, "_first_call_args"):
73 # simply set the attribute as currently there is no global PathOptimizer init
74 self._first_call_args = args
75 elif args != self._first_call_args:
76 raise ValueError(
77 "The arguments specifying the contraction that this path optimizer "
78 "instance was called with have changed - try creating a new instance."
79 )
81 def __call__(
82 self,
83 inputs: List[ArrayIndexType],
84 output: ArrayIndexType,
85 size_dict: Dict[str, int],
86 memory_limit: Optional[int] = None,
87 ) -> PathType:
88 raise NotImplementedError
91def ssa_to_linear(ssa_path: PathType) -> PathType:
92 """
93 Convert a path with static single assignment ids to a path with recycled
94 linear ids. For example:
96 ```python
97 ssa_to_linear([(0, 3), (2, 4), (1, 5)])
98 #> [(0, 3), (1, 2), (0, 1)]
99 ```
100 """
101 ids = np.arange(1 + max(map(max, ssa_path)), dtype=np.int32) # type: ignore
102 path = []
103 for ssa_ids in ssa_path:
104 path.append(tuple(int(ids[ssa_id]) for ssa_id in ssa_ids))
105 for ssa_id in ssa_ids:
106 ids[ssa_id:] -= 1
107 return path
110def linear_to_ssa(path: PathType) -> PathType:
111 """
112 Convert a path with recycled linear ids to a path with static single
113 assignment ids. For example::
115 ```python
116 linear_to_ssa([(0, 3), (1, 2), (0, 1)])
117 #> [(0, 3), (2, 4), (1, 5)]
118 ```
119 """
120 num_inputs = sum(map(len, path)) - len(path) + 1
121 linear_to_ssa = list(range(num_inputs))
122 new_ids = itertools.count(num_inputs)
123 ssa_path = []
124 for ids in path:
125 ssa_path.append(tuple(linear_to_ssa[id_] for id_ in ids))
126 for id_ in sorted(ids, reverse=True):
127 del linear_to_ssa[id_]
128 linear_to_ssa.append(next(new_ids))
129 return ssa_path
132def calc_k12_flops(
133 inputs: Tuple[FrozenSet[str]],
134 output: FrozenSet[str],
135 remaining: FrozenSet[int],
136 i: int,
137 j: int,
138 size_dict: Dict[str, int],
139) -> Tuple[FrozenSet[str], int]:
140 """
141 Calculate the resulting indices and flops for a potential pairwise
142 contraction - used in the recursive (optimal/branch) algorithms.
144 **Parameters:**
146 - **inputs** - *(tuple[frozenset[str]])* The indices of each tensor in this contraction, note this includes
147 tensors unavailable to contract as static single assignment is used ->
148 contracted tensors are not removed from the list.
149 - **output** - *(frozenset[str])* The set of output indices for the whole contraction.
150 - **remaining** - *(frozenset[int])* The set of indices (corresponding to ``inputs``) of tensors still available to contract.
151 - **i** - *(int)* Index of potential tensor to contract.
152 - **j** - *(int)* Index of potential tensor to contract.
153 - **size_dict : dict[str, int] )* Size mapping of all the indices.
155 **Returns:**
157 - **k12** - *(frozenset)* The resulting indices of the potential tensor.
158 - **cost** - *(int)* Estimated flop count of operation.
159 """
160 k1, k2 = inputs[i], inputs[j]
161 either = k1 | k2
162 shared = k1 & k2
163 keep = frozenset.union(output, *map(inputs.__getitem__, remaining - {i, j}))
165 k12 = either & keep
166 cost = flop_count(either, bool(shared - keep), 2, size_dict)
168 return k12, cost
171def _compute_oversize_flops(
172 inputs: Tuple[FrozenSet[str]],
173 remaining: List[ArrayIndexType],
174 output: ArrayIndexType,
175 size_dict: Dict[str, int],
176) -> int:
177 """
178 Compute the flop count for a contraction of all remaining arguments. This
179 is used when a memory limit means that no pairwise contractions can be made.
180 """
181 idx_contraction = frozenset.union(*map(inputs.__getitem__, remaining)) # type: ignore
182 inner = idx_contraction - output
183 num_terms = len(remaining)
184 return flop_count(idx_contraction, bool(inner), num_terms, size_dict)
187def optimal(
188 inputs: List[ArrayIndexType],
189 output: ArrayIndexType,
190 size_dict: Dict[str, int],
191 memory_limit: Optional[int] = None,
192) -> PathType:
193 """
194 Computes all possible pair contractions in a depth-first recursive manner,
195 sieving results based on `memory_limit` and the best path found so far.
196 **Returns:** the lowest cost path. This algorithm scales factoriallly with
197 respect to the elements in the list `input_sets`.
199 **Parameters:**
201 - **inputs** - *(list)* List of sets that represent the lhs side of the einsum subscript.
202 - **output** - *(set)* Set that represents the rhs side of the overall einsum subscript.
203 - **size_dict** - *(dictionary)* Dictionary of index sizes.
204 - **memory_limit** - *(int)* The maximum number of elements in a temporary array.
206 **Returns:**
208 - **path** - *(list)* The optimal contraction order within the memory limit constraint.
210 **Examples:**
212 ```python
213 isets = [set('abd'), set('ac'), set('bdc')]
214 oset = set('')
215 idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
216 optimal(isets, oset, idx_sizes, 5000)
217 #> [(0, 2), (0, 1)]
218 ```
219 """
220 inputs_set = tuple(map(frozenset, inputs)) # type: ignore
221 output_set = frozenset(output)
223 best_flops = {"flops": float("inf")}
224 best_ssa_path = {"ssa_path": (tuple(range(len(inputs))),)}
225 size_cache: Dict[FrozenSet[str], int] = {}
226 result_cache: Dict[Tuple[ArrayIndexType, ArrayIndexType], Tuple[FrozenSet[str], int]] = {}
228 def _optimal_iterate(path, remaining, inputs, flops):
230 # reached end of path (only ever get here if flops is best found so far)
231 if len(remaining) == 1:
232 best_flops["flops"] = flops
233 best_ssa_path["ssa_path"] = path
234 return
236 # check all possible remaining paths
237 for i, j in itertools.combinations(remaining, 2):
238 if i > j:
239 i, j = j, i
240 key = (inputs[i], inputs[j])
241 try:
242 k12, flops12 = result_cache[key]
243 except KeyError:
244 k12, flops12 = result_cache[key] = calc_k12_flops(inputs, output_set, remaining, i, j, size_dict)
246 # sieve based on current best flops
247 new_flops = flops + flops12
248 if new_flops >= best_flops["flops"]:
249 continue
251 # sieve based on memory limit
252 if memory_limit not in _UNLIMITED_MEM:
253 try:
254 size12 = size_cache[k12]
255 except KeyError:
256 size12 = size_cache[k12] = compute_size_by_dict(k12, size_dict)
258 # possibly terminate this path with an all-terms einsum
259 if size12 > memory_limit:
260 new_flops = flops + _compute_oversize_flops(inputs, remaining, output_set, size_dict)
261 if new_flops < best_flops["flops"]:
262 best_flops["flops"] = new_flops
263 best_ssa_path["ssa_path"] = path + (tuple(remaining),)
264 continue
266 # add contraction and recurse into all remaining
267 _optimal_iterate(
268 path=path + ((i, j),),
269 inputs=inputs + (k12,),
270 remaining=remaining - {i, j} | {len(inputs)},
271 flops=new_flops,
272 )
274 _optimal_iterate(path=(), inputs=inputs_set, remaining=set(range(len(inputs))), flops=0)
276 return ssa_to_linear(best_ssa_path["ssa_path"])
279# functions for comparing which of two paths is 'better'
282def better_flops_first(flops: int, size: int, best_flops: int, best_size: int) -> bool:
283 return (flops, size) < (best_flops, best_size)
286def better_size_first(flops: int, size: int, best_flops: int, best_size: int) -> bool:
287 return (size, flops) < (best_size, best_flops)
290_BETTER_FNS = {
291 "flops": better_flops_first,
292 "size": better_size_first,
293}
296def get_better_fn(key: str) -> Callable[[int, int, int, int], bool]:
297 return _BETTER_FNS[key]
300# functions for assigning a heuristic 'cost' to a potential contraction
303def cost_memory_removed(size12: int, size1: int, size2: int, k12: int, k1: int, k2: int) -> float:
304 """The default heuristic cost, corresponding to the total reduction in
305 memory of performing a contraction.
306 """
307 return size12 - size1 - size2
310def cost_memory_removed_jitter(size12: int, size1: int, size2: int, k12: int, k1: int, k2: int) -> float:
311 """Like memory-removed, but with a slight amount of noise that breaks ties
312 and thus jumbles the contractions a bit.
313 """
314 return random.gauss(1.0, 0.01) * (size12 - size1 - size2)
317_COST_FNS = {
318 "memory-removed": cost_memory_removed,
319 "memory-removed-jitter": cost_memory_removed_jitter,
320}
323class BranchBound(PathOptimizer):
324 """
325 Explores possible pair contractions in a depth-first recursive manner like
326 the `optimal` approach, but with extra heuristic early pruning of branches
327 as well sieving by `memory_limit` and the best path found so far. **Returns:**
328 the lowest cost path. This algorithm still scales factorially with respect
329 to the elements in the list `input_sets` if `nbranch` is not set, but it
330 scales exponentially like `nbranch**len(input_sets)` otherwise.
332 **Parameters:**
334 - **nbranch** - *(None or int, optional)* How many branches to explore at each contraction step. If None, explore
335 all possible branches. If an integer, branch into this many paths at
336 each step. Defaults to None.
337 - **cutoff_flops_factor** - *(float, optional)* If at any point, a path is doing this much worse than the best path
338 found so far was, terminate it. The larger this is made, the more paths
339 will be fully explored and the slower the algorithm. Defaults to 4.
340 - **minimize** - *({'flops', 'size'}, optional)* Whether to optimize the path with regard primarily to the total
341 estimated flop-count, or the size of the largest intermediate. The
342 option not chosen will still be used as a secondary criterion.
343 - **cost_fn** - *(callable, optional)* A function that returns a heuristic 'cost' of a potential contraction
344 with which to sort candidates. Should have signature
345 `cost_fn(size12, size1, size2, k12, k1, k2)`.
346 """
348 def __init__(
349 self,
350 nbranch=None,
351 cutoff_flops_factor=4,
352 minimize="flops",
353 cost_fn="memory-removed",
354 ):
355 if (nbranch is not None) and nbranch < 1:
356 raise ValueError(f"The number of branches must be at least one, `nbranch={nbranch}`.")
358 self.nbranch = nbranch
359 self.cutoff_flops_factor = cutoff_flops_factor
360 self.minimize = minimize
361 self.cost_fn = _COST_FNS.get(cost_fn, cost_fn)
363 self.better = get_better_fn(minimize)
364 self.best = {"flops": float("inf"), "size": float("inf")}
365 self.best_progress = defaultdict(lambda: float("inf"))
367 @property
368 def path(self) -> PathType:
369 return ssa_to_linear(self.best["ssa_path"])
371 def __call__(
372 self,
373 inputs_: List[ArrayIndexType],
374 output_: ArrayIndexType,
375 size_dict: Dict[str, int],
376 memory_limit: Optional[int] = None,
377 ) -> PathType:
378 """
380 **Parameters:**
382 - **input_sets** - *(list)* List of sets that represent the lhs side of the einsum subscript
383 - **output_set** - *(set)* Set that represents the rhs side of the overall einsum subscript
384 - **idx_dict** - *(dictionary)* Dictionary of index sizes
385 - **memory_limit** - *(int)* The maximum number of elements in a temporary array
387 **Returns:**
389 - **path** - *(list)* The contraction order within the memory limit constraint.
391 **Examples:**
393 ```python
394 isets = [set('abd'), set('ac'), set('bdc')]
395 oset = set('')
396 idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
397 optimal(isets, oset, idx_sizes, 5000)
398 #> [(0, 2), (0, 1)]
399 """
400 self._check_args_against_first_call(inputs_, output_, size_dict)
402 inputs: Tuple[FrozenSet[str]] = tuple(map(frozenset, inputs_)) # type: ignore
403 output: FrozenSet[str] = frozenset(output_)
405 size_cache = {k: compute_size_by_dict(k, size_dict) for k in inputs}
406 result_cache: Dict[Tuple[FrozenSet[str], FrozenSet[str]], Tuple[FrozenSet[str], int]] = {}
408 def _branch_iterate(path, inputs, remaining, flops, size):
410 # reached end of path (only ever get here if flops is best found so far)
411 if len(remaining) == 1:
412 self.best["size"] = size
413 self.best["flops"] = flops
414 self.best["ssa_path"] = path
415 return
417 def _assess_candidate(k1: FrozenSet[str], k2: FrozenSet[str], i: int, j: int) -> Any:
418 # find resulting indices and flops
419 try:
420 k12, flops12 = result_cache[k1, k2]
421 except KeyError:
422 k12, flops12 = result_cache[k1, k2] = calc_k12_flops(inputs, output, remaining, i, j, size_dict)
424 try:
425 size12 = size_cache[k12]
426 except KeyError:
427 size12 = size_cache[k12] = compute_size_by_dict(k12, size_dict)
429 new_flops = flops + flops12
430 new_size = max(size, size12)
432 # sieve based on current best i.e. check flops and size still better
433 if not self.better(new_flops, new_size, self.best["flops"], self.best["size"]):
434 return None
436 # compare to how the best method was doing as this point
437 if new_flops < self.best_progress[len(inputs)]:
438 self.best_progress[len(inputs)] = new_flops
439 # sieve based on current progress relative to best
440 elif new_flops > self.cutoff_flops_factor * self.best_progress[len(inputs)]:
441 return None
443 # sieve based on memory limit
444 if (memory_limit not in _UNLIMITED_MEM) and (size12 > memory_limit): # type: ignore
445 # terminate path here, but check all-terms contract first
446 new_flops = flops + _compute_oversize_flops(inputs, remaining, output_, size_dict)
447 if new_flops < self.best["flops"]:
448 self.best["flops"] = new_flops
449 self.best["ssa_path"] = path + (tuple(remaining),)
450 return None
452 # set cost heuristic in order to locally sort possible contractions
453 size1, size2 = size_cache[inputs[i]], size_cache[inputs[j]]
454 cost = self.cost_fn(size12, size1, size2, k12, k1, k2)
456 return cost, flops12, new_flops, new_size, (i, j), k12
458 # check all possible remaining paths
459 candidates = []
460 for i, j in itertools.combinations(remaining, 2):
461 if i > j:
462 i, j = j, i
463 k1, k2 = inputs[i], inputs[j]
465 # initially ignore outer products
466 if k1.isdisjoint(k2):
467 continue
469 candidate = _assess_candidate(k1, k2, i, j)
470 if candidate:
471 heapq.heappush(candidates, candidate)
473 # assess outer products if nothing left
474 if not candidates:
475 for i, j in itertools.combinations(remaining, 2):
476 if i > j:
477 i, j = j, i
478 k1, k2 = inputs[i], inputs[j]
479 candidate = _assess_candidate(k1, k2, i, j)
480 if candidate:
481 heapq.heappush(candidates, candidate)
483 # recurse into all or some of the best candidate contractions
484 bi = 0
485 while (self.nbranch is None or bi < self.nbranch) and candidates:
486 _, _, new_flops, new_size, (i, j), k12 = heapq.heappop(candidates)
487 _branch_iterate(
488 path=path + ((i, j),),
489 inputs=inputs + (k12,),
490 remaining=(remaining - {i, j}) | {len(inputs)},
491 flops=new_flops,
492 size=new_size,
493 )
494 bi += 1
496 _branch_iterate(path=(), inputs=inputs, remaining=set(range(len(inputs))), flops=0, size=0)
498 return self.path
501def branch(
502 inputs: List[ArrayIndexType],
503 output: ArrayIndexType,
504 size_dict: Dict[str, int],
505 memory_limit: Optional[int] = None,
506 **optimizer_kwargs: Dict[str, Any],
507) -> PathType:
508 optimizer = BranchBound(**optimizer_kwargs)
509 return optimizer(inputs, output, size_dict, memory_limit)
512branch_all = functools.partial(branch, nbranch=None)
513branch_2 = functools.partial(branch, nbranch=2)
514branch_1 = functools.partial(branch, nbranch=1)
516GreedyCostType = Tuple[int, int, int]
517GreedyContractionType = Tuple[GreedyCostType, ArrayIndexType, ArrayIndexType, ArrayIndexType] # Cost, t1,t2->t3
520def _get_candidate(
521 output: ArrayIndexType,
522 sizes: Dict[str, int],
523 remaining: Dict[ArrayIndexType, int],
524 footprints: Dict[ArrayIndexType, int],
525 dim_ref_counts: Dict[int, Set[str]],
526 k1: ArrayIndexType,
527 k2: ArrayIndexType,
528 cost_fn: Any,
529) -> GreedyContractionType:
530 either = k1 | k2
531 two = k1 & k2
532 one = either - two
533 k12 = (either & output) | (two & dim_ref_counts[3]) | (one & dim_ref_counts[2])
534 cost = cost_fn(
535 compute_size_by_dict(k12, sizes),
536 footprints[k1],
537 footprints[k2],
538 k12,
539 k1,
540 k2,
541 )
542 id1 = remaining[k1]
543 id2 = remaining[k2]
544 if id1 > id2:
545 k1, id1, k2, id2 = k2, id2, k1, id1
546 cost = cost, id2, id1 # break ties to ensure determinism
547 return cost, k1, k2, k12
550def _push_candidate(
551 output: ArrayIndexType,
552 sizes: Dict[str, Any],
553 remaining: Dict[ArrayIndexType, int],
554 footprints: Dict[ArrayIndexType, int],
555 dim_ref_counts: Dict[int, Set[str]],
556 k1: ArrayIndexType,
557 k2s: List[ArrayIndexType],
558 queue: List[GreedyContractionType],
559 push_all: bool,
560 cost_fn: Any,
561) -> None:
562 candidates = (_get_candidate(output, sizes, remaining, footprints, dim_ref_counts, k1, k2, cost_fn) for k2 in k2s)
563 if push_all:
564 # want to do this if we e.g. are using a custom 'choose_fn'
565 for candidate in candidates:
566 heapq.heappush(queue, candidate)
567 else:
568 heapq.heappush(queue, min(candidates))
571def _update_ref_counts(
572 dim_to_keys: Dict[str, Set[ArrayIndexType]],
573 dim_ref_counts: Dict[int, Set[str]],
574 dims: ArrayIndexType,
575) -> None:
576 for dim in dims:
577 count = len(dim_to_keys[dim])
578 if count <= 1:
579 dim_ref_counts[2].discard(dim)
580 dim_ref_counts[3].discard(dim)
581 elif count == 2:
582 dim_ref_counts[2].add(dim)
583 dim_ref_counts[3].discard(dim)
584 else:
585 dim_ref_counts[2].add(dim)
586 dim_ref_counts[3].add(dim)
589def _simple_chooser(queue, remaining):
590 """Default contraction chooser that simply takes the minimum cost option."""
591 cost, k1, k2, k12 = heapq.heappop(queue)
592 if k1 not in remaining or k2 not in remaining:
593 return None # candidate is obsolete
594 return cost, k1, k2, k12
597def ssa_greedy_optimize(
598 inputs: List[ArrayIndexType],
599 output: ArrayIndexType,
600 sizes: Dict[str, int],
601 choose_fn: Any = None,
602 cost_fn: Any = "memory-removed",
603) -> PathType:
604 """
605 This is the core function for :func:`greedy` but produces a path with
606 static single assignment ids rather than recycled linear ids.
607 SSA ids are cheaper to work with and easier to reason about.
608 """
609 if len(inputs) == 1:
610 # Perform a single contraction to match output shape.
611 return [(0,)]
613 # set the function that assigns a heuristic cost to a possible contraction
614 cost_fn = _COST_FNS.get(cost_fn, cost_fn)
616 # set the function that chooses which contraction to take
617 if choose_fn is None:
618 choose_fn = _simple_chooser
619 push_all = False
620 else:
621 # assume chooser wants access to all possible contractions
622 push_all = True
624 # A dim that is common to all tensors might as well be an output dim, since it
625 # cannot be contracted until the final step. This avoids an expensive all-pairs
626 # comparison to search for possible contractions at each step, leading to speedup
627 # in many practical problems where all tensors share a common batch dimension.
628 fs_inputs = [frozenset(x) for x in inputs]
629 output = frozenset(output) | frozenset.intersection(*fs_inputs)
631 # Deduplicate shapes by eagerly computing Hadamard products.
632 remaining: Dict[ArrayIndexType, int] = {} # key -> ssa_id
633 ssa_ids = itertools.count(len(fs_inputs))
634 ssa_path = []
635 for ssa_id, key in enumerate(fs_inputs):
636 if key in remaining:
637 ssa_path.append((remaining[key], ssa_id))
638 remaining[key] = next(ssa_ids)
639 else:
640 remaining[key] = ssa_id
642 # Keep track of possible contraction dims.
643 dim_to_keys = defaultdict(set)
644 for key in remaining:
645 for dim in key - output:
646 dim_to_keys[dim].add(key)
648 # Keep track of the number of tensors using each dim; when the dim is no longer
649 # used it can be contracted. Since we specialize to binary ops, we only care about
650 # ref counts of >=2 or >=3.
651 dim_ref_counts = {
652 count: set(dim for dim, keys in dim_to_keys.items() if len(keys) >= count) - output for count in [2, 3]
653 }
655 # Compute separable part of the objective function for contractions.
656 footprints = {key: compute_size_by_dict(key, sizes) for key in remaining}
658 # Find initial candidate contractions.
659 queue: List[GreedyContractionType] = []
660 for dim, dim_keys in dim_to_keys.items():
661 dim_keys_list = sorted(dim_keys, key=remaining.__getitem__)
662 for i, k1 in enumerate(dim_keys_list[:-1]):
663 k2s_guess = dim_keys_list[1 + i :]
664 _push_candidate(
665 output,
666 sizes,
667 remaining,
668 footprints,
669 dim_ref_counts,
670 k1,
671 k2s_guess,
672 queue,
673 push_all,
674 cost_fn,
675 )
677 # Greedily contract pairs of tensors.
678 while queue:
680 con = choose_fn(queue, remaining)
681 if con is None:
682 continue # allow choose_fn to flag all candidates obsolete
683 cost, k1, k2, k12 = con
685 ssa_id1 = remaining.pop(k1)
686 ssa_id2 = remaining.pop(k2)
687 for dim in k1 - output:
688 dim_to_keys[dim].remove(k1)
689 for dim in k2 - output:
690 dim_to_keys[dim].remove(k2)
691 ssa_path.append((ssa_id1, ssa_id2))
692 if k12 in remaining:
693 ssa_path.append((remaining[k12], next(ssa_ids)))
694 else:
695 for dim in k12 - output:
696 dim_to_keys[dim].add(k12)
697 remaining[k12] = next(ssa_ids)
698 _update_ref_counts(dim_to_keys, dim_ref_counts, k1 | k2 - output)
699 footprints[k12] = compute_size_by_dict(k12, sizes)
701 # Find new candidate contractions.
702 k1 = k12
703 k2s = set(k2 for dim in k1 for k2 in dim_to_keys[dim])
704 k2s.discard(k1)
705 if k2s:
706 _push_candidate(
707 output,
708 sizes,
709 remaining,
710 footprints,
711 dim_ref_counts,
712 k1,
713 list(k2s),
714 queue,
715 push_all,
716 cost_fn,
717 )
719 # Greedily compute pairwise outer products.
720 final_queue = [(compute_size_by_dict(key & output, sizes), ssa_id, key) for key, ssa_id in remaining.items()]
721 heapq.heapify(final_queue)
722 _, ssa_id1, k1 = heapq.heappop(final_queue)
723 while final_queue:
724 _, ssa_id2, k2 = heapq.heappop(final_queue)
725 ssa_path.append((min(ssa_id1, ssa_id2), max(ssa_id1, ssa_id2)))
726 k12 = (k1 | k2) & output
727 cost = compute_size_by_dict(k12, sizes)
728 ssa_id12 = next(ssa_ids)
729 _, ssa_id1, k1 = heapq.heappushpop(final_queue, (cost, ssa_id12, k12))
731 return ssa_path
734def greedy(
735 inputs: List[ArrayIndexType],
736 output: ArrayIndexType,
737 size_dict: Dict[str, int],
738 memory_limit: Optional[int] = None,
739 choose_fn: Any = None,
740 cost_fn: str = "memory-removed",
741) -> PathType:
742 """
743 Finds the path by a three stage algorithm:
745 1. Eagerly compute Hadamard products.
746 2. Greedily compute contractions to maximize `removed_size`
747 3. Greedily compute outer products.
749 This algorithm scales quadratically with respect to the
750 maximum number of elements sharing a common dim.
752 **Parameters:**
754 - **inputs** - *(list)* List of sets that represent the lhs side of the einsum subscript
755 - **output** - *(set)* Set that represents the rhs side of the overall einsum subscript
756 - **size_dict** - *(dictionary)* Dictionary of index sizes
757 - **memory_limit** - *(int)* The maximum number of elements in a temporary array
758 - **choose_fn** - *(callable, optional)* A function that chooses which contraction to perform from the queue
759 - **cost_fn** - *(callable, optional)* A function that assigns a potential contraction a cost.
761 **Returns:**
763 - **path** - *(list)* The contraction order (a list of tuples of ints).
765 **Examples:**
767 ```python
768 isets = [set('abd'), set('ac'), set('bdc')]
769 oset = set('')
770 idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
771 greedy(isets, oset, idx_sizes)
772 #> [(0, 2), (0, 1)]
773 ```
774 """
775 if memory_limit not in _UNLIMITED_MEM:
776 return branch(inputs, output, size_dict, memory_limit, nbranch=1, cost_fn=cost_fn) # type: ignore
778 ssa_path = ssa_greedy_optimize(inputs, output, size_dict, cost_fn=cost_fn, choose_fn=choose_fn)
779 return ssa_to_linear(ssa_path)
782def _tree_to_sequence(tree: Tuple[Any, ...]) -> PathType:
783 """
784 Converts a contraction tree to a contraction path as it has to be
785 returned by path optimizers. A contraction tree can either be an int
786 (=no contraction) or a tuple containing the terms to be contracted. An
787 arbitrary number (>= 1) of terms can be contracted at once. Note that
788 contractions are commutative, e.g. (j, k, l) = (k, l, j). Note that in
789 general, solutions are not unique.
791 **Parameters:**
793 - **c** - *(tuple or int)* Contraction tree
795 **Returns:**
797 - **path** - *(list[set[int]])* Contraction path
799 **Examples:**
801 ```python
802 _tree_to_sequence(((1,2),(0,(4,5,3))))
803 #> [(1, 2), (1, 2, 3), (0, 2), (0, 1)]
804 ```
805 """
807 # ((1,2),(0,(4,5,3))) --> [(1, 2), (1, 2, 3), (0, 2), (0, 1)]
808 #
809 # 0 0 0 (1,2) --> ((1,2),(0,(3,4,5)))
810 # 1 3 (1,2) --> (0,(3,4,5))
811 # 2 --> 4 --> (3,4,5)
812 # 3 5
813 # 4 (1,2)
814 # 5
815 #
816 # this function iterates through the table shown above from right to left;
818 if type(tree) == int:
819 return []
821 c: List[Tuple[Any, ...]] = [tree] # list of remaining contractions (lower part of columns shown above)
822 t: List[int] = [] # list of elementary tensors (upper part of columns)
823 s: List[Tuple[int, ...]] = [] # resulting contraction sequence
825 while len(c) > 0:
826 j = c.pop(-1)
827 s.insert(0, tuple())
829 for i in sorted([i for i in j if type(i) == int]):
830 s[0] += (sum(1 for q in t if q < i),)
831 t.insert(s[0][-1], i)
833 for i_tup in [i_tup for i_tup in j if type(i_tup) != int]:
834 s[0] += (len(t) + len(c),)
835 c.append(i_tup)
837 return s
840def _find_disconnected_subgraphs(inputs: List[FrozenSet[int]], output: FrozenSet[int]) -> List[FrozenSet[int]]:
841 """
842 Finds disconnected subgraphs in the given list of inputs. Inputs are
843 connected if they share summation indices. Note: Disconnected subgraphs
844 can be contracted independently before forming outer products.
846 **Parameters:**
847 - **inputs** - *(list[set])* List of sets that represent the lhs side of the einsum subscript
848 - **output** - *(set)* Set that represents the rhs side of the overall einsum subscript
850 **Returns:**
852 - **subgraphs** - *(list[set[int]])* List containing sets of indices for each subgraph
854 **Examples:**
856 ```python
857 _find_disconnected_subgraphs([set("ab"), set("c"), set("ad")], set("bd"))
858 #> [{0, 2}, {1}]
860 _find_disconnected_subgraphs([set("ab"), set("c"), set("ad")], set("abd"))
861 #> [{0}, {1}, {2}]
862 ```
863 """
865 subgraphs = []
866 unused_inputs = set(range(len(inputs)))
868 i_sum = frozenset.union(*inputs) - output # all summation indices
870 while len(unused_inputs) > 0:
871 g = set()
872 q = [unused_inputs.pop()]
873 while len(q) > 0:
874 j = q.pop()
875 g.add(j)
876 i_tmp = i_sum & inputs[j]
877 n = {k for k in unused_inputs if len(i_tmp & inputs[k]) > 0}
878 q.extend(n)
879 unused_inputs.difference_update(n)
881 subgraphs.append(g)
883 return [frozenset(x) for x in subgraphs]
886def _bitmap_select(s: int, seq: List[FrozenSet[int]]) -> Generator[FrozenSet[int], None, None]:
887 """Select elements of ``seq`` which are marked by the bitmap set ``s``.
889 E.g.:
891 >>> list(_bitmap_select(0b11010, ['A', 'B', 'C', 'D', 'E']))
892 ['B', 'D', 'E']
893 """
894 return (x for x, b in zip(seq, bin(s)[:1:-1]) if b == "1")
897def _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2):
898 """Calculates the effective outer indices of the intermediate tensor
899 corresponding to the subgraph ``s``.
900 """
901 # set of remaining tensors (=g-s)
902 r = g & (all_tensors ^ s)
903 # indices of remaining indices:
904 if r:
905 i_r = frozenset.union(*_bitmap_select(r, inputs))
906 else:
907 i_r = frozenset()
908 # contraction indices:
909 i_contract = i1_cut_i2_wo_output - i_r
910 return i1_union_i2 - i_contract
913def _dp_compare_flops(
914 cost1: int,
915 cost2: int,
916 i1_union_i2: Set[int],
917 size_dict: List[int],
918 cost_cap: int,
919 s1: int,
920 s2: int,
921 xn: Dict[int, Any],
922 g: int,
923 all_tensors: int,
924 inputs: List[FrozenSet[int]],
925 i1_cut_i2_wo_output: Set[int],
926 memory_limit: Optional[int],
927 contract1: Union[int, Tuple[int]],
928 contract2: Union[int, Tuple[int]],
929) -> None:
930 """Performs the inner comparison of whether the two subgraphs (the bitmaps
931 `s1` and `s2`) should be merged and added to the dynamic programming
932 search. Will skip for a number of reasons:
934 1. If the number of operations to form `s = s1 | s2` including previous
935 contractions is above the cost-cap.
936 2. If we've already found a better way of making `s`.
937 3. If the intermediate tensor corresponding to `s` is going to break the
938 memory limit.
939 """
941 # TODO: Odd usage with an Iterable[int] to map a dict of type List[int]
942 cost = cost1 + cost2 + compute_size_by_dict(i1_union_i2, size_dict)
943 if cost <= cost_cap:
944 s = s1 | s2
945 if s not in xn or cost < xn[s][1]:
946 i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2)
947 mem = compute_size_by_dict(i, size_dict)
948 if memory_limit is None or mem <= memory_limit:
949 xn[s] = (i, cost, (contract1, contract2))
952def _dp_compare_size(
953 cost1: int,
954 cost2: int,
955 i1_union_i2: Set[int],
956 size_dict: List[int],
957 cost_cap: int,
958 s1: int,
959 s2: int,
960 xn: Dict[int, Any],
961 g: int,
962 all_tensors: int,
963 inputs: List[FrozenSet[int]],
964 i1_cut_i2_wo_output: Set[int],
965 memory_limit: Optional[int],
966 contract1: Union[int, Tuple[int]],
967 contract2: Union[int, Tuple[int]],
968) -> None:
969 """Like `_dp_compare_flops` but sieves the potential contraction based
970 on the size of the intermediate tensor created, rather than the number of
971 operations, and so calculates that first.
972 """
974 s = s1 | s2
975 i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2)
976 mem = compute_size_by_dict(i, size_dict)
977 cost = max(cost1, cost2, mem)
978 if cost <= cost_cap:
979 if s not in xn or cost < xn[s][1]:
980 if memory_limit is None or mem <= memory_limit:
981 xn[s] = (i, cost, (contract1, contract2))
984def _dp_compare_write(
985 cost1: int,
986 cost2: int,
987 i1_union_i2: Set[int],
988 size_dict: List[int],
989 cost_cap: int,
990 s1: int,
991 s2: int,
992 xn: Dict[int, Any],
993 g: int,
994 all_tensors: int,
995 inputs: List[FrozenSet[int]],
996 i1_cut_i2_wo_output: Set[int],
997 memory_limit: Optional[int],
998 contract1: Union[int, Tuple[int]],
999 contract2: Union[int, Tuple[int]],
1000) -> None:
1001 """Like ``_dp_compare_flops`` but sieves the potential contraction based
1002 on the total size of memory created, rather than the number of
1003 operations, and so calculates that first.
1004 """
1005 s = s1 | s2
1006 i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2)
1007 mem = compute_size_by_dict(i, size_dict)
1008 cost = cost1 + cost2 + mem
1009 if cost <= cost_cap:
1010 if s not in xn or cost < xn[s][1]:
1011 if memory_limit is None or mem <= memory_limit:
1012 xn[s] = (i, cost, (contract1, contract2))
1015DEFAULT_COMBO_FACTOR = 64
1018def _dp_compare_combo(
1019 cost1: int,
1020 cost2: int,
1021 i1_union_i2: Set[int],
1022 size_dict: List[int],
1023 cost_cap: int,
1024 s1: int,
1025 s2: int,
1026 xn: Dict[int, Any],
1027 g: int,
1028 all_tensors: int,
1029 inputs: List[FrozenSet[int]],
1030 i1_cut_i2_wo_output: Set[int],
1031 memory_limit: Optional[int],
1032 contract1: Union[int, Tuple[int]],
1033 contract2: Union[int, Tuple[int]],
1034 factor: Union[int, float] = DEFAULT_COMBO_FACTOR,
1035 combine: Callable = sum,
1036) -> None:
1037 """Like ``_dp_compare_flops`` but sieves the potential contraction based
1038 on some combination of both the flops and size,
1039 """
1040 s = s1 | s2
1041 i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2)
1042 mem = compute_size_by_dict(i, size_dict)
1043 f = compute_size_by_dict(i1_union_i2, size_dict)
1044 cost = cost1 + cost2 + combine((f, factor * mem))
1045 if cost <= cost_cap:
1046 if s not in xn or cost < xn[s][1]:
1047 if memory_limit is None or mem <= memory_limit:
1048 xn[s] = (i, cost, (contract1, contract2))
1051minimize_finder = re.compile(r"(flops|size|write|combo|limit)-*(\d*)")
1054@functools.lru_cache(128)
1055def _parse_minimize(minimize: Union[str, Callable]) -> Tuple[Callable, Union[int, float]]:
1056 """This works out what local scoring function to use for the dp algorithm
1057 as well as a `naive_scale` to account for the memory_limit checks.
1058 """
1059 if minimize == "flops":
1060 return _dp_compare_flops, 1
1061 elif minimize == "size":
1062 return _dp_compare_size, 1
1063 elif minimize == "write":
1064 return _dp_compare_write, 1
1065 elif callable(minimize):
1066 # default to naive_scale=inf for this and remaining options
1067 # as otherwise memory_limit check can cause problems
1068 return minimize, float("inf")
1070 # parse out a customized value for the combination factor
1071 match = minimize_finder.fullmatch(minimize)
1072 if match is None:
1073 raise ValueError(f"Couldn't parse `minimize` value: {minimize}.")
1075 minimize, custom_factor = match.groups()
1076 factor = float(custom_factor) if custom_factor else DEFAULT_COMBO_FACTOR
1077 if minimize == "combo":
1078 return functools.partial(_dp_compare_combo, factor=factor, combine=sum), float("inf")
1079 elif minimize == "limit":
1080 return functools.partial(_dp_compare_combo, factor=factor, combine=max), float("inf")
1081 else:
1082 raise ValueError(f"Couldn't parse `minimize` value: {minimize}.")
1085def simple_tree_tuple(seq: Sequence[Tuple[int, ...]]) -> Tuple[Any, ...]:
1086 """Make a simple left to right binary tree out of iterable `seq`.
1088 ```python
1089 tuple_nest([1, 2, 3, 4])
1090 #> (((1, 2), 3), 4)
1091 ```
1093 """
1094 return functools.reduce(lambda x, y: (x, y), seq)
1097def _dp_parse_out_single_term_ops(
1098 inputs: List[FrozenSet[int]], all_inds: Tuple[str, ...], ind_counts: CounterType[str]
1099) -> Tuple[List[FrozenSet[int]], List[Tuple[int]], List[Union[int, Tuple[int]]]]:
1100 """Take `inputs` and parse for single term index operations, i.e. where
1101 an index appears on one tensor and nowhere else.
1103 If a term is completely reduced to a scalar in this way it can be removed
1104 to `inputs_done`. If only some indices can be summed then add a 'single
1105 term contraction' that will perform this summation.
1106 """
1107 i_single = frozenset(i for i, c in enumerate(all_inds) if ind_counts[c] == 1)
1108 inputs_parsed: List[FrozenSet[int]] = []
1109 inputs_done: List[Tuple[int]] = []
1110 inputs_contractions: List[Union[int, Tuple[int]]] = []
1111 for j, i in enumerate(inputs):
1112 i_reduced = i - i_single
1113 if (not i_reduced) and (len(i) > 0):
1114 # input reduced to scalar already - remove
1115 inputs_done.append((j,))
1116 else:
1117 # if the input has any index reductions, add single contraction
1118 inputs_parsed.append(i_reduced)
1119 inputs_contractions.append((j,) if i_reduced != i else j)
1121 return inputs_parsed, inputs_done, inputs_contractions
1124class DynamicProgramming(PathOptimizer):
1125 """
1126 Finds the optimal path of pairwise contractions without intermediate outer
1127 products based a dynamic programming approach presented in
1128 Phys. Rev. E 90, 033315 (2014) (the corresponding preprint is publicly
1129 available at https://arxiv.org/abs/1304.6112). This method is especially
1130 well-suited in the area of tensor network states, where it usually
1131 outperforms all the other optimization strategies.
1133 This algorithm shows exponential scaling with the number of inputs
1134 in the worst case scenario (see example below). If the graph to be
1135 contracted consists of disconnected subgraphs, the algorithm scales
1136 linearly in the number of disconnected subgraphs and only exponentially
1137 with the number of inputs per subgraph.
1139 **Parameters:**
1141 - **minimize** - *({'flops', 'size', 'write', 'combo', 'limit', callable}, optional)* What to minimize:
1143 - 'flops' - minimize the number of flops
1144 - 'size' - minimize the size of the largest intermediate
1145 - 'write' - minimize the size of all intermediate tensors
1146 - 'combo' - minimize `flops + alpha * write` summed over intermediates, a default ratio of alpha=64
1147 is used, or it can be customized with `f'combo-{alpha}'`
1148 - 'limit' - minimize `max(flops, alpha * write)` summed over intermediates, a default ratio of alpha=64
1149 is used, or it can be customized with `f'limit-{alpha}'`
1150 - callable - a custom local cost function
1152 - **cost_cap** - *({True, False, int}, optional)* How to implement cost-capping:
1154 - True - iteratively increase the cost-cap
1155 - False - implement no cost-cap at all
1156 - int - use explicit cost cap
1158 - **search_outer** - *(bool, optional)* In rare circumstances the optimal contraction may involve an outer
1159 product, this option allows searching such contractions but may well
1160 slow down the path finding considerably on all but very small graphs.
1161 """
1163 def __init__(self, minimize: str = "flops", cost_cap: bool = True, search_outer: bool = False) -> None:
1164 self.minimize = minimize
1165 self.search_outer = search_outer
1166 self.cost_cap = cost_cap
1168 def __call__(
1169 self,
1170 inputs_: List[ArrayIndexType],
1171 output_: ArrayIndexType,
1172 size_dict_: Dict[str, int],
1173 memory_limit: Optional[int] = None,
1174 ) -> PathType:
1175 """
1176 **Parameters:**
1178 - **inputs** - *(list)* List of sets that represent the lhs side of the einsum subscript
1179 - **output** - *(set)* Set that represents the rhs side of the overall einsum subscript
1180 - **size_dict** - *(dictionary)* Dictionary of index sizes
1181 - **memory_limit** - *(int)* The maximum number of elements in a temporary array
1183 **Returns:**
1185 - **path** - *(list)* The contraction order (a list of tuples of ints).
1187 **Examples:**
1189 ```python
1190 n_in = 3 # exponential scaling
1191 n_out = 2 # linear scaling
1192 s = dict()
1193 i_all = []
1194 for _ in range(n_out):
1195 i = [set() for _ in range(n_in)]
1196 for j in range(n_in):
1197 for k in range(j+1, n_in):
1198 c = oe.get_symbol(len(s))
1199 i[j].add(c)
1200 i[k].add(c)
1201 s[c] = 2
1202 i_all.extend(i)
1203 o = DynamicProgramming()
1204 o(i_all, set(), s)
1205 #> [(1, 2), (0, 4), (1, 2), (0, 2), (0, 1)]
1206 ```
1207 """
1208 _check_contraction, naive_scale = _parse_minimize(self.minimize)
1209 _check_outer = (lambda x: True) if self.search_outer else (lambda x: x)
1211 ind_counts = Counter(itertools.chain(*inputs_, output_))
1212 all_inds = tuple(ind_counts)
1214 # convert all indices to integers (makes set operations ~10 % faster)
1215 symbol2int = {c: j for j, c in enumerate(all_inds)}
1216 inputs = [frozenset(symbol2int[c] for c in i) for i in inputs_]
1217 output = frozenset(symbol2int[c] for c in output_)
1218 size_dict_canonical = {symbol2int[c]: v for c, v in size_dict_.items() if c in symbol2int}
1219 size_dict = [size_dict_canonical[j] for j in range(len(size_dict_canonical))]
1220 naive_cost = naive_scale * len(inputs) * functools.reduce(operator.mul, size_dict, 1)
1222 inputs, inputs_done, inputs_contractions = _dp_parse_out_single_term_ops(inputs, all_inds, ind_counts)
1224 if not inputs:
1225 # nothing left to do after single axis reductions!
1226 return _tree_to_sequence(simple_tree_tuple(inputs_done))
1228 # a list of all necessary contraction expressions for each of the
1229 # disconnected subgraphs and their size
1230 subgraph_contractions = inputs_done
1231 subgraph_contractions_size = [1] * len(inputs_done)
1233 if self.search_outer:
1234 # optimize everything together if we are considering outer products
1235 subgraphs = [frozenset(range(len(inputs)))]
1236 else:
1237 subgraphs = _find_disconnected_subgraphs(inputs, output)
1239 # the bitmap set of all tensors is computed as it is needed to
1240 # compute set differences: s1 - s2 transforms into
1241 # s1 & (all_tensors ^ s2)
1242 all_tensors = (1 << len(inputs)) - 1
1244 for g in subgraphs:
1246 # dynamic programming approach to compute x[n] for subgraph g;
1247 # x[n][set of n tensors] = (indices, cost, contraction)
1248 # the set of n tensors is represented by a bitmap: if bit j is 1,
1249 # tensor j is in the set, e.g. 0b100101 = {0,2,5}; set unions
1250 # (intersections) can then be computed by bitwise or (and);
1251 x: List[Any] = [None] * 2 + [dict() for j in range(len(g) - 1)]
1252 x[1] = OrderedDict((1 << j, (inputs[j], 0, inputs_contractions[j])) for j in g)
1254 # convert set of tensors g to a bitmap set:
1255 bitmap_g = functools.reduce(lambda x, y: x | y, (1 << j for j in g))
1257 # try to find contraction with cost <= cost_cap and increase
1258 # cost_cap successively if no such contraction is found;
1259 # this is a major performance improvement; start with product of
1260 # output index dimensions as initial cost_cap
1261 subgraph_inds = frozenset.union(*_bitmap_select(bitmap_g, inputs))
1262 if self.cost_cap is True:
1263 cost_cap = compute_size_by_dict(subgraph_inds & output, size_dict)
1264 elif self.cost_cap is False:
1265 cost_cap = float("inf") # type: ignore
1266 else:
1267 cost_cap = self.cost_cap
1268 # set the factor to increase the cost by each iteration (ensure > 1)
1269 if len(subgraph_inds) == 0:
1270 cost_increment = 2
1271 else:
1272 cost_increment = max(min(map(size_dict.__getitem__, subgraph_inds)), 2)
1274 while len(x[-1]) == 0:
1275 for n in range(2, len(x[1]) + 1):
1276 xn = x[n]
1278 # try to combine solutions from x[m] and x[n-m]
1279 for m in range(1, n // 2 + 1):
1280 for s1, (i1, cost1, contract1) in x[m].items():
1281 for s2, (i2, cost2, contract2) in x[n - m].items():
1283 # can only merge if s1 and s2 are disjoint
1284 # and avoid e.g. s1={0}, s2={1} and s1={1}, s2={0}
1285 if (not s1 & s2) and (m != n - m or s1 < s2):
1286 i1_cut_i2_wo_output = (i1 & i2) - output
1288 # maybe ignore outer products:
1289 if _check_outer(i1_cut_i2_wo_output):
1291 i1_union_i2 = i1 | i2
1292 _check_contraction(
1293 cost1,
1294 cost2,
1295 i1_union_i2,
1296 size_dict,
1297 cost_cap,
1298 s1,
1299 s2,
1300 xn,
1301 bitmap_g,
1302 all_tensors,
1303 inputs,
1304 i1_cut_i2_wo_output,
1305 memory_limit,
1306 contract1,
1307 contract2,
1308 )
1310 if (cost_cap > naive_cost) and (len(x[-1]) == 0):
1311 raise RuntimeError("No contraction found for given `memory_limit`.")
1313 # increase cost cap for next iteration:
1314 cost_cap = cost_increment * cost_cap
1316 i, cost, contraction = list(x[-1].values())[0]
1317 subgraph_contractions.append(contraction)
1318 subgraph_contractions_size.append(compute_size_by_dict(i, size_dict))
1320 # sort the subgraph contractions by the size of the subgraphs in
1321 # ascending order (will give the cheapest contractions); note that
1322 # outer products should be performed pairwise (to use BLAS functions)
1323 subgraph_contractions = [
1324 subgraph_contractions[j]
1325 for j in sorted(
1326 range(len(subgraph_contractions_size)),
1327 key=subgraph_contractions_size.__getitem__,
1328 )
1329 ]
1331 # build the final contraction tree
1332 tree = simple_tree_tuple(subgraph_contractions)
1333 return _tree_to_sequence(tree)
1336def dynamic_programming(
1337 inputs: List[ArrayIndexType],
1338 output: ArrayIndexType,
1339 size_dict: Dict[str, int],
1340 memory_limit: Optional[int] = None,
1341 **kwargs: Any,
1342) -> PathType:
1343 optimizer = DynamicProgramming(**kwargs)
1344 return optimizer(inputs, output, size_dict, memory_limit)
1347_AUTO_CHOICES = {}
1348for i in range(1, 5):
1349 _AUTO_CHOICES[i] = optimal
1350for i in range(5, 7):
1351 _AUTO_CHOICES[i] = branch_all
1352for i in range(7, 9):
1353 _AUTO_CHOICES[i] = branch_2
1354for i in range(9, 15):
1355 _AUTO_CHOICES[i] = branch_1
1358def auto(
1359 inputs: List[ArrayIndexType],
1360 output: ArrayIndexType,
1361 size_dict: Dict[str, int],
1362 memory_limit: Optional[int] = None,
1363) -> PathType:
1364 """Finds the contraction path by automatically choosing the method based on
1365 how many input arguments there are.
1366 """
1367 N = len(inputs)
1368 return _AUTO_CHOICES.get(N, greedy)(inputs, output, size_dict, memory_limit)
1371_AUTO_HQ_CHOICES = {}
1372for i in range(1, 6):
1373 _AUTO_HQ_CHOICES[i] = optimal
1374for i in range(6, 17):
1375 _AUTO_HQ_CHOICES[i] = dynamic_programming
1378def auto_hq(
1379 inputs: List[ArrayIndexType],
1380 output: ArrayIndexType,
1381 size_dict: Dict[str, int],
1382 memory_limit: Optional[int] = None,
1383) -> PathType:
1384 """Finds the contraction path by automatically choosing the method based on
1385 how many input arguments there are, but targeting a more generous
1386 amount of search time than ``'auto'``.
1387 """
1388 from .path_random import random_greedy_128
1390 N = len(inputs)
1391 return _AUTO_HQ_CHOICES.get(N, random_greedy_128)(inputs, output, size_dict, memory_limit)
1394PathSearchFunctionType = Callable[[List[ArrayIndexType], ArrayIndexType, Dict[str, int], Optional[int]], PathType]
1395_PATH_OPTIONS: Dict[str, PathSearchFunctionType] = {
1396 "auto": auto,
1397 "auto-hq": auto_hq,
1398 "optimal": optimal,
1399 "branch-all": branch_all,
1400 "branch-2": branch_2,
1401 "branch-1": branch_1,
1402 "greedy": greedy,
1403 "eager": greedy,
1404 "opportunistic": greedy,
1405 "dp": dynamic_programming,
1406 "dynamic-programming": dynamic_programming,
1407}
1410def register_path_fn(name: str, fn: PathSearchFunctionType) -> None:
1411 """Add path finding function ``fn`` as an option with ``name``."""
1412 if name in _PATH_OPTIONS:
1413 raise KeyError("Path optimizer '{}' already exists.".format(name))
1415 _PATH_OPTIONS[name.lower()] = fn
1418def get_path_fn(path_type: str) -> PathSearchFunctionType:
1419 """Get the correct path finding function from str ``path_type``."""
1420 path_type = path_type.lower()
1421 if path_type not in _PATH_OPTIONS:
1422 raise KeyError(
1423 "Path optimizer '{}' not found, valid options are {}.".format(path_type, set(_PATH_OPTIONS.keys()))
1424 )
1426 return _PATH_OPTIONS[path_type]