1"""
2Algorithms for computing distance measures on trees.
3"""
4
5import networkx as nx
6
7__all__ = [
8 "center",
9 "centroid",
10]
11
12
13@nx.utils.not_implemented_for("directed")
14def center(G):
15 """Returns the center of an undirected tree graph.
16
17 The center of a tree consists of nodes that minimize the maximum eccentricity.
18 That is, these nodes minimize the maximum distance to all other nodes.
19 This implementation currently only works for unweighted edges.
20
21 If the input graph is not a tree, results are not guaranteed to be correct and while
22 some non-trees will raise an ``nx.NotATree`` exception, not all non-trees will be discovered.
23 Thus, this function should not be used if caller is unsure whether the input graph
24 is a tree. Use ``nx.is_tree(G)`` to check.
25
26 Parameters
27 ----------
28 G : NetworkX graph
29 A tree graph (undirected, acyclic graph).
30
31 Returns
32 -------
33 center : list
34 A list of nodes forming the center of the tree. This can be one or two nodes.
35
36 Raises
37 ------
38 NetworkXNotImplemented
39 If the input graph is directed.
40
41 NotATree
42 If the algorithm detects the input graph is not a tree. There is no guarantee
43 this error will always raise if a non-tree is passed.
44
45 Notes
46 -----
47 This algorithm iteratively removes leaves (nodes with degree 1) from the tree until
48 there are only 1 or 2 nodes left. The remaining nodes form the center of the tree.
49
50 This algorithm's time complexity is ``O(N)`` where ``N`` is the number of nodes in the tree.
51
52 Examples
53 --------
54 >>> G = nx.Graph([(1, 2), (1, 3), (2, 4), (2, 5)])
55 >>> nx.tree.center(G)
56 [1, 2]
57
58 >>> G = nx.path_graph(5)
59 >>> nx.tree.center(G)
60 [2]
61 """
62 center_candidates_degree = dict(G.degree)
63 leaves = {node for node, degree in center_candidates_degree.items() if degree == 1}
64
65 # It's better to fail than an infinite loop, so check leaves to ensure progress.
66 while len(center_candidates_degree) > 2 and leaves:
67 new_leaves = set()
68 for leaf in leaves:
69 del center_candidates_degree[leaf]
70 for neighbor in G.neighbors(leaf):
71 if neighbor not in center_candidates_degree:
72 continue
73 center_candidates_degree[neighbor] -= 1
74 if (cddn := center_candidates_degree[neighbor]) == 1:
75 new_leaves.add(neighbor)
76 elif cddn == 0 and len(center_candidates_degree) != 1:
77 raise nx.NotATree("input graph is not a tree")
78 leaves = new_leaves
79
80 n = len(center_candidates_degree)
81 # check disconnected or cyclic
82 if (n == 2 and not leaves) or n not in {1, 2}:
83 # We have detected graph is not a tree. This check does not cover all cases.
84 # For example, `nx.Graph([(0, 0)])` will not raise an error.
85 raise nx.NotATree("input graph is not a tree")
86
87 return list(center_candidates_degree)
88
89
90def _subtree_sizes(G, root):
91 """Return a `dict` of the size of each subtree, for every subtree
92 of a tree rooted at a given node.
93
94 For every node in the given tree, consider the new tree that would
95 be created by detaching it from its parent node (if any). The
96 number of nodes in the resulting tree rooted at that node is then
97 assigned as the value for that node in the return dictionary.
98
99 Parameters
100 ----------
101 G : NetworkX graph
102 A tree.
103
104 root : node
105 A node in `G`.
106
107 Returns
108 -------
109 s : dict
110 Dictionary of number of nodes in every subtree of this tree,
111 keyed on the root node for each subtree.
112
113 Examples
114 --------
115 >>> _subtree_sizes(nx.path_graph(4), 0)
116 {0: 4, 1: 3, 2: 2, 3: 1}
117
118 >>> _subtree_sizes(nx.path_graph(4), 2)
119 {2: 4, 1: 2, 0: 1, 3: 1}
120
121 """
122 sizes = {root: 1}
123 stack = [root]
124 for parent, child in nx.dfs_edges(G, root):
125 while stack[-1] != parent:
126 descendant = stack.pop()
127 sizes[stack[-1]] += sizes[descendant]
128 stack.append(child)
129 sizes[child] = 1
130 for child, parent in nx.utils.pairwise(reversed(stack)):
131 sizes[parent] += sizes[child]
132 return sizes
133
134
135@nx.utils.not_implemented_for("directed")
136@nx._dispatchable
137def centroid(G):
138 """Return the centroid of an unweighted tree.
139
140 The centroid of a tree is the set of nodes such that removing any
141 one of them would split the tree into a forest of subtrees, each
142 with at most ``N / 2`` nodes, where ``N`` is the number of nodes
143 in the original tree. This set may contain two nodes if removing
144 an edge between them results in two trees of size exactly ``N /
145 2``.
146
147 Parameters
148 ----------
149 G : NetworkX graph
150 A tree.
151
152 Returns
153 -------
154 c : list
155 List of nodes in centroid of the tree. This could be one or two nodes.
156
157 Raises
158 ------
159 NotATree
160 If the input graph is not a tree.
161 NotImplementedException
162 If the input graph is directed.
163 NetworkXPointlessConcept
164 If `G` has no nodes or edges.
165
166 Notes
167 -----
168 This algorithm's time complexity is ``O(N)`` where ``N`` is the
169 number of nodes in the tree.
170
171 In unweighted trees the centroid coincides with the barycenter,
172 the node or nodes that minimize the sum of distances to all other
173 nodes. However, this concept is different from that of the graph
174 center, which is the set of nodes minimizing the maximum distance
175 to all other nodes.
176
177 Examples
178 --------
179 >>> G = nx.path_graph(4)
180 >>> nx.tree.centroid(G)
181 [1, 2]
182
183 A star-shaped tree with one long branch illustrates the difference
184 between the centroid and the center. The center lies near the
185 middle of the long branch, minimizing maximum distance. The
186 centroid, however, limits the size of any resulting subtree to at
187 most half the total nodes, forcing it to remain near the hub when
188 enough short branches are present.
189
190 >>> G = nx.star_graph(6)
191 >>> nx.add_path(G, [6, 7, 8, 9, 10])
192 >>> nx.tree.centroid(G), nx.tree.center(G)
193 ([0], [7])
194
195 See Also
196 --------
197 :func:`~networkx.algorithms.distance_measures.barycenter`
198 :func:`~networkx.algorithms.distance_measures.center`
199 center : tree center
200 """
201 if not nx.is_tree(G):
202 raise nx.NotATree("provided graph is not a tree")
203 prev, root = None, nx.utils.arbitrary_element(G)
204 sizes = _subtree_sizes(G, root)
205 total_size = G.number_of_nodes()
206
207 def _heaviest_child(prev, root):
208 return max(
209 (x for x in G.neighbors(root) if x != prev), key=sizes.get, default=None
210 )
211
212 hc = _heaviest_child(prev, root)
213 while max(total_size - sizes[root], sizes.get(hc, 0)) > total_size / 2:
214 prev, root = root, hc
215 hc = _heaviest_child(prev, root)
216
217 return [root] + [
218 x for x in G.neighbors(root) if x != prev and sizes[x] == total_size / 2
219 ]