Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/networkx/drawing/nx_pylab.py: 5%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

704 statements  

1""" 

2********** 

3Matplotlib 

4********** 

5 

6Draw networks with matplotlib. 

7 

8Examples 

9-------- 

10>>> G = nx.complete_graph(5) 

11>>> nx.draw(G) 

12 

13See Also 

14-------- 

15 - :doc:`matplotlib <matplotlib:index>` 

16 - :func:`matplotlib.pyplot.scatter` 

17 - :obj:`matplotlib.patches.FancyArrowPatch` 

18""" 

19 

20import collections 

21import itertools 

22import math 

23from numbers import Number 

24 

25import networkx as nx 

26 

27__all__ = [ 

28 "display", 

29 "apply_matplotlib_colors", 

30 "draw", 

31 "draw_networkx", 

32 "draw_networkx_nodes", 

33 "draw_networkx_edges", 

34 "draw_networkx_labels", 

35 "draw_networkx_edge_labels", 

36 "draw_bipartite", 

37 "draw_circular", 

38 "draw_kamada_kawai", 

39 "draw_random", 

40 "draw_spectral", 

41 "draw_spring", 

42 "draw_planar", 

43 "draw_shell", 

44 "draw_forceatlas2", 

45] 

46 

47 

48def apply_matplotlib_colors( 

49 G, src_attr, dest_attr, map, vmin=None, vmax=None, nodes=True 

50): 

51 """ 

52 Apply colors from a matplotlib colormap to a graph. 

53 

54 Reads values from the `src_attr` and use a matplotlib colormap 

55 to produce a color. Write the color to `dest_attr`. 

56 

57 Parameters 

58 ---------- 

59 G : nx.Graph 

60 The graph to read and compute colors for. 

61 

62 src_attr : str or other attribute name 

63 The name of the attribute to read from the graph. 

64 

65 dest_attr : str or other attribute name 

66 The name of the attribute to write to on the graph. 

67 

68 map : matplotlib.colormap 

69 The matplotlib colormap to use. 

70 

71 vmin : float, default None 

72 The minimum value for scaling the colormap. If `None`, find the 

73 minimum value of `src_attr`. 

74 

75 vmax : float, default None 

76 The maximum value for scaling the colormap. If `None`, find the 

77 maximum value of `src_attr`. 

78 

79 nodes : bool, default True 

80 Whether the attribute names are edge attributes or node attributes. 

81 """ 

82 import matplotlib as mpl 

83 

84 if nodes: 

85 type_iter = G.nodes() 

86 elif G.is_multigraph(): 

87 type_iter = G.edges(keys=True) 

88 else: 

89 type_iter = G.edges() 

90 

91 if vmin is None or vmax is None: 

92 vals = [type_iter[a][src_attr] for a in type_iter] 

93 if vmin is None: 

94 vmin = min(vals) 

95 if vmax is None: 

96 vmax = max(vals) 

97 

98 mapper = mpl.cm.ScalarMappable(cmap=map) 

99 mapper.set_clim(vmin, vmax) 

100 

101 def do_map(x): 

102 # Cast numpy scalars to float 

103 return tuple(float(x) for x in mapper.to_rgba(x)) 

104 

105 if nodes: 

106 nx.set_node_attributes( 

107 G, {n: do_map(G.nodes[n][src_attr]) for n in G.nodes()}, dest_attr 

108 ) 

109 else: 

110 nx.set_edge_attributes( 

111 G, {e: do_map(G.edges[e][src_attr]) for e in type_iter}, dest_attr 

112 ) 

113 

114 

115class CurvedArrowTextBase: 

116 def __init__( 

117 self, 

118 arrow, 

119 *args, 

120 label_pos=0.5, 

121 labels_horizontal=False, 

122 ax=None, 

123 **kwargs, 

124 ): 

125 # Bind to FancyArrowPatch 

126 self.arrow = arrow 

127 # how far along the text should be on the curve, 

128 # 0 is at start, 1 is at end etc. 

129 self.label_pos = label_pos 

130 self.labels_horizontal = labels_horizontal 

131 if ax is None: 

132 ax = plt.gca() 

133 self.ax = ax 

134 self.x, self.y, self.angle = self._update_text_pos_angle(arrow) 

135 

136 # Create text object 

137 super().__init__(self.x, self.y, *args, rotation=self.angle, **kwargs) 

138 # Bind to axis 

139 self.ax.add_artist(self) 

140 

141 def _get_arrow_path_disp(self, arrow): 

142 """ 

143 This is part of FancyArrowPatch._get_path_in_displaycoord 

144 It omits the second part of the method where path is converted 

145 to polygon based on width 

146 The transform is taken from ax, not the object, as the object 

147 has not been added yet, and doesn't have transform 

148 """ 

149 dpi_cor = arrow._dpi_cor 

150 trans_data = self.ax.transData 

151 if arrow._posA_posB is None: 

152 raise ValueError( 

153 "Can only draw labels for fancy arrows with " 

154 "posA and posB inputs, not custom path" 

155 ) 

156 posA = arrow._convert_xy_units(arrow._posA_posB[0]) 

157 posB = arrow._convert_xy_units(arrow._posA_posB[1]) 

158 (posA, posB) = trans_data.transform((posA, posB)) 

159 _path = arrow.get_connectionstyle()( 

160 posA, 

161 posB, 

162 patchA=arrow.patchA, 

163 patchB=arrow.patchB, 

164 shrinkA=arrow.shrinkA * dpi_cor, 

165 shrinkB=arrow.shrinkB * dpi_cor, 

166 ) 

167 # Return is in display coordinates 

168 return _path 

169 

170 def _update_text_pos_angle(self, arrow): 

171 # Fractional label position 

172 # Text position at a proportion t along the line in display coords 

173 # default is 0.5 so text appears at the halfway point 

174 import matplotlib as mpl 

175 import numpy as np 

176 

177 t = self.label_pos 

178 tt = 1 - t 

179 path_disp = self._get_arrow_path_disp(arrow) 

180 conn = arrow.get_connectionstyle() 

181 # 1. Calculate x and y 

182 points = path_disp.vertices 

183 if is_curve := isinstance( 

184 conn, 

185 mpl.patches.ConnectionStyle.Angle3 | mpl.patches.ConnectionStyle.Arc3, 

186 ): 

187 # Arc3 or Angle3 type Connection Styles - Bezier curve 

188 (x1, y1), (cx, cy), (x2, y2) = points 

189 x = tt**2 * x1 + 2 * t * tt * cx + t**2 * x2 

190 y = tt**2 * y1 + 2 * t * tt * cy + t**2 * y2 

191 else: 

192 if not isinstance( 

193 conn, 

194 mpl.patches.ConnectionStyle.Angle 

195 | mpl.patches.ConnectionStyle.Arc 

196 | mpl.patches.ConnectionStyle.Bar, 

197 ): 

198 msg = f"invalid connection style: {type(conn)}" 

199 raise TypeError(msg) 

200 # A. Collect lines 

201 codes = path_disp.codes 

202 lines = [ 

203 points[i - 1 : i + 1] 

204 for i in range(1, len(points)) 

205 if codes[i] == mpl.path.Path.LINETO 

206 ] 

207 # B. If more than one line, find the right one and position in it 

208 if (nlines := len(lines)) != 1: 

209 dists = [math.dist(*line) for line in lines] 

210 dist_tot = sum(dists) 

211 cdist = 0 

212 last_cut = 0 

213 i_last = nlines - 1 

214 for i, dist in enumerate(dists): 

215 cdist += dist 

216 cut = cdist / dist_tot 

217 if i == i_last or t < cut: 

218 t = (t - last_cut) / (dist / dist_tot) 

219 tt = 1 - t 

220 lines = [lines[i]] 

221 break 

222 last_cut = cut 

223 [[(cx1, cy1), (cx2, cy2)]] = lines 

224 x = cx1 * tt + cx2 * t 

225 y = cy1 * tt + cy2 * t 

226 

227 # 2. Calculate Angle 

228 if self.labels_horizontal: 

229 # Horizontal text labels 

230 angle = 0 

231 else: 

232 # Labels parallel to curve 

233 if is_curve: 

234 change_x = 2 * tt * (cx - x1) + 2 * t * (x2 - cx) 

235 change_y = 2 * tt * (cy - y1) + 2 * t * (y2 - cy) 

236 else: 

237 change_x = (cx2 - cx1) / 2 

238 change_y = (cy2 - cy1) / 2 

239 angle = np.arctan2(change_y, change_x) / (2 * np.pi) * 360 

240 # Text is "right way up" 

241 if angle > 90: 

242 angle -= 180 

243 elif angle < -90: 

244 angle += 180 

245 (x, y) = self.ax.transData.inverted().transform((x, y)) 

246 return x, y, angle 

247 

248 def draw(self, renderer): 

249 # recalculate the text position and angle 

250 self.x, self.y, self.angle = self._update_text_pos_angle(self.arrow) 

251 self.set_position((self.x, self.y)) 

252 self.set_rotation(self.angle) 

253 # redraw text 

254 super().draw(renderer) 

255 

256 

257def display( 

258 G, 

259 canvas=None, 

260 **kwargs, 

261): 

262 """Draw the graph G. 

263 

264 Draw the graph as a collection of nodes connected by edges. 

265 The exact details of what the graph looks like are controlled by the below 

266 attributes. All nodes and nodes at the end of visible edges must have a 

267 position set, but nearly all other node and edge attributes are options and 

268 nodes or edges missing the attribute will use the default listed below. A more 

269 complete description of each parameter is given below this summary. 

270 

271 .. list-table:: Default Visualization Attributes 

272 :widths: 25 25 50 

273 :header-rows: 1 

274 

275 * - Parameter 

276 - Default Attribute 

277 - Default Value 

278 * - node_pos 

279 - `"pos"` 

280 - If there is not position, a layout will be calculated with `nx.spring_layout`. 

281 * - node_visible 

282 - `"visible"` 

283 - True 

284 * - node_color 

285 - `"color"` 

286 - #1f78b4 

287 * - node_size 

288 - `"size"` 

289 - 300 

290 * - node_label 

291 - `"label"` 

292 - Dict describing the node label. Defaults create a black text with 

293 the node name as the label. The dict respects these keys and defaults: 

294 

295 * size : 12 

296 * color : black 

297 * family : sans serif 

298 * weight : normal 

299 * alpha : 1.0 

300 * h_align : center 

301 * v_align : center 

302 * bbox : Dict describing a `matplotlib.patches.FancyBboxPatch`. 

303 Default is None. 

304 

305 * - node_shape 

306 - `"shape"` 

307 - "o" 

308 * - node_alpha 

309 - `"alpha"` 

310 - 1.0 

311 * - node_border_width 

312 - `"border_width"` 

313 - 1.0 

314 * - node_border_color 

315 - `"border_color"` 

316 - Matching node_color 

317 * - edge_visible 

318 - `"visible"` 

319 - True 

320 * - edge_width 

321 - `"width"` 

322 - 1.0 

323 * - edge_color 

324 - `"color"` 

325 - Black (#000000) 

326 * - edge_label 

327 - `"label"` 

328 - Dict describing the edge label. Defaults create black text with a 

329 white bounding box. The dictionary respects these keys and defaults: 

330 

331 * size : 12 

332 * color : black 

333 * family : sans serif 

334 * weight : normal 

335 * alpha : 1.0 

336 * bbox : Dict describing a `matplotlib.patches.FancyBboxPatch`. 

337 Default {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)} 

338 * h_align : "center" 

339 * v_align : "center" 

340 * pos : 0.5 

341 * rotate : True 

342 

343 * - edge_style 

344 - `"style"` 

345 - "-" 

346 * - edge_alpha 

347 - `"alpha"` 

348 - 1.0 

349 * - edge_arrowstyle 

350 - `"arrowstyle"` 

351 - ``"-|>"`` if `G` is directed else ``"-"`` 

352 * - edge_arrowsize 

353 - `"arrowsize"` 

354 - 10 if `G` is directed else 0 

355 * - edge_curvature 

356 - `"curvature"` 

357 - arc3 

358 * - edge_source_margin 

359 - `"source_margin"` 

360 - 0 

361 * - edge_target_margin 

362 - `"target_margin"` 

363 - 0 

364 

365 Parameters 

366 ---------- 

367 G : graph 

368 A networkx graph 

369 

370 canvas : Matplotlib Axes object, optional 

371 Draw the graph in specified Matplotlib axes 

372 

373 node_pos : string or function, default "pos" 

374 A string naming the node attribute storing the position of nodes as a tuple. 

375 Or a function to be called with input `G` which returns the layout as a dict keyed 

376 by node to position tuple like the NetworkX layout functions. 

377 If no nodes in the graph has the attribute, a spring layout is calculated. 

378 

379 node_visible : string or bool, default visible 

380 A string naming the node attribute which stores if a node should be drawn. 

381 If `True`, all nodes will be visible while if `False` no nodes will be visible. 

382 If incomplete, nodes missing this attribute will be shown by default. 

383 

384 node_color : string, default "color" 

385 A string naming the node attribute which stores the color of each node. 

386 Visible nodes without this attribute will use '#1f78b4' as a default. 

387 

388 node_size : string or number, default "size" 

389 A string naming the node attribute which stores the size of each node. 

390 Visible nodes without this attribute will use a default size of 300. 

391 

392 node_label : string or bool, default "label" 

393 A string naming the node attribute which stores the label of each node. 

394 The attribute value can be a string, False (no label for that node), 

395 True (the node is the label) or a dict keyed by node to the label. 

396 

397 If a dict is specified, these keys are read to further control the label: 

398 

399 * label : The text of the label; default: name of the node 

400 * size : Font size of the label; default: 12 

401 * color : Font color of the label; default: black 

402 * family : Font family of the label; default: "sans-serif" 

403 * weight : Font weight of the label; default: "normal" 

404 * alpha : Alpha value of the label; default: 1.0 

405 * h_align : The horizontal alignment of the label. 

406 one of "left", "center", "right"; default: "center" 

407 * v_align : The vertical alignment of the label. 

408 one of "top", "center", "bottom"; default: "center" 

409 * bbox : A dict of parameters for `matplotlib.patches.FancyBboxPatch`. 

410 

411 Visible nodes without this attribute will be treated as if the value was True. 

412 

413 node_shape : string, default "shape" 

414 A string naming the node attribute which stores the label of each node. 

415 The values of this attribute are expected to be one of the matplotlib shapes, 

416 one of 'so^>v<dph8'. Visible nodes without this attribute will use 'o'. 

417 

418 node_alpha : string, default "alpha" 

419 A string naming the node attribute which stores the alpha of each node. 

420 The values of this attribute are expected to be floats between 0.0 and 1.0. 

421 Visible nodes without this attribute will be treated as if the value was 1.0. 

422 

423 node_border_width : string, default "border_width" 

424 A string naming the node attribute storing the width of the border of the node. 

425 The values of this attribute are expected to be numeric. Visible nodes without 

426 this attribute will use the assumed default of 1.0. 

427 

428 node_border_color : string, default "border_color" 

429 A string naming the node attribute which storing the color of the border of the node. 

430 Visible nodes missing this attribute will use the final node_color value. 

431 

432 edge_visible : string or bool, default "visible" 

433 A string nameing the edge attribute which stores if an edge should be drawn. 

434 If `True`, all edges will be drawn while if `False` no edges will be visible. 

435 If incomplete, edges missing this attribute will be shown by default. Values 

436 of this attribute are expected to be booleans. 

437 

438 edge_width : string or int, default "width" 

439 A string nameing the edge attribute which stores the width of each edge. 

440 Visible edges without this attribute will use a default width of 1.0. 

441 

442 edge_color : string or color, default "color" 

443 A string nameing the edge attribute which stores of color of each edge. 

444 Visible edges without this attribute will be drawn black. Each color can be 

445 a string or rgb (or rgba) tuple of floats from 0.0 to 1.0. 

446 

447 edge_label : string, default "label" 

448 A string naming the edge attribute which stores the label of each edge. 

449 The values of this attribute can be a string, number or False or None. In 

450 the latter two cases, no edge label is displayed. 

451 

452 If a dict is specified, these keys are read to further control the label: 

453 

454 * label : The text of the label, or the name of an edge attribute holding the label. 

455 * size : Font size of the label; default: 12 

456 * color : Font color of the label; default: black 

457 * family : Font family of the label; default: "sans-serif" 

458 * weight : Font weight of the label; default: "normal" 

459 * alpha : Alpha value of the label; default: 1.0 

460 * h_align : The horizontal alignment of the label. 

461 one of "left", "center", "right"; default: "center" 

462 * v_align : The vertical alignment of the label. 

463 one of "top", "center", "bottom"; default: "center" 

464 * bbox : A dict of parameters for `matplotlib.patches.FancyBboxPatch`. 

465 * rotate : Whether to rotate labels to lie parallel to the edge, default: True. 

466 * pos : A float showing how far along the edge to put the label; default: 0.5. 

467 

468 edge_style : string, default "style" 

469 A string naming the edge attribute which stores the style of each edge. 

470 Visible edges without this attribute will be drawn solid. Values of this 

471 attribute can be line styles, e.g. '-', '--', '-.' or ':' or words like 'solid' 

472 or 'dashed'. If no edge in the graph has this attribute and it is a non-default 

473 value, assume that it describes the edge style for all edges in the graph. 

474 

475 edge_alpha : string or float, default "alpha" 

476 A string naming the edge attribute which stores the alpha value of each edge. 

477 Visible edges without this attribute will use an alpha value of 1.0. 

478 

479 edge_arrowstyle : string, default "arrowstyle" 

480 A string naming the edge attribute which stores the type of arrowhead to use for 

481 each edge. Visible edges without this attribute use ``"-"`` for undirected graphs 

482 and ``"-|>"`` for directed graphs. 

483 

484 See `matplotlib.patches.ArrowStyle` for more options 

485 

486 edge_arrowsize : string or int, default "arrowsize" 

487 A string naming the edge attribute which stores the size of the arrowhead for each 

488 edge. Visible edges without this attribute will use a default value of 10. 

489 

490 edge_curvature : string, default "curvature" 

491 A string naming the edge attribute storing the curvature and connection style 

492 of each edge. Visible edges without this attribute will use "arc3" as a default 

493 value, resulting an a straight line between the two nodes. Curvature can be given 

494 as 'arc3,rad=0.2' to specify both the style and radius of curvature. 

495 

496 Please see `matplotlib.patches.ConnectionStyle` and 

497 `matplotlib.patches.FancyArrowPatch` for more information. 

498 

499 edge_source_margin : string or int, default "source_margin" 

500 A string naming the edge attribute which stores the minimum margin (gap) between 

501 the source node and the start of the edge. Visible edges without this attribute 

502 will use a default value of 0. 

503 

504 edge_target_margin : string or int, default "target_margin" 

505 A string naming the edge attribute which stores the minimumm margin (gap) between 

506 the target node and the end of the edge. Visible edges without this attribute 

507 will use a default value of 0. 

508 

509 hide_ticks : bool, default True 

510 Weather to remove the ticks from the axes of the matplotlib object. 

511 

512 Raises 

513 ------ 

514 NetworkXError 

515 If a node or edge is missing a required parameter such as `pos` or 

516 if `display` receives an argument not listed above. 

517 

518 ValueError 

519 If a node or edge has an invalid color format, i.e. not a color string, 

520 rgb tuple or rgba tuple. 

521 

522 Returns 

523 ------- 

524 The input graph. This is potentially useful for dispatching visualization 

525 functions. 

526 """ 

527 from collections import Counter 

528 

529 import matplotlib as mpl 

530 import matplotlib.pyplot as plt 

531 import numpy as np 

532 

533 defaults = { 

534 "node_pos": None, 

535 "node_visible": True, 

536 "node_color": "#1f78b4", 

537 "node_size": 300, 

538 "node_label": { 

539 "size": 12, 

540 "color": "#000000", 

541 "family": "sans-serif", 

542 "weight": "normal", 

543 "alpha": 1.0, 

544 "h_align": "center", 

545 "v_align": "center", 

546 "bbox": None, 

547 }, 

548 "node_shape": "o", 

549 "node_alpha": 1.0, 

550 "node_border_width": 1.0, 

551 "node_border_color": "face", 

552 "edge_visible": True, 

553 "edge_width": 1.0, 

554 "edge_color": "#000000", 

555 "edge_label": { 

556 "size": 12, 

557 "color": "#000000", 

558 "family": "sans-serif", 

559 "weight": "normal", 

560 "alpha": 1.0, 

561 "bbox": {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)}, 

562 "h_align": "center", 

563 "v_align": "center", 

564 "pos": 0.5, 

565 "rotate": True, 

566 }, 

567 "edge_style": "-", 

568 "edge_alpha": 1.0, 

569 "edge_arrowstyle": "-|>" if G.is_directed() else "-", 

570 "edge_arrowsize": 10 if G.is_directed() else 0, 

571 "edge_curvature": "arc3", 

572 "edge_source_margin": 0, 

573 "edge_target_margin": 0, 

574 "hide_ticks": True, 

575 } 

576 

577 # Check arguments 

578 for kwarg in kwargs: 

579 if kwarg not in defaults: 

580 raise nx.NetworkXError( 

581 f"Unrecognized visualization keyword argument: {kwarg}" 

582 ) 

583 

584 if canvas is None: 

585 canvas = plt.gca() 

586 

587 if kwargs.get("hide_ticks", defaults["hide_ticks"]): 

588 canvas.tick_params( 

589 axis="both", 

590 which="both", 

591 bottom=False, 

592 left=False, 

593 labelbottom=False, 

594 labelleft=False, 

595 ) 

596 

597 ### Helper methods and classes 

598 

599 def node_property_sequence(seq, attr): 

600 """Return a list of attribute values for `seq`, using a default if needed""" 

601 

602 # All node attribute parameters start with "node_" 

603 param_name = f"node_{attr}" 

604 default = defaults[param_name] 

605 attr = kwargs.get(param_name, attr) 

606 

607 if default is None: 

608 # raise instead of using non-existant default value 

609 for n in seq: 

610 if attr not in node_subgraph.nodes[n]: 

611 raise nx.NetworkXError(f"Attribute '{attr}' missing for node {n}") 

612 

613 # If `attr` is not a graph attr and was explicitly passed as an argument 

614 # it must be a user-default value. Allow attr=None to tell draw to skip 

615 # attributes which are on the graph 

616 if ( 

617 attr is not None 

618 and nx.get_node_attributes(node_subgraph, attr) == {} 

619 and any(attr == v for k, v in kwargs.items() if "node" in k) 

620 ): 

621 return [attr for _ in seq] 

622 

623 return [node_subgraph.nodes[n].get(attr, default) for n in seq] 

624 

625 def compute_colors(color, alpha): 

626 if isinstance(color, str): 

627 rgba = mpl.colors.colorConverter.to_rgba(color) 

628 # Using a non-default alpha value overrides any alpha value in the color 

629 if alpha != defaults["node_alpha"]: 

630 return (rgba[0], rgba[1], rgba[2], alpha) 

631 return rgba 

632 

633 if isinstance(color, tuple) and len(color) == 3: 

634 return (color[0], color[1], color[2], alpha) 

635 

636 if isinstance(color, tuple) and len(color) == 4: 

637 return color 

638 

639 raise ValueError(f"Invalid format for color: {color}") 

640 

641 # Find which edges can be plotted as a line collection 

642 # 

643 # Non-default values for these attributes require fancy arrow patches: 

644 # - any arrow style (including the default -|> for directed graphs) 

645 # - arrow size (by extension of style) 

646 # - connection style 

647 # - min_source_margin 

648 # - min_target_margin 

649 

650 def collection_compatible(e): 

651 return ( 

652 get_edge_attr(e, "arrowstyle") == "-" 

653 and get_edge_attr(e, "curvature") == "arc3" 

654 and get_edge_attr(e, "source_margin") == 0 

655 and get_edge_attr(e, "target_margin") == 0 

656 # Self-loops will use fancy arrow patches 

657 and e[0] != e[1] 

658 ) 

659 

660 def edge_property_sequence(seq, attr): 

661 """Return a list of attribute values for `seq`, using a default if needed""" 

662 

663 param_name = f"edge_{attr}" 

664 default = defaults[param_name] 

665 attr = kwargs.get(param_name, attr) 

666 

667 if default is None: 

668 # raise instead of using non-existant default value 

669 for e in seq: 

670 if attr not in edge_subgraph.edges[e]: 

671 raise nx.NetworkXError(f"Attribute '{attr}' missing for edge {e}") 

672 

673 if ( 

674 attr is not None 

675 and nx.get_edge_attributes(edge_subgraph, attr) == {} 

676 and any(attr == v for k, v in kwargs.items() if "edge" in k) 

677 ): 

678 return [attr for _ in seq] 

679 

680 return [edge_subgraph.edges[e].get(attr, default) for e in seq] 

681 

682 def get_edge_attr(e, attr): 

683 """Return the final edge attribute value, using default if not None""" 

684 

685 param_name = f"edge_{attr}" 

686 default = defaults[param_name] 

687 attr = kwargs.get(param_name, attr) 

688 

689 if default is None and attr not in edge_subgraph.edges[e]: 

690 raise nx.NetworkXError(f"Attribute '{attr}' missing from edge {e}") 

691 

692 if ( 

693 attr is not None 

694 and nx.get_edge_attributes(edge_subgraph, attr) == {} 

695 and attr in kwargs.values() 

696 ): 

697 return attr 

698 

699 return edge_subgraph.edges[e].get(attr, default) 

700 

701 def get_node_attr(n, attr, use_edge_subgraph=True): 

702 """Return the final node attribute value, using default if not None""" 

703 subgraph = edge_subgraph if use_edge_subgraph else node_subgraph 

704 

705 param_name = f"node_{attr}" 

706 default = defaults[param_name] 

707 attr = kwargs.get(param_name, attr) 

708 

709 if default is None and attr not in subgraph.nodes[n]: 

710 raise nx.NetworkXError(f"Attribute '{attr}' missing from node {n}") 

711 

712 if ( 

713 attr is not None 

714 and nx.get_node_attributes(subgraph, attr) == {} 

715 and attr in kwargs.values() 

716 ): 

717 return attr 

718 

719 return subgraph.nodes[n].get(attr, default) 

720 

721 # Taken from ConnectionStyleFactory 

722 def self_loop(edge_index, node_size): 

723 def self_loop_connection(posA, posB, *args, **kwargs): 

724 if not np.all(posA == posB): 

725 raise nx.NetworkXError( 

726 "`self_loop` connection style method" 

727 "is only to be used for self-loops" 

728 ) 

729 # this is called with _screen space_ values 

730 # so convert back to data space 

731 data_loc = canvas.transData.inverted().transform(posA) 

732 # Scale self loop based on the size of the base node 

733 # Size of nodes are given in points ** 2 and each point is 1/72 of an inch 

734 v_shift = np.sqrt(node_size) / 72 

735 h_shift = v_shift * 0.5 

736 # put the top of the loop first so arrow is not hidden by node 

737 path = np.asarray( 

738 [ 

739 # 1 

740 [0, v_shift], 

741 # 4 4 4 

742 [h_shift, v_shift], 

743 [h_shift, 0], 

744 [0, 0], 

745 # 4 4 4 

746 [-h_shift, 0], 

747 [-h_shift, v_shift], 

748 [0, v_shift], 

749 ] 

750 ) 

751 # Rotate self loop 90 deg. if more than 1 

752 # This will allow for maximum of 4 visible self loops 

753 if edge_index % 4: 

754 x, y = path.T 

755 for _ in range(edge_index % 4): 

756 x, y = y, -x 

757 path = np.array([x, y]).T 

758 return mpl.path.Path( 

759 canvas.transData.transform(data_loc + path), [1, 4, 4, 4, 4, 4, 4] 

760 ) 

761 

762 return self_loop_connection 

763 

764 def to_marker_edge(size, marker): 

765 if marker in "s^>v<d": 

766 return np.sqrt(2 * size) / 2 

767 else: 

768 return np.sqrt(size) / 2 

769 

770 def build_fancy_arrow(e): 

771 source_margin = to_marker_edge( 

772 get_node_attr(e[0], "size"), 

773 get_node_attr(e[0], "shape"), 

774 ) 

775 source_margin = max( 

776 source_margin, 

777 get_edge_attr(e, "source_margin"), 

778 ) 

779 

780 target_margin = to_marker_edge( 

781 get_node_attr(e[1], "size"), 

782 get_node_attr(e[1], "shape"), 

783 ) 

784 target_margin = max( 

785 target_margin, 

786 get_edge_attr(e, "target_margin"), 

787 ) 

788 return mpl.patches.FancyArrowPatch( 

789 edge_subgraph.nodes[e[0]][pos], 

790 edge_subgraph.nodes[e[1]][pos], 

791 arrowstyle=get_edge_attr(e, "arrowstyle"), 

792 connectionstyle=( 

793 get_edge_attr(e, "curvature") 

794 if e[0] != e[1] 

795 else self_loop( 

796 0 if len(e) == 2 else e[2] % 4, 

797 get_node_attr(e[0], "size"), 

798 ) 

799 ), 

800 color=get_edge_attr(e, "color"), 

801 linestyle=get_edge_attr(e, "style"), 

802 linewidth=get_edge_attr(e, "width"), 

803 mutation_scale=get_edge_attr(e, "arrowsize"), 

804 shrinkA=source_margin, 

805 shrinkB=source_margin, 

806 zorder=1, 

807 ) 

808 

809 class CurvedArrowText(CurvedArrowTextBase, mpl.text.Text): 

810 pass 

811 

812 ### Draw the nodes first 

813 node_visible = kwargs.get("node_visible", "visible") 

814 if isinstance(node_visible, bool): 

815 if node_visible: 

816 visible_nodes = G.nodes() 

817 else: 

818 visible_nodes = [] 

819 else: 

820 visible_nodes = [ 

821 n for n, v in nx.get_node_attributes(G, node_visible, True).items() if v 

822 ] 

823 

824 node_subgraph = G.subgraph(visible_nodes) 

825 

826 # Ignore the default dict value since that's for default values to use, not 

827 # default attribute name 

828 pos = kwargs.get("node_pos", "pos") 

829 

830 default_display_pos_attr = "display's position attribute name" 

831 if callable(pos): 

832 nx.set_node_attributes( 

833 node_subgraph, pos(node_subgraph), default_display_pos_attr 

834 ) 

835 pos = default_display_pos_attr 

836 kwargs["node_pos"] = default_display_pos_attr 

837 elif nx.get_node_attributes(G, pos) == {}: 

838 nx.set_node_attributes( 

839 node_subgraph, nx.spring_layout(node_subgraph), default_display_pos_attr 

840 ) 

841 pos = default_display_pos_attr 

842 kwargs["node_pos"] = default_display_pos_attr 

843 

844 # Each shape requires a new scatter object since they can't have different 

845 # shapes. 

846 if len(visible_nodes) > 0: 

847 node_shape = kwargs.get("node_shape", "shape") 

848 for shape in Counter( 

849 nx.get_node_attributes( 

850 node_subgraph, node_shape, defaults["node_shape"] 

851 ).values() 

852 ): 

853 # Filter position just on this shape. 

854 nodes_with_shape = [ 

855 n 

856 for n, s in node_subgraph.nodes(data=node_shape) 

857 if s == shape or (s is None and shape == defaults["node_shape"]) 

858 ] 

859 # There are two property sequences to create before hand. 

860 # 1. position, since it is used for x and y parameters to scatter 

861 # 2. edgecolor, since the spaeical 'face' parameter value can only be 

862 # be passed in as the sole string, not part of a list of strings. 

863 position = np.asarray(node_property_sequence(nodes_with_shape, "pos")) 

864 color = np.asarray( 

865 [ 

866 compute_colors(c, a) 

867 for c, a in zip( 

868 node_property_sequence(nodes_with_shape, "color"), 

869 node_property_sequence(nodes_with_shape, "alpha"), 

870 ) 

871 ] 

872 ) 

873 border_color = np.asarray( 

874 [ 

875 ( 

876 c 

877 if ( 

878 c := get_node_attr( 

879 n, 

880 "border_color", 

881 False, 

882 ) 

883 ) 

884 != "face" 

885 else color[i] 

886 ) 

887 for i, n in enumerate(nodes_with_shape) 

888 ] 

889 ) 

890 canvas.scatter( 

891 position[:, 0], 

892 position[:, 1], 

893 s=node_property_sequence(nodes_with_shape, "size"), 

894 c=color, 

895 marker=shape, 

896 linewidths=node_property_sequence(nodes_with_shape, "border_width"), 

897 edgecolors=border_color, 

898 zorder=2, 

899 ) 

900 

901 ### Draw node labels 

902 node_label = kwargs.get("node_label", "label") 

903 # Plot labels if node_label is not None and not False 

904 if node_label is not None and node_label is not False: 

905 default_dict = {} 

906 if isinstance(node_label, dict): 

907 default_dict = node_label 

908 node_label = None 

909 

910 for n, lbl in node_subgraph.nodes(data=node_label): 

911 if lbl is False: 

912 continue 

913 

914 # We work with label dicts down here... 

915 if not isinstance(lbl, dict): 

916 lbl = {"label": lbl if lbl is not None else n} 

917 

918 lbl_text = lbl.get("label", n) 

919 if not isinstance(lbl_text, str): 

920 lbl_text = str(lbl_text) 

921 

922 lbl.update(default_dict) 

923 x, y = node_subgraph.nodes[n][pos] 

924 canvas.text( 

925 x, 

926 y, 

927 lbl_text, 

928 size=lbl.get("size", defaults["node_label"]["size"]), 

929 color=lbl.get("color", defaults["node_label"]["color"]), 

930 family=lbl.get("family", defaults["node_label"]["family"]), 

931 weight=lbl.get("weight", defaults["node_label"]["weight"]), 

932 horizontalalignment=lbl.get( 

933 "h_align", defaults["node_label"]["h_align"] 

934 ), 

935 verticalalignment=lbl.get("v_align", defaults["node_label"]["v_align"]), 

936 transform=canvas.transData, 

937 bbox=lbl.get("bbox", defaults["node_label"]["bbox"]), 

938 ) 

939 

940 ### Draw edges 

941 

942 edge_visible = kwargs.get("edge_visible", "visible") 

943 if isinstance(edge_visible, bool): 

944 if edge_visible: 

945 visible_edges = G.edges() 

946 else: 

947 visible_edges = [] 

948 else: 

949 visible_edges = [ 

950 e for e, v in nx.get_edge_attributes(G, edge_visible, True).items() if v 

951 ] 

952 

953 edge_subgraph = G.edge_subgraph(visible_edges) 

954 nx.set_node_attributes( 

955 edge_subgraph, nx.get_node_attributes(node_subgraph, pos), name=pos 

956 ) 

957 

958 collection_edges = ( 

959 [e for e in edge_subgraph.edges(keys=True) if collection_compatible(e)] 

960 if edge_subgraph.is_multigraph() 

961 else [e for e in edge_subgraph.edges() if collection_compatible(e)] 

962 ) 

963 non_collection_edges = ( 

964 [e for e in edge_subgraph.edges(keys=True) if not collection_compatible(e)] 

965 if edge_subgraph.is_multigraph() 

966 else [e for e in edge_subgraph.edges() if not collection_compatible(e)] 

967 ) 

968 edge_position = np.asarray( 

969 [ 

970 ( 

971 get_node_attr(u, "pos", use_edge_subgraph=True), 

972 get_node_attr(v, "pos", use_edge_subgraph=True), 

973 ) 

974 for u, v, *_ in collection_edges 

975 ] 

976 ) 

977 

978 # Only plot a line collection if needed 

979 if len(collection_edges) > 0: 

980 edge_collection = mpl.collections.LineCollection( 

981 edge_position, 

982 colors=edge_property_sequence(collection_edges, "color"), 

983 linewidths=edge_property_sequence(collection_edges, "width"), 

984 linestyle=edge_property_sequence(collection_edges, "style"), 

985 alpha=edge_property_sequence(collection_edges, "alpha"), 

986 antialiaseds=(1,), 

987 zorder=1, 

988 ) 

989 canvas.add_collection(edge_collection) 

990 

991 fancy_arrows = {} 

992 if len(non_collection_edges) > 0: 

993 for e in non_collection_edges: 

994 # Cache results for use in edge labels 

995 fancy_arrows[e] = build_fancy_arrow(e) 

996 canvas.add_patch(fancy_arrows[e]) 

997 

998 ### Draw edge labels 

999 edge_label = kwargs.get("edge_label", "label") 

1000 default_dict = {} 

1001 if isinstance(edge_label, dict): 

1002 default_dict = edge_label 

1003 # Restore the default label attribute key of 'label' 

1004 edge_label = "label" 

1005 

1006 # Handle multigraphs 

1007 edge_label_data = ( 

1008 edge_subgraph.edges(data=edge_label, keys=True) 

1009 if edge_subgraph.is_multigraph() 

1010 else edge_subgraph.edges(data=edge_label) 

1011 ) 

1012 if edge_label is not None and edge_label is not False: 

1013 for *e, lbl in edge_label_data: 

1014 e = tuple(e) 

1015 # I'm not sure how I want to handle None here... For now it means no label 

1016 if lbl is False or lbl is None: 

1017 continue 

1018 

1019 if not isinstance(lbl, dict): 

1020 lbl = {"label": lbl} 

1021 

1022 lbl.update(default_dict) 

1023 lbl_text = lbl.get("label") 

1024 if not isinstance(lbl_text, str): 

1025 lbl_text = str(lbl_text) 

1026 

1027 # In the old code, every non-self-loop is placed via a fancy arrow patch 

1028 # Only compute a new fancy arrow if needed by caching the results from 

1029 # edge placement. 

1030 try: 

1031 arrow = fancy_arrows[e] 

1032 except KeyError: 

1033 arrow = build_fancy_arrow(e) 

1034 

1035 if e[0] == e[1]: 

1036 # Taken directly from draw_networkx_edge_labels 

1037 connectionstyle_obj = arrow.get_connectionstyle() 

1038 posA = canvas.transData.transform(edge_subgraph.nodes[e[0]][pos]) 

1039 path_disp = connectionstyle_obj(posA, posA) 

1040 path_data = canvas.transData.inverted().transform_path(path_disp) 

1041 x, y = path_data.vertices[0] 

1042 canvas.text( 

1043 x, 

1044 y, 

1045 lbl_text, 

1046 size=lbl.get("size", defaults["edge_label"]["size"]), 

1047 color=lbl.get("color", defaults["edge_label"]["color"]), 

1048 family=lbl.get("family", defaults["edge_label"]["family"]), 

1049 weight=lbl.get("weight", defaults["edge_label"]["weight"]), 

1050 alpha=lbl.get("alpha", defaults["edge_label"]["alpha"]), 

1051 horizontalalignment=lbl.get( 

1052 "h_align", defaults["edge_label"]["h_align"] 

1053 ), 

1054 verticalalignment=lbl.get( 

1055 "v_align", defaults["edge_label"]["v_align"] 

1056 ), 

1057 rotation=0, 

1058 transform=canvas.transData, 

1059 bbox=lbl.get("bbox", defaults["edge_label"]["bbox"]), 

1060 zorder=1, 

1061 ) 

1062 continue 

1063 

1064 CurvedArrowText( 

1065 arrow, 

1066 lbl_text, 

1067 size=lbl.get("size", defaults["edge_label"]["size"]), 

1068 color=lbl.get("color", defaults["edge_label"]["color"]), 

1069 family=lbl.get("family", defaults["edge_label"]["family"]), 

1070 weight=lbl.get("weight", defaults["edge_label"]["weight"]), 

1071 alpha=lbl.get("alpha", defaults["edge_label"]["alpha"]), 

1072 bbox=lbl.get("bbox", defaults["edge_label"]["bbox"]), 

1073 horizontalalignment=lbl.get( 

1074 "h_align", defaults["edge_label"]["h_align"] 

1075 ), 

1076 verticalalignment=lbl.get("v_align", defaults["edge_label"]["v_align"]), 

1077 label_pos=lbl.get("pos", defaults["edge_label"]["pos"]), 

1078 labels_horizontal=lbl.get("rotate", defaults["edge_label"]["rotate"]), 

1079 transform=canvas.transData, 

1080 zorder=1, 

1081 ax=canvas, 

1082 ) 

1083 

1084 # If we had to add an attribute, remove it here 

1085 if pos == default_display_pos_attr: 

1086 nx.remove_node_attributes(G, default_display_pos_attr) 

1087 

1088 return G 

1089 

1090 

1091def draw(G, pos=None, ax=None, **kwds): 

1092 """Draw the graph G with Matplotlib. 

1093 

1094 Draw the graph as a simple representation with no node 

1095 labels or edge labels and using the full Matplotlib figure area 

1096 and no axis labels by default. See draw_networkx() for more 

1097 full-featured drawing that allows title, axis labels etc. 

1098 

1099 Parameters 

1100 ---------- 

1101 G : graph 

1102 A networkx graph 

1103 

1104 pos : dictionary, optional 

1105 A dictionary with nodes as keys and positions as values. 

1106 If not specified a spring layout positioning will be computed. 

1107 See :py:mod:`networkx.drawing.layout` for functions that 

1108 compute node positions. 

1109 

1110 ax : Matplotlib Axes object, optional 

1111 Draw the graph in specified Matplotlib axes. 

1112 

1113 kwds : optional keywords 

1114 See networkx.draw_networkx() for a description of optional keywords. 

1115 

1116 Examples 

1117 -------- 

1118 >>> G = nx.dodecahedral_graph() 

1119 >>> nx.draw(G) 

1120 >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout 

1121 

1122 See Also 

1123 -------- 

1124 draw_networkx 

1125 draw_networkx_nodes 

1126 draw_networkx_edges 

1127 draw_networkx_labels 

1128 draw_networkx_edge_labels 

1129 

1130 Notes 

1131 ----- 

1132 This function has the same name as pylab.draw and pyplot.draw 

1133 so beware when using `from networkx import *` 

1134 

1135 since you might overwrite the pylab.draw function. 

1136 

1137 With pyplot use 

1138 

1139 >>> import matplotlib.pyplot as plt 

1140 >>> G = nx.dodecahedral_graph() 

1141 >>> nx.draw(G) # networkx draw() 

1142 >>> plt.draw() # pyplot draw() 

1143 

1144 Also see the NetworkX drawing examples at 

1145 https://networkx.org/documentation/latest/auto_examples/index.html 

1146 """ 

1147 

1148 import matplotlib.pyplot as plt 

1149 

1150 if ax is None: 

1151 cf = plt.gcf() 

1152 else: 

1153 cf = ax.get_figure() 

1154 cf.set_facecolor("w") 

1155 if ax is None: 

1156 if cf.axes: 

1157 ax = cf.gca() 

1158 else: 

1159 ax = cf.add_axes((0, 0, 1, 1)) 

1160 

1161 if "with_labels" not in kwds: 

1162 kwds["with_labels"] = "labels" in kwds 

1163 

1164 draw_networkx(G, pos=pos, ax=ax, **kwds) 

1165 ax.set_axis_off() 

1166 plt.draw_if_interactive() 

1167 return 

1168 

1169 

1170def draw_networkx(G, pos=None, arrows=None, with_labels=True, **kwds): 

1171 r"""Draw the graph G using Matplotlib. 

1172 

1173 Draw the graph with Matplotlib with options for node positions, 

1174 labeling, titles, and many other drawing features. 

1175 See draw() for simple drawing without labels or axes. 

1176 

1177 Parameters 

1178 ---------- 

1179 G : graph 

1180 A networkx graph 

1181 

1182 pos : dictionary, optional 

1183 A dictionary with nodes as keys and positions as values. 

1184 If not specified a spring layout positioning will be computed. 

1185 See :py:mod:`networkx.drawing.layout` for functions that 

1186 compute node positions. 

1187 

1188 arrows : bool or None, optional (default=None) 

1189 If `None`, directed graphs draw arrowheads with 

1190 `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges 

1191 via `~matplotlib.collections.LineCollection` for speed. 

1192 If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish). 

1193 If `False`, draw edges using LineCollection (linear and fast). 

1194 For directed graphs, if True draw arrowheads. 

1195 Note: Arrows will be the same color as edges. 

1196 

1197 arrowstyle : str (default='-\|>' for directed graphs) 

1198 For directed graphs, choose the style of the arrowsheads. 

1199 For undirected graphs default to '-' 

1200 

1201 See `matplotlib.patches.ArrowStyle` for more options. 

1202 

1203 arrowsize : int or list (default=10) 

1204 For directed graphs, choose the size of the arrow head's length and 

1205 width. A list of values can be passed in to assign a different size for arrow head's length and width. 

1206 See `matplotlib.patches.FancyArrowPatch` for attribute `mutation_scale` 

1207 for more info. 

1208 

1209 with_labels : bool (default=True) 

1210 Set to True to draw labels on the nodes. 

1211 

1212 ax : Matplotlib Axes object, optional 

1213 Draw the graph in the specified Matplotlib axes. 

1214 

1215 nodelist : list (default=list(G)) 

1216 Draw only specified nodes 

1217 

1218 edgelist : list (default=list(G.edges())) 

1219 Draw only specified edges 

1220 

1221 node_size : scalar or array (default=300) 

1222 Size of nodes. If an array is specified it must be the 

1223 same length as nodelist. 

1224 

1225 node_color : color or array of colors (default='#1f78b4') 

1226 Node color. Can be a single color or a sequence of colors with the same 

1227 length as nodelist. Color can be string or rgb (or rgba) tuple of 

1228 floats from 0-1. If numeric values are specified they will be 

1229 mapped to colors using the cmap and vmin,vmax parameters. See 

1230 matplotlib.scatter for more details. 

1231 

1232 node_shape : string (default='o') 

1233 The shape of the node. Specification is as matplotlib.scatter 

1234 marker, one of 'so^>v<dph8'. 

1235 

1236 alpha : float or None (default=None) 

1237 The node and edge transparency 

1238 

1239 cmap : Matplotlib colormap, optional 

1240 Colormap for mapping intensities of nodes 

1241 

1242 vmin,vmax : float, optional 

1243 Minimum and maximum for node colormap scaling 

1244 

1245 linewidths : scalar or sequence (default=1.0) 

1246 Line width of symbol border 

1247 

1248 width : float or array of floats (default=1.0) 

1249 Line width of edges 

1250 

1251 edge_color : color or array of colors (default='k') 

1252 Edge color. Can be a single color or a sequence of colors with the same 

1253 length as edgelist. Color can be string or rgb (or rgba) tuple of 

1254 floats from 0-1. If numeric values are specified they will be 

1255 mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters. 

1256 

1257 edge_cmap : Matplotlib colormap, optional 

1258 Colormap for mapping intensities of edges 

1259 

1260 edge_vmin,edge_vmax : floats, optional 

1261 Minimum and maximum for edge colormap scaling 

1262 

1263 style : string (default=solid line) 

1264 Edge line style e.g.: '-', '--', '-.', ':' 

1265 or words like 'solid' or 'dashed'. 

1266 (See `matplotlib.patches.FancyArrowPatch`: `linestyle`) 

1267 

1268 labels : dictionary (default=None) 

1269 Node labels in a dictionary of text labels keyed by node 

1270 

1271 font_size : int (default=12 for nodes, 10 for edges) 

1272 Font size for text labels 

1273 

1274 font_color : color (default='k' black) 

1275 Font color string. Color can be string or rgb (or rgba) tuple of 

1276 floats from 0-1. 

1277 

1278 font_weight : string (default='normal') 

1279 Font weight 

1280 

1281 font_family : string (default='sans-serif') 

1282 Font family 

1283 

1284 label : string, optional 

1285 Label for graph legend 

1286 

1287 hide_ticks : bool, optional 

1288 Hide ticks of axes. When `True` (the default), ticks and ticklabels 

1289 are removed from the axes. To set ticks and tick labels to the pyplot default, 

1290 use ``hide_ticks=False``. 

1291 

1292 kwds : optional keywords 

1293 See networkx.draw_networkx_nodes(), networkx.draw_networkx_edges(), and 

1294 networkx.draw_networkx_labels() for a description of optional keywords. 

1295 

1296 Notes 

1297 ----- 

1298 For directed graphs, arrows are drawn at the head end. Arrows can be 

1299 turned off with keyword arrows=False. 

1300 

1301 Examples 

1302 -------- 

1303 >>> G = nx.dodecahedral_graph() 

1304 >>> nx.draw(G) 

1305 >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout 

1306 

1307 >>> import matplotlib.pyplot as plt 

1308 >>> limits = plt.axis("off") # turn off axis 

1309 

1310 Also see the NetworkX drawing examples at 

1311 https://networkx.org/documentation/latest/auto_examples/index.html 

1312 

1313 See Also 

1314 -------- 

1315 draw 

1316 draw_networkx_nodes 

1317 draw_networkx_edges 

1318 draw_networkx_labels 

1319 draw_networkx_edge_labels 

1320 """ 

1321 from inspect import signature 

1322 

1323 import matplotlib.pyplot as plt 

1324 

1325 # Get all valid keywords by inspecting the signatures of draw_networkx_nodes, 

1326 # draw_networkx_edges, draw_networkx_labels 

1327 

1328 valid_node_kwds = signature(draw_networkx_nodes).parameters.keys() 

1329 valid_edge_kwds = signature(draw_networkx_edges).parameters.keys() 

1330 valid_label_kwds = signature(draw_networkx_labels).parameters.keys() 

1331 

1332 # Create a set with all valid keywords across the three functions and 

1333 # remove the arguments of this function (draw_networkx) 

1334 valid_kwds = (valid_node_kwds | valid_edge_kwds | valid_label_kwds) - { 

1335 "G", 

1336 "pos", 

1337 "arrows", 

1338 "with_labels", 

1339 } 

1340 

1341 if any(k not in valid_kwds for k in kwds): 

1342 invalid_args = ", ".join([k for k in kwds if k not in valid_kwds]) 

1343 raise ValueError(f"Received invalid argument(s): {invalid_args}") 

1344 

1345 node_kwds = {k: v for k, v in kwds.items() if k in valid_node_kwds} 

1346 edge_kwds = {k: v for k, v in kwds.items() if k in valid_edge_kwds} 

1347 label_kwds = {k: v for k, v in kwds.items() if k in valid_label_kwds} 

1348 

1349 if pos is None: 

1350 pos = nx.drawing.spring_layout(G) # default to spring layout 

1351 

1352 draw_networkx_nodes(G, pos, **node_kwds) 

1353 draw_networkx_edges(G, pos, arrows=arrows, **edge_kwds) 

1354 if with_labels: 

1355 draw_networkx_labels(G, pos, **label_kwds) 

1356 plt.draw_if_interactive() 

1357 

1358 

1359def draw_networkx_nodes( 

1360 G, 

1361 pos, 

1362 nodelist=None, 

1363 node_size=300, 

1364 node_color="#1f78b4", 

1365 node_shape="o", 

1366 alpha=None, 

1367 cmap=None, 

1368 vmin=None, 

1369 vmax=None, 

1370 ax=None, 

1371 linewidths=None, 

1372 edgecolors=None, 

1373 label=None, 

1374 margins=None, 

1375 hide_ticks=True, 

1376): 

1377 """Draw the nodes of the graph G. 

1378 

1379 This draws only the nodes of the graph G. 

1380 

1381 Parameters 

1382 ---------- 

1383 G : graph 

1384 A networkx graph 

1385 

1386 pos : dictionary 

1387 A dictionary with nodes as keys and positions as values. 

1388 Positions should be sequences of length 2. 

1389 

1390 ax : Matplotlib Axes object, optional 

1391 Draw the graph in the specified Matplotlib axes. 

1392 

1393 nodelist : list (default list(G)) 

1394 Draw only specified nodes 

1395 

1396 node_size : scalar or array (default=300) 

1397 Size of nodes. If an array it must be the same length as nodelist. 

1398 

1399 node_color : color or array of colors (default='#1f78b4') 

1400 Node color. Can be a single color or a sequence of colors with the same 

1401 length as nodelist. Color can be string or rgb (or rgba) tuple of 

1402 floats from 0-1. If numeric values are specified they will be 

1403 mapped to colors using the cmap and vmin,vmax parameters. See 

1404 matplotlib.scatter for more details. 

1405 

1406 node_shape : string (default='o') 

1407 The shape of the node. Specification is as matplotlib.scatter 

1408 marker, one of 'so^>v<dph8'. 

1409 

1410 alpha : float or array of floats (default=None) 

1411 The node transparency. This can be a single alpha value, 

1412 in which case it will be applied to all the nodes of color. Otherwise, 

1413 if it is an array, the elements of alpha will be applied to the colors 

1414 in order (cycling through alpha multiple times if necessary). 

1415 

1416 cmap : Matplotlib colormap (default=None) 

1417 Colormap for mapping intensities of nodes 

1418 

1419 vmin,vmax : floats or None (default=None) 

1420 Minimum and maximum for node colormap scaling 

1421 

1422 linewidths : [None | scalar | sequence] (default=1.0) 

1423 Line width of symbol border 

1424 

1425 edgecolors : [None | scalar | sequence] (default = node_color) 

1426 Colors of node borders. Can be a single color or a sequence of colors with the 

1427 same length as nodelist. Color can be string or rgb (or rgba) tuple of floats 

1428 from 0-1. If numeric values are specified they will be mapped to colors 

1429 using the cmap and vmin,vmax parameters. See `~matplotlib.pyplot.scatter` for more details. 

1430 

1431 label : [None | string] 

1432 Label for legend 

1433 

1434 margins : float or 2-tuple, optional 

1435 Sets the padding for axis autoscaling. Increase margin to prevent 

1436 clipping for nodes that are near the edges of an image. Values should 

1437 be in the range ``[0, 1]``. See :meth:`matplotlib.axes.Axes.margins` 

1438 for details. The default is `None`, which uses the Matplotlib default. 

1439 

1440 hide_ticks : bool, optional 

1441 Hide ticks of axes. When `True` (the default), ticks and ticklabels 

1442 are removed from the axes. To set ticks and tick labels to the pyplot default, 

1443 use ``hide_ticks=False``. 

1444 

1445 Returns 

1446 ------- 

1447 matplotlib.collections.PathCollection 

1448 `PathCollection` of the nodes. 

1449 

1450 Examples 

1451 -------- 

1452 >>> G = nx.dodecahedral_graph() 

1453 >>> nodes = nx.draw_networkx_nodes(G, pos=nx.spring_layout(G)) 

1454 

1455 Also see the NetworkX drawing examples at 

1456 https://networkx.org/documentation/latest/auto_examples/index.html 

1457 

1458 See Also 

1459 -------- 

1460 draw 

1461 draw_networkx 

1462 draw_networkx_edges 

1463 draw_networkx_labels 

1464 draw_networkx_edge_labels 

1465 """ 

1466 from collections.abc import Iterable 

1467 

1468 import matplotlib as mpl 

1469 import matplotlib.collections # call as mpl.collections 

1470 import matplotlib.pyplot as plt 

1471 import numpy as np 

1472 

1473 if ax is None: 

1474 ax = plt.gca() 

1475 

1476 if nodelist is None: 

1477 nodelist = list(G) 

1478 

1479 if len(nodelist) == 0: # empty nodelist, no drawing 

1480 return mpl.collections.PathCollection(None) 

1481 

1482 try: 

1483 xy = np.asarray([pos[v] for v in nodelist]) 

1484 except KeyError as err: 

1485 raise nx.NetworkXError(f"Node {err} has no position.") from err 

1486 

1487 if isinstance(alpha, Iterable): 

1488 node_color = apply_alpha(node_color, alpha, nodelist, cmap, vmin, vmax) 

1489 alpha = None 

1490 

1491 if not isinstance(node_shape, np.ndarray) and not isinstance(node_shape, list): 

1492 node_shape = np.array([node_shape for _ in range(len(nodelist))]) 

1493 elif isinstance(node_shape, list): 

1494 node_shape = np.asarray(node_shape) 

1495 

1496 for shape in np.unique(node_shape): 

1497 node_collection = ax.scatter( 

1498 xy[node_shape == shape, 0], 

1499 xy[node_shape == shape, 1], 

1500 s=node_size, 

1501 c=node_color, 

1502 marker=shape, 

1503 cmap=cmap, 

1504 vmin=vmin, 

1505 vmax=vmax, 

1506 alpha=alpha, 

1507 linewidths=linewidths, 

1508 edgecolors=edgecolors, 

1509 label=label, 

1510 ) 

1511 if hide_ticks: 

1512 ax.tick_params( 

1513 axis="both", 

1514 which="both", 

1515 bottom=False, 

1516 left=False, 

1517 labelbottom=False, 

1518 labelleft=False, 

1519 ) 

1520 

1521 if margins is not None: 

1522 if isinstance(margins, Iterable): 

1523 ax.margins(*margins) 

1524 else: 

1525 ax.margins(margins) 

1526 

1527 node_collection.set_zorder(2) 

1528 return node_collection 

1529 

1530 

1531class FancyArrowFactory: 

1532 """Draw arrows with `matplotlib.patches.FancyarrowPatch`""" 

1533 

1534 class ConnectionStyleFactory: 

1535 def __init__(self, connectionstyles, selfloop_height, ax=None): 

1536 import matplotlib as mpl 

1537 import matplotlib.path # call as mpl.path 

1538 import numpy as np 

1539 

1540 self.ax = ax 

1541 self.mpl = mpl 

1542 self.np = np 

1543 self.base_connection_styles = [ 

1544 mpl.patches.ConnectionStyle(cs) for cs in connectionstyles 

1545 ] 

1546 self.n = len(self.base_connection_styles) 

1547 self.selfloop_height = selfloop_height 

1548 

1549 def curved(self, edge_index): 

1550 return self.base_connection_styles[edge_index % self.n] 

1551 

1552 def self_loop(self, edge_index): 

1553 def self_loop_connection(posA, posB, *args, **kwargs): 

1554 if not self.np.all(posA == posB): 

1555 raise nx.NetworkXError( 

1556 "`self_loop` connection style method" 

1557 "is only to be used for self-loops" 

1558 ) 

1559 # this is called with _screen space_ values 

1560 # so convert back to data space 

1561 data_loc = self.ax.transData.inverted().transform(posA) 

1562 v_shift = 0.1 * self.selfloop_height 

1563 h_shift = v_shift * 0.5 

1564 # put the top of the loop first so arrow is not hidden by node 

1565 path = self.np.asarray( 

1566 [ 

1567 # 1 

1568 [0, v_shift], 

1569 # 4 4 4 

1570 [h_shift, v_shift], 

1571 [h_shift, 0], 

1572 [0, 0], 

1573 # 4 4 4 

1574 [-h_shift, 0], 

1575 [-h_shift, v_shift], 

1576 [0, v_shift], 

1577 ] 

1578 ) 

1579 # Rotate self loop 90 deg. if more than 1 

1580 # This will allow for maximum of 4 visible self loops 

1581 if edge_index % 4: 

1582 x, y = path.T 

1583 for _ in range(edge_index % 4): 

1584 x, y = y, -x 

1585 path = self.np.array([x, y]).T 

1586 return self.mpl.path.Path( 

1587 self.ax.transData.transform(data_loc + path), [1, 4, 4, 4, 4, 4, 4] 

1588 ) 

1589 

1590 return self_loop_connection 

1591 

1592 def __init__( 

1593 self, 

1594 edge_pos, 

1595 edgelist, 

1596 nodelist, 

1597 edge_indices, 

1598 node_size, 

1599 selfloop_height, 

1600 connectionstyle="arc3", 

1601 node_shape="o", 

1602 arrowstyle="-", 

1603 arrowsize=10, 

1604 edge_color="k", 

1605 alpha=None, 

1606 linewidth=1.0, 

1607 style="solid", 

1608 min_source_margin=0, 

1609 min_target_margin=0, 

1610 ax=None, 

1611 ): 

1612 import matplotlib as mpl 

1613 import matplotlib.patches # call as mpl.patches 

1614 import matplotlib.pyplot as plt 

1615 import numpy as np 

1616 

1617 if isinstance(connectionstyle, str): 

1618 connectionstyle = [connectionstyle] 

1619 elif np.iterable(connectionstyle): 

1620 connectionstyle = list(connectionstyle) 

1621 else: 

1622 msg = "ConnectionStyleFactory arg `connectionstyle` must be str or iterable" 

1623 raise nx.NetworkXError(msg) 

1624 self.ax = ax 

1625 self.mpl = mpl 

1626 self.np = np 

1627 self.edge_pos = edge_pos 

1628 self.edgelist = edgelist 

1629 self.nodelist = nodelist 

1630 self.node_shape = node_shape 

1631 self.min_source_margin = min_source_margin 

1632 self.min_target_margin = min_target_margin 

1633 self.edge_indices = edge_indices 

1634 self.node_size = node_size 

1635 self.connectionstyle_factory = self.ConnectionStyleFactory( 

1636 connectionstyle, selfloop_height, ax 

1637 ) 

1638 self.arrowstyle = arrowstyle 

1639 self.arrowsize = arrowsize 

1640 self.arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha) 

1641 self.linewidth = linewidth 

1642 self.style = style 

1643 if isinstance(arrowsize, list) and len(arrowsize) != len(edge_pos): 

1644 raise ValueError("arrowsize should have the same length as edgelist") 

1645 

1646 def __call__(self, i): 

1647 (x1, y1), (x2, y2) = self.edge_pos[i] 

1648 shrink_source = 0 # space from source to tail 

1649 shrink_target = 0 # space from head to target 

1650 if ( 

1651 self.np.iterable(self.min_source_margin) 

1652 and not isinstance(self.min_source_margin, str) 

1653 and not isinstance(self.min_source_margin, tuple) 

1654 ): 

1655 min_source_margin = self.min_source_margin[i] 

1656 else: 

1657 min_source_margin = self.min_source_margin 

1658 

1659 if ( 

1660 self.np.iterable(self.min_target_margin) 

1661 and not isinstance(self.min_target_margin, str) 

1662 and not isinstance(self.min_target_margin, tuple) 

1663 ): 

1664 min_target_margin = self.min_target_margin[i] 

1665 else: 

1666 min_target_margin = self.min_target_margin 

1667 

1668 if self.np.iterable(self.node_size): # many node sizes 

1669 source, target = self.edgelist[i][:2] 

1670 source_node_size = self.node_size[self.nodelist.index(source)] 

1671 target_node_size = self.node_size[self.nodelist.index(target)] 

1672 shrink_source = self.to_marker_edge(source_node_size, self.node_shape) 

1673 shrink_target = self.to_marker_edge(target_node_size, self.node_shape) 

1674 else: 

1675 shrink_source = self.to_marker_edge(self.node_size, self.node_shape) 

1676 shrink_target = shrink_source 

1677 shrink_source = max(shrink_source, min_source_margin) 

1678 shrink_target = max(shrink_target, min_target_margin) 

1679 

1680 # scale factor of arrow head 

1681 if isinstance(self.arrowsize, list): 

1682 mutation_scale = self.arrowsize[i] 

1683 else: 

1684 mutation_scale = self.arrowsize 

1685 

1686 if len(self.arrow_colors) > i: 

1687 arrow_color = self.arrow_colors[i] 

1688 elif len(self.arrow_colors) == 1: 

1689 arrow_color = self.arrow_colors[0] 

1690 else: # Cycle through colors 

1691 arrow_color = self.arrow_colors[i % len(self.arrow_colors)] 

1692 

1693 if self.np.iterable(self.linewidth): 

1694 if len(self.linewidth) > i: 

1695 linewidth = self.linewidth[i] 

1696 else: 

1697 linewidth = self.linewidth[i % len(self.linewidth)] 

1698 else: 

1699 linewidth = self.linewidth 

1700 

1701 if ( 

1702 self.np.iterable(self.style) 

1703 and not isinstance(self.style, str) 

1704 and not isinstance(self.style, tuple) 

1705 ): 

1706 if len(self.style) > i: 

1707 linestyle = self.style[i] 

1708 else: # Cycle through styles 

1709 linestyle = self.style[i % len(self.style)] 

1710 else: 

1711 linestyle = self.style 

1712 

1713 if x1 == x2 and y1 == y2: 

1714 connectionstyle = self.connectionstyle_factory.self_loop( 

1715 self.edge_indices[i] 

1716 ) 

1717 else: 

1718 connectionstyle = self.connectionstyle_factory.curved(self.edge_indices[i]) 

1719 

1720 if ( 

1721 self.np.iterable(self.arrowstyle) 

1722 and not isinstance(self.arrowstyle, str) 

1723 and not isinstance(self.arrowstyle, tuple) 

1724 ): 

1725 arrowstyle = self.arrowstyle[i] 

1726 else: 

1727 arrowstyle = self.arrowstyle 

1728 

1729 return self.mpl.patches.FancyArrowPatch( 

1730 (x1, y1), 

1731 (x2, y2), 

1732 arrowstyle=arrowstyle, 

1733 shrinkA=shrink_source, 

1734 shrinkB=shrink_target, 

1735 mutation_scale=mutation_scale, 

1736 color=arrow_color, 

1737 linewidth=linewidth, 

1738 connectionstyle=connectionstyle, 

1739 linestyle=linestyle, 

1740 zorder=1, # arrows go behind nodes 

1741 ) 

1742 

1743 def to_marker_edge(self, marker_size, marker): 

1744 if marker in "s^>v<d": # `large` markers need extra space 

1745 return self.np.sqrt(2 * marker_size) / 2 

1746 else: 

1747 return self.np.sqrt(marker_size) / 2 

1748 

1749 

1750def draw_networkx_edges( 

1751 G, 

1752 pos, 

1753 edgelist=None, 

1754 width=1.0, 

1755 edge_color="k", 

1756 style="solid", 

1757 alpha=None, 

1758 arrowstyle=None, 

1759 arrowsize=10, 

1760 edge_cmap=None, 

1761 edge_vmin=None, 

1762 edge_vmax=None, 

1763 ax=None, 

1764 arrows=None, 

1765 label=None, 

1766 node_size=300, 

1767 nodelist=None, 

1768 node_shape="o", 

1769 connectionstyle="arc3", 

1770 min_source_margin=0, 

1771 min_target_margin=0, 

1772 hide_ticks=True, 

1773): 

1774 r"""Draw the edges of the graph G. 

1775 

1776 This draws only the edges of the graph G. 

1777 

1778 Parameters 

1779 ---------- 

1780 G : graph 

1781 A networkx graph 

1782 

1783 pos : dictionary 

1784 A dictionary with nodes as keys and positions as values. 

1785 Positions should be sequences of length 2. 

1786 

1787 edgelist : collection of edge tuples (default=G.edges()) 

1788 Draw only specified edges 

1789 

1790 width : float or array of floats (default=1.0) 

1791 Line width of edges 

1792 

1793 edge_color : color or array of colors (default='k') 

1794 Edge color. Can be a single color or a sequence of colors with the same 

1795 length as edgelist. Color can be string or rgb (or rgba) tuple of 

1796 floats from 0-1. If numeric values are specified they will be 

1797 mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters. 

1798 

1799 style : string or array of strings (default='solid') 

1800 Edge line style e.g.: '-', '--', '-.', ':' 

1801 or words like 'solid' or 'dashed'. 

1802 Can be a single style or a sequence of styles with the same 

1803 length as the edge list. 

1804 If less styles than edges are given the styles will cycle. 

1805 If more styles than edges are given the styles will be used sequentially 

1806 and not be exhausted. 

1807 Also, `(offset, onoffseq)` tuples can be used as style instead of a strings. 

1808 (See `matplotlib.patches.FancyArrowPatch`: `linestyle`) 

1809 

1810 alpha : float or array of floats (default=None) 

1811 The edge transparency. This can be a single alpha value, 

1812 in which case it will be applied to all specified edges. Otherwise, 

1813 if it is an array, the elements of alpha will be applied to the colors 

1814 in order (cycling through alpha multiple times if necessary). 

1815 

1816 edge_cmap : Matplotlib colormap, optional 

1817 Colormap for mapping intensities of edges 

1818 

1819 edge_vmin,edge_vmax : floats, optional 

1820 Minimum and maximum for edge colormap scaling 

1821 

1822 ax : Matplotlib Axes object, optional 

1823 Draw the graph in the specified Matplotlib axes. 

1824 

1825 arrows : bool or None, optional (default=None) 

1826 If `None`, directed graphs draw arrowheads with 

1827 `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges 

1828 via `~matplotlib.collections.LineCollection` for speed. 

1829 If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish). 

1830 If `False`, draw edges using LineCollection (linear and fast). 

1831 

1832 Note: Arrowheads will be the same color as edges. 

1833 

1834 arrowstyle : str or list of strs (default='-\|>' for directed graphs) 

1835 For directed graphs and `arrows==True` defaults to '-\|>', 

1836 For undirected graphs default to '-'. 

1837 

1838 See `matplotlib.patches.ArrowStyle` for more options. 

1839 

1840 arrowsize : int or list of ints(default=10) 

1841 For directed graphs, choose the size of the arrow head's length and 

1842 width. See `matplotlib.patches.FancyArrowPatch` for attribute 

1843 `mutation_scale` for more info. 

1844 

1845 connectionstyle : string or iterable of strings (default="arc3") 

1846 Pass the connectionstyle parameter to create curved arc of rounding 

1847 radius rad. For example, connectionstyle='arc3,rad=0.2'. 

1848 See `matplotlib.patches.ConnectionStyle` and 

1849 `matplotlib.patches.FancyArrowPatch` for more info. 

1850 If Iterable, index indicates i'th edge key of MultiGraph 

1851 

1852 node_size : scalar or array (default=300) 

1853 Size of nodes. Though the nodes are not drawn with this function, the 

1854 node size is used in determining edge positioning. 

1855 

1856 nodelist : list, optional (default=G.nodes()) 

1857 This provides the node order for the `node_size` array (if it is an array). 

1858 

1859 node_shape : string (default='o') 

1860 The marker used for nodes, used in determining edge positioning. 

1861 Specification is as a `matplotlib.markers` marker, e.g. one of 'so^>v<dph8'. 

1862 

1863 label : None or string 

1864 Label for legend 

1865 

1866 min_source_margin : int or list of ints (default=0) 

1867 The minimum margin (gap) at the beginning of the edge at the source. 

1868 

1869 min_target_margin : int or list of ints (default=0) 

1870 The minimum margin (gap) at the end of the edge at the target. 

1871 

1872 hide_ticks : bool, optional 

1873 Hide ticks of axes. When `True` (the default), ticks and ticklabels 

1874 are removed from the axes. To set ticks and tick labels to the pyplot default, 

1875 use ``hide_ticks=False``. 

1876 

1877 Returns 

1878 ------- 

1879 matplotlib.collections.LineCollection or a list of matplotlib.patches.FancyArrowPatch 

1880 If ``arrows=True``, a list of FancyArrowPatches is returned. 

1881 If ``arrows=False``, a LineCollection is returned. 

1882 If ``arrows=None`` (the default), then a LineCollection is returned if 

1883 `G` is undirected, otherwise returns a list of FancyArrowPatches. 

1884 

1885 Notes 

1886 ----- 

1887 For directed graphs, arrows are drawn at the head end. Arrows can be 

1888 turned off with keyword arrows=False or by passing an arrowstyle without 

1889 an arrow on the end. 

1890 

1891 Be sure to include `node_size` as a keyword argument; arrows are 

1892 drawn considering the size of nodes. 

1893 

1894 Self-loops are always drawn with `~matplotlib.patches.FancyArrowPatch` 

1895 regardless of the value of `arrows` or whether `G` is directed. 

1896 When ``arrows=False`` or ``arrows=None`` and `G` is undirected, the 

1897 FancyArrowPatches corresponding to the self-loops are not explicitly 

1898 returned. They should instead be accessed via the ``Axes.patches`` 

1899 attribute (see examples). 

1900 

1901 Examples 

1902 -------- 

1903 >>> G = nx.dodecahedral_graph() 

1904 >>> edges = nx.draw_networkx_edges(G, pos=nx.spring_layout(G)) 

1905 

1906 >>> G = nx.DiGraph() 

1907 >>> G.add_edges_from([(1, 2), (1, 3), (2, 3)]) 

1908 >>> arcs = nx.draw_networkx_edges(G, pos=nx.spring_layout(G)) 

1909 >>> alphas = [0.3, 0.4, 0.5] 

1910 >>> for i, arc in enumerate(arcs): # change alpha values of arcs 

1911 ... arc.set_alpha(alphas[i]) 

1912 

1913 The FancyArrowPatches corresponding to self-loops are not always 

1914 returned, but can always be accessed via the ``patches`` attribute of the 

1915 `matplotlib.Axes` object. 

1916 

1917 >>> import matplotlib.pyplot as plt 

1918 >>> fig, ax = plt.subplots() 

1919 >>> G = nx.Graph([(0, 1), (0, 0)]) # Self-loop at node 0 

1920 >>> edge_collection = nx.draw_networkx_edges(G, pos=nx.circular_layout(G), ax=ax) 

1921 >>> self_loop_fap = ax.patches[0] 

1922 

1923 Also see the NetworkX drawing examples at 

1924 https://networkx.org/documentation/latest/auto_examples/index.html 

1925 

1926 See Also 

1927 -------- 

1928 draw 

1929 draw_networkx 

1930 draw_networkx_nodes 

1931 draw_networkx_labels 

1932 draw_networkx_edge_labels 

1933 

1934 """ 

1935 import warnings 

1936 

1937 import matplotlib as mpl 

1938 import matplotlib.collections # call as mpl.collections 

1939 import matplotlib.colors # call as mpl.colors 

1940 import matplotlib.pyplot as plt 

1941 import numpy as np 

1942 

1943 # The default behavior is to use LineCollection to draw edges for 

1944 # undirected graphs (for performance reasons) and use FancyArrowPatches 

1945 # for directed graphs. 

1946 # The `arrows` keyword can be used to override the default behavior 

1947 if arrows is None: 

1948 use_linecollection = not (G.is_directed() or G.is_multigraph()) 

1949 else: 

1950 if not isinstance(arrows, bool): 

1951 raise TypeError("Argument `arrows` must be of type bool or None") 

1952 use_linecollection = not arrows 

1953 

1954 if isinstance(connectionstyle, str): 

1955 connectionstyle = [connectionstyle] 

1956 elif np.iterable(connectionstyle): 

1957 connectionstyle = list(connectionstyle) 

1958 else: 

1959 msg = "draw_networkx_edges arg `connectionstyle` must be str or iterable" 

1960 raise nx.NetworkXError(msg) 

1961 

1962 # Some kwargs only apply to FancyArrowPatches. Warn users when they use 

1963 # non-default values for these kwargs when LineCollection is being used 

1964 # instead of silently ignoring the specified option 

1965 if use_linecollection: 

1966 msg = ( 

1967 "\n\nThe {0} keyword argument is not applicable when drawing edges\n" 

1968 "with LineCollection.\n\n" 

1969 "To make this warning go away, either specify `arrows=True` to\n" 

1970 "force FancyArrowPatches or use the default values.\n" 

1971 "Note that using FancyArrowPatches may be slow for large graphs.\n" 

1972 ) 

1973 if arrowstyle is not None: 

1974 warnings.warn(msg.format("arrowstyle"), category=UserWarning, stacklevel=2) 

1975 if arrowsize != 10: 

1976 warnings.warn(msg.format("arrowsize"), category=UserWarning, stacklevel=2) 

1977 if min_source_margin != 0: 

1978 warnings.warn( 

1979 msg.format("min_source_margin"), category=UserWarning, stacklevel=2 

1980 ) 

1981 if min_target_margin != 0: 

1982 warnings.warn( 

1983 msg.format("min_target_margin"), category=UserWarning, stacklevel=2 

1984 ) 

1985 if any(cs != "arc3" for cs in connectionstyle): 

1986 warnings.warn( 

1987 msg.format("connectionstyle"), category=UserWarning, stacklevel=2 

1988 ) 

1989 

1990 # NOTE: Arrowstyle modification must occur after the warnings section 

1991 if arrowstyle is None: 

1992 arrowstyle = "-|>" if G.is_directed() else "-" 

1993 

1994 if ax is None: 

1995 ax = plt.gca() 

1996 

1997 if edgelist is None: 

1998 edgelist = list(G.edges) # (u, v, k) for multigraph (u, v) otherwise 

1999 

2000 if len(edgelist): 

2001 if G.is_multigraph(): 

2002 key_count = collections.defaultdict(lambda: itertools.count(0)) 

2003 edge_indices = [next(key_count[tuple(e[:2])]) for e in edgelist] 

2004 else: 

2005 edge_indices = [0] * len(edgelist) 

2006 else: # no edges! 

2007 return [] 

2008 

2009 if nodelist is None: 

2010 nodelist = list(G.nodes()) 

2011 

2012 # FancyArrowPatch handles color=None different from LineCollection 

2013 if edge_color is None: 

2014 edge_color = "k" 

2015 

2016 # set edge positions 

2017 edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist]) 

2018 

2019 # Check if edge_color is an array of floats and map to edge_cmap. 

2020 # This is the only case handled differently from matplotlib 

2021 if ( 

2022 np.iterable(edge_color) 

2023 and (len(edge_color) == len(edge_pos)) 

2024 and np.all([isinstance(c, Number) for c in edge_color]) 

2025 ): 

2026 if edge_cmap is not None: 

2027 assert isinstance(edge_cmap, mpl.colors.Colormap) 

2028 else: 

2029 edge_cmap = plt.get_cmap() 

2030 if edge_vmin is None: 

2031 edge_vmin = min(edge_color) 

2032 if edge_vmax is None: 

2033 edge_vmax = max(edge_color) 

2034 color_normal = mpl.colors.Normalize(vmin=edge_vmin, vmax=edge_vmax) 

2035 edge_color = [edge_cmap(color_normal(e)) for e in edge_color] 

2036 

2037 # compute initial view 

2038 minx = np.amin(np.ravel(edge_pos[:, :, 0])) 

2039 maxx = np.amax(np.ravel(edge_pos[:, :, 0])) 

2040 miny = np.amin(np.ravel(edge_pos[:, :, 1])) 

2041 maxy = np.amax(np.ravel(edge_pos[:, :, 1])) 

2042 w = maxx - minx 

2043 h = maxy - miny 

2044 

2045 # Self-loops are scaled by view extent, except in cases the extent 

2046 # is 0, e.g. for a single node. In this case, fall back to scaling 

2047 # by the maximum node size 

2048 selfloop_height = h if h != 0 else 0.005 * np.array(node_size).max() 

2049 fancy_arrow_factory = FancyArrowFactory( 

2050 edge_pos, 

2051 edgelist, 

2052 nodelist, 

2053 edge_indices, 

2054 node_size, 

2055 selfloop_height, 

2056 connectionstyle, 

2057 node_shape, 

2058 arrowstyle, 

2059 arrowsize, 

2060 edge_color, 

2061 alpha, 

2062 width, 

2063 style, 

2064 min_source_margin, 

2065 min_target_margin, 

2066 ax=ax, 

2067 ) 

2068 

2069 # Draw the edges 

2070 if use_linecollection: 

2071 edge_collection = mpl.collections.LineCollection( 

2072 edge_pos, 

2073 colors=edge_color, 

2074 linewidths=width, 

2075 antialiaseds=(1,), 

2076 linestyle=style, 

2077 alpha=alpha, 

2078 ) 

2079 edge_collection.set_cmap(edge_cmap) 

2080 edge_collection.set_clim(edge_vmin, edge_vmax) 

2081 edge_collection.set_zorder(1) # edges go behind nodes 

2082 edge_collection.set_label(label) 

2083 ax.add_collection(edge_collection) 

2084 edge_viz_obj = edge_collection 

2085 

2086 # Make sure selfloop edges are also drawn 

2087 # --------------------------------------- 

2088 selfloops_to_draw = [loop for loop in nx.selfloop_edges(G) if loop in edgelist] 

2089 if selfloops_to_draw: 

2090 edgelist_tuple = list(map(tuple, edgelist)) 

2091 arrow_collection = [] 

2092 for loop in selfloops_to_draw: 

2093 i = edgelist_tuple.index(loop) 

2094 arrow = fancy_arrow_factory(i) 

2095 arrow_collection.append(arrow) 

2096 ax.add_patch(arrow) 

2097 else: 

2098 edge_viz_obj = [] 

2099 for i in range(len(edgelist)): 

2100 arrow = fancy_arrow_factory(i) 

2101 ax.add_patch(arrow) 

2102 edge_viz_obj.append(arrow) 

2103 

2104 # update view after drawing 

2105 padx, pady = 0.05 * w, 0.05 * h 

2106 corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady) 

2107 ax.update_datalim(corners) 

2108 ax.autoscale_view() 

2109 

2110 if hide_ticks: 

2111 ax.tick_params( 

2112 axis="both", 

2113 which="both", 

2114 bottom=False, 

2115 left=False, 

2116 labelbottom=False, 

2117 labelleft=False, 

2118 ) 

2119 

2120 return edge_viz_obj 

2121 

2122 

2123def draw_networkx_labels( 

2124 G, 

2125 pos, 

2126 labels=None, 

2127 font_size=12, 

2128 font_color="k", 

2129 font_family="sans-serif", 

2130 font_weight="normal", 

2131 alpha=None, 

2132 bbox=None, 

2133 horizontalalignment="center", 

2134 verticalalignment="center", 

2135 ax=None, 

2136 clip_on=True, 

2137 hide_ticks=True, 

2138): 

2139 """Draw node labels on the graph G. 

2140 

2141 Parameters 

2142 ---------- 

2143 G : graph 

2144 A networkx graph 

2145 

2146 pos : dictionary 

2147 A dictionary with nodes as keys and positions as values. 

2148 Positions should be sequences of length 2. 

2149 

2150 labels : dictionary (default={n: n for n in G}) 

2151 Node labels in a dictionary of text labels keyed by node. 

2152 Node-keys in labels should appear as keys in `pos`. 

2153 If needed use: `{n:lab for n,lab in labels.items() if n in pos}` 

2154 

2155 font_size : int or dictionary of nodes to ints (default=12) 

2156 Font size for text labels. 

2157 

2158 font_color : color or dictionary of nodes to colors (default='k' black) 

2159 Font color string. Color can be string or rgb (or rgba) tuple of 

2160 floats from 0-1. 

2161 

2162 font_weight : string or dictionary of nodes to strings (default='normal') 

2163 Font weight. 

2164 

2165 font_family : string or dictionary of nodes to strings (default='sans-serif') 

2166 Font family. 

2167 

2168 alpha : float or None or dictionary of nodes to floats (default=None) 

2169 The text transparency. 

2170 

2171 bbox : Matplotlib bbox, (default is Matplotlib's ax.text default) 

2172 Specify text box properties (e.g. shape, color etc.) for node labels. 

2173 

2174 horizontalalignment : string or array of strings (default='center') 

2175 Horizontal alignment {'center', 'right', 'left'}. If an array is 

2176 specified it must be the same length as `nodelist`. 

2177 

2178 verticalalignment : string (default='center') 

2179 Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}. 

2180 If an array is specified it must be the same length as `nodelist`. 

2181 

2182 ax : Matplotlib Axes object, optional 

2183 Draw the graph in the specified Matplotlib axes. 

2184 

2185 clip_on : bool (default=True) 

2186 Turn on clipping of node labels at axis boundaries 

2187 

2188 hide_ticks : bool, optional 

2189 Hide ticks of axes. When `True` (the default), ticks and ticklabels 

2190 are removed from the axes. To set ticks and tick labels to the pyplot default, 

2191 use ``hide_ticks=False``. 

2192 

2193 Returns 

2194 ------- 

2195 dict 

2196 `dict` of labels keyed on the nodes 

2197 

2198 Examples 

2199 -------- 

2200 >>> G = nx.dodecahedral_graph() 

2201 >>> labels = nx.draw_networkx_labels(G, pos=nx.spring_layout(G)) 

2202 

2203 Also see the NetworkX drawing examples at 

2204 https://networkx.org/documentation/latest/auto_examples/index.html 

2205 

2206 See Also 

2207 -------- 

2208 draw 

2209 draw_networkx 

2210 draw_networkx_nodes 

2211 draw_networkx_edges 

2212 draw_networkx_edge_labels 

2213 """ 

2214 import matplotlib.pyplot as plt 

2215 

2216 if ax is None: 

2217 ax = plt.gca() 

2218 

2219 if labels is None: 

2220 labels = {n: n for n in G.nodes()} 

2221 

2222 individual_params = set() 

2223 

2224 def check_individual_params(p_value, p_name): 

2225 if isinstance(p_value, dict): 

2226 if len(p_value) != len(labels): 

2227 raise ValueError(f"{p_name} must have the same length as labels.") 

2228 individual_params.add(p_name) 

2229 

2230 def get_param_value(node, p_value, p_name): 

2231 if p_name in individual_params: 

2232 return p_value[node] 

2233 return p_value 

2234 

2235 check_individual_params(font_size, "font_size") 

2236 check_individual_params(font_color, "font_color") 

2237 check_individual_params(font_weight, "font_weight") 

2238 check_individual_params(font_family, "font_family") 

2239 check_individual_params(alpha, "alpha") 

2240 

2241 text_items = {} # there is no text collection so we'll fake one 

2242 for n, label in labels.items(): 

2243 (x, y) = pos[n] 

2244 if not isinstance(label, str): 

2245 label = str(label) # this makes "1" and 1 labeled the same 

2246 t = ax.text( 

2247 x, 

2248 y, 

2249 label, 

2250 size=get_param_value(n, font_size, "font_size"), 

2251 color=get_param_value(n, font_color, "font_color"), 

2252 family=get_param_value(n, font_family, "font_family"), 

2253 weight=get_param_value(n, font_weight, "font_weight"), 

2254 alpha=get_param_value(n, alpha, "alpha"), 

2255 horizontalalignment=horizontalalignment, 

2256 verticalalignment=verticalalignment, 

2257 transform=ax.transData, 

2258 bbox=bbox, 

2259 clip_on=clip_on, 

2260 ) 

2261 text_items[n] = t 

2262 

2263 if hide_ticks: 

2264 ax.tick_params( 

2265 axis="both", 

2266 which="both", 

2267 bottom=False, 

2268 left=False, 

2269 labelbottom=False, 

2270 labelleft=False, 

2271 ) 

2272 

2273 return text_items 

2274 

2275 

2276def draw_networkx_edge_labels( 

2277 G, 

2278 pos, 

2279 edge_labels=None, 

2280 label_pos=0.5, 

2281 font_size=10, 

2282 font_color="k", 

2283 font_family="sans-serif", 

2284 font_weight="normal", 

2285 alpha=None, 

2286 bbox=None, 

2287 horizontalalignment="center", 

2288 verticalalignment="center", 

2289 ax=None, 

2290 rotate=True, 

2291 clip_on=True, 

2292 node_size=300, 

2293 nodelist=None, 

2294 connectionstyle="arc3", 

2295 hide_ticks=True, 

2296): 

2297 """Draw edge labels. 

2298 

2299 Parameters 

2300 ---------- 

2301 G : graph 

2302 A networkx graph 

2303 

2304 pos : dictionary 

2305 A dictionary with nodes as keys and positions as values. 

2306 Positions should be sequences of length 2. 

2307 

2308 edge_labels : dictionary (default=None) 

2309 Edge labels in a dictionary of labels keyed by edge two-tuple. 

2310 Only labels for the keys in the dictionary are drawn. 

2311 

2312 label_pos : float (default=0.5) 

2313 Position of edge label along edge (0=head, 0.5=center, 1=tail) 

2314 

2315 font_size : int (default=10) 

2316 Font size for text labels 

2317 

2318 font_color : color (default='k' black) 

2319 Font color string. Color can be string or rgb (or rgba) tuple of 

2320 floats from 0-1. 

2321 

2322 font_weight : string (default='normal') 

2323 Font weight 

2324 

2325 font_family : string (default='sans-serif') 

2326 Font family 

2327 

2328 alpha : float or None (default=None) 

2329 The text transparency 

2330 

2331 bbox : Matplotlib bbox, optional 

2332 Specify text box properties (e.g. shape, color etc.) for edge labels. 

2333 Default is {boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)}. 

2334 

2335 horizontalalignment : string (default='center') 

2336 Horizontal alignment {'center', 'right', 'left'} 

2337 

2338 verticalalignment : string (default='center') 

2339 Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'} 

2340 

2341 ax : Matplotlib Axes object, optional 

2342 Draw the graph in the specified Matplotlib axes. 

2343 

2344 rotate : bool (default=True) 

2345 Rotate edge labels to lie parallel to edges 

2346 

2347 clip_on : bool (default=True) 

2348 Turn on clipping of edge labels at axis boundaries 

2349 

2350 node_size : scalar or array (default=300) 

2351 Size of nodes. If an array it must be the same length as nodelist. 

2352 

2353 nodelist : list, optional (default=G.nodes()) 

2354 This provides the node order for the `node_size` array (if it is an array). 

2355 

2356 connectionstyle : string or iterable of strings (default="arc3") 

2357 Pass the connectionstyle parameter to create curved arc of rounding 

2358 radius rad. For example, connectionstyle='arc3,rad=0.2'. 

2359 See `matplotlib.patches.ConnectionStyle` and 

2360 `matplotlib.patches.FancyArrowPatch` for more info. 

2361 If Iterable, index indicates i'th edge key of MultiGraph 

2362 

2363 hide_ticks : bool, optional 

2364 Hide ticks of axes. When `True` (the default), ticks and ticklabels 

2365 are removed from the axes. To set ticks and tick labels to the pyplot default, 

2366 use ``hide_ticks=False``. 

2367 

2368 Returns 

2369 ------- 

2370 dict 

2371 `dict` of labels keyed by edge 

2372 

2373 Examples 

2374 -------- 

2375 >>> G = nx.dodecahedral_graph() 

2376 >>> edge_labels = nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G)) 

2377 

2378 Also see the NetworkX drawing examples at 

2379 https://networkx.org/documentation/latest/auto_examples/index.html 

2380 

2381 See Also 

2382 -------- 

2383 draw 

2384 draw_networkx 

2385 draw_networkx_nodes 

2386 draw_networkx_edges 

2387 draw_networkx_labels 

2388 """ 

2389 import matplotlib as mpl 

2390 import matplotlib.pyplot as plt 

2391 import numpy as np 

2392 

2393 class CurvedArrowText(CurvedArrowTextBase, mpl.text.Text): 

2394 pass 

2395 

2396 # use default box of white with white border 

2397 if bbox is None: 

2398 bbox = {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)} 

2399 

2400 if isinstance(connectionstyle, str): 

2401 connectionstyle = [connectionstyle] 

2402 elif np.iterable(connectionstyle): 

2403 connectionstyle = list(connectionstyle) 

2404 else: 

2405 raise nx.NetworkXError( 

2406 "draw_networkx_edges arg `connectionstyle` must be" 

2407 "string or iterable of strings" 

2408 ) 

2409 

2410 if ax is None: 

2411 ax = plt.gca() 

2412 

2413 if edge_labels is None: 

2414 kwds = {"keys": True} if G.is_multigraph() else {} 

2415 edge_labels = {tuple(edge): d for *edge, d in G.edges(data=True, **kwds)} 

2416 # NOTHING TO PLOT 

2417 if not edge_labels: 

2418 return {} 

2419 edgelist, labels = zip(*edge_labels.items()) 

2420 

2421 if nodelist is None: 

2422 nodelist = list(G.nodes()) 

2423 

2424 # set edge positions 

2425 edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist]) 

2426 

2427 if G.is_multigraph(): 

2428 key_count = collections.defaultdict(lambda: itertools.count(0)) 

2429 edge_indices = [next(key_count[tuple(e[:2])]) for e in edgelist] 

2430 else: 

2431 edge_indices = [0] * len(edgelist) 

2432 

2433 # Used to determine self loop mid-point 

2434 # Note, that this will not be accurate, 

2435 # if not drawing edge_labels for all edges drawn 

2436 h = 0 

2437 if edge_labels: 

2438 miny = np.amin(np.ravel(edge_pos[:, :, 1])) 

2439 maxy = np.amax(np.ravel(edge_pos[:, :, 1])) 

2440 h = maxy - miny 

2441 selfloop_height = h if h != 0 else 0.005 * np.array(node_size).max() 

2442 fancy_arrow_factory = FancyArrowFactory( 

2443 edge_pos, 

2444 edgelist, 

2445 nodelist, 

2446 edge_indices, 

2447 node_size, 

2448 selfloop_height, 

2449 connectionstyle, 

2450 ax=ax, 

2451 ) 

2452 

2453 individual_params = {} 

2454 

2455 def check_individual_params(p_value, p_name): 

2456 # TODO should this be list or array (as in a numpy array)? 

2457 if isinstance(p_value, list): 

2458 if len(p_value) != len(edgelist): 

2459 raise ValueError(f"{p_name} must have the same length as edgelist.") 

2460 individual_params[p_name] = p_value.iter() 

2461 

2462 # Don't need to pass in an edge because these are lists, not dicts 

2463 def get_param_value(p_value, p_name): 

2464 if p_name in individual_params: 

2465 return next(individual_params[p_name]) 

2466 return p_value 

2467 

2468 check_individual_params(font_size, "font_size") 

2469 check_individual_params(font_color, "font_color") 

2470 check_individual_params(font_weight, "font_weight") 

2471 check_individual_params(alpha, "alpha") 

2472 check_individual_params(horizontalalignment, "horizontalalignment") 

2473 check_individual_params(verticalalignment, "verticalalignment") 

2474 check_individual_params(rotate, "rotate") 

2475 check_individual_params(label_pos, "label_pos") 

2476 

2477 text_items = {} 

2478 for i, (edge, label) in enumerate(zip(edgelist, labels)): 

2479 if not isinstance(label, str): 

2480 label = str(label) # this makes "1" and 1 labeled the same 

2481 

2482 n1, n2 = edge[:2] 

2483 arrow = fancy_arrow_factory(i) 

2484 if n1 == n2: 

2485 connectionstyle_obj = arrow.get_connectionstyle() 

2486 posA = ax.transData.transform(pos[n1]) 

2487 path_disp = connectionstyle_obj(posA, posA) 

2488 path_data = ax.transData.inverted().transform_path(path_disp) 

2489 x, y = path_data.vertices[0] 

2490 text_items[edge] = ax.text( 

2491 x, 

2492 y, 

2493 label, 

2494 size=get_param_value(font_size, "font_size"), 

2495 color=get_param_value(font_color, "font_color"), 

2496 family=get_param_value(font_family, "font_family"), 

2497 weight=get_param_value(font_weight, "font_weight"), 

2498 alpha=get_param_value(alpha, "alpha"), 

2499 horizontalalignment=get_param_value( 

2500 horizontalalignment, "horizontalalignment" 

2501 ), 

2502 verticalalignment=get_param_value( 

2503 verticalalignment, "verticalalignment" 

2504 ), 

2505 rotation=0, 

2506 transform=ax.transData, 

2507 bbox=bbox, 

2508 zorder=1, 

2509 clip_on=clip_on, 

2510 ) 

2511 else: 

2512 text_items[edge] = CurvedArrowText( 

2513 arrow, 

2514 label, 

2515 size=get_param_value(font_size, "font_size"), 

2516 color=get_param_value(font_color, "font_color"), 

2517 family=get_param_value(font_family, "font_family"), 

2518 weight=get_param_value(font_weight, "font_weight"), 

2519 alpha=get_param_value(alpha, "alpha"), 

2520 horizontalalignment=get_param_value( 

2521 horizontalalignment, "horizontalalignment" 

2522 ), 

2523 verticalalignment=get_param_value( 

2524 verticalalignment, "verticalalignment" 

2525 ), 

2526 transform=ax.transData, 

2527 bbox=bbox, 

2528 zorder=1, 

2529 clip_on=clip_on, 

2530 label_pos=get_param_value(label_pos, "label_pos"), 

2531 labels_horizontal=not get_param_value(rotate, "rotate"), 

2532 ax=ax, 

2533 ) 

2534 

2535 if hide_ticks: 

2536 ax.tick_params( 

2537 axis="both", 

2538 which="both", 

2539 bottom=False, 

2540 left=False, 

2541 labelbottom=False, 

2542 labelleft=False, 

2543 ) 

2544 

2545 return text_items 

2546 

2547 

2548def draw_bipartite(G, **kwargs): 

2549 """Draw the graph `G` with a bipartite layout. 

2550 

2551 This is a convenience function equivalent to:: 

2552 

2553 nx.draw(G, pos=nx.bipartite_layout(G), **kwargs) 

2554 

2555 Parameters 

2556 ---------- 

2557 G : graph 

2558 A networkx graph 

2559 

2560 kwargs : optional keywords 

2561 See `draw_networkx` for a description of optional keywords. 

2562 

2563 Raises 

2564 ------ 

2565 NetworkXError : 

2566 If `G` is not bipartite. 

2567 

2568 Notes 

2569 ----- 

2570 The layout is computed each time this function is called. For 

2571 repeated drawing it is much more efficient to call 

2572 `~networkx.drawing.layout.bipartite_layout` directly and reuse the result:: 

2573 

2574 >>> G = nx.complete_bipartite_graph(3, 3) 

2575 >>> pos = nx.bipartite_layout(G) 

2576 >>> nx.draw(G, pos=pos) # Draw the original graph 

2577 >>> # Draw a subgraph, reusing the same node positions 

2578 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") 

2579 

2580 Examples 

2581 -------- 

2582 >>> G = nx.complete_bipartite_graph(2, 5) 

2583 >>> nx.draw_bipartite(G) 

2584 

2585 See Also 

2586 -------- 

2587 :func:`~networkx.drawing.layout.bipartite_layout` 

2588 """ 

2589 draw(G, pos=nx.bipartite_layout(G), **kwargs) 

2590 

2591 

2592def draw_circular(G, **kwargs): 

2593 """Draw the graph `G` with a circular layout. 

2594 

2595 This is a convenience function equivalent to:: 

2596 

2597 nx.draw(G, pos=nx.circular_layout(G), **kwargs) 

2598 

2599 Parameters 

2600 ---------- 

2601 G : graph 

2602 A networkx graph 

2603 

2604 kwargs : optional keywords 

2605 See `draw_networkx` for a description of optional keywords. 

2606 

2607 Notes 

2608 ----- 

2609 The layout is computed each time this function is called. For 

2610 repeated drawing it is much more efficient to call 

2611 `~networkx.drawing.layout.circular_layout` directly and reuse the result:: 

2612 

2613 >>> G = nx.complete_graph(5) 

2614 >>> pos = nx.circular_layout(G) 

2615 >>> nx.draw(G, pos=pos) # Draw the original graph 

2616 >>> # Draw a subgraph, reusing the same node positions 

2617 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") 

2618 

2619 Examples 

2620 -------- 

2621 >>> G = nx.path_graph(5) 

2622 >>> nx.draw_circular(G) 

2623 

2624 See Also 

2625 -------- 

2626 :func:`~networkx.drawing.layout.circular_layout` 

2627 """ 

2628 draw(G, pos=nx.circular_layout(G), **kwargs) 

2629 

2630 

2631def draw_kamada_kawai(G, **kwargs): 

2632 """Draw the graph `G` with a Kamada-Kawai force-directed layout. 

2633 

2634 This is a convenience function equivalent to:: 

2635 

2636 nx.draw(G, pos=nx.kamada_kawai_layout(G), **kwargs) 

2637 

2638 Parameters 

2639 ---------- 

2640 G : graph 

2641 A networkx graph 

2642 

2643 kwargs : optional keywords 

2644 See `draw_networkx` for a description of optional keywords. 

2645 

2646 Notes 

2647 ----- 

2648 The layout is computed each time this function is called. 

2649 For repeated drawing it is much more efficient to call 

2650 `~networkx.drawing.layout.kamada_kawai_layout` directly and reuse the 

2651 result:: 

2652 

2653 >>> G = nx.complete_graph(5) 

2654 >>> pos = nx.kamada_kawai_layout(G) 

2655 >>> nx.draw(G, pos=pos) # Draw the original graph 

2656 >>> # Draw a subgraph, reusing the same node positions 

2657 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") 

2658 

2659 Examples 

2660 -------- 

2661 >>> G = nx.path_graph(5) 

2662 >>> nx.draw_kamada_kawai(G) 

2663 

2664 See Also 

2665 -------- 

2666 :func:`~networkx.drawing.layout.kamada_kawai_layout` 

2667 """ 

2668 draw(G, pos=nx.kamada_kawai_layout(G), **kwargs) 

2669 

2670 

2671def draw_random(G, **kwargs): 

2672 """Draw the graph `G` with a random layout. 

2673 

2674 This is a convenience function equivalent to:: 

2675 

2676 nx.draw(G, pos=nx.random_layout(G), **kwargs) 

2677 

2678 Parameters 

2679 ---------- 

2680 G : graph 

2681 A networkx graph 

2682 

2683 kwargs : optional keywords 

2684 See `draw_networkx` for a description of optional keywords. 

2685 

2686 Notes 

2687 ----- 

2688 The layout is computed each time this function is called. 

2689 For repeated drawing it is much more efficient to call 

2690 `~networkx.drawing.layout.random_layout` directly and reuse the result:: 

2691 

2692 >>> G = nx.complete_graph(5) 

2693 >>> pos = nx.random_layout(G) 

2694 >>> nx.draw(G, pos=pos) # Draw the original graph 

2695 >>> # Draw a subgraph, reusing the same node positions 

2696 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") 

2697 

2698 Examples 

2699 -------- 

2700 >>> G = nx.lollipop_graph(4, 3) 

2701 >>> nx.draw_random(G) 

2702 

2703 See Also 

2704 -------- 

2705 :func:`~networkx.drawing.layout.random_layout` 

2706 """ 

2707 draw(G, pos=nx.random_layout(G), **kwargs) 

2708 

2709 

2710def draw_spectral(G, **kwargs): 

2711 """Draw the graph `G` with a spectral 2D layout. 

2712 

2713 This is a convenience function equivalent to:: 

2714 

2715 nx.draw(G, pos=nx.spectral_layout(G), **kwargs) 

2716 

2717 For more information about how node positions are determined, see 

2718 `~networkx.drawing.layout.spectral_layout`. 

2719 

2720 Parameters 

2721 ---------- 

2722 G : graph 

2723 A networkx graph 

2724 

2725 kwargs : optional keywords 

2726 See `draw_networkx` for a description of optional keywords. 

2727 

2728 Notes 

2729 ----- 

2730 The layout is computed each time this function is called. 

2731 For repeated drawing it is much more efficient to call 

2732 `~networkx.drawing.layout.spectral_layout` directly and reuse the result:: 

2733 

2734 >>> G = nx.complete_graph(5) 

2735 >>> pos = nx.spectral_layout(G) 

2736 >>> nx.draw(G, pos=pos) # Draw the original graph 

2737 >>> # Draw a subgraph, reusing the same node positions 

2738 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") 

2739 

2740 Examples 

2741 -------- 

2742 >>> G = nx.path_graph(5) 

2743 >>> nx.draw_spectral(G) 

2744 

2745 See Also 

2746 -------- 

2747 :func:`~networkx.drawing.layout.spectral_layout` 

2748 """ 

2749 draw(G, pos=nx.spectral_layout(G), **kwargs) 

2750 

2751 

2752def draw_spring(G, **kwargs): 

2753 """Draw the graph `G` with a spring layout. 

2754 

2755 This is a convenience function equivalent to:: 

2756 

2757 nx.draw(G, pos=nx.spring_layout(G), **kwargs) 

2758 

2759 Parameters 

2760 ---------- 

2761 G : graph 

2762 A networkx graph 

2763 

2764 kwargs : optional keywords 

2765 See `draw_networkx` for a description of optional keywords. 

2766 

2767 Notes 

2768 ----- 

2769 `~networkx.drawing.layout.spring_layout` is also the default layout for 

2770 `draw`, so this function is equivalent to `draw`. 

2771 

2772 The layout is computed each time this function is called. 

2773 For repeated drawing it is much more efficient to call 

2774 `~networkx.drawing.layout.spring_layout` directly and reuse the result:: 

2775 

2776 >>> G = nx.complete_graph(5) 

2777 >>> pos = nx.spring_layout(G) 

2778 >>> nx.draw(G, pos=pos) # Draw the original graph 

2779 >>> # Draw a subgraph, reusing the same node positions 

2780 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") 

2781 

2782 Examples 

2783 -------- 

2784 >>> G = nx.path_graph(20) 

2785 >>> nx.draw_spring(G) 

2786 

2787 See Also 

2788 -------- 

2789 draw 

2790 :func:`~networkx.drawing.layout.spring_layout` 

2791 """ 

2792 draw(G, pos=nx.spring_layout(G), **kwargs) 

2793 

2794 

2795def draw_shell(G, nlist=None, **kwargs): 

2796 """Draw networkx graph `G` with shell layout. 

2797 

2798 This is a convenience function equivalent to:: 

2799 

2800 nx.draw(G, pos=nx.shell_layout(G, nlist=nlist), **kwargs) 

2801 

2802 Parameters 

2803 ---------- 

2804 G : graph 

2805 A networkx graph 

2806 

2807 nlist : list of list of nodes, optional 

2808 A list containing lists of nodes representing the shells. 

2809 Default is `None`, meaning all nodes are in a single shell. 

2810 See `~networkx.drawing.layout.shell_layout` for details. 

2811 

2812 kwargs : optional keywords 

2813 See `draw_networkx` for a description of optional keywords. 

2814 

2815 Notes 

2816 ----- 

2817 The layout is computed each time this function is called. 

2818 For repeated drawing it is much more efficient to call 

2819 `~networkx.drawing.layout.shell_layout` directly and reuse the result:: 

2820 

2821 >>> G = nx.complete_graph(5) 

2822 >>> pos = nx.shell_layout(G) 

2823 >>> nx.draw(G, pos=pos) # Draw the original graph 

2824 >>> # Draw a subgraph, reusing the same node positions 

2825 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") 

2826 

2827 Examples 

2828 -------- 

2829 >>> G = nx.path_graph(4) 

2830 >>> shells = [[0], [1, 2, 3]] 

2831 >>> nx.draw_shell(G, nlist=shells) 

2832 

2833 See Also 

2834 -------- 

2835 :func:`~networkx.drawing.layout.shell_layout` 

2836 """ 

2837 draw(G, pos=nx.shell_layout(G, nlist=nlist), **kwargs) 

2838 

2839 

2840def draw_planar(G, **kwargs): 

2841 """Draw a planar networkx graph `G` with planar layout. 

2842 

2843 This is a convenience function equivalent to:: 

2844 

2845 nx.draw(G, pos=nx.planar_layout(G), **kwargs) 

2846 

2847 Parameters 

2848 ---------- 

2849 G : graph 

2850 A planar networkx graph 

2851 

2852 kwargs : optional keywords 

2853 See `draw_networkx` for a description of optional keywords. 

2854 

2855 Raises 

2856 ------ 

2857 NetworkXException 

2858 When `G` is not planar 

2859 

2860 Notes 

2861 ----- 

2862 The layout is computed each time this function is called. 

2863 For repeated drawing it is much more efficient to call 

2864 `~networkx.drawing.layout.planar_layout` directly and reuse the result:: 

2865 

2866 >>> G = nx.path_graph(5) 

2867 >>> pos = nx.planar_layout(G) 

2868 >>> nx.draw(G, pos=pos) # Draw the original graph 

2869 >>> # Draw a subgraph, reusing the same node positions 

2870 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") 

2871 

2872 Examples 

2873 -------- 

2874 >>> G = nx.path_graph(4) 

2875 >>> nx.draw_planar(G) 

2876 

2877 See Also 

2878 -------- 

2879 :func:`~networkx.drawing.layout.planar_layout` 

2880 """ 

2881 draw(G, pos=nx.planar_layout(G), **kwargs) 

2882 

2883 

2884def draw_forceatlas2(G, **kwargs): 

2885 """Draw a networkx graph with forceatlas2 layout. 

2886 

2887 This is a convenience function equivalent to:: 

2888 

2889 nx.draw(G, pos=nx.forceatlas2_layout(G), **kwargs) 

2890 

2891 Parameters 

2892 ---------- 

2893 G : graph 

2894 A networkx graph 

2895 

2896 kwargs : optional keywords 

2897 See networkx.draw_networkx() for a description of optional keywords, 

2898 with the exception of the pos parameter which is not used by this 

2899 function. 

2900 """ 

2901 draw(G, pos=nx.forceatlas2_layout(G), **kwargs) 

2902 

2903 

2904def apply_alpha(colors, alpha, elem_list, cmap=None, vmin=None, vmax=None): 

2905 """Apply an alpha (or list of alphas) to the colors provided. 

2906 

2907 Parameters 

2908 ---------- 

2909 

2910 colors : color string or array of floats (default='r') 

2911 Color of element. Can be a single color format string, 

2912 or a sequence of colors with the same length as nodelist. 

2913 If numeric values are specified they will be mapped to 

2914 colors using the cmap and vmin,vmax parameters. See 

2915 matplotlib.scatter for more details. 

2916 

2917 alpha : float or array of floats 

2918 Alpha values for elements. This can be a single alpha value, in 

2919 which case it will be applied to all the elements of color. Otherwise, 

2920 if it is an array, the elements of alpha will be applied to the colors 

2921 in order (cycling through alpha multiple times if necessary). 

2922 

2923 elem_list : array of networkx objects 

2924 The list of elements which are being colored. These could be nodes, 

2925 edges or labels. 

2926 

2927 cmap : matplotlib colormap 

2928 Color map for use if colors is a list of floats corresponding to points 

2929 on a color mapping. 

2930 

2931 vmin, vmax : float 

2932 Minimum and maximum values for normalizing colors if a colormap is used 

2933 

2934 Returns 

2935 ------- 

2936 

2937 rgba_colors : numpy ndarray 

2938 Array containing RGBA format values for each of the node colours. 

2939 

2940 """ 

2941 from itertools import cycle, islice 

2942 

2943 import matplotlib as mpl 

2944 import matplotlib.cm # call as mpl.cm 

2945 import matplotlib.colors # call as mpl.colors 

2946 import numpy as np 

2947 

2948 # If we have been provided with a list of numbers as long as elem_list, 

2949 # apply the color mapping. 

2950 if len(colors) == len(elem_list) and isinstance(colors[0], Number): 

2951 mapper = mpl.cm.ScalarMappable(cmap=cmap) 

2952 mapper.set_clim(vmin, vmax) 

2953 rgba_colors = mapper.to_rgba(colors) 

2954 # Otherwise, convert colors to matplotlib's RGB using the colorConverter 

2955 # object. These are converted to numpy ndarrays to be consistent with the 

2956 # to_rgba method of ScalarMappable. 

2957 else: 

2958 try: 

2959 rgba_colors = np.array([mpl.colors.colorConverter.to_rgba(colors)]) 

2960 except ValueError: 

2961 rgba_colors = np.array( 

2962 [mpl.colors.colorConverter.to_rgba(color) for color in colors] 

2963 ) 

2964 # Set the final column of the rgba_colors to have the relevant alpha values 

2965 try: 

2966 # If alpha is longer than the number of colors, resize to the number of 

2967 # elements. Also, if rgba_colors.size (the number of elements of 

2968 # rgba_colors) is the same as the number of elements, resize the array, 

2969 # to avoid it being interpreted as a colormap by scatter() 

2970 if len(alpha) > len(rgba_colors) or rgba_colors.size == len(elem_list): 

2971 rgba_colors = np.resize(rgba_colors, (len(elem_list), 4)) 

2972 rgba_colors[1:, 0] = rgba_colors[0, 0] 

2973 rgba_colors[1:, 1] = rgba_colors[0, 1] 

2974 rgba_colors[1:, 2] = rgba_colors[0, 2] 

2975 rgba_colors[:, 3] = list(islice(cycle(alpha), len(rgba_colors))) 

2976 except TypeError: 

2977 rgba_colors[:, -1] = alpha 

2978 return rgba_colors