1"""Read and write graphs in GEXF format.
2
3.. warning::
4 This parser uses the standard xml library present in Python, which is
5 insecure - see :external+python:mod:`xml` for additional information.
6 Only parse GEFX files you trust.
7
8GEXF (Graph Exchange XML Format) is a language for describing complex
9network structures, their associated data and dynamics.
10
11This implementation does not support mixed graphs (directed and
12undirected edges together).
13
14Format
15------
16GEXF is an XML format. See http://gexf.net/schema.html for the
17specification and http://gexf.net/basic.html for examples.
18"""
19
20import itertools
21import time
22from xml.etree.ElementTree import (
23 Element,
24 ElementTree,
25 SubElement,
26 register_namespace,
27 tostring,
28)
29
30import networkx as nx
31from networkx.utils import open_file
32
33__all__ = ["write_gexf", "read_gexf", "relabel_gexf_graph", "generate_gexf"]
34
35
36@open_file(1, mode="wb")
37def write_gexf(G, path, encoding="utf-8", prettyprint=True, version="1.2draft"):
38 """Write G in GEXF format to path.
39
40 "GEXF (Graph Exchange XML Format) is a language for describing
41 complex networks structures, their associated data and dynamics" [1]_.
42
43 Node attributes are checked according to the version of the GEXF
44 schemas used for parameters which are not user defined,
45 e.g. visualization 'viz' [2]_. See example for usage.
46
47 Parameters
48 ----------
49 G : graph
50 A NetworkX graph
51 path : file or string
52 File or file name to write.
53 File names ending in .gz or .bz2 will be compressed.
54 encoding : string (optional, default: 'utf-8')
55 Encoding for text data.
56 prettyprint : bool (optional, default: True)
57 If True use line breaks and indenting in output XML.
58 version: string (optional, default: '1.2draft')
59 The version of GEXF to be used for nodes attributes checking
60
61 Examples
62 --------
63 >>> G = nx.path_graph(4)
64 >>> nx.write_gexf(G, "test.gexf")
65
66 # visualization data
67 >>> G.nodes[0]["viz"] = {"size": 54}
68 >>> G.nodes[0]["viz"]["position"] = {"x": 0, "y": 1}
69 >>> G.nodes[0]["viz"]["color"] = {"r": 0, "g": 0, "b": 256}
70
71
72 Notes
73 -----
74 This implementation does not support mixed graphs (directed and undirected
75 edges together).
76
77 The node id attribute is set to be the string of the node label.
78 If you want to specify an id use set it as node data, e.g.
79 node['a']['id']=1 to set the id of node 'a' to 1.
80
81 References
82 ----------
83 .. [1] GEXF File Format, http://gexf.net/
84 .. [2] GEXF schema, http://gexf.net/schema.html
85 """
86 writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version)
87 writer.add_graph(G)
88 writer.write(path)
89
90
91def generate_gexf(G, encoding="utf-8", prettyprint=True, version="1.2draft"):
92 """Generate lines of GEXF format representation of G.
93
94 "GEXF (Graph Exchange XML Format) is a language for describing
95 complex networks structures, their associated data and dynamics" [1]_.
96
97 Parameters
98 ----------
99 G : graph
100 A NetworkX graph
101 encoding : string (optional, default: 'utf-8')
102 Encoding for text data.
103 prettyprint : bool (optional, default: True)
104 If True use line breaks and indenting in output XML.
105 version : string (default: 1.2draft)
106 Version of GEFX File Format (see http://gexf.net/schema.html)
107 Supported values: "1.1draft", "1.2draft"
108
109
110 Examples
111 --------
112 >>> G = nx.path_graph(4)
113 >>> linefeed = chr(10) # linefeed=\n
114 >>> s = linefeed.join(nx.generate_gexf(G))
115 >>> for line in nx.generate_gexf(G): # doctest: +SKIP
116 ... print(line)
117
118 Notes
119 -----
120 This implementation does not support mixed graphs (directed and undirected
121 edges together).
122
123 The node id attribute is set to be the string of the node label.
124 If you want to specify an id use set it as node data, e.g.
125 node['a']['id']=1 to set the id of node 'a' to 1.
126
127 References
128 ----------
129 .. [1] GEXF File Format, https://gephi.org/gexf/format/
130 """
131 writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version)
132 writer.add_graph(G)
133 yield from str(writer).splitlines()
134
135
136@open_file(0, mode="rb")
137@nx._dispatchable(graphs=None, returns_graph=True)
138def read_gexf(path, node_type=None, relabel=False, version="1.2draft"):
139 """Read graph in GEXF format from path.
140
141 "GEXF (Graph Exchange XML Format) is a language for describing
142 complex networks structures, their associated data and dynamics" [1]_.
143
144 Parameters
145 ----------
146 path : file or string
147 Filename or file handle to read.
148 Filenames ending in .gz or .bz2 will be decompressed.
149 node_type: Python type (default: None)
150 Convert node ids to this type if not None.
151 relabel : bool (default: False)
152 If True relabel the nodes to use the GEXF node "label" attribute
153 instead of the node "id" attribute as the NetworkX node label.
154 version : string (default: 1.2draft)
155 Version of GEFX File Format (see http://gexf.net/schema.html)
156 Supported values: "1.1draft", "1.2draft"
157
158 Returns
159 -------
160 graph: NetworkX graph
161 If no parallel edges are found a Graph or DiGraph is returned.
162 Otherwise a MultiGraph or MultiDiGraph is returned.
163
164 Notes
165 -----
166 This implementation does not support mixed graphs (directed and undirected
167 edges together).
168
169 References
170 ----------
171 .. [1] GEXF File Format, http://gexf.net/
172 """
173 reader = GEXFReader(node_type=node_type, version=version)
174 if relabel:
175 G = relabel_gexf_graph(reader(path))
176 else:
177 G = reader(path)
178 return G
179
180
181class GEXF:
182 versions = {
183 "1.1draft": {
184 "NS_GEXF": "http://www.gexf.net/1.1draft",
185 "NS_VIZ": "http://www.gexf.net/1.1draft/viz",
186 "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance",
187 "SCHEMALOCATION": " ".join(
188 [
189 "http://www.gexf.net/1.1draft",
190 "http://www.gexf.net/1.1draft/gexf.xsd",
191 ]
192 ),
193 "VERSION": "1.1",
194 },
195 "1.2draft": {
196 "NS_GEXF": "http://www.gexf.net/1.2draft",
197 "NS_VIZ": "http://www.gexf.net/1.2draft/viz",
198 "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance",
199 "SCHEMALOCATION": " ".join(
200 [
201 "http://www.gexf.net/1.2draft",
202 "http://www.gexf.net/1.2draft/gexf.xsd",
203 ]
204 ),
205 "VERSION": "1.2",
206 },
207 "1.3": {
208 "NS_GEXF": "http://gexf.net/1.3",
209 "NS_VIZ": "http://gexf.net/1.3/viz",
210 "NS_XSI": "http://w3.org/2001/XMLSchema-instance",
211 "SCHEMALOCATION": " ".join(
212 [
213 "http://gexf.net/1.3",
214 "http://gexf.net/1.3/gexf.xsd",
215 ]
216 ),
217 "VERSION": "1.3",
218 },
219 }
220
221 def construct_types(self):
222 types = [
223 (int, "integer"),
224 (float, "float"),
225 (float, "double"),
226 (bool, "boolean"),
227 (list, "string"),
228 (dict, "string"),
229 (int, "long"),
230 (str, "liststring"),
231 (str, "anyURI"),
232 (str, "string"),
233 ]
234
235 # These additions to types allow writing numpy types
236 try:
237 import numpy as np
238 except ImportError:
239 pass
240 else:
241 # prepend so that python types are created upon read (last entry wins)
242 types = [
243 (np.float64, "float"),
244 (np.float32, "float"),
245 (np.float16, "float"),
246 (np.int_, "int"),
247 (np.int8, "int"),
248 (np.int16, "int"),
249 (np.int32, "int"),
250 (np.int64, "int"),
251 (np.uint8, "int"),
252 (np.uint16, "int"),
253 (np.uint32, "int"),
254 (np.uint64, "int"),
255 (np.int_, "int"),
256 (np.intc, "int"),
257 (np.intp, "int"),
258 ] + types
259
260 self.xml_type = dict(types)
261 self.python_type = dict(reversed(a) for a in types)
262
263 # http://www.w3.org/TR/xmlschema-2/#boolean
264 convert_bool = {
265 "true": True,
266 "false": False,
267 "True": True,
268 "False": False,
269 "0": False,
270 0: False,
271 "1": True,
272 1: True,
273 }
274
275 def set_version(self, version):
276 d = self.versions.get(version)
277 if d is None:
278 raise nx.NetworkXError(f"Unknown GEXF version {version}.")
279 self.NS_GEXF = d["NS_GEXF"]
280 self.NS_VIZ = d["NS_VIZ"]
281 self.NS_XSI = d["NS_XSI"]
282 self.SCHEMALOCATION = d["SCHEMALOCATION"]
283 self.VERSION = d["VERSION"]
284 self.version = version
285
286
287class GEXFWriter(GEXF):
288 # class for writing GEXF format files
289 # use write_gexf() function
290 def __init__(
291 self, graph=None, encoding="utf-8", prettyprint=True, version="1.2draft"
292 ):
293 self.construct_types()
294 self.prettyprint = prettyprint
295 self.encoding = encoding
296 self.set_version(version)
297 self.xml = Element(
298 "gexf",
299 {
300 "xmlns": self.NS_GEXF,
301 "xmlns:xsi": self.NS_XSI,
302 "xsi:schemaLocation": self.SCHEMALOCATION,
303 "version": self.VERSION,
304 },
305 )
306
307 # Make meta element a non-graph element
308 # Also add lastmodifieddate as attribute, not tag
309 meta_element = Element("meta")
310 subelement_text = f"NetworkX {nx.__version__}"
311 SubElement(meta_element, "creator").text = subelement_text
312 meta_element.set("lastmodifieddate", time.strftime("%Y-%m-%d"))
313 self.xml.append(meta_element)
314
315 register_namespace("viz", self.NS_VIZ)
316
317 # counters for edge and attribute identifiers
318 self.edge_id = itertools.count()
319 self.attr_id = itertools.count()
320 self.all_edge_ids = set()
321 # default attributes are stored in dictionaries
322 self.attr = {}
323 self.attr["node"] = {}
324 self.attr["edge"] = {}
325 self.attr["node"]["dynamic"] = {}
326 self.attr["node"]["static"] = {}
327 self.attr["edge"]["dynamic"] = {}
328 self.attr["edge"]["static"] = {}
329
330 if graph is not None:
331 self.add_graph(graph)
332
333 def __str__(self):
334 if self.prettyprint:
335 self.indent(self.xml)
336 s = tostring(self.xml).decode(self.encoding)
337 return s
338
339 def add_graph(self, G):
340 # first pass through G collecting edge ids
341 for u, v, dd in G.edges(data=True):
342 eid = dd.get("id")
343 if eid is not None:
344 self.all_edge_ids.add(str(eid))
345 # set graph attributes
346 if G.graph.get("mode") == "dynamic":
347 mode = "dynamic"
348 else:
349 mode = "static"
350 # Add a graph element to the XML
351 if G.is_directed():
352 default = "directed"
353 else:
354 default = "undirected"
355 name = G.graph.get("name", "")
356 graph_element = Element("graph", defaultedgetype=default, mode=mode, name=name)
357 self.graph_element = graph_element
358 self.add_nodes(G, graph_element)
359 self.add_edges(G, graph_element)
360 self.xml.append(graph_element)
361
362 def add_nodes(self, G, graph_element):
363 nodes_element = Element("nodes")
364 for node, data in G.nodes(data=True):
365 node_data = data.copy()
366 node_id = str(node_data.pop("id", node))
367 kw = {"id": node_id}
368 label = str(node_data.pop("label", node))
369 kw["label"] = label
370 try:
371 pid = node_data.pop("pid")
372 kw["pid"] = str(pid)
373 except KeyError:
374 pass
375 try:
376 start = node_data.pop("start")
377 kw["start"] = str(start)
378 self.alter_graph_mode_timeformat(start)
379 except KeyError:
380 pass
381 try:
382 end = node_data.pop("end")
383 kw["end"] = str(end)
384 self.alter_graph_mode_timeformat(end)
385 except KeyError:
386 pass
387 # add node element with attributes
388 node_element = Element("node", **kw)
389 # add node element and attr subelements
390 default = G.graph.get("node_default", {})
391 node_data = self.add_parents(node_element, node_data)
392 if self.VERSION == "1.1":
393 node_data = self.add_slices(node_element, node_data)
394 else:
395 node_data = self.add_spells(node_element, node_data)
396 node_data = self.add_viz(node_element, node_data)
397 node_data = self.add_attributes("node", node_element, node_data, default)
398 nodes_element.append(node_element)
399 graph_element.append(nodes_element)
400
401 def add_edges(self, G, graph_element):
402 def edge_key_data(G):
403 # helper function to unify multigraph and graph edge iterator
404 if G.is_multigraph():
405 for u, v, key, data in G.edges(data=True, keys=True):
406 edge_data = data.copy()
407 edge_data.update(key=key)
408 edge_id = edge_data.pop("id", None)
409 if edge_id is None:
410 edge_id = next(self.edge_id)
411 while str(edge_id) in self.all_edge_ids:
412 edge_id = next(self.edge_id)
413 self.all_edge_ids.add(str(edge_id))
414 yield u, v, edge_id, edge_data
415 else:
416 for u, v, data in G.edges(data=True):
417 edge_data = data.copy()
418 edge_id = edge_data.pop("id", None)
419 if edge_id is None:
420 edge_id = next(self.edge_id)
421 while str(edge_id) in self.all_edge_ids:
422 edge_id = next(self.edge_id)
423 self.all_edge_ids.add(str(edge_id))
424 yield u, v, edge_id, edge_data
425
426 edges_element = Element("edges")
427 for u, v, key, edge_data in edge_key_data(G):
428 kw = {"id": str(key)}
429 try:
430 edge_label = edge_data.pop("label")
431 kw["label"] = str(edge_label)
432 except KeyError:
433 pass
434 try:
435 edge_weight = edge_data.pop("weight")
436 kw["weight"] = str(edge_weight)
437 except KeyError:
438 pass
439 try:
440 edge_type = edge_data.pop("type")
441 kw["type"] = str(edge_type)
442 except KeyError:
443 pass
444 try:
445 start = edge_data.pop("start")
446 kw["start"] = str(start)
447 self.alter_graph_mode_timeformat(start)
448 except KeyError:
449 pass
450 try:
451 end = edge_data.pop("end")
452 kw["end"] = str(end)
453 self.alter_graph_mode_timeformat(end)
454 except KeyError:
455 pass
456 source_id = str(G.nodes[u].get("id", u))
457 target_id = str(G.nodes[v].get("id", v))
458 edge_element = Element("edge", source=source_id, target=target_id, **kw)
459 default = G.graph.get("edge_default", {})
460 if self.VERSION == "1.1":
461 edge_data = self.add_slices(edge_element, edge_data)
462 else:
463 edge_data = self.add_spells(edge_element, edge_data)
464 edge_data = self.add_viz(edge_element, edge_data)
465 edge_data = self.add_attributes("edge", edge_element, edge_data, default)
466 edges_element.append(edge_element)
467 graph_element.append(edges_element)
468
469 def add_attributes(self, node_or_edge, xml_obj, data, default):
470 # Add attrvalues to node or edge
471 attvalues = Element("attvalues")
472 if len(data) == 0:
473 return data
474 mode = "static"
475 for k, v in data.items():
476 # rename generic multigraph key to avoid any name conflict
477 if k == "key":
478 k = "networkx_key"
479 val_type = type(v)
480 if val_type not in self.xml_type:
481 raise TypeError(f"attribute value type is not allowed: {val_type}")
482 if isinstance(v, list):
483 # dynamic data
484 for val, start, end in v:
485 val_type = type(val)
486 if start is not None or end is not None:
487 mode = "dynamic"
488 self.alter_graph_mode_timeformat(start)
489 self.alter_graph_mode_timeformat(end)
490 break
491 attr_id = self.get_attr_id(
492 str(k), self.xml_type[val_type], node_or_edge, default, mode
493 )
494 for val, start, end in v:
495 e = Element("attvalue")
496 e.attrib["for"] = attr_id
497 e.attrib["value"] = str(val)
498 # Handle nan, inf, -inf differently
499 if val_type is float:
500 if e.attrib["value"] == "inf":
501 e.attrib["value"] = "INF"
502 elif e.attrib["value"] == "nan":
503 e.attrib["value"] = "NaN"
504 elif e.attrib["value"] == "-inf":
505 e.attrib["value"] = "-INF"
506 if start is not None:
507 e.attrib["start"] = str(start)
508 if end is not None:
509 e.attrib["end"] = str(end)
510 attvalues.append(e)
511 else:
512 # static data
513 mode = "static"
514 attr_id = self.get_attr_id(
515 str(k), self.xml_type[val_type], node_or_edge, default, mode
516 )
517 e = Element("attvalue")
518 e.attrib["for"] = attr_id
519 if isinstance(v, bool):
520 e.attrib["value"] = str(v).lower()
521 else:
522 e.attrib["value"] = str(v)
523 # Handle float nan, inf, -inf differently
524 if val_type is float:
525 if e.attrib["value"] == "inf":
526 e.attrib["value"] = "INF"
527 elif e.attrib["value"] == "nan":
528 e.attrib["value"] = "NaN"
529 elif e.attrib["value"] == "-inf":
530 e.attrib["value"] = "-INF"
531 attvalues.append(e)
532 xml_obj.append(attvalues)
533 return data
534
535 def get_attr_id(self, title, attr_type, edge_or_node, default, mode):
536 # find the id of the attribute or generate a new id
537 try:
538 return self.attr[edge_or_node][mode][title]
539 except KeyError:
540 # generate new id
541 new_id = str(next(self.attr_id))
542 self.attr[edge_or_node][mode][title] = new_id
543 attr_kwargs = {"id": new_id, "title": title, "type": attr_type}
544 attribute = Element("attribute", **attr_kwargs)
545 # add subelement for data default value if present
546 default_title = default.get(title)
547 if default_title is not None:
548 default_element = Element("default")
549 default_element.text = str(default_title)
550 attribute.append(default_element)
551 # new insert it into the XML
552 attributes_element = None
553 for a in self.graph_element.findall("attributes"):
554 # find existing attributes element by class and mode
555 a_class = a.get("class")
556 a_mode = a.get("mode", "static")
557 if a_class == edge_or_node and a_mode == mode:
558 attributes_element = a
559 if attributes_element is None:
560 # create new attributes element
561 attr_kwargs = {"mode": mode, "class": edge_or_node}
562 attributes_element = Element("attributes", **attr_kwargs)
563 self.graph_element.insert(0, attributes_element)
564 attributes_element.append(attribute)
565 return new_id
566
567 def add_viz(self, element, node_data):
568 viz = node_data.pop("viz", False)
569 if viz:
570 color = viz.get("color")
571 if color is not None:
572 if self.VERSION == "1.1":
573 e = Element(
574 f"{{{self.NS_VIZ}}}color",
575 r=str(color.get("r")),
576 g=str(color.get("g")),
577 b=str(color.get("b")),
578 )
579 else:
580 e = Element(
581 f"{{{self.NS_VIZ}}}color",
582 r=str(color.get("r")),
583 g=str(color.get("g")),
584 b=str(color.get("b")),
585 a=str(color.get("a", 1.0)),
586 )
587 element.append(e)
588
589 size = viz.get("size")
590 if size is not None:
591 e = Element(f"{{{self.NS_VIZ}}}size", value=str(size))
592 element.append(e)
593
594 thickness = viz.get("thickness")
595 if thickness is not None:
596 e = Element(f"{{{self.NS_VIZ}}}thickness", value=str(thickness))
597 element.append(e)
598
599 shape = viz.get("shape")
600 if shape is not None:
601 if shape.startswith("http"):
602 e = Element(
603 f"{{{self.NS_VIZ}}}shape", value="image", uri=str(shape)
604 )
605 else:
606 e = Element(f"{{{self.NS_VIZ}}}shape", value=str(shape))
607 element.append(e)
608
609 position = viz.get("position")
610 if position is not None:
611 e = Element(
612 f"{{{self.NS_VIZ}}}position",
613 x=str(position.get("x")),
614 y=str(position.get("y")),
615 z=str(position.get("z")),
616 )
617 element.append(e)
618 return node_data
619
620 def add_parents(self, node_element, node_data):
621 parents = node_data.pop("parents", False)
622 if parents:
623 parents_element = Element("parents")
624 for p in parents:
625 e = Element("parent")
626 e.attrib["for"] = str(p)
627 parents_element.append(e)
628 node_element.append(parents_element)
629 return node_data
630
631 def add_slices(self, node_or_edge_element, node_or_edge_data):
632 slices = node_or_edge_data.pop("slices", False)
633 if slices:
634 slices_element = Element("slices")
635 for start, end in slices:
636 e = Element("slice", start=str(start), end=str(end))
637 slices_element.append(e)
638 node_or_edge_element.append(slices_element)
639 return node_or_edge_data
640
641 def add_spells(self, node_or_edge_element, node_or_edge_data):
642 spells = node_or_edge_data.pop("spells", False)
643 if spells:
644 spells_element = Element("spells")
645 for start, end in spells:
646 e = Element("spell")
647 if start is not None:
648 e.attrib["start"] = str(start)
649 self.alter_graph_mode_timeformat(start)
650 if end is not None:
651 e.attrib["end"] = str(end)
652 self.alter_graph_mode_timeformat(end)
653 spells_element.append(e)
654 node_or_edge_element.append(spells_element)
655 return node_or_edge_data
656
657 def alter_graph_mode_timeformat(self, start_or_end):
658 # If 'start' or 'end' appears, set timeformat
659 if start_or_end is not None:
660 if isinstance(start_or_end, str):
661 timeformat = "date"
662 elif isinstance(start_or_end, float):
663 timeformat = "double"
664 elif isinstance(start_or_end, int):
665 timeformat = "long"
666 else:
667 raise nx.NetworkXError(
668 "timeformat should be of the type int, float or str"
669 )
670 self.graph_element.set("timeformat", timeformat)
671 # If Graph mode is static, alter to dynamic
672 if self.graph_element.get("mode") == "static":
673 self.graph_element.set("mode", "dynamic")
674
675 def write(self, fh):
676 # Serialize graph G in GEXF to the open fh
677 if self.prettyprint:
678 self.indent(self.xml)
679 document = ElementTree(self.xml)
680 document.write(fh, encoding=self.encoding, xml_declaration=True)
681
682 def indent(self, elem, level=0):
683 # in-place prettyprint formatter
684 i = "\n" + " " * level
685 if len(elem):
686 if not elem.text or not elem.text.strip():
687 elem.text = i + " "
688 if not elem.tail or not elem.tail.strip():
689 elem.tail = i
690 for elem in elem:
691 self.indent(elem, level + 1)
692 if not elem.tail or not elem.tail.strip():
693 elem.tail = i
694 else:
695 if level and (not elem.tail or not elem.tail.strip()):
696 elem.tail = i
697
698
699class GEXFReader(GEXF):
700 # Class to read GEXF format files
701 # use read_gexf() function
702 def __init__(self, node_type=None, version="1.2draft"):
703 self.construct_types()
704 self.node_type = node_type
705 # assume simple graph and test for multigraph on read
706 self.simple_graph = True
707 self.set_version(version)
708
709 def __call__(self, stream):
710 self.xml = ElementTree(file=stream)
711 g = self.xml.find(f"{{{self.NS_GEXF}}}graph")
712 if g is not None:
713 return self.make_graph(g)
714 # try all the versions
715 for version in self.versions:
716 self.set_version(version)
717 g = self.xml.find(f"{{{self.NS_GEXF}}}graph")
718 if g is not None:
719 return self.make_graph(g)
720 raise nx.NetworkXError("No <graph> element in GEXF file.")
721
722 def make_graph(self, graph_xml):
723 # start with empty DiGraph or MultiDiGraph
724 edgedefault = graph_xml.get("defaultedgetype", None)
725 if edgedefault == "directed":
726 G = nx.MultiDiGraph()
727 else:
728 G = nx.MultiGraph()
729
730 # graph attributes
731 graph_name = graph_xml.get("name", "")
732 if graph_name != "":
733 G.graph["name"] = graph_name
734 graph_start = graph_xml.get("start")
735 if graph_start is not None:
736 G.graph["start"] = graph_start
737 graph_end = graph_xml.get("end")
738 if graph_end is not None:
739 G.graph["end"] = graph_end
740 graph_mode = graph_xml.get("mode", "")
741 if graph_mode == "dynamic":
742 G.graph["mode"] = "dynamic"
743 else:
744 G.graph["mode"] = "static"
745
746 # timeformat
747 self.timeformat = graph_xml.get("timeformat")
748 if self.timeformat == "date":
749 self.timeformat = "string"
750
751 # node and edge attributes
752 attributes_elements = graph_xml.findall(f"{{{self.NS_GEXF}}}attributes")
753 # dictionaries to hold attributes and attribute defaults
754 node_attr = {}
755 node_default = {}
756 edge_attr = {}
757 edge_default = {}
758 for a in attributes_elements:
759 attr_class = a.get("class")
760 if attr_class == "node":
761 na, nd = self.find_gexf_attributes(a)
762 node_attr.update(na)
763 node_default.update(nd)
764 G.graph["node_default"] = node_default
765 elif attr_class == "edge":
766 ea, ed = self.find_gexf_attributes(a)
767 edge_attr.update(ea)
768 edge_default.update(ed)
769 G.graph["edge_default"] = edge_default
770 else:
771 raise # unknown attribute class
772
773 # Hack to handle Gephi0.7beta bug
774 # add weight attribute
775 ea = {"weight": {"type": "double", "mode": "static", "title": "weight"}}
776 ed = {}
777 edge_attr.update(ea)
778 edge_default.update(ed)
779 G.graph["edge_default"] = edge_default
780
781 # add nodes
782 nodes_element = graph_xml.find(f"{{{self.NS_GEXF}}}nodes")
783 if nodes_element is not None:
784 for node_xml in nodes_element.findall(f"{{{self.NS_GEXF}}}node"):
785 self.add_node(G, node_xml, node_attr)
786
787 # add edges
788 edges_element = graph_xml.find(f"{{{self.NS_GEXF}}}edges")
789 if edges_element is not None:
790 for edge_xml in edges_element.findall(f"{{{self.NS_GEXF}}}edge"):
791 self.add_edge(G, edge_xml, edge_attr)
792
793 # switch to Graph or DiGraph if no parallel edges were found.
794 if self.simple_graph:
795 if G.is_directed():
796 G = nx.DiGraph(G)
797 else:
798 G = nx.Graph(G)
799 return G
800
801 def add_node(self, G, node_xml, node_attr, node_pid=None):
802 # add a single node with attributes to the graph
803
804 # get attributes and subattributues for node
805 data = self.decode_attr_elements(node_attr, node_xml)
806 data = self.add_parents(data, node_xml) # add any parents
807 if self.VERSION == "1.1":
808 data = self.add_slices(data, node_xml) # add slices
809 else:
810 data = self.add_spells(data, node_xml) # add spells
811 data = self.add_viz(data, node_xml) # add viz
812 data = self.add_start_end(data, node_xml) # add start/end
813
814 # find the node id and cast it to the appropriate type
815 node_id = node_xml.get("id")
816 if self.node_type is not None:
817 node_id = self.node_type(node_id)
818
819 # every node should have a label
820 node_label = node_xml.get("label")
821 data["label"] = node_label
822
823 # parent node id
824 node_pid = node_xml.get("pid", node_pid)
825 if node_pid is not None:
826 data["pid"] = node_pid
827
828 # check for subnodes, recursive
829 subnodes = node_xml.find(f"{{{self.NS_GEXF}}}nodes")
830 if subnodes is not None:
831 for node_xml in subnodes.findall(f"{{{self.NS_GEXF}}}node"):
832 self.add_node(G, node_xml, node_attr, node_pid=node_id)
833
834 G.add_node(node_id, **data)
835
836 def add_start_end(self, data, xml):
837 # start and end times
838 ttype = self.timeformat
839 node_start = xml.get("start")
840 if node_start is not None:
841 data["start"] = self.python_type[ttype](node_start)
842 node_end = xml.get("end")
843 if node_end is not None:
844 data["end"] = self.python_type[ttype](node_end)
845 return data
846
847 def add_viz(self, data, node_xml):
848 # add viz element for node
849 viz = {}
850 color = node_xml.find(f"{{{self.NS_VIZ}}}color")
851 if color is not None:
852 if self.VERSION == "1.1":
853 viz["color"] = {
854 "r": int(color.get("r")),
855 "g": int(color.get("g")),
856 "b": int(color.get("b")),
857 }
858 else:
859 viz["color"] = {
860 "r": int(color.get("r")),
861 "g": int(color.get("g")),
862 "b": int(color.get("b")),
863 "a": float(color.get("a", 1)),
864 }
865
866 size = node_xml.find(f"{{{self.NS_VIZ}}}size")
867 if size is not None:
868 viz["size"] = float(size.get("value"))
869
870 thickness = node_xml.find(f"{{{self.NS_VIZ}}}thickness")
871 if thickness is not None:
872 viz["thickness"] = float(thickness.get("value"))
873
874 shape = node_xml.find(f"{{{self.NS_VIZ}}}shape")
875 if shape is not None:
876 viz["shape"] = shape.get("shape")
877 if viz["shape"] == "image":
878 viz["shape"] = shape.get("uri")
879
880 position = node_xml.find(f"{{{self.NS_VIZ}}}position")
881 if position is not None:
882 viz["position"] = {
883 "x": float(position.get("x", 0)),
884 "y": float(position.get("y", 0)),
885 "z": float(position.get("z", 0)),
886 }
887
888 if len(viz) > 0:
889 data["viz"] = viz
890 return data
891
892 def add_parents(self, data, node_xml):
893 parents_element = node_xml.find(f"{{{self.NS_GEXF}}}parents")
894 if parents_element is not None:
895 data["parents"] = []
896 for p in parents_element.findall(f"{{{self.NS_GEXF}}}parent"):
897 parent = p.get("for")
898 data["parents"].append(parent)
899 return data
900
901 def add_slices(self, data, node_or_edge_xml):
902 slices_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}slices")
903 if slices_element is not None:
904 data["slices"] = []
905 for s in slices_element.findall(f"{{{self.NS_GEXF}}}slice"):
906 start = s.get("start")
907 end = s.get("end")
908 data["slices"].append((start, end))
909 return data
910
911 def add_spells(self, data, node_or_edge_xml):
912 spells_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}spells")
913 if spells_element is not None:
914 data["spells"] = []
915 ttype = self.timeformat
916 for s in spells_element.findall(f"{{{self.NS_GEXF}}}spell"):
917 start = self.python_type[ttype](s.get("start"))
918 end = self.python_type[ttype](s.get("end"))
919 data["spells"].append((start, end))
920 return data
921
922 def add_edge(self, G, edge_element, edge_attr):
923 # add an edge to the graph
924
925 # raise error if we find mixed directed and undirected edges
926 edge_direction = edge_element.get("type")
927 if G.is_directed() and edge_direction == "undirected":
928 raise nx.NetworkXError("Undirected edge found in directed graph.")
929 if (not G.is_directed()) and edge_direction == "directed":
930 raise nx.NetworkXError("Directed edge found in undirected graph.")
931
932 # Get source and target and recast type if required
933 source = edge_element.get("source")
934 target = edge_element.get("target")
935 if self.node_type is not None:
936 source = self.node_type(source)
937 target = self.node_type(target)
938
939 data = self.decode_attr_elements(edge_attr, edge_element)
940 data = self.add_start_end(data, edge_element)
941
942 if self.VERSION == "1.1":
943 data = self.add_slices(data, edge_element) # add slices
944 else:
945 data = self.add_spells(data, edge_element) # add spells
946
947 # GEXF stores edge ids as an attribute
948 # NetworkX uses them as keys in multigraphs
949 # if networkx_key is not specified as an attribute
950 edge_id = edge_element.get("id")
951 if edge_id is not None:
952 data["id"] = edge_id
953
954 # check if there is a 'multigraph_key' and use that as edge_id
955 multigraph_key = data.pop("networkx_key", None)
956 if multigraph_key is not None:
957 edge_id = multigraph_key
958
959 weight = edge_element.get("weight")
960 if weight is not None:
961 data["weight"] = float(weight)
962
963 edge_label = edge_element.get("label")
964 if edge_label is not None:
965 data["label"] = edge_label
966
967 if G.has_edge(source, target):
968 # seen this edge before - this is a multigraph
969 self.simple_graph = False
970 G.add_edge(source, target, key=edge_id, **data)
971 if edge_direction == "mutual":
972 G.add_edge(target, source, key=edge_id, **data)
973
974 def decode_attr_elements(self, gexf_keys, obj_xml):
975 # Use the key information to decode the attr XML
976 attr = {}
977 # look for outer '<attvalues>' element
978 attr_element = obj_xml.find(f"{{{self.NS_GEXF}}}attvalues")
979 if attr_element is not None:
980 # loop over <attvalue> elements
981 for a in attr_element.findall(f"{{{self.NS_GEXF}}}attvalue"):
982 key = a.get("for") # for is required
983 try: # should be in our gexf_keys dictionary
984 title = gexf_keys[key]["title"]
985 except KeyError as err:
986 raise nx.NetworkXError(f"No attribute defined for={key}.") from err
987 atype = gexf_keys[key]["type"]
988 value = a.get("value")
989 if atype == "boolean":
990 value = self.convert_bool[value]
991 else:
992 value = self.python_type[atype](value)
993 if gexf_keys[key]["mode"] == "dynamic":
994 # for dynamic graphs use list of three-tuples
995 # [(value1,start1,end1), (value2,start2,end2), etc]
996 ttype = self.timeformat
997 start = self.python_type[ttype](a.get("start"))
998 end = self.python_type[ttype](a.get("end"))
999 if title in attr:
1000 attr[title].append((value, start, end))
1001 else:
1002 attr[title] = [(value, start, end)]
1003 else:
1004 # for static graphs just assign the value
1005 attr[title] = value
1006 return attr
1007
1008 def find_gexf_attributes(self, attributes_element):
1009 # Extract all the attributes and defaults
1010 attrs = {}
1011 defaults = {}
1012 mode = attributes_element.get("mode")
1013 for k in attributes_element.findall(f"{{{self.NS_GEXF}}}attribute"):
1014 attr_id = k.get("id")
1015 title = k.get("title")
1016 atype = k.get("type")
1017 attrs[attr_id] = {"title": title, "type": atype, "mode": mode}
1018 # check for the 'default' subelement of key element and add
1019 default = k.find(f"{{{self.NS_GEXF}}}default")
1020 if default is not None:
1021 if atype == "boolean":
1022 value = self.convert_bool[default.text]
1023 else:
1024 value = self.python_type[atype](default.text)
1025 defaults[title] = value
1026 return attrs, defaults
1027
1028
1029def relabel_gexf_graph(G):
1030 """Relabel graph using "label" node keyword for node label.
1031
1032 Parameters
1033 ----------
1034 G : graph
1035 A NetworkX graph read from GEXF data
1036
1037 Returns
1038 -------
1039 H : graph
1040 A NetworkX graph with relabeled nodes
1041
1042 Raises
1043 ------
1044 NetworkXError
1045 If node labels are missing or not unique while relabel=True.
1046
1047 Notes
1048 -----
1049 This function relabels the nodes in a NetworkX graph with the
1050 "label" attribute. It also handles relabeling the specific GEXF
1051 node attributes "parents", and "pid".
1052 """
1053 # build mapping of node labels, do some error checking
1054 try:
1055 mapping = [(u, G.nodes[u]["label"]) for u in G]
1056 except KeyError as err:
1057 raise nx.NetworkXError(
1058 "Failed to relabel nodes: missing node labels found. Use relabel=False."
1059 ) from err
1060 x, y = zip(*mapping)
1061 if len(set(y)) != len(G):
1062 raise nx.NetworkXError(
1063 "Failed to relabel nodes: duplicate node labels found. Use relabel=False."
1064 )
1065 mapping = dict(mapping)
1066 H = nx.relabel_nodes(G, mapping)
1067 # relabel attributes
1068 for n in G:
1069 m = mapping[n]
1070 H.nodes[m]["id"] = n
1071 H.nodes[m].pop("label")
1072 if "pid" in H.nodes[m]:
1073 H.nodes[m]["pid"] = mapping[G.nodes[n]["pid"]]
1074 if "parents" in H.nodes[m]:
1075 H.nodes[m]["parents"] = [mapping[p] for p in G.nodes[n]["parents"]]
1076 return H