1"""Algorithms for finding the lowest common ancestor of trees and DAGs."""
2
3from collections import defaultdict
4from collections.abc import Mapping, Set
5from itertools import combinations_with_replacement
6
7import networkx as nx
8from networkx.utils import UnionFind, arbitrary_element, not_implemented_for
9
10__all__ = [
11 "all_pairs_lowest_common_ancestor",
12 "tree_all_pairs_lowest_common_ancestor",
13 "lowest_common_ancestor",
14]
15
16
17@not_implemented_for("undirected")
18@nx._dispatchable
19def all_pairs_lowest_common_ancestor(G, pairs=None):
20 """Return the lowest common ancestor of all pairs or the provided pairs
21
22 Parameters
23 ----------
24 G : NetworkX directed graph
25
26 pairs : iterable of pairs of nodes, optional (default: all pairs)
27 The pairs of nodes of interest.
28 If None, will find the LCA of all pairs of nodes.
29
30 Yields
31 ------
32 ((node1, node2), lca) : 2-tuple
33 Where lca is least common ancestor of node1 and node2.
34 Note that for the default case, the order of the node pair is not considered,
35 e.g. you will not get both ``(a, b)`` and ``(b, a)``
36
37 Raises
38 ------
39 NetworkXPointlessConcept
40 If `G` is null.
41 NetworkXError
42 If `G` is not a DAG.
43
44 Examples
45 --------
46 >>> from pprint import pprint
47
48 The default behavior is to yield the lowest common ancestor for all
49 possible combinations of nodes in `G`, including self-pairings:
50
51 >>> G = nx.DiGraph([(0, 1), (0, 3), (1, 2)])
52 >>> pprint(dict(nx.all_pairs_lowest_common_ancestor(G)))
53 {(0, 0): 0,
54 (0, 1): 0,
55 (0, 2): 0,
56 (0, 3): 0,
57 (1, 1): 1,
58 (1, 2): 1,
59 (1, 3): 0,
60 (2, 2): 2,
61 (3, 2): 0,
62 (3, 3): 3}
63
64 The pairs argument can be used to limit the output to only the
65 specified node pairings:
66
67 >>> dict(nx.all_pairs_lowest_common_ancestor(G, pairs=[(1, 2), (2, 3)]))
68 {(1, 2): 1, (2, 3): 0}
69
70 Notes
71 -----
72 Only defined on non-null directed acyclic graphs.
73
74 See Also
75 --------
76 lowest_common_ancestor
77 """
78 if not nx.is_directed_acyclic_graph(G):
79 raise nx.NetworkXError("LCA only defined on directed acyclic graphs.")
80 if len(G) == 0:
81 raise nx.NetworkXPointlessConcept("LCA meaningless on null graphs.")
82
83 if pairs is None:
84 pairs = combinations_with_replacement(G, 2)
85 else:
86 # Convert iterator to iterable, if necessary. Trim duplicates.
87 pairs = dict.fromkeys(pairs)
88 # Verify that each of the nodes in the provided pairs is in G
89 nodeset = set(G)
90 for pair in pairs:
91 if set(pair) - nodeset:
92 raise nx.NodeNotFound(
93 f"Node(s) {set(pair) - nodeset} from pair {pair} not in G."
94 )
95
96 # Once input validation is done, construct the generator
97 def generate_lca_from_pairs(G, pairs):
98 ancestor_cache = {}
99
100 for v, w in pairs:
101 if v not in ancestor_cache:
102 ancestor_cache[v] = nx.ancestors(G, v)
103 ancestor_cache[v].add(v)
104 if w not in ancestor_cache:
105 ancestor_cache[w] = nx.ancestors(G, w)
106 ancestor_cache[w].add(w)
107
108 common_ancestors = ancestor_cache[v] & ancestor_cache[w]
109
110 if common_ancestors:
111 common_ancestor = next(iter(common_ancestors))
112 while True:
113 successor = None
114 for lower_ancestor in G.successors(common_ancestor):
115 if lower_ancestor in common_ancestors:
116 successor = lower_ancestor
117 break
118 if successor is None:
119 break
120 common_ancestor = successor
121 yield ((v, w), common_ancestor)
122
123 return generate_lca_from_pairs(G, pairs)
124
125
126@not_implemented_for("undirected")
127@nx._dispatchable
128def lowest_common_ancestor(G, node1, node2, default=None):
129 """Compute the lowest common ancestor of the given pair of nodes.
130
131 Parameters
132 ----------
133 G : NetworkX directed graph
134
135 node1, node2 : nodes in the graph.
136
137 default : object
138 Returned if no common ancestor between `node1` and `node2`
139
140 Returns
141 -------
142 The lowest common ancestor of node1 and node2,
143 or default if they have no common ancestors.
144
145 Examples
146 --------
147 >>> G = nx.DiGraph()
148 >>> nx.add_path(G, (0, 1, 2, 3))
149 >>> nx.add_path(G, (0, 4, 3))
150 >>> nx.lowest_common_ancestor(G, 2, 4)
151 0
152
153 See Also
154 --------
155 all_pairs_lowest_common_ancestor"""
156
157 ans = list(all_pairs_lowest_common_ancestor(G, pairs=[(node1, node2)]))
158 if ans:
159 assert len(ans) == 1
160 return ans[0][1]
161 return default
162
163
164@not_implemented_for("undirected")
165@nx._dispatchable
166def tree_all_pairs_lowest_common_ancestor(G, root=None, pairs=None):
167 r"""Yield the lowest common ancestor for sets of pairs in a tree.
168
169 Parameters
170 ----------
171 G : NetworkX directed graph (must be a tree)
172
173 root : node, optional (default: None)
174 The root of the subtree to operate on.
175 If None, assume the entire graph has exactly one source and use that.
176
177 pairs : iterable or iterator of pairs of nodes, optional (default: None)
178 The pairs of interest. If None, Defaults to all pairs of nodes
179 under `root` that have a lowest common ancestor.
180
181 Returns
182 -------
183 lcas : generator of tuples `((u, v), lca)` where `u` and `v` are nodes
184 in `pairs` and `lca` is their lowest common ancestor.
185
186 Examples
187 --------
188 >>> import pprint
189 >>> G = nx.DiGraph([(1, 3), (2, 4), (1, 2)])
190 >>> pprint.pprint(dict(nx.tree_all_pairs_lowest_common_ancestor(G)))
191 {(1, 1): 1,
192 (2, 1): 1,
193 (2, 2): 2,
194 (3, 1): 1,
195 (3, 2): 1,
196 (3, 3): 3,
197 (3, 4): 1,
198 (4, 1): 1,
199 (4, 2): 2,
200 (4, 4): 4}
201
202 We can also use `pairs` argument to specify the pairs of nodes for which we
203 want to compute lowest common ancestors. Here is an example:
204
205 >>> dict(nx.tree_all_pairs_lowest_common_ancestor(G, pairs=[(1, 4), (2, 3)]))
206 {(2, 3): 1, (1, 4): 1}
207
208 Notes
209 -----
210 Only defined on non-null trees represented with directed edges from
211 parents to children. Uses Tarjan's off-line lowest-common-ancestors
212 algorithm. Runs in time $O(4 \times (V + E + P))$ time, where 4 is the largest
213 value of the inverse Ackermann function likely to ever come up in actual
214 use, and $P$ is the number of pairs requested (or $V^2$ if all are needed).
215
216 Tarjan, R. E. (1979), "Applications of path compression on balanced trees",
217 Journal of the ACM 26 (4): 690-715, doi:10.1145/322154.322161.
218
219 See Also
220 --------
221 all_pairs_lowest_common_ancestor: similar routine for general DAGs
222 lowest_common_ancestor: just a single pair for general DAGs
223 """
224 if len(G) == 0:
225 raise nx.NetworkXPointlessConcept("LCA meaningless on null graphs.")
226
227 # Index pairs of interest for efficient lookup from either side.
228 if pairs is not None:
229 pair_dict = defaultdict(set)
230 # See note on all_pairs_lowest_common_ancestor.
231 if not isinstance(pairs, Mapping | Set):
232 pairs = set(pairs)
233 for u, v in pairs:
234 for n in (u, v):
235 if n not in G:
236 msg = f"The node {str(n)} is not in the digraph."
237 raise nx.NodeNotFound(msg)
238 pair_dict[u].add(v)
239 pair_dict[v].add(u)
240
241 # If root is not specified, find the exactly one node with in degree 0 and
242 # use it. Raise an error if none are found, or more than one is. Also check
243 # for any nodes with in degree larger than 1, which would imply G is not a
244 # tree.
245 if root is None:
246 for n, deg in G.in_degree:
247 if deg == 0:
248 if root is not None:
249 msg = "No root specified and tree has multiple sources."
250 raise nx.NetworkXError(msg)
251 root = n
252 # checking deg>1 is not sufficient for MultiDiGraphs
253 elif deg > 1 and len(G.pred[n]) > 1:
254 msg = "Tree LCA only defined on trees; use DAG routine."
255 raise nx.NetworkXError(msg)
256 if root is None:
257 raise nx.NetworkXError("Graph contains a cycle.")
258
259 # Iterative implementation of Tarjan's offline lca algorithm
260 # as described in CLRS on page 521 (2nd edition)/page 584 (3rd edition)
261 uf = UnionFind()
262 ancestors = {}
263 for node in G:
264 ancestors[node] = uf[node]
265
266 colors = defaultdict(bool)
267 for node in nx.dfs_postorder_nodes(G, root):
268 colors[node] = True
269 for v in pair_dict[node] if pairs is not None else G:
270 if colors[v]:
271 # If the user requested both directions of a pair, give it.
272 # Otherwise, just give one.
273 if pairs is not None and (node, v) in pairs:
274 yield (node, v), ancestors[uf[v]]
275 if pairs is None or (v, node) in pairs:
276 yield (v, node), ancestors[uf[v]]
277 if node != root:
278 parent = arbitrary_element(G.pred[node])
279 uf.union(parent, node)
280 ancestors[uf[parent]] = parent