Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/paths.py: 16%
440 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""
2Contains the path technology behind opt_einsum in addition to several path helpers
3"""
5import functools
6import heapq
7import itertools
8import random
9from collections import Counter, OrderedDict, defaultdict
11import numpy as np
13from . import helpers
15__all__ = [
16 "optimal", "BranchBound", "branch", "greedy", "auto", "auto_hq", "get_path_fn", "DynamicProgramming",
17 "dynamic_programming"
18]
20_UNLIMITED_MEM = {-1, None, float('inf')}
23class PathOptimizer(object):
24 """Base class for different path optimizers to inherit from.
26 Subclassed optimizers should define a call method with signature::
28 def __call__(self, inputs, output, size_dict, memory_limit=None):
29 \"\"\"
30 Parameters
31 ----------
32 inputs : list[set[str]]
33 The indices of each input array.
34 outputs : set[str]
35 The output indices
36 size_dict : dict[str, int]
37 The size of each index
38 memory_limit : int, optional
39 If given, the maximum allowed memory.
40 \"\"\"
41 # ... compute path here ...
42 return path
44 where ``path`` is a list of int-tuples specifiying a contraction order.
45 """
47 def _check_args_against_first_call(self, inputs, output, size_dict):
48 """Utility that stateful optimizers can use to ensure they are not
49 called with different contractions across separate runs.
50 """
51 args = (inputs, output, size_dict)
52 if not hasattr(self, '_first_call_args'):
53 # simply set the attribute as currently there is no global PathOptimizer init
54 self._first_call_args = args
55 elif args != self._first_call_args:
56 raise ValueError("The arguments specifiying the contraction that this path optimizer "
57 "instance was called with have changed - try creating a new instance.")
59 def __call__(self, inputs, output, size_dict, memory_limit=None):
60 raise NotImplementedError
63def ssa_to_linear(ssa_path):
64 """
65 Convert a path with static single assignment ids to a path with recycled
66 linear ids. For example::
68 >>> ssa_to_linear([(0, 3), (2, 4), (1, 5)])
69 [(0, 3), (1, 2), (0, 1)]
70 """
71 ids = np.arange(1 + max(map(max, ssa_path)), dtype=np.int32)
72 path = []
73 for ssa_ids in ssa_path:
74 path.append(tuple(int(ids[ssa_id]) for ssa_id in ssa_ids))
75 for ssa_id in ssa_ids:
76 ids[ssa_id:] -= 1
77 return path
80def linear_to_ssa(path):
81 """
82 Convert a path with recycled linear ids to a path with static single
83 assignment ids. For example::
85 >>> linear_to_ssa([(0, 3), (1, 2), (0, 1)])
86 [(0, 3), (2, 4), (1, 5)]
87 """
88 num_inputs = sum(map(len, path)) - len(path) + 1
89 linear_to_ssa = list(range(num_inputs))
90 new_ids = itertools.count(num_inputs)
91 ssa_path = []
92 for ids in path:
93 ssa_path.append(tuple(linear_to_ssa[id_] for id_ in ids))
94 for id_ in sorted(ids, reverse=True):
95 del linear_to_ssa[id_]
96 linear_to_ssa.append(next(new_ids))
97 return ssa_path
100def calc_k12_flops(inputs, output, remaining, i, j, size_dict):
101 """
102 Calculate the resulting indices and flops for a potential pairwise
103 contraction - used in the recursive (optimal/branch) algorithms.
105 Parameters
106 ----------
107 inputs : tuple[frozenset[str]]
108 The indices of each tensor in this contraction, note this includes
109 tensors unavaiable to contract as static single assignment is used ->
110 contracted tensors are not removed from the list.
111 output : frozenset[str]
112 The set of output indices for the whole contraction.
113 remaining : frozenset[int]
114 The set of indices (corresponding to ``inputs``) of tensors still
115 available to contract.
116 i : int
117 Index of potential tensor to contract.
118 j : int
119 Index of potential tensor to contract.
120 size_dict dict[str, int]
121 Size mapping of all the indices.
123 Returns
124 -------
125 k12 : frozenset
126 The resulting indices of the potential tensor.
127 cost : int
128 Estimated flop count of operation.
129 """
130 k1, k2 = inputs[i], inputs[j]
131 either = k1 | k2
132 shared = k1 & k2
133 keep = frozenset.union(output, *map(inputs.__getitem__, remaining - {i, j}))
135 k12 = either & keep
136 cost = helpers.flop_count(either, shared - keep, 2, size_dict)
138 return k12, cost
141def _compute_oversize_flops(inputs, remaining, output, size_dict):
142 """
143 Compute the flop count for a contraction of all remaining arguments. This
144 is used when a memory limit means that no pairwise contractions can be made.
145 """
146 idx_contraction = frozenset.union(*map(inputs.__getitem__, remaining))
147 inner = idx_contraction - output
148 num_terms = len(remaining)
149 return helpers.flop_count(idx_contraction, inner, num_terms, size_dict)
152def optimal(inputs, output, size_dict, memory_limit=None):
153 """
154 Computes all possible pair contractions in a depth-first recursive manner,
155 sieving results based on ``memory_limit`` and the best path found so far.
156 Returns the lowest cost path. This algorithm scales factoriallly with
157 respect to the elements in the list ``input_sets``.
159 Parameters
160 ----------
161 inputs : list
162 List of sets that represent the lhs side of the einsum subscript.
163 output : set
164 Set that represents the rhs side of the overall einsum subscript.
165 size_dict : dictionary
166 Dictionary of index sizes.
167 memory_limit : int
168 The maximum number of elements in a temporary array.
170 Returns
171 -------
172 path : list
173 The optimal contraction order within the memory limit constraint.
175 Examples
176 --------
177 >>> isets = [set('abd'), set('ac'), set('bdc')]
178 >>> oset = set('')
179 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
180 >>> optimal(isets, oset, idx_sizes, 5000)
181 [(0, 2), (0, 1)]
182 """
183 inputs = tuple(map(frozenset, inputs))
184 output = frozenset(output)
186 best = {'flops': float('inf'), 'ssa_path': (tuple(range(len(inputs))), )}
187 size_cache = {}
188 result_cache = {}
190 def _optimal_iterate(path, remaining, inputs, flops):
192 # reached end of path (only ever get here if flops is best found so far)
193 if len(remaining) == 1:
194 best['flops'] = flops
195 best['ssa_path'] = path
196 return
198 # check all possible remaining paths
199 for i, j in itertools.combinations(remaining, 2):
200 if i > j:
201 i, j = j, i
202 key = (inputs[i], inputs[j])
203 try:
204 k12, flops12 = result_cache[key]
205 except KeyError:
206 k12, flops12 = result_cache[key] = calc_k12_flops(inputs, output, remaining, i, j, size_dict)
208 # sieve based on current best flops
209 new_flops = flops + flops12
210 if new_flops >= best['flops']:
211 continue
213 # sieve based on memory limit
214 if memory_limit not in _UNLIMITED_MEM:
215 try:
216 size12 = size_cache[k12]
217 except KeyError:
218 size12 = size_cache[k12] = helpers.compute_size_by_dict(k12, size_dict)
220 # possibly terminate this path with an all-terms einsum
221 if size12 > memory_limit:
222 new_flops = flops + _compute_oversize_flops(inputs, remaining, output, size_dict)
223 if new_flops < best['flops']:
224 best['flops'] = new_flops
225 best['ssa_path'] = path + (tuple(remaining), )
226 continue
228 # add contraction and recurse into all remaining
229 _optimal_iterate(path=path + ((i, j), ),
230 inputs=inputs + (k12, ),
231 remaining=remaining - {i, j} | {len(inputs)},
232 flops=new_flops)
234 _optimal_iterate(path=(), inputs=inputs, remaining=set(range(len(inputs))), flops=0)
236 return ssa_to_linear(best['ssa_path'])
239# functions for comparing which of two paths is 'better'
242def better_flops_first(flops, size, best_flops, best_size):
243 return (flops, size) < (best_flops, best_size)
246def better_size_first(flops, size, best_flops, best_size):
247 return (size, flops) < (best_size, best_flops)
250_BETTER_FNS = {
251 'flops': better_flops_first,
252 'size': better_size_first,
253}
256def get_better_fn(key):
257 return _BETTER_FNS[key]
260# functions for assigning a heuristic 'cost' to a potential contraction
263def cost_memory_removed(size12, size1, size2, k12, k1, k2):
264 """The default heuristic cost, corresponding to the total reduction in
265 memory of performing a contraction.
266 """
267 return size12 - size1 - size2
270def cost_memory_removed_jitter(size12, size1, size2, k12, k1, k2):
271 """Like memory-removed, but with a slight amount of noise that breaks ties
272 and thus jumbles the contractions a bit.
273 """
274 return random.gauss(1.0, 0.01) * (size12 - size1 - size2)
277_COST_FNS = {
278 'memory-removed': cost_memory_removed,
279 'memory-removed-jitter': cost_memory_removed_jitter,
280}
283class BranchBound(PathOptimizer):
284 """
285 Explores possible pair contractions in a depth-first recursive manner like
286 the ``optimal`` approach, but with extra heuristic early pruning of branches
287 as well sieving by ``memory_limit`` and the best path found so far. Returns
288 the lowest cost path. This algorithm still scales factorially with respect
289 to the elements in the list ``input_sets`` if ``nbranch`` is not set, but it
290 scales exponentially like ``nbranch**len(input_sets)`` otherwise.
292 Parameters
293 ----------
294 nbranch : None or int, optional
295 How many branches to explore at each contraction step. If None, explore
296 all possible branches. If an integer, branch into this many paths at
297 each step. Defaults to None.
298 cutoff_flops_factor : float, optional
299 If at any point, a path is doing this much worse than the best path
300 found so far was, terminate it. The larger this is made, the more paths
301 will be fully explored and the slower the algorithm. Defaults to 4.
302 minimize : {'flops', 'size'}, optional
303 Whether to optimize the path with regard primarily to the total
304 estimated flop-count, or the size of the largest intermediate. The
305 option not chosen will still be used as a secondary criterion.
306 cost_fn : callable, optional
307 A function that returns a heuristic 'cost' of a potential contraction
308 with which to sort candidates. Should have signature
309 ``cost_fn(size12, size1, size2, k12, k1, k2)``.
310 """
311 def __init__(self, nbranch=None, cutoff_flops_factor=4, minimize='flops', cost_fn='memory-removed'):
312 self.nbranch = nbranch
313 self.cutoff_flops_factor = cutoff_flops_factor
314 self.minimize = minimize
315 self.cost_fn = _COST_FNS.get(cost_fn, cost_fn)
317 self.better = get_better_fn(minimize)
318 self.best = {'flops': float('inf'), 'size': float('inf')}
319 self.best_progress = defaultdict(lambda: float('inf'))
321 @property
322 def path(self):
323 return ssa_to_linear(self.best['ssa_path'])
325 def __call__(self, inputs, output, size_dict, memory_limit=None):
326 """
328 Parameters
329 ----------
330 input_sets : list
331 List of sets that represent the lhs side of the einsum subscript
332 output_set : set
333 Set that represents the rhs side of the overall einsum subscript
334 idx_dict : dictionary
335 Dictionary of index sizes
336 memory_limit : int
337 The maximum number of elements in a temporary array
339 Returns
340 -------
341 path : list
342 The contraction order within the memory limit constraint.
344 Examples
345 --------
346 >>> isets = [set('abd'), set('ac'), set('bdc')]
347 >>> oset = set('')
348 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
349 >>> optimal(isets, oset, idx_sizes, 5000)
350 [(0, 2), (0, 1)]
351 """
352 self._check_args_against_first_call(inputs, output, size_dict)
354 inputs = tuple(map(frozenset, inputs))
355 output = frozenset(output)
357 size_cache = {k: helpers.compute_size_by_dict(k, size_dict) for k in inputs}
358 result_cache = {}
360 def _branch_iterate(path, inputs, remaining, flops, size):
362 # reached end of path (only ever get here if flops is best found so far)
363 if len(remaining) == 1:
364 self.best['size'] = size
365 self.best['flops'] = flops
366 self.best['ssa_path'] = path
367 return
369 def _assess_candidate(k1, k2, i, j):
370 # find resulting indices and flops
371 try:
372 k12, flops12 = result_cache[k1, k2]
373 except KeyError:
374 k12, flops12 = result_cache[k1, k2] = calc_k12_flops(inputs, output, remaining, i, j, size_dict)
376 try:
377 size12 = size_cache[k12]
378 except KeyError:
379 size12 = size_cache[k12] = helpers.compute_size_by_dict(k12, size_dict)
381 new_flops = flops + flops12
382 new_size = max(size, size12)
384 # sieve based on current best i.e. check flops and size still better
385 if not self.better(new_flops, new_size, self.best['flops'], self.best['size']):
386 return None
388 # compare to how the best method was doing as this point
389 if new_flops < self.best_progress[len(inputs)]:
390 self.best_progress[len(inputs)] = new_flops
391 # sieve based on current progress relative to best
392 elif new_flops > self.cutoff_flops_factor * self.best_progress[len(inputs)]:
393 return None
395 # sieve based on memory limit
396 if (memory_limit not in _UNLIMITED_MEM) and (size12 > memory_limit):
397 # terminate path here, but check all-terms contract first
398 new_flops = flops + _compute_oversize_flops(inputs, remaining, output, size_dict)
399 if new_flops < self.best['flops']:
400 self.best['flops'] = new_flops
401 self.best['ssa_path'] = path + (tuple(remaining), )
402 return None
404 # set cost heuristic in order to locally sort possible contractions
405 size1, size2 = size_cache[inputs[i]], size_cache[inputs[j]]
406 cost = self.cost_fn(size12, size1, size2, k12, k1, k2)
408 return cost, flops12, new_flops, new_size, (i, j), k12
410 # check all possible remaining paths
411 candidates = []
412 for i, j in itertools.combinations(remaining, 2):
413 if i > j:
414 i, j = j, i
415 k1, k2 = inputs[i], inputs[j]
417 # initially ignore outer products
418 if k1.isdisjoint(k2):
419 continue
421 candidate = _assess_candidate(k1, k2, i, j)
422 if candidate:
423 heapq.heappush(candidates, candidate)
425 # assess outer products if nothing left
426 if not candidates:
427 for i, j in itertools.combinations(remaining, 2):
428 if i > j:
429 i, j = j, i
430 k1, k2 = inputs[i], inputs[j]
431 candidate = _assess_candidate(k1, k2, i, j)
432 if candidate:
433 heapq.heappush(candidates, candidate)
435 # recurse into all or some of the best candidate contractions
436 bi = 0
437 while (self.nbranch is None or bi < self.nbranch) and candidates:
438 _, _, new_flops, new_size, (i, j), k12 = heapq.heappop(candidates)
439 _branch_iterate(path=path + ((i, j), ),
440 inputs=inputs + (k12, ),
441 remaining=(remaining - {i, j}) | {len(inputs)},
442 flops=new_flops,
443 size=new_size)
444 bi += 1
446 _branch_iterate(path=(), inputs=inputs, remaining=set(range(len(inputs))), flops=0, size=0)
448 return self.path
451def branch(inputs, output, size_dict, memory_limit=None, **optimizer_kwargs):
452 optimizer = BranchBound(**optimizer_kwargs)
453 return optimizer(inputs, output, size_dict, memory_limit)
456branch_all = functools.partial(branch, nbranch=None)
457branch_2 = functools.partial(branch, nbranch=2)
458branch_1 = functools.partial(branch, nbranch=1)
461def _get_candidate(output, sizes, remaining, footprints, dim_ref_counts, k1, k2, cost_fn):
462 either = k1 | k2
463 two = k1 & k2
464 one = either - two
465 k12 = (either & output) | (two & dim_ref_counts[3]) | (one & dim_ref_counts[2])
466 cost = cost_fn(helpers.compute_size_by_dict(k12, sizes), footprints[k1], footprints[k2], k12, k1, k2)
467 id1 = remaining[k1]
468 id2 = remaining[k2]
469 if id1 > id2:
470 k1, id1, k2, id2 = k2, id2, k1, id1
471 cost = cost, id2, id1 # break ties to ensure determinism
472 return cost, k1, k2, k12
475def _push_candidate(output, sizes, remaining, footprints, dim_ref_counts, k1, k2s, queue, push_all, cost_fn):
476 candidates = (_get_candidate(output, sizes, remaining, footprints, dim_ref_counts, k1, k2, cost_fn) for k2 in k2s)
477 if push_all:
478 # want to do this if we e.g. are using a custom 'choose_fn'
479 for candidate in candidates:
480 heapq.heappush(queue, candidate)
481 else:
482 heapq.heappush(queue, min(candidates))
485def _update_ref_counts(dim_to_keys, dim_ref_counts, dims):
486 for dim in dims:
487 count = len(dim_to_keys[dim])
488 if count <= 1:
489 dim_ref_counts[2].discard(dim)
490 dim_ref_counts[3].discard(dim)
491 elif count == 2:
492 dim_ref_counts[2].add(dim)
493 dim_ref_counts[3].discard(dim)
494 else:
495 dim_ref_counts[2].add(dim)
496 dim_ref_counts[3].add(dim)
499def _simple_chooser(queue, remaining):
500 """Default contraction chooser that simply takes the minimum cost option.
501 """
502 cost, k1, k2, k12 = heapq.heappop(queue)
503 if k1 not in remaining or k2 not in remaining:
504 return None # candidate is obsolete
505 return cost, k1, k2, k12
508def ssa_greedy_optimize(inputs, output, sizes, choose_fn=None, cost_fn='memory-removed'):
509 """
510 This is the core function for :func:`greedy` but produces a path with
511 static single assignment ids rather than recycled linear ids.
512 SSA ids are cheaper to work with and easier to reason about.
513 """
514 if len(inputs) == 1:
515 # Perform a single contraction to match output shape.
516 return [(0, )]
518 # set the function that assigns a heuristic cost to a possible contraction
519 cost_fn = _COST_FNS.get(cost_fn, cost_fn)
521 # set the function that chooses which contraction to take
522 if choose_fn is None:
523 choose_fn = _simple_chooser
524 push_all = False
525 else:
526 # assume chooser wants access to all possible contractions
527 push_all = True
529 # A dim that is common to all tensors might as well be an output dim, since it
530 # cannot be contracted until the final step. This avoids an expensive all-pairs
531 # comparison to search for possible contractions at each step, leading to speedup
532 # in many practical problems where all tensors share a common batch dimension.
533 inputs = list(map(frozenset, inputs))
534 output = frozenset(output) | frozenset.intersection(*inputs)
536 # Deduplicate shapes by eagerly computing Hadamard products.
537 remaining = {} # key -> ssa_id
538 ssa_ids = itertools.count(len(inputs))
539 ssa_path = []
540 for ssa_id, key in enumerate(inputs):
541 if key in remaining:
542 ssa_path.append((remaining[key], ssa_id))
543 remaining[key] = next(ssa_ids)
544 else:
545 remaining[key] = ssa_id
547 # Keep track of possible contraction dims.
548 dim_to_keys = defaultdict(set)
549 for key in remaining:
550 for dim in key - output:
551 dim_to_keys[dim].add(key)
553 # Keep track of the number of tensors using each dim; when the dim is no longer
554 # used it can be contracted. Since we specialize to binary ops, we only care about
555 # ref counts of >=2 or >=3.
556 dim_ref_counts = {
557 count: set(dim for dim, keys in dim_to_keys.items() if len(keys) >= count) - output
558 for count in [2, 3]
559 }
561 # Compute separable part of the objective function for contractions.
562 footprints = {key: helpers.compute_size_by_dict(key, sizes) for key in remaining}
564 # Find initial candidate contractions.
565 queue = []
566 for dim, keys in dim_to_keys.items():
567 keys = sorted(keys, key=remaining.__getitem__)
568 for i, k1 in enumerate(keys[:-1]):
569 k2s = keys[1 + i:]
570 _push_candidate(output, sizes, remaining, footprints, dim_ref_counts, k1, k2s, queue, push_all, cost_fn)
572 # Greedily contract pairs of tensors.
573 while queue:
575 con = choose_fn(queue, remaining)
576 if con is None:
577 continue # allow choose_fn to flag all candidates obsolete
578 cost, k1, k2, k12 = con
580 ssa_id1 = remaining.pop(k1)
581 ssa_id2 = remaining.pop(k2)
582 for dim in k1 - output:
583 dim_to_keys[dim].remove(k1)
584 for dim in k2 - output:
585 dim_to_keys[dim].remove(k2)
586 ssa_path.append((ssa_id1, ssa_id2))
587 if k12 in remaining:
588 ssa_path.append((remaining[k12], next(ssa_ids)))
589 else:
590 for dim in k12 - output:
591 dim_to_keys[dim].add(k12)
592 remaining[k12] = next(ssa_ids)
593 _update_ref_counts(dim_to_keys, dim_ref_counts, k1 | k2 - output)
594 footprints[k12] = helpers.compute_size_by_dict(k12, sizes)
596 # Find new candidate contractions.
597 k1 = k12
598 k2s = set(k2 for dim in k1 for k2 in dim_to_keys[dim])
599 k2s.discard(k1)
600 if k2s:
601 _push_candidate(output, sizes, remaining, footprints, dim_ref_counts, k1, k2s, queue, push_all, cost_fn)
603 # Greedily compute pairwise outer products.
604 queue = [(helpers.compute_size_by_dict(key & output, sizes), ssa_id, key) for key, ssa_id in remaining.items()]
605 heapq.heapify(queue)
606 _, ssa_id1, k1 = heapq.heappop(queue)
607 while queue:
608 _, ssa_id2, k2 = heapq.heappop(queue)
609 ssa_path.append((min(ssa_id1, ssa_id2), max(ssa_id1, ssa_id2)))
610 k12 = (k1 | k2) & output
611 cost = helpers.compute_size_by_dict(k12, sizes)
612 ssa_id12 = next(ssa_ids)
613 _, ssa_id1, k1 = heapq.heappushpop(queue, (cost, ssa_id12, k12))
615 return ssa_path
618def greedy(inputs, output, size_dict, memory_limit=None, choose_fn=None, cost_fn='memory-removed'):
619 """
620 Finds the path by a three stage algorithm:
622 1. Eagerly compute Hadamard products.
623 2. Greedily compute contractions to maximize ``removed_size``
624 3. Greedily compute outer products.
626 This algorithm scales quadratically with respect to the
627 maximum number of elements sharing a common dim.
629 Parameters
630 ----------
631 inputs : list
632 List of sets that represent the lhs side of the einsum subscript
633 output : set
634 Set that represents the rhs side of the overall einsum subscript
635 size_dict : dictionary
636 Dictionary of index sizes
637 memory_limit : int
638 The maximum number of elements in a temporary array
639 choose_fn : callable, optional
640 A function that chooses which contraction to perform from the queu
641 cost_fn : callable, optional
642 A function that assigns a potential contraction a cost.
644 Returns
645 -------
646 path : list
647 The contraction order (a list of tuples of ints).
649 Examples
650 --------
651 >>> isets = [set('abd'), set('ac'), set('bdc')]
652 >>> oset = set('')
653 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
654 >>> greedy(isets, oset, idx_sizes)
655 [(0, 2), (0, 1)]
656 """
657 if memory_limit not in _UNLIMITED_MEM:
658 return branch(inputs, output, size_dict, memory_limit, nbranch=1, cost_fn=cost_fn)
660 ssa_path = ssa_greedy_optimize(inputs, output, size_dict, cost_fn=cost_fn, choose_fn=choose_fn)
661 return ssa_to_linear(ssa_path)
664def _tree_to_sequence(c):
665 """
666 Converts a contraction tree to a contraction path as it has to be
667 returned by path optimizers. A contraction tree can either be an int
668 (=no contraction) or a tuple containing the terms to be contracted. An
669 arbitrary number (>= 1) of terms can be contracted at once. Note that
670 contractions are commutative, e.g. (j, k, l) = (k, l, j). Note that in
671 general, solutions are not unique.
673 Parameters
674 ----------
675 c : tuple or int
676 Contraction tree
678 Returns
679 -------
680 path : list[set[int]]
681 Contraction path
683 Examples
684 --------
685 >>> _tree_to_sequence(((1,2),(0,(4,5,3))))
686 [(1, 2), (1, 2, 3), (0, 2), (0, 1)]
687 """
689 # ((1,2),(0,(4,5,3))) --> [(1, 2), (1, 2, 3), (0, 2), (0, 1)]
690 #
691 # 0 0 0 (1,2) --> ((1,2),(0,(3,4,5)))
692 # 1 3 (1,2) --> (0,(3,4,5))
693 # 2 --> 4 --> (3,4,5)
694 # 3 5
695 # 4 (1,2)
696 # 5
697 #
698 # this function iterates through the table shown above from right to left;
700 if type(c) == int:
701 return []
703 c = [c] # list of remaining contractions (lower part of columns shown above)
704 t = [] # list of elementary tensors (upper part of colums)
705 s = [] # resulting contraction sequence
707 while len(c) > 0:
708 j = c.pop(-1)
709 s.insert(0, tuple())
711 for i in sorted([i for i in j if type(i) == int]):
712 s[0] += (sum(1 for q in t if q < i), )
713 t.insert(s[0][-1], i)
715 for i in [i for i in j if type(i) != int]:
716 s[0] += (len(t) + len(c), )
717 c.append(i)
719 return s
722def _find_disconnected_subgraphs(inputs, output):
723 """
724 Finds disconnected subgraphs in the given list of inputs. Inputs are
725 connected if they share summation indices. Note: Disconnected subgraphs
726 can be contracted independently before forming outer products.
728 Parameters
729 ----------
730 inputs : list[set]
731 List of sets that represent the lhs side of the einsum subscript
732 output : set
733 Set that represents the rhs side of the overall einsum subscript
735 Returns
736 -------
737 subgraphs : list[set[int]]
738 List containing sets of indices for each subgraph
740 Examples
741 --------
742 >>> _find_disconnected_subgraphs([set("ab"), set("c"), set("ad")], set("bd"))
743 [{0, 2}, {1}]
745 >>> _find_disconnected_subgraphs([set("ab"), set("c"), set("ad")], set("abd"))
746 [{0}, {1}, {2}]
747 """
749 subgraphs = []
750 unused_inputs = set(range(len(inputs)))
752 i_sum = set.union(*inputs) - output # all summation indices
754 while len(unused_inputs) > 0:
755 g = set()
756 q = [unused_inputs.pop()]
757 while len(q) > 0:
758 j = q.pop()
759 g.add(j)
760 i_tmp = i_sum & inputs[j]
761 n = {k for k in unused_inputs if len(i_tmp & inputs[k]) > 0}
762 q.extend(n)
763 unused_inputs.difference_update(n)
765 subgraphs.append(g)
767 return subgraphs
770def _bitmap_select(s, seq):
771 """Select elements of ``seq`` which are marked by the bitmap set ``s``.
773 E.g.:
775 >>> list(_bitmap_select(0b11010, ['A', 'B', 'C', 'D', 'E']))
776 ['B', 'D', 'E']
777 """
778 return (x for x, b in zip(seq, bin(s)[:1:-1]) if b == '1')
781def _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2):
782 """Calculates the effective outer indices of the intermediate tensor
783 corresponding to the subgraph ``s``.
784 """
785 # set of remaining tensors (=g-s)
786 r = g & (all_tensors ^ s)
787 # indices of remaining indices:
788 if r:
789 i_r = set.union(*_bitmap_select(r, inputs))
790 else:
791 i_r = set()
792 # contraction indices:
793 i_contract = i1_cut_i2_wo_output - i_r
794 return i1_union_i2 - i_contract
797def _dp_compare_flops(cost1, cost2, i1_union_i2, size_dict, cost_cap, s1, s2, xn, g, all_tensors, inputs,
798 i1_cut_i2_wo_output, memory_limit, cntrct1, cntrct2):
799 """Performs the inner comparison of whether the two subgraphs (the bitmaps
800 ``s1`` and ``s2``) should be merged and added to the dynamic programming
801 search. Will skip for a number of reasons:
803 1. If the number of operations to form ``s = s1 | s2`` including previous
804 contractions is above the cost-cap.
805 2. If we've already found a better way of making ``s``.
806 3. If the intermediate tensor corresponding to ``s`` is going to break the
807 memory limit.
808 """
809 cost = cost1 + cost2 + helpers.compute_size_by_dict(i1_union_i2, size_dict)
810 if cost <= cost_cap:
811 s = s1 | s2
812 if s not in xn or cost < xn[s][1]:
813 i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2)
814 mem = helpers.compute_size_by_dict(i, size_dict)
815 if memory_limit is None or mem <= memory_limit:
816 xn[s] = (i, cost, (cntrct1, cntrct2))
819def _dp_compare_size(cost1, cost2, i1_union_i2, size_dict, cost_cap, s1, s2, xn, g, all_tensors, inputs,
820 i1_cut_i2_wo_output, memory_limit, cntrct1, cntrct2):
821 """Like ``_dp_compare_flops`` but sieves the potential contraction based
822 on the size of the intermediate tensor created, rather than the number of
823 operations, and so calculates that first.
824 """
825 s = s1 | s2
826 i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2)
827 mem = helpers.compute_size_by_dict(i, size_dict)
828 cost = max(cost1, cost2, mem)
829 if cost <= cost_cap:
830 if s not in xn or cost < xn[s][1]:
831 if memory_limit is None or mem <= memory_limit:
832 xn[s] = (i, cost, (cntrct1, cntrct2))
835def simple_tree_tuple(seq):
836 """Make a simple left to right binary tree out of iterable ``seq``.
838 >>> tuple_nest([1, 2, 3, 4])
839 (((1, 2), 3), 4)
841 """
842 return functools.reduce(lambda x, y: (x, y), seq)
845def _dp_parse_out_single_term_ops(inputs, all_inds, ind_counts):
846 """Take ``inputs`` and parse for single term index operations, i.e. where
847 an index appears on one tensor and nowhere else.
849 If a term is completely reduced to a scalar in this way it can be removed
850 to ``inputs_done``. If only some indices can be summed then add a 'single
851 term contraction' that will perform this summation.
852 """
853 i_single = {i for i, c in enumerate(all_inds) if ind_counts[c] == 1}
854 inputs_parsed, inputs_done, inputs_contractions = [], [], []
855 for j, i in enumerate(inputs):
856 i_reduced = i - i_single
857 if not i_reduced:
858 # input reduced to scalar already - remove
859 inputs_done.append((j, ))
860 else:
861 # if the input has any index reductions, add single contraction
862 inputs_parsed.append(i_reduced)
863 inputs_contractions.append((j, ) if i_reduced != i else j)
865 return inputs_parsed, inputs_done, inputs_contractions
868class DynamicProgramming(PathOptimizer):
869 """
870 Finds the optimal path of pairwise contractions without intermediate outer
871 products based a dynamic programming approach presented in
872 Phys. Rev. E 90, 033315 (2014) (the corresponding preprint is publically
873 available at https://arxiv.org/abs/1304.6112). This method is especially
874 well-suited in the area of tensor network states, where it usually
875 outperforms all the other optimization strategies.
877 This algorithm shows exponential scaling with the number of inputs
878 in the worst case scenario (see example below). If the graph to be
879 contracted consists of disconnected subgraphs, the algorithm scales
880 linearly in the number of disconnected subgraphs and only exponentially
881 with the number of inputs per subgraph.
883 Parameters
884 ----------
885 minimize : {'flops', 'size'}, optional
886 Whether to find the contraction that minimizes the number of
887 operations or the size of the largest intermediate tensor.
888 cost_cap : {True, False, int}, optional
889 How to implement cost-capping:
891 * True - iteratively increase the cost-cap
892 * False - implement no cost-cap at all
893 * int - use explicit cost cap
895 search_outer : bool, optional
896 In rare circumstances the optimal contraction may involve an outer
897 product, this option allows searching such contractions but may well
898 slow down the path finding considerably on all but very small graphs.
899 """
900 def __init__(self, minimize='flops', cost_cap=True, search_outer=False):
902 # set whether inner function minimizes against flops or size
903 self.minimize = minimize
904 self._check_contraction = {
905 'flops': _dp_compare_flops,
906 'size': _dp_compare_size,
907 }[self.minimize]
909 # set whether inner function considers outer products
910 self.search_outer = search_outer
911 self._check_outer = {
912 False: lambda x: x,
913 True: lambda x: True,
914 }[self.search_outer]
916 self.cost_cap = cost_cap
918 def __call__(self, inputs, output, size_dict, memory_limit=None):
919 """
920 Parameters
921 ----------
922 inputs : list
923 List of sets that represent the lhs side of the einsum subscript
924 output : set
925 Set that represents the rhs side of the overall einsum subscript
926 size_dict : dictionary
927 Dictionary of index sizes
928 memory_limit : int
929 The maximum number of elements in a temporary array
931 Returns
932 -------
933 path : list
934 The contraction order (a list of tuples of ints).
936 Examples
937 --------
938 >>> n_in = 3 # exponential scaling
939 >>> n_out = 2 # linear scaling
940 >>> s = dict()
941 >>> i_all = []
942 >>> for _ in range(n_out):
943 >>> i = [set() for _ in range(n_in)]
944 >>> for j in range(n_in):
945 >>> for k in range(j+1, n_in):
946 >>> c = oe.get_symbol(len(s))
947 >>> i[j].add(c)
948 >>> i[k].add(c)
949 >>> s[c] = 2
950 >>> i_all.extend(i)
951 >>> o = DynamicProgramming()
952 >>> o(i_all, set(), s)
953 [(1, 2), (0, 4), (1, 2), (0, 2), (0, 1)]
954 """
955 ind_counts = Counter(itertools.chain(*inputs, output))
956 all_inds = tuple(ind_counts)
958 # convert all indices to integers (makes set operations ~10 % faster)
959 symbol2int = {c: j for j, c in enumerate(all_inds)}
960 inputs = [set(symbol2int[c] for c in i) for i in inputs]
961 output = set(symbol2int[c] for c in output)
962 size_dict = {symbol2int[c]: v for c, v in size_dict.items() if c in symbol2int}
963 size_dict = [size_dict[j] for j in range(len(size_dict))]
965 inputs, inputs_done, inputs_contractions = _dp_parse_out_single_term_ops(inputs, all_inds, ind_counts)
967 if not inputs:
968 # nothing left to do after single axis reductions!
969 return _tree_to_sequence(simple_tree_tuple(inputs_done))
971 # a list of all neccessary contraction expressions for each of the
972 # disconnected subgraphs and their size
973 subgraph_contractions = inputs_done
974 subgraph_contractions_size = [1] * len(inputs_done)
976 if self.search_outer:
977 # optimize everything together if we are considering outer products
978 subgraphs = [set(range(len(inputs)))]
979 else:
980 subgraphs = _find_disconnected_subgraphs(inputs, output)
982 # the bitmap set of all tensors is computed as it is needed to
983 # compute set differences: s1 - s2 transforms into
984 # s1 & (all_tensors ^ s2)
985 all_tensors = (1 << len(inputs)) - 1
987 for g in subgraphs:
989 # dynamic programming approach to compute x[n] for subgraph g;
990 # x[n][set of n tensors] = (indices, cost, contraction)
991 # the set of n tensors is represented by a bitmap: if bit j is 1,
992 # tensor j is in the set, e.g. 0b100101 = {0,2,5}; set unions
993 # (intersections) can then be computed by bitwise or (and);
994 x = [None] * 2 + [dict() for j in range(len(g) - 1)]
995 x[1] = OrderedDict((1 << j, (inputs[j], 0, inputs_contractions[j])) for j in g)
997 # convert set of tensors g to a bitmap set:
998 g = functools.reduce(lambda x, y: x | y, (1 << j for j in g))
1000 # try to find contraction with cost <= cost_cap and increase
1001 # cost_cap successively if no such contraction is found;
1002 # this is a major performance improvement; start with product of
1003 # output index dimensions as initial cost_cap
1004 subgraph_inds = set.union(*_bitmap_select(g, inputs))
1005 if self.cost_cap is True:
1006 cost_cap = helpers.compute_size_by_dict(subgraph_inds & output, size_dict)
1007 elif self.cost_cap is False:
1008 cost_cap = float('inf')
1009 else:
1010 cost_cap = self.cost_cap
1011 # set the factor to increase the cost by each iteration (ensure > 1)
1012 cost_increment = max(min(map(size_dict.__getitem__, subgraph_inds)), 2)
1014 while len(x[-1]) == 0:
1015 for n in range(2, len(x[1]) + 1):
1016 xn = x[n]
1018 # try to combine solutions from x[m] and x[n-m]
1019 for m in range(1, n // 2 + 1):
1020 for s1, (i1, cost1, cntrct1) in x[m].items():
1021 for s2, (i2, cost2, cntrct2) in x[n - m].items():
1023 # can only merge if s1 and s2 are disjoint
1024 # and avoid e.g. s1={0}, s2={1} and s1={1}, s2={0}
1025 if (not s1 & s2) and (m != n - m or s1 < s2):
1026 i1_cut_i2_wo_output = (i1 & i2) - output
1028 # maybe ignore outer products:
1029 if self._check_outer(i1_cut_i2_wo_output):
1031 i1_union_i2 = i1 | i2
1032 self._check_contraction(cost1, cost2, i1_union_i2, size_dict, cost_cap, s1, s2,
1033 xn, g, all_tensors, inputs, i1_cut_i2_wo_output,
1034 memory_limit, cntrct1, cntrct2)
1036 # increase cost cap for next iteration:
1037 cost_cap = cost_increment * cost_cap
1039 i, cost, contraction = list(x[-1].values())[0]
1040 subgraph_contractions.append(contraction)
1041 subgraph_contractions_size.append(helpers.compute_size_by_dict(i, size_dict))
1043 # sort the subgraph contractions by the size of the subgraphs in
1044 # ascending order (will give the cheapest contractions); note that
1045 # outer products should be performed pairwise (to use BLAS functions)
1046 subgraph_contractions = [
1047 subgraph_contractions[j]
1048 for j in sorted(range(len(subgraph_contractions_size)), key=subgraph_contractions_size.__getitem__)
1049 ]
1051 # build the final contraction tree
1052 tree = simple_tree_tuple(subgraph_contractions)
1053 return _tree_to_sequence(tree)
1056def dynamic_programming(inputs, output, size_dict, memory_limit=None, **kwargs):
1057 optimizer = DynamicProgramming(**kwargs)
1058 return optimizer(inputs, output, size_dict, memory_limit)
1061_AUTO_CHOICES = {}
1062for i in range(1, 5):
1063 _AUTO_CHOICES[i] = optimal
1064for i in range(5, 7):
1065 _AUTO_CHOICES[i] = branch_all
1066for i in range(7, 9):
1067 _AUTO_CHOICES[i] = branch_2
1068for i in range(9, 15):
1069 _AUTO_CHOICES[i] = branch_1
1072def auto(inputs, output, size_dict, memory_limit=None):
1073 """Finds the contraction path by automatically choosing the method based on
1074 how many input arguments there are.
1075 """
1076 N = len(inputs)
1077 return _AUTO_CHOICES.get(N, greedy)(inputs, output, size_dict, memory_limit)
1080_AUTO_HQ_CHOICES = {}
1081for i in range(1, 6):
1082 _AUTO_HQ_CHOICES[i] = optimal
1083for i in range(6, 17):
1084 _AUTO_HQ_CHOICES[i] = dynamic_programming
1087def auto_hq(inputs, output, size_dict, memory_limit=None):
1088 """Finds the contraction path by automatically choosing the method based on
1089 how many input arguments there are, but targeting a more generous
1090 amount of search time than ``'auto'``.
1091 """
1092 from .path_random import random_greedy_128
1094 N = len(inputs)
1095 return _AUTO_HQ_CHOICES.get(N, random_greedy_128)(inputs, output, size_dict, memory_limit)
1098_PATH_OPTIONS = {
1099 'auto': auto,
1100 'auto-hq': auto_hq,
1101 'optimal': optimal,
1102 'branch-all': branch_all,
1103 'branch-2': branch_2,
1104 'branch-1': branch_1,
1105 'greedy': greedy,
1106 'eager': greedy,
1107 'opportunistic': greedy,
1108 'dp': dynamic_programming,
1109 'dynamic-programming': dynamic_programming
1110}
1113def register_path_fn(name, fn):
1114 """Add path finding function ``fn`` as an option with ``name``.
1115 """
1116 if name in _PATH_OPTIONS:
1117 raise KeyError("Path optimizer '{}' already exists.".format(name))
1119 _PATH_OPTIONS[name.lower()] = fn
1122def get_path_fn(path_type):
1123 """Get the correct path finding function from str ``path_type``.
1124 """
1125 if path_type not in _PATH_OPTIONS:
1126 raise KeyError("Path optimizer '{}' not found, valid options are {}.".format(
1127 path_type, set(_PATH_OPTIONS.keys())))
1129 return _PATH_OPTIONS[path_type]