Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/networkx/drawing/nx_pylab.py: 5%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

702 statements  

1""" 

2********** 

3Matplotlib 

4********** 

5 

6Draw networks with matplotlib. 

7 

8Examples 

9-------- 

10>>> G = nx.complete_graph(5) 

11>>> nx.draw(G) 

12 

13See Also 

14-------- 

15 - :doc:`matplotlib <matplotlib:index>` 

16 - :func:`matplotlib.pyplot.scatter` 

17 - :obj:`matplotlib.patches.FancyArrowPatch` 

18""" 

19 

20import collections 

21import itertools 

22import math 

23from numbers import Number 

24 

25import networkx as nx 

26 

27__all__ = [ 

28 "display", 

29 "apply_matplotlib_colors", 

30 "draw", 

31 "draw_networkx", 

32 "draw_networkx_nodes", 

33 "draw_networkx_edges", 

34 "draw_networkx_labels", 

35 "draw_networkx_edge_labels", 

36 "draw_bipartite", 

37 "draw_circular", 

38 "draw_kamada_kawai", 

39 "draw_random", 

40 "draw_spectral", 

41 "draw_spring", 

42 "draw_planar", 

43 "draw_shell", 

44 "draw_forceatlas2", 

45] 

46 

47 

48def apply_matplotlib_colors( 

49 G, src_attr, dest_attr, map, vmin=None, vmax=None, nodes=True 

50): 

51 """ 

52 Apply colors from a matplotlib colormap to a graph. 

53 

54 Reads values from the `src_attr` and use a matplotlib colormap 

55 to produce a color. Write the color to `dest_attr`. 

56 

57 Parameters 

58 ---------- 

59 G : nx.Graph 

60 The graph to read and compute colors for. 

61 

62 src_attr : str or other attribute name 

63 The name of the attribute to read from the graph. 

64 

65 dest_attr : str or other attribute name 

66 The name of the attribute to write to on the graph. 

67 

68 map : matplotlib.colormap 

69 The matplotlib colormap to use. 

70 

71 vmin : float, default None 

72 The minimum value for scaling the colormap. If `None`, find the 

73 minimum value of `src_attr`. 

74 

75 vmax : float, default None 

76 The maximum value for scaling the colormap. If `None`, find the 

77 maximum value of `src_attr`. 

78 

79 nodes : bool, default True 

80 Whether the attribute names are edge attributes or node attributes. 

81 """ 

82 import matplotlib as mpl 

83 

84 if nodes: 

85 type_iter = G.nodes() 

86 elif G.is_multigraph(): 

87 type_iter = G.edges(keys=True) 

88 else: 

89 type_iter = G.edges() 

90 

91 if vmin is None or vmax is None: 

92 vals = [type_iter[a][src_attr] for a in type_iter] 

93 if vmin is None: 

94 vmin = min(vals) 

95 if vmax is None: 

96 vmax = max(vals) 

97 

98 mapper = mpl.cm.ScalarMappable(cmap=map) 

99 mapper.set_clim(vmin, vmax) 

100 

101 def do_map(x): 

102 # Cast numpy scalars to float 

103 return tuple(float(x) for x in mapper.to_rgba(x)) 

104 

105 if nodes: 

106 nx.set_node_attributes( 

107 G, {n: do_map(G.nodes[n][src_attr]) for n in G.nodes()}, dest_attr 

108 ) 

109 else: 

110 nx.set_edge_attributes( 

111 G, {e: do_map(G.edges[e][src_attr]) for e in type_iter}, dest_attr 

112 ) 

113 

114 

115class CurvedArrowTextBase: 

116 def __init__( 

117 self, 

118 arrow, 

119 *args, 

120 label_pos=0.5, 

121 labels_horizontal=False, 

122 ax=None, 

123 **kwargs, 

124 ): 

125 # Bind to FancyArrowPatch 

126 self.arrow = arrow 

127 # how far along the text should be on the curve, 

128 # 0 is at start, 1 is at end etc. 

129 self.label_pos = label_pos 

130 self.labels_horizontal = labels_horizontal 

131 if ax is None: 

132 ax = plt.gca() 

133 self.ax = ax 

134 self.x, self.y, self.angle = self._update_text_pos_angle(arrow) 

135 

136 # Create text object 

137 super().__init__(self.x, self.y, *args, rotation=self.angle, **kwargs) 

138 # Bind to axis 

139 self.ax.add_artist(self) 

140 

141 def _get_arrow_path_disp(self, arrow): 

142 """ 

143 This is part of FancyArrowPatch._get_path_in_displaycoord 

144 It omits the second part of the method where path is converted 

145 to polygon based on width 

146 The transform is taken from ax, not the object, as the object 

147 has not been added yet, and doesn't have transform 

148 """ 

149 dpi_cor = arrow._dpi_cor 

150 trans_data = self.ax.transData 

151 if arrow._posA_posB is None: 

152 raise ValueError( 

153 "Can only draw labels for fancy arrows with " 

154 "posA and posB inputs, not custom path" 

155 ) 

156 posA = arrow._convert_xy_units(arrow._posA_posB[0]) 

157 posB = arrow._convert_xy_units(arrow._posA_posB[1]) 

158 (posA, posB) = trans_data.transform((posA, posB)) 

159 _path = arrow.get_connectionstyle()( 

160 posA, 

161 posB, 

162 patchA=arrow.patchA, 

163 patchB=arrow.patchB, 

164 shrinkA=arrow.shrinkA * dpi_cor, 

165 shrinkB=arrow.shrinkB * dpi_cor, 

166 ) 

167 # Return is in display coordinates 

168 return _path 

169 

170 def _update_text_pos_angle(self, arrow): 

171 # Fractional label position 

172 # Text position at a proportion t along the line in display coords 

173 # default is 0.5 so text appears at the halfway point 

174 import matplotlib as mpl 

175 import numpy as np 

176 

177 t = self.label_pos 

178 tt = 1 - t 

179 path_disp = self._get_arrow_path_disp(arrow) 

180 conn = arrow.get_connectionstyle() 

181 # 1. Calculate x and y 

182 points = path_disp.vertices 

183 if is_curve := isinstance( 

184 conn, 

185 mpl.patches.ConnectionStyle.Angle3 | mpl.patches.ConnectionStyle.Arc3, 

186 ): 

187 # Arc3 or Angle3 type Connection Styles - Bezier curve 

188 (x1, y1), (cx, cy), (x2, y2) = points 

189 x = tt**2 * x1 + 2 * t * tt * cx + t**2 * x2 

190 y = tt**2 * y1 + 2 * t * tt * cy + t**2 * y2 

191 else: 

192 if not isinstance( 

193 conn, 

194 mpl.patches.ConnectionStyle.Angle 

195 | mpl.patches.ConnectionStyle.Arc 

196 | mpl.patches.ConnectionStyle.Bar, 

197 ): 

198 msg = f"invalid connection style: {type(conn)}" 

199 raise TypeError(msg) 

200 # A. Collect lines 

201 codes = path_disp.codes 

202 lines = [ 

203 points[i - 1 : i + 1] 

204 for i in range(1, len(points)) 

205 if codes[i] == mpl.path.Path.LINETO 

206 ] 

207 # B. If more than one line, find the right one and position in it 

208 if (nlines := len(lines)) != 1: 

209 dists = [math.dist(*line) for line in lines] 

210 dist_tot = sum(dists) 

211 cdist = 0 

212 last_cut = 0 

213 i_last = nlines - 1 

214 for i, dist in enumerate(dists): 

215 cdist += dist 

216 cut = cdist / dist_tot 

217 if i == i_last or t < cut: 

218 t = (t - last_cut) / (dist / dist_tot) 

219 tt = 1 - t 

220 lines = [lines[i]] 

221 break 

222 last_cut = cut 

223 [[(cx1, cy1), (cx2, cy2)]] = lines 

224 x = cx1 * tt + cx2 * t 

225 y = cy1 * tt + cy2 * t 

226 

227 # 2. Calculate Angle 

228 if self.labels_horizontal: 

229 # Horizontal text labels 

230 angle = 0 

231 else: 

232 # Labels parallel to curve 

233 if is_curve: 

234 change_x = 2 * tt * (cx - x1) + 2 * t * (x2 - cx) 

235 change_y = 2 * tt * (cy - y1) + 2 * t * (y2 - cy) 

236 else: 

237 change_x = (cx2 - cx1) / 2 

238 change_y = (cy2 - cy1) / 2 

239 angle = np.arctan2(change_y, change_x) / (2 * np.pi) * 360 

240 # Text is "right way up" 

241 if angle > 90: 

242 angle -= 180 

243 elif angle < -90: 

244 angle += 180 

245 (x, y) = self.ax.transData.inverted().transform((x, y)) 

246 return x, y, angle 

247 

248 def draw(self, renderer): 

249 # recalculate the text position and angle 

250 self.x, self.y, self.angle = self._update_text_pos_angle(self.arrow) 

251 self.set_position((self.x, self.y)) 

252 self.set_rotation(self.angle) 

253 # redraw text 

254 super().draw(renderer) 

255 

256 

257def display( 

258 G, 

259 canvas=None, 

260 **kwargs, 

261): 

262 """Draw the graph G. 

263 

264 Draw the graph as a collection of nodes connected by edges. 

265 The exact details of what the graph looks like are controlled by the below 

266 attributes. All nodes and nodes at the end of visible edges must have a 

267 position set, but nearly all other node and edge attributes are options and 

268 nodes or edges missing the attribute will use the default listed below. A more 

269 complete description of each parameter is given below this summary. 

270 

271 .. list-table:: Default Visualization Attributes 

272 :widths: 25 25 50 

273 :header-rows: 1 

274 

275 * - Parameter 

276 - Default Attribute 

277 - Default Value 

278 * - node_pos 

279 - `"pos"` 

280 - If there is not position, a layout will be calculated with `nx.spring_layout`. 

281 * - node_visible 

282 - `"visible"` 

283 - True 

284 * - node_color 

285 - `"color"` 

286 - #1f78b4 

287 * - node_size 

288 - `"size"` 

289 - 300 

290 * - node_label 

291 - `"label"` 

292 - Dict describing the node label. Defaults create a black text with 

293 the node name as the label. The dict respects these keys and defaults: 

294 

295 * size : 12 

296 * color : black 

297 * family : sans serif 

298 * weight : normal 

299 * alpha : 1.0 

300 * h_align : center 

301 * v_align : center 

302 * bbox : Dict describing a `matplotlib.patches.FancyBboxPatch`. 

303 Default is None. 

304 

305 * - node_shape 

306 - `"shape"` 

307 - "o" 

308 * - node_alpha 

309 - `"alpha"` 

310 - 1.0 

311 * - node_border_width 

312 - `"border_width"` 

313 - 1.0 

314 * - node_border_color 

315 - `"border_color"` 

316 - Matching node_color 

317 * - edge_visible 

318 - `"visible"` 

319 - True 

320 * - edge_width 

321 - `"width"` 

322 - 1.0 

323 * - edge_color 

324 - `"color"` 

325 - Black (#000000) 

326 * - edge_label 

327 - `"label"` 

328 - Dict describing the edge label. Defaults create black text with a 

329 white bounding box. The dictionary respects these keys and defaults: 

330 

331 * size : 12 

332 * color : black 

333 * family : sans serif 

334 * weight : normal 

335 * alpha : 1.0 

336 * bbox : Dict describing a `matplotlib.patches.FancyBboxPatch`. 

337 Default {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)} 

338 * h_align : "center" 

339 * v_align : "center" 

340 * pos : 0.5 

341 * rotate : True 

342 

343 * - edge_style 

344 - `"style"` 

345 - "-" 

346 * - edge_alpha 

347 - `"alpha"` 

348 - 1.0 

349 * - edge_arrowstyle 

350 - `"arrowstyle"` 

351 - ``"-|>"`` if `G` is directed else ``"-"`` 

352 * - edge_arrowsize 

353 - `"arrowsize"` 

354 - 10 if `G` is directed else 0 

355 * - edge_curvature 

356 - `"curvature"` 

357 - arc3 

358 * - edge_source_margin 

359 - `"source_margin"` 

360 - 0 

361 * - edge_target_margin 

362 - `"target_margin"` 

363 - 0 

364 

365 Parameters 

366 ---------- 

367 G : graph 

368 A networkx graph 

369 

370 canvas : Matplotlib Axes object, optional 

371 Draw the graph in specified Matplotlib axes 

372 

373 node_pos : string or function, default "pos" 

374 A string naming the node attribute storing the position of nodes as a tuple. 

375 Or a function to be called with input `G` which returns the layout as a dict keyed 

376 by node to position tuple like the NetworkX layout functions. 

377 If no nodes in the graph has the attribute, a spring layout is calculated. 

378 

379 node_visible : string or bool, default visible 

380 A string naming the node attribute which stores if a node should be drawn. 

381 If `True`, all nodes will be visible while if `False` no nodes will be visible. 

382 If incomplete, nodes missing this attribute will be shown by default. 

383 

384 node_color : string, default "color" 

385 A string naming the node attribute which stores the color of each node. 

386 Visible nodes without this attribute will use '#1f78b4' as a default. 

387 

388 node_size : string or number, default "size" 

389 A string naming the node attribute which stores the size of each node. 

390 Visible nodes without this attribute will use a default size of 300. 

391 

392 node_label : string or bool, default "label" 

393 A string naming the node attribute which stores the label of each node. 

394 The attribute value can be a string, False (no label for that node), 

395 True (the node is the label) or a dict keyed by node to the label. 

396 

397 If a dict is specified, these keys are read to further control the label: 

398 

399 * label : The text of the label; default: name of the node 

400 * size : Font size of the label; default: 12 

401 * color : Font color of the label; default: black 

402 * family : Font family of the label; default: "sans-serif" 

403 * weight : Font weight of the label; default: "normal" 

404 * alpha : Alpha value of the label; default: 1.0 

405 * h_align : The horizontal alignment of the label. 

406 one of "left", "center", "right"; default: "center" 

407 * v_align : The vertical alignment of the label. 

408 one of "top", "center", "bottom"; default: "center" 

409 * bbox : A dict of parameters for `matplotlib.patches.FancyBboxPatch`. 

410 

411 Visible nodes without this attribute will be treated as if the value was True. 

412 

413 node_shape : string, default "shape" 

414 A string naming the node attribute which stores the label of each node. 

415 The values of this attribute are expected to be one of the matplotlib shapes, 

416 one of 'so^>v<dph8'. Visible nodes without this attribute will use 'o'. 

417 

418 node_alpha : string, default "alpha" 

419 A string naming the node attribute which stores the alpha of each node. 

420 The values of this attribute are expected to be floats between 0.0 and 1.0. 

421 Visible nodes without this attribute will be treated as if the value was 1.0. 

422 

423 node_border_width : string, default "border_width" 

424 A string naming the node attribute storing the width of the border of the node. 

425 The values of this attribute are expected to be numeric. Visible nodes without 

426 this attribute will use the assumed default of 1.0. 

427 

428 node_border_color : string, default "border_color" 

429 A string naming the node attribute which storing the color of the border of the node. 

430 Visible nodes missing this attribute will use the final node_color value. 

431 

432 edge_visible : string or bool, default "visible" 

433 A string nameing the edge attribute which stores if an edge should be drawn. 

434 If `True`, all edges will be drawn while if `False` no edges will be visible. 

435 If incomplete, edges missing this attribute will be shown by default. Values 

436 of this attribute are expected to be booleans. 

437 

438 edge_width : string or int, default "width" 

439 A string nameing the edge attribute which stores the width of each edge. 

440 Visible edges without this attribute will use a default width of 1.0. 

441 

442 edge_color : string or color, default "color" 

443 A string nameing the edge attribute which stores of color of each edge. 

444 Visible edges without this attribute will be drawn black. Each color can be 

445 a string or rgb (or rgba) tuple of floats from 0.0 to 1.0. 

446 

447 edge_label : string, default "label" 

448 A string naming the edge attribute which stores the label of each edge. 

449 The values of this attribute can be a string, number or False or None. In 

450 the latter two cases, no edge label is displayed. 

451 

452 If a dict is specified, these keys are read to further control the label: 

453 

454 * label : The text of the label, or the name of an edge attribute holding the label. 

455 * size : Font size of the label; default: 12 

456 * color : Font color of the label; default: black 

457 * family : Font family of the label; default: "sans-serif" 

458 * weight : Font weight of the label; default: "normal" 

459 * alpha : Alpha value of the label; default: 1.0 

460 * h_align : The horizontal alignment of the label. 

461 one of "left", "center", "right"; default: "center" 

462 * v_align : The vertical alignment of the label. 

463 one of "top", "center", "bottom"; default: "center" 

464 * bbox : A dict of parameters for `matplotlib.patches.FancyBboxPatch`. 

465 * rotate : Whether to rotate labels to lie parallel to the edge, default: True. 

466 * pos : A float showing how far along the edge to put the label; default: 0.5. 

467 

468 edge_style : string, default "style" 

469 A string naming the edge attribute which stores the style of each edge. 

470 Visible edges without this attribute will be drawn solid. Values of this 

471 attribute can be line styles, e.g. '-', '--', '-.' or ':' or words like 'solid' 

472 or 'dashed'. If no edge in the graph has this attribute and it is a non-default 

473 value, assume that it describes the edge style for all edges in the graph. 

474 

475 edge_alpha : string or float, default "alpha" 

476 A string naming the edge attribute which stores the alpha value of each edge. 

477 Visible edges without this attribute will use an alpha value of 1.0. 

478 

479 edge_arrowstyle : string, default "arrowstyle" 

480 A string naming the edge attribute which stores the type of arrowhead to use for 

481 each edge. Visible edges without this attribute use ``"-"`` for undirected graphs 

482 and ``"-|>"`` for directed graphs. 

483 

484 See `matplotlib.patches.ArrowStyle` for more options 

485 

486 edge_arrowsize : string or int, default "arrowsize" 

487 A string naming the edge attribute which stores the size of the arrowhead for each 

488 edge. Visible edges without this attribute will use a default value of 10. 

489 

490 edge_curvature : string, default "curvature" 

491 A string naming the edge attribute storing the curvature and connection style 

492 of each edge. Visible edges without this attribute will use "arc3" as a default 

493 value, resulting an a straight line between the two nodes. Curvature can be given 

494 as 'arc3,rad=0.2' to specify both the style and radius of curvature. 

495 

496 Please see `matplotlib.patches.ConnectionStyle` and 

497 `matplotlib.patches.FancyArrowPatch` for more information. 

498 

499 edge_source_margin : string or int, default "source_margin" 

500 A string naming the edge attribute which stores the minimum margin (gap) between 

501 the source node and the start of the edge. Visible edges without this attribute 

502 will use a default value of 0. 

503 

504 edge_target_margin : string or int, default "target_margin" 

505 A string naming the edge attribute which stores the minimumm margin (gap) between 

506 the target node and the end of the edge. Visible edges without this attribute 

507 will use a default value of 0. 

508 

509 hide_ticks : bool, default True 

510 Weather to remove the ticks from the axes of the matplotlib object. 

511 

512 Raises 

513 ------ 

514 NetworkXError 

515 If a node or edge is missing a required parameter such as `pos` or 

516 if `display` receives an argument not listed above. 

517 

518 ValueError 

519 If a node or edge has an invalid color format, i.e. not a color string, 

520 rgb tuple or rgba tuple. 

521 

522 Returns 

523 ------- 

524 The input graph. This is potentially useful for dispatching visualization 

525 functions. 

526 """ 

527 from collections import Counter 

528 

529 import matplotlib as mpl 

530 import matplotlib.pyplot as plt 

531 import numpy as np 

532 

533 defaults = { 

534 "node_pos": None, 

535 "node_visible": True, 

536 "node_color": "#1f78b4", 

537 "node_size": 300, 

538 "node_label": { 

539 "size": 12, 

540 "color": "#000000", 

541 "family": "sans-serif", 

542 "weight": "normal", 

543 "alpha": 1.0, 

544 "h_align": "center", 

545 "v_align": "center", 

546 "bbox": None, 

547 }, 

548 "node_shape": "o", 

549 "node_alpha": 1.0, 

550 "node_border_width": 1.0, 

551 "node_border_color": "face", 

552 "edge_visible": True, 

553 "edge_width": 1.0, 

554 "edge_color": "#000000", 

555 "edge_label": { 

556 "size": 12, 

557 "color": "#000000", 

558 "family": "sans-serif", 

559 "weight": "normal", 

560 "alpha": 1.0, 

561 "bbox": {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)}, 

562 "h_align": "center", 

563 "v_align": "center", 

564 "pos": 0.5, 

565 "rotate": True, 

566 }, 

567 "edge_style": "-", 

568 "edge_alpha": 1.0, 

569 "edge_arrowstyle": "-|>" if G.is_directed() else "-", 

570 "edge_arrowsize": 10 if G.is_directed() else 0, 

571 "edge_curvature": "arc3", 

572 "edge_source_margin": 0, 

573 "edge_target_margin": 0, 

574 "hide_ticks": True, 

575 } 

576 

577 # Check arguments 

578 for kwarg in kwargs: 

579 if kwarg not in defaults: 

580 raise nx.NetworkXError( 

581 f"Unrecognized visualization keyword argument: {kwarg}" 

582 ) 

583 

584 if canvas is None: 

585 canvas = plt.gca() 

586 

587 if kwargs.get("hide_ticks", defaults["hide_ticks"]): 

588 canvas.tick_params( 

589 axis="both", 

590 which="both", 

591 bottom=False, 

592 left=False, 

593 labelbottom=False, 

594 labelleft=False, 

595 ) 

596 

597 ### Helper methods and classes 

598 

599 def node_property_sequence(seq, attr): 

600 """Return a list of attribute values for `seq`, using a default if needed""" 

601 

602 # All node attribute parameters start with "node_" 

603 param_name = f"node_{attr}" 

604 default = defaults[param_name] 

605 attr = kwargs.get(param_name, attr) 

606 

607 if default is None: 

608 # raise instead of using non-existant default value 

609 for n in seq: 

610 if attr not in node_subgraph.nodes[n]: 

611 raise nx.NetworkXError(f"Attribute '{attr}' missing for node {n}") 

612 

613 # If `attr` is not a graph attr and was explicitly passed as an argument 

614 # it must be a user-default value. Allow attr=None to tell draw to skip 

615 # attributes which are on the graph 

616 if ( 

617 attr is not None 

618 and nx.get_node_attributes(node_subgraph, attr) == {} 

619 and any(attr == v for k, v in kwargs.items() if "node" in k) 

620 ): 

621 return [attr for _ in seq] 

622 

623 return [node_subgraph.nodes[n].get(attr, default) for n in seq] 

624 

625 def compute_colors(color, alpha): 

626 if isinstance(color, str): 

627 rgba = mpl.colors.colorConverter.to_rgba(color) 

628 # Using a non-default alpha value overrides any alpha value in the color 

629 if alpha != defaults["node_alpha"]: 

630 return (rgba[0], rgba[1], rgba[2], alpha) 

631 return rgba 

632 

633 if isinstance(color, tuple) and len(color) == 3: 

634 return (color[0], color[1], color[2], alpha) 

635 

636 if isinstance(color, tuple) and len(color) == 4: 

637 return color 

638 

639 raise ValueError(f"Invalid format for color: {color}") 

640 

641 # Find which edges can be plotted as a line collection 

642 # 

643 # Non-default values for these attributes require fancy arrow patches: 

644 # - any arrow style (including the default -|> for directed graphs) 

645 # - arrow size (by extension of style) 

646 # - connection style 

647 # - min_source_margin 

648 # - min_target_margin 

649 

650 def collection_compatible(e): 

651 return ( 

652 get_edge_attr(e, "arrowstyle") == "-" 

653 and get_edge_attr(e, "curvature") == "arc3" 

654 and get_edge_attr(e, "source_margin") == 0 

655 and get_edge_attr(e, "target_margin") == 0 

656 # Self-loops will use fancy arrow patches 

657 and e[0] != e[1] 

658 ) 

659 

660 def edge_property_sequence(seq, attr): 

661 """Return a list of attribute values for `seq`, using a default if needed""" 

662 

663 param_name = f"edge_{attr}" 

664 default = defaults[param_name] 

665 attr = kwargs.get(param_name, attr) 

666 

667 if default is None: 

668 # raise instead of using non-existant default value 

669 for e in seq: 

670 if attr not in edge_subgraph.edges[e]: 

671 raise nx.NetworkXError(f"Attribute '{attr}' missing for edge {e}") 

672 

673 if ( 

674 attr is not None 

675 and nx.get_edge_attributes(edge_subgraph, attr) == {} 

676 and any(attr == v for k, v in kwargs.items() if "edge" in k) 

677 ): 

678 return [attr for _ in seq] 

679 

680 return [edge_subgraph.edges[e].get(attr, default) for e in seq] 

681 

682 def get_edge_attr(e, attr): 

683 """Return the final edge attribute value, using default if not None""" 

684 

685 param_name = f"edge_{attr}" 

686 default = defaults[param_name] 

687 attr = kwargs.get(param_name, attr) 

688 

689 if default is None and attr not in edge_subgraph.edges[e]: 

690 raise nx.NetworkXError(f"Attribute '{attr}' missing from edge {e}") 

691 

692 if ( 

693 attr is not None 

694 and nx.get_edge_attributes(edge_subgraph, attr) == {} 

695 and attr in kwargs.values() 

696 ): 

697 return attr 

698 

699 return edge_subgraph.edges[e].get(attr, default) 

700 

701 def get_node_attr(n, attr, use_edge_subgraph=True): 

702 """Return the final node attribute value, using default if not None""" 

703 subgraph = edge_subgraph if use_edge_subgraph else node_subgraph 

704 

705 param_name = f"node_{attr}" 

706 default = defaults[param_name] 

707 attr = kwargs.get(param_name, attr) 

708 

709 if default is None and attr not in subgraph.nodes[n]: 

710 raise nx.NetworkXError(f"Attribute '{attr}' missing from node {n}") 

711 

712 if ( 

713 attr is not None 

714 and nx.get_node_attributes(subgraph, attr) == {} 

715 and attr in kwargs.values() 

716 ): 

717 return attr 

718 

719 return subgraph.nodes[n].get(attr, default) 

720 

721 # Taken from ConnectionStyleFactory 

722 def self_loop(edge_index, node_size): 

723 def self_loop_connection(posA, posB, *args, **kwargs): 

724 if not np.all(posA == posB): 

725 raise nx.NetworkXError( 

726 "`self_loop` connection style method" 

727 "is only to be used for self-loops" 

728 ) 

729 # this is called with _screen space_ values 

730 # so convert back to data space 

731 data_loc = canvas.transData.inverted().transform(posA) 

732 # Scale self loop based on the size of the base node 

733 # Size of nodes are given in points ** 2 and each point is 1/72 of an inch 

734 v_shift = np.sqrt(node_size) / 72 

735 h_shift = v_shift * 0.5 

736 # put the top of the loop first so arrow is not hidden by node 

737 path = np.asarray( 

738 [ 

739 # 1 

740 [0, v_shift], 

741 # 4 4 4 

742 [h_shift, v_shift], 

743 [h_shift, 0], 

744 [0, 0], 

745 # 4 4 4 

746 [-h_shift, 0], 

747 [-h_shift, v_shift], 

748 [0, v_shift], 

749 ] 

750 ) 

751 # Rotate self loop 90 deg. if more than 1 

752 # This will allow for maximum of 4 visible self loops 

753 if edge_index % 4: 

754 x, y = path.T 

755 for _ in range(edge_index % 4): 

756 x, y = y, -x 

757 path = np.array([x, y]).T 

758 return mpl.path.Path( 

759 canvas.transData.transform(data_loc + path), [1, 4, 4, 4, 4, 4, 4] 

760 ) 

761 

762 return self_loop_connection 

763 

764 def to_marker_edge(size, marker): 

765 if marker in "s^>v<d": 

766 return np.sqrt(2 * size) / 2 

767 else: 

768 return np.sqrt(size) / 2 

769 

770 def build_fancy_arrow(e): 

771 source_margin = to_marker_edge( 

772 get_node_attr(e[0], "size"), 

773 get_node_attr(e[0], "shape"), 

774 ) 

775 source_margin = max( 

776 source_margin, 

777 get_edge_attr(e, "source_margin"), 

778 ) 

779 

780 target_margin = to_marker_edge( 

781 get_node_attr(e[1], "size"), 

782 get_node_attr(e[1], "shape"), 

783 ) 

784 target_margin = max( 

785 target_margin, 

786 get_edge_attr(e, "target_margin"), 

787 ) 

788 return mpl.patches.FancyArrowPatch( 

789 edge_subgraph.nodes[e[0]][pos], 

790 edge_subgraph.nodes[e[1]][pos], 

791 arrowstyle=get_edge_attr(e, "arrowstyle"), 

792 connectionstyle=( 

793 get_edge_attr(e, "curvature") 

794 if e[0] != e[1] 

795 else self_loop( 

796 0 if len(e) == 2 else e[2] % 4, 

797 get_node_attr(e[0], "size"), 

798 ) 

799 ), 

800 color=get_edge_attr(e, "color"), 

801 linestyle=get_edge_attr(e, "style"), 

802 linewidth=get_edge_attr(e, "width"), 

803 mutation_scale=get_edge_attr(e, "arrowsize"), 

804 shrinkA=source_margin, 

805 shrinkB=source_margin, 

806 zorder=1, 

807 ) 

808 

809 class CurvedArrowText(CurvedArrowTextBase, mpl.text.Text): 

810 pass 

811 

812 ### Draw the nodes first 

813 node_visible = kwargs.get("node_visible", "visible") 

814 if isinstance(node_visible, bool): 

815 if node_visible: 

816 visible_nodes = G.nodes() 

817 else: 

818 visible_nodes = [] 

819 else: 

820 visible_nodes = [ 

821 n for n, v in nx.get_node_attributes(G, node_visible, True).items() if v 

822 ] 

823 

824 node_subgraph = G.subgraph(visible_nodes) 

825 

826 # Ignore the default dict value since that's for default values to use, not 

827 # default attribute name 

828 pos = kwargs.get("node_pos", "pos") 

829 

830 default_display_pos_attr = "display's position attribute name" 

831 if callable(pos): 

832 nx.set_node_attributes( 

833 node_subgraph, pos(node_subgraph), default_display_pos_attr 

834 ) 

835 pos = default_display_pos_attr 

836 kwargs["node_pos"] = default_display_pos_attr 

837 elif nx.get_node_attributes(G, pos) == {}: 

838 nx.set_node_attributes( 

839 node_subgraph, nx.spring_layout(node_subgraph), default_display_pos_attr 

840 ) 

841 pos = default_display_pos_attr 

842 kwargs["node_pos"] = default_display_pos_attr 

843 

844 # Each shape requires a new scatter object since they can't have different 

845 # shapes. 

846 if len(visible_nodes) > 0: 

847 node_shape = kwargs.get("node_shape", "shape") 

848 for shape in Counter( 

849 nx.get_node_attributes( 

850 node_subgraph, node_shape, defaults["node_shape"] 

851 ).values() 

852 ): 

853 # Filter position just on this shape. 

854 nodes_with_shape = [ 

855 n 

856 for n, s in node_subgraph.nodes(data=node_shape) 

857 if s == shape or (s is None and shape == defaults["node_shape"]) 

858 ] 

859 # There are two property sequences to create before hand. 

860 # 1. position, since it is used for x and y parameters to scatter 

861 # 2. edgecolor, since the spaeical 'face' parameter value can only be 

862 # be passed in as the sole string, not part of a list of strings. 

863 position = np.asarray(node_property_sequence(nodes_with_shape, "pos")) 

864 color = np.asarray( 

865 [ 

866 compute_colors(c, a) 

867 for c, a in zip( 

868 node_property_sequence(nodes_with_shape, "color"), 

869 node_property_sequence(nodes_with_shape, "alpha"), 

870 ) 

871 ] 

872 ) 

873 border_color = np.asarray( 

874 [ 

875 ( 

876 c 

877 if ( 

878 c := get_node_attr( 

879 n, 

880 "border_color", 

881 False, 

882 ) 

883 ) 

884 != "face" 

885 else color[i] 

886 ) 

887 for i, n in enumerate(nodes_with_shape) 

888 ] 

889 ) 

890 canvas.scatter( 

891 position[:, 0], 

892 position[:, 1], 

893 s=node_property_sequence(nodes_with_shape, "size"), 

894 c=color, 

895 marker=shape, 

896 linewidths=node_property_sequence(nodes_with_shape, "border_width"), 

897 edgecolors=border_color, 

898 zorder=2, 

899 ) 

900 

901 ### Draw node labels 

902 node_label = kwargs.get("node_label", "label") 

903 # Plot labels if node_label is not None and not False 

904 if node_label is not None and node_label is not False: 

905 default_dict = {} 

906 if isinstance(node_label, dict): 

907 default_dict = node_label 

908 node_label = None 

909 

910 for n, lbl in node_subgraph.nodes(data=node_label): 

911 if lbl is False: 

912 continue 

913 

914 # We work with label dicts down here... 

915 if not isinstance(lbl, dict): 

916 lbl = {"label": lbl if lbl is not None else n} 

917 

918 lbl_text = lbl.get("label", n) 

919 if not isinstance(lbl_text, str): 

920 lbl_text = str(lbl_text) 

921 

922 lbl.update(default_dict) 

923 x, y = node_subgraph.nodes[n][pos] 

924 canvas.text( 

925 x, 

926 y, 

927 lbl_text, 

928 size=lbl.get("size", defaults["node_label"]["size"]), 

929 color=lbl.get("color", defaults["node_label"]["color"]), 

930 family=lbl.get("family", defaults["node_label"]["family"]), 

931 weight=lbl.get("weight", defaults["node_label"]["weight"]), 

932 horizontalalignment=lbl.get( 

933 "h_align", defaults["node_label"]["h_align"] 

934 ), 

935 verticalalignment=lbl.get("v_align", defaults["node_label"]["v_align"]), 

936 transform=canvas.transData, 

937 bbox=lbl.get("bbox", defaults["node_label"]["bbox"]), 

938 ) 

939 

940 ### Draw edges 

941 

942 edge_visible = kwargs.get("edge_visible", "visible") 

943 if isinstance(edge_visible, bool): 

944 if edge_visible: 

945 visible_edges = G.edges() 

946 else: 

947 visible_edges = [] 

948 else: 

949 visible_edges = [ 

950 e for e, v in nx.get_edge_attributes(G, edge_visible, True).items() if v 

951 ] 

952 

953 edge_subgraph = G.edge_subgraph(visible_edges) 

954 nx.set_node_attributes( 

955 edge_subgraph, nx.get_node_attributes(node_subgraph, pos), name=pos 

956 ) 

957 

958 collection_edges = ( 

959 [e for e in edge_subgraph.edges(keys=True) if collection_compatible(e)] 

960 if edge_subgraph.is_multigraph() 

961 else [e for e in edge_subgraph.edges() if collection_compatible(e)] 

962 ) 

963 non_collection_edges = ( 

964 [e for e in edge_subgraph.edges(keys=True) if not collection_compatible(e)] 

965 if edge_subgraph.is_multigraph() 

966 else [e for e in edge_subgraph.edges() if not collection_compatible(e)] 

967 ) 

968 edge_position = np.asarray( 

969 [ 

970 ( 

971 get_node_attr(u, "pos", use_edge_subgraph=True), 

972 get_node_attr(v, "pos", use_edge_subgraph=True), 

973 ) 

974 for u, v, *_ in collection_edges 

975 ] 

976 ) 

977 

978 # Only plot a line collection if needed 

979 if len(collection_edges) > 0: 

980 edge_collection = mpl.collections.LineCollection( 

981 edge_position, 

982 colors=edge_property_sequence(collection_edges, "color"), 

983 linewidths=edge_property_sequence(collection_edges, "width"), 

984 linestyle=edge_property_sequence(collection_edges, "style"), 

985 alpha=edge_property_sequence(collection_edges, "alpha"), 

986 antialiaseds=(1,), 

987 zorder=1, 

988 ) 

989 canvas.add_collection(edge_collection) 

990 

991 fancy_arrows = {} 

992 if len(non_collection_edges) > 0: 

993 for e in non_collection_edges: 

994 # Cache results for use in edge labels 

995 fancy_arrows[e] = build_fancy_arrow(e) 

996 canvas.add_patch(fancy_arrows[e]) 

997 

998 ### Draw edge labels 

999 edge_label = kwargs.get("edge_label", "label") 

1000 default_dict = {} 

1001 if isinstance(edge_label, dict): 

1002 default_dict = edge_label 

1003 # Restore the default label attribute key of 'label' 

1004 edge_label = "label" 

1005 

1006 # Handle multigraphs 

1007 edge_label_data = ( 

1008 edge_subgraph.edges(data=edge_label, keys=True) 

1009 if edge_subgraph.is_multigraph() 

1010 else edge_subgraph.edges(data=edge_label) 

1011 ) 

1012 if edge_label is not None and edge_label is not False: 

1013 for *e, lbl in edge_label_data: 

1014 e = tuple(e) 

1015 # I'm not sure how I want to handle None here... For now it means no label 

1016 if lbl is False or lbl is None: 

1017 continue 

1018 

1019 if not isinstance(lbl, dict): 

1020 lbl = {"label": lbl} 

1021 

1022 lbl.update(default_dict) 

1023 lbl_text = lbl.get("label") 

1024 if not isinstance(lbl_text, str): 

1025 lbl_text = str(lbl_text) 

1026 

1027 # In the old code, every non-self-loop is placed via a fancy arrow patch 

1028 # Only compute a new fancy arrow if needed by caching the results from 

1029 # edge placement. 

1030 try: 

1031 arrow = fancy_arrows[e] 

1032 except KeyError: 

1033 arrow = build_fancy_arrow(e) 

1034 

1035 if e[0] == e[1]: 

1036 # Taken directly from draw_networkx_edge_labels 

1037 connectionstyle_obj = arrow.get_connectionstyle() 

1038 posA = canvas.transData.transform(edge_subgraph.nodes[e[0]][pos]) 

1039 path_disp = connectionstyle_obj(posA, posA) 

1040 path_data = canvas.transData.inverted().transform_path(path_disp) 

1041 x, y = path_data.vertices[0] 

1042 canvas.text( 

1043 x, 

1044 y, 

1045 lbl_text, 

1046 size=lbl.get("size", defaults["edge_label"]["size"]), 

1047 color=lbl.get("color", defaults["edge_label"]["color"]), 

1048 family=lbl.get("family", defaults["edge_label"]["family"]), 

1049 weight=lbl.get("weight", defaults["edge_label"]["weight"]), 

1050 alpha=lbl.get("alpha", defaults["edge_label"]["alpha"]), 

1051 horizontalalignment=lbl.get( 

1052 "h_align", defaults["edge_label"]["h_align"] 

1053 ), 

1054 verticalalignment=lbl.get( 

1055 "v_align", defaults["edge_label"]["v_align"] 

1056 ), 

1057 rotation=0, 

1058 transform=canvas.transData, 

1059 bbox=lbl.get("bbox", defaults["edge_label"]["bbox"]), 

1060 zorder=1, 

1061 ) 

1062 continue 

1063 

1064 CurvedArrowText( 

1065 arrow, 

1066 lbl_text, 

1067 size=lbl.get("size", defaults["edge_label"]["size"]), 

1068 color=lbl.get("color", defaults["edge_label"]["color"]), 

1069 family=lbl.get("family", defaults["edge_label"]["family"]), 

1070 weight=lbl.get("weight", defaults["edge_label"]["weight"]), 

1071 alpha=lbl.get("alpha", defaults["edge_label"]["alpha"]), 

1072 bbox=lbl.get("bbox", defaults["edge_label"]["bbox"]), 

1073 horizontalalignment=lbl.get( 

1074 "h_align", defaults["edge_label"]["h_align"] 

1075 ), 

1076 verticalalignment=lbl.get("v_align", defaults["edge_label"]["v_align"]), 

1077 label_pos=lbl.get("pos", defaults["edge_label"]["pos"]), 

1078 labels_horizontal=lbl.get("rotate", defaults["edge_label"]["rotate"]), 

1079 transform=canvas.transData, 

1080 zorder=1, 

1081 ax=canvas, 

1082 ) 

1083 

1084 # If we had to add an attribute, remove it here 

1085 if pos == default_display_pos_attr: 

1086 nx.remove_node_attributes(G, default_display_pos_attr) 

1087 

1088 return G 

1089 

1090 

1091def draw(G, pos=None, ax=None, **kwds): 

1092 """Draw the graph G with Matplotlib. 

1093 

1094 Draw the graph as a simple representation with no node 

1095 labels or edge labels and using the full Matplotlib figure area 

1096 and no axis labels by default. See draw_networkx() for more 

1097 full-featured drawing that allows title, axis labels etc. 

1098 

1099 Parameters 

1100 ---------- 

1101 G : graph 

1102 A networkx graph 

1103 

1104 pos : dictionary, optional 

1105 A dictionary with nodes as keys and positions as values. 

1106 If not specified a spring layout positioning will be computed. 

1107 See :py:mod:`networkx.drawing.layout` for functions that 

1108 compute node positions. 

1109 

1110 ax : Matplotlib Axes object, optional 

1111 Draw the graph in specified Matplotlib axes. 

1112 

1113 kwds : optional keywords 

1114 See networkx.draw_networkx() for a description of optional keywords. 

1115 

1116 Examples 

1117 -------- 

1118 >>> G = nx.dodecahedral_graph() 

1119 >>> nx.draw(G) 

1120 >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout 

1121 

1122 See Also 

1123 -------- 

1124 draw_networkx 

1125 draw_networkx_nodes 

1126 draw_networkx_edges 

1127 draw_networkx_labels 

1128 draw_networkx_edge_labels 

1129 

1130 Notes 

1131 ----- 

1132 This function has the same name as pylab.draw and pyplot.draw 

1133 so beware when using `from networkx import *` 

1134 

1135 since you might overwrite the pylab.draw function. 

1136 

1137 With pyplot use 

1138 

1139 >>> import matplotlib.pyplot as plt 

1140 >>> G = nx.dodecahedral_graph() 

1141 >>> nx.draw(G) # networkx draw() 

1142 >>> plt.draw() # pyplot draw() 

1143 

1144 Also see the NetworkX drawing examples at 

1145 https://networkx.org/documentation/latest/auto_examples/index.html 

1146 """ 

1147 

1148 import matplotlib.pyplot as plt 

1149 

1150 if ax is None: 

1151 cf = plt.gcf() 

1152 else: 

1153 cf = ax.get_figure() 

1154 cf.set_facecolor("w") 

1155 if ax is None: 

1156 if cf.axes: 

1157 ax = cf.gca() 

1158 else: 

1159 ax = cf.add_axes((0, 0, 1, 1)) 

1160 

1161 if "with_labels" not in kwds: 

1162 kwds["with_labels"] = "labels" in kwds 

1163 

1164 draw_networkx(G, pos=pos, ax=ax, **kwds) 

1165 ax.set_axis_off() 

1166 plt.draw_if_interactive() 

1167 return 

1168 

1169 

1170def draw_networkx(G, pos=None, arrows=None, with_labels=True, **kwds): 

1171 r"""Draw the graph G using Matplotlib. 

1172 

1173 Draw the graph with Matplotlib with options for node positions, 

1174 labeling, titles, and many other drawing features. 

1175 See draw() for simple drawing without labels or axes. 

1176 

1177 Parameters 

1178 ---------- 

1179 G : graph 

1180 A networkx graph 

1181 

1182 pos : dictionary, optional 

1183 A dictionary with nodes as keys and positions as values. 

1184 If not specified a spring layout positioning will be computed. 

1185 See :py:mod:`networkx.drawing.layout` for functions that 

1186 compute node positions. 

1187 

1188 arrows : bool or None, optional (default=None) 

1189 If `None`, directed graphs draw arrowheads with 

1190 `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges 

1191 via `~matplotlib.collections.LineCollection` for speed. 

1192 If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish). 

1193 If `False`, draw edges using LineCollection (linear and fast). 

1194 For directed graphs, if True draw arrowheads. 

1195 Note: Arrows will be the same color as edges. 

1196 

1197 arrowstyle : str (default='-\|>' for directed graphs) 

1198 For directed graphs, choose the style of the arrowsheads. 

1199 For undirected graphs default to '-' 

1200 

1201 See `matplotlib.patches.ArrowStyle` for more options. 

1202 

1203 arrowsize : int or list (default=10) 

1204 For directed graphs, choose the size of the arrow head's length and 

1205 width. A list of values can be passed in to assign a different size for arrow head's length and width. 

1206 See `matplotlib.patches.FancyArrowPatch` for attribute `mutation_scale` 

1207 for more info. 

1208 

1209 with_labels : bool (default=True) 

1210 Set to True to draw labels on the nodes. 

1211 

1212 ax : Matplotlib Axes object, optional 

1213 Draw the graph in the specified Matplotlib axes. 

1214 

1215 nodelist : list (default=list(G)) 

1216 Draw only specified nodes 

1217 

1218 edgelist : list (default=list(G.edges())) 

1219 Draw only specified edges 

1220 

1221 node_size : scalar or array (default=300) 

1222 Size of nodes. If an array is specified it must be the 

1223 same length as nodelist. 

1224 

1225 node_color : color or array of colors (default='#1f78b4') 

1226 Node color. Can be a single color or a sequence of colors with the same 

1227 length as nodelist. Color can be string or rgb (or rgba) tuple of 

1228 floats from 0-1. If numeric values are specified they will be 

1229 mapped to colors using the cmap and vmin,vmax parameters. See 

1230 matplotlib.scatter for more details. 

1231 

1232 node_shape : string (default='o') 

1233 The shape of the node. Specification is as matplotlib.scatter 

1234 marker, one of 'so^>v<dph8'. 

1235 

1236 alpha : float or None (default=None) 

1237 The node and edge transparency 

1238 

1239 cmap : Matplotlib colormap, optional 

1240 Colormap for mapping intensities of nodes 

1241 

1242 vmin,vmax : float, optional 

1243 Minimum and maximum for node colormap scaling 

1244 

1245 linewidths : scalar or sequence (default=1.0) 

1246 Line width of symbol border 

1247 

1248 width : float or array of floats (default=1.0) 

1249 Line width of edges 

1250 

1251 edge_color : color or array of colors (default='k') 

1252 Edge color. Can be a single color or a sequence of colors with the same 

1253 length as edgelist. Color can be string or rgb (or rgba) tuple of 

1254 floats from 0-1. If numeric values are specified they will be 

1255 mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters. 

1256 

1257 edge_cmap : Matplotlib colormap, optional 

1258 Colormap for mapping intensities of edges 

1259 

1260 edge_vmin,edge_vmax : floats, optional 

1261 Minimum and maximum for edge colormap scaling 

1262 

1263 style : string (default=solid line) 

1264 Edge line style e.g.: '-', '--', '-.', ':' 

1265 or words like 'solid' or 'dashed'. 

1266 (See `matplotlib.patches.FancyArrowPatch`: `linestyle`) 

1267 

1268 labels : dictionary (default=None) 

1269 Node labels in a dictionary of text labels keyed by node 

1270 

1271 font_size : int (default=12 for nodes, 10 for edges) 

1272 Font size for text labels 

1273 

1274 font_color : color (default='k' black) 

1275 Font color string. Color can be string or rgb (or rgba) tuple of 

1276 floats from 0-1. 

1277 

1278 font_weight : string (default='normal') 

1279 Font weight 

1280 

1281 font_family : string (default='sans-serif') 

1282 Font family 

1283 

1284 label : string, optional 

1285 Label for graph legend 

1286 

1287 hide_ticks : bool, optional 

1288 Hide ticks of axes. When `True` (the default), ticks and ticklabels 

1289 are removed from the axes. To set ticks and tick labels to the pyplot default, 

1290 use ``hide_ticks=False``. 

1291 

1292 kwds : optional keywords 

1293 See networkx.draw_networkx_nodes(), networkx.draw_networkx_edges(), and 

1294 networkx.draw_networkx_labels() for a description of optional keywords. 

1295 

1296 Notes 

1297 ----- 

1298 For directed graphs, arrows are drawn at the head end. Arrows can be 

1299 turned off with keyword arrows=False. 

1300 

1301 Examples 

1302 -------- 

1303 >>> G = nx.dodecahedral_graph() 

1304 >>> nx.draw(G) 

1305 >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout 

1306 

1307 >>> import matplotlib.pyplot as plt 

1308 >>> limits = plt.axis("off") # turn off axis 

1309 

1310 Also see the NetworkX drawing examples at 

1311 https://networkx.org/documentation/latest/auto_examples/index.html 

1312 

1313 See Also 

1314 -------- 

1315 draw 

1316 draw_networkx_nodes 

1317 draw_networkx_edges 

1318 draw_networkx_labels 

1319 draw_networkx_edge_labels 

1320 """ 

1321 from inspect import signature 

1322 

1323 import matplotlib.pyplot as plt 

1324 

1325 # Get all valid keywords by inspecting the signatures of draw_networkx_nodes, 

1326 # draw_networkx_edges, draw_networkx_labels 

1327 

1328 valid_node_kwds = signature(draw_networkx_nodes).parameters.keys() 

1329 valid_edge_kwds = signature(draw_networkx_edges).parameters.keys() 

1330 valid_label_kwds = signature(draw_networkx_labels).parameters.keys() 

1331 

1332 # Create a set with all valid keywords across the three functions and 

1333 # remove the arguments of this function (draw_networkx) 

1334 valid_kwds = (valid_node_kwds | valid_edge_kwds | valid_label_kwds) - { 

1335 "G", 

1336 "pos", 

1337 "arrows", 

1338 "with_labels", 

1339 } 

1340 

1341 if any(k not in valid_kwds for k in kwds): 

1342 invalid_args = ", ".join([k for k in kwds if k not in valid_kwds]) 

1343 raise ValueError(f"Received invalid argument(s): {invalid_args}") 

1344 

1345 node_kwds = {k: v for k, v in kwds.items() if k in valid_node_kwds} 

1346 edge_kwds = {k: v for k, v in kwds.items() if k in valid_edge_kwds} 

1347 label_kwds = {k: v for k, v in kwds.items() if k in valid_label_kwds} 

1348 

1349 if pos is None: 

1350 pos = nx.drawing.spring_layout(G) # default to spring layout 

1351 

1352 draw_networkx_nodes(G, pos, **node_kwds) 

1353 draw_networkx_edges(G, pos, arrows=arrows, **edge_kwds) 

1354 if with_labels: 

1355 draw_networkx_labels(G, pos, **label_kwds) 

1356 plt.draw_if_interactive() 

1357 

1358 

1359def draw_networkx_nodes( 

1360 G, 

1361 pos, 

1362 nodelist=None, 

1363 node_size=300, 

1364 node_color="#1f78b4", 

1365 node_shape="o", 

1366 alpha=None, 

1367 cmap=None, 

1368 vmin=None, 

1369 vmax=None, 

1370 ax=None, 

1371 linewidths=None, 

1372 edgecolors=None, 

1373 label=None, 

1374 margins=None, 

1375 hide_ticks=True, 

1376): 

1377 """Draw the nodes of the graph G. 

1378 

1379 This draws only the nodes of the graph G. 

1380 

1381 Parameters 

1382 ---------- 

1383 G : graph 

1384 A networkx graph 

1385 

1386 pos : dictionary 

1387 A dictionary with nodes as keys and positions as values. 

1388 Positions should be sequences of length 2. 

1389 

1390 ax : Matplotlib Axes object, optional 

1391 Draw the graph in the specified Matplotlib axes. 

1392 

1393 nodelist : list (default list(G)) 

1394 Draw only specified nodes 

1395 

1396 node_size : scalar or array (default=300) 

1397 Size of nodes. If an array it must be the same length as nodelist. 

1398 

1399 node_color : color or array of colors (default='#1f78b4') 

1400 Node color. Can be a single color or a sequence of colors with the same 

1401 length as nodelist. Color can be string or rgb (or rgba) tuple of 

1402 floats from 0-1. If numeric values are specified they will be 

1403 mapped to colors using the cmap and vmin,vmax parameters. See 

1404 matplotlib.scatter for more details. 

1405 

1406 node_shape : string (default='o') 

1407 The shape of the node. Specification is as matplotlib.scatter 

1408 marker, one of 'so^>v<dph8'. 

1409 

1410 alpha : float or array of floats (default=None) 

1411 The node transparency. This can be a single alpha value, 

1412 in which case it will be applied to all the nodes of color. Otherwise, 

1413 if it is an array, the elements of alpha will be applied to the colors 

1414 in order (cycling through alpha multiple times if necessary). 

1415 

1416 cmap : Matplotlib colormap (default=None) 

1417 Colormap for mapping intensities of nodes 

1418 

1419 vmin,vmax : floats or None (default=None) 

1420 Minimum and maximum for node colormap scaling 

1421 

1422 linewidths : [None | scalar | sequence] (default=1.0) 

1423 Line width of symbol border 

1424 

1425 edgecolors : [None | scalar | sequence] (default = node_color) 

1426 Colors of node borders. Can be a single color or a sequence of colors with the 

1427 same length as nodelist. Color can be string or rgb (or rgba) tuple of floats 

1428 from 0-1. If numeric values are specified they will be mapped to colors 

1429 using the cmap and vmin,vmax parameters. See `~matplotlib.pyplot.scatter` for more details. 

1430 

1431 label : [None | string] 

1432 Label for legend 

1433 

1434 margins : float or 2-tuple, optional 

1435 Sets the padding for axis autoscaling. Increase margin to prevent 

1436 clipping for nodes that are near the edges of an image. Values should 

1437 be in the range ``[0, 1]``. See :meth:`matplotlib.axes.Axes.margins` 

1438 for details. The default is `None`, which uses the Matplotlib default. 

1439 

1440 hide_ticks : bool, optional 

1441 Hide ticks of axes. When `True` (the default), ticks and ticklabels 

1442 are removed from the axes. To set ticks and tick labels to the pyplot default, 

1443 use ``hide_ticks=False``. 

1444 

1445 Returns 

1446 ------- 

1447 matplotlib.collections.PathCollection 

1448 `PathCollection` of the nodes. 

1449 

1450 Examples 

1451 -------- 

1452 >>> G = nx.dodecahedral_graph() 

1453 >>> nodes = nx.draw_networkx_nodes(G, pos=nx.spring_layout(G)) 

1454 

1455 Also see the NetworkX drawing examples at 

1456 https://networkx.org/documentation/latest/auto_examples/index.html 

1457 

1458 See Also 

1459 -------- 

1460 draw 

1461 draw_networkx 

1462 draw_networkx_edges 

1463 draw_networkx_labels 

1464 draw_networkx_edge_labels 

1465 """ 

1466 from collections.abc import Iterable 

1467 

1468 import matplotlib as mpl 

1469 import matplotlib.collections # call as mpl.collections 

1470 import matplotlib.pyplot as plt 

1471 import numpy as np 

1472 

1473 if ax is None: 

1474 ax = plt.gca() 

1475 

1476 if nodelist is None: 

1477 nodelist = list(G) 

1478 

1479 if len(nodelist) == 0: # empty nodelist, no drawing 

1480 return mpl.collections.PathCollection(None) 

1481 

1482 try: 

1483 xy = np.asarray([pos[v] for v in nodelist]) 

1484 except KeyError as err: 

1485 raise nx.NetworkXError(f"Node {err} has no position.") from err 

1486 

1487 if isinstance(alpha, Iterable): 

1488 node_color = apply_alpha(node_color, alpha, nodelist, cmap, vmin, vmax) 

1489 alpha = None 

1490 

1491 if not isinstance(node_shape, np.ndarray) and not isinstance(node_shape, list): 

1492 node_shape = np.array([node_shape for _ in range(len(nodelist))]) 

1493 

1494 for shape in np.unique(node_shape): 

1495 node_collection = ax.scatter( 

1496 xy[node_shape == shape, 0], 

1497 xy[node_shape == shape, 1], 

1498 s=node_size, 

1499 c=node_color, 

1500 marker=shape, 

1501 cmap=cmap, 

1502 vmin=vmin, 

1503 vmax=vmax, 

1504 alpha=alpha, 

1505 linewidths=linewidths, 

1506 edgecolors=edgecolors, 

1507 label=label, 

1508 ) 

1509 if hide_ticks: 

1510 ax.tick_params( 

1511 axis="both", 

1512 which="both", 

1513 bottom=False, 

1514 left=False, 

1515 labelbottom=False, 

1516 labelleft=False, 

1517 ) 

1518 

1519 if margins is not None: 

1520 if isinstance(margins, Iterable): 

1521 ax.margins(*margins) 

1522 else: 

1523 ax.margins(margins) 

1524 

1525 node_collection.set_zorder(2) 

1526 return node_collection 

1527 

1528 

1529class FancyArrowFactory: 

1530 """Draw arrows with `matplotlib.patches.FancyarrowPatch`""" 

1531 

1532 class ConnectionStyleFactory: 

1533 def __init__(self, connectionstyles, selfloop_height, ax=None): 

1534 import matplotlib as mpl 

1535 import matplotlib.path # call as mpl.path 

1536 import numpy as np 

1537 

1538 self.ax = ax 

1539 self.mpl = mpl 

1540 self.np = np 

1541 self.base_connection_styles = [ 

1542 mpl.patches.ConnectionStyle(cs) for cs in connectionstyles 

1543 ] 

1544 self.n = len(self.base_connection_styles) 

1545 self.selfloop_height = selfloop_height 

1546 

1547 def curved(self, edge_index): 

1548 return self.base_connection_styles[edge_index % self.n] 

1549 

1550 def self_loop(self, edge_index): 

1551 def self_loop_connection(posA, posB, *args, **kwargs): 

1552 if not self.np.all(posA == posB): 

1553 raise nx.NetworkXError( 

1554 "`self_loop` connection style method" 

1555 "is only to be used for self-loops" 

1556 ) 

1557 # this is called with _screen space_ values 

1558 # so convert back to data space 

1559 data_loc = self.ax.transData.inverted().transform(posA) 

1560 v_shift = 0.1 * self.selfloop_height 

1561 h_shift = v_shift * 0.5 

1562 # put the top of the loop first so arrow is not hidden by node 

1563 path = self.np.asarray( 

1564 [ 

1565 # 1 

1566 [0, v_shift], 

1567 # 4 4 4 

1568 [h_shift, v_shift], 

1569 [h_shift, 0], 

1570 [0, 0], 

1571 # 4 4 4 

1572 [-h_shift, 0], 

1573 [-h_shift, v_shift], 

1574 [0, v_shift], 

1575 ] 

1576 ) 

1577 # Rotate self loop 90 deg. if more than 1 

1578 # This will allow for maximum of 4 visible self loops 

1579 if edge_index % 4: 

1580 x, y = path.T 

1581 for _ in range(edge_index % 4): 

1582 x, y = y, -x 

1583 path = self.np.array([x, y]).T 

1584 return self.mpl.path.Path( 

1585 self.ax.transData.transform(data_loc + path), [1, 4, 4, 4, 4, 4, 4] 

1586 ) 

1587 

1588 return self_loop_connection 

1589 

1590 def __init__( 

1591 self, 

1592 edge_pos, 

1593 edgelist, 

1594 nodelist, 

1595 edge_indices, 

1596 node_size, 

1597 selfloop_height, 

1598 connectionstyle="arc3", 

1599 node_shape="o", 

1600 arrowstyle="-", 

1601 arrowsize=10, 

1602 edge_color="k", 

1603 alpha=None, 

1604 linewidth=1.0, 

1605 style="solid", 

1606 min_source_margin=0, 

1607 min_target_margin=0, 

1608 ax=None, 

1609 ): 

1610 import matplotlib as mpl 

1611 import matplotlib.patches # call as mpl.patches 

1612 import matplotlib.pyplot as plt 

1613 import numpy as np 

1614 

1615 if isinstance(connectionstyle, str): 

1616 connectionstyle = [connectionstyle] 

1617 elif np.iterable(connectionstyle): 

1618 connectionstyle = list(connectionstyle) 

1619 else: 

1620 msg = "ConnectionStyleFactory arg `connectionstyle` must be str or iterable" 

1621 raise nx.NetworkXError(msg) 

1622 self.ax = ax 

1623 self.mpl = mpl 

1624 self.np = np 

1625 self.edge_pos = edge_pos 

1626 self.edgelist = edgelist 

1627 self.nodelist = nodelist 

1628 self.node_shape = node_shape 

1629 self.min_source_margin = min_source_margin 

1630 self.min_target_margin = min_target_margin 

1631 self.edge_indices = edge_indices 

1632 self.node_size = node_size 

1633 self.connectionstyle_factory = self.ConnectionStyleFactory( 

1634 connectionstyle, selfloop_height, ax 

1635 ) 

1636 self.arrowstyle = arrowstyle 

1637 self.arrowsize = arrowsize 

1638 self.arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha) 

1639 self.linewidth = linewidth 

1640 self.style = style 

1641 if isinstance(arrowsize, list) and len(arrowsize) != len(edge_pos): 

1642 raise ValueError("arrowsize should have the same length as edgelist") 

1643 

1644 def __call__(self, i): 

1645 (x1, y1), (x2, y2) = self.edge_pos[i] 

1646 shrink_source = 0 # space from source to tail 

1647 shrink_target = 0 # space from head to target 

1648 if ( 

1649 self.np.iterable(self.min_source_margin) 

1650 and not isinstance(self.min_source_margin, str) 

1651 and not isinstance(self.min_source_margin, tuple) 

1652 ): 

1653 min_source_margin = self.min_source_margin[i] 

1654 else: 

1655 min_source_margin = self.min_source_margin 

1656 

1657 if ( 

1658 self.np.iterable(self.min_target_margin) 

1659 and not isinstance(self.min_target_margin, str) 

1660 and not isinstance(self.min_target_margin, tuple) 

1661 ): 

1662 min_target_margin = self.min_target_margin[i] 

1663 else: 

1664 min_target_margin = self.min_target_margin 

1665 

1666 if self.np.iterable(self.node_size): # many node sizes 

1667 source, target = self.edgelist[i][:2] 

1668 source_node_size = self.node_size[self.nodelist.index(source)] 

1669 target_node_size = self.node_size[self.nodelist.index(target)] 

1670 shrink_source = self.to_marker_edge(source_node_size, self.node_shape) 

1671 shrink_target = self.to_marker_edge(target_node_size, self.node_shape) 

1672 else: 

1673 shrink_source = self.to_marker_edge(self.node_size, self.node_shape) 

1674 shrink_target = shrink_source 

1675 shrink_source = max(shrink_source, min_source_margin) 

1676 shrink_target = max(shrink_target, min_target_margin) 

1677 

1678 # scale factor of arrow head 

1679 if isinstance(self.arrowsize, list): 

1680 mutation_scale = self.arrowsize[i] 

1681 else: 

1682 mutation_scale = self.arrowsize 

1683 

1684 if len(self.arrow_colors) > i: 

1685 arrow_color = self.arrow_colors[i] 

1686 elif len(self.arrow_colors) == 1: 

1687 arrow_color = self.arrow_colors[0] 

1688 else: # Cycle through colors 

1689 arrow_color = self.arrow_colors[i % len(self.arrow_colors)] 

1690 

1691 if self.np.iterable(self.linewidth): 

1692 if len(self.linewidth) > i: 

1693 linewidth = self.linewidth[i] 

1694 else: 

1695 linewidth = self.linewidth[i % len(self.linewidth)] 

1696 else: 

1697 linewidth = self.linewidth 

1698 

1699 if ( 

1700 self.np.iterable(self.style) 

1701 and not isinstance(self.style, str) 

1702 and not isinstance(self.style, tuple) 

1703 ): 

1704 if len(self.style) > i: 

1705 linestyle = self.style[i] 

1706 else: # Cycle through styles 

1707 linestyle = self.style[i % len(self.style)] 

1708 else: 

1709 linestyle = self.style 

1710 

1711 if x1 == x2 and y1 == y2: 

1712 connectionstyle = self.connectionstyle_factory.self_loop( 

1713 self.edge_indices[i] 

1714 ) 

1715 else: 

1716 connectionstyle = self.connectionstyle_factory.curved(self.edge_indices[i]) 

1717 

1718 if ( 

1719 self.np.iterable(self.arrowstyle) 

1720 and not isinstance(self.arrowstyle, str) 

1721 and not isinstance(self.arrowstyle, tuple) 

1722 ): 

1723 arrowstyle = self.arrowstyle[i] 

1724 else: 

1725 arrowstyle = self.arrowstyle 

1726 

1727 return self.mpl.patches.FancyArrowPatch( 

1728 (x1, y1), 

1729 (x2, y2), 

1730 arrowstyle=arrowstyle, 

1731 shrinkA=shrink_source, 

1732 shrinkB=shrink_target, 

1733 mutation_scale=mutation_scale, 

1734 color=arrow_color, 

1735 linewidth=linewidth, 

1736 connectionstyle=connectionstyle, 

1737 linestyle=linestyle, 

1738 zorder=1, # arrows go behind nodes 

1739 ) 

1740 

1741 def to_marker_edge(self, marker_size, marker): 

1742 if marker in "s^>v<d": # `large` markers need extra space 

1743 return self.np.sqrt(2 * marker_size) / 2 

1744 else: 

1745 return self.np.sqrt(marker_size) / 2 

1746 

1747 

1748def draw_networkx_edges( 

1749 G, 

1750 pos, 

1751 edgelist=None, 

1752 width=1.0, 

1753 edge_color="k", 

1754 style="solid", 

1755 alpha=None, 

1756 arrowstyle=None, 

1757 arrowsize=10, 

1758 edge_cmap=None, 

1759 edge_vmin=None, 

1760 edge_vmax=None, 

1761 ax=None, 

1762 arrows=None, 

1763 label=None, 

1764 node_size=300, 

1765 nodelist=None, 

1766 node_shape="o", 

1767 connectionstyle="arc3", 

1768 min_source_margin=0, 

1769 min_target_margin=0, 

1770 hide_ticks=True, 

1771): 

1772 r"""Draw the edges of the graph G. 

1773 

1774 This draws only the edges of the graph G. 

1775 

1776 Parameters 

1777 ---------- 

1778 G : graph 

1779 A networkx graph 

1780 

1781 pos : dictionary 

1782 A dictionary with nodes as keys and positions as values. 

1783 Positions should be sequences of length 2. 

1784 

1785 edgelist : collection of edge tuples (default=G.edges()) 

1786 Draw only specified edges 

1787 

1788 width : float or array of floats (default=1.0) 

1789 Line width of edges 

1790 

1791 edge_color : color or array of colors (default='k') 

1792 Edge color. Can be a single color or a sequence of colors with the same 

1793 length as edgelist. Color can be string or rgb (or rgba) tuple of 

1794 floats from 0-1. If numeric values are specified they will be 

1795 mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters. 

1796 

1797 style : string or array of strings (default='solid') 

1798 Edge line style e.g.: '-', '--', '-.', ':' 

1799 or words like 'solid' or 'dashed'. 

1800 Can be a single style or a sequence of styles with the same 

1801 length as the edge list. 

1802 If less styles than edges are given the styles will cycle. 

1803 If more styles than edges are given the styles will be used sequentially 

1804 and not be exhausted. 

1805 Also, `(offset, onoffseq)` tuples can be used as style instead of a strings. 

1806 (See `matplotlib.patches.FancyArrowPatch`: `linestyle`) 

1807 

1808 alpha : float or array of floats (default=None) 

1809 The edge transparency. This can be a single alpha value, 

1810 in which case it will be applied to all specified edges. Otherwise, 

1811 if it is an array, the elements of alpha will be applied to the colors 

1812 in order (cycling through alpha multiple times if necessary). 

1813 

1814 edge_cmap : Matplotlib colormap, optional 

1815 Colormap for mapping intensities of edges 

1816 

1817 edge_vmin,edge_vmax : floats, optional 

1818 Minimum and maximum for edge colormap scaling 

1819 

1820 ax : Matplotlib Axes object, optional 

1821 Draw the graph in the specified Matplotlib axes. 

1822 

1823 arrows : bool or None, optional (default=None) 

1824 If `None`, directed graphs draw arrowheads with 

1825 `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges 

1826 via `~matplotlib.collections.LineCollection` for speed. 

1827 If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish). 

1828 If `False`, draw edges using LineCollection (linear and fast). 

1829 

1830 Note: Arrowheads will be the same color as edges. 

1831 

1832 arrowstyle : str or list of strs (default='-\|>' for directed graphs) 

1833 For directed graphs and `arrows==True` defaults to '-\|>', 

1834 For undirected graphs default to '-'. 

1835 

1836 See `matplotlib.patches.ArrowStyle` for more options. 

1837 

1838 arrowsize : int or list of ints(default=10) 

1839 For directed graphs, choose the size of the arrow head's length and 

1840 width. See `matplotlib.patches.FancyArrowPatch` for attribute 

1841 `mutation_scale` for more info. 

1842 

1843 connectionstyle : string or iterable of strings (default="arc3") 

1844 Pass the connectionstyle parameter to create curved arc of rounding 

1845 radius rad. For example, connectionstyle='arc3,rad=0.2'. 

1846 See `matplotlib.patches.ConnectionStyle` and 

1847 `matplotlib.patches.FancyArrowPatch` for more info. 

1848 If Iterable, index indicates i'th edge key of MultiGraph 

1849 

1850 node_size : scalar or array (default=300) 

1851 Size of nodes. Though the nodes are not drawn with this function, the 

1852 node size is used in determining edge positioning. 

1853 

1854 nodelist : list, optional (default=G.nodes()) 

1855 This provides the node order for the `node_size` array (if it is an array). 

1856 

1857 node_shape : string (default='o') 

1858 The marker used for nodes, used in determining edge positioning. 

1859 Specification is as a `matplotlib.markers` marker, e.g. one of 'so^>v<dph8'. 

1860 

1861 label : None or string 

1862 Label for legend 

1863 

1864 min_source_margin : int or list of ints (default=0) 

1865 The minimum margin (gap) at the beginning of the edge at the source. 

1866 

1867 min_target_margin : int or list of ints (default=0) 

1868 The minimum margin (gap) at the end of the edge at the target. 

1869 

1870 hide_ticks : bool, optional 

1871 Hide ticks of axes. When `True` (the default), ticks and ticklabels 

1872 are removed from the axes. To set ticks and tick labels to the pyplot default, 

1873 use ``hide_ticks=False``. 

1874 

1875 Returns 

1876 ------- 

1877 matplotlib.collections.LineCollection or a list of matplotlib.patches.FancyArrowPatch 

1878 If ``arrows=True``, a list of FancyArrowPatches is returned. 

1879 If ``arrows=False``, a LineCollection is returned. 

1880 If ``arrows=None`` (the default), then a LineCollection is returned if 

1881 `G` is undirected, otherwise returns a list of FancyArrowPatches. 

1882 

1883 Notes 

1884 ----- 

1885 For directed graphs, arrows are drawn at the head end. Arrows can be 

1886 turned off with keyword arrows=False or by passing an arrowstyle without 

1887 an arrow on the end. 

1888 

1889 Be sure to include `node_size` as a keyword argument; arrows are 

1890 drawn considering the size of nodes. 

1891 

1892 Self-loops are always drawn with `~matplotlib.patches.FancyArrowPatch` 

1893 regardless of the value of `arrows` or whether `G` is directed. 

1894 When ``arrows=False`` or ``arrows=None`` and `G` is undirected, the 

1895 FancyArrowPatches corresponding to the self-loops are not explicitly 

1896 returned. They should instead be accessed via the ``Axes.patches`` 

1897 attribute (see examples). 

1898 

1899 Examples 

1900 -------- 

1901 >>> G = nx.dodecahedral_graph() 

1902 >>> edges = nx.draw_networkx_edges(G, pos=nx.spring_layout(G)) 

1903 

1904 >>> G = nx.DiGraph() 

1905 >>> G.add_edges_from([(1, 2), (1, 3), (2, 3)]) 

1906 >>> arcs = nx.draw_networkx_edges(G, pos=nx.spring_layout(G)) 

1907 >>> alphas = [0.3, 0.4, 0.5] 

1908 >>> for i, arc in enumerate(arcs): # change alpha values of arcs 

1909 ... arc.set_alpha(alphas[i]) 

1910 

1911 The FancyArrowPatches corresponding to self-loops are not always 

1912 returned, but can always be accessed via the ``patches`` attribute of the 

1913 `matplotlib.Axes` object. 

1914 

1915 >>> import matplotlib.pyplot as plt 

1916 >>> fig, ax = plt.subplots() 

1917 >>> G = nx.Graph([(0, 1), (0, 0)]) # Self-loop at node 0 

1918 >>> edge_collection = nx.draw_networkx_edges(G, pos=nx.circular_layout(G), ax=ax) 

1919 >>> self_loop_fap = ax.patches[0] 

1920 

1921 Also see the NetworkX drawing examples at 

1922 https://networkx.org/documentation/latest/auto_examples/index.html 

1923 

1924 See Also 

1925 -------- 

1926 draw 

1927 draw_networkx 

1928 draw_networkx_nodes 

1929 draw_networkx_labels 

1930 draw_networkx_edge_labels 

1931 

1932 """ 

1933 import warnings 

1934 

1935 import matplotlib as mpl 

1936 import matplotlib.collections # call as mpl.collections 

1937 import matplotlib.colors # call as mpl.colors 

1938 import matplotlib.pyplot as plt 

1939 import numpy as np 

1940 

1941 # The default behavior is to use LineCollection to draw edges for 

1942 # undirected graphs (for performance reasons) and use FancyArrowPatches 

1943 # for directed graphs. 

1944 # The `arrows` keyword can be used to override the default behavior 

1945 if arrows is None: 

1946 use_linecollection = not (G.is_directed() or G.is_multigraph()) 

1947 else: 

1948 if not isinstance(arrows, bool): 

1949 raise TypeError("Argument `arrows` must be of type bool or None") 

1950 use_linecollection = not arrows 

1951 

1952 if isinstance(connectionstyle, str): 

1953 connectionstyle = [connectionstyle] 

1954 elif np.iterable(connectionstyle): 

1955 connectionstyle = list(connectionstyle) 

1956 else: 

1957 msg = "draw_networkx_edges arg `connectionstyle` must be str or iterable" 

1958 raise nx.NetworkXError(msg) 

1959 

1960 # Some kwargs only apply to FancyArrowPatches. Warn users when they use 

1961 # non-default values for these kwargs when LineCollection is being used 

1962 # instead of silently ignoring the specified option 

1963 if use_linecollection: 

1964 msg = ( 

1965 "\n\nThe {0} keyword argument is not applicable when drawing edges\n" 

1966 "with LineCollection.\n\n" 

1967 "To make this warning go away, either specify `arrows=True` to\n" 

1968 "force FancyArrowPatches or use the default values.\n" 

1969 "Note that using FancyArrowPatches may be slow for large graphs.\n" 

1970 ) 

1971 if arrowstyle is not None: 

1972 warnings.warn(msg.format("arrowstyle"), category=UserWarning, stacklevel=2) 

1973 if arrowsize != 10: 

1974 warnings.warn(msg.format("arrowsize"), category=UserWarning, stacklevel=2) 

1975 if min_source_margin != 0: 

1976 warnings.warn( 

1977 msg.format("min_source_margin"), category=UserWarning, stacklevel=2 

1978 ) 

1979 if min_target_margin != 0: 

1980 warnings.warn( 

1981 msg.format("min_target_margin"), category=UserWarning, stacklevel=2 

1982 ) 

1983 if any(cs != "arc3" for cs in connectionstyle): 

1984 warnings.warn( 

1985 msg.format("connectionstyle"), category=UserWarning, stacklevel=2 

1986 ) 

1987 

1988 # NOTE: Arrowstyle modification must occur after the warnings section 

1989 if arrowstyle is None: 

1990 arrowstyle = "-|>" if G.is_directed() else "-" 

1991 

1992 if ax is None: 

1993 ax = plt.gca() 

1994 

1995 if edgelist is None: 

1996 edgelist = list(G.edges) # (u, v, k) for multigraph (u, v) otherwise 

1997 

1998 if len(edgelist): 

1999 if G.is_multigraph(): 

2000 key_count = collections.defaultdict(lambda: itertools.count(0)) 

2001 edge_indices = [next(key_count[tuple(e[:2])]) for e in edgelist] 

2002 else: 

2003 edge_indices = [0] * len(edgelist) 

2004 else: # no edges! 

2005 return [] 

2006 

2007 if nodelist is None: 

2008 nodelist = list(G.nodes()) 

2009 

2010 # FancyArrowPatch handles color=None different from LineCollection 

2011 if edge_color is None: 

2012 edge_color = "k" 

2013 

2014 # set edge positions 

2015 edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist]) 

2016 

2017 # Check if edge_color is an array of floats and map to edge_cmap. 

2018 # This is the only case handled differently from matplotlib 

2019 if ( 

2020 np.iterable(edge_color) 

2021 and (len(edge_color) == len(edge_pos)) 

2022 and np.all([isinstance(c, Number) for c in edge_color]) 

2023 ): 

2024 if edge_cmap is not None: 

2025 assert isinstance(edge_cmap, mpl.colors.Colormap) 

2026 else: 

2027 edge_cmap = plt.get_cmap() 

2028 if edge_vmin is None: 

2029 edge_vmin = min(edge_color) 

2030 if edge_vmax is None: 

2031 edge_vmax = max(edge_color) 

2032 color_normal = mpl.colors.Normalize(vmin=edge_vmin, vmax=edge_vmax) 

2033 edge_color = [edge_cmap(color_normal(e)) for e in edge_color] 

2034 

2035 # compute initial view 

2036 minx = np.amin(np.ravel(edge_pos[:, :, 0])) 

2037 maxx = np.amax(np.ravel(edge_pos[:, :, 0])) 

2038 miny = np.amin(np.ravel(edge_pos[:, :, 1])) 

2039 maxy = np.amax(np.ravel(edge_pos[:, :, 1])) 

2040 w = maxx - minx 

2041 h = maxy - miny 

2042 

2043 # Self-loops are scaled by view extent, except in cases the extent 

2044 # is 0, e.g. for a single node. In this case, fall back to scaling 

2045 # by the maximum node size 

2046 selfloop_height = h if h != 0 else 0.005 * np.array(node_size).max() 

2047 fancy_arrow_factory = FancyArrowFactory( 

2048 edge_pos, 

2049 edgelist, 

2050 nodelist, 

2051 edge_indices, 

2052 node_size, 

2053 selfloop_height, 

2054 connectionstyle, 

2055 node_shape, 

2056 arrowstyle, 

2057 arrowsize, 

2058 edge_color, 

2059 alpha, 

2060 width, 

2061 style, 

2062 min_source_margin, 

2063 min_target_margin, 

2064 ax=ax, 

2065 ) 

2066 

2067 # Draw the edges 

2068 if use_linecollection: 

2069 edge_collection = mpl.collections.LineCollection( 

2070 edge_pos, 

2071 colors=edge_color, 

2072 linewidths=width, 

2073 antialiaseds=(1,), 

2074 linestyle=style, 

2075 alpha=alpha, 

2076 ) 

2077 edge_collection.set_cmap(edge_cmap) 

2078 edge_collection.set_clim(edge_vmin, edge_vmax) 

2079 edge_collection.set_zorder(1) # edges go behind nodes 

2080 edge_collection.set_label(label) 

2081 ax.add_collection(edge_collection) 

2082 edge_viz_obj = edge_collection 

2083 

2084 # Make sure selfloop edges are also drawn 

2085 # --------------------------------------- 

2086 selfloops_to_draw = [loop for loop in nx.selfloop_edges(G) if loop in edgelist] 

2087 if selfloops_to_draw: 

2088 edgelist_tuple = list(map(tuple, edgelist)) 

2089 arrow_collection = [] 

2090 for loop in selfloops_to_draw: 

2091 i = edgelist_tuple.index(loop) 

2092 arrow = fancy_arrow_factory(i) 

2093 arrow_collection.append(arrow) 

2094 ax.add_patch(arrow) 

2095 else: 

2096 edge_viz_obj = [] 

2097 for i in range(len(edgelist)): 

2098 arrow = fancy_arrow_factory(i) 

2099 ax.add_patch(arrow) 

2100 edge_viz_obj.append(arrow) 

2101 

2102 # update view after drawing 

2103 padx, pady = 0.05 * w, 0.05 * h 

2104 corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady) 

2105 ax.update_datalim(corners) 

2106 ax.autoscale_view() 

2107 

2108 if hide_ticks: 

2109 ax.tick_params( 

2110 axis="both", 

2111 which="both", 

2112 bottom=False, 

2113 left=False, 

2114 labelbottom=False, 

2115 labelleft=False, 

2116 ) 

2117 

2118 return edge_viz_obj 

2119 

2120 

2121def draw_networkx_labels( 

2122 G, 

2123 pos, 

2124 labels=None, 

2125 font_size=12, 

2126 font_color="k", 

2127 font_family="sans-serif", 

2128 font_weight="normal", 

2129 alpha=None, 

2130 bbox=None, 

2131 horizontalalignment="center", 

2132 verticalalignment="center", 

2133 ax=None, 

2134 clip_on=True, 

2135 hide_ticks=True, 

2136): 

2137 """Draw node labels on the graph G. 

2138 

2139 Parameters 

2140 ---------- 

2141 G : graph 

2142 A networkx graph 

2143 

2144 pos : dictionary 

2145 A dictionary with nodes as keys and positions as values. 

2146 Positions should be sequences of length 2. 

2147 

2148 labels : dictionary (default={n: n for n in G}) 

2149 Node labels in a dictionary of text labels keyed by node. 

2150 Node-keys in labels should appear as keys in `pos`. 

2151 If needed use: `{n:lab for n,lab in labels.items() if n in pos}` 

2152 

2153 font_size : int or dictionary of nodes to ints (default=12) 

2154 Font size for text labels. 

2155 

2156 font_color : color or dictionary of nodes to colors (default='k' black) 

2157 Font color string. Color can be string or rgb (or rgba) tuple of 

2158 floats from 0-1. 

2159 

2160 font_weight : string or dictionary of nodes to strings (default='normal') 

2161 Font weight. 

2162 

2163 font_family : string or dictionary of nodes to strings (default='sans-serif') 

2164 Font family. 

2165 

2166 alpha : float or None or dictionary of nodes to floats (default=None) 

2167 The text transparency. 

2168 

2169 bbox : Matplotlib bbox, (default is Matplotlib's ax.text default) 

2170 Specify text box properties (e.g. shape, color etc.) for node labels. 

2171 

2172 horizontalalignment : string or array of strings (default='center') 

2173 Horizontal alignment {'center', 'right', 'left'}. If an array is 

2174 specified it must be the same length as `nodelist`. 

2175 

2176 verticalalignment : string (default='center') 

2177 Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}. 

2178 If an array is specified it must be the same length as `nodelist`. 

2179 

2180 ax : Matplotlib Axes object, optional 

2181 Draw the graph in the specified Matplotlib axes. 

2182 

2183 clip_on : bool (default=True) 

2184 Turn on clipping of node labels at axis boundaries 

2185 

2186 hide_ticks : bool, optional 

2187 Hide ticks of axes. When `True` (the default), ticks and ticklabels 

2188 are removed from the axes. To set ticks and tick labels to the pyplot default, 

2189 use ``hide_ticks=False``. 

2190 

2191 Returns 

2192 ------- 

2193 dict 

2194 `dict` of labels keyed on the nodes 

2195 

2196 Examples 

2197 -------- 

2198 >>> G = nx.dodecahedral_graph() 

2199 >>> labels = nx.draw_networkx_labels(G, pos=nx.spring_layout(G)) 

2200 

2201 Also see the NetworkX drawing examples at 

2202 https://networkx.org/documentation/latest/auto_examples/index.html 

2203 

2204 See Also 

2205 -------- 

2206 draw 

2207 draw_networkx 

2208 draw_networkx_nodes 

2209 draw_networkx_edges 

2210 draw_networkx_edge_labels 

2211 """ 

2212 import matplotlib.pyplot as plt 

2213 

2214 if ax is None: 

2215 ax = plt.gca() 

2216 

2217 if labels is None: 

2218 labels = {n: n for n in G.nodes()} 

2219 

2220 individual_params = set() 

2221 

2222 def check_individual_params(p_value, p_name): 

2223 if isinstance(p_value, dict): 

2224 if len(p_value) != len(labels): 

2225 raise ValueError(f"{p_name} must have the same length as labels.") 

2226 individual_params.add(p_name) 

2227 

2228 def get_param_value(node, p_value, p_name): 

2229 if p_name in individual_params: 

2230 return p_value[node] 

2231 return p_value 

2232 

2233 check_individual_params(font_size, "font_size") 

2234 check_individual_params(font_color, "font_color") 

2235 check_individual_params(font_weight, "font_weight") 

2236 check_individual_params(font_family, "font_family") 

2237 check_individual_params(alpha, "alpha") 

2238 

2239 text_items = {} # there is no text collection so we'll fake one 

2240 for n, label in labels.items(): 

2241 (x, y) = pos[n] 

2242 if not isinstance(label, str): 

2243 label = str(label) # this makes "1" and 1 labeled the same 

2244 t = ax.text( 

2245 x, 

2246 y, 

2247 label, 

2248 size=get_param_value(n, font_size, "font_size"), 

2249 color=get_param_value(n, font_color, "font_color"), 

2250 family=get_param_value(n, font_family, "font_family"), 

2251 weight=get_param_value(n, font_weight, "font_weight"), 

2252 alpha=get_param_value(n, alpha, "alpha"), 

2253 horizontalalignment=horizontalalignment, 

2254 verticalalignment=verticalalignment, 

2255 transform=ax.transData, 

2256 bbox=bbox, 

2257 clip_on=clip_on, 

2258 ) 

2259 text_items[n] = t 

2260 

2261 if hide_ticks: 

2262 ax.tick_params( 

2263 axis="both", 

2264 which="both", 

2265 bottom=False, 

2266 left=False, 

2267 labelbottom=False, 

2268 labelleft=False, 

2269 ) 

2270 

2271 return text_items 

2272 

2273 

2274def draw_networkx_edge_labels( 

2275 G, 

2276 pos, 

2277 edge_labels=None, 

2278 label_pos=0.5, 

2279 font_size=10, 

2280 font_color="k", 

2281 font_family="sans-serif", 

2282 font_weight="normal", 

2283 alpha=None, 

2284 bbox=None, 

2285 horizontalalignment="center", 

2286 verticalalignment="center", 

2287 ax=None, 

2288 rotate=True, 

2289 clip_on=True, 

2290 node_size=300, 

2291 nodelist=None, 

2292 connectionstyle="arc3", 

2293 hide_ticks=True, 

2294): 

2295 """Draw edge labels. 

2296 

2297 Parameters 

2298 ---------- 

2299 G : graph 

2300 A networkx graph 

2301 

2302 pos : dictionary 

2303 A dictionary with nodes as keys and positions as values. 

2304 Positions should be sequences of length 2. 

2305 

2306 edge_labels : dictionary (default=None) 

2307 Edge labels in a dictionary of labels keyed by edge two-tuple. 

2308 Only labels for the keys in the dictionary are drawn. 

2309 

2310 label_pos : float (default=0.5) 

2311 Position of edge label along edge (0=head, 0.5=center, 1=tail) 

2312 

2313 font_size : int (default=10) 

2314 Font size for text labels 

2315 

2316 font_color : color (default='k' black) 

2317 Font color string. Color can be string or rgb (or rgba) tuple of 

2318 floats from 0-1. 

2319 

2320 font_weight : string (default='normal') 

2321 Font weight 

2322 

2323 font_family : string (default='sans-serif') 

2324 Font family 

2325 

2326 alpha : float or None (default=None) 

2327 The text transparency 

2328 

2329 bbox : Matplotlib bbox, optional 

2330 Specify text box properties (e.g. shape, color etc.) for edge labels. 

2331 Default is {boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)}. 

2332 

2333 horizontalalignment : string (default='center') 

2334 Horizontal alignment {'center', 'right', 'left'} 

2335 

2336 verticalalignment : string (default='center') 

2337 Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'} 

2338 

2339 ax : Matplotlib Axes object, optional 

2340 Draw the graph in the specified Matplotlib axes. 

2341 

2342 rotate : bool (default=True) 

2343 Rotate edge labels to lie parallel to edges 

2344 

2345 clip_on : bool (default=True) 

2346 Turn on clipping of edge labels at axis boundaries 

2347 

2348 node_size : scalar or array (default=300) 

2349 Size of nodes. If an array it must be the same length as nodelist. 

2350 

2351 nodelist : list, optional (default=G.nodes()) 

2352 This provides the node order for the `node_size` array (if it is an array). 

2353 

2354 connectionstyle : string or iterable of strings (default="arc3") 

2355 Pass the connectionstyle parameter to create curved arc of rounding 

2356 radius rad. For example, connectionstyle='arc3,rad=0.2'. 

2357 See `matplotlib.patches.ConnectionStyle` and 

2358 `matplotlib.patches.FancyArrowPatch` for more info. 

2359 If Iterable, index indicates i'th edge key of MultiGraph 

2360 

2361 hide_ticks : bool, optional 

2362 Hide ticks of axes. When `True` (the default), ticks and ticklabels 

2363 are removed from the axes. To set ticks and tick labels to the pyplot default, 

2364 use ``hide_ticks=False``. 

2365 

2366 Returns 

2367 ------- 

2368 dict 

2369 `dict` of labels keyed by edge 

2370 

2371 Examples 

2372 -------- 

2373 >>> G = nx.dodecahedral_graph() 

2374 >>> edge_labels = nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G)) 

2375 

2376 Also see the NetworkX drawing examples at 

2377 https://networkx.org/documentation/latest/auto_examples/index.html 

2378 

2379 See Also 

2380 -------- 

2381 draw 

2382 draw_networkx 

2383 draw_networkx_nodes 

2384 draw_networkx_edges 

2385 draw_networkx_labels 

2386 """ 

2387 import matplotlib as mpl 

2388 import matplotlib.pyplot as plt 

2389 import numpy as np 

2390 

2391 class CurvedArrowText(CurvedArrowTextBase, mpl.text.Text): 

2392 pass 

2393 

2394 # use default box of white with white border 

2395 if bbox is None: 

2396 bbox = {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)} 

2397 

2398 if isinstance(connectionstyle, str): 

2399 connectionstyle = [connectionstyle] 

2400 elif np.iterable(connectionstyle): 

2401 connectionstyle = list(connectionstyle) 

2402 else: 

2403 raise nx.NetworkXError( 

2404 "draw_networkx_edges arg `connectionstyle` must be" 

2405 "string or iterable of strings" 

2406 ) 

2407 

2408 if ax is None: 

2409 ax = plt.gca() 

2410 

2411 if edge_labels is None: 

2412 kwds = {"keys": True} if G.is_multigraph() else {} 

2413 edge_labels = {tuple(edge): d for *edge, d in G.edges(data=True, **kwds)} 

2414 # NOTHING TO PLOT 

2415 if not edge_labels: 

2416 return {} 

2417 edgelist, labels = zip(*edge_labels.items()) 

2418 

2419 if nodelist is None: 

2420 nodelist = list(G.nodes()) 

2421 

2422 # set edge positions 

2423 edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist]) 

2424 

2425 if G.is_multigraph(): 

2426 key_count = collections.defaultdict(lambda: itertools.count(0)) 

2427 edge_indices = [next(key_count[tuple(e[:2])]) for e in edgelist] 

2428 else: 

2429 edge_indices = [0] * len(edgelist) 

2430 

2431 # Used to determine self loop mid-point 

2432 # Note, that this will not be accurate, 

2433 # if not drawing edge_labels for all edges drawn 

2434 h = 0 

2435 if edge_labels: 

2436 miny = np.amin(np.ravel(edge_pos[:, :, 1])) 

2437 maxy = np.amax(np.ravel(edge_pos[:, :, 1])) 

2438 h = maxy - miny 

2439 selfloop_height = h if h != 0 else 0.005 * np.array(node_size).max() 

2440 fancy_arrow_factory = FancyArrowFactory( 

2441 edge_pos, 

2442 edgelist, 

2443 nodelist, 

2444 edge_indices, 

2445 node_size, 

2446 selfloop_height, 

2447 connectionstyle, 

2448 ax=ax, 

2449 ) 

2450 

2451 individual_params = {} 

2452 

2453 def check_individual_params(p_value, p_name): 

2454 # TODO should this be list or array (as in a numpy array)? 

2455 if isinstance(p_value, list): 

2456 if len(p_value) != len(edgelist): 

2457 raise ValueError(f"{p_name} must have the same length as edgelist.") 

2458 individual_params[p_name] = p_value.iter() 

2459 

2460 # Don't need to pass in an edge because these are lists, not dicts 

2461 def get_param_value(p_value, p_name): 

2462 if p_name in individual_params: 

2463 return next(individual_params[p_name]) 

2464 return p_value 

2465 

2466 check_individual_params(font_size, "font_size") 

2467 check_individual_params(font_color, "font_color") 

2468 check_individual_params(font_weight, "font_weight") 

2469 check_individual_params(alpha, "alpha") 

2470 check_individual_params(horizontalalignment, "horizontalalignment") 

2471 check_individual_params(verticalalignment, "verticalalignment") 

2472 check_individual_params(rotate, "rotate") 

2473 check_individual_params(label_pos, "label_pos") 

2474 

2475 text_items = {} 

2476 for i, (edge, label) in enumerate(zip(edgelist, labels)): 

2477 if not isinstance(label, str): 

2478 label = str(label) # this makes "1" and 1 labeled the same 

2479 

2480 n1, n2 = edge[:2] 

2481 arrow = fancy_arrow_factory(i) 

2482 if n1 == n2: 

2483 connectionstyle_obj = arrow.get_connectionstyle() 

2484 posA = ax.transData.transform(pos[n1]) 

2485 path_disp = connectionstyle_obj(posA, posA) 

2486 path_data = ax.transData.inverted().transform_path(path_disp) 

2487 x, y = path_data.vertices[0] 

2488 text_items[edge] = ax.text( 

2489 x, 

2490 y, 

2491 label, 

2492 size=get_param_value(font_size, "font_size"), 

2493 color=get_param_value(font_color, "font_color"), 

2494 family=get_param_value(font_family, "font_family"), 

2495 weight=get_param_value(font_weight, "font_weight"), 

2496 alpha=get_param_value(alpha, "alpha"), 

2497 horizontalalignment=get_param_value( 

2498 horizontalalignment, "horizontalalignment" 

2499 ), 

2500 verticalalignment=get_param_value( 

2501 verticalalignment, "verticalalignment" 

2502 ), 

2503 rotation=0, 

2504 transform=ax.transData, 

2505 bbox=bbox, 

2506 zorder=1, 

2507 clip_on=clip_on, 

2508 ) 

2509 else: 

2510 text_items[edge] = CurvedArrowText( 

2511 arrow, 

2512 label, 

2513 size=get_param_value(font_size, "font_size"), 

2514 color=get_param_value(font_color, "font_color"), 

2515 family=get_param_value(font_family, "font_family"), 

2516 weight=get_param_value(font_weight, "font_weight"), 

2517 alpha=get_param_value(alpha, "alpha"), 

2518 horizontalalignment=get_param_value( 

2519 horizontalalignment, "horizontalalignment" 

2520 ), 

2521 verticalalignment=get_param_value( 

2522 verticalalignment, "verticalalignment" 

2523 ), 

2524 transform=ax.transData, 

2525 bbox=bbox, 

2526 zorder=1, 

2527 clip_on=clip_on, 

2528 label_pos=get_param_value(label_pos, "label_pos"), 

2529 labels_horizontal=not get_param_value(rotate, "rotate"), 

2530 ax=ax, 

2531 ) 

2532 

2533 if hide_ticks: 

2534 ax.tick_params( 

2535 axis="both", 

2536 which="both", 

2537 bottom=False, 

2538 left=False, 

2539 labelbottom=False, 

2540 labelleft=False, 

2541 ) 

2542 

2543 return text_items 

2544 

2545 

2546def draw_bipartite(G, **kwargs): 

2547 """Draw the graph `G` with a bipartite layout. 

2548 

2549 This is a convenience function equivalent to:: 

2550 

2551 nx.draw(G, pos=nx.bipartite_layout(G), **kwargs) 

2552 

2553 Parameters 

2554 ---------- 

2555 G : graph 

2556 A networkx graph 

2557 

2558 kwargs : optional keywords 

2559 See `draw_networkx` for a description of optional keywords. 

2560 

2561 Raises 

2562 ------ 

2563 NetworkXError : 

2564 If `G` is not bipartite. 

2565 

2566 Notes 

2567 ----- 

2568 The layout is computed each time this function is called. For 

2569 repeated drawing it is much more efficient to call 

2570 `~networkx.drawing.layout.bipartite_layout` directly and reuse the result:: 

2571 

2572 >>> G = nx.complete_bipartite_graph(3, 3) 

2573 >>> pos = nx.bipartite_layout(G) 

2574 >>> nx.draw(G, pos=pos) # Draw the original graph 

2575 >>> # Draw a subgraph, reusing the same node positions 

2576 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") 

2577 

2578 Examples 

2579 -------- 

2580 >>> G = nx.complete_bipartite_graph(2, 5) 

2581 >>> nx.draw_bipartite(G) 

2582 

2583 See Also 

2584 -------- 

2585 :func:`~networkx.drawing.layout.bipartite_layout` 

2586 """ 

2587 draw(G, pos=nx.bipartite_layout(G), **kwargs) 

2588 

2589 

2590def draw_circular(G, **kwargs): 

2591 """Draw the graph `G` with a circular layout. 

2592 

2593 This is a convenience function equivalent to:: 

2594 

2595 nx.draw(G, pos=nx.circular_layout(G), **kwargs) 

2596 

2597 Parameters 

2598 ---------- 

2599 G : graph 

2600 A networkx graph 

2601 

2602 kwargs : optional keywords 

2603 See `draw_networkx` for a description of optional keywords. 

2604 

2605 Notes 

2606 ----- 

2607 The layout is computed each time this function is called. For 

2608 repeated drawing it is much more efficient to call 

2609 `~networkx.drawing.layout.circular_layout` directly and reuse the result:: 

2610 

2611 >>> G = nx.complete_graph(5) 

2612 >>> pos = nx.circular_layout(G) 

2613 >>> nx.draw(G, pos=pos) # Draw the original graph 

2614 >>> # Draw a subgraph, reusing the same node positions 

2615 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") 

2616 

2617 Examples 

2618 -------- 

2619 >>> G = nx.path_graph(5) 

2620 >>> nx.draw_circular(G) 

2621 

2622 See Also 

2623 -------- 

2624 :func:`~networkx.drawing.layout.circular_layout` 

2625 """ 

2626 draw(G, pos=nx.circular_layout(G), **kwargs) 

2627 

2628 

2629def draw_kamada_kawai(G, **kwargs): 

2630 """Draw the graph `G` with a Kamada-Kawai force-directed layout. 

2631 

2632 This is a convenience function equivalent to:: 

2633 

2634 nx.draw(G, pos=nx.kamada_kawai_layout(G), **kwargs) 

2635 

2636 Parameters 

2637 ---------- 

2638 G : graph 

2639 A networkx graph 

2640 

2641 kwargs : optional keywords 

2642 See `draw_networkx` for a description of optional keywords. 

2643 

2644 Notes 

2645 ----- 

2646 The layout is computed each time this function is called. 

2647 For repeated drawing it is much more efficient to call 

2648 `~networkx.drawing.layout.kamada_kawai_layout` directly and reuse the 

2649 result:: 

2650 

2651 >>> G = nx.complete_graph(5) 

2652 >>> pos = nx.kamada_kawai_layout(G) 

2653 >>> nx.draw(G, pos=pos) # Draw the original graph 

2654 >>> # Draw a subgraph, reusing the same node positions 

2655 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") 

2656 

2657 Examples 

2658 -------- 

2659 >>> G = nx.path_graph(5) 

2660 >>> nx.draw_kamada_kawai(G) 

2661 

2662 See Also 

2663 -------- 

2664 :func:`~networkx.drawing.layout.kamada_kawai_layout` 

2665 """ 

2666 draw(G, pos=nx.kamada_kawai_layout(G), **kwargs) 

2667 

2668 

2669def draw_random(G, **kwargs): 

2670 """Draw the graph `G` with a random layout. 

2671 

2672 This is a convenience function equivalent to:: 

2673 

2674 nx.draw(G, pos=nx.random_layout(G), **kwargs) 

2675 

2676 Parameters 

2677 ---------- 

2678 G : graph 

2679 A networkx graph 

2680 

2681 kwargs : optional keywords 

2682 See `draw_networkx` for a description of optional keywords. 

2683 

2684 Notes 

2685 ----- 

2686 The layout is computed each time this function is called. 

2687 For repeated drawing it is much more efficient to call 

2688 `~networkx.drawing.layout.random_layout` directly and reuse the result:: 

2689 

2690 >>> G = nx.complete_graph(5) 

2691 >>> pos = nx.random_layout(G) 

2692 >>> nx.draw(G, pos=pos) # Draw the original graph 

2693 >>> # Draw a subgraph, reusing the same node positions 

2694 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") 

2695 

2696 Examples 

2697 -------- 

2698 >>> G = nx.lollipop_graph(4, 3) 

2699 >>> nx.draw_random(G) 

2700 

2701 See Also 

2702 -------- 

2703 :func:`~networkx.drawing.layout.random_layout` 

2704 """ 

2705 draw(G, pos=nx.random_layout(G), **kwargs) 

2706 

2707 

2708def draw_spectral(G, **kwargs): 

2709 """Draw the graph `G` with a spectral 2D layout. 

2710 

2711 This is a convenience function equivalent to:: 

2712 

2713 nx.draw(G, pos=nx.spectral_layout(G), **kwargs) 

2714 

2715 For more information about how node positions are determined, see 

2716 `~networkx.drawing.layout.spectral_layout`. 

2717 

2718 Parameters 

2719 ---------- 

2720 G : graph 

2721 A networkx graph 

2722 

2723 kwargs : optional keywords 

2724 See `draw_networkx` for a description of optional keywords. 

2725 

2726 Notes 

2727 ----- 

2728 The layout is computed each time this function is called. 

2729 For repeated drawing it is much more efficient to call 

2730 `~networkx.drawing.layout.spectral_layout` directly and reuse the result:: 

2731 

2732 >>> G = nx.complete_graph(5) 

2733 >>> pos = nx.spectral_layout(G) 

2734 >>> nx.draw(G, pos=pos) # Draw the original graph 

2735 >>> # Draw a subgraph, reusing the same node positions 

2736 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") 

2737 

2738 Examples 

2739 -------- 

2740 >>> G = nx.path_graph(5) 

2741 >>> nx.draw_spectral(G) 

2742 

2743 See Also 

2744 -------- 

2745 :func:`~networkx.drawing.layout.spectral_layout` 

2746 """ 

2747 draw(G, pos=nx.spectral_layout(G), **kwargs) 

2748 

2749 

2750def draw_spring(G, **kwargs): 

2751 """Draw the graph `G` with a spring layout. 

2752 

2753 This is a convenience function equivalent to:: 

2754 

2755 nx.draw(G, pos=nx.spring_layout(G), **kwargs) 

2756 

2757 Parameters 

2758 ---------- 

2759 G : graph 

2760 A networkx graph 

2761 

2762 kwargs : optional keywords 

2763 See `draw_networkx` for a description of optional keywords. 

2764 

2765 Notes 

2766 ----- 

2767 `~networkx.drawing.layout.spring_layout` is also the default layout for 

2768 `draw`, so this function is equivalent to `draw`. 

2769 

2770 The layout is computed each time this function is called. 

2771 For repeated drawing it is much more efficient to call 

2772 `~networkx.drawing.layout.spring_layout` directly and reuse the result:: 

2773 

2774 >>> G = nx.complete_graph(5) 

2775 >>> pos = nx.spring_layout(G) 

2776 >>> nx.draw(G, pos=pos) # Draw the original graph 

2777 >>> # Draw a subgraph, reusing the same node positions 

2778 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") 

2779 

2780 Examples 

2781 -------- 

2782 >>> G = nx.path_graph(20) 

2783 >>> nx.draw_spring(G) 

2784 

2785 See Also 

2786 -------- 

2787 draw 

2788 :func:`~networkx.drawing.layout.spring_layout` 

2789 """ 

2790 draw(G, pos=nx.spring_layout(G), **kwargs) 

2791 

2792 

2793def draw_shell(G, nlist=None, **kwargs): 

2794 """Draw networkx graph `G` with shell layout. 

2795 

2796 This is a convenience function equivalent to:: 

2797 

2798 nx.draw(G, pos=nx.shell_layout(G, nlist=nlist), **kwargs) 

2799 

2800 Parameters 

2801 ---------- 

2802 G : graph 

2803 A networkx graph 

2804 

2805 nlist : list of list of nodes, optional 

2806 A list containing lists of nodes representing the shells. 

2807 Default is `None`, meaning all nodes are in a single shell. 

2808 See `~networkx.drawing.layout.shell_layout` for details. 

2809 

2810 kwargs : optional keywords 

2811 See `draw_networkx` for a description of optional keywords. 

2812 

2813 Notes 

2814 ----- 

2815 The layout is computed each time this function is called. 

2816 For repeated drawing it is much more efficient to call 

2817 `~networkx.drawing.layout.shell_layout` directly and reuse the result:: 

2818 

2819 >>> G = nx.complete_graph(5) 

2820 >>> pos = nx.shell_layout(G) 

2821 >>> nx.draw(G, pos=pos) # Draw the original graph 

2822 >>> # Draw a subgraph, reusing the same node positions 

2823 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") 

2824 

2825 Examples 

2826 -------- 

2827 >>> G = nx.path_graph(4) 

2828 >>> shells = [[0], [1, 2, 3]] 

2829 >>> nx.draw_shell(G, nlist=shells) 

2830 

2831 See Also 

2832 -------- 

2833 :func:`~networkx.drawing.layout.shell_layout` 

2834 """ 

2835 draw(G, pos=nx.shell_layout(G, nlist=nlist), **kwargs) 

2836 

2837 

2838def draw_planar(G, **kwargs): 

2839 """Draw a planar networkx graph `G` with planar layout. 

2840 

2841 This is a convenience function equivalent to:: 

2842 

2843 nx.draw(G, pos=nx.planar_layout(G), **kwargs) 

2844 

2845 Parameters 

2846 ---------- 

2847 G : graph 

2848 A planar networkx graph 

2849 

2850 kwargs : optional keywords 

2851 See `draw_networkx` for a description of optional keywords. 

2852 

2853 Raises 

2854 ------ 

2855 NetworkXException 

2856 When `G` is not planar 

2857 

2858 Notes 

2859 ----- 

2860 The layout is computed each time this function is called. 

2861 For repeated drawing it is much more efficient to call 

2862 `~networkx.drawing.layout.planar_layout` directly and reuse the result:: 

2863 

2864 >>> G = nx.path_graph(5) 

2865 >>> pos = nx.planar_layout(G) 

2866 >>> nx.draw(G, pos=pos) # Draw the original graph 

2867 >>> # Draw a subgraph, reusing the same node positions 

2868 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") 

2869 

2870 Examples 

2871 -------- 

2872 >>> G = nx.path_graph(4) 

2873 >>> nx.draw_planar(G) 

2874 

2875 See Also 

2876 -------- 

2877 :func:`~networkx.drawing.layout.planar_layout` 

2878 """ 

2879 draw(G, pos=nx.planar_layout(G), **kwargs) 

2880 

2881 

2882def draw_forceatlas2(G, **kwargs): 

2883 """Draw a networkx graph with forceatlas2 layout. 

2884 

2885 This is a convenience function equivalent to:: 

2886 

2887 nx.draw(G, pos=nx.forceatlas2_layout(G), **kwargs) 

2888 

2889 Parameters 

2890 ---------- 

2891 G : graph 

2892 A networkx graph 

2893 

2894 kwargs : optional keywords 

2895 See networkx.draw_networkx() for a description of optional keywords, 

2896 with the exception of the pos parameter which is not used by this 

2897 function. 

2898 """ 

2899 draw(G, pos=nx.forceatlas2_layout(G), **kwargs) 

2900 

2901 

2902def apply_alpha(colors, alpha, elem_list, cmap=None, vmin=None, vmax=None): 

2903 """Apply an alpha (or list of alphas) to the colors provided. 

2904 

2905 Parameters 

2906 ---------- 

2907 

2908 colors : color string or array of floats (default='r') 

2909 Color of element. Can be a single color format string, 

2910 or a sequence of colors with the same length as nodelist. 

2911 If numeric values are specified they will be mapped to 

2912 colors using the cmap and vmin,vmax parameters. See 

2913 matplotlib.scatter for more details. 

2914 

2915 alpha : float or array of floats 

2916 Alpha values for elements. This can be a single alpha value, in 

2917 which case it will be applied to all the elements of color. Otherwise, 

2918 if it is an array, the elements of alpha will be applied to the colors 

2919 in order (cycling through alpha multiple times if necessary). 

2920 

2921 elem_list : array of networkx objects 

2922 The list of elements which are being colored. These could be nodes, 

2923 edges or labels. 

2924 

2925 cmap : matplotlib colormap 

2926 Color map for use if colors is a list of floats corresponding to points 

2927 on a color mapping. 

2928 

2929 vmin, vmax : float 

2930 Minimum and maximum values for normalizing colors if a colormap is used 

2931 

2932 Returns 

2933 ------- 

2934 

2935 rgba_colors : numpy ndarray 

2936 Array containing RGBA format values for each of the node colours. 

2937 

2938 """ 

2939 from itertools import cycle, islice 

2940 

2941 import matplotlib as mpl 

2942 import matplotlib.cm # call as mpl.cm 

2943 import matplotlib.colors # call as mpl.colors 

2944 import numpy as np 

2945 

2946 # If we have been provided with a list of numbers as long as elem_list, 

2947 # apply the color mapping. 

2948 if len(colors) == len(elem_list) and isinstance(colors[0], Number): 

2949 mapper = mpl.cm.ScalarMappable(cmap=cmap) 

2950 mapper.set_clim(vmin, vmax) 

2951 rgba_colors = mapper.to_rgba(colors) 

2952 # Otherwise, convert colors to matplotlib's RGB using the colorConverter 

2953 # object. These are converted to numpy ndarrays to be consistent with the 

2954 # to_rgba method of ScalarMappable. 

2955 else: 

2956 try: 

2957 rgba_colors = np.array([mpl.colors.colorConverter.to_rgba(colors)]) 

2958 except ValueError: 

2959 rgba_colors = np.array( 

2960 [mpl.colors.colorConverter.to_rgba(color) for color in colors] 

2961 ) 

2962 # Set the final column of the rgba_colors to have the relevant alpha values 

2963 try: 

2964 # If alpha is longer than the number of colors, resize to the number of 

2965 # elements. Also, if rgba_colors.size (the number of elements of 

2966 # rgba_colors) is the same as the number of elements, resize the array, 

2967 # to avoid it being interpreted as a colormap by scatter() 

2968 if len(alpha) > len(rgba_colors) or rgba_colors.size == len(elem_list): 

2969 rgba_colors = np.resize(rgba_colors, (len(elem_list), 4)) 

2970 rgba_colors[1:, 0] = rgba_colors[0, 0] 

2971 rgba_colors[1:, 1] = rgba_colors[0, 1] 

2972 rgba_colors[1:, 2] = rgba_colors[0, 2] 

2973 rgba_colors[:, 3] = list(islice(cycle(alpha), len(rgba_colors))) 

2974 except TypeError: 

2975 rgba_colors[:, -1] = alpha 

2976 return rgba_colors