1"""
2**********
3Matplotlib
4**********
5
6Draw networks with matplotlib.
7
8Examples
9--------
10>>> G = nx.complete_graph(5)
11>>> nx.draw(G)
12
13See Also
14--------
15 - :doc:`matplotlib <matplotlib:index>`
16 - :func:`matplotlib.pyplot.scatter`
17 - :obj:`matplotlib.patches.FancyArrowPatch`
18"""
19
20import collections
21import itertools
22import math
23from numbers import Number
24
25import networkx as nx
26
27__all__ = [
28 "display",
29 "apply_matplotlib_colors",
30 "draw",
31 "draw_networkx",
32 "draw_networkx_nodes",
33 "draw_networkx_edges",
34 "draw_networkx_labels",
35 "draw_networkx_edge_labels",
36 "draw_bipartite",
37 "draw_circular",
38 "draw_kamada_kawai",
39 "draw_random",
40 "draw_spectral",
41 "draw_spring",
42 "draw_planar",
43 "draw_shell",
44 "draw_forceatlas2",
45]
46
47
48def apply_matplotlib_colors(
49 G, src_attr, dest_attr, map, vmin=None, vmax=None, nodes=True
50):
51 """
52 Apply colors from a matplotlib colormap to a graph.
53
54 Reads values from the `src_attr` and use a matplotlib colormap
55 to produce a color. Write the color to `dest_attr`.
56
57 Parameters
58 ----------
59 G : nx.Graph
60 The graph to read and compute colors for.
61
62 src_attr : str or other attribute name
63 The name of the attribute to read from the graph.
64
65 dest_attr : str or other attribute name
66 The name of the attribute to write to on the graph.
67
68 map : matplotlib.colormap
69 The matplotlib colormap to use.
70
71 vmin : float, default None
72 The minimum value for scaling the colormap. If `None`, find the
73 minimum value of `src_attr`.
74
75 vmax : float, default None
76 The maximum value for scaling the colormap. If `None`, find the
77 maximum value of `src_attr`.
78
79 nodes : bool, default True
80 Whether the attribute names are edge attributes or node attributes.
81 """
82 import matplotlib as mpl
83
84 if nodes:
85 type_iter = G.nodes()
86 elif G.is_multigraph():
87 type_iter = G.edges(keys=True)
88 else:
89 type_iter = G.edges()
90
91 if vmin is None or vmax is None:
92 vals = [type_iter[a][src_attr] for a in type_iter]
93 if vmin is None:
94 vmin = min(vals)
95 if vmax is None:
96 vmax = max(vals)
97
98 mapper = mpl.cm.ScalarMappable(cmap=map)
99 mapper.set_clim(vmin, vmax)
100
101 def do_map(x):
102 # Cast numpy scalars to float
103 return tuple(float(x) for x in mapper.to_rgba(x))
104
105 if nodes:
106 nx.set_node_attributes(
107 G, {n: do_map(G.nodes[n][src_attr]) for n in G.nodes()}, dest_attr
108 )
109 else:
110 nx.set_edge_attributes(
111 G, {e: do_map(G.edges[e][src_attr]) for e in type_iter}, dest_attr
112 )
113
114
115class CurvedArrowTextBase:
116 def __init__(
117 self,
118 arrow,
119 *args,
120 label_pos=0.5,
121 labels_horizontal=False,
122 ax=None,
123 **kwargs,
124 ):
125 # Bind to FancyArrowPatch
126 self.arrow = arrow
127 # how far along the text should be on the curve,
128 # 0 is at start, 1 is at end etc.
129 self.label_pos = label_pos
130 self.labels_horizontal = labels_horizontal
131 if ax is None:
132 ax = plt.gca()
133 self.ax = ax
134 self.x, self.y, self.angle = self._update_text_pos_angle(arrow)
135
136 # Create text object
137 super().__init__(self.x, self.y, *args, rotation=self.angle, **kwargs)
138 # Bind to axis
139 self.ax.add_artist(self)
140
141 def _get_arrow_path_disp(self, arrow):
142 """
143 This is part of FancyArrowPatch._get_path_in_displaycoord
144 It omits the second part of the method where path is converted
145 to polygon based on width
146 The transform is taken from ax, not the object, as the object
147 has not been added yet, and doesn't have transform
148 """
149 dpi_cor = arrow._dpi_cor
150 trans_data = self.ax.transData
151 if arrow._posA_posB is None:
152 raise ValueError(
153 "Can only draw labels for fancy arrows with "
154 "posA and posB inputs, not custom path"
155 )
156 posA = arrow._convert_xy_units(arrow._posA_posB[0])
157 posB = arrow._convert_xy_units(arrow._posA_posB[1])
158 (posA, posB) = trans_data.transform((posA, posB))
159 _path = arrow.get_connectionstyle()(
160 posA,
161 posB,
162 patchA=arrow.patchA,
163 patchB=arrow.patchB,
164 shrinkA=arrow.shrinkA * dpi_cor,
165 shrinkB=arrow.shrinkB * dpi_cor,
166 )
167 # Return is in display coordinates
168 return _path
169
170 def _update_text_pos_angle(self, arrow):
171 # Fractional label position
172 # Text position at a proportion t along the line in display coords
173 # default is 0.5 so text appears at the halfway point
174 import matplotlib as mpl
175 import numpy as np
176
177 t = self.label_pos
178 tt = 1 - t
179 path_disp = self._get_arrow_path_disp(arrow)
180 conn = arrow.get_connectionstyle()
181 # 1. Calculate x and y
182 points = path_disp.vertices
183 if is_curve := isinstance(
184 conn,
185 mpl.patches.ConnectionStyle.Angle3 | mpl.patches.ConnectionStyle.Arc3,
186 ):
187 # Arc3 or Angle3 type Connection Styles - Bezier curve
188 (x1, y1), (cx, cy), (x2, y2) = points
189 x = tt**2 * x1 + 2 * t * tt * cx + t**2 * x2
190 y = tt**2 * y1 + 2 * t * tt * cy + t**2 * y2
191 else:
192 if not isinstance(
193 conn,
194 mpl.patches.ConnectionStyle.Angle
195 | mpl.patches.ConnectionStyle.Arc
196 | mpl.patches.ConnectionStyle.Bar,
197 ):
198 msg = f"invalid connection style: {type(conn)}"
199 raise TypeError(msg)
200 # A. Collect lines
201 codes = path_disp.codes
202 lines = [
203 points[i - 1 : i + 1]
204 for i in range(1, len(points))
205 if codes[i] == mpl.path.Path.LINETO
206 ]
207 # B. If more than one line, find the right one and position in it
208 if (nlines := len(lines)) != 1:
209 dists = [math.dist(*line) for line in lines]
210 dist_tot = sum(dists)
211 cdist = 0
212 last_cut = 0
213 i_last = nlines - 1
214 for i, dist in enumerate(dists):
215 cdist += dist
216 cut = cdist / dist_tot
217 if i == i_last or t < cut:
218 t = (t - last_cut) / (dist / dist_tot)
219 tt = 1 - t
220 lines = [lines[i]]
221 break
222 last_cut = cut
223 [[(cx1, cy1), (cx2, cy2)]] = lines
224 x = cx1 * tt + cx2 * t
225 y = cy1 * tt + cy2 * t
226
227 # 2. Calculate Angle
228 if self.labels_horizontal:
229 # Horizontal text labels
230 angle = 0
231 else:
232 # Labels parallel to curve
233 if is_curve:
234 change_x = 2 * tt * (cx - x1) + 2 * t * (x2 - cx)
235 change_y = 2 * tt * (cy - y1) + 2 * t * (y2 - cy)
236 else:
237 change_x = (cx2 - cx1) / 2
238 change_y = (cy2 - cy1) / 2
239 angle = np.arctan2(change_y, change_x) / (2 * np.pi) * 360
240 # Text is "right way up"
241 if angle > 90:
242 angle -= 180
243 elif angle < -90:
244 angle += 180
245 (x, y) = self.ax.transData.inverted().transform((x, y))
246 return x, y, angle
247
248 def draw(self, renderer):
249 # recalculate the text position and angle
250 self.x, self.y, self.angle = self._update_text_pos_angle(self.arrow)
251 self.set_position((self.x, self.y))
252 self.set_rotation(self.angle)
253 # redraw text
254 super().draw(renderer)
255
256
257def display(
258 G,
259 canvas=None,
260 **kwargs,
261):
262 """Draw the graph G.
263
264 Draw the graph as a collection of nodes connected by edges.
265 The exact details of what the graph looks like are controlled by the below
266 attributes. All nodes and nodes at the end of visible edges must have a
267 position set, but nearly all other node and edge attributes are options and
268 nodes or edges missing the attribute will use the default listed below. A more
269 complete description of each parameter is given below this summary.
270
271 .. list-table:: Default Visualization Attributes
272 :widths: 25 25 50
273 :header-rows: 1
274
275 * - Parameter
276 - Default Attribute
277 - Default Value
278 * - node_pos
279 - `"pos"`
280 - If there is not position, a layout will be calculated with `nx.spring_layout`.
281 * - node_visible
282 - `"visible"`
283 - True
284 * - node_color
285 - `"color"`
286 - #1f78b4
287 * - node_size
288 - `"size"`
289 - 300
290 * - node_label
291 - `"label"`
292 - Dict describing the node label. Defaults create a black text with
293 the node name as the label. The dict respects these keys and defaults:
294
295 * size : 12
296 * color : black
297 * family : sans serif
298 * weight : normal
299 * alpha : 1.0
300 * h_align : center
301 * v_align : center
302 * bbox : Dict describing a `matplotlib.patches.FancyBboxPatch`.
303 Default is None.
304
305 * - node_shape
306 - `"shape"`
307 - "o"
308 * - node_alpha
309 - `"alpha"`
310 - 1.0
311 * - node_border_width
312 - `"border_width"`
313 - 1.0
314 * - node_border_color
315 - `"border_color"`
316 - Matching node_color
317 * - edge_visible
318 - `"visible"`
319 - True
320 * - edge_width
321 - `"width"`
322 - 1.0
323 * - edge_color
324 - `"color"`
325 - Black (#000000)
326 * - edge_label
327 - `"label"`
328 - Dict describing the edge label. Defaults create black text with a
329 white bounding box. The dictionary respects these keys and defaults:
330
331 * size : 12
332 * color : black
333 * family : sans serif
334 * weight : normal
335 * alpha : 1.0
336 * bbox : Dict describing a `matplotlib.patches.FancyBboxPatch`.
337 Default {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)}
338 * h_align : "center"
339 * v_align : "center"
340 * pos : 0.5
341 * rotate : True
342
343 * - edge_style
344 - `"style"`
345 - "-"
346 * - edge_alpha
347 - `"alpha"`
348 - 1.0
349 * - edge_arrowstyle
350 - `"arrowstyle"`
351 - ``"-|>"`` if `G` is directed else ``"-"``
352 * - edge_arrowsize
353 - `"arrowsize"`
354 - 10 if `G` is directed else 0
355 * - edge_curvature
356 - `"curvature"`
357 - arc3
358 * - edge_source_margin
359 - `"source_margin"`
360 - 0
361 * - edge_target_margin
362 - `"target_margin"`
363 - 0
364
365 Parameters
366 ----------
367 G : graph
368 A networkx graph
369
370 canvas : Matplotlib Axes object, optional
371 Draw the graph in specified Matplotlib axes
372
373 node_pos : string or function, default "pos"
374 A string naming the node attribute storing the position of nodes as a tuple.
375 Or a function to be called with input `G` which returns the layout as a dict keyed
376 by node to position tuple like the NetworkX layout functions.
377 If no nodes in the graph has the attribute, a spring layout is calculated.
378
379 node_visible : string or bool, default visible
380 A string naming the node attribute which stores if a node should be drawn.
381 If `True`, all nodes will be visible while if `False` no nodes will be visible.
382 If incomplete, nodes missing this attribute will be shown by default.
383
384 node_color : string, default "color"
385 A string naming the node attribute which stores the color of each node.
386 Visible nodes without this attribute will use '#1f78b4' as a default.
387
388 node_size : string or number, default "size"
389 A string naming the node attribute which stores the size of each node.
390 Visible nodes without this attribute will use a default size of 300.
391
392 node_label : string or bool, default "label"
393 A string naming the node attribute which stores the label of each node.
394 The attribute value can be a string, False (no label for that node),
395 True (the node is the label) or a dict keyed by node to the label.
396
397 If a dict is specified, these keys are read to further control the label:
398
399 * label : The text of the label; default: name of the node
400 * size : Font size of the label; default: 12
401 * color : Font color of the label; default: black
402 * family : Font family of the label; default: "sans-serif"
403 * weight : Font weight of the label; default: "normal"
404 * alpha : Alpha value of the label; default: 1.0
405 * h_align : The horizontal alignment of the label.
406 one of "left", "center", "right"; default: "center"
407 * v_align : The vertical alignment of the label.
408 one of "top", "center", "bottom"; default: "center"
409 * bbox : A dict of parameters for `matplotlib.patches.FancyBboxPatch`.
410
411 Visible nodes without this attribute will be treated as if the value was True.
412
413 node_shape : string, default "shape"
414 A string naming the node attribute which stores the label of each node.
415 The values of this attribute are expected to be one of the matplotlib shapes,
416 one of 'so^>v<dph8'. Visible nodes without this attribute will use 'o'.
417
418 node_alpha : string, default "alpha"
419 A string naming the node attribute which stores the alpha of each node.
420 The values of this attribute are expected to be floats between 0.0 and 1.0.
421 Visible nodes without this attribute will be treated as if the value was 1.0.
422
423 node_border_width : string, default "border_width"
424 A string naming the node attribute storing the width of the border of the node.
425 The values of this attribute are expected to be numeric. Visible nodes without
426 this attribute will use the assumed default of 1.0.
427
428 node_border_color : string, default "border_color"
429 A string naming the node attribute which storing the color of the border of the node.
430 Visible nodes missing this attribute will use the final node_color value.
431
432 edge_visible : string or bool, default "visible"
433 A string nameing the edge attribute which stores if an edge should be drawn.
434 If `True`, all edges will be drawn while if `False` no edges will be visible.
435 If incomplete, edges missing this attribute will be shown by default. Values
436 of this attribute are expected to be booleans.
437
438 edge_width : string or int, default "width"
439 A string nameing the edge attribute which stores the width of each edge.
440 Visible edges without this attribute will use a default width of 1.0.
441
442 edge_color : string or color, default "color"
443 A string nameing the edge attribute which stores of color of each edge.
444 Visible edges without this attribute will be drawn black. Each color can be
445 a string or rgb (or rgba) tuple of floats from 0.0 to 1.0.
446
447 edge_label : string, default "label"
448 A string naming the edge attribute which stores the label of each edge.
449 The values of this attribute can be a string, number or False or None. In
450 the latter two cases, no edge label is displayed.
451
452 If a dict is specified, these keys are read to further control the label:
453
454 * label : The text of the label, or the name of an edge attribute holding the label.
455 * size : Font size of the label; default: 12
456 * color : Font color of the label; default: black
457 * family : Font family of the label; default: "sans-serif"
458 * weight : Font weight of the label; default: "normal"
459 * alpha : Alpha value of the label; default: 1.0
460 * h_align : The horizontal alignment of the label.
461 one of "left", "center", "right"; default: "center"
462 * v_align : The vertical alignment of the label.
463 one of "top", "center", "bottom"; default: "center"
464 * bbox : A dict of parameters for `matplotlib.patches.FancyBboxPatch`.
465 * rotate : Whether to rotate labels to lie parallel to the edge, default: True.
466 * pos : A float showing how far along the edge to put the label; default: 0.5.
467
468 edge_style : string, default "style"
469 A string naming the edge attribute which stores the style of each edge.
470 Visible edges without this attribute will be drawn solid. Values of this
471 attribute can be line styles, e.g. '-', '--', '-.' or ':' or words like 'solid'
472 or 'dashed'. If no edge in the graph has this attribute and it is a non-default
473 value, assume that it describes the edge style for all edges in the graph.
474
475 edge_alpha : string or float, default "alpha"
476 A string naming the edge attribute which stores the alpha value of each edge.
477 Visible edges without this attribute will use an alpha value of 1.0.
478
479 edge_arrowstyle : string, default "arrowstyle"
480 A string naming the edge attribute which stores the type of arrowhead to use for
481 each edge. Visible edges without this attribute use ``"-"`` for undirected graphs
482 and ``"-|>"`` for directed graphs.
483
484 See `matplotlib.patches.ArrowStyle` for more options
485
486 edge_arrowsize : string or int, default "arrowsize"
487 A string naming the edge attribute which stores the size of the arrowhead for each
488 edge. Visible edges without this attribute will use a default value of 10.
489
490 edge_curvature : string, default "curvature"
491 A string naming the edge attribute storing the curvature and connection style
492 of each edge. Visible edges without this attribute will use "arc3" as a default
493 value, resulting an a straight line between the two nodes. Curvature can be given
494 as 'arc3,rad=0.2' to specify both the style and radius of curvature.
495
496 Please see `matplotlib.patches.ConnectionStyle` and
497 `matplotlib.patches.FancyArrowPatch` for more information.
498
499 edge_source_margin : string or int, default "source_margin"
500 A string naming the edge attribute which stores the minimum margin (gap) between
501 the source node and the start of the edge. Visible edges without this attribute
502 will use a default value of 0.
503
504 edge_target_margin : string or int, default "target_margin"
505 A string naming the edge attribute which stores the minimumm margin (gap) between
506 the target node and the end of the edge. Visible edges without this attribute
507 will use a default value of 0.
508
509 hide_ticks : bool, default True
510 Weather to remove the ticks from the axes of the matplotlib object.
511
512 Raises
513 ------
514 NetworkXError
515 If a node or edge is missing a required parameter such as `pos` or
516 if `display` receives an argument not listed above.
517
518 ValueError
519 If a node or edge has an invalid color format, i.e. not a color string,
520 rgb tuple or rgba tuple.
521
522 Returns
523 -------
524 The input graph. This is potentially useful for dispatching visualization
525 functions.
526 """
527 from collections import Counter
528
529 import matplotlib as mpl
530 import matplotlib.pyplot as plt
531 import numpy as np
532
533 defaults = {
534 "node_pos": None,
535 "node_visible": True,
536 "node_color": "#1f78b4",
537 "node_size": 300,
538 "node_label": {
539 "size": 12,
540 "color": "#000000",
541 "family": "sans-serif",
542 "weight": "normal",
543 "alpha": 1.0,
544 "h_align": "center",
545 "v_align": "center",
546 "bbox": None,
547 },
548 "node_shape": "o",
549 "node_alpha": 1.0,
550 "node_border_width": 1.0,
551 "node_border_color": "face",
552 "edge_visible": True,
553 "edge_width": 1.0,
554 "edge_color": "#000000",
555 "edge_label": {
556 "size": 12,
557 "color": "#000000",
558 "family": "sans-serif",
559 "weight": "normal",
560 "alpha": 1.0,
561 "bbox": {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)},
562 "h_align": "center",
563 "v_align": "center",
564 "pos": 0.5,
565 "rotate": True,
566 },
567 "edge_style": "-",
568 "edge_alpha": 1.0,
569 "edge_arrowstyle": "-|>" if G.is_directed() else "-",
570 "edge_arrowsize": 10 if G.is_directed() else 0,
571 "edge_curvature": "arc3",
572 "edge_source_margin": 0,
573 "edge_target_margin": 0,
574 "hide_ticks": True,
575 }
576
577 # Check arguments
578 for kwarg in kwargs:
579 if kwarg not in defaults:
580 raise nx.NetworkXError(
581 f"Unrecognized visualization keyword argument: {kwarg}"
582 )
583
584 if canvas is None:
585 canvas = plt.gca()
586
587 if kwargs.get("hide_ticks", defaults["hide_ticks"]):
588 canvas.tick_params(
589 axis="both",
590 which="both",
591 bottom=False,
592 left=False,
593 labelbottom=False,
594 labelleft=False,
595 )
596
597 ### Helper methods and classes
598
599 def node_property_sequence(seq, attr):
600 """Return a list of attribute values for `seq`, using a default if needed"""
601
602 # All node attribute parameters start with "node_"
603 param_name = f"node_{attr}"
604 default = defaults[param_name]
605 attr = kwargs.get(param_name, attr)
606
607 if default is None:
608 # raise instead of using non-existant default value
609 for n in seq:
610 if attr not in node_subgraph.nodes[n]:
611 raise nx.NetworkXError(f"Attribute '{attr}' missing for node {n}")
612
613 # If `attr` is not a graph attr and was explicitly passed as an argument
614 # it must be a user-default value. Allow attr=None to tell draw to skip
615 # attributes which are on the graph
616 if (
617 attr is not None
618 and nx.get_node_attributes(node_subgraph, attr) == {}
619 and any(attr == v for k, v in kwargs.items() if "node" in k)
620 ):
621 return [attr for _ in seq]
622
623 return [node_subgraph.nodes[n].get(attr, default) for n in seq]
624
625 def compute_colors(color, alpha):
626 if isinstance(color, str):
627 rgba = mpl.colors.colorConverter.to_rgba(color)
628 # Using a non-default alpha value overrides any alpha value in the color
629 if alpha != defaults["node_alpha"]:
630 return (rgba[0], rgba[1], rgba[2], alpha)
631 return rgba
632
633 if isinstance(color, tuple) and len(color) == 3:
634 return (color[0], color[1], color[2], alpha)
635
636 if isinstance(color, tuple) and len(color) == 4:
637 return color
638
639 raise ValueError(f"Invalid format for color: {color}")
640
641 # Find which edges can be plotted as a line collection
642 #
643 # Non-default values for these attributes require fancy arrow patches:
644 # - any arrow style (including the default -|> for directed graphs)
645 # - arrow size (by extension of style)
646 # - connection style
647 # - min_source_margin
648 # - min_target_margin
649
650 def collection_compatible(e):
651 return (
652 get_edge_attr(e, "arrowstyle") == "-"
653 and get_edge_attr(e, "curvature") == "arc3"
654 and get_edge_attr(e, "source_margin") == 0
655 and get_edge_attr(e, "target_margin") == 0
656 # Self-loops will use fancy arrow patches
657 and e[0] != e[1]
658 )
659
660 def edge_property_sequence(seq, attr):
661 """Return a list of attribute values for `seq`, using a default if needed"""
662
663 param_name = f"edge_{attr}"
664 default = defaults[param_name]
665 attr = kwargs.get(param_name, attr)
666
667 if default is None:
668 # raise instead of using non-existant default value
669 for e in seq:
670 if attr not in edge_subgraph.edges[e]:
671 raise nx.NetworkXError(f"Attribute '{attr}' missing for edge {e}")
672
673 if (
674 attr is not None
675 and nx.get_edge_attributes(edge_subgraph, attr) == {}
676 and any(attr == v for k, v in kwargs.items() if "edge" in k)
677 ):
678 return [attr for _ in seq]
679
680 return [edge_subgraph.edges[e].get(attr, default) for e in seq]
681
682 def get_edge_attr(e, attr):
683 """Return the final edge attribute value, using default if not None"""
684
685 param_name = f"edge_{attr}"
686 default = defaults[param_name]
687 attr = kwargs.get(param_name, attr)
688
689 if default is None and attr not in edge_subgraph.edges[e]:
690 raise nx.NetworkXError(f"Attribute '{attr}' missing from edge {e}")
691
692 if (
693 attr is not None
694 and nx.get_edge_attributes(edge_subgraph, attr) == {}
695 and attr in kwargs.values()
696 ):
697 return attr
698
699 return edge_subgraph.edges[e].get(attr, default)
700
701 def get_node_attr(n, attr, use_edge_subgraph=True):
702 """Return the final node attribute value, using default if not None"""
703 subgraph = edge_subgraph if use_edge_subgraph else node_subgraph
704
705 param_name = f"node_{attr}"
706 default = defaults[param_name]
707 attr = kwargs.get(param_name, attr)
708
709 if default is None and attr not in subgraph.nodes[n]:
710 raise nx.NetworkXError(f"Attribute '{attr}' missing from node {n}")
711
712 if (
713 attr is not None
714 and nx.get_node_attributes(subgraph, attr) == {}
715 and attr in kwargs.values()
716 ):
717 return attr
718
719 return subgraph.nodes[n].get(attr, default)
720
721 # Taken from ConnectionStyleFactory
722 def self_loop(edge_index, node_size):
723 def self_loop_connection(posA, posB, *args, **kwargs):
724 if not np.all(posA == posB):
725 raise nx.NetworkXError(
726 "`self_loop` connection style method"
727 "is only to be used for self-loops"
728 )
729 # this is called with _screen space_ values
730 # so convert back to data space
731 data_loc = canvas.transData.inverted().transform(posA)
732 # Scale self loop based on the size of the base node
733 # Size of nodes are given in points ** 2 and each point is 1/72 of an inch
734 v_shift = np.sqrt(node_size) / 72
735 h_shift = v_shift * 0.5
736 # put the top of the loop first so arrow is not hidden by node
737 path = np.asarray(
738 [
739 # 1
740 [0, v_shift],
741 # 4 4 4
742 [h_shift, v_shift],
743 [h_shift, 0],
744 [0, 0],
745 # 4 4 4
746 [-h_shift, 0],
747 [-h_shift, v_shift],
748 [0, v_shift],
749 ]
750 )
751 # Rotate self loop 90 deg. if more than 1
752 # This will allow for maximum of 4 visible self loops
753 if edge_index % 4:
754 x, y = path.T
755 for _ in range(edge_index % 4):
756 x, y = y, -x
757 path = np.array([x, y]).T
758 return mpl.path.Path(
759 canvas.transData.transform(data_loc + path), [1, 4, 4, 4, 4, 4, 4]
760 )
761
762 return self_loop_connection
763
764 def to_marker_edge(size, marker):
765 if marker in "s^>v<d":
766 return np.sqrt(2 * size) / 2
767 else:
768 return np.sqrt(size) / 2
769
770 def build_fancy_arrow(e):
771 source_margin = to_marker_edge(
772 get_node_attr(e[0], "size"),
773 get_node_attr(e[0], "shape"),
774 )
775 source_margin = max(
776 source_margin,
777 get_edge_attr(e, "source_margin"),
778 )
779
780 target_margin = to_marker_edge(
781 get_node_attr(e[1], "size"),
782 get_node_attr(e[1], "shape"),
783 )
784 target_margin = max(
785 target_margin,
786 get_edge_attr(e, "target_margin"),
787 )
788 return mpl.patches.FancyArrowPatch(
789 edge_subgraph.nodes[e[0]][pos],
790 edge_subgraph.nodes[e[1]][pos],
791 arrowstyle=get_edge_attr(e, "arrowstyle"),
792 connectionstyle=(
793 get_edge_attr(e, "curvature")
794 if e[0] != e[1]
795 else self_loop(
796 0 if len(e) == 2 else e[2] % 4,
797 get_node_attr(e[0], "size"),
798 )
799 ),
800 color=get_edge_attr(e, "color"),
801 linestyle=get_edge_attr(e, "style"),
802 linewidth=get_edge_attr(e, "width"),
803 mutation_scale=get_edge_attr(e, "arrowsize"),
804 shrinkA=source_margin,
805 shrinkB=source_margin,
806 zorder=1,
807 )
808
809 class CurvedArrowText(CurvedArrowTextBase, mpl.text.Text):
810 pass
811
812 ### Draw the nodes first
813 node_visible = kwargs.get("node_visible", "visible")
814 if isinstance(node_visible, bool):
815 if node_visible:
816 visible_nodes = G.nodes()
817 else:
818 visible_nodes = []
819 else:
820 visible_nodes = [
821 n for n, v in nx.get_node_attributes(G, node_visible, True).items() if v
822 ]
823
824 node_subgraph = G.subgraph(visible_nodes)
825
826 # Ignore the default dict value since that's for default values to use, not
827 # default attribute name
828 pos = kwargs.get("node_pos", "pos")
829
830 default_display_pos_attr = "display's position attribute name"
831 if callable(pos):
832 nx.set_node_attributes(
833 node_subgraph, pos(node_subgraph), default_display_pos_attr
834 )
835 pos = default_display_pos_attr
836 kwargs["node_pos"] = default_display_pos_attr
837 elif nx.get_node_attributes(G, pos) == {}:
838 nx.set_node_attributes(
839 node_subgraph, nx.spring_layout(node_subgraph), default_display_pos_attr
840 )
841 pos = default_display_pos_attr
842 kwargs["node_pos"] = default_display_pos_attr
843
844 # Each shape requires a new scatter object since they can't have different
845 # shapes.
846 if len(visible_nodes) > 0:
847 node_shape = kwargs.get("node_shape", "shape")
848 for shape in Counter(
849 nx.get_node_attributes(
850 node_subgraph, node_shape, defaults["node_shape"]
851 ).values()
852 ):
853 # Filter position just on this shape.
854 nodes_with_shape = [
855 n
856 for n, s in node_subgraph.nodes(data=node_shape)
857 if s == shape or (s is None and shape == defaults["node_shape"])
858 ]
859 # There are two property sequences to create before hand.
860 # 1. position, since it is used for x and y parameters to scatter
861 # 2. edgecolor, since the spaeical 'face' parameter value can only be
862 # be passed in as the sole string, not part of a list of strings.
863 position = np.asarray(node_property_sequence(nodes_with_shape, "pos"))
864 color = np.asarray(
865 [
866 compute_colors(c, a)
867 for c, a in zip(
868 node_property_sequence(nodes_with_shape, "color"),
869 node_property_sequence(nodes_with_shape, "alpha"),
870 )
871 ]
872 )
873 border_color = np.asarray(
874 [
875 (
876 c
877 if (
878 c := get_node_attr(
879 n,
880 "border_color",
881 False,
882 )
883 )
884 != "face"
885 else color[i]
886 )
887 for i, n in enumerate(nodes_with_shape)
888 ]
889 )
890 canvas.scatter(
891 position[:, 0],
892 position[:, 1],
893 s=node_property_sequence(nodes_with_shape, "size"),
894 c=color,
895 marker=shape,
896 linewidths=node_property_sequence(nodes_with_shape, "border_width"),
897 edgecolors=border_color,
898 zorder=2,
899 )
900
901 ### Draw node labels
902 node_label = kwargs.get("node_label", "label")
903 # Plot labels if node_label is not None and not False
904 if node_label is not None and node_label is not False:
905 default_dict = {}
906 if isinstance(node_label, dict):
907 default_dict = node_label
908 node_label = None
909
910 for n, lbl in node_subgraph.nodes(data=node_label):
911 if lbl is False:
912 continue
913
914 # We work with label dicts down here...
915 if not isinstance(lbl, dict):
916 lbl = {"label": lbl if lbl is not None else n}
917
918 lbl_text = lbl.get("label", n)
919 if not isinstance(lbl_text, str):
920 lbl_text = str(lbl_text)
921
922 lbl.update(default_dict)
923 x, y = node_subgraph.nodes[n][pos]
924 canvas.text(
925 x,
926 y,
927 lbl_text,
928 size=lbl.get("size", defaults["node_label"]["size"]),
929 color=lbl.get("color", defaults["node_label"]["color"]),
930 family=lbl.get("family", defaults["node_label"]["family"]),
931 weight=lbl.get("weight", defaults["node_label"]["weight"]),
932 horizontalalignment=lbl.get(
933 "h_align", defaults["node_label"]["h_align"]
934 ),
935 verticalalignment=lbl.get("v_align", defaults["node_label"]["v_align"]),
936 transform=canvas.transData,
937 bbox=lbl.get("bbox", defaults["node_label"]["bbox"]),
938 )
939
940 ### Draw edges
941
942 edge_visible = kwargs.get("edge_visible", "visible")
943 if isinstance(edge_visible, bool):
944 if edge_visible:
945 visible_edges = G.edges()
946 else:
947 visible_edges = []
948 else:
949 visible_edges = [
950 e for e, v in nx.get_edge_attributes(G, edge_visible, True).items() if v
951 ]
952
953 edge_subgraph = G.edge_subgraph(visible_edges)
954 nx.set_node_attributes(
955 edge_subgraph, nx.get_node_attributes(node_subgraph, pos), name=pos
956 )
957
958 collection_edges = (
959 [e for e in edge_subgraph.edges(keys=True) if collection_compatible(e)]
960 if edge_subgraph.is_multigraph()
961 else [e for e in edge_subgraph.edges() if collection_compatible(e)]
962 )
963 non_collection_edges = (
964 [e for e in edge_subgraph.edges(keys=True) if not collection_compatible(e)]
965 if edge_subgraph.is_multigraph()
966 else [e for e in edge_subgraph.edges() if not collection_compatible(e)]
967 )
968 edge_position = np.asarray(
969 [
970 (
971 get_node_attr(u, "pos", use_edge_subgraph=True),
972 get_node_attr(v, "pos", use_edge_subgraph=True),
973 )
974 for u, v, *_ in collection_edges
975 ]
976 )
977
978 # Only plot a line collection if needed
979 if len(collection_edges) > 0:
980 edge_collection = mpl.collections.LineCollection(
981 edge_position,
982 colors=edge_property_sequence(collection_edges, "color"),
983 linewidths=edge_property_sequence(collection_edges, "width"),
984 linestyle=edge_property_sequence(collection_edges, "style"),
985 alpha=edge_property_sequence(collection_edges, "alpha"),
986 antialiaseds=(1,),
987 zorder=1,
988 )
989 canvas.add_collection(edge_collection)
990
991 fancy_arrows = {}
992 if len(non_collection_edges) > 0:
993 for e in non_collection_edges:
994 # Cache results for use in edge labels
995 fancy_arrows[e] = build_fancy_arrow(e)
996 canvas.add_patch(fancy_arrows[e])
997
998 ### Draw edge labels
999 edge_label = kwargs.get("edge_label", "label")
1000 default_dict = {}
1001 if isinstance(edge_label, dict):
1002 default_dict = edge_label
1003 # Restore the default label attribute key of 'label'
1004 edge_label = "label"
1005
1006 # Handle multigraphs
1007 edge_label_data = (
1008 edge_subgraph.edges(data=edge_label, keys=True)
1009 if edge_subgraph.is_multigraph()
1010 else edge_subgraph.edges(data=edge_label)
1011 )
1012 if edge_label is not None and edge_label is not False:
1013 for *e, lbl in edge_label_data:
1014 e = tuple(e)
1015 # I'm not sure how I want to handle None here... For now it means no label
1016 if lbl is False or lbl is None:
1017 continue
1018
1019 if not isinstance(lbl, dict):
1020 lbl = {"label": lbl}
1021
1022 lbl.update(default_dict)
1023 lbl_text = lbl.get("label")
1024 if not isinstance(lbl_text, str):
1025 lbl_text = str(lbl_text)
1026
1027 # In the old code, every non-self-loop is placed via a fancy arrow patch
1028 # Only compute a new fancy arrow if needed by caching the results from
1029 # edge placement.
1030 try:
1031 arrow = fancy_arrows[e]
1032 except KeyError:
1033 arrow = build_fancy_arrow(e)
1034
1035 if e[0] == e[1]:
1036 # Taken directly from draw_networkx_edge_labels
1037 connectionstyle_obj = arrow.get_connectionstyle()
1038 posA = canvas.transData.transform(edge_subgraph.nodes[e[0]][pos])
1039 path_disp = connectionstyle_obj(posA, posA)
1040 path_data = canvas.transData.inverted().transform_path(path_disp)
1041 x, y = path_data.vertices[0]
1042 canvas.text(
1043 x,
1044 y,
1045 lbl_text,
1046 size=lbl.get("size", defaults["edge_label"]["size"]),
1047 color=lbl.get("color", defaults["edge_label"]["color"]),
1048 family=lbl.get("family", defaults["edge_label"]["family"]),
1049 weight=lbl.get("weight", defaults["edge_label"]["weight"]),
1050 alpha=lbl.get("alpha", defaults["edge_label"]["alpha"]),
1051 horizontalalignment=lbl.get(
1052 "h_align", defaults["edge_label"]["h_align"]
1053 ),
1054 verticalalignment=lbl.get(
1055 "v_align", defaults["edge_label"]["v_align"]
1056 ),
1057 rotation=0,
1058 transform=canvas.transData,
1059 bbox=lbl.get("bbox", defaults["edge_label"]["bbox"]),
1060 zorder=1,
1061 )
1062 continue
1063
1064 CurvedArrowText(
1065 arrow,
1066 lbl_text,
1067 size=lbl.get("size", defaults["edge_label"]["size"]),
1068 color=lbl.get("color", defaults["edge_label"]["color"]),
1069 family=lbl.get("family", defaults["edge_label"]["family"]),
1070 weight=lbl.get("weight", defaults["edge_label"]["weight"]),
1071 alpha=lbl.get("alpha", defaults["edge_label"]["alpha"]),
1072 bbox=lbl.get("bbox", defaults["edge_label"]["bbox"]),
1073 horizontalalignment=lbl.get(
1074 "h_align", defaults["edge_label"]["h_align"]
1075 ),
1076 verticalalignment=lbl.get("v_align", defaults["edge_label"]["v_align"]),
1077 label_pos=lbl.get("pos", defaults["edge_label"]["pos"]),
1078 labels_horizontal=lbl.get("rotate", defaults["edge_label"]["rotate"]),
1079 transform=canvas.transData,
1080 zorder=1,
1081 ax=canvas,
1082 )
1083
1084 # If we had to add an attribute, remove it here
1085 if pos == default_display_pos_attr:
1086 nx.remove_node_attributes(G, default_display_pos_attr)
1087
1088 return G
1089
1090
1091def draw(G, pos=None, ax=None, **kwds):
1092 """Draw the graph G with Matplotlib.
1093
1094 Draw the graph as a simple representation with no node
1095 labels or edge labels and using the full Matplotlib figure area
1096 and no axis labels by default. See draw_networkx() for more
1097 full-featured drawing that allows title, axis labels etc.
1098
1099 Parameters
1100 ----------
1101 G : graph
1102 A networkx graph
1103
1104 pos : dictionary, optional
1105 A dictionary with nodes as keys and positions as values.
1106 If not specified a spring layout positioning will be computed.
1107 See :py:mod:`networkx.drawing.layout` for functions that
1108 compute node positions.
1109
1110 ax : Matplotlib Axes object, optional
1111 Draw the graph in specified Matplotlib axes.
1112
1113 kwds : optional keywords
1114 See networkx.draw_networkx() for a description of optional keywords.
1115
1116 Examples
1117 --------
1118 >>> G = nx.dodecahedral_graph()
1119 >>> nx.draw(G)
1120 >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout
1121
1122 See Also
1123 --------
1124 draw_networkx
1125 draw_networkx_nodes
1126 draw_networkx_edges
1127 draw_networkx_labels
1128 draw_networkx_edge_labels
1129
1130 Notes
1131 -----
1132 This function has the same name as pylab.draw and pyplot.draw
1133 so beware when using `from networkx import *`
1134
1135 since you might overwrite the pylab.draw function.
1136
1137 With pyplot use
1138
1139 >>> import matplotlib.pyplot as plt
1140 >>> G = nx.dodecahedral_graph()
1141 >>> nx.draw(G) # networkx draw()
1142 >>> plt.draw() # pyplot draw()
1143
1144 Also see the NetworkX drawing examples at
1145 https://networkx.org/documentation/latest/auto_examples/index.html
1146 """
1147
1148 import matplotlib.pyplot as plt
1149
1150 if ax is None:
1151 cf = plt.gcf()
1152 else:
1153 cf = ax.get_figure()
1154 cf.set_facecolor("w")
1155 if ax is None:
1156 if cf.axes:
1157 ax = cf.gca()
1158 else:
1159 ax = cf.add_axes((0, 0, 1, 1))
1160
1161 if "with_labels" not in kwds:
1162 kwds["with_labels"] = "labels" in kwds
1163
1164 draw_networkx(G, pos=pos, ax=ax, **kwds)
1165 ax.set_axis_off()
1166 plt.draw_if_interactive()
1167 return
1168
1169
1170def draw_networkx(G, pos=None, arrows=None, with_labels=True, **kwds):
1171 r"""Draw the graph G using Matplotlib.
1172
1173 Draw the graph with Matplotlib with options for node positions,
1174 labeling, titles, and many other drawing features.
1175 See draw() for simple drawing without labels or axes.
1176
1177 Parameters
1178 ----------
1179 G : graph
1180 A networkx graph
1181
1182 pos : dictionary, optional
1183 A dictionary with nodes as keys and positions as values.
1184 If not specified a spring layout positioning will be computed.
1185 See :py:mod:`networkx.drawing.layout` for functions that
1186 compute node positions.
1187
1188 arrows : bool or None, optional (default=None)
1189 If `None`, directed graphs draw arrowheads with
1190 `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges
1191 via `~matplotlib.collections.LineCollection` for speed.
1192 If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish).
1193 If `False`, draw edges using LineCollection (linear and fast).
1194 For directed graphs, if True draw arrowheads.
1195 Note: Arrows will be the same color as edges.
1196
1197 arrowstyle : str (default='-\|>' for directed graphs)
1198 For directed graphs, choose the style of the arrowsheads.
1199 For undirected graphs default to '-'
1200
1201 See `matplotlib.patches.ArrowStyle` for more options.
1202
1203 arrowsize : int or list (default=10)
1204 For directed graphs, choose the size of the arrow head's length and
1205 width. A list of values can be passed in to assign a different size for arrow head's length and width.
1206 See `matplotlib.patches.FancyArrowPatch` for attribute `mutation_scale`
1207 for more info.
1208
1209 with_labels : bool (default=True)
1210 Set to True to draw labels on the nodes.
1211
1212 ax : Matplotlib Axes object, optional
1213 Draw the graph in the specified Matplotlib axes.
1214
1215 nodelist : list (default=list(G))
1216 Draw only specified nodes
1217
1218 edgelist : list (default=list(G.edges()))
1219 Draw only specified edges
1220
1221 node_size : scalar or array (default=300)
1222 Size of nodes. If an array is specified it must be the
1223 same length as nodelist.
1224
1225 node_color : color or array of colors (default='#1f78b4')
1226 Node color. Can be a single color or a sequence of colors with the same
1227 length as nodelist. Color can be string or rgb (or rgba) tuple of
1228 floats from 0-1. If numeric values are specified they will be
1229 mapped to colors using the cmap and vmin,vmax parameters. See
1230 matplotlib.scatter for more details.
1231
1232 node_shape : string (default='o')
1233 The shape of the node. Specification is as matplotlib.scatter
1234 marker, one of 'so^>v<dph8'.
1235
1236 alpha : float or None (default=None)
1237 The node and edge transparency
1238
1239 cmap : Matplotlib colormap, optional
1240 Colormap for mapping intensities of nodes
1241
1242 vmin,vmax : float, optional
1243 Minimum and maximum for node colormap scaling
1244
1245 linewidths : scalar or sequence (default=1.0)
1246 Line width of symbol border
1247
1248 width : float or array of floats (default=1.0)
1249 Line width of edges
1250
1251 edge_color : color or array of colors (default='k')
1252 Edge color. Can be a single color or a sequence of colors with the same
1253 length as edgelist. Color can be string or rgb (or rgba) tuple of
1254 floats from 0-1. If numeric values are specified they will be
1255 mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters.
1256
1257 edge_cmap : Matplotlib colormap, optional
1258 Colormap for mapping intensities of edges
1259
1260 edge_vmin,edge_vmax : floats, optional
1261 Minimum and maximum for edge colormap scaling
1262
1263 style : string (default=solid line)
1264 Edge line style e.g.: '-', '--', '-.', ':'
1265 or words like 'solid' or 'dashed'.
1266 (See `matplotlib.patches.FancyArrowPatch`: `linestyle`)
1267
1268 labels : dictionary (default=None)
1269 Node labels in a dictionary of text labels keyed by node
1270
1271 font_size : int (default=12 for nodes, 10 for edges)
1272 Font size for text labels
1273
1274 font_color : color (default='k' black)
1275 Font color string. Color can be string or rgb (or rgba) tuple of
1276 floats from 0-1.
1277
1278 font_weight : string (default='normal')
1279 Font weight
1280
1281 font_family : string (default='sans-serif')
1282 Font family
1283
1284 label : string, optional
1285 Label for graph legend
1286
1287 hide_ticks : bool, optional
1288 Hide ticks of axes. When `True` (the default), ticks and ticklabels
1289 are removed from the axes. To set ticks and tick labels to the pyplot default,
1290 use ``hide_ticks=False``.
1291
1292 kwds : optional keywords
1293 See networkx.draw_networkx_nodes(), networkx.draw_networkx_edges(), and
1294 networkx.draw_networkx_labels() for a description of optional keywords.
1295
1296 Notes
1297 -----
1298 For directed graphs, arrows are drawn at the head end. Arrows can be
1299 turned off with keyword arrows=False.
1300
1301 Examples
1302 --------
1303 >>> G = nx.dodecahedral_graph()
1304 >>> nx.draw(G)
1305 >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout
1306
1307 >>> import matplotlib.pyplot as plt
1308 >>> limits = plt.axis("off") # turn off axis
1309
1310 Also see the NetworkX drawing examples at
1311 https://networkx.org/documentation/latest/auto_examples/index.html
1312
1313 See Also
1314 --------
1315 draw
1316 draw_networkx_nodes
1317 draw_networkx_edges
1318 draw_networkx_labels
1319 draw_networkx_edge_labels
1320 """
1321 from inspect import signature
1322
1323 import matplotlib.pyplot as plt
1324
1325 # Get all valid keywords by inspecting the signatures of draw_networkx_nodes,
1326 # draw_networkx_edges, draw_networkx_labels
1327
1328 valid_node_kwds = signature(draw_networkx_nodes).parameters.keys()
1329 valid_edge_kwds = signature(draw_networkx_edges).parameters.keys()
1330 valid_label_kwds = signature(draw_networkx_labels).parameters.keys()
1331
1332 # Create a set with all valid keywords across the three functions and
1333 # remove the arguments of this function (draw_networkx)
1334 valid_kwds = (valid_node_kwds | valid_edge_kwds | valid_label_kwds) - {
1335 "G",
1336 "pos",
1337 "arrows",
1338 "with_labels",
1339 }
1340
1341 if any(k not in valid_kwds for k in kwds):
1342 invalid_args = ", ".join([k for k in kwds if k not in valid_kwds])
1343 raise ValueError(f"Received invalid argument(s): {invalid_args}")
1344
1345 node_kwds = {k: v for k, v in kwds.items() if k in valid_node_kwds}
1346 edge_kwds = {k: v for k, v in kwds.items() if k in valid_edge_kwds}
1347 label_kwds = {k: v for k, v in kwds.items() if k in valid_label_kwds}
1348
1349 if pos is None:
1350 pos = nx.drawing.spring_layout(G) # default to spring layout
1351
1352 draw_networkx_nodes(G, pos, **node_kwds)
1353 draw_networkx_edges(G, pos, arrows=arrows, **edge_kwds)
1354 if with_labels:
1355 draw_networkx_labels(G, pos, **label_kwds)
1356 plt.draw_if_interactive()
1357
1358
1359def draw_networkx_nodes(
1360 G,
1361 pos,
1362 nodelist=None,
1363 node_size=300,
1364 node_color="#1f78b4",
1365 node_shape="o",
1366 alpha=None,
1367 cmap=None,
1368 vmin=None,
1369 vmax=None,
1370 ax=None,
1371 linewidths=None,
1372 edgecolors=None,
1373 label=None,
1374 margins=None,
1375 hide_ticks=True,
1376):
1377 """Draw the nodes of the graph G.
1378
1379 This draws only the nodes of the graph G.
1380
1381 Parameters
1382 ----------
1383 G : graph
1384 A networkx graph
1385
1386 pos : dictionary
1387 A dictionary with nodes as keys and positions as values.
1388 Positions should be sequences of length 2.
1389
1390 ax : Matplotlib Axes object, optional
1391 Draw the graph in the specified Matplotlib axes.
1392
1393 nodelist : list (default list(G))
1394 Draw only specified nodes
1395
1396 node_size : scalar or array (default=300)
1397 Size of nodes. If an array it must be the same length as nodelist.
1398
1399 node_color : color or array of colors (default='#1f78b4')
1400 Node color. Can be a single color or a sequence of colors with the same
1401 length as nodelist. Color can be string or rgb (or rgba) tuple of
1402 floats from 0-1. If numeric values are specified they will be
1403 mapped to colors using the cmap and vmin,vmax parameters. See
1404 matplotlib.scatter for more details.
1405
1406 node_shape : string (default='o')
1407 The shape of the node. Specification is as matplotlib.scatter
1408 marker, one of 'so^>v<dph8'.
1409
1410 alpha : float or array of floats (default=None)
1411 The node transparency. This can be a single alpha value,
1412 in which case it will be applied to all the nodes of color. Otherwise,
1413 if it is an array, the elements of alpha will be applied to the colors
1414 in order (cycling through alpha multiple times if necessary).
1415
1416 cmap : Matplotlib colormap (default=None)
1417 Colormap for mapping intensities of nodes
1418
1419 vmin,vmax : floats or None (default=None)
1420 Minimum and maximum for node colormap scaling
1421
1422 linewidths : [None | scalar | sequence] (default=1.0)
1423 Line width of symbol border
1424
1425 edgecolors : [None | scalar | sequence] (default = node_color)
1426 Colors of node borders. Can be a single color or a sequence of colors with the
1427 same length as nodelist. Color can be string or rgb (or rgba) tuple of floats
1428 from 0-1. If numeric values are specified they will be mapped to colors
1429 using the cmap and vmin,vmax parameters. See `~matplotlib.pyplot.scatter` for more details.
1430
1431 label : [None | string]
1432 Label for legend
1433
1434 margins : float or 2-tuple, optional
1435 Sets the padding for axis autoscaling. Increase margin to prevent
1436 clipping for nodes that are near the edges of an image. Values should
1437 be in the range ``[0, 1]``. See :meth:`matplotlib.axes.Axes.margins`
1438 for details. The default is `None`, which uses the Matplotlib default.
1439
1440 hide_ticks : bool, optional
1441 Hide ticks of axes. When `True` (the default), ticks and ticklabels
1442 are removed from the axes. To set ticks and tick labels to the pyplot default,
1443 use ``hide_ticks=False``.
1444
1445 Returns
1446 -------
1447 matplotlib.collections.PathCollection
1448 `PathCollection` of the nodes.
1449
1450 Examples
1451 --------
1452 >>> G = nx.dodecahedral_graph()
1453 >>> nodes = nx.draw_networkx_nodes(G, pos=nx.spring_layout(G))
1454
1455 Also see the NetworkX drawing examples at
1456 https://networkx.org/documentation/latest/auto_examples/index.html
1457
1458 See Also
1459 --------
1460 draw
1461 draw_networkx
1462 draw_networkx_edges
1463 draw_networkx_labels
1464 draw_networkx_edge_labels
1465 """
1466 from collections.abc import Iterable
1467
1468 import matplotlib as mpl
1469 import matplotlib.collections # call as mpl.collections
1470 import matplotlib.pyplot as plt
1471 import numpy as np
1472
1473 if ax is None:
1474 ax = plt.gca()
1475
1476 if nodelist is None:
1477 nodelist = list(G)
1478
1479 if len(nodelist) == 0: # empty nodelist, no drawing
1480 return mpl.collections.PathCollection(None)
1481
1482 try:
1483 xy = np.asarray([pos[v] for v in nodelist])
1484 except KeyError as err:
1485 raise nx.NetworkXError(f"Node {err} has no position.") from err
1486
1487 if isinstance(alpha, Iterable):
1488 node_color = apply_alpha(node_color, alpha, nodelist, cmap, vmin, vmax)
1489 alpha = None
1490
1491 if not isinstance(node_shape, np.ndarray) and not isinstance(node_shape, list):
1492 node_shape = np.array([node_shape for _ in range(len(nodelist))])
1493
1494 for shape in np.unique(node_shape):
1495 node_collection = ax.scatter(
1496 xy[node_shape == shape, 0],
1497 xy[node_shape == shape, 1],
1498 s=node_size,
1499 c=node_color,
1500 marker=shape,
1501 cmap=cmap,
1502 vmin=vmin,
1503 vmax=vmax,
1504 alpha=alpha,
1505 linewidths=linewidths,
1506 edgecolors=edgecolors,
1507 label=label,
1508 )
1509 if hide_ticks:
1510 ax.tick_params(
1511 axis="both",
1512 which="both",
1513 bottom=False,
1514 left=False,
1515 labelbottom=False,
1516 labelleft=False,
1517 )
1518
1519 if margins is not None:
1520 if isinstance(margins, Iterable):
1521 ax.margins(*margins)
1522 else:
1523 ax.margins(margins)
1524
1525 node_collection.set_zorder(2)
1526 return node_collection
1527
1528
1529class FancyArrowFactory:
1530 """Draw arrows with `matplotlib.patches.FancyarrowPatch`"""
1531
1532 class ConnectionStyleFactory:
1533 def __init__(self, connectionstyles, selfloop_height, ax=None):
1534 import matplotlib as mpl
1535 import matplotlib.path # call as mpl.path
1536 import numpy as np
1537
1538 self.ax = ax
1539 self.mpl = mpl
1540 self.np = np
1541 self.base_connection_styles = [
1542 mpl.patches.ConnectionStyle(cs) for cs in connectionstyles
1543 ]
1544 self.n = len(self.base_connection_styles)
1545 self.selfloop_height = selfloop_height
1546
1547 def curved(self, edge_index):
1548 return self.base_connection_styles[edge_index % self.n]
1549
1550 def self_loop(self, edge_index):
1551 def self_loop_connection(posA, posB, *args, **kwargs):
1552 if not self.np.all(posA == posB):
1553 raise nx.NetworkXError(
1554 "`self_loop` connection style method"
1555 "is only to be used for self-loops"
1556 )
1557 # this is called with _screen space_ values
1558 # so convert back to data space
1559 data_loc = self.ax.transData.inverted().transform(posA)
1560 v_shift = 0.1 * self.selfloop_height
1561 h_shift = v_shift * 0.5
1562 # put the top of the loop first so arrow is not hidden by node
1563 path = self.np.asarray(
1564 [
1565 # 1
1566 [0, v_shift],
1567 # 4 4 4
1568 [h_shift, v_shift],
1569 [h_shift, 0],
1570 [0, 0],
1571 # 4 4 4
1572 [-h_shift, 0],
1573 [-h_shift, v_shift],
1574 [0, v_shift],
1575 ]
1576 )
1577 # Rotate self loop 90 deg. if more than 1
1578 # This will allow for maximum of 4 visible self loops
1579 if edge_index % 4:
1580 x, y = path.T
1581 for _ in range(edge_index % 4):
1582 x, y = y, -x
1583 path = self.np.array([x, y]).T
1584 return self.mpl.path.Path(
1585 self.ax.transData.transform(data_loc + path), [1, 4, 4, 4, 4, 4, 4]
1586 )
1587
1588 return self_loop_connection
1589
1590 def __init__(
1591 self,
1592 edge_pos,
1593 edgelist,
1594 nodelist,
1595 edge_indices,
1596 node_size,
1597 selfloop_height,
1598 connectionstyle="arc3",
1599 node_shape="o",
1600 arrowstyle="-",
1601 arrowsize=10,
1602 edge_color="k",
1603 alpha=None,
1604 linewidth=1.0,
1605 style="solid",
1606 min_source_margin=0,
1607 min_target_margin=0,
1608 ax=None,
1609 ):
1610 import matplotlib as mpl
1611 import matplotlib.patches # call as mpl.patches
1612 import matplotlib.pyplot as plt
1613 import numpy as np
1614
1615 if isinstance(connectionstyle, str):
1616 connectionstyle = [connectionstyle]
1617 elif np.iterable(connectionstyle):
1618 connectionstyle = list(connectionstyle)
1619 else:
1620 msg = "ConnectionStyleFactory arg `connectionstyle` must be str or iterable"
1621 raise nx.NetworkXError(msg)
1622 self.ax = ax
1623 self.mpl = mpl
1624 self.np = np
1625 self.edge_pos = edge_pos
1626 self.edgelist = edgelist
1627 self.nodelist = nodelist
1628 self.node_shape = node_shape
1629 self.min_source_margin = min_source_margin
1630 self.min_target_margin = min_target_margin
1631 self.edge_indices = edge_indices
1632 self.node_size = node_size
1633 self.connectionstyle_factory = self.ConnectionStyleFactory(
1634 connectionstyle, selfloop_height, ax
1635 )
1636 self.arrowstyle = arrowstyle
1637 self.arrowsize = arrowsize
1638 self.arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha)
1639 self.linewidth = linewidth
1640 self.style = style
1641 if isinstance(arrowsize, list) and len(arrowsize) != len(edge_pos):
1642 raise ValueError("arrowsize should have the same length as edgelist")
1643
1644 def __call__(self, i):
1645 (x1, y1), (x2, y2) = self.edge_pos[i]
1646 shrink_source = 0 # space from source to tail
1647 shrink_target = 0 # space from head to target
1648 if (
1649 self.np.iterable(self.min_source_margin)
1650 and not isinstance(self.min_source_margin, str)
1651 and not isinstance(self.min_source_margin, tuple)
1652 ):
1653 min_source_margin = self.min_source_margin[i]
1654 else:
1655 min_source_margin = self.min_source_margin
1656
1657 if (
1658 self.np.iterable(self.min_target_margin)
1659 and not isinstance(self.min_target_margin, str)
1660 and not isinstance(self.min_target_margin, tuple)
1661 ):
1662 min_target_margin = self.min_target_margin[i]
1663 else:
1664 min_target_margin = self.min_target_margin
1665
1666 if self.np.iterable(self.node_size): # many node sizes
1667 source, target = self.edgelist[i][:2]
1668 source_node_size = self.node_size[self.nodelist.index(source)]
1669 target_node_size = self.node_size[self.nodelist.index(target)]
1670 shrink_source = self.to_marker_edge(source_node_size, self.node_shape)
1671 shrink_target = self.to_marker_edge(target_node_size, self.node_shape)
1672 else:
1673 shrink_source = self.to_marker_edge(self.node_size, self.node_shape)
1674 shrink_target = shrink_source
1675 shrink_source = max(shrink_source, min_source_margin)
1676 shrink_target = max(shrink_target, min_target_margin)
1677
1678 # scale factor of arrow head
1679 if isinstance(self.arrowsize, list):
1680 mutation_scale = self.arrowsize[i]
1681 else:
1682 mutation_scale = self.arrowsize
1683
1684 if len(self.arrow_colors) > i:
1685 arrow_color = self.arrow_colors[i]
1686 elif len(self.arrow_colors) == 1:
1687 arrow_color = self.arrow_colors[0]
1688 else: # Cycle through colors
1689 arrow_color = self.arrow_colors[i % len(self.arrow_colors)]
1690
1691 if self.np.iterable(self.linewidth):
1692 if len(self.linewidth) > i:
1693 linewidth = self.linewidth[i]
1694 else:
1695 linewidth = self.linewidth[i % len(self.linewidth)]
1696 else:
1697 linewidth = self.linewidth
1698
1699 if (
1700 self.np.iterable(self.style)
1701 and not isinstance(self.style, str)
1702 and not isinstance(self.style, tuple)
1703 ):
1704 if len(self.style) > i:
1705 linestyle = self.style[i]
1706 else: # Cycle through styles
1707 linestyle = self.style[i % len(self.style)]
1708 else:
1709 linestyle = self.style
1710
1711 if x1 == x2 and y1 == y2:
1712 connectionstyle = self.connectionstyle_factory.self_loop(
1713 self.edge_indices[i]
1714 )
1715 else:
1716 connectionstyle = self.connectionstyle_factory.curved(self.edge_indices[i])
1717
1718 if (
1719 self.np.iterable(self.arrowstyle)
1720 and not isinstance(self.arrowstyle, str)
1721 and not isinstance(self.arrowstyle, tuple)
1722 ):
1723 arrowstyle = self.arrowstyle[i]
1724 else:
1725 arrowstyle = self.arrowstyle
1726
1727 return self.mpl.patches.FancyArrowPatch(
1728 (x1, y1),
1729 (x2, y2),
1730 arrowstyle=arrowstyle,
1731 shrinkA=shrink_source,
1732 shrinkB=shrink_target,
1733 mutation_scale=mutation_scale,
1734 color=arrow_color,
1735 linewidth=linewidth,
1736 connectionstyle=connectionstyle,
1737 linestyle=linestyle,
1738 zorder=1, # arrows go behind nodes
1739 )
1740
1741 def to_marker_edge(self, marker_size, marker):
1742 if marker in "s^>v<d": # `large` markers need extra space
1743 return self.np.sqrt(2 * marker_size) / 2
1744 else:
1745 return self.np.sqrt(marker_size) / 2
1746
1747
1748def draw_networkx_edges(
1749 G,
1750 pos,
1751 edgelist=None,
1752 width=1.0,
1753 edge_color="k",
1754 style="solid",
1755 alpha=None,
1756 arrowstyle=None,
1757 arrowsize=10,
1758 edge_cmap=None,
1759 edge_vmin=None,
1760 edge_vmax=None,
1761 ax=None,
1762 arrows=None,
1763 label=None,
1764 node_size=300,
1765 nodelist=None,
1766 node_shape="o",
1767 connectionstyle="arc3",
1768 min_source_margin=0,
1769 min_target_margin=0,
1770 hide_ticks=True,
1771):
1772 r"""Draw the edges of the graph G.
1773
1774 This draws only the edges of the graph G.
1775
1776 Parameters
1777 ----------
1778 G : graph
1779 A networkx graph
1780
1781 pos : dictionary
1782 A dictionary with nodes as keys and positions as values.
1783 Positions should be sequences of length 2.
1784
1785 edgelist : collection of edge tuples (default=G.edges())
1786 Draw only specified edges
1787
1788 width : float or array of floats (default=1.0)
1789 Line width of edges
1790
1791 edge_color : color or array of colors (default='k')
1792 Edge color. Can be a single color or a sequence of colors with the same
1793 length as edgelist. Color can be string or rgb (or rgba) tuple of
1794 floats from 0-1. If numeric values are specified they will be
1795 mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters.
1796
1797 style : string or array of strings (default='solid')
1798 Edge line style e.g.: '-', '--', '-.', ':'
1799 or words like 'solid' or 'dashed'.
1800 Can be a single style or a sequence of styles with the same
1801 length as the edge list.
1802 If less styles than edges are given the styles will cycle.
1803 If more styles than edges are given the styles will be used sequentially
1804 and not be exhausted.
1805 Also, `(offset, onoffseq)` tuples can be used as style instead of a strings.
1806 (See `matplotlib.patches.FancyArrowPatch`: `linestyle`)
1807
1808 alpha : float or array of floats (default=None)
1809 The edge transparency. This can be a single alpha value,
1810 in which case it will be applied to all specified edges. Otherwise,
1811 if it is an array, the elements of alpha will be applied to the colors
1812 in order (cycling through alpha multiple times if necessary).
1813
1814 edge_cmap : Matplotlib colormap, optional
1815 Colormap for mapping intensities of edges
1816
1817 edge_vmin,edge_vmax : floats, optional
1818 Minimum and maximum for edge colormap scaling
1819
1820 ax : Matplotlib Axes object, optional
1821 Draw the graph in the specified Matplotlib axes.
1822
1823 arrows : bool or None, optional (default=None)
1824 If `None`, directed graphs draw arrowheads with
1825 `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges
1826 via `~matplotlib.collections.LineCollection` for speed.
1827 If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish).
1828 If `False`, draw edges using LineCollection (linear and fast).
1829
1830 Note: Arrowheads will be the same color as edges.
1831
1832 arrowstyle : str or list of strs (default='-\|>' for directed graphs)
1833 For directed graphs and `arrows==True` defaults to '-\|>',
1834 For undirected graphs default to '-'.
1835
1836 See `matplotlib.patches.ArrowStyle` for more options.
1837
1838 arrowsize : int or list of ints(default=10)
1839 For directed graphs, choose the size of the arrow head's length and
1840 width. See `matplotlib.patches.FancyArrowPatch` for attribute
1841 `mutation_scale` for more info.
1842
1843 connectionstyle : string or iterable of strings (default="arc3")
1844 Pass the connectionstyle parameter to create curved arc of rounding
1845 radius rad. For example, connectionstyle='arc3,rad=0.2'.
1846 See `matplotlib.patches.ConnectionStyle` and
1847 `matplotlib.patches.FancyArrowPatch` for more info.
1848 If Iterable, index indicates i'th edge key of MultiGraph
1849
1850 node_size : scalar or array (default=300)
1851 Size of nodes. Though the nodes are not drawn with this function, the
1852 node size is used in determining edge positioning.
1853
1854 nodelist : list, optional (default=G.nodes())
1855 This provides the node order for the `node_size` array (if it is an array).
1856
1857 node_shape : string (default='o')
1858 The marker used for nodes, used in determining edge positioning.
1859 Specification is as a `matplotlib.markers` marker, e.g. one of 'so^>v<dph8'.
1860
1861 label : None or string
1862 Label for legend
1863
1864 min_source_margin : int or list of ints (default=0)
1865 The minimum margin (gap) at the beginning of the edge at the source.
1866
1867 min_target_margin : int or list of ints (default=0)
1868 The minimum margin (gap) at the end of the edge at the target.
1869
1870 hide_ticks : bool, optional
1871 Hide ticks of axes. When `True` (the default), ticks and ticklabels
1872 are removed from the axes. To set ticks and tick labels to the pyplot default,
1873 use ``hide_ticks=False``.
1874
1875 Returns
1876 -------
1877 matplotlib.collections.LineCollection or a list of matplotlib.patches.FancyArrowPatch
1878 If ``arrows=True``, a list of FancyArrowPatches is returned.
1879 If ``arrows=False``, a LineCollection is returned.
1880 If ``arrows=None`` (the default), then a LineCollection is returned if
1881 `G` is undirected, otherwise returns a list of FancyArrowPatches.
1882
1883 Notes
1884 -----
1885 For directed graphs, arrows are drawn at the head end. Arrows can be
1886 turned off with keyword arrows=False or by passing an arrowstyle without
1887 an arrow on the end.
1888
1889 Be sure to include `node_size` as a keyword argument; arrows are
1890 drawn considering the size of nodes.
1891
1892 Self-loops are always drawn with `~matplotlib.patches.FancyArrowPatch`
1893 regardless of the value of `arrows` or whether `G` is directed.
1894 When ``arrows=False`` or ``arrows=None`` and `G` is undirected, the
1895 FancyArrowPatches corresponding to the self-loops are not explicitly
1896 returned. They should instead be accessed via the ``Axes.patches``
1897 attribute (see examples).
1898
1899 Examples
1900 --------
1901 >>> G = nx.dodecahedral_graph()
1902 >>> edges = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
1903
1904 >>> G = nx.DiGraph()
1905 >>> G.add_edges_from([(1, 2), (1, 3), (2, 3)])
1906 >>> arcs = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
1907 >>> alphas = [0.3, 0.4, 0.5]
1908 >>> for i, arc in enumerate(arcs): # change alpha values of arcs
1909 ... arc.set_alpha(alphas[i])
1910
1911 The FancyArrowPatches corresponding to self-loops are not always
1912 returned, but can always be accessed via the ``patches`` attribute of the
1913 `matplotlib.Axes` object.
1914
1915 >>> import matplotlib.pyplot as plt
1916 >>> fig, ax = plt.subplots()
1917 >>> G = nx.Graph([(0, 1), (0, 0)]) # Self-loop at node 0
1918 >>> edge_collection = nx.draw_networkx_edges(G, pos=nx.circular_layout(G), ax=ax)
1919 >>> self_loop_fap = ax.patches[0]
1920
1921 Also see the NetworkX drawing examples at
1922 https://networkx.org/documentation/latest/auto_examples/index.html
1923
1924 See Also
1925 --------
1926 draw
1927 draw_networkx
1928 draw_networkx_nodes
1929 draw_networkx_labels
1930 draw_networkx_edge_labels
1931
1932 """
1933 import warnings
1934
1935 import matplotlib as mpl
1936 import matplotlib.collections # call as mpl.collections
1937 import matplotlib.colors # call as mpl.colors
1938 import matplotlib.pyplot as plt
1939 import numpy as np
1940
1941 # The default behavior is to use LineCollection to draw edges for
1942 # undirected graphs (for performance reasons) and use FancyArrowPatches
1943 # for directed graphs.
1944 # The `arrows` keyword can be used to override the default behavior
1945 if arrows is None:
1946 use_linecollection = not (G.is_directed() or G.is_multigraph())
1947 else:
1948 if not isinstance(arrows, bool):
1949 raise TypeError("Argument `arrows` must be of type bool or None")
1950 use_linecollection = not arrows
1951
1952 if isinstance(connectionstyle, str):
1953 connectionstyle = [connectionstyle]
1954 elif np.iterable(connectionstyle):
1955 connectionstyle = list(connectionstyle)
1956 else:
1957 msg = "draw_networkx_edges arg `connectionstyle` must be str or iterable"
1958 raise nx.NetworkXError(msg)
1959
1960 # Some kwargs only apply to FancyArrowPatches. Warn users when they use
1961 # non-default values for these kwargs when LineCollection is being used
1962 # instead of silently ignoring the specified option
1963 if use_linecollection:
1964 msg = (
1965 "\n\nThe {0} keyword argument is not applicable when drawing edges\n"
1966 "with LineCollection.\n\n"
1967 "To make this warning go away, either specify `arrows=True` to\n"
1968 "force FancyArrowPatches or use the default values.\n"
1969 "Note that using FancyArrowPatches may be slow for large graphs.\n"
1970 )
1971 if arrowstyle is not None:
1972 warnings.warn(msg.format("arrowstyle"), category=UserWarning, stacklevel=2)
1973 if arrowsize != 10:
1974 warnings.warn(msg.format("arrowsize"), category=UserWarning, stacklevel=2)
1975 if min_source_margin != 0:
1976 warnings.warn(
1977 msg.format("min_source_margin"), category=UserWarning, stacklevel=2
1978 )
1979 if min_target_margin != 0:
1980 warnings.warn(
1981 msg.format("min_target_margin"), category=UserWarning, stacklevel=2
1982 )
1983 if any(cs != "arc3" for cs in connectionstyle):
1984 warnings.warn(
1985 msg.format("connectionstyle"), category=UserWarning, stacklevel=2
1986 )
1987
1988 # NOTE: Arrowstyle modification must occur after the warnings section
1989 if arrowstyle is None:
1990 arrowstyle = "-|>" if G.is_directed() else "-"
1991
1992 if ax is None:
1993 ax = plt.gca()
1994
1995 if edgelist is None:
1996 edgelist = list(G.edges) # (u, v, k) for multigraph (u, v) otherwise
1997
1998 if len(edgelist):
1999 if G.is_multigraph():
2000 key_count = collections.defaultdict(lambda: itertools.count(0))
2001 edge_indices = [next(key_count[tuple(e[:2])]) for e in edgelist]
2002 else:
2003 edge_indices = [0] * len(edgelist)
2004 else: # no edges!
2005 return []
2006
2007 if nodelist is None:
2008 nodelist = list(G.nodes())
2009
2010 # FancyArrowPatch handles color=None different from LineCollection
2011 if edge_color is None:
2012 edge_color = "k"
2013
2014 # set edge positions
2015 edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])
2016
2017 # Check if edge_color is an array of floats and map to edge_cmap.
2018 # This is the only case handled differently from matplotlib
2019 if (
2020 np.iterable(edge_color)
2021 and (len(edge_color) == len(edge_pos))
2022 and np.all([isinstance(c, Number) for c in edge_color])
2023 ):
2024 if edge_cmap is not None:
2025 assert isinstance(edge_cmap, mpl.colors.Colormap)
2026 else:
2027 edge_cmap = plt.get_cmap()
2028 if edge_vmin is None:
2029 edge_vmin = min(edge_color)
2030 if edge_vmax is None:
2031 edge_vmax = max(edge_color)
2032 color_normal = mpl.colors.Normalize(vmin=edge_vmin, vmax=edge_vmax)
2033 edge_color = [edge_cmap(color_normal(e)) for e in edge_color]
2034
2035 # compute initial view
2036 minx = np.amin(np.ravel(edge_pos[:, :, 0]))
2037 maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
2038 miny = np.amin(np.ravel(edge_pos[:, :, 1]))
2039 maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
2040 w = maxx - minx
2041 h = maxy - miny
2042
2043 # Self-loops are scaled by view extent, except in cases the extent
2044 # is 0, e.g. for a single node. In this case, fall back to scaling
2045 # by the maximum node size
2046 selfloop_height = h if h != 0 else 0.005 * np.array(node_size).max()
2047 fancy_arrow_factory = FancyArrowFactory(
2048 edge_pos,
2049 edgelist,
2050 nodelist,
2051 edge_indices,
2052 node_size,
2053 selfloop_height,
2054 connectionstyle,
2055 node_shape,
2056 arrowstyle,
2057 arrowsize,
2058 edge_color,
2059 alpha,
2060 width,
2061 style,
2062 min_source_margin,
2063 min_target_margin,
2064 ax=ax,
2065 )
2066
2067 # Draw the edges
2068 if use_linecollection:
2069 edge_collection = mpl.collections.LineCollection(
2070 edge_pos,
2071 colors=edge_color,
2072 linewidths=width,
2073 antialiaseds=(1,),
2074 linestyle=style,
2075 alpha=alpha,
2076 )
2077 edge_collection.set_cmap(edge_cmap)
2078 edge_collection.set_clim(edge_vmin, edge_vmax)
2079 edge_collection.set_zorder(1) # edges go behind nodes
2080 edge_collection.set_label(label)
2081 ax.add_collection(edge_collection)
2082 edge_viz_obj = edge_collection
2083
2084 # Make sure selfloop edges are also drawn
2085 # ---------------------------------------
2086 selfloops_to_draw = [loop for loop in nx.selfloop_edges(G) if loop in edgelist]
2087 if selfloops_to_draw:
2088 edgelist_tuple = list(map(tuple, edgelist))
2089 arrow_collection = []
2090 for loop in selfloops_to_draw:
2091 i = edgelist_tuple.index(loop)
2092 arrow = fancy_arrow_factory(i)
2093 arrow_collection.append(arrow)
2094 ax.add_patch(arrow)
2095 else:
2096 edge_viz_obj = []
2097 for i in range(len(edgelist)):
2098 arrow = fancy_arrow_factory(i)
2099 ax.add_patch(arrow)
2100 edge_viz_obj.append(arrow)
2101
2102 # update view after drawing
2103 padx, pady = 0.05 * w, 0.05 * h
2104 corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady)
2105 ax.update_datalim(corners)
2106 ax.autoscale_view()
2107
2108 if hide_ticks:
2109 ax.tick_params(
2110 axis="both",
2111 which="both",
2112 bottom=False,
2113 left=False,
2114 labelbottom=False,
2115 labelleft=False,
2116 )
2117
2118 return edge_viz_obj
2119
2120
2121def draw_networkx_labels(
2122 G,
2123 pos,
2124 labels=None,
2125 font_size=12,
2126 font_color="k",
2127 font_family="sans-serif",
2128 font_weight="normal",
2129 alpha=None,
2130 bbox=None,
2131 horizontalalignment="center",
2132 verticalalignment="center",
2133 ax=None,
2134 clip_on=True,
2135 hide_ticks=True,
2136):
2137 """Draw node labels on the graph G.
2138
2139 Parameters
2140 ----------
2141 G : graph
2142 A networkx graph
2143
2144 pos : dictionary
2145 A dictionary with nodes as keys and positions as values.
2146 Positions should be sequences of length 2.
2147
2148 labels : dictionary (default={n: n for n in G})
2149 Node labels in a dictionary of text labels keyed by node.
2150 Node-keys in labels should appear as keys in `pos`.
2151 If needed use: `{n:lab for n,lab in labels.items() if n in pos}`
2152
2153 font_size : int or dictionary of nodes to ints (default=12)
2154 Font size for text labels.
2155
2156 font_color : color or dictionary of nodes to colors (default='k' black)
2157 Font color string. Color can be string or rgb (or rgba) tuple of
2158 floats from 0-1.
2159
2160 font_weight : string or dictionary of nodes to strings (default='normal')
2161 Font weight.
2162
2163 font_family : string or dictionary of nodes to strings (default='sans-serif')
2164 Font family.
2165
2166 alpha : float or None or dictionary of nodes to floats (default=None)
2167 The text transparency.
2168
2169 bbox : Matplotlib bbox, (default is Matplotlib's ax.text default)
2170 Specify text box properties (e.g. shape, color etc.) for node labels.
2171
2172 horizontalalignment : string or array of strings (default='center')
2173 Horizontal alignment {'center', 'right', 'left'}. If an array is
2174 specified it must be the same length as `nodelist`.
2175
2176 verticalalignment : string (default='center')
2177 Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}.
2178 If an array is specified it must be the same length as `nodelist`.
2179
2180 ax : Matplotlib Axes object, optional
2181 Draw the graph in the specified Matplotlib axes.
2182
2183 clip_on : bool (default=True)
2184 Turn on clipping of node labels at axis boundaries
2185
2186 hide_ticks : bool, optional
2187 Hide ticks of axes. When `True` (the default), ticks and ticklabels
2188 are removed from the axes. To set ticks and tick labels to the pyplot default,
2189 use ``hide_ticks=False``.
2190
2191 Returns
2192 -------
2193 dict
2194 `dict` of labels keyed on the nodes
2195
2196 Examples
2197 --------
2198 >>> G = nx.dodecahedral_graph()
2199 >>> labels = nx.draw_networkx_labels(G, pos=nx.spring_layout(G))
2200
2201 Also see the NetworkX drawing examples at
2202 https://networkx.org/documentation/latest/auto_examples/index.html
2203
2204 See Also
2205 --------
2206 draw
2207 draw_networkx
2208 draw_networkx_nodes
2209 draw_networkx_edges
2210 draw_networkx_edge_labels
2211 """
2212 import matplotlib.pyplot as plt
2213
2214 if ax is None:
2215 ax = plt.gca()
2216
2217 if labels is None:
2218 labels = {n: n for n in G.nodes()}
2219
2220 individual_params = set()
2221
2222 def check_individual_params(p_value, p_name):
2223 if isinstance(p_value, dict):
2224 if len(p_value) != len(labels):
2225 raise ValueError(f"{p_name} must have the same length as labels.")
2226 individual_params.add(p_name)
2227
2228 def get_param_value(node, p_value, p_name):
2229 if p_name in individual_params:
2230 return p_value[node]
2231 return p_value
2232
2233 check_individual_params(font_size, "font_size")
2234 check_individual_params(font_color, "font_color")
2235 check_individual_params(font_weight, "font_weight")
2236 check_individual_params(font_family, "font_family")
2237 check_individual_params(alpha, "alpha")
2238
2239 text_items = {} # there is no text collection so we'll fake one
2240 for n, label in labels.items():
2241 (x, y) = pos[n]
2242 if not isinstance(label, str):
2243 label = str(label) # this makes "1" and 1 labeled the same
2244 t = ax.text(
2245 x,
2246 y,
2247 label,
2248 size=get_param_value(n, font_size, "font_size"),
2249 color=get_param_value(n, font_color, "font_color"),
2250 family=get_param_value(n, font_family, "font_family"),
2251 weight=get_param_value(n, font_weight, "font_weight"),
2252 alpha=get_param_value(n, alpha, "alpha"),
2253 horizontalalignment=horizontalalignment,
2254 verticalalignment=verticalalignment,
2255 transform=ax.transData,
2256 bbox=bbox,
2257 clip_on=clip_on,
2258 )
2259 text_items[n] = t
2260
2261 if hide_ticks:
2262 ax.tick_params(
2263 axis="both",
2264 which="both",
2265 bottom=False,
2266 left=False,
2267 labelbottom=False,
2268 labelleft=False,
2269 )
2270
2271 return text_items
2272
2273
2274def draw_networkx_edge_labels(
2275 G,
2276 pos,
2277 edge_labels=None,
2278 label_pos=0.5,
2279 font_size=10,
2280 font_color="k",
2281 font_family="sans-serif",
2282 font_weight="normal",
2283 alpha=None,
2284 bbox=None,
2285 horizontalalignment="center",
2286 verticalalignment="center",
2287 ax=None,
2288 rotate=True,
2289 clip_on=True,
2290 node_size=300,
2291 nodelist=None,
2292 connectionstyle="arc3",
2293 hide_ticks=True,
2294):
2295 """Draw edge labels.
2296
2297 Parameters
2298 ----------
2299 G : graph
2300 A networkx graph
2301
2302 pos : dictionary
2303 A dictionary with nodes as keys and positions as values.
2304 Positions should be sequences of length 2.
2305
2306 edge_labels : dictionary (default=None)
2307 Edge labels in a dictionary of labels keyed by edge two-tuple.
2308 Only labels for the keys in the dictionary are drawn.
2309
2310 label_pos : float (default=0.5)
2311 Position of edge label along edge (0=head, 0.5=center, 1=tail)
2312
2313 font_size : int (default=10)
2314 Font size for text labels
2315
2316 font_color : color (default='k' black)
2317 Font color string. Color can be string or rgb (or rgba) tuple of
2318 floats from 0-1.
2319
2320 font_weight : string (default='normal')
2321 Font weight
2322
2323 font_family : string (default='sans-serif')
2324 Font family
2325
2326 alpha : float or None (default=None)
2327 The text transparency
2328
2329 bbox : Matplotlib bbox, optional
2330 Specify text box properties (e.g. shape, color etc.) for edge labels.
2331 Default is {boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)}.
2332
2333 horizontalalignment : string (default='center')
2334 Horizontal alignment {'center', 'right', 'left'}
2335
2336 verticalalignment : string (default='center')
2337 Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}
2338
2339 ax : Matplotlib Axes object, optional
2340 Draw the graph in the specified Matplotlib axes.
2341
2342 rotate : bool (default=True)
2343 Rotate edge labels to lie parallel to edges
2344
2345 clip_on : bool (default=True)
2346 Turn on clipping of edge labels at axis boundaries
2347
2348 node_size : scalar or array (default=300)
2349 Size of nodes. If an array it must be the same length as nodelist.
2350
2351 nodelist : list, optional (default=G.nodes())
2352 This provides the node order for the `node_size` array (if it is an array).
2353
2354 connectionstyle : string or iterable of strings (default="arc3")
2355 Pass the connectionstyle parameter to create curved arc of rounding
2356 radius rad. For example, connectionstyle='arc3,rad=0.2'.
2357 See `matplotlib.patches.ConnectionStyle` and
2358 `matplotlib.patches.FancyArrowPatch` for more info.
2359 If Iterable, index indicates i'th edge key of MultiGraph
2360
2361 hide_ticks : bool, optional
2362 Hide ticks of axes. When `True` (the default), ticks and ticklabels
2363 are removed from the axes. To set ticks and tick labels to the pyplot default,
2364 use ``hide_ticks=False``.
2365
2366 Returns
2367 -------
2368 dict
2369 `dict` of labels keyed by edge
2370
2371 Examples
2372 --------
2373 >>> G = nx.dodecahedral_graph()
2374 >>> edge_labels = nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G))
2375
2376 Also see the NetworkX drawing examples at
2377 https://networkx.org/documentation/latest/auto_examples/index.html
2378
2379 See Also
2380 --------
2381 draw
2382 draw_networkx
2383 draw_networkx_nodes
2384 draw_networkx_edges
2385 draw_networkx_labels
2386 """
2387 import matplotlib as mpl
2388 import matplotlib.pyplot as plt
2389 import numpy as np
2390
2391 class CurvedArrowText(CurvedArrowTextBase, mpl.text.Text):
2392 pass
2393
2394 # use default box of white with white border
2395 if bbox is None:
2396 bbox = {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)}
2397
2398 if isinstance(connectionstyle, str):
2399 connectionstyle = [connectionstyle]
2400 elif np.iterable(connectionstyle):
2401 connectionstyle = list(connectionstyle)
2402 else:
2403 raise nx.NetworkXError(
2404 "draw_networkx_edges arg `connectionstyle` must be"
2405 "string or iterable of strings"
2406 )
2407
2408 if ax is None:
2409 ax = plt.gca()
2410
2411 if edge_labels is None:
2412 kwds = {"keys": True} if G.is_multigraph() else {}
2413 edge_labels = {tuple(edge): d for *edge, d in G.edges(data=True, **kwds)}
2414 # NOTHING TO PLOT
2415 if not edge_labels:
2416 return {}
2417 edgelist, labels = zip(*edge_labels.items())
2418
2419 if nodelist is None:
2420 nodelist = list(G.nodes())
2421
2422 # set edge positions
2423 edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])
2424
2425 if G.is_multigraph():
2426 key_count = collections.defaultdict(lambda: itertools.count(0))
2427 edge_indices = [next(key_count[tuple(e[:2])]) for e in edgelist]
2428 else:
2429 edge_indices = [0] * len(edgelist)
2430
2431 # Used to determine self loop mid-point
2432 # Note, that this will not be accurate,
2433 # if not drawing edge_labels for all edges drawn
2434 h = 0
2435 if edge_labels:
2436 miny = np.amin(np.ravel(edge_pos[:, :, 1]))
2437 maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
2438 h = maxy - miny
2439 selfloop_height = h if h != 0 else 0.005 * np.array(node_size).max()
2440 fancy_arrow_factory = FancyArrowFactory(
2441 edge_pos,
2442 edgelist,
2443 nodelist,
2444 edge_indices,
2445 node_size,
2446 selfloop_height,
2447 connectionstyle,
2448 ax=ax,
2449 )
2450
2451 individual_params = {}
2452
2453 def check_individual_params(p_value, p_name):
2454 # TODO should this be list or array (as in a numpy array)?
2455 if isinstance(p_value, list):
2456 if len(p_value) != len(edgelist):
2457 raise ValueError(f"{p_name} must have the same length as edgelist.")
2458 individual_params[p_name] = p_value.iter()
2459
2460 # Don't need to pass in an edge because these are lists, not dicts
2461 def get_param_value(p_value, p_name):
2462 if p_name in individual_params:
2463 return next(individual_params[p_name])
2464 return p_value
2465
2466 check_individual_params(font_size, "font_size")
2467 check_individual_params(font_color, "font_color")
2468 check_individual_params(font_weight, "font_weight")
2469 check_individual_params(alpha, "alpha")
2470 check_individual_params(horizontalalignment, "horizontalalignment")
2471 check_individual_params(verticalalignment, "verticalalignment")
2472 check_individual_params(rotate, "rotate")
2473 check_individual_params(label_pos, "label_pos")
2474
2475 text_items = {}
2476 for i, (edge, label) in enumerate(zip(edgelist, labels)):
2477 if not isinstance(label, str):
2478 label = str(label) # this makes "1" and 1 labeled the same
2479
2480 n1, n2 = edge[:2]
2481 arrow = fancy_arrow_factory(i)
2482 if n1 == n2:
2483 connectionstyle_obj = arrow.get_connectionstyle()
2484 posA = ax.transData.transform(pos[n1])
2485 path_disp = connectionstyle_obj(posA, posA)
2486 path_data = ax.transData.inverted().transform_path(path_disp)
2487 x, y = path_data.vertices[0]
2488 text_items[edge] = ax.text(
2489 x,
2490 y,
2491 label,
2492 size=get_param_value(font_size, "font_size"),
2493 color=get_param_value(font_color, "font_color"),
2494 family=get_param_value(font_family, "font_family"),
2495 weight=get_param_value(font_weight, "font_weight"),
2496 alpha=get_param_value(alpha, "alpha"),
2497 horizontalalignment=get_param_value(
2498 horizontalalignment, "horizontalalignment"
2499 ),
2500 verticalalignment=get_param_value(
2501 verticalalignment, "verticalalignment"
2502 ),
2503 rotation=0,
2504 transform=ax.transData,
2505 bbox=bbox,
2506 zorder=1,
2507 clip_on=clip_on,
2508 )
2509 else:
2510 text_items[edge] = CurvedArrowText(
2511 arrow,
2512 label,
2513 size=get_param_value(font_size, "font_size"),
2514 color=get_param_value(font_color, "font_color"),
2515 family=get_param_value(font_family, "font_family"),
2516 weight=get_param_value(font_weight, "font_weight"),
2517 alpha=get_param_value(alpha, "alpha"),
2518 horizontalalignment=get_param_value(
2519 horizontalalignment, "horizontalalignment"
2520 ),
2521 verticalalignment=get_param_value(
2522 verticalalignment, "verticalalignment"
2523 ),
2524 transform=ax.transData,
2525 bbox=bbox,
2526 zorder=1,
2527 clip_on=clip_on,
2528 label_pos=get_param_value(label_pos, "label_pos"),
2529 labels_horizontal=not get_param_value(rotate, "rotate"),
2530 ax=ax,
2531 )
2532
2533 if hide_ticks:
2534 ax.tick_params(
2535 axis="both",
2536 which="both",
2537 bottom=False,
2538 left=False,
2539 labelbottom=False,
2540 labelleft=False,
2541 )
2542
2543 return text_items
2544
2545
2546def draw_bipartite(G, **kwargs):
2547 """Draw the graph `G` with a bipartite layout.
2548
2549 This is a convenience function equivalent to::
2550
2551 nx.draw(G, pos=nx.bipartite_layout(G), **kwargs)
2552
2553 Parameters
2554 ----------
2555 G : graph
2556 A networkx graph
2557
2558 kwargs : optional keywords
2559 See `draw_networkx` for a description of optional keywords.
2560
2561 Raises
2562 ------
2563 NetworkXError :
2564 If `G` is not bipartite.
2565
2566 Notes
2567 -----
2568 The layout is computed each time this function is called. For
2569 repeated drawing it is much more efficient to call
2570 `~networkx.drawing.layout.bipartite_layout` directly and reuse the result::
2571
2572 >>> G = nx.complete_bipartite_graph(3, 3)
2573 >>> pos = nx.bipartite_layout(G)
2574 >>> nx.draw(G, pos=pos) # Draw the original graph
2575 >>> # Draw a subgraph, reusing the same node positions
2576 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
2577
2578 Examples
2579 --------
2580 >>> G = nx.complete_bipartite_graph(2, 5)
2581 >>> nx.draw_bipartite(G)
2582
2583 See Also
2584 --------
2585 :func:`~networkx.drawing.layout.bipartite_layout`
2586 """
2587 draw(G, pos=nx.bipartite_layout(G), **kwargs)
2588
2589
2590def draw_circular(G, **kwargs):
2591 """Draw the graph `G` with a circular layout.
2592
2593 This is a convenience function equivalent to::
2594
2595 nx.draw(G, pos=nx.circular_layout(G), **kwargs)
2596
2597 Parameters
2598 ----------
2599 G : graph
2600 A networkx graph
2601
2602 kwargs : optional keywords
2603 See `draw_networkx` for a description of optional keywords.
2604
2605 Notes
2606 -----
2607 The layout is computed each time this function is called. For
2608 repeated drawing it is much more efficient to call
2609 `~networkx.drawing.layout.circular_layout` directly and reuse the result::
2610
2611 >>> G = nx.complete_graph(5)
2612 >>> pos = nx.circular_layout(G)
2613 >>> nx.draw(G, pos=pos) # Draw the original graph
2614 >>> # Draw a subgraph, reusing the same node positions
2615 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
2616
2617 Examples
2618 --------
2619 >>> G = nx.path_graph(5)
2620 >>> nx.draw_circular(G)
2621
2622 See Also
2623 --------
2624 :func:`~networkx.drawing.layout.circular_layout`
2625 """
2626 draw(G, pos=nx.circular_layout(G), **kwargs)
2627
2628
2629def draw_kamada_kawai(G, **kwargs):
2630 """Draw the graph `G` with a Kamada-Kawai force-directed layout.
2631
2632 This is a convenience function equivalent to::
2633
2634 nx.draw(G, pos=nx.kamada_kawai_layout(G), **kwargs)
2635
2636 Parameters
2637 ----------
2638 G : graph
2639 A networkx graph
2640
2641 kwargs : optional keywords
2642 See `draw_networkx` for a description of optional keywords.
2643
2644 Notes
2645 -----
2646 The layout is computed each time this function is called.
2647 For repeated drawing it is much more efficient to call
2648 `~networkx.drawing.layout.kamada_kawai_layout` directly and reuse the
2649 result::
2650
2651 >>> G = nx.complete_graph(5)
2652 >>> pos = nx.kamada_kawai_layout(G)
2653 >>> nx.draw(G, pos=pos) # Draw the original graph
2654 >>> # Draw a subgraph, reusing the same node positions
2655 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
2656
2657 Examples
2658 --------
2659 >>> G = nx.path_graph(5)
2660 >>> nx.draw_kamada_kawai(G)
2661
2662 See Also
2663 --------
2664 :func:`~networkx.drawing.layout.kamada_kawai_layout`
2665 """
2666 draw(G, pos=nx.kamada_kawai_layout(G), **kwargs)
2667
2668
2669def draw_random(G, **kwargs):
2670 """Draw the graph `G` with a random layout.
2671
2672 This is a convenience function equivalent to::
2673
2674 nx.draw(G, pos=nx.random_layout(G), **kwargs)
2675
2676 Parameters
2677 ----------
2678 G : graph
2679 A networkx graph
2680
2681 kwargs : optional keywords
2682 See `draw_networkx` for a description of optional keywords.
2683
2684 Notes
2685 -----
2686 The layout is computed each time this function is called.
2687 For repeated drawing it is much more efficient to call
2688 `~networkx.drawing.layout.random_layout` directly and reuse the result::
2689
2690 >>> G = nx.complete_graph(5)
2691 >>> pos = nx.random_layout(G)
2692 >>> nx.draw(G, pos=pos) # Draw the original graph
2693 >>> # Draw a subgraph, reusing the same node positions
2694 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
2695
2696 Examples
2697 --------
2698 >>> G = nx.lollipop_graph(4, 3)
2699 >>> nx.draw_random(G)
2700
2701 See Also
2702 --------
2703 :func:`~networkx.drawing.layout.random_layout`
2704 """
2705 draw(G, pos=nx.random_layout(G), **kwargs)
2706
2707
2708def draw_spectral(G, **kwargs):
2709 """Draw the graph `G` with a spectral 2D layout.
2710
2711 This is a convenience function equivalent to::
2712
2713 nx.draw(G, pos=nx.spectral_layout(G), **kwargs)
2714
2715 For more information about how node positions are determined, see
2716 `~networkx.drawing.layout.spectral_layout`.
2717
2718 Parameters
2719 ----------
2720 G : graph
2721 A networkx graph
2722
2723 kwargs : optional keywords
2724 See `draw_networkx` for a description of optional keywords.
2725
2726 Notes
2727 -----
2728 The layout is computed each time this function is called.
2729 For repeated drawing it is much more efficient to call
2730 `~networkx.drawing.layout.spectral_layout` directly and reuse the result::
2731
2732 >>> G = nx.complete_graph(5)
2733 >>> pos = nx.spectral_layout(G)
2734 >>> nx.draw(G, pos=pos) # Draw the original graph
2735 >>> # Draw a subgraph, reusing the same node positions
2736 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
2737
2738 Examples
2739 --------
2740 >>> G = nx.path_graph(5)
2741 >>> nx.draw_spectral(G)
2742
2743 See Also
2744 --------
2745 :func:`~networkx.drawing.layout.spectral_layout`
2746 """
2747 draw(G, pos=nx.spectral_layout(G), **kwargs)
2748
2749
2750def draw_spring(G, **kwargs):
2751 """Draw the graph `G` with a spring layout.
2752
2753 This is a convenience function equivalent to::
2754
2755 nx.draw(G, pos=nx.spring_layout(G), **kwargs)
2756
2757 Parameters
2758 ----------
2759 G : graph
2760 A networkx graph
2761
2762 kwargs : optional keywords
2763 See `draw_networkx` for a description of optional keywords.
2764
2765 Notes
2766 -----
2767 `~networkx.drawing.layout.spring_layout` is also the default layout for
2768 `draw`, so this function is equivalent to `draw`.
2769
2770 The layout is computed each time this function is called.
2771 For repeated drawing it is much more efficient to call
2772 `~networkx.drawing.layout.spring_layout` directly and reuse the result::
2773
2774 >>> G = nx.complete_graph(5)
2775 >>> pos = nx.spring_layout(G)
2776 >>> nx.draw(G, pos=pos) # Draw the original graph
2777 >>> # Draw a subgraph, reusing the same node positions
2778 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
2779
2780 Examples
2781 --------
2782 >>> G = nx.path_graph(20)
2783 >>> nx.draw_spring(G)
2784
2785 See Also
2786 --------
2787 draw
2788 :func:`~networkx.drawing.layout.spring_layout`
2789 """
2790 draw(G, pos=nx.spring_layout(G), **kwargs)
2791
2792
2793def draw_shell(G, nlist=None, **kwargs):
2794 """Draw networkx graph `G` with shell layout.
2795
2796 This is a convenience function equivalent to::
2797
2798 nx.draw(G, pos=nx.shell_layout(G, nlist=nlist), **kwargs)
2799
2800 Parameters
2801 ----------
2802 G : graph
2803 A networkx graph
2804
2805 nlist : list of list of nodes, optional
2806 A list containing lists of nodes representing the shells.
2807 Default is `None`, meaning all nodes are in a single shell.
2808 See `~networkx.drawing.layout.shell_layout` for details.
2809
2810 kwargs : optional keywords
2811 See `draw_networkx` for a description of optional keywords.
2812
2813 Notes
2814 -----
2815 The layout is computed each time this function is called.
2816 For repeated drawing it is much more efficient to call
2817 `~networkx.drawing.layout.shell_layout` directly and reuse the result::
2818
2819 >>> G = nx.complete_graph(5)
2820 >>> pos = nx.shell_layout(G)
2821 >>> nx.draw(G, pos=pos) # Draw the original graph
2822 >>> # Draw a subgraph, reusing the same node positions
2823 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
2824
2825 Examples
2826 --------
2827 >>> G = nx.path_graph(4)
2828 >>> shells = [[0], [1, 2, 3]]
2829 >>> nx.draw_shell(G, nlist=shells)
2830
2831 See Also
2832 --------
2833 :func:`~networkx.drawing.layout.shell_layout`
2834 """
2835 draw(G, pos=nx.shell_layout(G, nlist=nlist), **kwargs)
2836
2837
2838def draw_planar(G, **kwargs):
2839 """Draw a planar networkx graph `G` with planar layout.
2840
2841 This is a convenience function equivalent to::
2842
2843 nx.draw(G, pos=nx.planar_layout(G), **kwargs)
2844
2845 Parameters
2846 ----------
2847 G : graph
2848 A planar networkx graph
2849
2850 kwargs : optional keywords
2851 See `draw_networkx` for a description of optional keywords.
2852
2853 Raises
2854 ------
2855 NetworkXException
2856 When `G` is not planar
2857
2858 Notes
2859 -----
2860 The layout is computed each time this function is called.
2861 For repeated drawing it is much more efficient to call
2862 `~networkx.drawing.layout.planar_layout` directly and reuse the result::
2863
2864 >>> G = nx.path_graph(5)
2865 >>> pos = nx.planar_layout(G)
2866 >>> nx.draw(G, pos=pos) # Draw the original graph
2867 >>> # Draw a subgraph, reusing the same node positions
2868 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
2869
2870 Examples
2871 --------
2872 >>> G = nx.path_graph(4)
2873 >>> nx.draw_planar(G)
2874
2875 See Also
2876 --------
2877 :func:`~networkx.drawing.layout.planar_layout`
2878 """
2879 draw(G, pos=nx.planar_layout(G), **kwargs)
2880
2881
2882def draw_forceatlas2(G, **kwargs):
2883 """Draw a networkx graph with forceatlas2 layout.
2884
2885 This is a convenience function equivalent to::
2886
2887 nx.draw(G, pos=nx.forceatlas2_layout(G), **kwargs)
2888
2889 Parameters
2890 ----------
2891 G : graph
2892 A networkx graph
2893
2894 kwargs : optional keywords
2895 See networkx.draw_networkx() for a description of optional keywords,
2896 with the exception of the pos parameter which is not used by this
2897 function.
2898 """
2899 draw(G, pos=nx.forceatlas2_layout(G), **kwargs)
2900
2901
2902def apply_alpha(colors, alpha, elem_list, cmap=None, vmin=None, vmax=None):
2903 """Apply an alpha (or list of alphas) to the colors provided.
2904
2905 Parameters
2906 ----------
2907
2908 colors : color string or array of floats (default='r')
2909 Color of element. Can be a single color format string,
2910 or a sequence of colors with the same length as nodelist.
2911 If numeric values are specified they will be mapped to
2912 colors using the cmap and vmin,vmax parameters. See
2913 matplotlib.scatter for more details.
2914
2915 alpha : float or array of floats
2916 Alpha values for elements. This can be a single alpha value, in
2917 which case it will be applied to all the elements of color. Otherwise,
2918 if it is an array, the elements of alpha will be applied to the colors
2919 in order (cycling through alpha multiple times if necessary).
2920
2921 elem_list : array of networkx objects
2922 The list of elements which are being colored. These could be nodes,
2923 edges or labels.
2924
2925 cmap : matplotlib colormap
2926 Color map for use if colors is a list of floats corresponding to points
2927 on a color mapping.
2928
2929 vmin, vmax : float
2930 Minimum and maximum values for normalizing colors if a colormap is used
2931
2932 Returns
2933 -------
2934
2935 rgba_colors : numpy ndarray
2936 Array containing RGBA format values for each of the node colours.
2937
2938 """
2939 from itertools import cycle, islice
2940
2941 import matplotlib as mpl
2942 import matplotlib.cm # call as mpl.cm
2943 import matplotlib.colors # call as mpl.colors
2944 import numpy as np
2945
2946 # If we have been provided with a list of numbers as long as elem_list,
2947 # apply the color mapping.
2948 if len(colors) == len(elem_list) and isinstance(colors[0], Number):
2949 mapper = mpl.cm.ScalarMappable(cmap=cmap)
2950 mapper.set_clim(vmin, vmax)
2951 rgba_colors = mapper.to_rgba(colors)
2952 # Otherwise, convert colors to matplotlib's RGB using the colorConverter
2953 # object. These are converted to numpy ndarrays to be consistent with the
2954 # to_rgba method of ScalarMappable.
2955 else:
2956 try:
2957 rgba_colors = np.array([mpl.colors.colorConverter.to_rgba(colors)])
2958 except ValueError:
2959 rgba_colors = np.array(
2960 [mpl.colors.colorConverter.to_rgba(color) for color in colors]
2961 )
2962 # Set the final column of the rgba_colors to have the relevant alpha values
2963 try:
2964 # If alpha is longer than the number of colors, resize to the number of
2965 # elements. Also, if rgba_colors.size (the number of elements of
2966 # rgba_colors) is the same as the number of elements, resize the array,
2967 # to avoid it being interpreted as a colormap by scatter()
2968 if len(alpha) > len(rgba_colors) or rgba_colors.size == len(elem_list):
2969 rgba_colors = np.resize(rgba_colors, (len(elem_list), 4))
2970 rgba_colors[1:, 0] = rgba_colors[0, 0]
2971 rgba_colors[1:, 1] = rgba_colors[0, 1]
2972 rgba_colors[1:, 2] = rgba_colors[0, 2]
2973 rgba_colors[:, 3] = list(islice(cycle(alpha), len(rgba_colors)))
2974 except TypeError:
2975 rgba_colors[:, -1] = alpha
2976 return rgba_colors