1"""Read and write graphs in GEXF format.
2
3.. warning::
4 This parser uses the standard xml library present in Python, which is
5 insecure - see :external+python:mod:`xml` for additional information.
6 Only parse GEFX files you trust.
7
8GEXF (Graph Exchange XML Format) is a language for describing complex
9network structures, their associated data and dynamics.
10
11This implementation does not support mixed graphs (directed and
12undirected edges together).
13
14Format
15------
16GEXF is an XML format. See http://gexf.net/schema.html for the
17specification and http://gexf.net/basic.html for examples.
18"""
19
20import itertools
21import time
22from xml.etree.ElementTree import (
23 Element,
24 ElementTree,
25 SubElement,
26 register_namespace,
27 tostring,
28)
29
30import networkx as nx
31from networkx.utils import open_file
32
33__all__ = ["write_gexf", "read_gexf", "relabel_gexf_graph", "generate_gexf"]
34
35
36@open_file(1, mode="wb")
37def write_gexf(G, path, encoding="utf-8", prettyprint=True, version="1.2draft"):
38 """Write G in GEXF format to path.
39
40 "GEXF (Graph Exchange XML Format) is a language for describing
41 complex networks structures, their associated data and dynamics" [1]_.
42
43 Node attributes are checked according to the version of the GEXF
44 schemas used for parameters which are not user defined,
45 e.g. visualization 'viz' [2]_. See example for usage.
46
47 .. warning::
48
49 The `GEXF specification <https://gexf.net/schema.html>`_ reserves some
50 keywords (e.g. ``id``, ``pid``, ``label``, etc.) for specifying node/edge
51 metadata in the file format. Ensure NetworkX node/edge attribute names
52 do not use these special keywords to guarantee all attributes are preserved
53 as expected when roundtripping to/from GEXF format.
54
55 Parameters
56 ----------
57 G : graph
58 A NetworkX graph
59 path : file or string
60 File or file name to write.
61 File names ending in .gz or .bz2 will be compressed.
62 encoding : string (optional, default: 'utf-8')
63 Encoding for text data.
64 prettyprint : bool (optional, default: True)
65 If True use line breaks and indenting in output XML.
66 version: string (optional, default: '1.2draft')
67 The version of GEXF to be used for nodes attributes checking
68
69 Examples
70 --------
71 >>> G = nx.path_graph(4)
72 >>> nx.write_gexf(G, "test.gexf")
73
74 # visualization data
75 >>> G.nodes[0]["viz"] = {"size": 54}
76 >>> G.nodes[0]["viz"]["position"] = {"x": 0, "y": 1}
77 >>> G.nodes[0]["viz"]["color"] = {"r": 0, "g": 0, "b": 256}
78
79
80 Notes
81 -----
82 This implementation does not support mixed graphs (directed and undirected
83 edges together).
84
85 The node id attribute is set to be the string of the node label.
86 If you want to specify an id use set it as node data, e.g.
87 node['a']['id']=1 to set the id of node 'a' to 1.
88
89 References
90 ----------
91 .. [1] GEXF File Format, http://gexf.net/
92 .. [2] GEXF schema, http://gexf.net/schema.html
93 """
94 writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version)
95 writer.add_graph(G)
96 writer.write(path)
97
98
99def generate_gexf(G, encoding="utf-8", prettyprint=True, version="1.2draft"):
100 """Generate lines of GEXF format representation of G.
101
102 "GEXF (Graph Exchange XML Format) is a language for describing
103 complex networks structures, their associated data and dynamics" [1]_.
104
105 Parameters
106 ----------
107 G : graph
108 A NetworkX graph
109 encoding : string (optional, default: 'utf-8')
110 Encoding for text data.
111 prettyprint : bool (optional, default: True)
112 If True use line breaks and indenting in output XML.
113 version : string (default: 1.2draft)
114 Version of GEFX File Format (see http://gexf.net/schema.html)
115 Supported values: "1.1draft", "1.2draft"
116
117
118 Examples
119 --------
120 >>> G = nx.path_graph(4)
121 >>> linefeed = chr(10) # linefeed=\n
122 >>> s = linefeed.join(nx.generate_gexf(G))
123 >>> for line in nx.generate_gexf(G): # doctest: +SKIP
124 ... print(line)
125
126 Notes
127 -----
128 This implementation does not support mixed graphs (directed and undirected
129 edges together).
130
131 The node id attribute is set to be the string of the node label.
132 If you want to specify an id use set it as node data, e.g.
133 node['a']['id']=1 to set the id of node 'a' to 1.
134
135 References
136 ----------
137 .. [1] GEXF File Format, https://gephi.org/gexf/format/
138 """
139 writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version)
140 writer.add_graph(G)
141 yield from str(writer).splitlines()
142
143
144@open_file(0, mode="rb")
145@nx._dispatchable(graphs=None, returns_graph=True)
146def read_gexf(path, node_type=None, relabel=False, version="1.2draft"):
147 """Read graph in GEXF format from path.
148
149 "GEXF (Graph Exchange XML Format) is a language for describing
150 complex networks structures, their associated data and dynamics" [1]_.
151
152 Parameters
153 ----------
154 path : file or string
155 Filename or file handle to read.
156 Filenames ending in .gz or .bz2 will be decompressed.
157 node_type: Python type (default: None)
158 Convert node ids to this type if not None.
159 relabel : bool (default: False)
160 If True relabel the nodes to use the GEXF node "label" attribute
161 instead of the node "id" attribute as the NetworkX node label.
162 version : string (default: 1.2draft)
163 Version of GEFX File Format (see http://gexf.net/schema.html)
164 Supported values: "1.1draft", "1.2draft"
165
166 Returns
167 -------
168 graph: NetworkX graph
169 If no parallel edges are found a Graph or DiGraph is returned.
170 Otherwise a MultiGraph or MultiDiGraph is returned.
171
172 Notes
173 -----
174 This implementation does not support mixed graphs (directed and undirected
175 edges together).
176
177 References
178 ----------
179 .. [1] GEXF File Format, http://gexf.net/
180 """
181 reader = GEXFReader(node_type=node_type, version=version)
182 if relabel:
183 G = relabel_gexf_graph(reader(path))
184 else:
185 G = reader(path)
186 return G
187
188
189class GEXF:
190 versions = {
191 "1.1draft": {
192 "NS_GEXF": "http://www.gexf.net/1.1draft",
193 "NS_VIZ": "http://www.gexf.net/1.1draft/viz",
194 "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance",
195 "SCHEMALOCATION": " ".join(
196 [
197 "http://www.gexf.net/1.1draft",
198 "http://www.gexf.net/1.1draft/gexf.xsd",
199 ]
200 ),
201 "VERSION": "1.1",
202 },
203 "1.2draft": {
204 "NS_GEXF": "http://www.gexf.net/1.2draft",
205 "NS_VIZ": "http://www.gexf.net/1.2draft/viz",
206 "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance",
207 "SCHEMALOCATION": " ".join(
208 [
209 "http://www.gexf.net/1.2draft",
210 "http://www.gexf.net/1.2draft/gexf.xsd",
211 ]
212 ),
213 "VERSION": "1.2",
214 },
215 "1.3": {
216 "NS_GEXF": "http://gexf.net/1.3",
217 "NS_VIZ": "http://gexf.net/1.3/viz",
218 "NS_XSI": "http://w3.org/2001/XMLSchema-instance",
219 "SCHEMALOCATION": " ".join(
220 [
221 "http://gexf.net/1.3",
222 "http://gexf.net/1.3/gexf.xsd",
223 ]
224 ),
225 "VERSION": "1.3",
226 },
227 }
228
229 def construct_types(self):
230 types = [
231 (int, "integer"),
232 (float, "float"),
233 (float, "double"),
234 (bool, "boolean"),
235 (list, "string"),
236 (dict, "string"),
237 (int, "long"),
238 (str, "liststring"),
239 (str, "anyURI"),
240 (str, "string"),
241 ]
242
243 # These additions to types allow writing numpy types
244 try:
245 import numpy as np
246 except ImportError:
247 pass
248 else:
249 # prepend so that python types are created upon read (last entry wins)
250 types = [
251 (np.float64, "float"),
252 (np.float32, "float"),
253 (np.float16, "float"),
254 (np.int_, "int"),
255 (np.int8, "int"),
256 (np.int16, "int"),
257 (np.int32, "int"),
258 (np.int64, "int"),
259 (np.uint8, "int"),
260 (np.uint16, "int"),
261 (np.uint32, "int"),
262 (np.uint64, "int"),
263 (np.int_, "int"),
264 (np.intc, "int"),
265 (np.intp, "int"),
266 ] + types
267
268 self.xml_type = dict(types)
269 self.python_type = dict(reversed(a) for a in types)
270
271 # http://www.w3.org/TR/xmlschema-2/#boolean
272 convert_bool = {
273 "true": True,
274 "false": False,
275 "True": True,
276 "False": False,
277 "0": False,
278 0: False,
279 "1": True,
280 1: True,
281 }
282
283 def set_version(self, version):
284 d = self.versions.get(version)
285 if d is None:
286 raise nx.NetworkXError(f"Unknown GEXF version {version}.")
287 self.NS_GEXF = d["NS_GEXF"]
288 self.NS_VIZ = d["NS_VIZ"]
289 self.NS_XSI = d["NS_XSI"]
290 self.SCHEMALOCATION = d["SCHEMALOCATION"]
291 self.VERSION = d["VERSION"]
292 self.version = version
293
294
295class GEXFWriter(GEXF):
296 # class for writing GEXF format files
297 # use write_gexf() function
298 def __init__(
299 self, graph=None, encoding="utf-8", prettyprint=True, version="1.2draft"
300 ):
301 self.construct_types()
302 self.prettyprint = prettyprint
303 self.encoding = encoding
304 self.set_version(version)
305 self.xml = Element(
306 "gexf",
307 {
308 "xmlns": self.NS_GEXF,
309 "xmlns:xsi": self.NS_XSI,
310 "xsi:schemaLocation": self.SCHEMALOCATION,
311 "version": self.VERSION,
312 },
313 )
314
315 # Make meta element a non-graph element
316 # Also add lastmodifieddate as attribute, not tag
317 meta_element = Element("meta")
318 subelement_text = f"NetworkX {nx.__version__}"
319 SubElement(meta_element, "creator").text = subelement_text
320 meta_element.set("lastmodifieddate", time.strftime("%Y-%m-%d"))
321 self.xml.append(meta_element)
322
323 register_namespace("viz", self.NS_VIZ)
324
325 # counters for edge and attribute identifiers
326 self.edge_id = itertools.count()
327 self.attr_id = itertools.count()
328 self.all_edge_ids = set()
329 # default attributes are stored in dictionaries
330 self.attr = {}
331 self.attr["node"] = {}
332 self.attr["edge"] = {}
333 self.attr["node"]["dynamic"] = {}
334 self.attr["node"]["static"] = {}
335 self.attr["edge"]["dynamic"] = {}
336 self.attr["edge"]["static"] = {}
337
338 if graph is not None:
339 self.add_graph(graph)
340
341 def __str__(self):
342 if self.prettyprint:
343 self.indent(self.xml)
344 s = tostring(self.xml).decode(self.encoding)
345 return s
346
347 def add_graph(self, G):
348 # first pass through G collecting edge ids
349 for u, v, dd in G.edges(data=True):
350 eid = dd.get("id")
351 if eid is not None:
352 self.all_edge_ids.add(str(eid))
353 # set graph attributes
354 if G.graph.get("mode") == "dynamic":
355 mode = "dynamic"
356 else:
357 mode = "static"
358 # Add a graph element to the XML
359 if G.is_directed():
360 default = "directed"
361 else:
362 default = "undirected"
363 name = G.graph.get("name", "")
364 graph_element = Element("graph", defaultedgetype=default, mode=mode, name=name)
365 self.graph_element = graph_element
366 self.add_nodes(G, graph_element)
367 self.add_edges(G, graph_element)
368 self.xml.append(graph_element)
369
370 def add_nodes(self, G, graph_element):
371 nodes_element = Element("nodes")
372 for node, data in G.nodes(data=True):
373 node_data = data.copy()
374 node_id = str(node_data.pop("id", node))
375 kw = {"id": node_id}
376 label = str(node_data.pop("label", node))
377 kw["label"] = label
378 try:
379 pid = node_data.pop("pid")
380 kw["pid"] = str(pid)
381 except KeyError:
382 pass
383 try:
384 start = node_data.pop("start")
385 kw["start"] = str(start)
386 self.alter_graph_mode_timeformat(start)
387 except KeyError:
388 pass
389 try:
390 end = node_data.pop("end")
391 kw["end"] = str(end)
392 self.alter_graph_mode_timeformat(end)
393 except KeyError:
394 pass
395 # add node element with attributes
396 node_element = Element("node", **kw)
397 # add node element and attr subelements
398 default = G.graph.get("node_default", {})
399 node_data = self.add_parents(node_element, node_data)
400 if self.VERSION == "1.1":
401 node_data = self.add_slices(node_element, node_data)
402 else:
403 node_data = self.add_spells(node_element, node_data)
404 node_data = self.add_viz(node_element, node_data)
405 node_data = self.add_attributes("node", node_element, node_data, default)
406 nodes_element.append(node_element)
407 graph_element.append(nodes_element)
408
409 def add_edges(self, G, graph_element):
410 def edge_key_data(G):
411 # helper function to unify multigraph and graph edge iterator
412 if G.is_multigraph():
413 for u, v, key, data in G.edges(data=True, keys=True):
414 edge_data = data.copy()
415 edge_data.update(key=key)
416 edge_id = edge_data.pop("id", None)
417 if edge_id is None:
418 edge_id = next(self.edge_id)
419 while str(edge_id) in self.all_edge_ids:
420 edge_id = next(self.edge_id)
421 self.all_edge_ids.add(str(edge_id))
422 yield u, v, edge_id, edge_data
423 else:
424 for u, v, data in G.edges(data=True):
425 edge_data = data.copy()
426 edge_id = edge_data.pop("id", None)
427 if edge_id is None:
428 edge_id = next(self.edge_id)
429 while str(edge_id) in self.all_edge_ids:
430 edge_id = next(self.edge_id)
431 self.all_edge_ids.add(str(edge_id))
432 yield u, v, edge_id, edge_data
433
434 edges_element = Element("edges")
435 for u, v, key, edge_data in edge_key_data(G):
436 kw = {"id": str(key)}
437 try:
438 edge_label = edge_data.pop("label")
439 kw["label"] = str(edge_label)
440 except KeyError:
441 pass
442 try:
443 edge_weight = edge_data.pop("weight")
444 kw["weight"] = str(edge_weight)
445 except KeyError:
446 pass
447 try:
448 edge_type = edge_data.pop("type")
449 kw["type"] = str(edge_type)
450 except KeyError:
451 pass
452 try:
453 start = edge_data.pop("start")
454 kw["start"] = str(start)
455 self.alter_graph_mode_timeformat(start)
456 except KeyError:
457 pass
458 try:
459 end = edge_data.pop("end")
460 kw["end"] = str(end)
461 self.alter_graph_mode_timeformat(end)
462 except KeyError:
463 pass
464 source_id = str(G.nodes[u].get("id", u))
465 target_id = str(G.nodes[v].get("id", v))
466 edge_element = Element("edge", source=source_id, target=target_id, **kw)
467 default = G.graph.get("edge_default", {})
468 if self.VERSION == "1.1":
469 edge_data = self.add_slices(edge_element, edge_data)
470 else:
471 edge_data = self.add_spells(edge_element, edge_data)
472 edge_data = self.add_viz(edge_element, edge_data)
473 edge_data = self.add_attributes("edge", edge_element, edge_data, default)
474 edges_element.append(edge_element)
475 graph_element.append(edges_element)
476
477 def add_attributes(self, node_or_edge, xml_obj, data, default):
478 # Add attrvalues to node or edge
479 attvalues = Element("attvalues")
480 if len(data) == 0:
481 return data
482 mode = "static"
483 for k, v in data.items():
484 # rename generic multigraph key to avoid any name conflict
485 if k == "key":
486 k = "networkx_key"
487 val_type = type(v)
488 if val_type not in self.xml_type:
489 raise TypeError(f"attribute value type is not allowed: {val_type}")
490 if isinstance(v, list):
491 # dynamic data
492 for val, start, end in v:
493 val_type = type(val)
494 if start is not None or end is not None:
495 mode = "dynamic"
496 self.alter_graph_mode_timeformat(start)
497 self.alter_graph_mode_timeformat(end)
498 break
499 attr_id = self.get_attr_id(
500 str(k), self.xml_type[val_type], node_or_edge, default, mode
501 )
502 for val, start, end in v:
503 e = Element("attvalue")
504 e.attrib["for"] = attr_id
505 e.attrib["value"] = str(val)
506 # Handle nan, inf, -inf differently
507 if val_type is float:
508 if e.attrib["value"] == "inf":
509 e.attrib["value"] = "INF"
510 elif e.attrib["value"] == "nan":
511 e.attrib["value"] = "NaN"
512 elif e.attrib["value"] == "-inf":
513 e.attrib["value"] = "-INF"
514 if start is not None:
515 e.attrib["start"] = str(start)
516 if end is not None:
517 e.attrib["end"] = str(end)
518 attvalues.append(e)
519 else:
520 # static data
521 mode = "static"
522 attr_id = self.get_attr_id(
523 str(k), self.xml_type[val_type], node_or_edge, default, mode
524 )
525 e = Element("attvalue")
526 e.attrib["for"] = attr_id
527 if isinstance(v, bool):
528 e.attrib["value"] = str(v).lower()
529 else:
530 e.attrib["value"] = str(v)
531 # Handle float nan, inf, -inf differently
532 if val_type is float:
533 if e.attrib["value"] == "inf":
534 e.attrib["value"] = "INF"
535 elif e.attrib["value"] == "nan":
536 e.attrib["value"] = "NaN"
537 elif e.attrib["value"] == "-inf":
538 e.attrib["value"] = "-INF"
539 attvalues.append(e)
540 xml_obj.append(attvalues)
541 return data
542
543 def get_attr_id(self, title, attr_type, edge_or_node, default, mode):
544 # find the id of the attribute or generate a new id
545 try:
546 return self.attr[edge_or_node][mode][title]
547 except KeyError:
548 # generate new id
549 new_id = str(next(self.attr_id))
550 self.attr[edge_or_node][mode][title] = new_id
551 attr_kwargs = {"id": new_id, "title": title, "type": attr_type}
552 attribute = Element("attribute", **attr_kwargs)
553 # add subelement for data default value if present
554 default_title = default.get(title)
555 if default_title is not None:
556 default_element = Element("default")
557 default_element.text = str(default_title)
558 attribute.append(default_element)
559 # new insert it into the XML
560 attributes_element = None
561 for a in self.graph_element.findall("attributes"):
562 # find existing attributes element by class and mode
563 a_class = a.get("class")
564 a_mode = a.get("mode", "static")
565 if a_class == edge_or_node and a_mode == mode:
566 attributes_element = a
567 if attributes_element is None:
568 # create new attributes element
569 attr_kwargs = {"mode": mode, "class": edge_or_node}
570 attributes_element = Element("attributes", **attr_kwargs)
571 self.graph_element.insert(0, attributes_element)
572 attributes_element.append(attribute)
573 return new_id
574
575 def add_viz(self, element, node_data):
576 viz = node_data.pop("viz", False)
577 if viz:
578 color = viz.get("color")
579 if color is not None:
580 if self.VERSION == "1.1":
581 e = Element(
582 f"{{{self.NS_VIZ}}}color",
583 r=str(color.get("r")),
584 g=str(color.get("g")),
585 b=str(color.get("b")),
586 )
587 else:
588 e = Element(
589 f"{{{self.NS_VIZ}}}color",
590 r=str(color.get("r")),
591 g=str(color.get("g")),
592 b=str(color.get("b")),
593 a=str(color.get("a", 1.0)),
594 )
595 element.append(e)
596
597 size = viz.get("size")
598 if size is not None:
599 e = Element(f"{{{self.NS_VIZ}}}size", value=str(size))
600 element.append(e)
601
602 thickness = viz.get("thickness")
603 if thickness is not None:
604 e = Element(f"{{{self.NS_VIZ}}}thickness", value=str(thickness))
605 element.append(e)
606
607 shape = viz.get("shape")
608 if shape is not None:
609 if shape.startswith("http"):
610 e = Element(
611 f"{{{self.NS_VIZ}}}shape", value="image", uri=str(shape)
612 )
613 else:
614 e = Element(f"{{{self.NS_VIZ}}}shape", value=str(shape))
615 element.append(e)
616
617 position = viz.get("position")
618 if position is not None:
619 e = Element(
620 f"{{{self.NS_VIZ}}}position",
621 x=str(position.get("x")),
622 y=str(position.get("y")),
623 z=str(position.get("z")),
624 )
625 element.append(e)
626 return node_data
627
628 def add_parents(self, node_element, node_data):
629 parents = node_data.pop("parents", False)
630 if parents:
631 parents_element = Element("parents")
632 for p in parents:
633 e = Element("parent")
634 e.attrib["for"] = str(p)
635 parents_element.append(e)
636 node_element.append(parents_element)
637 return node_data
638
639 def add_slices(self, node_or_edge_element, node_or_edge_data):
640 slices = node_or_edge_data.pop("slices", False)
641 if slices:
642 slices_element = Element("slices")
643 for start, end in slices:
644 e = Element("slice", start=str(start), end=str(end))
645 slices_element.append(e)
646 node_or_edge_element.append(slices_element)
647 return node_or_edge_data
648
649 def add_spells(self, node_or_edge_element, node_or_edge_data):
650 spells = node_or_edge_data.pop("spells", False)
651 if spells:
652 spells_element = Element("spells")
653 for start, end in spells:
654 e = Element("spell")
655 if start is not None:
656 e.attrib["start"] = str(start)
657 self.alter_graph_mode_timeformat(start)
658 if end is not None:
659 e.attrib["end"] = str(end)
660 self.alter_graph_mode_timeformat(end)
661 spells_element.append(e)
662 node_or_edge_element.append(spells_element)
663 return node_or_edge_data
664
665 def alter_graph_mode_timeformat(self, start_or_end):
666 # If 'start' or 'end' appears, set timeformat
667 if start_or_end is not None:
668 if isinstance(start_or_end, str):
669 timeformat = "date"
670 elif isinstance(start_or_end, float):
671 timeformat = "double"
672 elif isinstance(start_or_end, int):
673 timeformat = "long"
674 else:
675 raise nx.NetworkXError(
676 "timeformat should be of the type int, float or str"
677 )
678 self.graph_element.set("timeformat", timeformat)
679 # If Graph mode is static, alter to dynamic
680 if self.graph_element.get("mode") == "static":
681 self.graph_element.set("mode", "dynamic")
682
683 def write(self, fh):
684 # Serialize graph G in GEXF to the open fh
685 if self.prettyprint:
686 self.indent(self.xml)
687 document = ElementTree(self.xml)
688 document.write(fh, encoding=self.encoding, xml_declaration=True)
689
690 def indent(self, elem, level=0):
691 # in-place prettyprint formatter
692 i = "\n" + " " * level
693 if len(elem):
694 if not elem.text or not elem.text.strip():
695 elem.text = i + " "
696 if not elem.tail or not elem.tail.strip():
697 elem.tail = i
698 for elem in elem:
699 self.indent(elem, level + 1)
700 if not elem.tail or not elem.tail.strip():
701 elem.tail = i
702 else:
703 if level and (not elem.tail or not elem.tail.strip()):
704 elem.tail = i
705
706
707class GEXFReader(GEXF):
708 # Class to read GEXF format files
709 # use read_gexf() function
710 def __init__(self, node_type=None, version="1.2draft"):
711 self.construct_types()
712 self.node_type = node_type
713 # assume simple graph and test for multigraph on read
714 self.simple_graph = True
715 self.set_version(version)
716
717 def __call__(self, stream):
718 self.xml = ElementTree(file=stream)
719 g = self.xml.find(f"{{{self.NS_GEXF}}}graph")
720 if g is not None:
721 return self.make_graph(g)
722 # try all the versions
723 for version in self.versions:
724 self.set_version(version)
725 g = self.xml.find(f"{{{self.NS_GEXF}}}graph")
726 if g is not None:
727 return self.make_graph(g)
728 raise nx.NetworkXError("No <graph> element in GEXF file.")
729
730 def make_graph(self, graph_xml):
731 # start with empty DiGraph or MultiDiGraph
732 edgedefault = graph_xml.get("defaultedgetype", None)
733 if edgedefault == "directed":
734 G = nx.MultiDiGraph()
735 else:
736 G = nx.MultiGraph()
737
738 # graph attributes
739 graph_name = graph_xml.get("name", "")
740 if graph_name != "":
741 G.graph["name"] = graph_name
742 graph_start = graph_xml.get("start")
743 if graph_start is not None:
744 G.graph["start"] = graph_start
745 graph_end = graph_xml.get("end")
746 if graph_end is not None:
747 G.graph["end"] = graph_end
748 graph_mode = graph_xml.get("mode", "")
749 if graph_mode == "dynamic":
750 G.graph["mode"] = "dynamic"
751 else:
752 G.graph["mode"] = "static"
753
754 # timeformat
755 self.timeformat = graph_xml.get("timeformat")
756 if self.timeformat == "date":
757 self.timeformat = "string"
758
759 # node and edge attributes
760 attributes_elements = graph_xml.findall(f"{{{self.NS_GEXF}}}attributes")
761 # dictionaries to hold attributes and attribute defaults
762 node_attr = {}
763 node_default = {}
764 edge_attr = {}
765 edge_default = {}
766 for a in attributes_elements:
767 attr_class = a.get("class")
768 if attr_class == "node":
769 na, nd = self.find_gexf_attributes(a)
770 node_attr.update(na)
771 node_default.update(nd)
772 G.graph["node_default"] = node_default
773 elif attr_class == "edge":
774 ea, ed = self.find_gexf_attributes(a)
775 edge_attr.update(ea)
776 edge_default.update(ed)
777 G.graph["edge_default"] = edge_default
778 else:
779 raise # unknown attribute class
780
781 # Hack to handle Gephi0.7beta bug
782 # add weight attribute
783 ea = {"weight": {"type": "double", "mode": "static", "title": "weight"}}
784 ed = {}
785 edge_attr.update(ea)
786 edge_default.update(ed)
787 G.graph["edge_default"] = edge_default
788
789 # add nodes
790 nodes_element = graph_xml.find(f"{{{self.NS_GEXF}}}nodes")
791 if nodes_element is not None:
792 for node_xml in nodes_element.findall(f"{{{self.NS_GEXF}}}node"):
793 self.add_node(G, node_xml, node_attr)
794
795 # add edges
796 edges_element = graph_xml.find(f"{{{self.NS_GEXF}}}edges")
797 if edges_element is not None:
798 for edge_xml in edges_element.findall(f"{{{self.NS_GEXF}}}edge"):
799 self.add_edge(G, edge_xml, edge_attr)
800
801 # switch to Graph or DiGraph if no parallel edges were found.
802 if self.simple_graph:
803 if G.is_directed():
804 G = nx.DiGraph(G)
805 else:
806 G = nx.Graph(G)
807 return G
808
809 def add_node(self, G, node_xml, node_attr, node_pid=None):
810 # add a single node with attributes to the graph
811
812 # get attributes and subattributues for node
813 data = self.decode_attr_elements(node_attr, node_xml)
814 data = self.add_parents(data, node_xml) # add any parents
815 if self.VERSION == "1.1":
816 data = self.add_slices(data, node_xml) # add slices
817 else:
818 data = self.add_spells(data, node_xml) # add spells
819 data = self.add_viz(data, node_xml) # add viz
820 data = self.add_start_end(data, node_xml) # add start/end
821
822 # find the node id and cast it to the appropriate type
823 node_id = node_xml.get("id")
824 if self.node_type is not None:
825 node_id = self.node_type(node_id)
826
827 # every node should have a label
828 node_label = node_xml.get("label")
829 data["label"] = node_label
830
831 # parent node id
832 node_pid = node_xml.get("pid", node_pid)
833 if node_pid is not None:
834 data["pid"] = node_pid
835
836 # check for subnodes, recursive
837 subnodes = node_xml.find(f"{{{self.NS_GEXF}}}nodes")
838 if subnodes is not None:
839 for node_xml in subnodes.findall(f"{{{self.NS_GEXF}}}node"):
840 self.add_node(G, node_xml, node_attr, node_pid=node_id)
841
842 G.add_node(node_id, **data)
843
844 def add_start_end(self, data, xml):
845 # start and end times
846 ttype = self.timeformat
847 node_start = xml.get("start")
848 if node_start is not None:
849 data["start"] = self.python_type[ttype](node_start)
850 node_end = xml.get("end")
851 if node_end is not None:
852 data["end"] = self.python_type[ttype](node_end)
853 return data
854
855 def add_viz(self, data, node_xml):
856 # add viz element for node
857 viz = {}
858 color = node_xml.find(f"{{{self.NS_VIZ}}}color")
859 if color is not None:
860 if self.VERSION == "1.1":
861 viz["color"] = {
862 "r": int(color.get("r")),
863 "g": int(color.get("g")),
864 "b": int(color.get("b")),
865 }
866 else:
867 viz["color"] = {
868 "r": int(color.get("r")),
869 "g": int(color.get("g")),
870 "b": int(color.get("b")),
871 "a": float(color.get("a", 1)),
872 }
873
874 size = node_xml.find(f"{{{self.NS_VIZ}}}size")
875 if size is not None:
876 viz["size"] = float(size.get("value"))
877
878 thickness = node_xml.find(f"{{{self.NS_VIZ}}}thickness")
879 if thickness is not None:
880 viz["thickness"] = float(thickness.get("value"))
881
882 shape = node_xml.find(f"{{{self.NS_VIZ}}}shape")
883 if shape is not None:
884 viz["shape"] = shape.get("shape")
885 if viz["shape"] == "image":
886 viz["shape"] = shape.get("uri")
887
888 position = node_xml.find(f"{{{self.NS_VIZ}}}position")
889 if position is not None:
890 viz["position"] = {
891 "x": float(position.get("x", 0)),
892 "y": float(position.get("y", 0)),
893 "z": float(position.get("z", 0)),
894 }
895
896 if len(viz) > 0:
897 data["viz"] = viz
898 return data
899
900 def add_parents(self, data, node_xml):
901 parents_element = node_xml.find(f"{{{self.NS_GEXF}}}parents")
902 if parents_element is not None:
903 data["parents"] = []
904 for p in parents_element.findall(f"{{{self.NS_GEXF}}}parent"):
905 parent = p.get("for")
906 data["parents"].append(parent)
907 return data
908
909 def add_slices(self, data, node_or_edge_xml):
910 slices_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}slices")
911 if slices_element is not None:
912 data["slices"] = []
913 for s in slices_element.findall(f"{{{self.NS_GEXF}}}slice"):
914 start = s.get("start")
915 end = s.get("end")
916 data["slices"].append((start, end))
917 return data
918
919 def add_spells(self, data, node_or_edge_xml):
920 spells_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}spells")
921 if spells_element is not None:
922 data["spells"] = []
923 ttype = self.timeformat
924 for s in spells_element.findall(f"{{{self.NS_GEXF}}}spell"):
925 start = self.python_type[ttype](s.get("start"))
926 end = self.python_type[ttype](s.get("end"))
927 data["spells"].append((start, end))
928 return data
929
930 def add_edge(self, G, edge_element, edge_attr):
931 # add an edge to the graph
932
933 # raise error if we find mixed directed and undirected edges
934 edge_direction = edge_element.get("type")
935 if G.is_directed() and edge_direction == "undirected":
936 raise nx.NetworkXError("Undirected edge found in directed graph.")
937 if (not G.is_directed()) and edge_direction == "directed":
938 raise nx.NetworkXError("Directed edge found in undirected graph.")
939
940 # Get source and target and recast type if required
941 source = edge_element.get("source")
942 target = edge_element.get("target")
943 if self.node_type is not None:
944 source = self.node_type(source)
945 target = self.node_type(target)
946
947 data = self.decode_attr_elements(edge_attr, edge_element)
948 data = self.add_start_end(data, edge_element)
949
950 if self.VERSION == "1.1":
951 data = self.add_slices(data, edge_element) # add slices
952 else:
953 data = self.add_spells(data, edge_element) # add spells
954
955 # GEXF stores edge ids as an attribute
956 # NetworkX uses them as keys in multigraphs
957 # if networkx_key is not specified as an attribute
958 edge_id = edge_element.get("id")
959 if edge_id is not None:
960 data["id"] = edge_id
961
962 # check if there is a 'multigraph_key' and use that as edge_id
963 multigraph_key = data.pop("networkx_key", None)
964 if multigraph_key is not None:
965 edge_id = multigraph_key
966
967 weight = edge_element.get("weight")
968 if weight is not None:
969 data["weight"] = float(weight)
970
971 edge_label = edge_element.get("label")
972 if edge_label is not None:
973 data["label"] = edge_label
974
975 if G.has_edge(source, target):
976 # seen this edge before - this is a multigraph
977 self.simple_graph = False
978 G.add_edge(source, target, key=edge_id, **data)
979 if edge_direction == "mutual":
980 G.add_edge(target, source, key=edge_id, **data)
981
982 def decode_attr_elements(self, gexf_keys, obj_xml):
983 # Use the key information to decode the attr XML
984 attr = {}
985 # look for outer '<attvalues>' element
986 attr_element = obj_xml.find(f"{{{self.NS_GEXF}}}attvalues")
987 if attr_element is not None:
988 # loop over <attvalue> elements
989 for a in attr_element.findall(f"{{{self.NS_GEXF}}}attvalue"):
990 key = a.get("for") # for is required
991 try: # should be in our gexf_keys dictionary
992 title = gexf_keys[key]["title"]
993 except KeyError as err:
994 raise nx.NetworkXError(f"No attribute defined for={key}.") from err
995 atype = gexf_keys[key]["type"]
996 value = a.get("value")
997 if atype == "boolean":
998 value = self.convert_bool[value]
999 else:
1000 value = self.python_type[atype](value)
1001 if gexf_keys[key]["mode"] == "dynamic":
1002 # for dynamic graphs use list of three-tuples
1003 # [(value1,start1,end1), (value2,start2,end2), etc]
1004 ttype = self.timeformat
1005 start = self.python_type[ttype](a.get("start"))
1006 end = self.python_type[ttype](a.get("end"))
1007 if title in attr:
1008 attr[title].append((value, start, end))
1009 else:
1010 attr[title] = [(value, start, end)]
1011 else:
1012 # for static graphs just assign the value
1013 attr[title] = value
1014 return attr
1015
1016 def find_gexf_attributes(self, attributes_element):
1017 # Extract all the attributes and defaults
1018 attrs = {}
1019 defaults = {}
1020 mode = attributes_element.get("mode")
1021 for k in attributes_element.findall(f"{{{self.NS_GEXF}}}attribute"):
1022 attr_id = k.get("id")
1023 title = k.get("title")
1024 atype = k.get("type")
1025 attrs[attr_id] = {"title": title, "type": atype, "mode": mode}
1026 # check for the 'default' subelement of key element and add
1027 default = k.find(f"{{{self.NS_GEXF}}}default")
1028 if default is not None:
1029 if atype == "boolean":
1030 value = self.convert_bool[default.text]
1031 else:
1032 value = self.python_type[atype](default.text)
1033 defaults[title] = value
1034 return attrs, defaults
1035
1036
1037def relabel_gexf_graph(G):
1038 """Relabel graph using "label" node keyword for node label.
1039
1040 Parameters
1041 ----------
1042 G : graph
1043 A NetworkX graph read from GEXF data
1044
1045 Returns
1046 -------
1047 H : graph
1048 A NetworkX graph with relabeled nodes
1049
1050 Raises
1051 ------
1052 NetworkXError
1053 If node labels are missing or not unique while relabel=True.
1054
1055 Notes
1056 -----
1057 This function relabels the nodes in a NetworkX graph with the
1058 "label" attribute. It also handles relabeling the specific GEXF
1059 node attributes "parents", and "pid".
1060 """
1061 # build mapping of node labels, do some error checking
1062 try:
1063 mapping = [(u, G.nodes[u]["label"]) for u in G]
1064 except KeyError as err:
1065 raise nx.NetworkXError(
1066 "Failed to relabel nodes: missing node labels found. Use relabel=False."
1067 ) from err
1068 x, y = zip(*mapping)
1069 if len(set(y)) != len(G):
1070 raise nx.NetworkXError(
1071 "Failed to relabel nodes: duplicate node labels found. Use relabel=False."
1072 )
1073 mapping = dict(mapping)
1074 H = nx.relabel_nodes(G, mapping)
1075 # relabel attributes
1076 for n in G:
1077 m = mapping[n]
1078 H.nodes[m]["id"] = n
1079 H.nodes[m].pop("label")
1080 if "pid" in H.nodes[m]:
1081 H.nodes[m]["pid"] = mapping[G.nodes[n]["pid"]]
1082 if "parents" in H.nodes[m]:
1083 H.nodes[m]["parents"] = [mapping[p] for p in G.nodes[n]["parents"]]
1084 return H