1"""
2**********
3Matplotlib
4**********
5
6Draw networks with matplotlib.
7
8Examples
9--------
10>>> G = nx.complete_graph(5)
11>>> nx.draw(G)
12
13See Also
14--------
15 - :doc:`matplotlib <matplotlib:index>`
16 - :func:`matplotlib.pyplot.scatter`
17 - :obj:`matplotlib.patches.FancyArrowPatch`
18"""
19
20import collections
21import itertools
22import math
23from numbers import Number
24
25import networkx as nx
26
27__all__ = [
28 "display",
29 "apply_matplotlib_colors",
30 "draw",
31 "draw_networkx",
32 "draw_networkx_nodes",
33 "draw_networkx_edges",
34 "draw_networkx_labels",
35 "draw_networkx_edge_labels",
36 "draw_bipartite",
37 "draw_circular",
38 "draw_kamada_kawai",
39 "draw_random",
40 "draw_spectral",
41 "draw_spring",
42 "draw_planar",
43 "draw_shell",
44 "draw_forceatlas2",
45]
46
47
48def apply_matplotlib_colors(
49 G, src_attr, dest_attr, map, vmin=None, vmax=None, nodes=True
50):
51 """
52 Apply colors from a matplotlib colormap to a graph.
53
54 Reads values from the `src_attr` and use a matplotlib colormap
55 to produce a color. Write the color to `dest_attr`.
56
57 Parameters
58 ----------
59 G : nx.Graph
60 The graph to read and compute colors for.
61
62 src_attr : str or other attribute name
63 The name of the attribute to read from the graph.
64
65 dest_attr : str or other attribute name
66 The name of the attribute to write to on the graph.
67
68 map : matplotlib.colormap
69 The matplotlib colormap to use.
70
71 vmin : float, default None
72 The minimum value for scaling the colormap. If `None`, find the
73 minimum value of `src_attr`.
74
75 vmax : float, default None
76 The maximum value for scaling the colormap. If `None`, find the
77 maximum value of `src_attr`.
78
79 nodes : bool, default True
80 Whether the attribute names are edge attributes or node attributes.
81 """
82 import matplotlib as mpl
83
84 if nodes:
85 type_iter = G.nodes()
86 elif G.is_multigraph():
87 type_iter = G.edges(keys=True)
88 else:
89 type_iter = G.edges()
90
91 if vmin is None or vmax is None:
92 vals = [type_iter[a][src_attr] for a in type_iter]
93 if vmin is None:
94 vmin = min(vals)
95 if vmax is None:
96 vmax = max(vals)
97
98 mapper = mpl.cm.ScalarMappable(cmap=map)
99 mapper.set_clim(vmin, vmax)
100
101 def do_map(x):
102 # Cast numpy scalars to float
103 return tuple(float(x) for x in mapper.to_rgba(x))
104
105 if nodes:
106 nx.set_node_attributes(
107 G, {n: do_map(G.nodes[n][src_attr]) for n in G.nodes()}, dest_attr
108 )
109 else:
110 nx.set_edge_attributes(
111 G, {e: do_map(G.edges[e][src_attr]) for e in type_iter}, dest_attr
112 )
113
114
115class CurvedArrowTextBase:
116 def __init__(
117 self,
118 arrow,
119 *args,
120 label_pos=0.5,
121 labels_horizontal=False,
122 ax=None,
123 **kwargs,
124 ):
125 # Bind to FancyArrowPatch
126 self.arrow = arrow
127 # how far along the text should be on the curve,
128 # 0 is at start, 1 is at end etc.
129 self.label_pos = label_pos
130 self.labels_horizontal = labels_horizontal
131 if ax is None:
132 ax = plt.gca()
133 self.ax = ax
134 self.x, self.y, self.angle = self._update_text_pos_angle(arrow)
135
136 # Create text object
137 super().__init__(self.x, self.y, *args, rotation=self.angle, **kwargs)
138 # Bind to axis
139 self.ax.add_artist(self)
140
141 def _get_arrow_path_disp(self, arrow):
142 """
143 This is part of FancyArrowPatch._get_path_in_displaycoord
144 It omits the second part of the method where path is converted
145 to polygon based on width
146 The transform is taken from ax, not the object, as the object
147 has not been added yet, and doesn't have transform
148 """
149 dpi_cor = arrow._dpi_cor
150 trans_data = self.ax.transData
151 if arrow._posA_posB is None:
152 raise ValueError(
153 "Can only draw labels for fancy arrows with "
154 "posA and posB inputs, not custom path"
155 )
156 posA = arrow._convert_xy_units(arrow._posA_posB[0])
157 posB = arrow._convert_xy_units(arrow._posA_posB[1])
158 (posA, posB) = trans_data.transform((posA, posB))
159 _path = arrow.get_connectionstyle()(
160 posA,
161 posB,
162 patchA=arrow.patchA,
163 patchB=arrow.patchB,
164 shrinkA=arrow.shrinkA * dpi_cor,
165 shrinkB=arrow.shrinkB * dpi_cor,
166 )
167 # Return is in display coordinates
168 return _path
169
170 def _update_text_pos_angle(self, arrow):
171 # Fractional label position
172 # Text position at a proportion t along the line in display coords
173 # default is 0.5 so text appears at the halfway point
174 import matplotlib as mpl
175 import numpy as np
176
177 t = self.label_pos
178 tt = 1 - t
179 path_disp = self._get_arrow_path_disp(arrow)
180 conn = arrow.get_connectionstyle()
181 # 1. Calculate x and y
182 points = path_disp.vertices
183 if is_curve := isinstance(
184 conn,
185 mpl.patches.ConnectionStyle.Angle3 | mpl.patches.ConnectionStyle.Arc3,
186 ):
187 # Arc3 or Angle3 type Connection Styles - Bezier curve
188 (x1, y1), (cx, cy), (x2, y2) = points
189 x = tt**2 * x1 + 2 * t * tt * cx + t**2 * x2
190 y = tt**2 * y1 + 2 * t * tt * cy + t**2 * y2
191 else:
192 if not isinstance(
193 conn,
194 mpl.patches.ConnectionStyle.Angle
195 | mpl.patches.ConnectionStyle.Arc
196 | mpl.patches.ConnectionStyle.Bar,
197 ):
198 msg = f"invalid connection style: {type(conn)}"
199 raise TypeError(msg)
200 # A. Collect lines
201 codes = path_disp.codes
202 lines = [
203 points[i - 1 : i + 1]
204 for i in range(1, len(points))
205 if codes[i] == mpl.path.Path.LINETO
206 ]
207 # B. If more than one line, find the right one and position in it
208 if (nlines := len(lines)) != 1:
209 dists = [math.dist(*line) for line in lines]
210 dist_tot = sum(dists)
211 cdist = 0
212 last_cut = 0
213 i_last = nlines - 1
214 for i, dist in enumerate(dists):
215 cdist += dist
216 cut = cdist / dist_tot
217 if i == i_last or t < cut:
218 t = (t - last_cut) / (dist / dist_tot)
219 tt = 1 - t
220 lines = [lines[i]]
221 break
222 last_cut = cut
223 [[(cx1, cy1), (cx2, cy2)]] = lines
224 x = cx1 * tt + cx2 * t
225 y = cy1 * tt + cy2 * t
226
227 # 2. Calculate Angle
228 if self.labels_horizontal:
229 # Horizontal text labels
230 angle = 0
231 else:
232 # Labels parallel to curve
233 if is_curve:
234 change_x = 2 * tt * (cx - x1) + 2 * t * (x2 - cx)
235 change_y = 2 * tt * (cy - y1) + 2 * t * (y2 - cy)
236 else:
237 change_x = (cx2 - cx1) / 2
238 change_y = (cy2 - cy1) / 2
239 angle = np.arctan2(change_y, change_x) / (2 * np.pi) * 360
240 # Text is "right way up"
241 if angle > 90:
242 angle -= 180
243 elif angle < -90:
244 angle += 180
245 (x, y) = self.ax.transData.inverted().transform((x, y))
246 return x, y, angle
247
248 def draw(self, renderer):
249 # recalculate the text position and angle
250 self.x, self.y, self.angle = self._update_text_pos_angle(self.arrow)
251 self.set_position((self.x, self.y))
252 self.set_rotation(self.angle)
253 # redraw text
254 super().draw(renderer)
255
256
257def display(
258 G,
259 canvas=None,
260 **kwargs,
261):
262 """Draw the graph G.
263
264 Draw the graph as a collection of nodes connected by edges.
265 The exact details of what the graph looks like are controlled by the below
266 attributes. All nodes and nodes at the end of visible edges must have a
267 position set, but nearly all other node and edge attributes are options and
268 nodes or edges missing the attribute will use the default listed below. A more
269 complete description of each parameter is given below this summary.
270
271 .. list-table:: Default Visualization Attributes
272 :widths: 25 25 50
273 :header-rows: 1
274
275 * - Parameter
276 - Default Attribute
277 - Default Value
278 * - node_pos
279 - `"pos"`
280 - If there is not position, a layout will be calculated with `nx.spring_layout`.
281 * - node_visible
282 - `"visible"`
283 - True
284 * - node_color
285 - `"color"`
286 - #1f78b4
287 * - node_size
288 - `"size"`
289 - 300
290 * - node_label
291 - `"label"`
292 - Dict describing the node label. Defaults create a black text with
293 the node name as the label. The dict respects these keys and defaults:
294
295 * size : 12
296 * color : black
297 * family : sans serif
298 * weight : normal
299 * alpha : 1.0
300 * h_align : center
301 * v_align : center
302 * bbox : Dict describing a `matplotlib.patches.FancyBboxPatch`.
303 Default is None.
304
305 * - node_shape
306 - `"shape"`
307 - "o"
308 * - node_alpha
309 - `"alpha"`
310 - 1.0
311 * - node_border_width
312 - `"border_width"`
313 - 1.0
314 * - node_border_color
315 - `"border_color"`
316 - Matching node_color
317 * - edge_visible
318 - `"visible"`
319 - True
320 * - edge_width
321 - `"width"`
322 - 1.0
323 * - edge_color
324 - `"color"`
325 - Black (#000000)
326 * - edge_label
327 - `"label"`
328 - Dict describing the edge label. Defaults create black text with a
329 white bounding box. The dictionary respects these keys and defaults:
330
331 * size : 12
332 * color : black
333 * family : sans serif
334 * weight : normal
335 * alpha : 1.0
336 * bbox : Dict describing a `matplotlib.patches.FancyBboxPatch`.
337 Default {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)}
338 * h_align : "center"
339 * v_align : "center"
340 * pos : 0.5
341 * rotate : True
342
343 * - edge_style
344 - `"style"`
345 - "-"
346 * - edge_alpha
347 - `"alpha"`
348 - 1.0
349 * - edge_arrowstyle
350 - `"arrowstyle"`
351 - ``"-|>"`` if `G` is directed else ``"-"``
352 * - edge_arrowsize
353 - `"arrowsize"`
354 - 10 if `G` is directed else 0
355 * - edge_curvature
356 - `"curvature"`
357 - arc3
358 * - edge_source_margin
359 - `"source_margin"`
360 - 0
361 * - edge_target_margin
362 - `"target_margin"`
363 - 0
364
365 Parameters
366 ----------
367 G : graph
368 A networkx graph
369
370 canvas : Matplotlib Axes object, optional
371 Draw the graph in specified Matplotlib axes
372
373 node_pos : string or function, default "pos"
374 A string naming the node attribute storing the position of nodes as a tuple.
375 Or a function to be called with input `G` which returns the layout as a dict keyed
376 by node to position tuple like the NetworkX layout functions.
377 If no nodes in the graph has the attribute, a spring layout is calculated.
378
379 node_visible : string or bool, default visible
380 A string naming the node attribute which stores if a node should be drawn.
381 If `True`, all nodes will be visible while if `False` no nodes will be visible.
382 If incomplete, nodes missing this attribute will be shown by default.
383
384 node_color : string, default "color"
385 A string naming the node attribute which stores the color of each node.
386 Visible nodes without this attribute will use '#1f78b4' as a default.
387
388 node_size : string or number, default "size"
389 A string naming the node attribute which stores the size of each node.
390 Visible nodes without this attribute will use a default size of 300.
391
392 node_label : string or bool, default "label"
393 A string naming the node attribute which stores the label of each node.
394 The attribute value can be a string, False (no label for that node),
395 True (the node is the label) or a dict keyed by node to the label.
396
397 If a dict is specified, these keys are read to further control the label:
398
399 * label : The text of the label; default: name of the node
400 * size : Font size of the label; default: 12
401 * color : Font color of the label; default: black
402 * family : Font family of the label; default: "sans-serif"
403 * weight : Font weight of the label; default: "normal"
404 * alpha : Alpha value of the label; default: 1.0
405 * h_align : The horizontal alignment of the label.
406 one of "left", "center", "right"; default: "center"
407 * v_align : The vertical alignment of the label.
408 one of "top", "center", "bottom"; default: "center"
409 * bbox : A dict of parameters for `matplotlib.patches.FancyBboxPatch`.
410
411 Visible nodes without this attribute will be treated as if the value was True.
412
413 node_shape : string, default "shape"
414 A string naming the node attribute which stores the label of each node.
415 The values of this attribute are expected to be one of the matplotlib shapes,
416 one of 'so^>v<dph8'. Visible nodes without this attribute will use 'o'.
417
418 node_alpha : string, default "alpha"
419 A string naming the node attribute which stores the alpha of each node.
420 The values of this attribute are expected to be floats between 0.0 and 1.0.
421 Visible nodes without this attribute will be treated as if the value was 1.0.
422
423 node_border_width : string, default "border_width"
424 A string naming the node attribute storing the width of the border of the node.
425 The values of this attribute are expected to be numeric. Visible nodes without
426 this attribute will use the assumed default of 1.0.
427
428 node_border_color : string, default "border_color"
429 A string naming the node attribute which storing the color of the border of the node.
430 Visible nodes missing this attribute will use the final node_color value.
431
432 edge_visible : string or bool, default "visible"
433 A string nameing the edge attribute which stores if an edge should be drawn.
434 If `True`, all edges will be drawn while if `False` no edges will be visible.
435 If incomplete, edges missing this attribute will be shown by default. Values
436 of this attribute are expected to be booleans.
437
438 edge_width : string or int, default "width"
439 A string nameing the edge attribute which stores the width of each edge.
440 Visible edges without this attribute will use a default width of 1.0.
441
442 edge_color : string or color, default "color"
443 A string nameing the edge attribute which stores of color of each edge.
444 Visible edges without this attribute will be drawn black. Each color can be
445 a string or rgb (or rgba) tuple of floats from 0.0 to 1.0.
446
447 edge_label : string, default "label"
448 A string naming the edge attribute which stores the label of each edge.
449 The values of this attribute can be a string, number or False or None. In
450 the latter two cases, no edge label is displayed.
451
452 If a dict is specified, these keys are read to further control the label:
453
454 * label : The text of the label, or the name of an edge attribute holding the label.
455 * size : Font size of the label; default: 12
456 * color : Font color of the label; default: black
457 * family : Font family of the label; default: "sans-serif"
458 * weight : Font weight of the label; default: "normal"
459 * alpha : Alpha value of the label; default: 1.0
460 * h_align : The horizontal alignment of the label.
461 one of "left", "center", "right"; default: "center"
462 * v_align : The vertical alignment of the label.
463 one of "top", "center", "bottom"; default: "center"
464 * bbox : A dict of parameters for `matplotlib.patches.FancyBboxPatch`.
465 * rotate : Whether to rotate labels to lie parallel to the edge, default: True.
466 * pos : A float showing how far along the edge to put the label; default: 0.5.
467
468 edge_style : string, default "style"
469 A string naming the edge attribute which stores the style of each edge.
470 Visible edges without this attribute will be drawn solid. Values of this
471 attribute can be line styles, e.g. '-', '--', '-.' or ':' or words like 'solid'
472 or 'dashed'. If no edge in the graph has this attribute and it is a non-default
473 value, assume that it describes the edge style for all edges in the graph.
474
475 edge_alpha : string or float, default "alpha"
476 A string naming the edge attribute which stores the alpha value of each edge.
477 Visible edges without this attribute will use an alpha value of 1.0.
478
479 edge_arrowstyle : string, default "arrowstyle"
480 A string naming the edge attribute which stores the type of arrowhead to use for
481 each edge. Visible edges without this attribute use ``"-"`` for undirected graphs
482 and ``"-|>"`` for directed graphs.
483
484 See `matplotlib.patches.ArrowStyle` for more options
485
486 edge_arrowsize : string or int, default "arrowsize"
487 A string naming the edge attribute which stores the size of the arrowhead for each
488 edge. Visible edges without this attribute will use a default value of 10.
489
490 edge_curvature : string, default "curvature"
491 A string naming the edge attribute storing the curvature and connection style
492 of each edge. Visible edges without this attribute will use "arc3" as a default
493 value, resulting an a straight line between the two nodes. Curvature can be given
494 as 'arc3,rad=0.2' to specify both the style and radius of curvature.
495
496 Please see `matplotlib.patches.ConnectionStyle` and
497 `matplotlib.patches.FancyArrowPatch` for more information.
498
499 edge_source_margin : string or int, default "source_margin"
500 A string naming the edge attribute which stores the minimum margin (gap) between
501 the source node and the start of the edge. Visible edges without this attribute
502 will use a default value of 0.
503
504 edge_target_margin : string or int, default "target_margin"
505 A string naming the edge attribute which stores the minimumm margin (gap) between
506 the target node and the end of the edge. Visible edges without this attribute
507 will use a default value of 0.
508
509 hide_ticks : bool, default True
510 Weather to remove the ticks from the axes of the matplotlib object.
511
512 Raises
513 ------
514 NetworkXError
515 If a node or edge is missing a required parameter such as `pos` or
516 if `display` receives an argument not listed above.
517
518 ValueError
519 If a node or edge has an invalid color format, i.e. not a color string,
520 rgb tuple or rgba tuple.
521
522 Returns
523 -------
524 The input graph. This is potentially useful for dispatching visualization
525 functions.
526 """
527 from collections import Counter
528
529 import matplotlib as mpl
530 import matplotlib.pyplot as plt
531 import numpy as np
532
533 defaults = {
534 "node_pos": None,
535 "node_visible": True,
536 "node_color": "#1f78b4",
537 "node_size": 300,
538 "node_label": {
539 "size": 12,
540 "color": "#000000",
541 "family": "sans-serif",
542 "weight": "normal",
543 "alpha": 1.0,
544 "h_align": "center",
545 "v_align": "center",
546 "bbox": None,
547 },
548 "node_shape": "o",
549 "node_alpha": 1.0,
550 "node_border_width": 1.0,
551 "node_border_color": "face",
552 "edge_visible": True,
553 "edge_width": 1.0,
554 "edge_color": "#000000",
555 "edge_label": {
556 "size": 12,
557 "color": "#000000",
558 "family": "sans-serif",
559 "weight": "normal",
560 "alpha": 1.0,
561 "bbox": {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)},
562 "h_align": "center",
563 "v_align": "center",
564 "pos": 0.5,
565 "rotate": True,
566 },
567 "edge_style": "-",
568 "edge_alpha": 1.0,
569 "edge_arrowstyle": "-|>" if G.is_directed() else "-",
570 "edge_arrowsize": 10 if G.is_directed() else 0,
571 "edge_curvature": "arc3",
572 "edge_source_margin": 0,
573 "edge_target_margin": 0,
574 "hide_ticks": True,
575 }
576
577 # Check arguments
578 for kwarg in kwargs:
579 if kwarg not in defaults:
580 raise nx.NetworkXError(
581 f"Unrecognized visualization keyword argument: {kwarg}"
582 )
583
584 if canvas is None:
585 canvas = plt.gca()
586
587 if kwargs.get("hide_ticks", defaults["hide_ticks"]):
588 canvas.tick_params(
589 axis="both",
590 which="both",
591 bottom=False,
592 left=False,
593 labelbottom=False,
594 labelleft=False,
595 )
596
597 ### Helper methods and classes
598
599 def node_property_sequence(seq, attr):
600 """Return a list of attribute values for `seq`, using a default if needed"""
601
602 # All node attribute parameters start with "node_"
603 param_name = f"node_{attr}"
604 default = defaults[param_name]
605 attr = kwargs.get(param_name, attr)
606
607 if default is None:
608 # raise instead of using non-existant default value
609 for n in seq:
610 if attr not in node_subgraph.nodes[n]:
611 raise nx.NetworkXError(f"Attribute '{attr}' missing for node {n}")
612
613 # If `attr` is not a graph attr and was explicitly passed as an argument
614 # it must be a user-default value. Allow attr=None to tell draw to skip
615 # attributes which are on the graph
616 if (
617 attr is not None
618 and nx.get_node_attributes(node_subgraph, attr) == {}
619 and any(attr == v for k, v in kwargs.items() if "node" in k)
620 ):
621 return [attr for _ in seq]
622
623 return [node_subgraph.nodes[n].get(attr, default) for n in seq]
624
625 def compute_colors(color, alpha):
626 if isinstance(color, str):
627 rgba = mpl.colors.colorConverter.to_rgba(color)
628 # Using a non-default alpha value overrides any alpha value in the color
629 if alpha != defaults["node_alpha"]:
630 return (rgba[0], rgba[1], rgba[2], alpha)
631 return rgba
632
633 if isinstance(color, tuple) and len(color) == 3:
634 return (color[0], color[1], color[2], alpha)
635
636 if isinstance(color, tuple) and len(color) == 4:
637 return color
638
639 raise ValueError(f"Invalid format for color: {color}")
640
641 # Find which edges can be plotted as a line collection
642 #
643 # Non-default values for these attributes require fancy arrow patches:
644 # - any arrow style (including the default -|> for directed graphs)
645 # - arrow size (by extension of style)
646 # - connection style
647 # - min_source_margin
648 # - min_target_margin
649
650 def collection_compatible(e):
651 return (
652 get_edge_attr(e, "arrowstyle") == "-"
653 and get_edge_attr(e, "curvature") == "arc3"
654 and get_edge_attr(e, "source_margin") == 0
655 and get_edge_attr(e, "target_margin") == 0
656 # Self-loops will use fancy arrow patches
657 and e[0] != e[1]
658 )
659
660 def edge_property_sequence(seq, attr):
661 """Return a list of attribute values for `seq`, using a default if needed"""
662
663 param_name = f"edge_{attr}"
664 default = defaults[param_name]
665 attr = kwargs.get(param_name, attr)
666
667 if default is None:
668 # raise instead of using non-existant default value
669 for e in seq:
670 if attr not in edge_subgraph.edges[e]:
671 raise nx.NetworkXError(f"Attribute '{attr}' missing for edge {e}")
672
673 if (
674 attr is not None
675 and nx.get_edge_attributes(edge_subgraph, attr) == {}
676 and any(attr == v for k, v in kwargs.items() if "edge" in k)
677 ):
678 return [attr for _ in seq]
679
680 return [edge_subgraph.edges[e].get(attr, default) for e in seq]
681
682 def get_edge_attr(e, attr):
683 """Return the final edge attribute value, using default if not None"""
684
685 param_name = f"edge_{attr}"
686 default = defaults[param_name]
687 attr = kwargs.get(param_name, attr)
688
689 if default is None and attr not in edge_subgraph.edges[e]:
690 raise nx.NetworkXError(f"Attribute '{attr}' missing from edge {e}")
691
692 if (
693 attr is not None
694 and nx.get_edge_attributes(edge_subgraph, attr) == {}
695 and attr in kwargs.values()
696 ):
697 return attr
698
699 return edge_subgraph.edges[e].get(attr, default)
700
701 def get_node_attr(n, attr, use_edge_subgraph=True):
702 """Return the final node attribute value, using default if not None"""
703 subgraph = edge_subgraph if use_edge_subgraph else node_subgraph
704
705 param_name = f"node_{attr}"
706 default = defaults[param_name]
707 attr = kwargs.get(param_name, attr)
708
709 if default is None and attr not in subgraph.nodes[n]:
710 raise nx.NetworkXError(f"Attribute '{attr}' missing from node {n}")
711
712 if (
713 attr is not None
714 and nx.get_node_attributes(subgraph, attr) == {}
715 and attr in kwargs.values()
716 ):
717 return attr
718
719 return subgraph.nodes[n].get(attr, default)
720
721 # Taken from ConnectionStyleFactory
722 def self_loop(edge_index, node_size):
723 def self_loop_connection(posA, posB, *args, **kwargs):
724 if not np.all(posA == posB):
725 raise nx.NetworkXError(
726 "`self_loop` connection style method"
727 "is only to be used for self-loops"
728 )
729 # this is called with _screen space_ values
730 # so convert back to data space
731 data_loc = canvas.transData.inverted().transform(posA)
732 # Scale self loop based on the size of the base node
733 # Size of nodes are given in points ** 2 and each point is 1/72 of an inch
734 v_shift = np.sqrt(node_size) / 72
735 h_shift = v_shift * 0.5
736 # put the top of the loop first so arrow is not hidden by node
737 path = np.asarray(
738 [
739 # 1
740 [0, v_shift],
741 # 4 4 4
742 [h_shift, v_shift],
743 [h_shift, 0],
744 [0, 0],
745 # 4 4 4
746 [-h_shift, 0],
747 [-h_shift, v_shift],
748 [0, v_shift],
749 ]
750 )
751 # Rotate self loop 90 deg. if more than 1
752 # This will allow for maximum of 4 visible self loops
753 if edge_index % 4:
754 x, y = path.T
755 for _ in range(edge_index % 4):
756 x, y = y, -x
757 path = np.array([x, y]).T
758 return mpl.path.Path(
759 canvas.transData.transform(data_loc + path), [1, 4, 4, 4, 4, 4, 4]
760 )
761
762 return self_loop_connection
763
764 def to_marker_edge(size, marker):
765 if marker in "s^>v<d":
766 return np.sqrt(2 * size) / 2
767 else:
768 return np.sqrt(size) / 2
769
770 def build_fancy_arrow(e):
771 source_margin = to_marker_edge(
772 get_node_attr(e[0], "size"),
773 get_node_attr(e[0], "shape"),
774 )
775 source_margin = max(
776 source_margin,
777 get_edge_attr(e, "source_margin"),
778 )
779
780 target_margin = to_marker_edge(
781 get_node_attr(e[1], "size"),
782 get_node_attr(e[1], "shape"),
783 )
784 target_margin = max(
785 target_margin,
786 get_edge_attr(e, "target_margin"),
787 )
788 return mpl.patches.FancyArrowPatch(
789 edge_subgraph.nodes[e[0]][pos],
790 edge_subgraph.nodes[e[1]][pos],
791 arrowstyle=get_edge_attr(e, "arrowstyle"),
792 connectionstyle=(
793 get_edge_attr(e, "curvature")
794 if e[0] != e[1]
795 else self_loop(
796 0 if len(e) == 2 else e[2] % 4,
797 get_node_attr(e[0], "size"),
798 )
799 ),
800 color=get_edge_attr(e, "color"),
801 linestyle=get_edge_attr(e, "style"),
802 linewidth=get_edge_attr(e, "width"),
803 mutation_scale=get_edge_attr(e, "arrowsize"),
804 shrinkA=source_margin,
805 shrinkB=source_margin,
806 zorder=1,
807 )
808
809 class CurvedArrowText(CurvedArrowTextBase, mpl.text.Text):
810 pass
811
812 ### Draw the nodes first
813 node_visible = kwargs.get("node_visible", "visible")
814 if isinstance(node_visible, bool):
815 if node_visible:
816 visible_nodes = G.nodes()
817 else:
818 visible_nodes = []
819 else:
820 visible_nodes = [
821 n for n, v in nx.get_node_attributes(G, node_visible, True).items() if v
822 ]
823
824 node_subgraph = G.subgraph(visible_nodes)
825
826 # Ignore the default dict value since that's for default values to use, not
827 # default attribute name
828 pos = kwargs.get("node_pos", "pos")
829
830 default_display_pos_attr = "display's position attribute name"
831 if callable(pos):
832 nx.set_node_attributes(
833 node_subgraph, pos(node_subgraph), default_display_pos_attr
834 )
835 pos = default_display_pos_attr
836 kwargs["node_pos"] = default_display_pos_attr
837 elif nx.get_node_attributes(G, pos) == {}:
838 nx.set_node_attributes(
839 node_subgraph, nx.spring_layout(node_subgraph), default_display_pos_attr
840 )
841 pos = default_display_pos_attr
842 kwargs["node_pos"] = default_display_pos_attr
843
844 # Each shape requires a new scatter object since they can't have different
845 # shapes.
846 if len(visible_nodes) > 0:
847 node_shape = kwargs.get("node_shape", "shape")
848 for shape in Counter(
849 nx.get_node_attributes(
850 node_subgraph, node_shape, defaults["node_shape"]
851 ).values()
852 ):
853 # Filter position just on this shape.
854 nodes_with_shape = [
855 n
856 for n, s in node_subgraph.nodes(data=node_shape)
857 if s == shape or (s is None and shape == defaults["node_shape"])
858 ]
859 # There are two property sequences to create before hand.
860 # 1. position, since it is used for x and y parameters to scatter
861 # 2. edgecolor, since the spaeical 'face' parameter value can only be
862 # be passed in as the sole string, not part of a list of strings.
863 position = np.asarray(node_property_sequence(nodes_with_shape, "pos"))
864 color = np.asarray(
865 [
866 compute_colors(c, a)
867 for c, a in zip(
868 node_property_sequence(nodes_with_shape, "color"),
869 node_property_sequence(nodes_with_shape, "alpha"),
870 )
871 ]
872 )
873 border_color = np.asarray(
874 [
875 (
876 c
877 if (
878 c := get_node_attr(
879 n,
880 "border_color",
881 False,
882 )
883 )
884 != "face"
885 else color[i]
886 )
887 for i, n in enumerate(nodes_with_shape)
888 ]
889 )
890 canvas.scatter(
891 position[:, 0],
892 position[:, 1],
893 s=node_property_sequence(nodes_with_shape, "size"),
894 c=color,
895 marker=shape,
896 linewidths=node_property_sequence(nodes_with_shape, "border_width"),
897 edgecolors=border_color,
898 zorder=2,
899 )
900
901 ### Draw node labels
902 node_label = kwargs.get("node_label", "label")
903 # Plot labels if node_label is not None and not False
904 if node_label is not None and node_label is not False:
905 default_dict = {}
906 if isinstance(node_label, dict):
907 default_dict = node_label
908 node_label = None
909
910 for n, lbl in node_subgraph.nodes(data=node_label):
911 if lbl is False:
912 continue
913
914 # We work with label dicts down here...
915 if not isinstance(lbl, dict):
916 lbl = {"label": lbl if lbl is not None else n}
917
918 lbl_text = lbl.get("label", n)
919 if not isinstance(lbl_text, str):
920 lbl_text = str(lbl_text)
921
922 lbl.update(default_dict)
923 x, y = node_subgraph.nodes[n][pos]
924 canvas.text(
925 x,
926 y,
927 lbl_text,
928 size=lbl.get("size", defaults["node_label"]["size"]),
929 color=lbl.get("color", defaults["node_label"]["color"]),
930 family=lbl.get("family", defaults["node_label"]["family"]),
931 weight=lbl.get("weight", defaults["node_label"]["weight"]),
932 horizontalalignment=lbl.get(
933 "h_align", defaults["node_label"]["h_align"]
934 ),
935 verticalalignment=lbl.get("v_align", defaults["node_label"]["v_align"]),
936 transform=canvas.transData,
937 bbox=lbl.get("bbox", defaults["node_label"]["bbox"]),
938 )
939
940 ### Draw edges
941
942 edge_visible = kwargs.get("edge_visible", "visible")
943 if isinstance(edge_visible, bool):
944 if edge_visible:
945 visible_edges = G.edges()
946 else:
947 visible_edges = []
948 else:
949 visible_edges = [
950 e for e, v in nx.get_edge_attributes(G, edge_visible, True).items() if v
951 ]
952
953 edge_subgraph = G.edge_subgraph(visible_edges)
954 nx.set_node_attributes(
955 edge_subgraph, nx.get_node_attributes(node_subgraph, pos), name=pos
956 )
957
958 collection_edges = (
959 [e for e in edge_subgraph.edges(keys=True) if collection_compatible(e)]
960 if edge_subgraph.is_multigraph()
961 else [e for e in edge_subgraph.edges() if collection_compatible(e)]
962 )
963 non_collection_edges = (
964 [e for e in edge_subgraph.edges(keys=True) if not collection_compatible(e)]
965 if edge_subgraph.is_multigraph()
966 else [e for e in edge_subgraph.edges() if not collection_compatible(e)]
967 )
968 edge_position = np.asarray(
969 [
970 (
971 get_node_attr(u, "pos", use_edge_subgraph=True),
972 get_node_attr(v, "pos", use_edge_subgraph=True),
973 )
974 for u, v, *_ in collection_edges
975 ]
976 )
977
978 # Only plot a line collection if needed
979 if len(collection_edges) > 0:
980 edge_collection = mpl.collections.LineCollection(
981 edge_position,
982 colors=edge_property_sequence(collection_edges, "color"),
983 linewidths=edge_property_sequence(collection_edges, "width"),
984 linestyle=edge_property_sequence(collection_edges, "style"),
985 alpha=edge_property_sequence(collection_edges, "alpha"),
986 antialiaseds=(1,),
987 zorder=1,
988 )
989 canvas.add_collection(edge_collection)
990
991 fancy_arrows = {}
992 if len(non_collection_edges) > 0:
993 for e in non_collection_edges:
994 # Cache results for use in edge labels
995 fancy_arrows[e] = build_fancy_arrow(e)
996 canvas.add_patch(fancy_arrows[e])
997
998 ### Draw edge labels
999 edge_label = kwargs.get("edge_label", "label")
1000 default_dict = {}
1001 if isinstance(edge_label, dict):
1002 default_dict = edge_label
1003 # Restore the default label attribute key of 'label'
1004 edge_label = "label"
1005
1006 # Handle multigraphs
1007 edge_label_data = (
1008 edge_subgraph.edges(data=edge_label, keys=True)
1009 if edge_subgraph.is_multigraph()
1010 else edge_subgraph.edges(data=edge_label)
1011 )
1012 if edge_label is not None and edge_label is not False:
1013 for *e, lbl in edge_label_data:
1014 e = tuple(e)
1015 # I'm not sure how I want to handle None here... For now it means no label
1016 if lbl is False or lbl is None:
1017 continue
1018
1019 if not isinstance(lbl, dict):
1020 lbl = {"label": lbl}
1021
1022 lbl.update(default_dict)
1023 lbl_text = lbl.get("label")
1024 if not isinstance(lbl_text, str):
1025 lbl_text = str(lbl_text)
1026
1027 # In the old code, every non-self-loop is placed via a fancy arrow patch
1028 # Only compute a new fancy arrow if needed by caching the results from
1029 # edge placement.
1030 try:
1031 arrow = fancy_arrows[e]
1032 except KeyError:
1033 arrow = build_fancy_arrow(e)
1034
1035 if e[0] == e[1]:
1036 # Taken directly from draw_networkx_edge_labels
1037 connectionstyle_obj = arrow.get_connectionstyle()
1038 posA = canvas.transData.transform(edge_subgraph.nodes[e[0]][pos])
1039 path_disp = connectionstyle_obj(posA, posA)
1040 path_data = canvas.transData.inverted().transform_path(path_disp)
1041 x, y = path_data.vertices[0]
1042 canvas.text(
1043 x,
1044 y,
1045 lbl_text,
1046 size=lbl.get("size", defaults["edge_label"]["size"]),
1047 color=lbl.get("color", defaults["edge_label"]["color"]),
1048 family=lbl.get("family", defaults["edge_label"]["family"]),
1049 weight=lbl.get("weight", defaults["edge_label"]["weight"]),
1050 alpha=lbl.get("alpha", defaults["edge_label"]["alpha"]),
1051 horizontalalignment=lbl.get(
1052 "h_align", defaults["edge_label"]["h_align"]
1053 ),
1054 verticalalignment=lbl.get(
1055 "v_align", defaults["edge_label"]["v_align"]
1056 ),
1057 rotation=0,
1058 transform=canvas.transData,
1059 bbox=lbl.get("bbox", defaults["edge_label"]["bbox"]),
1060 zorder=1,
1061 )
1062 continue
1063
1064 CurvedArrowText(
1065 arrow,
1066 lbl_text,
1067 size=lbl.get("size", defaults["edge_label"]["size"]),
1068 color=lbl.get("color", defaults["edge_label"]["color"]),
1069 family=lbl.get("family", defaults["edge_label"]["family"]),
1070 weight=lbl.get("weight", defaults["edge_label"]["weight"]),
1071 alpha=lbl.get("alpha", defaults["edge_label"]["alpha"]),
1072 bbox=lbl.get("bbox", defaults["edge_label"]["bbox"]),
1073 horizontalalignment=lbl.get(
1074 "h_align", defaults["edge_label"]["h_align"]
1075 ),
1076 verticalalignment=lbl.get("v_align", defaults["edge_label"]["v_align"]),
1077 label_pos=lbl.get("pos", defaults["edge_label"]["pos"]),
1078 labels_horizontal=lbl.get("rotate", defaults["edge_label"]["rotate"]),
1079 transform=canvas.transData,
1080 zorder=1,
1081 ax=canvas,
1082 )
1083
1084 # If we had to add an attribute, remove it here
1085 if pos == default_display_pos_attr:
1086 nx.remove_node_attributes(G, default_display_pos_attr)
1087
1088 return G
1089
1090
1091def draw(G, pos=None, ax=None, **kwds):
1092 """Draw the graph G with Matplotlib.
1093
1094 Draw the graph as a simple representation with no node
1095 labels or edge labels and using the full Matplotlib figure area
1096 and no axis labels by default. See draw_networkx() for more
1097 full-featured drawing that allows title, axis labels etc.
1098
1099 Parameters
1100 ----------
1101 G : graph
1102 A networkx graph
1103
1104 pos : dictionary, optional
1105 A dictionary with nodes as keys and positions as values.
1106 If not specified a spring layout positioning will be computed.
1107 See :py:mod:`networkx.drawing.layout` for functions that
1108 compute node positions.
1109
1110 ax : Matplotlib Axes object, optional
1111 Draw the graph in specified Matplotlib axes.
1112
1113 kwds : optional keywords
1114 See networkx.draw_networkx() for a description of optional keywords.
1115
1116 Examples
1117 --------
1118 >>> G = nx.dodecahedral_graph()
1119 >>> nx.draw(G)
1120 >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout
1121
1122 See Also
1123 --------
1124 draw_networkx
1125 draw_networkx_nodes
1126 draw_networkx_edges
1127 draw_networkx_labels
1128 draw_networkx_edge_labels
1129
1130 Notes
1131 -----
1132 This function has the same name as pylab.draw and pyplot.draw
1133 so beware when using `from networkx import *`
1134
1135 since you might overwrite the pylab.draw function.
1136
1137 With pyplot use
1138
1139 >>> import matplotlib.pyplot as plt
1140 >>> G = nx.dodecahedral_graph()
1141 >>> nx.draw(G) # networkx draw()
1142 >>> plt.draw() # pyplot draw()
1143
1144 Also see the NetworkX drawing examples at
1145 https://networkx.org/documentation/latest/auto_examples/index.html
1146 """
1147
1148 import matplotlib.pyplot as plt
1149
1150 if ax is None:
1151 cf = plt.gcf()
1152 else:
1153 cf = ax.get_figure()
1154 cf.set_facecolor("w")
1155 if ax is None:
1156 if cf.axes:
1157 ax = cf.gca()
1158 else:
1159 ax = cf.add_axes((0, 0, 1, 1))
1160
1161 if "with_labels" not in kwds:
1162 kwds["with_labels"] = "labels" in kwds
1163
1164 draw_networkx(G, pos=pos, ax=ax, **kwds)
1165 ax.set_axis_off()
1166 plt.draw_if_interactive()
1167 return
1168
1169
1170def draw_networkx(G, pos=None, arrows=None, with_labels=True, **kwds):
1171 r"""Draw the graph G using Matplotlib.
1172
1173 Draw the graph with Matplotlib with options for node positions,
1174 labeling, titles, and many other drawing features.
1175 See draw() for simple drawing without labels or axes.
1176
1177 Parameters
1178 ----------
1179 G : graph
1180 A networkx graph
1181
1182 pos : dictionary, optional
1183 A dictionary with nodes as keys and positions as values.
1184 If not specified a spring layout positioning will be computed.
1185 See :py:mod:`networkx.drawing.layout` for functions that
1186 compute node positions.
1187
1188 arrows : bool or None, optional (default=None)
1189 If `None`, directed graphs draw arrowheads with
1190 `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges
1191 via `~matplotlib.collections.LineCollection` for speed.
1192 If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish).
1193 If `False`, draw edges using LineCollection (linear and fast).
1194 For directed graphs, if True draw arrowheads.
1195 Note: Arrows will be the same color as edges.
1196
1197 arrowstyle : str (default='-\|>' for directed graphs)
1198 For directed graphs, choose the style of the arrowsheads.
1199 For undirected graphs default to '-'
1200
1201 See `matplotlib.patches.ArrowStyle` for more options.
1202
1203 arrowsize : int or list (default=10)
1204 For directed graphs, choose the size of the arrow head's length and
1205 width. A list of values can be passed in to assign a different size for arrow head's length and width.
1206 See `matplotlib.patches.FancyArrowPatch` for attribute `mutation_scale`
1207 for more info.
1208
1209 with_labels : bool (default=True)
1210 Set to True to draw labels on the nodes.
1211
1212 ax : Matplotlib Axes object, optional
1213 Draw the graph in the specified Matplotlib axes.
1214
1215 nodelist : list (default=list(G))
1216 Draw only specified nodes
1217
1218 edgelist : list (default=list(G.edges()))
1219 Draw only specified edges
1220
1221 node_size : scalar or array (default=300)
1222 Size of nodes. If an array is specified it must be the
1223 same length as nodelist.
1224
1225 node_color : color or array of colors (default='#1f78b4')
1226 Node color. Can be a single color or a sequence of colors with the same
1227 length as nodelist. Color can be string or rgb (or rgba) tuple of
1228 floats from 0-1. If numeric values are specified they will be
1229 mapped to colors using the cmap and vmin,vmax parameters. See
1230 matplotlib.scatter for more details.
1231
1232 node_shape : string (default='o')
1233 The shape of the node. Specification is as matplotlib.scatter
1234 marker, one of 'so^>v<dph8'.
1235
1236 alpha : float or None (default=None)
1237 The node and edge transparency
1238
1239 cmap : Matplotlib colormap, optional
1240 Colormap for mapping intensities of nodes
1241
1242 vmin,vmax : float, optional
1243 Minimum and maximum for node colormap scaling
1244
1245 linewidths : scalar or sequence (default=1.0)
1246 Line width of symbol border
1247
1248 width : float or array of floats (default=1.0)
1249 Line width of edges
1250
1251 edge_color : color or array of colors (default='k')
1252 Edge color. Can be a single color or a sequence of colors with the same
1253 length as edgelist. Color can be string or rgb (or rgba) tuple of
1254 floats from 0-1. If numeric values are specified they will be
1255 mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters.
1256
1257 edge_cmap : Matplotlib colormap, optional
1258 Colormap for mapping intensities of edges
1259
1260 edge_vmin,edge_vmax : floats, optional
1261 Minimum and maximum for edge colormap scaling
1262
1263 style : string (default=solid line)
1264 Edge line style e.g.: '-', '--', '-.', ':'
1265 or words like 'solid' or 'dashed'.
1266 (See `matplotlib.patches.FancyArrowPatch`: `linestyle`)
1267
1268 labels : dictionary (default=None)
1269 Node labels in a dictionary of text labels keyed by node
1270
1271 font_size : int (default=12 for nodes, 10 for edges)
1272 Font size for text labels
1273
1274 font_color : color (default='k' black)
1275 Font color string. Color can be string or rgb (or rgba) tuple of
1276 floats from 0-1.
1277
1278 font_weight : string (default='normal')
1279 Font weight
1280
1281 font_family : string (default='sans-serif')
1282 Font family
1283
1284 label : string, optional
1285 Label for graph legend
1286
1287 hide_ticks : bool, optional
1288 Hide ticks of axes. When `True` (the default), ticks and ticklabels
1289 are removed from the axes. To set ticks and tick labels to the pyplot default,
1290 use ``hide_ticks=False``.
1291
1292 kwds : optional keywords
1293 See networkx.draw_networkx_nodes(), networkx.draw_networkx_edges(), and
1294 networkx.draw_networkx_labels() for a description of optional keywords.
1295
1296 Notes
1297 -----
1298 For directed graphs, arrows are drawn at the head end. Arrows can be
1299 turned off with keyword arrows=False.
1300
1301 Examples
1302 --------
1303 >>> G = nx.dodecahedral_graph()
1304 >>> nx.draw(G)
1305 >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout
1306
1307 >>> import matplotlib.pyplot as plt
1308 >>> limits = plt.axis("off") # turn off axis
1309
1310 Also see the NetworkX drawing examples at
1311 https://networkx.org/documentation/latest/auto_examples/index.html
1312
1313 See Also
1314 --------
1315 draw
1316 draw_networkx_nodes
1317 draw_networkx_edges
1318 draw_networkx_labels
1319 draw_networkx_edge_labels
1320 """
1321 from inspect import signature
1322
1323 import matplotlib.pyplot as plt
1324
1325 # Get all valid keywords by inspecting the signatures of draw_networkx_nodes,
1326 # draw_networkx_edges, draw_networkx_labels
1327
1328 valid_node_kwds = signature(draw_networkx_nodes).parameters.keys()
1329 valid_edge_kwds = signature(draw_networkx_edges).parameters.keys()
1330 valid_label_kwds = signature(draw_networkx_labels).parameters.keys()
1331
1332 # Create a set with all valid keywords across the three functions and
1333 # remove the arguments of this function (draw_networkx)
1334 valid_kwds = (valid_node_kwds | valid_edge_kwds | valid_label_kwds) - {
1335 "G",
1336 "pos",
1337 "arrows",
1338 "with_labels",
1339 }
1340
1341 if any(k not in valid_kwds for k in kwds):
1342 invalid_args = ", ".join([k for k in kwds if k not in valid_kwds])
1343 raise ValueError(f"Received invalid argument(s): {invalid_args}")
1344
1345 node_kwds = {k: v for k, v in kwds.items() if k in valid_node_kwds}
1346 edge_kwds = {k: v for k, v in kwds.items() if k in valid_edge_kwds}
1347 label_kwds = {k: v for k, v in kwds.items() if k in valid_label_kwds}
1348
1349 if pos is None:
1350 pos = nx.drawing.spring_layout(G) # default to spring layout
1351
1352 draw_networkx_nodes(G, pos, **node_kwds)
1353 draw_networkx_edges(G, pos, arrows=arrows, **edge_kwds)
1354 if with_labels:
1355 draw_networkx_labels(G, pos, **label_kwds)
1356 plt.draw_if_interactive()
1357
1358
1359def draw_networkx_nodes(
1360 G,
1361 pos,
1362 nodelist=None,
1363 node_size=300,
1364 node_color="#1f78b4",
1365 node_shape="o",
1366 alpha=None,
1367 cmap=None,
1368 vmin=None,
1369 vmax=None,
1370 ax=None,
1371 linewidths=None,
1372 edgecolors=None,
1373 label=None,
1374 margins=None,
1375 hide_ticks=True,
1376):
1377 """Draw the nodes of the graph G.
1378
1379 This draws only the nodes of the graph G.
1380
1381 Parameters
1382 ----------
1383 G : graph
1384 A networkx graph
1385
1386 pos : dictionary
1387 A dictionary with nodes as keys and positions as values.
1388 Positions should be sequences of length 2.
1389
1390 ax : Matplotlib Axes object, optional
1391 Draw the graph in the specified Matplotlib axes.
1392
1393 nodelist : list (default list(G))
1394 Draw only specified nodes
1395
1396 node_size : scalar or array (default=300)
1397 Size of nodes. If an array it must be the same length as nodelist.
1398
1399 node_color : color or array of colors (default='#1f78b4')
1400 Node color. Can be a single color or a sequence of colors with the same
1401 length as nodelist. Color can be string or rgb (or rgba) tuple of
1402 floats from 0-1. If numeric values are specified they will be
1403 mapped to colors using the cmap and vmin,vmax parameters. See
1404 matplotlib.scatter for more details.
1405
1406 node_shape : string (default='o')
1407 The shape of the node. Specification is as matplotlib.scatter
1408 marker, one of 'so^>v<dph8'.
1409
1410 alpha : float or array of floats (default=None)
1411 The node transparency. This can be a single alpha value,
1412 in which case it will be applied to all the nodes of color. Otherwise,
1413 if it is an array, the elements of alpha will be applied to the colors
1414 in order (cycling through alpha multiple times if necessary).
1415
1416 cmap : Matplotlib colormap (default=None)
1417 Colormap for mapping intensities of nodes
1418
1419 vmin,vmax : floats or None (default=None)
1420 Minimum and maximum for node colormap scaling
1421
1422 linewidths : [None | scalar | sequence] (default=1.0)
1423 Line width of symbol border
1424
1425 edgecolors : [None | scalar | sequence] (default = node_color)
1426 Colors of node borders. Can be a single color or a sequence of colors with the
1427 same length as nodelist. Color can be string or rgb (or rgba) tuple of floats
1428 from 0-1. If numeric values are specified they will be mapped to colors
1429 using the cmap and vmin,vmax parameters. See `~matplotlib.pyplot.scatter` for more details.
1430
1431 label : [None | string]
1432 Label for legend
1433
1434 margins : float or 2-tuple, optional
1435 Sets the padding for axis autoscaling. Increase margin to prevent
1436 clipping for nodes that are near the edges of an image. Values should
1437 be in the range ``[0, 1]``. See :meth:`matplotlib.axes.Axes.margins`
1438 for details. The default is `None`, which uses the Matplotlib default.
1439
1440 hide_ticks : bool, optional
1441 Hide ticks of axes. When `True` (the default), ticks and ticklabels
1442 are removed from the axes. To set ticks and tick labels to the pyplot default,
1443 use ``hide_ticks=False``.
1444
1445 Returns
1446 -------
1447 matplotlib.collections.PathCollection
1448 `PathCollection` of the nodes.
1449
1450 Examples
1451 --------
1452 >>> G = nx.dodecahedral_graph()
1453 >>> nodes = nx.draw_networkx_nodes(G, pos=nx.spring_layout(G))
1454
1455 Also see the NetworkX drawing examples at
1456 https://networkx.org/documentation/latest/auto_examples/index.html
1457
1458 See Also
1459 --------
1460 draw
1461 draw_networkx
1462 draw_networkx_edges
1463 draw_networkx_labels
1464 draw_networkx_edge_labels
1465 """
1466 from collections.abc import Iterable
1467
1468 import matplotlib as mpl
1469 import matplotlib.collections # call as mpl.collections
1470 import matplotlib.pyplot as plt
1471 import numpy as np
1472
1473 if ax is None:
1474 ax = plt.gca()
1475
1476 if nodelist is None:
1477 nodelist = list(G)
1478
1479 if len(nodelist) == 0: # empty nodelist, no drawing
1480 return mpl.collections.PathCollection(None)
1481
1482 try:
1483 xy = np.asarray([pos[v] for v in nodelist])
1484 except KeyError as err:
1485 raise nx.NetworkXError(f"Node {err} has no position.") from err
1486
1487 if isinstance(alpha, Iterable):
1488 node_color = apply_alpha(node_color, alpha, nodelist, cmap, vmin, vmax)
1489 alpha = None
1490
1491 if not isinstance(node_shape, np.ndarray) and not isinstance(node_shape, list):
1492 node_shape = np.array([node_shape for _ in range(len(nodelist))])
1493 elif isinstance(node_shape, list):
1494 node_shape = np.asarray(node_shape)
1495
1496 for shape in np.unique(node_shape):
1497 node_collection = ax.scatter(
1498 xy[node_shape == shape, 0],
1499 xy[node_shape == shape, 1],
1500 s=node_size,
1501 c=node_color,
1502 marker=shape,
1503 cmap=cmap,
1504 vmin=vmin,
1505 vmax=vmax,
1506 alpha=alpha,
1507 linewidths=linewidths,
1508 edgecolors=edgecolors,
1509 label=label,
1510 )
1511 if hide_ticks:
1512 ax.tick_params(
1513 axis="both",
1514 which="both",
1515 bottom=False,
1516 left=False,
1517 labelbottom=False,
1518 labelleft=False,
1519 )
1520
1521 if margins is not None:
1522 if isinstance(margins, Iterable):
1523 ax.margins(*margins)
1524 else:
1525 ax.margins(margins)
1526
1527 node_collection.set_zorder(2)
1528 return node_collection
1529
1530
1531class FancyArrowFactory:
1532 """Draw arrows with `matplotlib.patches.FancyarrowPatch`"""
1533
1534 class ConnectionStyleFactory:
1535 def __init__(self, connectionstyles, selfloop_height, ax=None):
1536 import matplotlib as mpl
1537 import matplotlib.path # call as mpl.path
1538 import numpy as np
1539
1540 self.ax = ax
1541 self.mpl = mpl
1542 self.np = np
1543 self.base_connection_styles = [
1544 mpl.patches.ConnectionStyle(cs) for cs in connectionstyles
1545 ]
1546 self.n = len(self.base_connection_styles)
1547 self.selfloop_height = selfloop_height
1548
1549 def curved(self, edge_index):
1550 return self.base_connection_styles[edge_index % self.n]
1551
1552 def self_loop(self, edge_index):
1553 def self_loop_connection(posA, posB, *args, **kwargs):
1554 if not self.np.all(posA == posB):
1555 raise nx.NetworkXError(
1556 "`self_loop` connection style method"
1557 "is only to be used for self-loops"
1558 )
1559 # this is called with _screen space_ values
1560 # so convert back to data space
1561 data_loc = self.ax.transData.inverted().transform(posA)
1562 v_shift = 0.1 * self.selfloop_height
1563 h_shift = v_shift * 0.5
1564 # put the top of the loop first so arrow is not hidden by node
1565 path = self.np.asarray(
1566 [
1567 # 1
1568 [0, v_shift],
1569 # 4 4 4
1570 [h_shift, v_shift],
1571 [h_shift, 0],
1572 [0, 0],
1573 # 4 4 4
1574 [-h_shift, 0],
1575 [-h_shift, v_shift],
1576 [0, v_shift],
1577 ]
1578 )
1579 # Rotate self loop 90 deg. if more than 1
1580 # This will allow for maximum of 4 visible self loops
1581 if edge_index % 4:
1582 x, y = path.T
1583 for _ in range(edge_index % 4):
1584 x, y = y, -x
1585 path = self.np.array([x, y]).T
1586 return self.mpl.path.Path(
1587 self.ax.transData.transform(data_loc + path), [1, 4, 4, 4, 4, 4, 4]
1588 )
1589
1590 return self_loop_connection
1591
1592 def __init__(
1593 self,
1594 edge_pos,
1595 edgelist,
1596 nodelist,
1597 edge_indices,
1598 node_size,
1599 selfloop_height,
1600 connectionstyle="arc3",
1601 node_shape="o",
1602 arrowstyle="-",
1603 arrowsize=10,
1604 edge_color="k",
1605 alpha=None,
1606 linewidth=1.0,
1607 style="solid",
1608 min_source_margin=0,
1609 min_target_margin=0,
1610 ax=None,
1611 ):
1612 import matplotlib as mpl
1613 import matplotlib.patches # call as mpl.patches
1614 import matplotlib.pyplot as plt
1615 import numpy as np
1616
1617 if isinstance(connectionstyle, str):
1618 connectionstyle = [connectionstyle]
1619 elif np.iterable(connectionstyle):
1620 connectionstyle = list(connectionstyle)
1621 else:
1622 msg = "ConnectionStyleFactory arg `connectionstyle` must be str or iterable"
1623 raise nx.NetworkXError(msg)
1624 self.ax = ax
1625 self.mpl = mpl
1626 self.np = np
1627 self.edge_pos = edge_pos
1628 self.edgelist = edgelist
1629 self.nodelist = nodelist
1630 self.node_shape = node_shape
1631 self.min_source_margin = min_source_margin
1632 self.min_target_margin = min_target_margin
1633 self.edge_indices = edge_indices
1634 self.node_size = node_size
1635 self.connectionstyle_factory = self.ConnectionStyleFactory(
1636 connectionstyle, selfloop_height, ax
1637 )
1638 self.arrowstyle = arrowstyle
1639 self.arrowsize = arrowsize
1640 self.arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha)
1641 self.linewidth = linewidth
1642 self.style = style
1643 if isinstance(arrowsize, list) and len(arrowsize) != len(edge_pos):
1644 raise ValueError("arrowsize should have the same length as edgelist")
1645
1646 def __call__(self, i):
1647 (x1, y1), (x2, y2) = self.edge_pos[i]
1648 shrink_source = 0 # space from source to tail
1649 shrink_target = 0 # space from head to target
1650 if (
1651 self.np.iterable(self.min_source_margin)
1652 and not isinstance(self.min_source_margin, str)
1653 and not isinstance(self.min_source_margin, tuple)
1654 ):
1655 min_source_margin = self.min_source_margin[i]
1656 else:
1657 min_source_margin = self.min_source_margin
1658
1659 if (
1660 self.np.iterable(self.min_target_margin)
1661 and not isinstance(self.min_target_margin, str)
1662 and not isinstance(self.min_target_margin, tuple)
1663 ):
1664 min_target_margin = self.min_target_margin[i]
1665 else:
1666 min_target_margin = self.min_target_margin
1667
1668 if self.np.iterable(self.node_size): # many node sizes
1669 source, target = self.edgelist[i][:2]
1670 source_node_size = self.node_size[self.nodelist.index(source)]
1671 target_node_size = self.node_size[self.nodelist.index(target)]
1672 shrink_source = self.to_marker_edge(source_node_size, self.node_shape)
1673 shrink_target = self.to_marker_edge(target_node_size, self.node_shape)
1674 else:
1675 shrink_source = self.to_marker_edge(self.node_size, self.node_shape)
1676 shrink_target = shrink_source
1677 shrink_source = max(shrink_source, min_source_margin)
1678 shrink_target = max(shrink_target, min_target_margin)
1679
1680 # scale factor of arrow head
1681 if isinstance(self.arrowsize, list):
1682 mutation_scale = self.arrowsize[i]
1683 else:
1684 mutation_scale = self.arrowsize
1685
1686 if len(self.arrow_colors) > i:
1687 arrow_color = self.arrow_colors[i]
1688 elif len(self.arrow_colors) == 1:
1689 arrow_color = self.arrow_colors[0]
1690 else: # Cycle through colors
1691 arrow_color = self.arrow_colors[i % len(self.arrow_colors)]
1692
1693 if self.np.iterable(self.linewidth):
1694 if len(self.linewidth) > i:
1695 linewidth = self.linewidth[i]
1696 else:
1697 linewidth = self.linewidth[i % len(self.linewidth)]
1698 else:
1699 linewidth = self.linewidth
1700
1701 if (
1702 self.np.iterable(self.style)
1703 and not isinstance(self.style, str)
1704 and not isinstance(self.style, tuple)
1705 ):
1706 if len(self.style) > i:
1707 linestyle = self.style[i]
1708 else: # Cycle through styles
1709 linestyle = self.style[i % len(self.style)]
1710 else:
1711 linestyle = self.style
1712
1713 if x1 == x2 and y1 == y2:
1714 connectionstyle = self.connectionstyle_factory.self_loop(
1715 self.edge_indices[i]
1716 )
1717 else:
1718 connectionstyle = self.connectionstyle_factory.curved(self.edge_indices[i])
1719
1720 if (
1721 self.np.iterable(self.arrowstyle)
1722 and not isinstance(self.arrowstyle, str)
1723 and not isinstance(self.arrowstyle, tuple)
1724 ):
1725 arrowstyle = self.arrowstyle[i]
1726 else:
1727 arrowstyle = self.arrowstyle
1728
1729 return self.mpl.patches.FancyArrowPatch(
1730 (x1, y1),
1731 (x2, y2),
1732 arrowstyle=arrowstyle,
1733 shrinkA=shrink_source,
1734 shrinkB=shrink_target,
1735 mutation_scale=mutation_scale,
1736 color=arrow_color,
1737 linewidth=linewidth,
1738 connectionstyle=connectionstyle,
1739 linestyle=linestyle,
1740 zorder=1, # arrows go behind nodes
1741 )
1742
1743 def to_marker_edge(self, marker_size, marker):
1744 if marker in "s^>v<d": # `large` markers need extra space
1745 return self.np.sqrt(2 * marker_size) / 2
1746 else:
1747 return self.np.sqrt(marker_size) / 2
1748
1749
1750def draw_networkx_edges(
1751 G,
1752 pos,
1753 edgelist=None,
1754 width=1.0,
1755 edge_color="k",
1756 style="solid",
1757 alpha=None,
1758 arrowstyle=None,
1759 arrowsize=10,
1760 edge_cmap=None,
1761 edge_vmin=None,
1762 edge_vmax=None,
1763 ax=None,
1764 arrows=None,
1765 label=None,
1766 node_size=300,
1767 nodelist=None,
1768 node_shape="o",
1769 connectionstyle="arc3",
1770 min_source_margin=0,
1771 min_target_margin=0,
1772 hide_ticks=True,
1773):
1774 r"""Draw the edges of the graph G.
1775
1776 This draws only the edges of the graph G.
1777
1778 Parameters
1779 ----------
1780 G : graph
1781 A networkx graph
1782
1783 pos : dictionary
1784 A dictionary with nodes as keys and positions as values.
1785 Positions should be sequences of length 2.
1786
1787 edgelist : collection of edge tuples (default=G.edges())
1788 Draw only specified edges
1789
1790 width : float or array of floats (default=1.0)
1791 Line width of edges
1792
1793 edge_color : color or array of colors (default='k')
1794 Edge color. Can be a single color or a sequence of colors with the same
1795 length as edgelist. Color can be string or rgb (or rgba) tuple of
1796 floats from 0-1. If numeric values are specified they will be
1797 mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters.
1798
1799 style : string or array of strings (default='solid')
1800 Edge line style e.g.: '-', '--', '-.', ':'
1801 or words like 'solid' or 'dashed'.
1802 Can be a single style or a sequence of styles with the same
1803 length as the edge list.
1804 If less styles than edges are given the styles will cycle.
1805 If more styles than edges are given the styles will be used sequentially
1806 and not be exhausted.
1807 Also, `(offset, onoffseq)` tuples can be used as style instead of a strings.
1808 (See `matplotlib.patches.FancyArrowPatch`: `linestyle`)
1809
1810 alpha : float or array of floats (default=None)
1811 The edge transparency. This can be a single alpha value,
1812 in which case it will be applied to all specified edges. Otherwise,
1813 if it is an array, the elements of alpha will be applied to the colors
1814 in order (cycling through alpha multiple times if necessary).
1815
1816 edge_cmap : Matplotlib colormap, optional
1817 Colormap for mapping intensities of edges
1818
1819 edge_vmin,edge_vmax : floats, optional
1820 Minimum and maximum for edge colormap scaling
1821
1822 ax : Matplotlib Axes object, optional
1823 Draw the graph in the specified Matplotlib axes.
1824
1825 arrows : bool or None, optional (default=None)
1826 If `None`, directed graphs draw arrowheads with
1827 `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges
1828 via `~matplotlib.collections.LineCollection` for speed.
1829 If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish).
1830 If `False`, draw edges using LineCollection (linear and fast).
1831
1832 Note: Arrowheads will be the same color as edges.
1833
1834 arrowstyle : str or list of strs (default='-\|>' for directed graphs)
1835 For directed graphs and `arrows==True` defaults to '-\|>',
1836 For undirected graphs default to '-'.
1837
1838 See `matplotlib.patches.ArrowStyle` for more options.
1839
1840 arrowsize : int or list of ints(default=10)
1841 For directed graphs, choose the size of the arrow head's length and
1842 width. See `matplotlib.patches.FancyArrowPatch` for attribute
1843 `mutation_scale` for more info.
1844
1845 connectionstyle : string or iterable of strings (default="arc3")
1846 Pass the connectionstyle parameter to create curved arc of rounding
1847 radius rad. For example, connectionstyle='arc3,rad=0.2'.
1848 See `matplotlib.patches.ConnectionStyle` and
1849 `matplotlib.patches.FancyArrowPatch` for more info.
1850 If Iterable, index indicates i'th edge key of MultiGraph
1851
1852 node_size : scalar or array (default=300)
1853 Size of nodes. Though the nodes are not drawn with this function, the
1854 node size is used in determining edge positioning.
1855
1856 nodelist : list, optional (default=G.nodes())
1857 This provides the node order for the `node_size` array (if it is an array).
1858
1859 node_shape : string (default='o')
1860 The marker used for nodes, used in determining edge positioning.
1861 Specification is as a `matplotlib.markers` marker, e.g. one of 'so^>v<dph8'.
1862
1863 label : None or string
1864 Label for legend
1865
1866 min_source_margin : int or list of ints (default=0)
1867 The minimum margin (gap) at the beginning of the edge at the source.
1868
1869 min_target_margin : int or list of ints (default=0)
1870 The minimum margin (gap) at the end of the edge at the target.
1871
1872 hide_ticks : bool, optional
1873 Hide ticks of axes. When `True` (the default), ticks and ticklabels
1874 are removed from the axes. To set ticks and tick labels to the pyplot default,
1875 use ``hide_ticks=False``.
1876
1877 Returns
1878 -------
1879 matplotlib.collections.LineCollection or a list of matplotlib.patches.FancyArrowPatch
1880 If ``arrows=True``, a list of FancyArrowPatches is returned.
1881 If ``arrows=False``, a LineCollection is returned.
1882 If ``arrows=None`` (the default), then a LineCollection is returned if
1883 `G` is undirected, otherwise returns a list of FancyArrowPatches.
1884
1885 Notes
1886 -----
1887 For directed graphs, arrows are drawn at the head end. Arrows can be
1888 turned off with keyword arrows=False or by passing an arrowstyle without
1889 an arrow on the end.
1890
1891 Be sure to include `node_size` as a keyword argument; arrows are
1892 drawn considering the size of nodes.
1893
1894 Self-loops are always drawn with `~matplotlib.patches.FancyArrowPatch`
1895 regardless of the value of `arrows` or whether `G` is directed.
1896 When ``arrows=False`` or ``arrows=None`` and `G` is undirected, the
1897 FancyArrowPatches corresponding to the self-loops are not explicitly
1898 returned. They should instead be accessed via the ``Axes.patches``
1899 attribute (see examples).
1900
1901 Examples
1902 --------
1903 >>> G = nx.dodecahedral_graph()
1904 >>> edges = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
1905
1906 >>> G = nx.DiGraph()
1907 >>> G.add_edges_from([(1, 2), (1, 3), (2, 3)])
1908 >>> arcs = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
1909 >>> alphas = [0.3, 0.4, 0.5]
1910 >>> for i, arc in enumerate(arcs): # change alpha values of arcs
1911 ... arc.set_alpha(alphas[i])
1912
1913 The FancyArrowPatches corresponding to self-loops are not always
1914 returned, but can always be accessed via the ``patches`` attribute of the
1915 `matplotlib.Axes` object.
1916
1917 >>> import matplotlib.pyplot as plt
1918 >>> fig, ax = plt.subplots()
1919 >>> G = nx.Graph([(0, 1), (0, 0)]) # Self-loop at node 0
1920 >>> edge_collection = nx.draw_networkx_edges(G, pos=nx.circular_layout(G), ax=ax)
1921 >>> self_loop_fap = ax.patches[0]
1922
1923 Also see the NetworkX drawing examples at
1924 https://networkx.org/documentation/latest/auto_examples/index.html
1925
1926 See Also
1927 --------
1928 draw
1929 draw_networkx
1930 draw_networkx_nodes
1931 draw_networkx_labels
1932 draw_networkx_edge_labels
1933
1934 """
1935 import warnings
1936
1937 import matplotlib as mpl
1938 import matplotlib.collections # call as mpl.collections
1939 import matplotlib.colors # call as mpl.colors
1940 import matplotlib.pyplot as plt
1941 import numpy as np
1942
1943 # The default behavior is to use LineCollection to draw edges for
1944 # undirected graphs (for performance reasons) and use FancyArrowPatches
1945 # for directed graphs.
1946 # The `arrows` keyword can be used to override the default behavior
1947 if arrows is None:
1948 use_linecollection = not (G.is_directed() or G.is_multigraph())
1949 else:
1950 if not isinstance(arrows, bool):
1951 raise TypeError("Argument `arrows` must be of type bool or None")
1952 use_linecollection = not arrows
1953
1954 if isinstance(connectionstyle, str):
1955 connectionstyle = [connectionstyle]
1956 elif np.iterable(connectionstyle):
1957 connectionstyle = list(connectionstyle)
1958 else:
1959 msg = "draw_networkx_edges arg `connectionstyle` must be str or iterable"
1960 raise nx.NetworkXError(msg)
1961
1962 # Some kwargs only apply to FancyArrowPatches. Warn users when they use
1963 # non-default values for these kwargs when LineCollection is being used
1964 # instead of silently ignoring the specified option
1965 if use_linecollection:
1966 msg = (
1967 "\n\nThe {0} keyword argument is not applicable when drawing edges\n"
1968 "with LineCollection.\n\n"
1969 "To make this warning go away, either specify `arrows=True` to\n"
1970 "force FancyArrowPatches or use the default values.\n"
1971 "Note that using FancyArrowPatches may be slow for large graphs.\n"
1972 )
1973 if arrowstyle is not None:
1974 warnings.warn(msg.format("arrowstyle"), category=UserWarning, stacklevel=2)
1975 if arrowsize != 10:
1976 warnings.warn(msg.format("arrowsize"), category=UserWarning, stacklevel=2)
1977 if min_source_margin != 0:
1978 warnings.warn(
1979 msg.format("min_source_margin"), category=UserWarning, stacklevel=2
1980 )
1981 if min_target_margin != 0:
1982 warnings.warn(
1983 msg.format("min_target_margin"), category=UserWarning, stacklevel=2
1984 )
1985 if any(cs != "arc3" for cs in connectionstyle):
1986 warnings.warn(
1987 msg.format("connectionstyle"), category=UserWarning, stacklevel=2
1988 )
1989
1990 # NOTE: Arrowstyle modification must occur after the warnings section
1991 if arrowstyle is None:
1992 arrowstyle = "-|>" if G.is_directed() else "-"
1993
1994 if ax is None:
1995 ax = plt.gca()
1996
1997 if edgelist is None:
1998 edgelist = list(G.edges) # (u, v, k) for multigraph (u, v) otherwise
1999
2000 if len(edgelist):
2001 if G.is_multigraph():
2002 key_count = collections.defaultdict(lambda: itertools.count(0))
2003 edge_indices = [next(key_count[tuple(e[:2])]) for e in edgelist]
2004 else:
2005 edge_indices = [0] * len(edgelist)
2006 else: # no edges!
2007 return []
2008
2009 if nodelist is None:
2010 nodelist = list(G.nodes())
2011
2012 # FancyArrowPatch handles color=None different from LineCollection
2013 if edge_color is None:
2014 edge_color = "k"
2015
2016 # set edge positions
2017 edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])
2018
2019 # Check if edge_color is an array of floats and map to edge_cmap.
2020 # This is the only case handled differently from matplotlib
2021 if (
2022 np.iterable(edge_color)
2023 and (len(edge_color) == len(edge_pos))
2024 and np.all([isinstance(c, Number) for c in edge_color])
2025 ):
2026 if edge_cmap is not None:
2027 assert isinstance(edge_cmap, mpl.colors.Colormap)
2028 else:
2029 edge_cmap = plt.get_cmap()
2030 if edge_vmin is None:
2031 edge_vmin = min(edge_color)
2032 if edge_vmax is None:
2033 edge_vmax = max(edge_color)
2034 color_normal = mpl.colors.Normalize(vmin=edge_vmin, vmax=edge_vmax)
2035 edge_color = [edge_cmap(color_normal(e)) for e in edge_color]
2036
2037 # compute initial view
2038 minx = np.amin(np.ravel(edge_pos[:, :, 0]))
2039 maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
2040 miny = np.amin(np.ravel(edge_pos[:, :, 1]))
2041 maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
2042 w = maxx - minx
2043 h = maxy - miny
2044
2045 # Self-loops are scaled by view extent, except in cases the extent
2046 # is 0, e.g. for a single node. In this case, fall back to scaling
2047 # by the maximum node size
2048 selfloop_height = h if h != 0 else 0.005 * np.array(node_size).max()
2049 fancy_arrow_factory = FancyArrowFactory(
2050 edge_pos,
2051 edgelist,
2052 nodelist,
2053 edge_indices,
2054 node_size,
2055 selfloop_height,
2056 connectionstyle,
2057 node_shape,
2058 arrowstyle,
2059 arrowsize,
2060 edge_color,
2061 alpha,
2062 width,
2063 style,
2064 min_source_margin,
2065 min_target_margin,
2066 ax=ax,
2067 )
2068
2069 # Draw the edges
2070 if use_linecollection:
2071 edge_collection = mpl.collections.LineCollection(
2072 edge_pos,
2073 colors=edge_color,
2074 linewidths=width,
2075 antialiaseds=(1,),
2076 linestyle=style,
2077 alpha=alpha,
2078 )
2079 edge_collection.set_cmap(edge_cmap)
2080 edge_collection.set_clim(edge_vmin, edge_vmax)
2081 edge_collection.set_zorder(1) # edges go behind nodes
2082 edge_collection.set_label(label)
2083 ax.add_collection(edge_collection)
2084 edge_viz_obj = edge_collection
2085
2086 # Make sure selfloop edges are also drawn
2087 # ---------------------------------------
2088 selfloops_to_draw = [loop for loop in nx.selfloop_edges(G) if loop in edgelist]
2089 if selfloops_to_draw:
2090 edgelist_tuple = list(map(tuple, edgelist))
2091 arrow_collection = []
2092 for loop in selfloops_to_draw:
2093 i = edgelist_tuple.index(loop)
2094 arrow = fancy_arrow_factory(i)
2095 arrow_collection.append(arrow)
2096 ax.add_patch(arrow)
2097 else:
2098 edge_viz_obj = []
2099 for i in range(len(edgelist)):
2100 arrow = fancy_arrow_factory(i)
2101 ax.add_patch(arrow)
2102 edge_viz_obj.append(arrow)
2103
2104 # update view after drawing
2105 padx, pady = 0.05 * w, 0.05 * h
2106 corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady)
2107 ax.update_datalim(corners)
2108 ax.autoscale_view()
2109
2110 if hide_ticks:
2111 ax.tick_params(
2112 axis="both",
2113 which="both",
2114 bottom=False,
2115 left=False,
2116 labelbottom=False,
2117 labelleft=False,
2118 )
2119
2120 return edge_viz_obj
2121
2122
2123def draw_networkx_labels(
2124 G,
2125 pos,
2126 labels=None,
2127 font_size=12,
2128 font_color="k",
2129 font_family="sans-serif",
2130 font_weight="normal",
2131 alpha=None,
2132 bbox=None,
2133 horizontalalignment="center",
2134 verticalalignment="center",
2135 ax=None,
2136 clip_on=True,
2137 hide_ticks=True,
2138):
2139 """Draw node labels on the graph G.
2140
2141 Parameters
2142 ----------
2143 G : graph
2144 A networkx graph
2145
2146 pos : dictionary
2147 A dictionary with nodes as keys and positions as values.
2148 Positions should be sequences of length 2.
2149
2150 labels : dictionary (default={n: n for n in G})
2151 Node labels in a dictionary of text labels keyed by node.
2152 Node-keys in labels should appear as keys in `pos`.
2153 If needed use: `{n:lab for n,lab in labels.items() if n in pos}`
2154
2155 font_size : int or dictionary of nodes to ints (default=12)
2156 Font size for text labels.
2157
2158 font_color : color or dictionary of nodes to colors (default='k' black)
2159 Font color string. Color can be string or rgb (or rgba) tuple of
2160 floats from 0-1.
2161
2162 font_weight : string or dictionary of nodes to strings (default='normal')
2163 Font weight.
2164
2165 font_family : string or dictionary of nodes to strings (default='sans-serif')
2166 Font family.
2167
2168 alpha : float or None or dictionary of nodes to floats (default=None)
2169 The text transparency.
2170
2171 bbox : Matplotlib bbox, (default is Matplotlib's ax.text default)
2172 Specify text box properties (e.g. shape, color etc.) for node labels.
2173
2174 horizontalalignment : string or array of strings (default='center')
2175 Horizontal alignment {'center', 'right', 'left'}. If an array is
2176 specified it must be the same length as `nodelist`.
2177
2178 verticalalignment : string (default='center')
2179 Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}.
2180 If an array is specified it must be the same length as `nodelist`.
2181
2182 ax : Matplotlib Axes object, optional
2183 Draw the graph in the specified Matplotlib axes.
2184
2185 clip_on : bool (default=True)
2186 Turn on clipping of node labels at axis boundaries
2187
2188 hide_ticks : bool, optional
2189 Hide ticks of axes. When `True` (the default), ticks and ticklabels
2190 are removed from the axes. To set ticks and tick labels to the pyplot default,
2191 use ``hide_ticks=False``.
2192
2193 Returns
2194 -------
2195 dict
2196 `dict` of labels keyed on the nodes
2197
2198 Examples
2199 --------
2200 >>> G = nx.dodecahedral_graph()
2201 >>> labels = nx.draw_networkx_labels(G, pos=nx.spring_layout(G))
2202
2203 Also see the NetworkX drawing examples at
2204 https://networkx.org/documentation/latest/auto_examples/index.html
2205
2206 See Also
2207 --------
2208 draw
2209 draw_networkx
2210 draw_networkx_nodes
2211 draw_networkx_edges
2212 draw_networkx_edge_labels
2213 """
2214 import matplotlib.pyplot as plt
2215
2216 if ax is None:
2217 ax = plt.gca()
2218
2219 if labels is None:
2220 labels = {n: n for n in G.nodes()}
2221
2222 individual_params = set()
2223
2224 def check_individual_params(p_value, p_name):
2225 if isinstance(p_value, dict):
2226 if len(p_value) != len(labels):
2227 raise ValueError(f"{p_name} must have the same length as labels.")
2228 individual_params.add(p_name)
2229
2230 def get_param_value(node, p_value, p_name):
2231 if p_name in individual_params:
2232 return p_value[node]
2233 return p_value
2234
2235 check_individual_params(font_size, "font_size")
2236 check_individual_params(font_color, "font_color")
2237 check_individual_params(font_weight, "font_weight")
2238 check_individual_params(font_family, "font_family")
2239 check_individual_params(alpha, "alpha")
2240
2241 text_items = {} # there is no text collection so we'll fake one
2242 for n, label in labels.items():
2243 (x, y) = pos[n]
2244 if not isinstance(label, str):
2245 label = str(label) # this makes "1" and 1 labeled the same
2246 t = ax.text(
2247 x,
2248 y,
2249 label,
2250 size=get_param_value(n, font_size, "font_size"),
2251 color=get_param_value(n, font_color, "font_color"),
2252 family=get_param_value(n, font_family, "font_family"),
2253 weight=get_param_value(n, font_weight, "font_weight"),
2254 alpha=get_param_value(n, alpha, "alpha"),
2255 horizontalalignment=horizontalalignment,
2256 verticalalignment=verticalalignment,
2257 transform=ax.transData,
2258 bbox=bbox,
2259 clip_on=clip_on,
2260 )
2261 text_items[n] = t
2262
2263 if hide_ticks:
2264 ax.tick_params(
2265 axis="both",
2266 which="both",
2267 bottom=False,
2268 left=False,
2269 labelbottom=False,
2270 labelleft=False,
2271 )
2272
2273 return text_items
2274
2275
2276def draw_networkx_edge_labels(
2277 G,
2278 pos,
2279 edge_labels=None,
2280 label_pos=0.5,
2281 font_size=10,
2282 font_color="k",
2283 font_family="sans-serif",
2284 font_weight="normal",
2285 alpha=None,
2286 bbox=None,
2287 horizontalalignment="center",
2288 verticalalignment="center",
2289 ax=None,
2290 rotate=True,
2291 clip_on=True,
2292 node_size=300,
2293 nodelist=None,
2294 connectionstyle="arc3",
2295 hide_ticks=True,
2296):
2297 """Draw edge labels.
2298
2299 Parameters
2300 ----------
2301 G : graph
2302 A networkx graph
2303
2304 pos : dictionary
2305 A dictionary with nodes as keys and positions as values.
2306 Positions should be sequences of length 2.
2307
2308 edge_labels : dictionary (default=None)
2309 Edge labels in a dictionary of labels keyed by edge two-tuple.
2310 Only labels for the keys in the dictionary are drawn.
2311
2312 label_pos : float (default=0.5)
2313 Position of edge label along edge (0=head, 0.5=center, 1=tail)
2314
2315 font_size : int (default=10)
2316 Font size for text labels
2317
2318 font_color : color (default='k' black)
2319 Font color string. Color can be string or rgb (or rgba) tuple of
2320 floats from 0-1.
2321
2322 font_weight : string (default='normal')
2323 Font weight
2324
2325 font_family : string (default='sans-serif')
2326 Font family
2327
2328 alpha : float or None (default=None)
2329 The text transparency
2330
2331 bbox : Matplotlib bbox, optional
2332 Specify text box properties (e.g. shape, color etc.) for edge labels.
2333 Default is {boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)}.
2334
2335 horizontalalignment : string (default='center')
2336 Horizontal alignment {'center', 'right', 'left'}
2337
2338 verticalalignment : string (default='center')
2339 Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}
2340
2341 ax : Matplotlib Axes object, optional
2342 Draw the graph in the specified Matplotlib axes.
2343
2344 rotate : bool (default=True)
2345 Rotate edge labels to lie parallel to edges
2346
2347 clip_on : bool (default=True)
2348 Turn on clipping of edge labels at axis boundaries
2349
2350 node_size : scalar or array (default=300)
2351 Size of nodes. If an array it must be the same length as nodelist.
2352
2353 nodelist : list, optional (default=G.nodes())
2354 This provides the node order for the `node_size` array (if it is an array).
2355
2356 connectionstyle : string or iterable of strings (default="arc3")
2357 Pass the connectionstyle parameter to create curved arc of rounding
2358 radius rad. For example, connectionstyle='arc3,rad=0.2'.
2359 See `matplotlib.patches.ConnectionStyle` and
2360 `matplotlib.patches.FancyArrowPatch` for more info.
2361 If Iterable, index indicates i'th edge key of MultiGraph
2362
2363 hide_ticks : bool, optional
2364 Hide ticks of axes. When `True` (the default), ticks and ticklabels
2365 are removed from the axes. To set ticks and tick labels to the pyplot default,
2366 use ``hide_ticks=False``.
2367
2368 Returns
2369 -------
2370 dict
2371 `dict` of labels keyed by edge
2372
2373 Examples
2374 --------
2375 >>> G = nx.dodecahedral_graph()
2376 >>> edge_labels = nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G))
2377
2378 Also see the NetworkX drawing examples at
2379 https://networkx.org/documentation/latest/auto_examples/index.html
2380
2381 See Also
2382 --------
2383 draw
2384 draw_networkx
2385 draw_networkx_nodes
2386 draw_networkx_edges
2387 draw_networkx_labels
2388 """
2389 import matplotlib as mpl
2390 import matplotlib.pyplot as plt
2391 import numpy as np
2392
2393 class CurvedArrowText(CurvedArrowTextBase, mpl.text.Text):
2394 pass
2395
2396 # use default box of white with white border
2397 if bbox is None:
2398 bbox = {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)}
2399
2400 if isinstance(connectionstyle, str):
2401 connectionstyle = [connectionstyle]
2402 elif np.iterable(connectionstyle):
2403 connectionstyle = list(connectionstyle)
2404 else:
2405 raise nx.NetworkXError(
2406 "draw_networkx_edges arg `connectionstyle` must be"
2407 "string or iterable of strings"
2408 )
2409
2410 if ax is None:
2411 ax = plt.gca()
2412
2413 if edge_labels is None:
2414 kwds = {"keys": True} if G.is_multigraph() else {}
2415 edge_labels = {tuple(edge): d for *edge, d in G.edges(data=True, **kwds)}
2416 # NOTHING TO PLOT
2417 if not edge_labels:
2418 return {}
2419 edgelist, labels = zip(*edge_labels.items())
2420
2421 if nodelist is None:
2422 nodelist = list(G.nodes())
2423
2424 # set edge positions
2425 edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])
2426
2427 if G.is_multigraph():
2428 key_count = collections.defaultdict(lambda: itertools.count(0))
2429 edge_indices = [next(key_count[tuple(e[:2])]) for e in edgelist]
2430 else:
2431 edge_indices = [0] * len(edgelist)
2432
2433 # Used to determine self loop mid-point
2434 # Note, that this will not be accurate,
2435 # if not drawing edge_labels for all edges drawn
2436 h = 0
2437 if edge_labels:
2438 miny = np.amin(np.ravel(edge_pos[:, :, 1]))
2439 maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
2440 h = maxy - miny
2441 selfloop_height = h if h != 0 else 0.005 * np.array(node_size).max()
2442 fancy_arrow_factory = FancyArrowFactory(
2443 edge_pos,
2444 edgelist,
2445 nodelist,
2446 edge_indices,
2447 node_size,
2448 selfloop_height,
2449 connectionstyle,
2450 ax=ax,
2451 )
2452
2453 individual_params = {}
2454
2455 def check_individual_params(p_value, p_name):
2456 # TODO should this be list or array (as in a numpy array)?
2457 if isinstance(p_value, list):
2458 if len(p_value) != len(edgelist):
2459 raise ValueError(f"{p_name} must have the same length as edgelist.")
2460 individual_params[p_name] = p_value.iter()
2461
2462 # Don't need to pass in an edge because these are lists, not dicts
2463 def get_param_value(p_value, p_name):
2464 if p_name in individual_params:
2465 return next(individual_params[p_name])
2466 return p_value
2467
2468 check_individual_params(font_size, "font_size")
2469 check_individual_params(font_color, "font_color")
2470 check_individual_params(font_weight, "font_weight")
2471 check_individual_params(alpha, "alpha")
2472 check_individual_params(horizontalalignment, "horizontalalignment")
2473 check_individual_params(verticalalignment, "verticalalignment")
2474 check_individual_params(rotate, "rotate")
2475 check_individual_params(label_pos, "label_pos")
2476
2477 text_items = {}
2478 for i, (edge, label) in enumerate(zip(edgelist, labels)):
2479 if not isinstance(label, str):
2480 label = str(label) # this makes "1" and 1 labeled the same
2481
2482 n1, n2 = edge[:2]
2483 arrow = fancy_arrow_factory(i)
2484 if n1 == n2:
2485 connectionstyle_obj = arrow.get_connectionstyle()
2486 posA = ax.transData.transform(pos[n1])
2487 path_disp = connectionstyle_obj(posA, posA)
2488 path_data = ax.transData.inverted().transform_path(path_disp)
2489 x, y = path_data.vertices[0]
2490 text_items[edge] = ax.text(
2491 x,
2492 y,
2493 label,
2494 size=get_param_value(font_size, "font_size"),
2495 color=get_param_value(font_color, "font_color"),
2496 family=get_param_value(font_family, "font_family"),
2497 weight=get_param_value(font_weight, "font_weight"),
2498 alpha=get_param_value(alpha, "alpha"),
2499 horizontalalignment=get_param_value(
2500 horizontalalignment, "horizontalalignment"
2501 ),
2502 verticalalignment=get_param_value(
2503 verticalalignment, "verticalalignment"
2504 ),
2505 rotation=0,
2506 transform=ax.transData,
2507 bbox=bbox,
2508 zorder=1,
2509 clip_on=clip_on,
2510 )
2511 else:
2512 text_items[edge] = CurvedArrowText(
2513 arrow,
2514 label,
2515 size=get_param_value(font_size, "font_size"),
2516 color=get_param_value(font_color, "font_color"),
2517 family=get_param_value(font_family, "font_family"),
2518 weight=get_param_value(font_weight, "font_weight"),
2519 alpha=get_param_value(alpha, "alpha"),
2520 horizontalalignment=get_param_value(
2521 horizontalalignment, "horizontalalignment"
2522 ),
2523 verticalalignment=get_param_value(
2524 verticalalignment, "verticalalignment"
2525 ),
2526 transform=ax.transData,
2527 bbox=bbox,
2528 zorder=1,
2529 clip_on=clip_on,
2530 label_pos=get_param_value(label_pos, "label_pos"),
2531 labels_horizontal=not get_param_value(rotate, "rotate"),
2532 ax=ax,
2533 )
2534
2535 if hide_ticks:
2536 ax.tick_params(
2537 axis="both",
2538 which="both",
2539 bottom=False,
2540 left=False,
2541 labelbottom=False,
2542 labelleft=False,
2543 )
2544
2545 return text_items
2546
2547
2548def draw_bipartite(G, **kwargs):
2549 """Draw the graph `G` with a bipartite layout.
2550
2551 This is a convenience function equivalent to::
2552
2553 nx.draw(G, pos=nx.bipartite_layout(G), **kwargs)
2554
2555 Parameters
2556 ----------
2557 G : graph
2558 A networkx graph
2559
2560 kwargs : optional keywords
2561 See `draw_networkx` for a description of optional keywords.
2562
2563 Raises
2564 ------
2565 NetworkXError :
2566 If `G` is not bipartite.
2567
2568 Notes
2569 -----
2570 The layout is computed each time this function is called. For
2571 repeated drawing it is much more efficient to call
2572 `~networkx.drawing.layout.bipartite_layout` directly and reuse the result::
2573
2574 >>> G = nx.complete_bipartite_graph(3, 3)
2575 >>> pos = nx.bipartite_layout(G)
2576 >>> nx.draw(G, pos=pos) # Draw the original graph
2577 >>> # Draw a subgraph, reusing the same node positions
2578 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
2579
2580 Examples
2581 --------
2582 >>> G = nx.complete_bipartite_graph(2, 5)
2583 >>> nx.draw_bipartite(G)
2584
2585 See Also
2586 --------
2587 :func:`~networkx.drawing.layout.bipartite_layout`
2588 """
2589 draw(G, pos=nx.bipartite_layout(G), **kwargs)
2590
2591
2592def draw_circular(G, **kwargs):
2593 """Draw the graph `G` with a circular layout.
2594
2595 This is a convenience function equivalent to::
2596
2597 nx.draw(G, pos=nx.circular_layout(G), **kwargs)
2598
2599 Parameters
2600 ----------
2601 G : graph
2602 A networkx graph
2603
2604 kwargs : optional keywords
2605 See `draw_networkx` for a description of optional keywords.
2606
2607 Notes
2608 -----
2609 The layout is computed each time this function is called. For
2610 repeated drawing it is much more efficient to call
2611 `~networkx.drawing.layout.circular_layout` directly and reuse the result::
2612
2613 >>> G = nx.complete_graph(5)
2614 >>> pos = nx.circular_layout(G)
2615 >>> nx.draw(G, pos=pos) # Draw the original graph
2616 >>> # Draw a subgraph, reusing the same node positions
2617 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
2618
2619 Examples
2620 --------
2621 >>> G = nx.path_graph(5)
2622 >>> nx.draw_circular(G)
2623
2624 See Also
2625 --------
2626 :func:`~networkx.drawing.layout.circular_layout`
2627 """
2628 draw(G, pos=nx.circular_layout(G), **kwargs)
2629
2630
2631def draw_kamada_kawai(G, **kwargs):
2632 """Draw the graph `G` with a Kamada-Kawai force-directed layout.
2633
2634 This is a convenience function equivalent to::
2635
2636 nx.draw(G, pos=nx.kamada_kawai_layout(G), **kwargs)
2637
2638 Parameters
2639 ----------
2640 G : graph
2641 A networkx graph
2642
2643 kwargs : optional keywords
2644 See `draw_networkx` for a description of optional keywords.
2645
2646 Notes
2647 -----
2648 The layout is computed each time this function is called.
2649 For repeated drawing it is much more efficient to call
2650 `~networkx.drawing.layout.kamada_kawai_layout` directly and reuse the
2651 result::
2652
2653 >>> G = nx.complete_graph(5)
2654 >>> pos = nx.kamada_kawai_layout(G)
2655 >>> nx.draw(G, pos=pos) # Draw the original graph
2656 >>> # Draw a subgraph, reusing the same node positions
2657 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
2658
2659 Examples
2660 --------
2661 >>> G = nx.path_graph(5)
2662 >>> nx.draw_kamada_kawai(G)
2663
2664 See Also
2665 --------
2666 :func:`~networkx.drawing.layout.kamada_kawai_layout`
2667 """
2668 draw(G, pos=nx.kamada_kawai_layout(G), **kwargs)
2669
2670
2671def draw_random(G, **kwargs):
2672 """Draw the graph `G` with a random layout.
2673
2674 This is a convenience function equivalent to::
2675
2676 nx.draw(G, pos=nx.random_layout(G), **kwargs)
2677
2678 Parameters
2679 ----------
2680 G : graph
2681 A networkx graph
2682
2683 kwargs : optional keywords
2684 See `draw_networkx` for a description of optional keywords.
2685
2686 Notes
2687 -----
2688 The layout is computed each time this function is called.
2689 For repeated drawing it is much more efficient to call
2690 `~networkx.drawing.layout.random_layout` directly and reuse the result::
2691
2692 >>> G = nx.complete_graph(5)
2693 >>> pos = nx.random_layout(G)
2694 >>> nx.draw(G, pos=pos) # Draw the original graph
2695 >>> # Draw a subgraph, reusing the same node positions
2696 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
2697
2698 Examples
2699 --------
2700 >>> G = nx.lollipop_graph(4, 3)
2701 >>> nx.draw_random(G)
2702
2703 See Also
2704 --------
2705 :func:`~networkx.drawing.layout.random_layout`
2706 """
2707 draw(G, pos=nx.random_layout(G), **kwargs)
2708
2709
2710def draw_spectral(G, **kwargs):
2711 """Draw the graph `G` with a spectral 2D layout.
2712
2713 This is a convenience function equivalent to::
2714
2715 nx.draw(G, pos=nx.spectral_layout(G), **kwargs)
2716
2717 For more information about how node positions are determined, see
2718 `~networkx.drawing.layout.spectral_layout`.
2719
2720 Parameters
2721 ----------
2722 G : graph
2723 A networkx graph
2724
2725 kwargs : optional keywords
2726 See `draw_networkx` for a description of optional keywords.
2727
2728 Notes
2729 -----
2730 The layout is computed each time this function is called.
2731 For repeated drawing it is much more efficient to call
2732 `~networkx.drawing.layout.spectral_layout` directly and reuse the result::
2733
2734 >>> G = nx.complete_graph(5)
2735 >>> pos = nx.spectral_layout(G)
2736 >>> nx.draw(G, pos=pos) # Draw the original graph
2737 >>> # Draw a subgraph, reusing the same node positions
2738 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
2739
2740 Examples
2741 --------
2742 >>> G = nx.path_graph(5)
2743 >>> nx.draw_spectral(G)
2744
2745 See Also
2746 --------
2747 :func:`~networkx.drawing.layout.spectral_layout`
2748 """
2749 draw(G, pos=nx.spectral_layout(G), **kwargs)
2750
2751
2752def draw_spring(G, **kwargs):
2753 """Draw the graph `G` with a spring layout.
2754
2755 This is a convenience function equivalent to::
2756
2757 nx.draw(G, pos=nx.spring_layout(G), **kwargs)
2758
2759 Parameters
2760 ----------
2761 G : graph
2762 A networkx graph
2763
2764 kwargs : optional keywords
2765 See `draw_networkx` for a description of optional keywords.
2766
2767 Notes
2768 -----
2769 `~networkx.drawing.layout.spring_layout` is also the default layout for
2770 `draw`, so this function is equivalent to `draw`.
2771
2772 The layout is computed each time this function is called.
2773 For repeated drawing it is much more efficient to call
2774 `~networkx.drawing.layout.spring_layout` directly and reuse the result::
2775
2776 >>> G = nx.complete_graph(5)
2777 >>> pos = nx.spring_layout(G)
2778 >>> nx.draw(G, pos=pos) # Draw the original graph
2779 >>> # Draw a subgraph, reusing the same node positions
2780 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
2781
2782 Examples
2783 --------
2784 >>> G = nx.path_graph(20)
2785 >>> nx.draw_spring(G)
2786
2787 See Also
2788 --------
2789 draw
2790 :func:`~networkx.drawing.layout.spring_layout`
2791 """
2792 draw(G, pos=nx.spring_layout(G), **kwargs)
2793
2794
2795def draw_shell(G, nlist=None, **kwargs):
2796 """Draw networkx graph `G` with shell layout.
2797
2798 This is a convenience function equivalent to::
2799
2800 nx.draw(G, pos=nx.shell_layout(G, nlist=nlist), **kwargs)
2801
2802 Parameters
2803 ----------
2804 G : graph
2805 A networkx graph
2806
2807 nlist : list of list of nodes, optional
2808 A list containing lists of nodes representing the shells.
2809 Default is `None`, meaning all nodes are in a single shell.
2810 See `~networkx.drawing.layout.shell_layout` for details.
2811
2812 kwargs : optional keywords
2813 See `draw_networkx` for a description of optional keywords.
2814
2815 Notes
2816 -----
2817 The layout is computed each time this function is called.
2818 For repeated drawing it is much more efficient to call
2819 `~networkx.drawing.layout.shell_layout` directly and reuse the result::
2820
2821 >>> G = nx.complete_graph(5)
2822 >>> pos = nx.shell_layout(G)
2823 >>> nx.draw(G, pos=pos) # Draw the original graph
2824 >>> # Draw a subgraph, reusing the same node positions
2825 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
2826
2827 Examples
2828 --------
2829 >>> G = nx.path_graph(4)
2830 >>> shells = [[0], [1, 2, 3]]
2831 >>> nx.draw_shell(G, nlist=shells)
2832
2833 See Also
2834 --------
2835 :func:`~networkx.drawing.layout.shell_layout`
2836 """
2837 draw(G, pos=nx.shell_layout(G, nlist=nlist), **kwargs)
2838
2839
2840def draw_planar(G, **kwargs):
2841 """Draw a planar networkx graph `G` with planar layout.
2842
2843 This is a convenience function equivalent to::
2844
2845 nx.draw(G, pos=nx.planar_layout(G), **kwargs)
2846
2847 Parameters
2848 ----------
2849 G : graph
2850 A planar networkx graph
2851
2852 kwargs : optional keywords
2853 See `draw_networkx` for a description of optional keywords.
2854
2855 Raises
2856 ------
2857 NetworkXException
2858 When `G` is not planar
2859
2860 Notes
2861 -----
2862 The layout is computed each time this function is called.
2863 For repeated drawing it is much more efficient to call
2864 `~networkx.drawing.layout.planar_layout` directly and reuse the result::
2865
2866 >>> G = nx.path_graph(5)
2867 >>> pos = nx.planar_layout(G)
2868 >>> nx.draw(G, pos=pos) # Draw the original graph
2869 >>> # Draw a subgraph, reusing the same node positions
2870 >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
2871
2872 Examples
2873 --------
2874 >>> G = nx.path_graph(4)
2875 >>> nx.draw_planar(G)
2876
2877 See Also
2878 --------
2879 :func:`~networkx.drawing.layout.planar_layout`
2880 """
2881 draw(G, pos=nx.planar_layout(G), **kwargs)
2882
2883
2884def draw_forceatlas2(G, **kwargs):
2885 """Draw a networkx graph with forceatlas2 layout.
2886
2887 This is a convenience function equivalent to::
2888
2889 nx.draw(G, pos=nx.forceatlas2_layout(G), **kwargs)
2890
2891 Parameters
2892 ----------
2893 G : graph
2894 A networkx graph
2895
2896 kwargs : optional keywords
2897 See networkx.draw_networkx() for a description of optional keywords,
2898 with the exception of the pos parameter which is not used by this
2899 function.
2900 """
2901 draw(G, pos=nx.forceatlas2_layout(G), **kwargs)
2902
2903
2904def apply_alpha(colors, alpha, elem_list, cmap=None, vmin=None, vmax=None):
2905 """Apply an alpha (or list of alphas) to the colors provided.
2906
2907 Parameters
2908 ----------
2909
2910 colors : color string or array of floats (default='r')
2911 Color of element. Can be a single color format string,
2912 or a sequence of colors with the same length as nodelist.
2913 If numeric values are specified they will be mapped to
2914 colors using the cmap and vmin,vmax parameters. See
2915 matplotlib.scatter for more details.
2916
2917 alpha : float or array of floats
2918 Alpha values for elements. This can be a single alpha value, in
2919 which case it will be applied to all the elements of color. Otherwise,
2920 if it is an array, the elements of alpha will be applied to the colors
2921 in order (cycling through alpha multiple times if necessary).
2922
2923 elem_list : array of networkx objects
2924 The list of elements which are being colored. These could be nodes,
2925 edges or labels.
2926
2927 cmap : matplotlib colormap
2928 Color map for use if colors is a list of floats corresponding to points
2929 on a color mapping.
2930
2931 vmin, vmax : float
2932 Minimum and maximum values for normalizing colors if a colormap is used
2933
2934 Returns
2935 -------
2936
2937 rgba_colors : numpy ndarray
2938 Array containing RGBA format values for each of the node colours.
2939
2940 """
2941 from itertools import cycle, islice
2942
2943 import matplotlib as mpl
2944 import matplotlib.cm # call as mpl.cm
2945 import matplotlib.colors # call as mpl.colors
2946 import numpy as np
2947
2948 # If we have been provided with a list of numbers as long as elem_list,
2949 # apply the color mapping.
2950 if len(colors) == len(elem_list) and isinstance(colors[0], Number):
2951 mapper = mpl.cm.ScalarMappable(cmap=cmap)
2952 mapper.set_clim(vmin, vmax)
2953 rgba_colors = mapper.to_rgba(colors)
2954 # Otherwise, convert colors to matplotlib's RGB using the colorConverter
2955 # object. These are converted to numpy ndarrays to be consistent with the
2956 # to_rgba method of ScalarMappable.
2957 else:
2958 try:
2959 rgba_colors = np.array([mpl.colors.colorConverter.to_rgba(colors)])
2960 except ValueError:
2961 rgba_colors = np.array(
2962 [mpl.colors.colorConverter.to_rgba(color) for color in colors]
2963 )
2964 # Set the final column of the rgba_colors to have the relevant alpha values
2965 try:
2966 # If alpha is longer than the number of colors, resize to the number of
2967 # elements. Also, if rgba_colors.size (the number of elements of
2968 # rgba_colors) is the same as the number of elements, resize the array,
2969 # to avoid it being interpreted as a colormap by scatter()
2970 if len(alpha) > len(rgba_colors) or rgba_colors.size == len(elem_list):
2971 rgba_colors = np.resize(rgba_colors, (len(elem_list), 4))
2972 rgba_colors[1:, 0] = rgba_colors[0, 0]
2973 rgba_colors[1:, 1] = rgba_colors[0, 1]
2974 rgba_colors[1:, 2] = rgba_colors[0, 2]
2975 rgba_colors[:, 3] = list(islice(cycle(alpha), len(rgba_colors)))
2976 except TypeError:
2977 rgba_colors[:, -1] = alpha
2978 return rgba_colors