Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/networkx/algorithms/tree/branchings.py: 11%

589 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-20 07:00 +0000

1""" 

2Algorithms for finding optimum branchings and spanning arborescences. 

3 

4This implementation is based on: 

5 

6 J. Edmonds, Optimum branchings, J. Res. Natl. Bur. Standards 71B (1967), 

7 233–240. URL: http://archive.org/details/jresv71Bn4p233 

8 

9""" 

10# TODO: Implement method from Gabow, Galil, Spence and Tarjan: 

11# 

12# @article{ 

13# year={1986}, 

14# issn={0209-9683}, 

15# journal={Combinatorica}, 

16# volume={6}, 

17# number={2}, 

18# doi={10.1007/BF02579168}, 

19# title={Efficient algorithms for finding minimum spanning trees in 

20# undirected and directed graphs}, 

21# url={https://doi.org/10.1007/BF02579168}, 

22# publisher={Springer-Verlag}, 

23# keywords={68 B 15; 68 C 05}, 

24# author={Gabow, Harold N. and Galil, Zvi and Spencer, Thomas and Tarjan, 

25# Robert E.}, 

26# pages={109-122}, 

27# language={English} 

28# } 

29import string 

30from dataclasses import dataclass, field 

31from enum import Enum 

32from operator import itemgetter 

33from queue import PriorityQueue 

34 

35import networkx as nx 

36from networkx.utils import py_random_state 

37 

38from .recognition import is_arborescence, is_branching 

39 

40__all__ = [ 

41 "branching_weight", 

42 "greedy_branching", 

43 "maximum_branching", 

44 "minimum_branching", 

45 "minimal_branching", 

46 "maximum_spanning_arborescence", 

47 "minimum_spanning_arborescence", 

48 "ArborescenceIterator", 

49 "Edmonds", 

50] 

51 

52KINDS = {"max", "min"} 

53 

54STYLES = { 

55 "branching": "branching", 

56 "arborescence": "arborescence", 

57 "spanning arborescence": "arborescence", 

58} 

59 

60INF = float("inf") 

61 

62 

63@py_random_state(1) 

64def random_string(L=15, seed=None): 

65 return "".join([seed.choice(string.ascii_letters) for n in range(L)]) 

66 

67 

68def _min_weight(weight): 

69 return -weight 

70 

71 

72def _max_weight(weight): 

73 return weight 

74 

75 

76@nx._dispatch(edge_attrs={"attr": "default"}) 

77def branching_weight(G, attr="weight", default=1): 

78 """ 

79 Returns the total weight of a branching. 

80 

81 You must access this function through the networkx.algorithms.tree module. 

82 

83 Parameters 

84 ---------- 

85 G : DiGraph 

86 The directed graph. 

87 attr : str 

88 The attribute to use as weights. If None, then each edge will be 

89 treated equally with a weight of 1. 

90 default : float 

91 When `attr` is not None, then if an edge does not have that attribute, 

92 `default` specifies what value it should take. 

93 

94 Returns 

95 ------- 

96 weight: int or float 

97 The total weight of the branching. 

98 

99 Examples 

100 -------- 

101 >>> G = nx.DiGraph() 

102 >>> G.add_weighted_edges_from([(0, 1, 2), (1, 2, 4), (2, 3, 3), (3, 4, 2)]) 

103 >>> nx.tree.branching_weight(G) 

104 11 

105 

106 """ 

107 return sum(edge[2].get(attr, default) for edge in G.edges(data=True)) 

108 

109 

110@py_random_state(4) 

111@nx._dispatch(edge_attrs={"attr": "default"}) 

112def greedy_branching(G, attr="weight", default=1, kind="max", seed=None): 

113 """ 

114 Returns a branching obtained through a greedy algorithm. 

115 

116 This algorithm is wrong, and cannot give a proper optimal branching. 

117 However, we include it for pedagogical reasons, as it can be helpful to 

118 see what its outputs are. 

119 

120 The output is a branching, and possibly, a spanning arborescence. However, 

121 it is not guaranteed to be optimal in either case. 

122 

123 Parameters 

124 ---------- 

125 G : DiGraph 

126 The directed graph to scan. 

127 attr : str 

128 The attribute to use as weights. If None, then each edge will be 

129 treated equally with a weight of 1. 

130 default : float 

131 When `attr` is not None, then if an edge does not have that attribute, 

132 `default` specifies what value it should take. 

133 kind : str 

134 The type of optimum to search for: 'min' or 'max' greedy branching. 

135 seed : integer, random_state, or None (default) 

136 Indicator of random number generation state. 

137 See :ref:`Randomness<randomness>`. 

138 

139 Returns 

140 ------- 

141 B : directed graph 

142 The greedily obtained branching. 

143 

144 """ 

145 if kind not in KINDS: 

146 raise nx.NetworkXException("Unknown value for `kind`.") 

147 

148 if kind == "min": 

149 reverse = False 

150 else: 

151 reverse = True 

152 

153 if attr is None: 

154 # Generate a random string the graph probably won't have. 

155 attr = random_string(seed=seed) 

156 

157 edges = [(u, v, data.get(attr, default)) for (u, v, data) in G.edges(data=True)] 

158 

159 # We sort by weight, but also by nodes to normalize behavior across runs. 

160 try: 

161 edges.sort(key=itemgetter(2, 0, 1), reverse=reverse) 

162 except TypeError: 

163 # This will fail in Python 3.x if the nodes are of varying types. 

164 # In that case, we use the arbitrary order. 

165 edges.sort(key=itemgetter(2), reverse=reverse) 

166 

167 # The branching begins with a forest of no edges. 

168 B = nx.DiGraph() 

169 B.add_nodes_from(G) 

170 

171 # Now we add edges greedily so long we maintain the branching. 

172 uf = nx.utils.UnionFind() 

173 for i, (u, v, w) in enumerate(edges): 

174 if uf[u] == uf[v]: 

175 # Adding this edge would form a directed cycle. 

176 continue 

177 elif B.in_degree(v) == 1: 

178 # The edge would increase the degree to be greater than one. 

179 continue 

180 else: 

181 # If attr was None, then don't insert weights... 

182 data = {} 

183 if attr is not None: 

184 data[attr] = w 

185 B.add_edge(u, v, **data) 

186 uf.union(u, v) 

187 

188 return B 

189 

190 

191class MultiDiGraph_EdgeKey(nx.MultiDiGraph): 

192 """ 

193 MultiDiGraph which assigns unique keys to every edge. 

194 

195 Adds a dictionary edge_index which maps edge keys to (u, v, data) tuples. 

196 

197 This is not a complete implementation. For Edmonds algorithm, we only use 

198 add_node and add_edge, so that is all that is implemented here. During 

199 additions, any specified keys are ignored---this means that you also 

200 cannot update edge attributes through add_node and add_edge. 

201 

202 Why do we need this? Edmonds algorithm requires that we track edges, even 

203 as we change the head and tail of an edge, and even changing the weight 

204 of edges. We must reliably track edges across graph mutations. 

205 """ 

206 

207 def __init__(self, incoming_graph_data=None, **attr): 

208 cls = super() 

209 cls.__init__(incoming_graph_data=incoming_graph_data, **attr) 

210 

211 self._cls = cls 

212 self.edge_index = {} 

213 

214 import warnings 

215 

216 msg = "MultiDiGraph_EdgeKey has been deprecated and will be removed in NetworkX 3.4." 

217 warnings.warn(msg, DeprecationWarning) 

218 

219 def remove_node(self, n): 

220 keys = set() 

221 for keydict in self.pred[n].values(): 

222 keys.update(keydict) 

223 for keydict in self.succ[n].values(): 

224 keys.update(keydict) 

225 

226 for key in keys: 

227 del self.edge_index[key] 

228 

229 self._cls.remove_node(n) 

230 

231 def remove_nodes_from(self, nbunch): 

232 for n in nbunch: 

233 self.remove_node(n) 

234 

235 def add_edge(self, u_for_edge, v_for_edge, key_for_edge, **attr): 

236 """ 

237 Key is now required. 

238 

239 """ 

240 u, v, key = u_for_edge, v_for_edge, key_for_edge 

241 if key in self.edge_index: 

242 uu, vv, _ = self.edge_index[key] 

243 if (u != uu) or (v != vv): 

244 raise Exception(f"Key {key!r} is already in use.") 

245 

246 self._cls.add_edge(u, v, key, **attr) 

247 self.edge_index[key] = (u, v, self.succ[u][v][key]) 

248 

249 def add_edges_from(self, ebunch_to_add, **attr): 

250 for u, v, k, d in ebunch_to_add: 

251 self.add_edge(u, v, k, **d) 

252 

253 def remove_edge_with_key(self, key): 

254 try: 

255 u, v, _ = self.edge_index[key] 

256 except KeyError as err: 

257 raise KeyError(f"Invalid edge key {key!r}") from err 

258 else: 

259 del self.edge_index[key] 

260 self._cls.remove_edge(u, v, key) 

261 

262 def remove_edges_from(self, ebunch): 

263 raise NotImplementedError 

264 

265 

266def get_path(G, u, v): 

267 """ 

268 Returns the edge keys of the unique path between u and v. 

269 

270 This is not a generic function. G must be a branching and an instance of 

271 MultiDiGraph_EdgeKey. 

272 

273 """ 

274 nodes = nx.shortest_path(G, u, v) 

275 

276 # We are guaranteed that there is only one edge connected every node 

277 # in the shortest path. 

278 

279 def first_key(i, vv): 

280 # Needed for 2.x/3.x compatibility 

281 keys = G[nodes[i]][vv].keys() 

282 # Normalize behavior 

283 keys = list(keys) 

284 return keys[0] 

285 

286 edges = [first_key(i, vv) for i, vv in enumerate(nodes[1:])] 

287 return nodes, edges 

288 

289 

290class Edmonds: 

291 """ 

292 Edmonds algorithm [1]_ for finding optimal branchings and spanning 

293 arborescences. 

294 

295 This algorithm can find both minimum and maximum spanning arborescences and 

296 branchings. 

297 

298 Notes 

299 ----- 

300 While this algorithm can find a minimum branching, since it isn't required 

301 to be spanning, the minimum branching is always from the set of negative 

302 weight edges which is most likely the empty set for most graphs. 

303 

304 References 

305 ---------- 

306 .. [1] J. Edmonds, Optimum Branchings, Journal of Research of the National 

307 Bureau of Standards, 1967, Vol. 71B, p.233-240, 

308 https://archive.org/details/jresv71Bn4p233 

309 

310 """ 

311 

312 def __init__(self, G, seed=None): 

313 self.G_original = G 

314 

315 # Need to fix this. We need the whole tree. 

316 self.store = True 

317 

318 # The final answer. 

319 self.edges = [] 

320 

321 # Since we will be creating graphs with new nodes, we need to make 

322 # sure that our node names do not conflict with the real node names. 

323 self.template = random_string(seed=seed) + "_{0}" 

324 

325 import warnings 

326 

327 msg = "Edmonds has been deprecated and will be removed in NetworkX 3.4. Please use the appropriate minimum or maximum branching or arborescence function directly." 

328 warnings.warn(msg, DeprecationWarning) 

329 

330 def _init(self, attr, default, kind, style, preserve_attrs, seed, partition): 

331 """ 

332 So we need the code in _init and find_optimum to successfully run edmonds algorithm. 

333 Responsibilities of the _init function: 

334 - Check that the kind argument is in {min, max} or raise a NetworkXException. 

335 - Transform the graph if we need a minimum arborescence/branching. 

336 - The current method is to map weight -> -weight. This is NOT a good approach since 

337 the algorithm can and does choose to ignore negative weights when creating a branching 

338 since that is always optimal when maximzing the weights. I think we should set the edge 

339 weights to be (max_weight + 1) - edge_weight. 

340 - Transform the graph into a MultiDiGraph, adding the partition information and potoentially 

341 other edge attributes if we set preserve_attrs = True. 

342 - Setup the buckets and union find data structures required for the algorithm. 

343 """ 

344 if kind not in KINDS: 

345 raise nx.NetworkXException("Unknown value for `kind`.") 

346 

347 # Store inputs. 

348 self.attr = attr 

349 self.default = default 

350 self.kind = kind 

351 self.style = style 

352 

353 # Determine how we are going to transform the weights. 

354 if kind == "min": 

355 self.trans = trans = _min_weight 

356 else: 

357 self.trans = trans = _max_weight 

358 

359 if attr is None: 

360 # Generate a random attr the graph probably won't have. 

361 attr = random_string(seed=seed) 

362 

363 # This is the actual attribute used by the algorithm. 

364 self._attr = attr 

365 

366 # This attribute is used to store whether a particular edge is still 

367 # a candidate. We generate a random attr to remove clashes with 

368 # preserved edges 

369 self.candidate_attr = "candidate_" + random_string(seed=seed) 

370 

371 # The object we manipulate at each step is a multidigraph. 

372 self.G = G = MultiDiGraph_EdgeKey() 

373 for key, (u, v, data) in enumerate(self.G_original.edges(data=True)): 

374 d = {attr: trans(data.get(attr, default))} 

375 

376 if data.get(partition) is not None: 

377 d[partition] = data.get(partition) 

378 

379 if preserve_attrs: 

380 for d_k, d_v in data.items(): 

381 if d_k != attr: 

382 d[d_k] = d_v 

383 

384 G.add_edge(u, v, key, **d) 

385 

386 self.level = 0 

387 

388 # These are the "buckets" from the paper. 

389 # 

390 # As in the paper, G^i are modified versions of the original graph. 

391 # D^i and E^i are nodes and edges of the maximal edges that are 

392 # consistent with G^i. These are dashed edges in figures A-F of the 

393 # paper. In this implementation, we store D^i and E^i together as a 

394 # graph B^i. So we will have strictly more B^i than the paper does. 

395 self.B = MultiDiGraph_EdgeKey() 

396 self.B.edge_index = {} 

397 self.graphs = [] # G^i 

398 self.branchings = [] # B^i 

399 self.uf = nx.utils.UnionFind() 

400 

401 # A list of lists of edge indexes. Each list is a circuit for graph G^i. 

402 # Note the edge list will not, in general, be a circuit in graph G^0. 

403 self.circuits = [] 

404 # Stores the index of the minimum edge in the circuit found in G^i 

405 # and B^i. The ordering of the edges seems to preserve the weight 

406 # ordering from G^0. So even if the circuit does not form a circuit 

407 # in G^0, it is still true that the minimum edge of the circuit in 

408 # G^i is still the minimum edge in circuit G^0 (despite their weights 

409 # being different). 

410 self.minedge_circuit = [] 

411 

412 # TODO: separate each step into an inner function. Then the overall loop would become 

413 # while True: 

414 # step_I1() 

415 # if cycle detected: 

416 # step_I2() 

417 # elif every node of G is in D and E is a branching 

418 # break 

419 

420 def find_optimum( 

421 self, 

422 attr="weight", 

423 default=1, 

424 kind="max", 

425 style="branching", 

426 preserve_attrs=False, 

427 partition=None, 

428 seed=None, 

429 ): 

430 """ 

431 Returns a branching from G. 

432 

433 Parameters 

434 ---------- 

435 attr : str 

436 The edge attribute used to in determining optimality. 

437 default : float 

438 The value of the edge attribute used if an edge does not have 

439 the attribute `attr`. 

440 kind : {'min', 'max'} 

441 The type of optimum to search for, either 'min' or 'max'. 

442 style : {'branching', 'arborescence'} 

443 If 'branching', then an optimal branching is found. If `style` is 

444 'arborescence', then a branching is found, such that if the 

445 branching is also an arborescence, then the branching is an 

446 optimal spanning arborescences. A given graph G need not have 

447 an optimal spanning arborescence. 

448 preserve_attrs : bool 

449 If True, preserve the other edge attributes of the original 

450 graph (that are not the one passed to `attr`) 

451 partition : str 

452 The edge attribute holding edge partition data. Used in the 

453 spanning arborescence iterator. 

454 seed : integer, random_state, or None (default) 

455 Indicator of random number generation state. 

456 See :ref:`Randomness<randomness>`. 

457 

458 Returns 

459 ------- 

460 H : (multi)digraph 

461 The branching. 

462 

463 """ 

464 self._init(attr, default, kind, style, preserve_attrs, seed, partition) 

465 uf = self.uf 

466 

467 # This enormous while loop could use some refactoring... 

468 

469 G, B = self.G, self.B 

470 D = set() 

471 nodes = iter(list(G.nodes())) 

472 attr = self._attr 

473 G_pred = G.pred 

474 

475 def desired_edge(v): 

476 """ 

477 Find the edge directed toward v with maximal weight. 

478 

479 If an edge partition exists in this graph, return the included edge 

480 if it exists and no not return any excluded edges. There can only 

481 be one included edge for each vertex otherwise the edge partition is 

482 empty. 

483 """ 

484 edge = None 

485 weight = -INF 

486 for u, _, key, data in G.in_edges(v, data=True, keys=True): 

487 # Skip excluded edges 

488 if data.get(partition) == nx.EdgePartition.EXCLUDED: 

489 continue 

490 new_weight = data[attr] 

491 # Return the included edge 

492 if data.get(partition) == nx.EdgePartition.INCLUDED: 

493 weight = new_weight 

494 edge = (u, v, key, new_weight, data) 

495 return edge, weight 

496 # Find the best open edge 

497 if new_weight > weight: 

498 weight = new_weight 

499 edge = (u, v, key, new_weight, data) 

500 

501 return edge, weight 

502 

503 while True: 

504 # (I1): Choose a node v in G^i not in D^i. 

505 try: 

506 v = next(nodes) 

507 except StopIteration: 

508 # If there are no more new nodes to consider, then we *should* 

509 # meet the break condition (b) from the paper: 

510 # (b) every node of G^i is in D^i and E^i is a branching 

511 # Construction guarantees that it's a branching. 

512 assert len(G) == len(B) 

513 if len(B): 

514 assert is_branching(B) 

515 

516 if self.store: 

517 self.graphs.append(G.copy()) 

518 self.branchings.append(B.copy()) 

519 

520 # Add these to keep the lengths equal. Element i is the 

521 # circuit at level i that was merged to form branching i+1. 

522 # There is no circuit for the last level. 

523 self.circuits.append([]) 

524 self.minedge_circuit.append(None) 

525 break 

526 else: 

527 if v in D: 

528 # print("v in D", v) 

529 continue 

530 

531 # Put v into bucket D^i. 

532 # print(f"Adding node {v}") 

533 D.add(v) 

534 B.add_node(v) 

535 # End (I1) 

536 

537 # Start cycle detection 

538 edge, weight = desired_edge(v) 

539 # print(f"Max edge is {edge!r}") 

540 if edge is None: 

541 # If there is no edge, continue with a new node at (I1). 

542 continue 

543 else: 

544 # Determine if adding the edge to E^i would mean its no longer 

545 # a branching. Presently, v has indegree 0 in B---it is a root. 

546 u = edge[0] 

547 

548 if uf[u] == uf[v]: 

549 # Then adding the edge will create a circuit. Then B 

550 # contains a unique path P from v to u. So condition (a) 

551 # from the paper does hold. We need to store the circuit 

552 # for future reference. 

553 Q_nodes, Q_edges = get_path(B, v, u) 

554 Q_edges.append(edge[2]) # Edge key 

555 else: 

556 # Then B with the edge is still a branching and condition 

557 # (a) from the paper does not hold. 

558 Q_nodes, Q_edges = None, None 

559 # End cycle detection 

560 

561 # THIS WILL PROBABLY BE REMOVED? MAYBE A NEW ARG FOR THIS FEATURE? 

562 # Conditions for adding the edge. 

563 # If weight < 0, then it cannot help in finding a maximum branching. 

564 # This is the root of the problem with minimum branching. 

565 if self.style == "branching" and weight <= 0: 

566 acceptable = False 

567 else: 

568 acceptable = True 

569 

570 # print(f"Edge is acceptable: {acceptable}") 

571 if acceptable: 

572 dd = {attr: weight} 

573 if edge[4].get(partition) is not None: 

574 dd[partition] = edge[4].get(partition) 

575 B.add_edge(u, v, edge[2], **dd) 

576 G[u][v][edge[2]][self.candidate_attr] = True 

577 uf.union(u, v) 

578 if Q_edges is not None: 

579 # print("Edge introduced a simple cycle:") 

580 # print(Q_nodes, Q_edges) 

581 

582 # Move to method 

583 # Previous meaning of u and v is no longer important. 

584 

585 # Apply (I2). 

586 # Get the edge in the cycle with the minimum weight. 

587 # Also, save the incoming weights for each node. 

588 minweight = INF 

589 minedge = None 

590 Q_incoming_weight = {} 

591 for edge_key in Q_edges: 

592 u, v, data = B.edge_index[edge_key] 

593 # We cannot remove an included edges, even if it is 

594 # the minimum edge in the circuit 

595 w = data[attr] 

596 Q_incoming_weight[v] = w 

597 if data.get(partition) == nx.EdgePartition.INCLUDED: 

598 continue 

599 if w < minweight: 

600 minweight = w 

601 minedge = edge_key 

602 

603 self.circuits.append(Q_edges) 

604 self.minedge_circuit.append(minedge) 

605 

606 if self.store: 

607 self.graphs.append(G.copy()) 

608 # Always need the branching with circuits. 

609 self.branchings.append(B.copy()) 

610 

611 # Now we mutate it. 

612 new_node = self.template.format(self.level) 

613 

614 # print(minweight, minedge, Q_incoming_weight) 

615 

616 G.add_node(new_node) 

617 new_edges = [] 

618 for u, v, key, data in G.edges(data=True, keys=True): 

619 if u in Q_incoming_weight: 

620 if v in Q_incoming_weight: 

621 # Circuit edge, do nothing for now. 

622 # Eventually delete it. 

623 continue 

624 else: 

625 # Outgoing edge. Make it from new node 

626 dd = data.copy() 

627 new_edges.append((new_node, v, key, dd)) 

628 else: 

629 if v in Q_incoming_weight: 

630 # Incoming edge. Change its weight 

631 w = data[attr] 

632 w += minweight - Q_incoming_weight[v] 

633 dd = data.copy() 

634 dd[attr] = w 

635 new_edges.append((u, new_node, key, dd)) 

636 else: 

637 # Outside edge. No modification necessary. 

638 continue 

639 

640 G.remove_nodes_from(Q_nodes) 

641 B.remove_nodes_from(Q_nodes) 

642 D.difference_update(set(Q_nodes)) 

643 

644 for u, v, key, data in new_edges: 

645 G.add_edge(u, v, key, **data) 

646 if self.candidate_attr in data: 

647 del data[self.candidate_attr] 

648 B.add_edge(u, v, key, **data) 

649 uf.union(u, v) 

650 

651 nodes = iter(list(G.nodes())) 

652 self.level += 1 

653 # END STEP (I2)? 

654 

655 # (I3) Branch construction. 

656 # print(self.level) 

657 H = self.G_original.__class__() 

658 

659 def is_root(G, u, edgekeys): 

660 """ 

661 Returns True if `u` is a root node in G. 

662 

663 Node `u` will be a root node if its in-degree, restricted to the 

664 specified edges, is equal to 0. 

665 

666 """ 

667 if u not in G: 

668 # print(G.nodes(), u) 

669 raise Exception(f"{u!r} not in G") 

670 for v in G.pred[u]: 

671 for edgekey in G.pred[u][v]: 

672 if edgekey in edgekeys: 

673 return False, edgekey 

674 else: 

675 return True, None 

676 

677 # Start with the branching edges in the last level. 

678 edges = set(self.branchings[self.level].edge_index) 

679 while self.level > 0: 

680 self.level -= 1 

681 

682 # The current level is i, and we start counting from 0. 

683 

684 # We need the node at level i+1 that results from merging a circuit 

685 # at level i. randomname_0 is the first merged node and this 

686 # happens at level 1. That is, randomname_0 is a node at level 1 

687 # that results from merging a circuit at level 0. 

688 merged_node = self.template.format(self.level) 

689 

690 # The circuit at level i that was merged as a node the graph 

691 # at level i+1. 

692 circuit = self.circuits[self.level] 

693 # print 

694 # print(merged_node, self.level, circuit) 

695 # print("before", edges) 

696 # Note, we ask if it is a root in the full graph, not the branching. 

697 # The branching alone doesn't have all the edges. 

698 isroot, edgekey = is_root(self.graphs[self.level + 1], merged_node, edges) 

699 edges.update(circuit) 

700 if isroot: 

701 minedge = self.minedge_circuit[self.level] 

702 if minedge is None: 

703 raise Exception 

704 

705 # Remove the edge in the cycle with minimum weight. 

706 edges.remove(minedge) 

707 else: 

708 # We have identified an edge at next higher level that 

709 # transitions into the merged node at the level. That edge 

710 # transitions to some corresponding node at the current level. 

711 # We want to remove an edge from the cycle that transitions 

712 # into the corresponding node. 

713 # print("edgekey is: ", edgekey) 

714 # print("circuit is: ", circuit) 

715 # The branching at level i 

716 G = self.graphs[self.level] 

717 # print(G.edge_index) 

718 target = G.edge_index[edgekey][1] 

719 for edgekey in circuit: 

720 u, v, data = G.edge_index[edgekey] 

721 if v == target: 

722 break 

723 else: 

724 raise Exception("Couldn't find edge incoming to merged node.") 

725 

726 edges.remove(edgekey) 

727 

728 self.edges = edges 

729 

730 H.add_nodes_from(self.G_original) 

731 for edgekey in edges: 

732 u, v, d = self.graphs[0].edge_index[edgekey] 

733 dd = {self.attr: self.trans(d[self.attr])} 

734 

735 # Optionally, preserve the other edge attributes of the original 

736 # graph 

737 if preserve_attrs: 

738 for key, value in d.items(): 

739 if key not in [self.attr, self.candidate_attr]: 

740 dd[key] = value 

741 

742 # TODO: make this preserve the key. 

743 H.add_edge(u, v, **dd) 

744 

745 return H 

746 

747 

748@nx._dispatch( 

749 edge_attrs={"attr": "default", "partition": 0}, 

750 preserve_edge_attrs="preserve_attrs", 

751) 

752def maximum_branching( 

753 G, 

754 attr="weight", 

755 default=1, 

756 preserve_attrs=False, 

757 partition=None, 

758): 

759 ####################################### 

760 ### Data Structure Helper Functions ### 

761 ####################################### 

762 

763 def edmonds_add_edge(G, edge_index, u, v, key, **d): 

764 """ 

765 Adds an edge to `G` while also updating the edge index. 

766 

767 This algorithm requires the use of an external dictionary to track 

768 the edge keys since it is possible that the source or destination 

769 node of an edge will be changed and the default key-handling 

770 capabilities of the MultiDiGraph class do not account for this. 

771 

772 Parameters 

773 ---------- 

774 G : MultiDiGraph 

775 The graph to insert an edge into. 

776 edge_index : dict 

777 A mapping from integers to the edges of the graph. 

778 u : node 

779 The source node of the new edge. 

780 v : node 

781 The destination node of the new edge. 

782 key : int 

783 The key to use from `edge_index`. 

784 d : keyword arguments, optional 

785 Other attributes to store on the new edge. 

786 """ 

787 

788 if key in edge_index: 

789 uu, vv, _ = edge_index[key] 

790 if (u != uu) or (v != vv): 

791 raise Exception(f"Key {key!r} is already in use.") 

792 

793 G.add_edge(u, v, key, **d) 

794 edge_index[key] = (u, v, G.succ[u][v][key]) 

795 

796 def edmonds_remove_node(G, edge_index, n): 

797 """ 

798 Remove a node from the graph, updating the edge index to match. 

799 

800 Parameters 

801 ---------- 

802 G : MultiDiGraph 

803 The graph to remove an edge from. 

804 edge_index : dict 

805 A mapping from integers to the edges of the graph. 

806 n : node 

807 The node to remove from `G`. 

808 """ 

809 keys = set() 

810 for keydict in G.pred[n].values(): 

811 keys.update(keydict) 

812 for keydict in G.succ[n].values(): 

813 keys.update(keydict) 

814 

815 for key in keys: 

816 del edge_index[key] 

817 

818 G.remove_node(n) 

819 

820 ####################### 

821 ### Algorithm Setup ### 

822 ####################### 

823 

824 # Pick an attribute name that the original graph is unlikly to have 

825 candidate_attr = "edmonds' secret candidate attribute" 

826 new_node_base_name = "edmonds new node base name " 

827 

828 G_original = G 

829 G = nx.MultiDiGraph() 

830 # A dict to reliably track mutations to the edges using the key of the edge. 

831 G_edge_index = {} 

832 # Each edge is given an arbitrary numerical key 

833 for key, (u, v, data) in enumerate(G_original.edges(data=True)): 

834 d = {attr: data.get(attr, default)} 

835 

836 if data.get(partition) is not None: 

837 d[partition] = data.get(partition) 

838 

839 if preserve_attrs: 

840 for d_k, d_v in data.items(): 

841 if d_k != attr: 

842 d[d_k] = d_v 

843 

844 edmonds_add_edge(G, G_edge_index, u, v, key, **d) 

845 

846 level = 0 # Stores the number of contracted nodes 

847 

848 # These are the buckets from the paper. 

849 # 

850 # In the paper, G^i are modified versions of the original graph. 

851 # D^i and E^i are the nodes and edges of the maximal edges that are 

852 # consistent with G^i. In this implementation, D^i and E^i are stored 

853 # together as the graph B^i. We will have strictly more B^i then the 

854 # paper will have. 

855 # 

856 # Note that the data in graphs and branchings are tuples with the graph as 

857 # the first element and the edge index as the second. 

858 B = nx.MultiDiGraph() 

859 B_edge_index = {} 

860 graphs = [] # G^i list 

861 branchings = [] # B^i list 

862 selected_nodes = set() # D^i bucket 

863 uf = nx.utils.UnionFind() 

864 

865 # A list of lists of edge indices. Each list is a circuit for graph G^i. 

866 # Note the edge list is not required to be a circuit in G^0. 

867 circuits = [] 

868 

869 # Stores the index of the minimum edge in the circuit found in G^i and B^i. 

870 # The ordering of the edges seems to preserver the weight ordering from 

871 # G^0. So even if the circuit does not form a circuit in G^0, it is still 

872 # true that the minimum edges in circuit G^0 (despite their weights being 

873 # different) 

874 minedge_circuit = [] 

875 

876 ########################### 

877 ### Algorithm Structure ### 

878 ########################### 

879 

880 # Each step listed in the algorithm is an inner function. Thus, the overall 

881 # loop structure is: 

882 # 

883 # while True: 

884 # step_I1() 

885 # if cycle detected: 

886 # step_I2() 

887 # elif every node of G is in D and E is a branching: 

888 # break 

889 

890 ################################## 

891 ### Algorithm Helper Functions ### 

892 ################################## 

893 

894 def edmonds_find_desired_edge(v): 

895 """ 

896 Find the edge directed towards v with maximal weight. 

897 

898 If an edge partition exists in this graph, return the included 

899 edge if it exists and never return any excluded edge. 

900 

901 Note: There can only be one included edge for each vertex otherwise 

902 the edge partition is empty. 

903 

904 Parameters 

905 ---------- 

906 v : node 

907 The node to search for the maximal weight incoming edge. 

908 """ 

909 edge = None 

910 max_weight = -INF 

911 for u, _, key, data in G.in_edges(v, data=True, keys=True): 

912 # Skip excluded edges 

913 if data.get(partition) == nx.EdgePartition.EXCLUDED: 

914 continue 

915 

916 new_weight = data[attr] 

917 

918 # Return the included edge 

919 if data.get(partition) == nx.EdgePartition.INCLUDED: 

920 max_weight = new_weight 

921 edge = (u, v, key, new_weight, data) 

922 break 

923 

924 # Find the best open edge 

925 if new_weight > max_weight: 

926 max_weight = new_weight 

927 edge = (u, v, key, new_weight, data) 

928 

929 return edge, max_weight 

930 

931 def edmonds_step_I2(v, desired_edge, level): 

932 """ 

933 Perform step I2 from Edmonds' paper 

934 

935 First, check if the last step I1 created a cycle. If it did not, do nothing. 

936 If it did, store the cycle for later reference and contract it. 

937 

938 Parameters 

939 ---------- 

940 v : node 

941 The current node to consider 

942 desired_edge : edge 

943 The minimum desired edge to remove from the cycle. 

944 level : int 

945 The current level, i.e. the number of cycles that have already been removed. 

946 """ 

947 u = desired_edge[0] 

948 

949 Q_nodes = nx.shortest_path(B, v, u) 

950 Q_edges = [ 

951 list(B[Q_nodes[i]][vv].keys())[0] for i, vv in enumerate(Q_nodes[1:]) 

952 ] 

953 Q_edges.append(desired_edge[2]) # Add the new edge key to complete the circuit 

954 

955 # Get the edge in the circuit with the minimum weight. 

956 # Also, save the incoming weights for each node. 

957 minweight = INF 

958 minedge = None 

959 Q_incoming_weight = {} 

960 for edge_key in Q_edges: 

961 u, v, data = B_edge_index[edge_key] 

962 w = data[attr] 

963 # We cannot remove an included edge, even if it is the 

964 # minimum edge in the circuit 

965 Q_incoming_weight[v] = w 

966 if data.get(partition) == nx.EdgePartition.INCLUDED: 

967 continue 

968 if w < minweight: 

969 minweight = w 

970 minedge = edge_key 

971 

972 circuits.append(Q_edges) 

973 minedge_circuit.append(minedge) 

974 graphs.append((G.copy(), G_edge_index.copy())) 

975 branchings.append((B.copy(), B_edge_index.copy())) 

976 

977 # Mutate the graph to contract the circuit 

978 new_node = new_node_base_name + str(level) 

979 G.add_node(new_node) 

980 new_edges = [] 

981 for u, v, key, data in G.edges(data=True, keys=True): 

982 if u in Q_incoming_weight: 

983 if v in Q_incoming_weight: 

984 # Circuit edge. For the moment do nothing, 

985 # eventually it will be removed. 

986 continue 

987 else: 

988 # Outgoing edge from a node in the circuit. 

989 # Make it come from the new node instead 

990 dd = data.copy() 

991 new_edges.append((new_node, v, key, dd)) 

992 else: 

993 if v in Q_incoming_weight: 

994 # Incoming edge to the circuit. 

995 # Update it's weight 

996 w = data[attr] 

997 w += minweight - Q_incoming_weight[v] 

998 dd = data.copy() 

999 dd[attr] = w 

1000 new_edges.append((u, new_node, key, dd)) 

1001 else: 

1002 # Outside edge. No modification needed 

1003 continue 

1004 

1005 for node in Q_nodes: 

1006 edmonds_remove_node(G, G_edge_index, node) 

1007 edmonds_remove_node(B, B_edge_index, node) 

1008 

1009 selected_nodes.difference_update(set(Q_nodes)) 

1010 

1011 for u, v, key, data in new_edges: 

1012 edmonds_add_edge(G, G_edge_index, u, v, key, **data) 

1013 if candidate_attr in data: 

1014 del data[candidate_attr] 

1015 edmonds_add_edge(B, B_edge_index, u, v, key, **data) 

1016 uf.union(u, v) 

1017 

1018 def is_root(G, u, edgekeys): 

1019 """ 

1020 Returns True if `u` is a root node in G. 

1021 

1022 Node `u` is a root node if its in-degree over the specified edges is zero. 

1023 

1024 Parameters 

1025 ---------- 

1026 G : Graph 

1027 The current graph. 

1028 u : node 

1029 The node in `G` to check if it is a root. 

1030 edgekeys : iterable of edges 

1031 The edges for which to check if `u` is a root of. 

1032 """ 

1033 if u not in G: 

1034 raise Exception(f"{u!r} not in G") 

1035 

1036 for v in G.pred[u]: 

1037 for edgekey in G.pred[u][v]: 

1038 if edgekey in edgekeys: 

1039 return False, edgekey 

1040 else: 

1041 return True, None 

1042 

1043 nodes = iter(list(G.nodes)) 

1044 while True: 

1045 try: 

1046 v = next(nodes) 

1047 except StopIteration: 

1048 # If there are no more new nodes to consider, then we should 

1049 # meet stopping condition (b) from the paper: 

1050 # (b) every node of G^i is in D^i and E^i is a branching 

1051 assert len(G) == len(B) 

1052 if len(B): 

1053 assert is_branching(B) 

1054 

1055 graphs.append((G.copy(), G_edge_index.copy())) 

1056 branchings.append((B.copy(), B_edge_index.copy())) 

1057 circuits.append([]) 

1058 minedge_circuit.append(None) 

1059 

1060 break 

1061 else: 

1062 ##################### 

1063 ### BEGIN STEP I1 ### 

1064 ##################### 

1065 

1066 # This is a very simple step, so I don't think it needs a method of it's own 

1067 if v in selected_nodes: 

1068 continue 

1069 

1070 selected_nodes.add(v) 

1071 B.add_node(v) 

1072 desired_edge, desired_edge_weight = edmonds_find_desired_edge(v) 

1073 

1074 # There might be no desired edge if all edges are excluded or 

1075 # v is the last node to be added to B, the ultimate root of the branching 

1076 if desired_edge is not None and desired_edge_weight > 0: 

1077 u = desired_edge[0] 

1078 # Flag adding the edge will create a circuit before merging the two 

1079 # connected components of u and v in B 

1080 circuit = uf[u] == uf[v] 

1081 dd = {attr: desired_edge_weight} 

1082 if desired_edge[4].get(partition) is not None: 

1083 dd[partition] = desired_edge[4].get(partition) 

1084 

1085 edmonds_add_edge(B, B_edge_index, u, v, desired_edge[2], **dd) 

1086 G[u][v][desired_edge[2]][candidate_attr] = True 

1087 uf.union(u, v) 

1088 

1089 ################### 

1090 ### END STEP I1 ### 

1091 ################### 

1092 

1093 ##################### 

1094 ### BEGIN STEP I2 ### 

1095 ##################### 

1096 

1097 if circuit: 

1098 edmonds_step_I2(v, desired_edge, level) 

1099 nodes = iter(list(G.nodes())) 

1100 level += 1 

1101 

1102 ################### 

1103 ### END STEP I2 ### 

1104 ################### 

1105 

1106 ##################### 

1107 ### BEGIN STEP I3 ### 

1108 ##################### 

1109 

1110 # Create a new graph of the same class as the input graph 

1111 H = G_original.__class__() 

1112 

1113 # Start with the branching edges in the last level. 

1114 edges = set(branchings[level][1]) 

1115 while level > 0: 

1116 level -= 1 

1117 

1118 # The current level is i, and we start counting from 0. 

1119 # 

1120 # We need the node at level i+1 that results from merging a circuit 

1121 # at level i. basename_0 is the first merged node and this happens 

1122 # at level 1. That is basename_0 is a node at level 1 that results 

1123 # from merging a circuit at level 0. 

1124 

1125 merged_node = new_node_base_name + str(level) 

1126 circuit = circuits[level] 

1127 isroot, edgekey = is_root(graphs[level + 1][0], merged_node, edges) 

1128 edges.update(circuit) 

1129 

1130 if isroot: 

1131 minedge = minedge_circuit[level] 

1132 if minedge is None: 

1133 raise Exception 

1134 

1135 # Remove the edge in the cycle with minimum weight 

1136 edges.remove(minedge) 

1137 else: 

1138 # We have identified an edge at the next higher level that 

1139 # transitions into the merged node at this level. That edge 

1140 # transitions to some corresponding node at the current level. 

1141 # 

1142 # We want to remove an edge from the cycle that transitions 

1143 # into the corresponding node, otherwise the result would not 

1144 # be a branching. 

1145 

1146 G, G_edge_index = graphs[level] 

1147 target = G_edge_index[edgekey][1] 

1148 for edgekey in circuit: 

1149 u, v, data = G_edge_index[edgekey] 

1150 if v == target: 

1151 break 

1152 else: 

1153 raise Exception("Couldn't find edge incoming to merged node.") 

1154 

1155 edges.remove(edgekey) 

1156 

1157 H.add_nodes_from(G_original) 

1158 for edgekey in edges: 

1159 u, v, d = graphs[0][1][edgekey] 

1160 dd = {attr: d[attr]} 

1161 

1162 if preserve_attrs: 

1163 for key, value in d.items(): 

1164 if key not in [attr, candidate_attr]: 

1165 dd[key] = value 

1166 

1167 H.add_edge(u, v, **dd) 

1168 

1169 ################### 

1170 ### END STEP I3 ### 

1171 ################### 

1172 

1173 return H 

1174 

1175 

1176@nx._dispatch( 

1177 edge_attrs={"attr": "default", "partition": None}, 

1178 preserve_edge_attrs="preserve_attrs", 

1179) 

1180def minimum_branching( 

1181 G, attr="weight", default=1, preserve_attrs=False, partition=None 

1182): 

1183 for _, _, d in G.edges(data=True): 

1184 d[attr] = -d[attr] 

1185 

1186 B = maximum_branching(G, attr, default, preserve_attrs, partition) 

1187 

1188 for _, _, d in G.edges(data=True): 

1189 d[attr] = -d[attr] 

1190 

1191 for _, _, d in B.edges(data=True): 

1192 d[attr] = -d[attr] 

1193 

1194 return B 

1195 

1196 

1197@nx._dispatch( 

1198 edge_attrs={"attr": "default", "partition": None}, 

1199 preserve_edge_attrs="preserve_attrs", 

1200) 

1201def minimal_branching( 

1202 G, /, *, attr="weight", default=1, preserve_attrs=False, partition=None 

1203): 

1204 """ 

1205 Returns a minimal branching from `G`. 

1206 

1207 A minimal branching is a branching similar to a minimal arborescence but 

1208 without the requirement that the result is actually a spanning arborescence. 

1209 This allows minimal branchinges to be computed over graphs which may not 

1210 have arborescence (such as multiple components). 

1211 

1212 Parameters 

1213 ---------- 

1214 G : (multi)digraph-like 

1215 The graph to be searched. 

1216 attr : str 

1217 The edge attribute used in determining optimality. 

1218 default : float 

1219 The value of the edge attribute used if an edge does not have 

1220 the attribute `attr`. 

1221 preserve_attrs : bool 

1222 If True, preserve the other attributes of the original graph (that are not 

1223 passed to `attr`) 

1224 partition : str 

1225 The key for the edge attribute containing the partition 

1226 data on the graph. Edges can be included, excluded or open using the 

1227 `EdgePartition` enum. 

1228 

1229 Returns 

1230 ------- 

1231 B : (multi)digraph-like 

1232 A minimal branching. 

1233 """ 

1234 max_weight = -INF 

1235 min_weight = INF 

1236 for _, _, w in G.edges(data=attr): 

1237 if w > max_weight: 

1238 max_weight = w 

1239 if w < min_weight: 

1240 min_weight = w 

1241 

1242 for _, _, d in G.edges(data=True): 

1243 # Transform the weights so that the minimum weight is larger than 

1244 # the difference between the max and min weights. This is important 

1245 # in order to prevent the edge weights from becoming negative during 

1246 # computation 

1247 d[attr] = max_weight + 1 + (max_weight - min_weight) - d[attr] 

1248 

1249 B = maximum_branching(G, attr, default, preserve_attrs, partition) 

1250 

1251 # Reverse the weight transformations 

1252 for _, _, d in G.edges(data=True): 

1253 d[attr] = max_weight + 1 + (max_weight - min_weight) - d[attr] 

1254 

1255 for _, _, d in B.edges(data=True): 

1256 d[attr] = max_weight + 1 + (max_weight - min_weight) - d[attr] 

1257 

1258 return B 

1259 

1260 

1261@nx._dispatch( 

1262 edge_attrs={"attr": "default", "partition": None}, 

1263 preserve_edge_attrs="preserve_attrs", 

1264) 

1265def maximum_spanning_arborescence( 

1266 G, attr="weight", default=1, preserve_attrs=False, partition=None 

1267): 

1268 # In order to use the same algorithm is the maximum branching, we need to adjust 

1269 # the weights of the graph. The branching algorithm can choose to not include an 

1270 # edge if it doesn't help find a branching, mainly triggered by edges with negative 

1271 # weights. 

1272 # 

1273 # To prevent this from happening while trying to find a spanning arborescence, we 

1274 # just have to tweak the edge weights so that they are all positive and cannot 

1275 # become negative during the branching algorithm, find the maximum branching and 

1276 # then return them to their original values. 

1277 

1278 min_weight = INF 

1279 max_weight = -INF 

1280 for _, _, w in G.edges(data=attr): 

1281 if w < min_weight: 

1282 min_weight = w 

1283 if w > max_weight: 

1284 max_weight = w 

1285 

1286 for _, _, d in G.edges(data=True): 

1287 d[attr] = d[attr] - min_weight + 1 - (min_weight - max_weight) 

1288 

1289 B = maximum_branching(G, attr, default, preserve_attrs, partition) 

1290 

1291 for _, _, d in G.edges(data=True): 

1292 d[attr] = d[attr] + min_weight - 1 + (min_weight - max_weight) 

1293 

1294 for _, _, d in B.edges(data=True): 

1295 d[attr] = d[attr] + min_weight - 1 + (min_weight - max_weight) 

1296 

1297 if not is_arborescence(B): 

1298 raise nx.exception.NetworkXException("No maximum spanning arborescence in G.") 

1299 

1300 return B 

1301 

1302 

1303@nx._dispatch( 

1304 edge_attrs={"attr": "default", "partition": None}, 

1305 preserve_edge_attrs="preserve_attrs", 

1306) 

1307def minimum_spanning_arborescence( 

1308 G, attr="weight", default=1, preserve_attrs=False, partition=None 

1309): 

1310 B = minimal_branching( 

1311 G, 

1312 attr=attr, 

1313 default=default, 

1314 preserve_attrs=preserve_attrs, 

1315 partition=partition, 

1316 ) 

1317 

1318 if not is_arborescence(B): 

1319 raise nx.exception.NetworkXException("No minimum spanning arborescence in G.") 

1320 

1321 return B 

1322 

1323 

1324docstring_branching = """ 

1325Returns a {kind} {style} from G. 

1326 

1327Parameters 

1328---------- 

1329G : (multi)digraph-like 

1330 The graph to be searched. 

1331attr : str 

1332 The edge attribute used to in determining optimality. 

1333default : float 

1334 The value of the edge attribute used if an edge does not have 

1335 the attribute `attr`. 

1336preserve_attrs : bool 

1337 If True, preserve the other attributes of the original graph (that are not 

1338 passed to `attr`) 

1339partition : str 

1340 The key for the edge attribute containing the partition 

1341 data on the graph. Edges can be included, excluded or open using the 

1342 `EdgePartition` enum. 

1343 

1344Returns 

1345------- 

1346B : (multi)digraph-like 

1347 A {kind} {style}. 

1348""" 

1349 

1350docstring_arborescence = ( 

1351 docstring_branching 

1352 + """ 

1353Raises 

1354------ 

1355NetworkXException 

1356 If the graph does not contain a {kind} {style}. 

1357 

1358""" 

1359) 

1360 

1361maximum_branching.__doc__ = docstring_branching.format( 

1362 kind="maximum", style="branching" 

1363) 

1364 

1365minimum_branching.__doc__ = ( 

1366 docstring_branching.format(kind="minimum", style="branching") 

1367 + """ 

1368See Also  

1369--------  

1370 minimal_branching 

1371""" 

1372) 

1373 

1374maximum_spanning_arborescence.__doc__ = docstring_arborescence.format( 

1375 kind="maximum", style="spanning arborescence" 

1376) 

1377 

1378minimum_spanning_arborescence.__doc__ = docstring_arborescence.format( 

1379 kind="minimum", style="spanning arborescence" 

1380) 

1381 

1382 

1383class ArborescenceIterator: 

1384 """ 

1385 Iterate over all spanning arborescences of a graph in either increasing or 

1386 decreasing cost. 

1387 

1388 Notes 

1389 ----- 

1390 This iterator uses the partition scheme from [1]_ (included edges, 

1391 excluded edges and open edges). It generates minimum spanning 

1392 arborescences using a modified Edmonds' Algorithm which respects the 

1393 partition of edges. For arborescences with the same weight, ties are 

1394 broken arbitrarily. 

1395 

1396 References 

1397 ---------- 

1398 .. [1] G.K. Janssens, K. Sörensen, An algorithm to generate all spanning 

1399 trees in order of increasing cost, Pesquisa Operacional, 2005-08, 

1400 Vol. 25 (2), p. 219-229, 

1401 https://www.scielo.br/j/pope/a/XHswBwRwJyrfL88dmMwYNWp/?lang=en 

1402 """ 

1403 

1404 @dataclass(order=True) 

1405 class Partition: 

1406 """ 

1407 This dataclass represents a partition and stores a dict with the edge 

1408 data and the weight of the minimum spanning arborescence of the 

1409 partition dict. 

1410 """ 

1411 

1412 mst_weight: float 

1413 partition_dict: dict = field(compare=False) 

1414 

1415 def __copy__(self): 

1416 return ArborescenceIterator.Partition( 

1417 self.mst_weight, self.partition_dict.copy() 

1418 ) 

1419 

1420 def __init__(self, G, weight="weight", minimum=True, init_partition=None): 

1421 """ 

1422 Initialize the iterator 

1423 

1424 Parameters 

1425 ---------- 

1426 G : nx.DiGraph 

1427 The directed graph which we need to iterate trees over 

1428 

1429 weight : String, default = "weight" 

1430 The edge attribute used to store the weight of the edge 

1431 

1432 minimum : bool, default = True 

1433 Return the trees in increasing order while true and decreasing order 

1434 while false. 

1435 

1436 init_partition : tuple, default = None 

1437 In the case that certain edges have to be included or excluded from 

1438 the arborescences, `init_partition` should be in the form 

1439 `(included_edges, excluded_edges)` where each edges is a 

1440 `(u, v)`-tuple inside an iterable such as a list or set. 

1441 

1442 """ 

1443 self.G = G.copy() 

1444 self.weight = weight 

1445 self.minimum = minimum 

1446 self.method = ( 

1447 minimum_spanning_arborescence if minimum else maximum_spanning_arborescence 

1448 ) 

1449 # Randomly create a key for an edge attribute to hold the partition data 

1450 self.partition_key = ( 

1451 "ArborescenceIterators super secret partition attribute name" 

1452 ) 

1453 if init_partition is not None: 

1454 partition_dict = {} 

1455 for e in init_partition[0]: 

1456 partition_dict[e] = nx.EdgePartition.INCLUDED 

1457 for e in init_partition[1]: 

1458 partition_dict[e] = nx.EdgePartition.EXCLUDED 

1459 self.init_partition = ArborescenceIterator.Partition(0, partition_dict) 

1460 else: 

1461 self.init_partition = None 

1462 

1463 def __iter__(self): 

1464 """ 

1465 Returns 

1466 ------- 

1467 ArborescenceIterator 

1468 The iterator object for this graph 

1469 """ 

1470 self.partition_queue = PriorityQueue() 

1471 self._clear_partition(self.G) 

1472 

1473 # Write the initial partition if it exists. 

1474 if self.init_partition is not None: 

1475 self._write_partition(self.init_partition) 

1476 

1477 mst_weight = self.method( 

1478 self.G, 

1479 self.weight, 

1480 partition=self.partition_key, 

1481 preserve_attrs=True, 

1482 ).size(weight=self.weight) 

1483 

1484 self.partition_queue.put( 

1485 self.Partition( 

1486 mst_weight if self.minimum else -mst_weight, 

1487 {} 

1488 if self.init_partition is None 

1489 else self.init_partition.partition_dict, 

1490 ) 

1491 ) 

1492 

1493 return self 

1494 

1495 def __next__(self): 

1496 """ 

1497 Returns 

1498 ------- 

1499 (multi)Graph 

1500 The spanning tree of next greatest weight, which ties broken 

1501 arbitrarily. 

1502 """ 

1503 if self.partition_queue.empty(): 

1504 del self.G, self.partition_queue 

1505 raise StopIteration 

1506 

1507 partition = self.partition_queue.get() 

1508 self._write_partition(partition) 

1509 next_arborescence = self.method( 

1510 self.G, 

1511 self.weight, 

1512 partition=self.partition_key, 

1513 preserve_attrs=True, 

1514 ) 

1515 self._partition(partition, next_arborescence) 

1516 

1517 self._clear_partition(next_arborescence) 

1518 return next_arborescence 

1519 

1520 def _partition(self, partition, partition_arborescence): 

1521 """ 

1522 Create new partitions based of the minimum spanning tree of the 

1523 current minimum partition. 

1524 

1525 Parameters 

1526 ---------- 

1527 partition : Partition 

1528 The Partition instance used to generate the current minimum spanning 

1529 tree. 

1530 partition_arborescence : nx.Graph 

1531 The minimum spanning arborescence of the input partition. 

1532 """ 

1533 # create two new partitions with the data from the input partition dict 

1534 p1 = self.Partition(0, partition.partition_dict.copy()) 

1535 p2 = self.Partition(0, partition.partition_dict.copy()) 

1536 for e in partition_arborescence.edges: 

1537 # determine if the edge was open or included 

1538 if e not in partition.partition_dict: 

1539 # This is an open edge 

1540 p1.partition_dict[e] = nx.EdgePartition.EXCLUDED 

1541 p2.partition_dict[e] = nx.EdgePartition.INCLUDED 

1542 

1543 self._write_partition(p1) 

1544 try: 

1545 p1_mst = self.method( 

1546 self.G, 

1547 self.weight, 

1548 partition=self.partition_key, 

1549 preserve_attrs=True, 

1550 ) 

1551 

1552 p1_mst_weight = p1_mst.size(weight=self.weight) 

1553 p1.mst_weight = p1_mst_weight if self.minimum else -p1_mst_weight 

1554 self.partition_queue.put(p1.__copy__()) 

1555 except nx.NetworkXException: 

1556 pass 

1557 

1558 p1.partition_dict = p2.partition_dict.copy() 

1559 

1560 def _write_partition(self, partition): 

1561 """ 

1562 Writes the desired partition into the graph to calculate the minimum 

1563 spanning tree. Also, if one incoming edge is included, mark all others 

1564 as excluded so that if that vertex is merged during Edmonds' algorithm 

1565 we cannot still pick another of that vertex's included edges. 

1566 

1567 Parameters 

1568 ---------- 

1569 partition : Partition 

1570 A Partition dataclass describing a partition on the edges of the 

1571 graph. 

1572 """ 

1573 for u, v, d in self.G.edges(data=True): 

1574 if (u, v) in partition.partition_dict: 

1575 d[self.partition_key] = partition.partition_dict[(u, v)] 

1576 else: 

1577 d[self.partition_key] = nx.EdgePartition.OPEN 

1578 

1579 for n in self.G: 

1580 included_count = 0 

1581 excluded_count = 0 

1582 for u, v, d in self.G.in_edges(nbunch=n, data=True): 

1583 if d.get(self.partition_key) == nx.EdgePartition.INCLUDED: 

1584 included_count += 1 

1585 elif d.get(self.partition_key) == nx.EdgePartition.EXCLUDED: 

1586 excluded_count += 1 

1587 # Check that if there is an included edges, all other incoming ones 

1588 # are excluded. If not fix it! 

1589 if included_count == 1 and excluded_count != self.G.in_degree(n) - 1: 

1590 for u, v, d in self.G.in_edges(nbunch=n, data=True): 

1591 if d.get(self.partition_key) != nx.EdgePartition.INCLUDED: 

1592 d[self.partition_key] = nx.EdgePartition.EXCLUDED 

1593 

1594 def _clear_partition(self, G): 

1595 """ 

1596 Removes partition data from the graph 

1597 """ 

1598 for u, v, d in G.edges(data=True): 

1599 if self.partition_key in d: 

1600 del d[self.partition_key]