1"""
2Algorithms for calculating min/max spanning trees/forests.
3
4"""
5
6from dataclasses import dataclass, field
7from enum import Enum
8from heapq import heappop, heappush
9from itertools import count
10from math import isnan
11from operator import itemgetter
12from queue import PriorityQueue
13
14import networkx as nx
15from networkx.utils import UnionFind, not_implemented_for, py_random_state
16
17__all__ = [
18 "minimum_spanning_edges",
19 "maximum_spanning_edges",
20 "minimum_spanning_tree",
21 "maximum_spanning_tree",
22 "number_of_spanning_trees",
23 "random_spanning_tree",
24 "partition_spanning_tree",
25 "EdgePartition",
26 "SpanningTreeIterator",
27]
28
29
30class EdgePartition(Enum):
31 """
32 An enum to store the state of an edge partition. The enum is written to the
33 edges of a graph before being pasted to `kruskal_mst_edges`. Options are:
34
35 - EdgePartition.OPEN
36 - EdgePartition.INCLUDED
37 - EdgePartition.EXCLUDED
38 """
39
40 OPEN = 0
41 INCLUDED = 1
42 EXCLUDED = 2
43
44
45@not_implemented_for("multigraph")
46@nx._dispatchable(edge_attrs="weight", preserve_edge_attrs="data")
47def boruvka_mst_edges(
48 G, minimum=True, weight="weight", keys=False, data=True, ignore_nan=False
49):
50 """Iterate over edges of a Borůvka's algorithm min/max spanning tree.
51
52 Parameters
53 ----------
54 G : NetworkX Graph
55 The edges of `G` must have distinct weights,
56 otherwise the edges may not form a tree.
57
58 minimum : bool (default: True)
59 Find the minimum (True) or maximum (False) spanning tree.
60
61 weight : string (default: 'weight')
62 The name of the edge attribute holding the edge weights.
63
64 keys : bool (default: True)
65 This argument is ignored since this function is not
66 implemented for multigraphs; it exists only for consistency
67 with the other minimum spanning tree functions.
68
69 data : bool (default: True)
70 Flag for whether to yield edge attribute dicts.
71 If True, yield edges `(u, v, d)`, where `d` is the attribute dict.
72 If False, yield edges `(u, v)`.
73
74 ignore_nan : bool (default: False)
75 If a NaN is found as an edge weight normally an exception is raised.
76 If `ignore_nan is True` then that edge is ignored instead.
77
78 """
79 # Initialize a forest, assuming initially that it is the discrete
80 # partition of the nodes of the graph.
81 forest = UnionFind(G)
82
83 def best_edge(component):
84 """Returns the optimum (minimum or maximum) edge on the edge
85 boundary of the given set of nodes.
86
87 A return value of ``None`` indicates an empty boundary.
88
89 """
90 sign = 1 if minimum else -1
91 minwt = float("inf")
92 boundary = None
93 for e in nx.edge_boundary(G, component, data=True):
94 wt = e[-1].get(weight, 1) * sign
95 if isnan(wt):
96 if ignore_nan:
97 continue
98 msg = f"NaN found as an edge weight. Edge {e}"
99 raise ValueError(msg)
100 if wt < minwt:
101 minwt = wt
102 boundary = e
103 return boundary
104
105 # Determine the optimum edge in the edge boundary of each component
106 # in the forest.
107 best_edges = (best_edge(component) for component in forest.to_sets())
108 best_edges = [edge for edge in best_edges if edge is not None]
109 # If each entry was ``None``, that means the graph was disconnected,
110 # so we are done generating the forest.
111 while best_edges:
112 # Determine the optimum edge in the edge boundary of each
113 # component in the forest.
114 #
115 # This must be a sequence, not an iterator. In this list, the
116 # same edge may appear twice, in different orientations (but
117 # that's okay, since a union operation will be called on the
118 # endpoints the first time it is seen, but not the second time).
119 #
120 # Any ``None`` indicates that the edge boundary for that
121 # component was empty, so that part of the forest has been
122 # completed.
123 #
124 # TODO This can be parallelized, both in the outer loop over
125 # each component in the forest and in the computation of the
126 # minimum. (Same goes for the identical lines outside the loop.)
127 best_edges = (best_edge(component) for component in forest.to_sets())
128 best_edges = [edge for edge in best_edges if edge is not None]
129 # Join trees in the forest using the best edges, and yield that
130 # edge, since it is part of the spanning tree.
131 #
132 # TODO This loop can be parallelized, to an extent (the union
133 # operation must be atomic).
134 for u, v, d in best_edges:
135 if forest[u] != forest[v]:
136 if data:
137 yield u, v, d
138 else:
139 yield u, v
140 forest.union(u, v)
141
142
143@nx._dispatchable(
144 edge_attrs={"weight": None, "partition": None}, preserve_edge_attrs="data"
145)
146def kruskal_mst_edges(
147 G, minimum, weight="weight", keys=True, data=True, ignore_nan=False, partition=None
148):
149 """
150 Iterate over edge of a Kruskal's algorithm min/max spanning tree.
151
152 Parameters
153 ----------
154 G : NetworkX Graph
155 The graph holding the tree of interest.
156
157 minimum : bool (default: True)
158 Find the minimum (True) or maximum (False) spanning tree.
159
160 weight : string (default: 'weight')
161 The name of the edge attribute holding the edge weights.
162
163 keys : bool (default: True)
164 If `G` is a multigraph, `keys` controls whether edge keys ar yielded.
165 Otherwise `keys` is ignored.
166
167 data : bool (default: True)
168 Flag for whether to yield edge attribute dicts.
169 If True, yield edges `(u, v, d)`, where `d` is the attribute dict.
170 If False, yield edges `(u, v)`.
171
172 ignore_nan : bool (default: False)
173 If a NaN is found as an edge weight normally an exception is raised.
174 If `ignore_nan is True` then that edge is ignored instead.
175
176 partition : string (default: None)
177 The name of the edge attribute holding the partition data, if it exists.
178 Partition data is written to the edges using the `EdgePartition` enum.
179 If a partition exists, all included edges and none of the excluded edges
180 will appear in the final tree. Open edges may or may not be used.
181
182 Yields
183 ------
184 edge tuple
185 The edges as discovered by Kruskal's method. Each edge can
186 take the following forms: `(u, v)`, `(u, v, d)` or `(u, v, k, d)`
187 depending on the `key` and `data` parameters
188 """
189 subtrees = UnionFind()
190 if G.is_multigraph():
191 edges = G.edges(keys=True, data=True)
192 else:
193 edges = G.edges(data=True)
194
195 # Sort the edges of the graph with respect to the partition data.
196 # Edges are returned in the following order:
197
198 # * Included edges
199 # * Open edges from smallest to largest weight
200 # * Excluded edges
201 included_edges = []
202 open_edges = []
203 for e in edges:
204 d = e[-1]
205 wt = d.get(weight, 1)
206 if isnan(wt):
207 if ignore_nan:
208 continue
209 raise ValueError(f"NaN found as an edge weight. Edge {e}")
210
211 edge = (wt,) + e
212 if d.get(partition) == EdgePartition.INCLUDED:
213 included_edges.append(edge)
214 elif d.get(partition) == EdgePartition.EXCLUDED:
215 continue
216 else:
217 open_edges.append(edge)
218
219 if minimum:
220 sorted_open_edges = sorted(open_edges, key=itemgetter(0))
221 else:
222 sorted_open_edges = sorted(open_edges, key=itemgetter(0), reverse=True)
223
224 # Condense the lists into one
225 included_edges.extend(sorted_open_edges)
226 sorted_edges = included_edges
227 del open_edges, sorted_open_edges, included_edges
228
229 edges_needed = len(G) - 1
230 edges_added = 0
231
232 # Multigraphs need to handle edge keys in addition to edge data.
233 if G.is_multigraph():
234 for wt, u, v, k, d in sorted_edges:
235 if subtrees[u] != subtrees[v]:
236 if keys:
237 if data:
238 yield u, v, k, d
239 else:
240 yield u, v, k
241 else:
242 if data:
243 yield u, v, d
244 else:
245 yield u, v
246 subtrees.union(u, v)
247 edges_added += 1
248 if edges_added == edges_needed:
249 return
250 else:
251 for wt, u, v, d in sorted_edges:
252 if subtrees[u] != subtrees[v]:
253 if data:
254 yield u, v, d
255 else:
256 yield u, v
257 subtrees.union(u, v)
258 edges_added += 1
259 if edges_added == edges_needed:
260 return
261
262
263@nx._dispatchable(edge_attrs="weight", preserve_edge_attrs="data")
264def prim_mst_edges(G, minimum, weight="weight", keys=True, data=True, ignore_nan=False):
265 """Iterate over edges of Prim's algorithm min/max spanning tree.
266
267 Parameters
268 ----------
269 G : NetworkX Graph
270 The graph holding the tree of interest.
271
272 minimum : bool (default: True)
273 Find the minimum (True) or maximum (False) spanning tree.
274
275 weight : string (default: 'weight')
276 The name of the edge attribute holding the edge weights.
277
278 keys : bool (default: True)
279 If `G` is a multigraph, `keys` controls whether edge keys ar yielded.
280 Otherwise `keys` is ignored.
281
282 data : bool (default: True)
283 Flag for whether to yield edge attribute dicts.
284 If True, yield edges `(u, v, d)`, where `d` is the attribute dict.
285 If False, yield edges `(u, v)`.
286
287 ignore_nan : bool (default: False)
288 If a NaN is found as an edge weight normally an exception is raised.
289 If `ignore_nan is True` then that edge is ignored instead.
290
291 """
292 is_multigraph = G.is_multigraph()
293
294 nodes = set(G)
295 c = count()
296
297 sign = 1 if minimum else -1
298
299 while nodes:
300 u = nodes.pop()
301 frontier = []
302 visited = {u}
303 if is_multigraph:
304 for v, keydict in G.adj[u].items():
305 for k, d in keydict.items():
306 wt = d.get(weight, 1) * sign
307 if isnan(wt):
308 if ignore_nan:
309 continue
310 msg = f"NaN found as an edge weight. Edge {(u, v, k, d)}"
311 raise ValueError(msg)
312 heappush(frontier, (wt, next(c), u, v, k, d))
313 else:
314 for v, d in G.adj[u].items():
315 wt = d.get(weight, 1) * sign
316 if isnan(wt):
317 if ignore_nan:
318 continue
319 msg = f"NaN found as an edge weight. Edge {(u, v, d)}"
320 raise ValueError(msg)
321 heappush(frontier, (wt, next(c), u, v, d))
322 while nodes and frontier:
323 if is_multigraph:
324 W, _, u, v, k, d = heappop(frontier)
325 else:
326 W, _, u, v, d = heappop(frontier)
327 if v in visited or v not in nodes:
328 continue
329 # Multigraphs need to handle edge keys in addition to edge data.
330 if is_multigraph and keys:
331 if data:
332 yield u, v, k, d
333 else:
334 yield u, v, k
335 else:
336 if data:
337 yield u, v, d
338 else:
339 yield u, v
340 # update frontier
341 visited.add(v)
342 nodes.discard(v)
343 if is_multigraph:
344 for w, keydict in G.adj[v].items():
345 if w in visited:
346 continue
347 for k2, d2 in keydict.items():
348 new_weight = d2.get(weight, 1) * sign
349 if isnan(new_weight):
350 if ignore_nan:
351 continue
352 msg = f"NaN found as an edge weight. Edge {(v, w, k2, d2)}"
353 raise ValueError(msg)
354 heappush(frontier, (new_weight, next(c), v, w, k2, d2))
355 else:
356 for w, d2 in G.adj[v].items():
357 if w in visited:
358 continue
359 new_weight = d2.get(weight, 1) * sign
360 if isnan(new_weight):
361 if ignore_nan:
362 continue
363 msg = f"NaN found as an edge weight. Edge {(v, w, d2)}"
364 raise ValueError(msg)
365 heappush(frontier, (new_weight, next(c), v, w, d2))
366
367
368ALGORITHMS = {
369 "boruvka": boruvka_mst_edges,
370 "borůvka": boruvka_mst_edges,
371 "kruskal": kruskal_mst_edges,
372 "prim": prim_mst_edges,
373}
374
375
376@not_implemented_for("directed")
377@nx._dispatchable(edge_attrs="weight", preserve_edge_attrs="data")
378def minimum_spanning_edges(
379 G, algorithm="kruskal", weight="weight", keys=True, data=True, ignore_nan=False
380):
381 """Generate edges in a minimum spanning forest of an undirected
382 weighted graph.
383
384 A minimum spanning tree is a subgraph of the graph (a tree)
385 with the minimum sum of edge weights. A spanning forest is a
386 union of the spanning trees for each connected component of the graph.
387
388 Parameters
389 ----------
390 G : undirected Graph
391 An undirected graph. If `G` is connected, then the algorithm finds a
392 spanning tree. Otherwise, a spanning forest is found.
393
394 algorithm : string
395 The algorithm to use when finding a minimum spanning tree. Valid
396 choices are 'kruskal', 'prim', or 'boruvka'. The default is 'kruskal'.
397
398 weight : string
399 Edge data key to use for weight (default 'weight').
400
401 keys : bool
402 Whether to yield edge key in multigraphs in addition to the edge.
403 If `G` is not a multigraph, this is ignored.
404
405 data : bool, optional
406 If True yield the edge data along with the edge.
407
408 ignore_nan : bool (default: False)
409 If a NaN is found as an edge weight normally an exception is raised.
410 If `ignore_nan is True` then that edge is ignored instead.
411
412 Returns
413 -------
414 edges : iterator
415 An iterator over edges in a maximum spanning tree of `G`.
416 Edges connecting nodes `u` and `v` are represented as tuples:
417 `(u, v, k, d)` or `(u, v, k)` or `(u, v, d)` or `(u, v)`
418
419 If `G` is a multigraph, `keys` indicates whether the edge key `k` will
420 be reported in the third position in the edge tuple. `data` indicates
421 whether the edge datadict `d` will appear at the end of the edge tuple.
422
423 If `G` is not a multigraph, the tuples are `(u, v, d)` if `data` is True
424 or `(u, v)` if `data` is False.
425
426 Examples
427 --------
428 >>> from networkx.algorithms import tree
429
430 Find minimum spanning edges by Kruskal's algorithm
431
432 >>> G = nx.cycle_graph(4)
433 >>> G.add_edge(0, 3, weight=2)
434 >>> mst = tree.minimum_spanning_edges(G, algorithm="kruskal", data=False)
435 >>> edgelist = list(mst)
436 >>> sorted(sorted(e) for e in edgelist)
437 [[0, 1], [1, 2], [2, 3]]
438
439 Find minimum spanning edges by Prim's algorithm
440
441 >>> G = nx.cycle_graph(4)
442 >>> G.add_edge(0, 3, weight=2)
443 >>> mst = tree.minimum_spanning_edges(G, algorithm="prim", data=False)
444 >>> edgelist = list(mst)
445 >>> sorted(sorted(e) for e in edgelist)
446 [[0, 1], [1, 2], [2, 3]]
447
448 Notes
449 -----
450 For Borůvka's algorithm, each edge must have a weight attribute, and
451 each edge weight must be distinct.
452
453 For the other algorithms, if the graph edges do not have a weight
454 attribute a default weight of 1 will be used.
455
456 Modified code from David Eppstein, April 2006
457 http://www.ics.uci.edu/~eppstein/PADS/
458
459 """
460 try:
461 algo = ALGORITHMS[algorithm]
462 except KeyError as err:
463 msg = f"{algorithm} is not a valid choice for an algorithm."
464 raise ValueError(msg) from err
465
466 return algo(
467 G, minimum=True, weight=weight, keys=keys, data=data, ignore_nan=ignore_nan
468 )
469
470
471@not_implemented_for("directed")
472@nx._dispatchable(edge_attrs="weight", preserve_edge_attrs="data")
473def maximum_spanning_edges(
474 G, algorithm="kruskal", weight="weight", keys=True, data=True, ignore_nan=False
475):
476 """Generate edges in a maximum spanning forest of an undirected
477 weighted graph.
478
479 A maximum spanning tree is a subgraph of the graph (a tree)
480 with the maximum possible sum of edge weights. A spanning forest is a
481 union of the spanning trees for each connected component of the graph.
482
483 Parameters
484 ----------
485 G : undirected Graph
486 An undirected graph. If `G` is connected, then the algorithm finds a
487 spanning tree. Otherwise, a spanning forest is found.
488
489 algorithm : string
490 The algorithm to use when finding a maximum spanning tree. Valid
491 choices are 'kruskal', 'prim', or 'boruvka'. The default is 'kruskal'.
492
493 weight : string
494 Edge data key to use for weight (default 'weight').
495
496 keys : bool
497 Whether to yield edge key in multigraphs in addition to the edge.
498 If `G` is not a multigraph, this is ignored.
499
500 data : bool, optional
501 If True yield the edge data along with the edge.
502
503 ignore_nan : bool (default: False)
504 If a NaN is found as an edge weight normally an exception is raised.
505 If `ignore_nan is True` then that edge is ignored instead.
506
507 Returns
508 -------
509 edges : iterator
510 An iterator over edges in a maximum spanning tree of `G`.
511 Edges connecting nodes `u` and `v` are represented as tuples:
512 `(u, v, k, d)` or `(u, v, k)` or `(u, v, d)` or `(u, v)`
513
514 If `G` is a multigraph, `keys` indicates whether the edge key `k` will
515 be reported in the third position in the edge tuple. `data` indicates
516 whether the edge datadict `d` will appear at the end of the edge tuple.
517
518 If `G` is not a multigraph, the tuples are `(u, v, d)` if `data` is True
519 or `(u, v)` if `data` is False.
520
521 Examples
522 --------
523 >>> from networkx.algorithms import tree
524
525 Find maximum spanning edges by Kruskal's algorithm
526
527 >>> G = nx.cycle_graph(4)
528 >>> G.add_edge(0, 3, weight=2)
529 >>> mst = tree.maximum_spanning_edges(G, algorithm="kruskal", data=False)
530 >>> edgelist = list(mst)
531 >>> sorted(sorted(e) for e in edgelist)
532 [[0, 1], [0, 3], [1, 2]]
533
534 Find maximum spanning edges by Prim's algorithm
535
536 >>> G = nx.cycle_graph(4)
537 >>> G.add_edge(0, 3, weight=2) # assign weight 2 to edge 0-3
538 >>> mst = tree.maximum_spanning_edges(G, algorithm="prim", data=False)
539 >>> edgelist = list(mst)
540 >>> sorted(sorted(e) for e in edgelist)
541 [[0, 1], [0, 3], [2, 3]]
542
543 Notes
544 -----
545 For Borůvka's algorithm, each edge must have a weight attribute, and
546 each edge weight must be distinct.
547
548 For the other algorithms, if the graph edges do not have a weight
549 attribute a default weight of 1 will be used.
550
551 Modified code from David Eppstein, April 2006
552 http://www.ics.uci.edu/~eppstein/PADS/
553 """
554 try:
555 algo = ALGORITHMS[algorithm]
556 except KeyError as err:
557 msg = f"{algorithm} is not a valid choice for an algorithm."
558 raise ValueError(msg) from err
559
560 return algo(
561 G, minimum=False, weight=weight, keys=keys, data=data, ignore_nan=ignore_nan
562 )
563
564
565@nx._dispatchable(preserve_all_attrs=True, returns_graph=True)
566def minimum_spanning_tree(G, weight="weight", algorithm="kruskal", ignore_nan=False):
567 """Returns a minimum spanning tree or forest on an undirected graph `G`.
568
569 Parameters
570 ----------
571 G : undirected graph
572 An undirected graph. If `G` is connected, then the algorithm finds a
573 spanning tree. Otherwise, a spanning forest is found.
574
575 weight : str
576 Data key to use for edge weights.
577
578 algorithm : string
579 The algorithm to use when finding a minimum spanning tree. Valid
580 choices are 'kruskal', 'prim', or 'boruvka'. The default is
581 'kruskal'.
582
583 ignore_nan : bool (default: False)
584 If a NaN is found as an edge weight normally an exception is raised.
585 If `ignore_nan is True` then that edge is ignored instead.
586
587 Returns
588 -------
589 G : NetworkX Graph
590 A minimum spanning tree or forest.
591
592 Examples
593 --------
594 >>> G = nx.cycle_graph(4)
595 >>> G.add_edge(0, 3, weight=2)
596 >>> T = nx.minimum_spanning_tree(G)
597 >>> sorted(T.edges(data=True))
598 [(0, 1, {}), (1, 2, {}), (2, 3, {})]
599
600
601 Notes
602 -----
603 For Borůvka's algorithm, each edge must have a weight attribute, and
604 each edge weight must be distinct.
605
606 For the other algorithms, if the graph edges do not have a weight
607 attribute a default weight of 1 will be used.
608
609 There may be more than one tree with the same minimum or maximum weight.
610 See :mod:`networkx.tree.recognition` for more detailed definitions.
611
612 Isolated nodes with self-loops are in the tree as edgeless isolated nodes.
613
614 """
615 edges = minimum_spanning_edges(
616 G, algorithm, weight, keys=True, data=True, ignore_nan=ignore_nan
617 )
618 T = G.__class__() # Same graph class as G
619 T.graph.update(G.graph)
620 T.add_nodes_from(G.nodes.items())
621 T.add_edges_from(edges)
622 return T
623
624
625@nx._dispatchable(preserve_all_attrs=True, returns_graph=True)
626def partition_spanning_tree(
627 G, minimum=True, weight="weight", partition="partition", ignore_nan=False
628):
629 """
630 Find a spanning tree while respecting a partition of edges.
631
632 Edges can be flagged as either `INCLUDED` which are required to be in the
633 returned tree, `EXCLUDED`, which cannot be in the returned tree and `OPEN`.
634
635 This is used in the SpanningTreeIterator to create new partitions following
636 the algorithm of Sörensen and Janssens [1]_.
637
638 Parameters
639 ----------
640 G : undirected graph
641 An undirected graph.
642
643 minimum : bool (default: True)
644 Determines whether the returned tree is the minimum spanning tree of
645 the partition of the maximum one.
646
647 weight : str
648 Data key to use for edge weights.
649
650 partition : str
651 The key for the edge attribute containing the partition
652 data on the graph. Edges can be included, excluded or open using the
653 `EdgePartition` enum.
654
655 ignore_nan : bool (default: False)
656 If a NaN is found as an edge weight normally an exception is raised.
657 If `ignore_nan is True` then that edge is ignored instead.
658
659
660 Returns
661 -------
662 G : NetworkX Graph
663 A minimum spanning tree using all of the included edges in the graph and
664 none of the excluded edges.
665
666 References
667 ----------
668 .. [1] G.K. Janssens, K. Sörensen, An algorithm to generate all spanning
669 trees in order of increasing cost, Pesquisa Operacional, 2005-08,
670 Vol. 25 (2), p. 219-229,
671 https://www.scielo.br/j/pope/a/XHswBwRwJyrfL88dmMwYNWp/?lang=en
672 """
673 edges = kruskal_mst_edges(
674 G,
675 minimum,
676 weight,
677 keys=True,
678 data=True,
679 ignore_nan=ignore_nan,
680 partition=partition,
681 )
682 T = G.__class__() # Same graph class as G
683 T.graph.update(G.graph)
684 T.add_nodes_from(G.nodes.items())
685 T.add_edges_from(edges)
686 return T
687
688
689@nx._dispatchable(preserve_all_attrs=True, returns_graph=True)
690def maximum_spanning_tree(G, weight="weight", algorithm="kruskal", ignore_nan=False):
691 """Returns a maximum spanning tree or forest on an undirected graph `G`.
692
693 Parameters
694 ----------
695 G : undirected graph
696 An undirected graph. If `G` is connected, then the algorithm finds a
697 spanning tree. Otherwise, a spanning forest is found.
698
699 weight : str
700 Data key to use for edge weights.
701
702 algorithm : string
703 The algorithm to use when finding a maximum spanning tree. Valid
704 choices are 'kruskal', 'prim', or 'boruvka'. The default is
705 'kruskal'.
706
707 ignore_nan : bool (default: False)
708 If a NaN is found as an edge weight normally an exception is raised.
709 If `ignore_nan is True` then that edge is ignored instead.
710
711
712 Returns
713 -------
714 G : NetworkX Graph
715 A maximum spanning tree or forest.
716
717
718 Examples
719 --------
720 >>> G = nx.cycle_graph(4)
721 >>> G.add_edge(0, 3, weight=2)
722 >>> T = nx.maximum_spanning_tree(G)
723 >>> sorted(T.edges(data=True))
724 [(0, 1, {}), (0, 3, {'weight': 2}), (1, 2, {})]
725
726
727 Notes
728 -----
729 For Borůvka's algorithm, each edge must have a weight attribute, and
730 each edge weight must be distinct.
731
732 For the other algorithms, if the graph edges do not have a weight
733 attribute a default weight of 1 will be used.
734
735 There may be more than one tree with the same minimum or maximum weight.
736 See :mod:`networkx.tree.recognition` for more detailed definitions.
737
738 Isolated nodes with self-loops are in the tree as edgeless isolated nodes.
739
740 """
741 edges = maximum_spanning_edges(
742 G, algorithm, weight, keys=True, data=True, ignore_nan=ignore_nan
743 )
744 edges = list(edges)
745 T = G.__class__() # Same graph class as G
746 T.graph.update(G.graph)
747 T.add_nodes_from(G.nodes.items())
748 T.add_edges_from(edges)
749 return T
750
751
752@py_random_state(3)
753@nx._dispatchable(preserve_edge_attrs=True, returns_graph=True)
754def random_spanning_tree(G, weight=None, *, multiplicative=True, seed=None):
755 """
756 Sample a random spanning tree using the edges weights of `G`.
757
758 This function supports two different methods for determining the
759 probability of the graph. If ``multiplicative=True``, the probability
760 is based on the product of edge weights, and if ``multiplicative=False``
761 it is based on the sum of the edge weight. However, since it is
762 easier to determine the total weight of all spanning trees for the
763 multiplicative version, that is significantly faster and should be used if
764 possible. Additionally, setting `weight` to `None` will cause a spanning tree
765 to be selected with uniform probability.
766
767 The function uses algorithm A8 in [1]_ .
768
769 Parameters
770 ----------
771 G : nx.Graph
772 An undirected version of the original graph.
773
774 weight : string
775 The edge key for the edge attribute holding edge weight.
776
777 multiplicative : bool, default=True
778 If `True`, the probability of each tree is the product of its edge weight
779 over the sum of the product of all the spanning trees in the graph. If
780 `False`, the probability is the sum of its edge weight over the sum of
781 the sum of weights for all spanning trees in the graph.
782
783 seed : integer, random_state, or None (default)
784 Indicator of random number generation state.
785 See :ref:`Randomness<randomness>`.
786
787 Returns
788 -------
789 nx.Graph
790 A spanning tree using the distribution defined by the weight of the tree.
791
792 References
793 ----------
794 .. [1] V. Kulkarni, Generating random combinatorial objects, Journal of
795 Algorithms, 11 (1990), pp. 185–207
796 """
797
798 def find_node(merged_nodes, node):
799 """
800 We can think of clusters of contracted nodes as having one
801 representative in the graph. Each node which is not in merged_nodes
802 is still its own representative. Since a representative can be later
803 contracted, we need to recursively search though the dict to find
804 the final representative, but once we know it we can use path
805 compression to speed up the access of the representative for next time.
806
807 This cannot be replaced by the standard NetworkX union_find since that
808 data structure will merge nodes with less representing nodes into the
809 one with more representing nodes but this function requires we merge
810 them using the order that contract_edges contracts using.
811
812 Parameters
813 ----------
814 merged_nodes : dict
815 The dict storing the mapping from node to representative
816 node
817 The node whose representative we seek
818
819 Returns
820 -------
821 The representative of the `node`
822 """
823 if node not in merged_nodes:
824 return node
825 else:
826 rep = find_node(merged_nodes, merged_nodes[node])
827 merged_nodes[node] = rep
828 return rep
829
830 def prepare_graph():
831 """
832 For the graph `G`, remove all edges not in the set `V` and then
833 contract all edges in the set `U`.
834
835 Returns
836 -------
837 A copy of `G` which has had all edges not in `V` removed and all edges
838 in `U` contracted.
839 """
840
841 # The result is a MultiGraph version of G so that parallel edges are
842 # allowed during edge contraction
843 result = nx.MultiGraph(incoming_graph_data=G)
844
845 # Remove all edges not in V
846 edges_to_remove = set(result.edges()).difference(V)
847 result.remove_edges_from(edges_to_remove)
848
849 # Contract all edges in U
850 #
851 # Imagine that you have two edges to contract and they share an
852 # endpoint like this:
853 # [0] ----- [1] ----- [2]
854 # If we contract (0, 1) first, the contraction function will always
855 # delete the second node it is passed so the resulting graph would be
856 # [0] ----- [2]
857 # and edge (1, 2) no longer exists but (0, 2) would need to be contracted
858 # in its place now. That is why I use the below dict as a merge-find
859 # data structure with path compression to track how the nodes are merged.
860 merged_nodes = {}
861
862 for u, v in U:
863 u_rep = find_node(merged_nodes, u)
864 v_rep = find_node(merged_nodes, v)
865 # We cannot contract a node with itself
866 if u_rep == v_rep:
867 continue
868 nx.contracted_nodes(result, u_rep, v_rep, self_loops=False, copy=False)
869 merged_nodes[v_rep] = u_rep
870
871 return merged_nodes, result
872
873 def spanning_tree_total_weight(G, weight):
874 """
875 Find the sum of weights of the spanning trees of `G` using the
876 appropriate `method`.
877
878 This is easy if the chosen method is 'multiplicative', since we can
879 use Kirchhoff's Tree Matrix Theorem directly. However, with the
880 'additive' method, this process is slightly more complex and less
881 computationally efficient as we have to find the number of spanning
882 trees which contain each possible edge in the graph.
883
884 Parameters
885 ----------
886 G : NetworkX Graph
887 The graph to find the total weight of all spanning trees on.
888
889 weight : string
890 The key for the weight edge attribute of the graph.
891
892 Returns
893 -------
894 float
895 The sum of either the multiplicative or additive weight for all
896 spanning trees in the graph.
897 """
898 if multiplicative:
899 return number_of_spanning_trees(G, weight=weight)
900 else:
901 # There are two cases for the total spanning tree additive weight.
902 # 1. There is one edge in the graph. Then the only spanning tree is
903 # that edge itself, which will have a total weight of that edge
904 # itself.
905 if G.number_of_edges() == 1:
906 return G.edges(data=weight).__iter__().__next__()[2]
907 # 2. There are no edges or two or more edges in the graph. Then, we find the
908 # total weight of the spanning trees using the formula in the
909 # reference paper: take the weight of each edge and multiply it by
910 # the number of spanning trees which include that edge. This
911 # can be accomplished by contracting the edge and finding the
912 # multiplicative total spanning tree weight if the weight of each edge
913 # is assumed to be 1, which is conveniently built into networkx already,
914 # by calling number_of_spanning_trees with weight=None.
915 # Note that with no edges the returned value is just zero.
916 else:
917 total = 0
918 for u, v, w in G.edges(data=weight):
919 total += w * nx.number_of_spanning_trees(
920 nx.contracted_edge(G, edge=(u, v), self_loops=False),
921 weight=None,
922 )
923 return total
924
925 if G.number_of_nodes() < 2:
926 # no edges in the spanning tree
927 return nx.empty_graph(G.nodes)
928
929 U = set()
930 st_cached_value = 0
931 V = set(G.edges())
932 shuffled_edges = list(G.edges())
933 seed.shuffle(shuffled_edges)
934
935 for u, v in shuffled_edges:
936 e_weight = G[u][v][weight] if weight is not None else 1
937 node_map, prepared_G = prepare_graph()
938 G_total_tree_weight = spanning_tree_total_weight(prepared_G, weight)
939 # Add the edge to U so that we can compute the total tree weight
940 # assuming we include that edge
941 # Now, if (u, v) cannot exist in G because it is fully contracted out
942 # of existence, then it by definition cannot influence G_e's Kirchhoff
943 # value. But, we also cannot pick it.
944 rep_edge = (find_node(node_map, u), find_node(node_map, v))
945 # Check to see if the 'representative edge' for the current edge is
946 # in prepared_G. If so, then we can pick it.
947 if rep_edge in prepared_G.edges:
948 prepared_G_e = nx.contracted_edge(
949 prepared_G, edge=rep_edge, self_loops=False
950 )
951 G_e_total_tree_weight = spanning_tree_total_weight(prepared_G_e, weight)
952 if multiplicative:
953 threshold = e_weight * G_e_total_tree_weight / G_total_tree_weight
954 else:
955 numerator = (st_cached_value + e_weight) * nx.number_of_spanning_trees(
956 prepared_G_e
957 ) + G_e_total_tree_weight
958 denominator = (
959 st_cached_value * nx.number_of_spanning_trees(prepared_G)
960 + G_total_tree_weight
961 )
962 threshold = numerator / denominator
963 else:
964 threshold = 0.0
965 z = seed.uniform(0.0, 1.0)
966 if z > threshold:
967 # Remove the edge from V since we did not pick it.
968 V.remove((u, v))
969 else:
970 # Add the edge to U since we picked it.
971 st_cached_value += e_weight
972 U.add((u, v))
973 # If we decide to keep an edge, it may complete the spanning tree.
974 if len(U) == G.number_of_nodes() - 1:
975 spanning_tree = nx.Graph()
976 spanning_tree.add_edges_from(U)
977 return spanning_tree
978 raise Exception(f"Something went wrong! Only {len(U)} edges in the spanning tree!")
979
980
981class SpanningTreeIterator:
982 """
983 Iterate over all spanning trees of a graph in either increasing or
984 decreasing cost.
985
986 Notes
987 -----
988 This iterator uses the partition scheme from [1]_ (included edges,
989 excluded edges and open edges) as well as a modified Kruskal's Algorithm
990 to generate minimum spanning trees which respect the partition of edges.
991 For spanning trees with the same weight, ties are broken arbitrarily.
992
993 References
994 ----------
995 .. [1] G.K. Janssens, K. Sörensen, An algorithm to generate all spanning
996 trees in order of increasing cost, Pesquisa Operacional, 2005-08,
997 Vol. 25 (2), p. 219-229,
998 https://www.scielo.br/j/pope/a/XHswBwRwJyrfL88dmMwYNWp/?lang=en
999 """
1000
1001 @dataclass(order=True)
1002 class Partition:
1003 """
1004 This dataclass represents a partition and stores a dict with the edge
1005 data and the weight of the minimum spanning tree of the partition dict.
1006 """
1007
1008 mst_weight: float
1009 partition_dict: dict = field(compare=False)
1010
1011 def __copy__(self):
1012 return SpanningTreeIterator.Partition(
1013 self.mst_weight, self.partition_dict.copy()
1014 )
1015
1016 def __init__(self, G, weight="weight", minimum=True, ignore_nan=False):
1017 """
1018 Initialize the iterator
1019
1020 Parameters
1021 ----------
1022 G : nx.Graph
1023 The directed graph which we need to iterate trees over
1024
1025 weight : String, default = "weight"
1026 The edge attribute used to store the weight of the edge
1027
1028 minimum : bool, default = True
1029 Return the trees in increasing order while true and decreasing order
1030 while false.
1031
1032 ignore_nan : bool, default = False
1033 If a NaN is found as an edge weight normally an exception is raised.
1034 If `ignore_nan is True` then that edge is ignored instead.
1035 """
1036 self.G = G.copy()
1037 self.G.__networkx_cache__ = None # Disable caching
1038 self.weight = weight
1039 self.minimum = minimum
1040 self.ignore_nan = ignore_nan
1041 # Randomly create a key for an edge attribute to hold the partition data
1042 self.partition_key = (
1043 "SpanningTreeIterators super secret partition attribute name"
1044 )
1045
1046 def __iter__(self):
1047 """
1048 Returns
1049 -------
1050 SpanningTreeIterator
1051 The iterator object for this graph
1052 """
1053 self.partition_queue = PriorityQueue()
1054 self._clear_partition(self.G)
1055 mst_weight = partition_spanning_tree(
1056 self.G, self.minimum, self.weight, self.partition_key, self.ignore_nan
1057 ).size(weight=self.weight)
1058
1059 self.partition_queue.put(
1060 self.Partition(mst_weight if self.minimum else -mst_weight, {})
1061 )
1062
1063 return self
1064
1065 def __next__(self):
1066 """
1067 Returns
1068 -------
1069 (multi)Graph
1070 The spanning tree of next greatest weight, which ties broken
1071 arbitrarily.
1072 """
1073 if self.partition_queue.empty():
1074 del self.G, self.partition_queue
1075 raise StopIteration
1076
1077 partition = self.partition_queue.get()
1078 self._write_partition(partition)
1079 next_tree = partition_spanning_tree(
1080 self.G, self.minimum, self.weight, self.partition_key, self.ignore_nan
1081 )
1082 self._partition(partition, next_tree)
1083
1084 self._clear_partition(next_tree)
1085 return next_tree
1086
1087 def _partition(self, partition, partition_tree):
1088 """
1089 Create new partitions based of the minimum spanning tree of the
1090 current minimum partition.
1091
1092 Parameters
1093 ----------
1094 partition : Partition
1095 The Partition instance used to generate the current minimum spanning
1096 tree.
1097 partition_tree : nx.Graph
1098 The minimum spanning tree of the input partition.
1099 """
1100 # create two new partitions with the data from the input partition dict
1101 p1 = self.Partition(0, partition.partition_dict.copy())
1102 p2 = self.Partition(0, partition.partition_dict.copy())
1103 for e in partition_tree.edges:
1104 # determine if the edge was open or included
1105 if e not in partition.partition_dict:
1106 # This is an open edge
1107 p1.partition_dict[e] = EdgePartition.EXCLUDED
1108 p2.partition_dict[e] = EdgePartition.INCLUDED
1109
1110 self._write_partition(p1)
1111 p1_mst = partition_spanning_tree(
1112 self.G,
1113 self.minimum,
1114 self.weight,
1115 self.partition_key,
1116 self.ignore_nan,
1117 )
1118 p1_mst_weight = p1_mst.size(weight=self.weight)
1119 if nx.is_connected(p1_mst):
1120 p1.mst_weight = p1_mst_weight if self.minimum else -p1_mst_weight
1121 self.partition_queue.put(p1.__copy__())
1122 p1.partition_dict = p2.partition_dict.copy()
1123
1124 def _write_partition(self, partition):
1125 """
1126 Writes the desired partition into the graph to calculate the minimum
1127 spanning tree.
1128
1129 Parameters
1130 ----------
1131 partition : Partition
1132 A Partition dataclass describing a partition on the edges of the
1133 graph.
1134 """
1135
1136 partition_dict = partition.partition_dict
1137 partition_key = self.partition_key
1138 G = self.G
1139
1140 edges = (
1141 G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True)
1142 )
1143 for *e, d in edges:
1144 d[partition_key] = partition_dict.get(tuple(e), EdgePartition.OPEN)
1145
1146 def _clear_partition(self, G):
1147 """
1148 Removes partition data from the graph
1149 """
1150 partition_key = self.partition_key
1151 edges = (
1152 G.edges(keys=True, data=True) if G.is_multigraph() else G.edges(data=True)
1153 )
1154 for *e, d in edges:
1155 if partition_key in d:
1156 del d[partition_key]
1157
1158
1159@nx._dispatchable(edge_attrs="weight")
1160def number_of_spanning_trees(G, *, root=None, weight=None):
1161 """Returns the number of spanning trees in `G`.
1162
1163 A spanning tree for an undirected graph is a tree that connects
1164 all nodes in the graph. For a directed graph, the analog of a
1165 spanning tree is called a (spanning) arborescence. The arborescence
1166 includes a unique directed path from the `root` node to each other node.
1167 The graph must be weakly connected, and the root must be a node
1168 that includes all nodes as successors [3]_. Note that to avoid
1169 discussing sink-roots and reverse-arborescences, we have reversed
1170 the edge orientation from [3]_ and use the in-degree laplacian.
1171
1172 This function (when `weight` is `None`) returns the number of
1173 spanning trees for an undirected graph and the number of
1174 arborescences from a single root node for a directed graph.
1175 When `weight` is the name of an edge attribute which holds the
1176 weight value of each edge, the function returns the sum over
1177 all trees of the multiplicative weight of each tree. That is,
1178 the weight of the tree is the product of its edge weights.
1179
1180 Kirchoff's Tree Matrix Theorem states that any cofactor of the
1181 Laplacian matrix of a graph is the number of spanning trees in the
1182 graph. (Here we use cofactors for a diagonal entry so that the
1183 cofactor becomes the determinant of the matrix with one row
1184 and its matching column removed.) For a weighted Laplacian matrix,
1185 the cofactor is the sum across all spanning trees of the
1186 multiplicative weight of each tree. That is, the weight of each
1187 tree is the product of its edge weights. The theorem is also
1188 known as Kirchhoff's theorem [1]_ and the Matrix-Tree theorem [2]_.
1189
1190 For directed graphs, a similar theorem (Tutte's Theorem) holds with
1191 the cofactor chosen to be the one with row and column removed that
1192 correspond to the root. The cofactor is the number of arborescences
1193 with the specified node as root. And the weighted version gives the
1194 sum of the arborescence weights with root `root`. The arborescence
1195 weight is the product of its edge weights.
1196
1197 Parameters
1198 ----------
1199 G : NetworkX graph
1200
1201 root : node
1202 A node in the directed graph `G` that has all nodes as descendants.
1203 (This is ignored for undirected graphs.)
1204
1205 weight : string or None, optional (default=None)
1206 The name of the edge attribute holding the edge weight.
1207 If `None`, then each edge is assumed to have a weight of 1.
1208
1209 Returns
1210 -------
1211 Number
1212 Undirected graphs:
1213 The number of spanning trees of the graph `G`.
1214 Or the sum of all spanning tree weights of the graph `G`
1215 where the weight of a tree is the product of its edge weights.
1216 Directed graphs:
1217 The number of arborescences of `G` rooted at node `root`.
1218 Or the sum of all arborescence weights of the graph `G` with
1219 specified root where the weight of an arborescence is the product
1220 of its edge weights.
1221
1222 Raises
1223 ------
1224 NetworkXPointlessConcept
1225 If `G` does not contain any nodes.
1226
1227 NetworkXError
1228 If the graph `G` is directed and the root node
1229 is not specified or is not in G.
1230
1231 Examples
1232 --------
1233 >>> G = nx.complete_graph(5)
1234 >>> round(nx.number_of_spanning_trees(G))
1235 125
1236
1237 >>> G = nx.Graph()
1238 >>> G.add_edge(1, 2, weight=2)
1239 >>> G.add_edge(1, 3, weight=1)
1240 >>> G.add_edge(2, 3, weight=1)
1241 >>> round(nx.number_of_spanning_trees(G, weight="weight"))
1242 5
1243
1244 Notes
1245 -----
1246 Self-loops are excluded. Multi-edges are contracted in one edge
1247 equal to the sum of the weights.
1248
1249 References
1250 ----------
1251 .. [1] Wikipedia
1252 "Kirchhoff's theorem."
1253 https://en.wikipedia.org/wiki/Kirchhoff%27s_theorem
1254 .. [2] Kirchhoff, G. R.
1255 Über die Auflösung der Gleichungen, auf welche man
1256 bei der Untersuchung der linearen Vertheilung
1257 Galvanischer Ströme geführt wird
1258 Annalen der Physik und Chemie, vol. 72, pp. 497-508, 1847.
1259 .. [3] Margoliash, J.
1260 "Matrix-Tree Theorem for Directed Graphs"
1261 https://www.math.uchicago.edu/~may/VIGRE/VIGRE2010/REUPapers/Margoliash.pdf
1262 """
1263 import numpy as np
1264
1265 if len(G) == 0:
1266 raise nx.NetworkXPointlessConcept("Graph G must contain at least one node.")
1267
1268 # undirected G
1269 if not nx.is_directed(G):
1270 if not nx.is_connected(G):
1271 return 0
1272 G_laplacian = nx.laplacian_matrix(G, weight=weight).toarray()
1273 return float(np.linalg.det(G_laplacian[1:, 1:]))
1274
1275 # directed G
1276 if root is None:
1277 raise nx.NetworkXError("Input `root` must be provided when G is directed")
1278 if root not in G:
1279 raise nx.NetworkXError("The node root is not in the graph G.")
1280 if not nx.is_weakly_connected(G):
1281 return 0
1282
1283 # Compute directed Laplacian matrix
1284 nodelist = [root] + [n for n in G if n != root]
1285 A = nx.adjacency_matrix(G, nodelist=nodelist, weight=weight)
1286 D = np.diag(A.sum(axis=0))
1287 G_laplacian = D - A
1288
1289 # Compute number of spanning trees
1290 return float(np.linalg.det(G_laplacian[1:, 1:]))