1# util/topological.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""Topological sorting algorithms."""
9
10from __future__ import annotations
11
12from typing import Any
13from typing import Collection
14from typing import DefaultDict
15from typing import Iterable
16from typing import Iterator
17from typing import Sequence
18from typing import Set
19from typing import Tuple
20from typing import TypeVar
21
22from .. import util
23from ..exc import CircularDependencyError
24
25_T = TypeVar("_T", bound=Any)
26
27__all__ = ["sort", "sort_as_subsets", "find_cycles"]
28
29
30def sort_as_subsets(
31 tuples: Collection[Tuple[_T, _T]], allitems: Collection[_T]
32) -> Iterator[Sequence[_T]]:
33 edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set)
34 for parent, child in tuples:
35 edges[child].add(parent)
36
37 todo = list(allitems)
38 todo_set = set(allitems)
39
40 while todo_set:
41 output = []
42 for node in todo:
43 if todo_set.isdisjoint(edges[node]):
44 output.append(node)
45
46 if not output:
47 raise CircularDependencyError(
48 "Circular dependency detected.",
49 find_cycles(tuples, allitems),
50 _gen_edges(edges),
51 )
52
53 todo_set.difference_update(output)
54 todo = [t for t in todo if t in todo_set]
55 yield output
56
57
58def sort(
59 tuples: Collection[Tuple[_T, _T]],
60 allitems: Collection[_T],
61 deterministic_order: bool = True,
62) -> Iterator[_T]:
63 """sort the given list of items by dependency.
64
65 'tuples' is a list of tuples representing a partial ordering.
66
67 deterministic_order is no longer used, the order is now always
68 deterministic given the order of "allitems". the flag is there
69 for backwards compatibility with Alembic.
70
71 """
72
73 for set_ in sort_as_subsets(tuples, allitems):
74 yield from set_
75
76
77def find_cycles(
78 tuples: Iterable[Tuple[_T, _T]], allitems: Iterable[_T]
79) -> Set[_T]:
80 # adapted from:
81 # https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
82
83 edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set)
84 for parent, child in tuples:
85 edges[parent].add(child)
86 nodes_to_test = set(edges)
87
88 output = set()
89
90 # we'd like to find all nodes that are
91 # involved in cycles, so we do the full
92 # pass through the whole thing for each
93 # node in the original list.
94
95 # we can go just through parent edge nodes.
96 # if a node is only a child and never a parent,
97 # by definition it can't be part of a cycle. same
98 # if it's not in the edges at all.
99 for node in nodes_to_test:
100 stack = [node]
101 todo = nodes_to_test.difference(stack)
102 while stack:
103 top = stack[-1]
104 for node in edges[top]:
105 if node in stack:
106 cyc = stack[stack.index(node) :]
107 todo.difference_update(cyc)
108 output.update(cyc)
109
110 if node in todo:
111 stack.append(node)
112 todo.remove(node)
113 break
114 else:
115 node = stack.pop()
116 return output
117
118
119def _gen_edges(edges: DefaultDict[_T, Set[_T]]) -> Set[Tuple[_T, _T]]:
120 return {(right, left) for left in edges for right in edges[left]}