1"""
2Algorithms for finding optimum branchings and spanning arborescences.
3
4This implementation is based on:
5
6 J. Edmonds, Optimum branchings, J. Res. Natl. Bur. Standards 71B (1967),
7 233–240. URL: http://archive.org/details/jresv71Bn4p233
8
9"""
10
11# TODO: Implement method from Gabow, Galil, Spence and Tarjan:
12#
13# @article{
14# year={1986},
15# issn={0209-9683},
16# journal={Combinatorica},
17# volume={6},
18# number={2},
19# doi={10.1007/BF02579168},
20# title={Efficient algorithms for finding minimum spanning trees in
21# undirected and directed graphs},
22# url={https://doi.org/10.1007/BF02579168},
23# publisher={Springer-Verlag},
24# keywords={68 B 15; 68 C 05},
25# author={Gabow, Harold N. and Galil, Zvi and Spencer, Thomas and Tarjan,
26# Robert E.},
27# pages={109-122},
28# language={English}
29# }
30import string
31from dataclasses import dataclass, field
32from operator import itemgetter
33from queue import PriorityQueue
34
35import networkx as nx
36from networkx.utils import py_random_state
37
38from .recognition import is_arborescence, is_branching
39
40__all__ = [
41 "branching_weight",
42 "greedy_branching",
43 "maximum_branching",
44 "minimum_branching",
45 "minimal_branching",
46 "maximum_spanning_arborescence",
47 "minimum_spanning_arborescence",
48 "ArborescenceIterator",
49]
50
51KINDS = {"max", "min"}
52
53STYLES = {
54 "branching": "branching",
55 "arborescence": "arborescence",
56 "spanning arborescence": "arborescence",
57}
58
59INF = float("inf")
60
61
62@py_random_state(1)
63def random_string(L=15, seed=None):
64 return "".join([seed.choice(string.ascii_letters) for n in range(L)])
65
66
67def _min_weight(weight):
68 return -weight
69
70
71def _max_weight(weight):
72 return weight
73
74
75@nx._dispatchable(edge_attrs={"attr": "default"})
76def branching_weight(G, attr="weight", default=1):
77 """
78 Returns the total weight of a branching.
79
80 You must access this function through the networkx.algorithms.tree module.
81
82 Parameters
83 ----------
84 G : DiGraph
85 The directed graph.
86 attr : str
87 The attribute to use as weights. If None, then each edge will be
88 treated equally with a weight of 1.
89 default : float
90 When `attr` is not None, then if an edge does not have that attribute,
91 `default` specifies what value it should take.
92
93 Returns
94 -------
95 weight: int or float
96 The total weight of the branching.
97
98 Examples
99 --------
100 >>> G = nx.DiGraph()
101 >>> G.add_weighted_edges_from([(0, 1, 2), (1, 2, 4), (2, 3, 3), (3, 4, 2)])
102 >>> nx.tree.branching_weight(G)
103 11
104
105 """
106 return sum(edge[2].get(attr, default) for edge in G.edges(data=True))
107
108
109@py_random_state(4)
110@nx._dispatchable(edge_attrs={"attr": "default"}, returns_graph=True)
111def greedy_branching(G, attr="weight", default=1, kind="max", seed=None):
112 """
113 Returns a branching obtained through a greedy algorithm.
114
115 This algorithm is wrong, and cannot give a proper optimal branching.
116 However, we include it for pedagogical reasons, as it can be helpful to
117 see what its outputs are.
118
119 The output is a branching, and possibly, a spanning arborescence. However,
120 it is not guaranteed to be optimal in either case.
121
122 Parameters
123 ----------
124 G : DiGraph
125 The directed graph to scan.
126 attr : str
127 The attribute to use as weights. If None, then each edge will be
128 treated equally with a weight of 1.
129 default : float
130 When `attr` is not None, then if an edge does not have that attribute,
131 `default` specifies what value it should take.
132 kind : str
133 The type of optimum to search for: 'min' or 'max' greedy branching.
134 seed : integer, random_state, or None (default)
135 Indicator of random number generation state.
136 See :ref:`Randomness<randomness>`.
137
138 Returns
139 -------
140 B : directed graph
141 The greedily obtained branching.
142
143 """
144 if kind not in KINDS:
145 raise nx.NetworkXException("Unknown value for `kind`.")
146
147 if kind == "min":
148 reverse = False
149 else:
150 reverse = True
151
152 if attr is None:
153 # Generate a random string the graph probably won't have.
154 attr = random_string(seed=seed)
155
156 edges = [(u, v, data.get(attr, default)) for (u, v, data) in G.edges(data=True)]
157
158 # We sort by weight, but also by nodes to normalize behavior across runs.
159 try:
160 edges.sort(key=itemgetter(2, 0, 1), reverse=reverse)
161 except TypeError:
162 # This will fail in Python 3.x if the nodes are of varying types.
163 # In that case, we use the arbitrary order.
164 edges.sort(key=itemgetter(2), reverse=reverse)
165
166 # The branching begins with a forest of no edges.
167 B = nx.DiGraph()
168 B.add_nodes_from(G)
169
170 # Now we add edges greedily so long we maintain the branching.
171 uf = nx.utils.UnionFind()
172 for i, (u, v, w) in enumerate(edges):
173 if uf[u] == uf[v]:
174 # Adding this edge would form a directed cycle.
175 continue
176 elif B.in_degree(v) == 1:
177 # The edge would increase the degree to be greater than one.
178 continue
179 else:
180 # If attr was None, then don't insert weights...
181 data = {}
182 if attr is not None:
183 data[attr] = w
184 B.add_edge(u, v, **data)
185 uf.union(u, v)
186
187 return B
188
189
190@nx._dispatchable(preserve_edge_attrs=True, returns_graph=True)
191def maximum_branching(
192 G,
193 attr="weight",
194 default=1,
195 preserve_attrs=False,
196 partition=None,
197):
198 #######################################
199 ### Data Structure Helper Functions ###
200 #######################################
201
202 def edmonds_add_edge(G, edge_index, u, v, key, **d):
203 """
204 Adds an edge to `G` while also updating the edge index.
205
206 This algorithm requires the use of an external dictionary to track
207 the edge keys since it is possible that the source or destination
208 node of an edge will be changed and the default key-handling
209 capabilities of the MultiDiGraph class do not account for this.
210
211 Parameters
212 ----------
213 G : MultiDiGraph
214 The graph to insert an edge into.
215 edge_index : dict
216 A mapping from integers to the edges of the graph.
217 u : node
218 The source node of the new edge.
219 v : node
220 The destination node of the new edge.
221 key : int
222 The key to use from `edge_index`.
223 d : keyword arguments, optional
224 Other attributes to store on the new edge.
225 """
226
227 if key in edge_index:
228 uu, vv, _ = edge_index[key]
229 if (u != uu) or (v != vv):
230 raise Exception(f"Key {key!r} is already in use.")
231
232 G.add_edge(u, v, key, **d)
233 edge_index[key] = (u, v, G.succ[u][v][key])
234
235 def edmonds_remove_node(G, edge_index, n):
236 """
237 Remove a node from the graph, updating the edge index to match.
238
239 Parameters
240 ----------
241 G : MultiDiGraph
242 The graph to remove an edge from.
243 edge_index : dict
244 A mapping from integers to the edges of the graph.
245 n : node
246 The node to remove from `G`.
247 """
248 keys = set()
249 for keydict in G.pred[n].values():
250 keys.update(keydict)
251 for keydict in G.succ[n].values():
252 keys.update(keydict)
253
254 for key in keys:
255 del edge_index[key]
256
257 G.remove_node(n)
258
259 #######################
260 ### Algorithm Setup ###
261 #######################
262
263 # Pick an attribute name that the original graph is unlikly to have
264 candidate_attr = "edmonds' secret candidate attribute"
265 new_node_base_name = "edmonds new node base name "
266
267 G_original = G
268 G = nx.MultiDiGraph()
269 G.__networkx_cache__ = None # Disable caching
270
271 # A dict to reliably track mutations to the edges using the key of the edge.
272 G_edge_index = {}
273 # Each edge is given an arbitrary numerical key
274 for key, (u, v, data) in enumerate(G_original.edges(data=True)):
275 d = {attr: data.get(attr, default)}
276
277 if data.get(partition) is not None:
278 d[partition] = data.get(partition)
279
280 if preserve_attrs:
281 for d_k, d_v in data.items():
282 if d_k != attr:
283 d[d_k] = d_v
284
285 edmonds_add_edge(G, G_edge_index, u, v, key, **d)
286
287 level = 0 # Stores the number of contracted nodes
288
289 # These are the buckets from the paper.
290 #
291 # In the paper, G^i are modified versions of the original graph.
292 # D^i and E^i are the nodes and edges of the maximal edges that are
293 # consistent with G^i. In this implementation, D^i and E^i are stored
294 # together as the graph B^i. We will have strictly more B^i then the
295 # paper will have.
296 #
297 # Note that the data in graphs and branchings are tuples with the graph as
298 # the first element and the edge index as the second.
299 B = nx.MultiDiGraph()
300 B_edge_index = {}
301 graphs = [] # G^i list
302 branchings = [] # B^i list
303 selected_nodes = set() # D^i bucket
304 uf = nx.utils.UnionFind()
305
306 # A list of lists of edge indices. Each list is a circuit for graph G^i.
307 # Note the edge list is not required to be a circuit in G^0.
308 circuits = []
309
310 # Stores the index of the minimum edge in the circuit found in G^i and B^i.
311 # The ordering of the edges seems to preserver the weight ordering from
312 # G^0. So even if the circuit does not form a circuit in G^0, it is still
313 # true that the minimum edges in circuit G^0 (despite their weights being
314 # different)
315 minedge_circuit = []
316
317 ###########################
318 ### Algorithm Structure ###
319 ###########################
320
321 # Each step listed in the algorithm is an inner function. Thus, the overall
322 # loop structure is:
323 #
324 # while True:
325 # step_I1()
326 # if cycle detected:
327 # step_I2()
328 # elif every node of G is in D and E is a branching:
329 # break
330
331 ##################################
332 ### Algorithm Helper Functions ###
333 ##################################
334
335 def edmonds_find_desired_edge(v):
336 """
337 Find the edge directed towards v with maximal weight.
338
339 If an edge partition exists in this graph, return the included
340 edge if it exists and never return any excluded edge.
341
342 Note: There can only be one included edge for each vertex otherwise
343 the edge partition is empty.
344
345 Parameters
346 ----------
347 v : node
348 The node to search for the maximal weight incoming edge.
349 """
350 edge = None
351 max_weight = -INF
352 for u, _, key, data in G.in_edges(v, data=True, keys=True):
353 # Skip excluded edges
354 if data.get(partition) == nx.EdgePartition.EXCLUDED:
355 continue
356
357 new_weight = data[attr]
358
359 # Return the included edge
360 if data.get(partition) == nx.EdgePartition.INCLUDED:
361 max_weight = new_weight
362 edge = (u, v, key, new_weight, data)
363 break
364
365 # Find the best open edge
366 if new_weight > max_weight:
367 max_weight = new_weight
368 edge = (u, v, key, new_weight, data)
369
370 return edge, max_weight
371
372 def edmonds_step_I2(v, desired_edge, level):
373 """
374 Perform step I2 from Edmonds' paper
375
376 First, check if the last step I1 created a cycle. If it did not, do nothing.
377 If it did, store the cycle for later reference and contract it.
378
379 Parameters
380 ----------
381 v : node
382 The current node to consider
383 desired_edge : edge
384 The minimum desired edge to remove from the cycle.
385 level : int
386 The current level, i.e. the number of cycles that have already been removed.
387 """
388 u = desired_edge[0]
389
390 Q_nodes = nx.shortest_path(B, v, u)
391 Q_edges = [
392 list(B[Q_nodes[i]][vv].keys())[0] for i, vv in enumerate(Q_nodes[1:])
393 ]
394 Q_edges.append(desired_edge[2]) # Add the new edge key to complete the circuit
395
396 # Get the edge in the circuit with the minimum weight.
397 # Also, save the incoming weights for each node.
398 minweight = INF
399 minedge = None
400 Q_incoming_weight = {}
401 for edge_key in Q_edges:
402 u, v, data = B_edge_index[edge_key]
403 w = data[attr]
404 # We cannot remove an included edge, even if it is the
405 # minimum edge in the circuit
406 Q_incoming_weight[v] = w
407 if data.get(partition) == nx.EdgePartition.INCLUDED:
408 continue
409 if w < minweight:
410 minweight = w
411 minedge = edge_key
412
413 circuits.append(Q_edges)
414 minedge_circuit.append(minedge)
415 graphs.append((G.copy(), G_edge_index.copy()))
416 branchings.append((B.copy(), B_edge_index.copy()))
417
418 # Mutate the graph to contract the circuit
419 new_node = new_node_base_name + str(level)
420 G.add_node(new_node)
421 new_edges = []
422 for u, v, key, data in G.edges(data=True, keys=True):
423 if u in Q_incoming_weight:
424 if v in Q_incoming_weight:
425 # Circuit edge. For the moment do nothing,
426 # eventually it will be removed.
427 continue
428 else:
429 # Outgoing edge from a node in the circuit.
430 # Make it come from the new node instead
431 dd = data.copy()
432 new_edges.append((new_node, v, key, dd))
433 else:
434 if v in Q_incoming_weight:
435 # Incoming edge to the circuit.
436 # Update it's weight
437 w = data[attr]
438 w += minweight - Q_incoming_weight[v]
439 dd = data.copy()
440 dd[attr] = w
441 new_edges.append((u, new_node, key, dd))
442 else:
443 # Outside edge. No modification needed
444 continue
445
446 for node in Q_nodes:
447 edmonds_remove_node(G, G_edge_index, node)
448 edmonds_remove_node(B, B_edge_index, node)
449
450 selected_nodes.difference_update(set(Q_nodes))
451
452 for u, v, key, data in new_edges:
453 edmonds_add_edge(G, G_edge_index, u, v, key, **data)
454 if candidate_attr in data:
455 del data[candidate_attr]
456 edmonds_add_edge(B, B_edge_index, u, v, key, **data)
457 uf.union(u, v)
458
459 def is_root(G, u, edgekeys):
460 """
461 Returns True if `u` is a root node in G.
462
463 Node `u` is a root node if its in-degree over the specified edges is zero.
464
465 Parameters
466 ----------
467 G : Graph
468 The current graph.
469 u : node
470 The node in `G` to check if it is a root.
471 edgekeys : iterable of edges
472 The edges for which to check if `u` is a root of.
473 """
474 if u not in G:
475 raise Exception(f"{u!r} not in G")
476
477 for v in G.pred[u]:
478 for edgekey in G.pred[u][v]:
479 if edgekey in edgekeys:
480 return False, edgekey
481 else:
482 return True, None
483
484 nodes = iter(list(G.nodes))
485 while True:
486 try:
487 v = next(nodes)
488 except StopIteration:
489 # If there are no more new nodes to consider, then we should
490 # meet stopping condition (b) from the paper:
491 # (b) every node of G^i is in D^i and E^i is a branching
492 assert len(G) == len(B)
493 if len(B):
494 assert is_branching(B)
495
496 graphs.append((G.copy(), G_edge_index.copy()))
497 branchings.append((B.copy(), B_edge_index.copy()))
498 circuits.append([])
499 minedge_circuit.append(None)
500
501 break
502 else:
503 #####################
504 ### BEGIN STEP I1 ###
505 #####################
506
507 # This is a very simple step, so I don't think it needs a method of it's own
508 if v in selected_nodes:
509 continue
510
511 selected_nodes.add(v)
512 B.add_node(v)
513 desired_edge, desired_edge_weight = edmonds_find_desired_edge(v)
514
515 # There might be no desired edge if all edges are excluded or
516 # v is the last node to be added to B, the ultimate root of the branching
517 if desired_edge is not None and desired_edge_weight > 0:
518 u = desired_edge[0]
519 # Flag adding the edge will create a circuit before merging the two
520 # connected components of u and v in B
521 circuit = uf[u] == uf[v]
522 dd = {attr: desired_edge_weight}
523 if desired_edge[4].get(partition) is not None:
524 dd[partition] = desired_edge[4].get(partition)
525
526 edmonds_add_edge(B, B_edge_index, u, v, desired_edge[2], **dd)
527 G[u][v][desired_edge[2]][candidate_attr] = True
528 uf.union(u, v)
529
530 ###################
531 ### END STEP I1 ###
532 ###################
533
534 #####################
535 ### BEGIN STEP I2 ###
536 #####################
537
538 if circuit:
539 edmonds_step_I2(v, desired_edge, level)
540 nodes = iter(list(G.nodes()))
541 level += 1
542
543 ###################
544 ### END STEP I2 ###
545 ###################
546
547 #####################
548 ### BEGIN STEP I3 ###
549 #####################
550
551 # Create a new graph of the same class as the input graph
552 H = G_original.__class__()
553
554 # Start with the branching edges in the last level.
555 edges = set(branchings[level][1])
556 while level > 0:
557 level -= 1
558
559 # The current level is i, and we start counting from 0.
560 #
561 # We need the node at level i+1 that results from merging a circuit
562 # at level i. basename_0 is the first merged node and this happens
563 # at level 1. That is basename_0 is a node at level 1 that results
564 # from merging a circuit at level 0.
565
566 merged_node = new_node_base_name + str(level)
567 circuit = circuits[level]
568 isroot, edgekey = is_root(graphs[level + 1][0], merged_node, edges)
569 edges.update(circuit)
570
571 if isroot:
572 minedge = minedge_circuit[level]
573 if minedge is None:
574 raise Exception
575
576 # Remove the edge in the cycle with minimum weight
577 edges.remove(minedge)
578 else:
579 # We have identified an edge at the next higher level that
580 # transitions into the merged node at this level. That edge
581 # transitions to some corresponding node at the current level.
582 #
583 # We want to remove an edge from the cycle that transitions
584 # into the corresponding node, otherwise the result would not
585 # be a branching.
586
587 G, G_edge_index = graphs[level]
588 target = G_edge_index[edgekey][1]
589 for edgekey in circuit:
590 u, v, data = G_edge_index[edgekey]
591 if v == target:
592 break
593 else:
594 raise Exception("Couldn't find edge incoming to merged node.")
595
596 edges.remove(edgekey)
597
598 H.add_nodes_from(G_original)
599 for edgekey in edges:
600 u, v, d = graphs[0][1][edgekey]
601 dd = {attr: d[attr]}
602
603 if preserve_attrs:
604 for key, value in d.items():
605 if key not in [attr, candidate_attr]:
606 dd[key] = value
607
608 H.add_edge(u, v, **dd)
609
610 ###################
611 ### END STEP I3 ###
612 ###################
613
614 return H
615
616
617@nx._dispatchable(preserve_edge_attrs=True, mutates_input=True, returns_graph=True)
618def minimum_branching(
619 G, attr="weight", default=1, preserve_attrs=False, partition=None
620):
621 for _, _, d in G.edges(data=True):
622 d[attr] = -d.get(attr, default)
623 nx._clear_cache(G)
624
625 B = maximum_branching(G, attr, default, preserve_attrs, partition)
626
627 for _, _, d in G.edges(data=True):
628 d[attr] = -d.get(attr, default)
629 nx._clear_cache(G)
630
631 for _, _, d in B.edges(data=True):
632 d[attr] = -d.get(attr, default)
633 nx._clear_cache(B)
634
635 return B
636
637
638@nx._dispatchable(preserve_edge_attrs=True, mutates_input=True, returns_graph=True)
639def minimal_branching(
640 G, /, *, attr="weight", default=1, preserve_attrs=False, partition=None
641):
642 """
643 Returns a minimal branching from `G`.
644
645 A minimal branching is a branching similar to a minimal arborescence but
646 without the requirement that the result is actually a spanning arborescence.
647 This allows minimal branchinges to be computed over graphs which may not
648 have arborescence (such as multiple components).
649
650 Parameters
651 ----------
652 G : (multi)digraph-like
653 The graph to be searched.
654 attr : str
655 The edge attribute used in determining optimality.
656 default : float
657 The value of the edge attribute used if an edge does not have
658 the attribute `attr`.
659 preserve_attrs : bool
660 If True, preserve the other attributes of the original graph (that are not
661 passed to `attr`)
662 partition : str
663 The key for the edge attribute containing the partition
664 data on the graph. Edges can be included, excluded or open using the
665 `EdgePartition` enum.
666
667 Returns
668 -------
669 B : (multi)digraph-like
670 A minimal branching.
671 """
672 max_weight = -INF
673 min_weight = INF
674 for _, _, w in G.edges(data=attr, default=default):
675 if w > max_weight:
676 max_weight = w
677 if w < min_weight:
678 min_weight = w
679
680 for _, _, d in G.edges(data=True):
681 # Transform the weights so that the minimum weight is larger than
682 # the difference between the max and min weights. This is important
683 # in order to prevent the edge weights from becoming negative during
684 # computation
685 d[attr] = max_weight + 1 + (max_weight - min_weight) - d.get(attr, default)
686 nx._clear_cache(G)
687
688 B = maximum_branching(G, attr, default, preserve_attrs, partition)
689
690 # Reverse the weight transformations
691 for _, _, d in G.edges(data=True):
692 d[attr] = max_weight + 1 + (max_weight - min_weight) - d.get(attr, default)
693 nx._clear_cache(G)
694
695 for _, _, d in B.edges(data=True):
696 d[attr] = max_weight + 1 + (max_weight - min_weight) - d.get(attr, default)
697 nx._clear_cache(B)
698
699 return B
700
701
702@nx._dispatchable(preserve_edge_attrs=True, mutates_input=True, returns_graph=True)
703def maximum_spanning_arborescence(
704 G, attr="weight", default=1, preserve_attrs=False, partition=None
705):
706 # In order to use the same algorithm is the maximum branching, we need to adjust
707 # the weights of the graph. The branching algorithm can choose to not include an
708 # edge if it doesn't help find a branching, mainly triggered by edges with negative
709 # weights.
710 #
711 # To prevent this from happening while trying to find a spanning arborescence, we
712 # just have to tweak the edge weights so that they are all positive and cannot
713 # become negative during the branching algorithm, find the maximum branching and
714 # then return them to their original values.
715
716 min_weight = INF
717 max_weight = -INF
718 for _, _, w in G.edges(data=attr, default=default):
719 if w < min_weight:
720 min_weight = w
721 if w > max_weight:
722 max_weight = w
723
724 for _, _, d in G.edges(data=True):
725 d[attr] = d.get(attr, default) - min_weight + 1 - (min_weight - max_weight)
726 nx._clear_cache(G)
727
728 B = maximum_branching(G, attr, default, preserve_attrs, partition)
729
730 for _, _, d in G.edges(data=True):
731 d[attr] = d.get(attr, default) + min_weight - 1 + (min_weight - max_weight)
732 nx._clear_cache(G)
733
734 for _, _, d in B.edges(data=True):
735 d[attr] = d.get(attr, default) + min_weight - 1 + (min_weight - max_weight)
736 nx._clear_cache(B)
737
738 if not is_arborescence(B):
739 raise nx.exception.NetworkXException("No maximum spanning arborescence in G.")
740
741 return B
742
743
744@nx._dispatchable(preserve_edge_attrs=True, mutates_input=True, returns_graph=True)
745def minimum_spanning_arborescence(
746 G, attr="weight", default=1, preserve_attrs=False, partition=None
747):
748 B = minimal_branching(
749 G,
750 attr=attr,
751 default=default,
752 preserve_attrs=preserve_attrs,
753 partition=partition,
754 )
755
756 if not is_arborescence(B):
757 raise nx.exception.NetworkXException("No minimum spanning arborescence in G.")
758
759 return B
760
761
762docstring_branching = """
763Returns a {kind} {style} from G.
764
765Parameters
766----------
767G : (multi)digraph-like
768 The graph to be searched.
769attr : str
770 The edge attribute used to in determining optimality.
771default : float
772 The value of the edge attribute used if an edge does not have
773 the attribute `attr`.
774preserve_attrs : bool
775 If True, preserve the other attributes of the original graph (that are not
776 passed to `attr`)
777partition : str
778 The key for the edge attribute containing the partition
779 data on the graph. Edges can be included, excluded or open using the
780 `EdgePartition` enum.
781
782Returns
783-------
784B : (multi)digraph-like
785 A {kind} {style}.
786"""
787
788docstring_arborescence = (
789 docstring_branching
790 + """
791Raises
792------
793NetworkXException
794 If the graph does not contain a {kind} {style}.
795
796"""
797)
798
799maximum_branching.__doc__ = docstring_branching.format(
800 kind="maximum", style="branching"
801)
802
803minimum_branching.__doc__ = (
804 docstring_branching.format(kind="minimum", style="branching")
805 + """
806See Also
807--------
808 minimal_branching
809"""
810)
811
812maximum_spanning_arborescence.__doc__ = docstring_arborescence.format(
813 kind="maximum", style="spanning arborescence"
814)
815
816minimum_spanning_arborescence.__doc__ = docstring_arborescence.format(
817 kind="minimum", style="spanning arborescence"
818)
819
820
821class ArborescenceIterator:
822 """
823 Iterate over all spanning arborescences of a graph in either increasing or
824 decreasing cost.
825
826 Notes
827 -----
828 This iterator uses the partition scheme from [1]_ (included edges,
829 excluded edges and open edges). It generates minimum spanning
830 arborescences using a modified Edmonds' Algorithm which respects the
831 partition of edges. For arborescences with the same weight, ties are
832 broken arbitrarily.
833
834 References
835 ----------
836 .. [1] G.K. Janssens, K. Sörensen, An algorithm to generate all spanning
837 trees in order of increasing cost, Pesquisa Operacional, 2005-08,
838 Vol. 25 (2), p. 219-229,
839 https://www.scielo.br/j/pope/a/XHswBwRwJyrfL88dmMwYNWp/?lang=en
840 """
841
842 @dataclass(order=True)
843 class Partition:
844 """
845 This dataclass represents a partition and stores a dict with the edge
846 data and the weight of the minimum spanning arborescence of the
847 partition dict.
848 """
849
850 mst_weight: float
851 partition_dict: dict = field(compare=False)
852
853 def __copy__(self):
854 return ArborescenceIterator.Partition(
855 self.mst_weight, self.partition_dict.copy()
856 )
857
858 def __init__(self, G, weight="weight", minimum=True, init_partition=None):
859 """
860 Initialize the iterator
861
862 Parameters
863 ----------
864 G : nx.DiGraph
865 The directed graph which we need to iterate trees over
866
867 weight : String, default = "weight"
868 The edge attribute used to store the weight of the edge
869
870 minimum : bool, default = True
871 Return the trees in increasing order while true and decreasing order
872 while false.
873
874 init_partition : tuple, default = None
875 In the case that certain edges have to be included or excluded from
876 the arborescences, `init_partition` should be in the form
877 `(included_edges, excluded_edges)` where each edges is a
878 `(u, v)`-tuple inside an iterable such as a list or set.
879
880 """
881 self.G = G.copy()
882 self.weight = weight
883 self.minimum = minimum
884 self.method = (
885 minimum_spanning_arborescence if minimum else maximum_spanning_arborescence
886 )
887 # Randomly create a key for an edge attribute to hold the partition data
888 self.partition_key = (
889 "ArborescenceIterators super secret partition attribute name"
890 )
891 if init_partition is not None:
892 partition_dict = {}
893 for e in init_partition[0]:
894 partition_dict[e] = nx.EdgePartition.INCLUDED
895 for e in init_partition[1]:
896 partition_dict[e] = nx.EdgePartition.EXCLUDED
897 self.init_partition = ArborescenceIterator.Partition(0, partition_dict)
898 else:
899 self.init_partition = None
900
901 def __iter__(self):
902 """
903 Returns
904 -------
905 ArborescenceIterator
906 The iterator object for this graph
907 """
908 self.partition_queue = PriorityQueue()
909 self._clear_partition(self.G)
910
911 # Write the initial partition if it exists.
912 if self.init_partition is not None:
913 self._write_partition(self.init_partition)
914
915 mst_weight = self.method(
916 self.G,
917 self.weight,
918 partition=self.partition_key,
919 preserve_attrs=True,
920 ).size(weight=self.weight)
921
922 self.partition_queue.put(
923 self.Partition(
924 mst_weight if self.minimum else -mst_weight,
925 (
926 {}
927 if self.init_partition is None
928 else self.init_partition.partition_dict
929 ),
930 )
931 )
932
933 return self
934
935 def __next__(self):
936 """
937 Returns
938 -------
939 (multi)Graph
940 The spanning tree of next greatest weight, which ties broken
941 arbitrarily.
942 """
943 if self.partition_queue.empty():
944 del self.G, self.partition_queue
945 raise StopIteration
946
947 partition = self.partition_queue.get()
948 self._write_partition(partition)
949 next_arborescence = self.method(
950 self.G,
951 self.weight,
952 partition=self.partition_key,
953 preserve_attrs=True,
954 )
955 self._partition(partition, next_arborescence)
956
957 self._clear_partition(next_arborescence)
958 return next_arborescence
959
960 def _partition(self, partition, partition_arborescence):
961 """
962 Create new partitions based of the minimum spanning tree of the
963 current minimum partition.
964
965 Parameters
966 ----------
967 partition : Partition
968 The Partition instance used to generate the current minimum spanning
969 tree.
970 partition_arborescence : nx.Graph
971 The minimum spanning arborescence of the input partition.
972 """
973 # create two new partitions with the data from the input partition dict
974 p1 = self.Partition(0, partition.partition_dict.copy())
975 p2 = self.Partition(0, partition.partition_dict.copy())
976 for e in partition_arborescence.edges:
977 # determine if the edge was open or included
978 if e not in partition.partition_dict:
979 # This is an open edge
980 p1.partition_dict[e] = nx.EdgePartition.EXCLUDED
981 p2.partition_dict[e] = nx.EdgePartition.INCLUDED
982
983 self._write_partition(p1)
984 try:
985 p1_mst = self.method(
986 self.G,
987 self.weight,
988 partition=self.partition_key,
989 preserve_attrs=True,
990 )
991
992 p1_mst_weight = p1_mst.size(weight=self.weight)
993 p1.mst_weight = p1_mst_weight if self.minimum else -p1_mst_weight
994 self.partition_queue.put(p1.__copy__())
995 except nx.NetworkXException:
996 pass
997
998 p1.partition_dict = p2.partition_dict.copy()
999
1000 def _write_partition(self, partition):
1001 """
1002 Writes the desired partition into the graph to calculate the minimum
1003 spanning tree. Also, if one incoming edge is included, mark all others
1004 as excluded so that if that vertex is merged during Edmonds' algorithm
1005 we cannot still pick another of that vertex's included edges.
1006
1007 Parameters
1008 ----------
1009 partition : Partition
1010 A Partition dataclass describing a partition on the edges of the
1011 graph.
1012 """
1013 for u, v, d in self.G.edges(data=True):
1014 if (u, v) in partition.partition_dict:
1015 d[self.partition_key] = partition.partition_dict[(u, v)]
1016 else:
1017 d[self.partition_key] = nx.EdgePartition.OPEN
1018 nx._clear_cache(self.G)
1019
1020 for n in self.G:
1021 included_count = 0
1022 excluded_count = 0
1023 for u, v, d in self.G.in_edges(nbunch=n, data=True):
1024 if d.get(self.partition_key) == nx.EdgePartition.INCLUDED:
1025 included_count += 1
1026 elif d.get(self.partition_key) == nx.EdgePartition.EXCLUDED:
1027 excluded_count += 1
1028 # Check that if there is an included edges, all other incoming ones
1029 # are excluded. If not fix it!
1030 if included_count == 1 and excluded_count != self.G.in_degree(n) - 1:
1031 for u, v, d in self.G.in_edges(nbunch=n, data=True):
1032 if d.get(self.partition_key) != nx.EdgePartition.INCLUDED:
1033 d[self.partition_key] = nx.EdgePartition.EXCLUDED
1034
1035 def _clear_partition(self, G):
1036 """
1037 Removes partition data from the graph
1038 """
1039 for u, v, d in G.edges(data=True):
1040 if self.partition_key in d:
1041 del d[self.partition_key]
1042 nx._clear_cache(self.G)