Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/networkx/readwrite/gexf.py: 8%
607 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-20 07:00 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-20 07:00 +0000
1"""Read and write graphs in GEXF format.
3.. warning::
4 This parser uses the standard xml library present in Python, which is
5 insecure - see :external+python:mod:`xml` for additional information.
6 Only parse GEFX files you trust.
8GEXF (Graph Exchange XML Format) is a language for describing complex
9network structures, their associated data and dynamics.
11This implementation does not support mixed graphs (directed and
12undirected edges together).
14Format
15------
16GEXF is an XML format. See http://gexf.net/schema.html for the
17specification and http://gexf.net/basic.html for examples.
18"""
19import itertools
20import time
21from xml.etree.ElementTree import (
22 Element,
23 ElementTree,
24 SubElement,
25 register_namespace,
26 tostring,
27)
29import networkx as nx
30from networkx.utils import open_file
32__all__ = ["write_gexf", "read_gexf", "relabel_gexf_graph", "generate_gexf"]
35@open_file(1, mode="wb")
36def write_gexf(G, path, encoding="utf-8", prettyprint=True, version="1.2draft"):
37 """Write G in GEXF format to path.
39 "GEXF (Graph Exchange XML Format) is a language for describing
40 complex networks structures, their associated data and dynamics" [1]_.
42 Node attributes are checked according to the version of the GEXF
43 schemas used for parameters which are not user defined,
44 e.g. visualization 'viz' [2]_. See example for usage.
46 Parameters
47 ----------
48 G : graph
49 A NetworkX graph
50 path : file or string
51 File or file name to write.
52 File names ending in .gz or .bz2 will be compressed.
53 encoding : string (optional, default: 'utf-8')
54 Encoding for text data.
55 prettyprint : bool (optional, default: True)
56 If True use line breaks and indenting in output XML.
57 version: string (optional, default: '1.2draft')
58 The version of GEXF to be used for nodes attributes checking
60 Examples
61 --------
62 >>> G = nx.path_graph(4)
63 >>> nx.write_gexf(G, "test.gexf")
65 # visualization data
66 >>> G.nodes[0]["viz"] = {"size": 54}
67 >>> G.nodes[0]["viz"]["position"] = {"x": 0, "y": 1}
68 >>> G.nodes[0]["viz"]["color"] = {"r": 0, "g": 0, "b": 256}
71 Notes
72 -----
73 This implementation does not support mixed graphs (directed and undirected
74 edges together).
76 The node id attribute is set to be the string of the node label.
77 If you want to specify an id use set it as node data, e.g.
78 node['a']['id']=1 to set the id of node 'a' to 1.
80 References
81 ----------
82 .. [1] GEXF File Format, http://gexf.net/
83 .. [2] GEXF schema, http://gexf.net/schema.html
84 """
85 writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version)
86 writer.add_graph(G)
87 writer.write(path)
90def generate_gexf(G, encoding="utf-8", prettyprint=True, version="1.2draft"):
91 """Generate lines of GEXF format representation of G.
93 "GEXF (Graph Exchange XML Format) is a language for describing
94 complex networks structures, their associated data and dynamics" [1]_.
96 Parameters
97 ----------
98 G : graph
99 A NetworkX graph
100 encoding : string (optional, default: 'utf-8')
101 Encoding for text data.
102 prettyprint : bool (optional, default: True)
103 If True use line breaks and indenting in output XML.
104 version : string (default: 1.2draft)
105 Version of GEFX File Format (see http://gexf.net/schema.html)
106 Supported values: "1.1draft", "1.2draft"
109 Examples
110 --------
111 >>> G = nx.path_graph(4)
112 >>> linefeed = chr(10) # linefeed=\n
113 >>> s = linefeed.join(nx.generate_gexf(G))
114 >>> for line in nx.generate_gexf(G): # doctest: +SKIP
115 ... print(line)
117 Notes
118 -----
119 This implementation does not support mixed graphs (directed and undirected
120 edges together).
122 The node id attribute is set to be the string of the node label.
123 If you want to specify an id use set it as node data, e.g.
124 node['a']['id']=1 to set the id of node 'a' to 1.
126 References
127 ----------
128 .. [1] GEXF File Format, https://gephi.org/gexf/format/
129 """
130 writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version)
131 writer.add_graph(G)
132 yield from str(writer).splitlines()
135@open_file(0, mode="rb")
136@nx._dispatch(graphs=None)
137def read_gexf(path, node_type=None, relabel=False, version="1.2draft"):
138 """Read graph in GEXF format from path.
140 "GEXF (Graph Exchange XML Format) is a language for describing
141 complex networks structures, their associated data and dynamics" [1]_.
143 Parameters
144 ----------
145 path : file or string
146 File or file name to read.
147 File names ending in .gz or .bz2 will be decompressed.
148 node_type: Python type (default: None)
149 Convert node ids to this type if not None.
150 relabel : bool (default: False)
151 If True relabel the nodes to use the GEXF node "label" attribute
152 instead of the node "id" attribute as the NetworkX node label.
153 version : string (default: 1.2draft)
154 Version of GEFX File Format (see http://gexf.net/schema.html)
155 Supported values: "1.1draft", "1.2draft"
157 Returns
158 -------
159 graph: NetworkX graph
160 If no parallel edges are found a Graph or DiGraph is returned.
161 Otherwise a MultiGraph or MultiDiGraph is returned.
163 Notes
164 -----
165 This implementation does not support mixed graphs (directed and undirected
166 edges together).
168 References
169 ----------
170 .. [1] GEXF File Format, http://gexf.net/
171 """
172 reader = GEXFReader(node_type=node_type, version=version)
173 if relabel:
174 G = relabel_gexf_graph(reader(path))
175 else:
176 G = reader(path)
177 return G
180class GEXF:
181 versions = {
182 "1.1draft": {
183 "NS_GEXF": "http://www.gexf.net/1.1draft",
184 "NS_VIZ": "http://www.gexf.net/1.1draft/viz",
185 "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance",
186 "SCHEMALOCATION": " ".join(
187 [
188 "http://www.gexf.net/1.1draft",
189 "http://www.gexf.net/1.1draft/gexf.xsd",
190 ]
191 ),
192 "VERSION": "1.1",
193 },
194 "1.2draft": {
195 "NS_GEXF": "http://www.gexf.net/1.2draft",
196 "NS_VIZ": "http://www.gexf.net/1.2draft/viz",
197 "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance",
198 "SCHEMALOCATION": " ".join(
199 [
200 "http://www.gexf.net/1.2draft",
201 "http://www.gexf.net/1.2draft/gexf.xsd",
202 ]
203 ),
204 "VERSION": "1.2",
205 },
206 }
208 def construct_types(self):
209 types = [
210 (int, "integer"),
211 (float, "float"),
212 (float, "double"),
213 (bool, "boolean"),
214 (list, "string"),
215 (dict, "string"),
216 (int, "long"),
217 (str, "liststring"),
218 (str, "anyURI"),
219 (str, "string"),
220 ]
222 # These additions to types allow writing numpy types
223 try:
224 import numpy as np
225 except ImportError:
226 pass
227 else:
228 # prepend so that python types are created upon read (last entry wins)
229 types = [
230 (np.float64, "float"),
231 (np.float32, "float"),
232 (np.float16, "float"),
233 (np.int_, "int"),
234 (np.int8, "int"),
235 (np.int16, "int"),
236 (np.int32, "int"),
237 (np.int64, "int"),
238 (np.uint8, "int"),
239 (np.uint16, "int"),
240 (np.uint32, "int"),
241 (np.uint64, "int"),
242 (np.int_, "int"),
243 (np.intc, "int"),
244 (np.intp, "int"),
245 ] + types
247 self.xml_type = dict(types)
248 self.python_type = dict(reversed(a) for a in types)
250 # http://www.w3.org/TR/xmlschema-2/#boolean
251 convert_bool = {
252 "true": True,
253 "false": False,
254 "True": True,
255 "False": False,
256 "0": False,
257 0: False,
258 "1": True,
259 1: True,
260 }
262 def set_version(self, version):
263 d = self.versions.get(version)
264 if d is None:
265 raise nx.NetworkXError(f"Unknown GEXF version {version}.")
266 self.NS_GEXF = d["NS_GEXF"]
267 self.NS_VIZ = d["NS_VIZ"]
268 self.NS_XSI = d["NS_XSI"]
269 self.SCHEMALOCATION = d["SCHEMALOCATION"]
270 self.VERSION = d["VERSION"]
271 self.version = version
274class GEXFWriter(GEXF):
275 # class for writing GEXF format files
276 # use write_gexf() function
277 def __init__(
278 self, graph=None, encoding="utf-8", prettyprint=True, version="1.2draft"
279 ):
280 self.construct_types()
281 self.prettyprint = prettyprint
282 self.encoding = encoding
283 self.set_version(version)
284 self.xml = Element(
285 "gexf",
286 {
287 "xmlns": self.NS_GEXF,
288 "xmlns:xsi": self.NS_XSI,
289 "xsi:schemaLocation": self.SCHEMALOCATION,
290 "version": self.VERSION,
291 },
292 )
294 # Make meta element a non-graph element
295 # Also add lastmodifieddate as attribute, not tag
296 meta_element = Element("meta")
297 subelement_text = f"NetworkX {nx.__version__}"
298 SubElement(meta_element, "creator").text = subelement_text
299 meta_element.set("lastmodifieddate", time.strftime("%Y-%m-%d"))
300 self.xml.append(meta_element)
302 register_namespace("viz", self.NS_VIZ)
304 # counters for edge and attribute identifiers
305 self.edge_id = itertools.count()
306 self.attr_id = itertools.count()
307 self.all_edge_ids = set()
308 # default attributes are stored in dictionaries
309 self.attr = {}
310 self.attr["node"] = {}
311 self.attr["edge"] = {}
312 self.attr["node"]["dynamic"] = {}
313 self.attr["node"]["static"] = {}
314 self.attr["edge"]["dynamic"] = {}
315 self.attr["edge"]["static"] = {}
317 if graph is not None:
318 self.add_graph(graph)
320 def __str__(self):
321 if self.prettyprint:
322 self.indent(self.xml)
323 s = tostring(self.xml).decode(self.encoding)
324 return s
326 def add_graph(self, G):
327 # first pass through G collecting edge ids
328 for u, v, dd in G.edges(data=True):
329 eid = dd.get("id")
330 if eid is not None:
331 self.all_edge_ids.add(str(eid))
332 # set graph attributes
333 if G.graph.get("mode") == "dynamic":
334 mode = "dynamic"
335 else:
336 mode = "static"
337 # Add a graph element to the XML
338 if G.is_directed():
339 default = "directed"
340 else:
341 default = "undirected"
342 name = G.graph.get("name", "")
343 graph_element = Element("graph", defaultedgetype=default, mode=mode, name=name)
344 self.graph_element = graph_element
345 self.add_nodes(G, graph_element)
346 self.add_edges(G, graph_element)
347 self.xml.append(graph_element)
349 def add_nodes(self, G, graph_element):
350 nodes_element = Element("nodes")
351 for node, data in G.nodes(data=True):
352 node_data = data.copy()
353 node_id = str(node_data.pop("id", node))
354 kw = {"id": node_id}
355 label = str(node_data.pop("label", node))
356 kw["label"] = label
357 try:
358 pid = node_data.pop("pid")
359 kw["pid"] = str(pid)
360 except KeyError:
361 pass
362 try:
363 start = node_data.pop("start")
364 kw["start"] = str(start)
365 self.alter_graph_mode_timeformat(start)
366 except KeyError:
367 pass
368 try:
369 end = node_data.pop("end")
370 kw["end"] = str(end)
371 self.alter_graph_mode_timeformat(end)
372 except KeyError:
373 pass
374 # add node element with attributes
375 node_element = Element("node", **kw)
376 # add node element and attr subelements
377 default = G.graph.get("node_default", {})
378 node_data = self.add_parents(node_element, node_data)
379 if self.VERSION == "1.1":
380 node_data = self.add_slices(node_element, node_data)
381 else:
382 node_data = self.add_spells(node_element, node_data)
383 node_data = self.add_viz(node_element, node_data)
384 node_data = self.add_attributes("node", node_element, node_data, default)
385 nodes_element.append(node_element)
386 graph_element.append(nodes_element)
388 def add_edges(self, G, graph_element):
389 def edge_key_data(G):
390 # helper function to unify multigraph and graph edge iterator
391 if G.is_multigraph():
392 for u, v, key, data in G.edges(data=True, keys=True):
393 edge_data = data.copy()
394 edge_data.update(key=key)
395 edge_id = edge_data.pop("id", None)
396 if edge_id is None:
397 edge_id = next(self.edge_id)
398 while str(edge_id) in self.all_edge_ids:
399 edge_id = next(self.edge_id)
400 self.all_edge_ids.add(str(edge_id))
401 yield u, v, edge_id, edge_data
402 else:
403 for u, v, data in G.edges(data=True):
404 edge_data = data.copy()
405 edge_id = edge_data.pop("id", None)
406 if edge_id is None:
407 edge_id = next(self.edge_id)
408 while str(edge_id) in self.all_edge_ids:
409 edge_id = next(self.edge_id)
410 self.all_edge_ids.add(str(edge_id))
411 yield u, v, edge_id, edge_data
413 edges_element = Element("edges")
414 for u, v, key, edge_data in edge_key_data(G):
415 kw = {"id": str(key)}
416 try:
417 edge_label = edge_data.pop("label")
418 kw["label"] = str(edge_label)
419 except KeyError:
420 pass
421 try:
422 edge_weight = edge_data.pop("weight")
423 kw["weight"] = str(edge_weight)
424 except KeyError:
425 pass
426 try:
427 edge_type = edge_data.pop("type")
428 kw["type"] = str(edge_type)
429 except KeyError:
430 pass
431 try:
432 start = edge_data.pop("start")
433 kw["start"] = str(start)
434 self.alter_graph_mode_timeformat(start)
435 except KeyError:
436 pass
437 try:
438 end = edge_data.pop("end")
439 kw["end"] = str(end)
440 self.alter_graph_mode_timeformat(end)
441 except KeyError:
442 pass
443 source_id = str(G.nodes[u].get("id", u))
444 target_id = str(G.nodes[v].get("id", v))
445 edge_element = Element("edge", source=source_id, target=target_id, **kw)
446 default = G.graph.get("edge_default", {})
447 if self.VERSION == "1.1":
448 edge_data = self.add_slices(edge_element, edge_data)
449 else:
450 edge_data = self.add_spells(edge_element, edge_data)
451 edge_data = self.add_viz(edge_element, edge_data)
452 edge_data = self.add_attributes("edge", edge_element, edge_data, default)
453 edges_element.append(edge_element)
454 graph_element.append(edges_element)
456 def add_attributes(self, node_or_edge, xml_obj, data, default):
457 # Add attrvalues to node or edge
458 attvalues = Element("attvalues")
459 if len(data) == 0:
460 return data
461 mode = "static"
462 for k, v in data.items():
463 # rename generic multigraph key to avoid any name conflict
464 if k == "key":
465 k = "networkx_key"
466 val_type = type(v)
467 if val_type not in self.xml_type:
468 raise TypeError(f"attribute value type is not allowed: {val_type}")
469 if isinstance(v, list):
470 # dynamic data
471 for val, start, end in v:
472 val_type = type(val)
473 if start is not None or end is not None:
474 mode = "dynamic"
475 self.alter_graph_mode_timeformat(start)
476 self.alter_graph_mode_timeformat(end)
477 break
478 attr_id = self.get_attr_id(
479 str(k), self.xml_type[val_type], node_or_edge, default, mode
480 )
481 for val, start, end in v:
482 e = Element("attvalue")
483 e.attrib["for"] = attr_id
484 e.attrib["value"] = str(val)
485 # Handle nan, inf, -inf differently
486 if val_type == float:
487 if e.attrib["value"] == "inf":
488 e.attrib["value"] = "INF"
489 elif e.attrib["value"] == "nan":
490 e.attrib["value"] = "NaN"
491 elif e.attrib["value"] == "-inf":
492 e.attrib["value"] = "-INF"
493 if start is not None:
494 e.attrib["start"] = str(start)
495 if end is not None:
496 e.attrib["end"] = str(end)
497 attvalues.append(e)
498 else:
499 # static data
500 mode = "static"
501 attr_id = self.get_attr_id(
502 str(k), self.xml_type[val_type], node_or_edge, default, mode
503 )
504 e = Element("attvalue")
505 e.attrib["for"] = attr_id
506 if isinstance(v, bool):
507 e.attrib["value"] = str(v).lower()
508 else:
509 e.attrib["value"] = str(v)
510 # Handle float nan, inf, -inf differently
511 if val_type == float:
512 if e.attrib["value"] == "inf":
513 e.attrib["value"] = "INF"
514 elif e.attrib["value"] == "nan":
515 e.attrib["value"] = "NaN"
516 elif e.attrib["value"] == "-inf":
517 e.attrib["value"] = "-INF"
518 attvalues.append(e)
519 xml_obj.append(attvalues)
520 return data
522 def get_attr_id(self, title, attr_type, edge_or_node, default, mode):
523 # find the id of the attribute or generate a new id
524 try:
525 return self.attr[edge_or_node][mode][title]
526 except KeyError:
527 # generate new id
528 new_id = str(next(self.attr_id))
529 self.attr[edge_or_node][mode][title] = new_id
530 attr_kwargs = {"id": new_id, "title": title, "type": attr_type}
531 attribute = Element("attribute", **attr_kwargs)
532 # add subelement for data default value if present
533 default_title = default.get(title)
534 if default_title is not None:
535 default_element = Element("default")
536 default_element.text = str(default_title)
537 attribute.append(default_element)
538 # new insert it into the XML
539 attributes_element = None
540 for a in self.graph_element.findall("attributes"):
541 # find existing attributes element by class and mode
542 a_class = a.get("class")
543 a_mode = a.get("mode", "static")
544 if a_class == edge_or_node and a_mode == mode:
545 attributes_element = a
546 if attributes_element is None:
547 # create new attributes element
548 attr_kwargs = {"mode": mode, "class": edge_or_node}
549 attributes_element = Element("attributes", **attr_kwargs)
550 self.graph_element.insert(0, attributes_element)
551 attributes_element.append(attribute)
552 return new_id
554 def add_viz(self, element, node_data):
555 viz = node_data.pop("viz", False)
556 if viz:
557 color = viz.get("color")
558 if color is not None:
559 if self.VERSION == "1.1":
560 e = Element(
561 f"{{{self.NS_VIZ}}}color",
562 r=str(color.get("r")),
563 g=str(color.get("g")),
564 b=str(color.get("b")),
565 )
566 else:
567 e = Element(
568 f"{{{self.NS_VIZ}}}color",
569 r=str(color.get("r")),
570 g=str(color.get("g")),
571 b=str(color.get("b")),
572 a=str(color.get("a", 1.0)),
573 )
574 element.append(e)
576 size = viz.get("size")
577 if size is not None:
578 e = Element(f"{{{self.NS_VIZ}}}size", value=str(size))
579 element.append(e)
581 thickness = viz.get("thickness")
582 if thickness is not None:
583 e = Element(f"{{{self.NS_VIZ}}}thickness", value=str(thickness))
584 element.append(e)
586 shape = viz.get("shape")
587 if shape is not None:
588 if shape.startswith("http"):
589 e = Element(
590 f"{{{self.NS_VIZ}}}shape", value="image", uri=str(shape)
591 )
592 else:
593 e = Element(f"{{{self.NS_VIZ}}}shape", value=str(shape))
594 element.append(e)
596 position = viz.get("position")
597 if position is not None:
598 e = Element(
599 f"{{{self.NS_VIZ}}}position",
600 x=str(position.get("x")),
601 y=str(position.get("y")),
602 z=str(position.get("z")),
603 )
604 element.append(e)
605 return node_data
607 def add_parents(self, node_element, node_data):
608 parents = node_data.pop("parents", False)
609 if parents:
610 parents_element = Element("parents")
611 for p in parents:
612 e = Element("parent")
613 e.attrib["for"] = str(p)
614 parents_element.append(e)
615 node_element.append(parents_element)
616 return node_data
618 def add_slices(self, node_or_edge_element, node_or_edge_data):
619 slices = node_or_edge_data.pop("slices", False)
620 if slices:
621 slices_element = Element("slices")
622 for start, end in slices:
623 e = Element("slice", start=str(start), end=str(end))
624 slices_element.append(e)
625 node_or_edge_element.append(slices_element)
626 return node_or_edge_data
628 def add_spells(self, node_or_edge_element, node_or_edge_data):
629 spells = node_or_edge_data.pop("spells", False)
630 if spells:
631 spells_element = Element("spells")
632 for start, end in spells:
633 e = Element("spell")
634 if start is not None:
635 e.attrib["start"] = str(start)
636 self.alter_graph_mode_timeformat(start)
637 if end is not None:
638 e.attrib["end"] = str(end)
639 self.alter_graph_mode_timeformat(end)
640 spells_element.append(e)
641 node_or_edge_element.append(spells_element)
642 return node_or_edge_data
644 def alter_graph_mode_timeformat(self, start_or_end):
645 # If 'start' or 'end' appears, alter Graph mode to dynamic and
646 # set timeformat
647 if self.graph_element.get("mode") == "static":
648 if start_or_end is not None:
649 if isinstance(start_or_end, str):
650 timeformat = "date"
651 elif isinstance(start_or_end, float):
652 timeformat = "double"
653 elif isinstance(start_or_end, int):
654 timeformat = "long"
655 else:
656 raise nx.NetworkXError(
657 "timeformat should be of the type int, float or str"
658 )
659 self.graph_element.set("timeformat", timeformat)
660 self.graph_element.set("mode", "dynamic")
662 def write(self, fh):
663 # Serialize graph G in GEXF to the open fh
664 if self.prettyprint:
665 self.indent(self.xml)
666 document = ElementTree(self.xml)
667 document.write(fh, encoding=self.encoding, xml_declaration=True)
669 def indent(self, elem, level=0):
670 # in-place prettyprint formatter
671 i = "\n" + " " * level
672 if len(elem):
673 if not elem.text or not elem.text.strip():
674 elem.text = i + " "
675 if not elem.tail or not elem.tail.strip():
676 elem.tail = i
677 for elem in elem:
678 self.indent(elem, level + 1)
679 if not elem.tail or not elem.tail.strip():
680 elem.tail = i
681 else:
682 if level and (not elem.tail or not elem.tail.strip()):
683 elem.tail = i
686class GEXFReader(GEXF):
687 # Class to read GEXF format files
688 # use read_gexf() function
689 def __init__(self, node_type=None, version="1.2draft"):
690 self.construct_types()
691 self.node_type = node_type
692 # assume simple graph and test for multigraph on read
693 self.simple_graph = True
694 self.set_version(version)
696 def __call__(self, stream):
697 self.xml = ElementTree(file=stream)
698 g = self.xml.find(f"{{{self.NS_GEXF}}}graph")
699 if g is not None:
700 return self.make_graph(g)
701 # try all the versions
702 for version in self.versions:
703 self.set_version(version)
704 g = self.xml.find(f"{{{self.NS_GEXF}}}graph")
705 if g is not None:
706 return self.make_graph(g)
707 raise nx.NetworkXError("No <graph> element in GEXF file.")
709 def make_graph(self, graph_xml):
710 # start with empty DiGraph or MultiDiGraph
711 edgedefault = graph_xml.get("defaultedgetype", None)
712 if edgedefault == "directed":
713 G = nx.MultiDiGraph()
714 else:
715 G = nx.MultiGraph()
717 # graph attributes
718 graph_name = graph_xml.get("name", "")
719 if graph_name != "":
720 G.graph["name"] = graph_name
721 graph_start = graph_xml.get("start")
722 if graph_start is not None:
723 G.graph["start"] = graph_start
724 graph_end = graph_xml.get("end")
725 if graph_end is not None:
726 G.graph["end"] = graph_end
727 graph_mode = graph_xml.get("mode", "")
728 if graph_mode == "dynamic":
729 G.graph["mode"] = "dynamic"
730 else:
731 G.graph["mode"] = "static"
733 # timeformat
734 self.timeformat = graph_xml.get("timeformat")
735 if self.timeformat == "date":
736 self.timeformat = "string"
738 # node and edge attributes
739 attributes_elements = graph_xml.findall(f"{{{self.NS_GEXF}}}attributes")
740 # dictionaries to hold attributes and attribute defaults
741 node_attr = {}
742 node_default = {}
743 edge_attr = {}
744 edge_default = {}
745 for a in attributes_elements:
746 attr_class = a.get("class")
747 if attr_class == "node":
748 na, nd = self.find_gexf_attributes(a)
749 node_attr.update(na)
750 node_default.update(nd)
751 G.graph["node_default"] = node_default
752 elif attr_class == "edge":
753 ea, ed = self.find_gexf_attributes(a)
754 edge_attr.update(ea)
755 edge_default.update(ed)
756 G.graph["edge_default"] = edge_default
757 else:
758 raise # unknown attribute class
760 # Hack to handle Gephi0.7beta bug
761 # add weight attribute
762 ea = {"weight": {"type": "double", "mode": "static", "title": "weight"}}
763 ed = {}
764 edge_attr.update(ea)
765 edge_default.update(ed)
766 G.graph["edge_default"] = edge_default
768 # add nodes
769 nodes_element = graph_xml.find(f"{{{self.NS_GEXF}}}nodes")
770 if nodes_element is not None:
771 for node_xml in nodes_element.findall(f"{{{self.NS_GEXF}}}node"):
772 self.add_node(G, node_xml, node_attr)
774 # add edges
775 edges_element = graph_xml.find(f"{{{self.NS_GEXF}}}edges")
776 if edges_element is not None:
777 for edge_xml in edges_element.findall(f"{{{self.NS_GEXF}}}edge"):
778 self.add_edge(G, edge_xml, edge_attr)
780 # switch to Graph or DiGraph if no parallel edges were found.
781 if self.simple_graph:
782 if G.is_directed():
783 G = nx.DiGraph(G)
784 else:
785 G = nx.Graph(G)
786 return G
788 def add_node(self, G, node_xml, node_attr, node_pid=None):
789 # add a single node with attributes to the graph
791 # get attributes and subattributues for node
792 data = self.decode_attr_elements(node_attr, node_xml)
793 data = self.add_parents(data, node_xml) # add any parents
794 if self.VERSION == "1.1":
795 data = self.add_slices(data, node_xml) # add slices
796 else:
797 data = self.add_spells(data, node_xml) # add spells
798 data = self.add_viz(data, node_xml) # add viz
799 data = self.add_start_end(data, node_xml) # add start/end
801 # find the node id and cast it to the appropriate type
802 node_id = node_xml.get("id")
803 if self.node_type is not None:
804 node_id = self.node_type(node_id)
806 # every node should have a label
807 node_label = node_xml.get("label")
808 data["label"] = node_label
810 # parent node id
811 node_pid = node_xml.get("pid", node_pid)
812 if node_pid is not None:
813 data["pid"] = node_pid
815 # check for subnodes, recursive
816 subnodes = node_xml.find(f"{{{self.NS_GEXF}}}nodes")
817 if subnodes is not None:
818 for node_xml in subnodes.findall(f"{{{self.NS_GEXF}}}node"):
819 self.add_node(G, node_xml, node_attr, node_pid=node_id)
821 G.add_node(node_id, **data)
823 def add_start_end(self, data, xml):
824 # start and end times
825 ttype = self.timeformat
826 node_start = xml.get("start")
827 if node_start is not None:
828 data["start"] = self.python_type[ttype](node_start)
829 node_end = xml.get("end")
830 if node_end is not None:
831 data["end"] = self.python_type[ttype](node_end)
832 return data
834 def add_viz(self, data, node_xml):
835 # add viz element for node
836 viz = {}
837 color = node_xml.find(f"{{{self.NS_VIZ}}}color")
838 if color is not None:
839 if self.VERSION == "1.1":
840 viz["color"] = {
841 "r": int(color.get("r")),
842 "g": int(color.get("g")),
843 "b": int(color.get("b")),
844 }
845 else:
846 viz["color"] = {
847 "r": int(color.get("r")),
848 "g": int(color.get("g")),
849 "b": int(color.get("b")),
850 "a": float(color.get("a", 1)),
851 }
853 size = node_xml.find(f"{{{self.NS_VIZ}}}size")
854 if size is not None:
855 viz["size"] = float(size.get("value"))
857 thickness = node_xml.find(f"{{{self.NS_VIZ}}}thickness")
858 if thickness is not None:
859 viz["thickness"] = float(thickness.get("value"))
861 shape = node_xml.find(f"{{{self.NS_VIZ}}}shape")
862 if shape is not None:
863 viz["shape"] = shape.get("shape")
864 if viz["shape"] == "image":
865 viz["shape"] = shape.get("uri")
867 position = node_xml.find(f"{{{self.NS_VIZ}}}position")
868 if position is not None:
869 viz["position"] = {
870 "x": float(position.get("x", 0)),
871 "y": float(position.get("y", 0)),
872 "z": float(position.get("z", 0)),
873 }
875 if len(viz) > 0:
876 data["viz"] = viz
877 return data
879 def add_parents(self, data, node_xml):
880 parents_element = node_xml.find(f"{{{self.NS_GEXF}}}parents")
881 if parents_element is not None:
882 data["parents"] = []
883 for p in parents_element.findall(f"{{{self.NS_GEXF}}}parent"):
884 parent = p.get("for")
885 data["parents"].append(parent)
886 return data
888 def add_slices(self, data, node_or_edge_xml):
889 slices_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}slices")
890 if slices_element is not None:
891 data["slices"] = []
892 for s in slices_element.findall(f"{{{self.NS_GEXF}}}slice"):
893 start = s.get("start")
894 end = s.get("end")
895 data["slices"].append((start, end))
896 return data
898 def add_spells(self, data, node_or_edge_xml):
899 spells_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}spells")
900 if spells_element is not None:
901 data["spells"] = []
902 ttype = self.timeformat
903 for s in spells_element.findall(f"{{{self.NS_GEXF}}}spell"):
904 start = self.python_type[ttype](s.get("start"))
905 end = self.python_type[ttype](s.get("end"))
906 data["spells"].append((start, end))
907 return data
909 def add_edge(self, G, edge_element, edge_attr):
910 # add an edge to the graph
912 # raise error if we find mixed directed and undirected edges
913 edge_direction = edge_element.get("type")
914 if G.is_directed() and edge_direction == "undirected":
915 raise nx.NetworkXError("Undirected edge found in directed graph.")
916 if (not G.is_directed()) and edge_direction == "directed":
917 raise nx.NetworkXError("Directed edge found in undirected graph.")
919 # Get source and target and recast type if required
920 source = edge_element.get("source")
921 target = edge_element.get("target")
922 if self.node_type is not None:
923 source = self.node_type(source)
924 target = self.node_type(target)
926 data = self.decode_attr_elements(edge_attr, edge_element)
927 data = self.add_start_end(data, edge_element)
929 if self.VERSION == "1.1":
930 data = self.add_slices(data, edge_element) # add slices
931 else:
932 data = self.add_spells(data, edge_element) # add spells
934 # GEXF stores edge ids as an attribute
935 # NetworkX uses them as keys in multigraphs
936 # if networkx_key is not specified as an attribute
937 edge_id = edge_element.get("id")
938 if edge_id is not None:
939 data["id"] = edge_id
941 # check if there is a 'multigraph_key' and use that as edge_id
942 multigraph_key = data.pop("networkx_key", None)
943 if multigraph_key is not None:
944 edge_id = multigraph_key
946 weight = edge_element.get("weight")
947 if weight is not None:
948 data["weight"] = float(weight)
950 edge_label = edge_element.get("label")
951 if edge_label is not None:
952 data["label"] = edge_label
954 if G.has_edge(source, target):
955 # seen this edge before - this is a multigraph
956 self.simple_graph = False
957 G.add_edge(source, target, key=edge_id, **data)
958 if edge_direction == "mutual":
959 G.add_edge(target, source, key=edge_id, **data)
961 def decode_attr_elements(self, gexf_keys, obj_xml):
962 # Use the key information to decode the attr XML
963 attr = {}
964 # look for outer '<attvalues>' element
965 attr_element = obj_xml.find(f"{{{self.NS_GEXF}}}attvalues")
966 if attr_element is not None:
967 # loop over <attvalue> elements
968 for a in attr_element.findall(f"{{{self.NS_GEXF}}}attvalue"):
969 key = a.get("for") # for is required
970 try: # should be in our gexf_keys dictionary
971 title = gexf_keys[key]["title"]
972 except KeyError as err:
973 raise nx.NetworkXError(f"No attribute defined for={key}.") from err
974 atype = gexf_keys[key]["type"]
975 value = a.get("value")
976 if atype == "boolean":
977 value = self.convert_bool[value]
978 else:
979 value = self.python_type[atype](value)
980 if gexf_keys[key]["mode"] == "dynamic":
981 # for dynamic graphs use list of three-tuples
982 # [(value1,start1,end1), (value2,start2,end2), etc]
983 ttype = self.timeformat
984 start = self.python_type[ttype](a.get("start"))
985 end = self.python_type[ttype](a.get("end"))
986 if title in attr:
987 attr[title].append((value, start, end))
988 else:
989 attr[title] = [(value, start, end)]
990 else:
991 # for static graphs just assign the value
992 attr[title] = value
993 return attr
995 def find_gexf_attributes(self, attributes_element):
996 # Extract all the attributes and defaults
997 attrs = {}
998 defaults = {}
999 mode = attributes_element.get("mode")
1000 for k in attributes_element.findall(f"{{{self.NS_GEXF}}}attribute"):
1001 attr_id = k.get("id")
1002 title = k.get("title")
1003 atype = k.get("type")
1004 attrs[attr_id] = {"title": title, "type": atype, "mode": mode}
1005 # check for the 'default' subelement of key element and add
1006 default = k.find(f"{{{self.NS_GEXF}}}default")
1007 if default is not None:
1008 if atype == "boolean":
1009 value = self.convert_bool[default.text]
1010 else:
1011 value = self.python_type[atype](default.text)
1012 defaults[title] = value
1013 return attrs, defaults
1016def relabel_gexf_graph(G):
1017 """Relabel graph using "label" node keyword for node label.
1019 Parameters
1020 ----------
1021 G : graph
1022 A NetworkX graph read from GEXF data
1024 Returns
1025 -------
1026 H : graph
1027 A NetworkX graph with relabeled nodes
1029 Raises
1030 ------
1031 NetworkXError
1032 If node labels are missing or not unique while relabel=True.
1034 Notes
1035 -----
1036 This function relabels the nodes in a NetworkX graph with the
1037 "label" attribute. It also handles relabeling the specific GEXF
1038 node attributes "parents", and "pid".
1039 """
1040 # build mapping of node labels, do some error checking
1041 try:
1042 mapping = [(u, G.nodes[u]["label"]) for u in G]
1043 except KeyError as err:
1044 raise nx.NetworkXError(
1045 "Failed to relabel nodes: missing node labels found. Use relabel=False."
1046 ) from err
1047 x, y = zip(*mapping)
1048 if len(set(y)) != len(G):
1049 raise nx.NetworkXError(
1050 "Failed to relabel nodes: "
1051 "duplicate node labels found. "
1052 "Use relabel=False."
1053 )
1054 mapping = dict(mapping)
1055 H = nx.relabel_nodes(G, mapping)
1056 # relabel attributes
1057 for n in G:
1058 m = mapping[n]
1059 H.nodes[m]["id"] = n
1060 H.nodes[m].pop("label")
1061 if "pid" in H.nodes[m]:
1062 H.nodes[m]["pid"] = mapping[G.nodes[n]["pid"]]
1063 if "parents" in H.nodes[m]:
1064 H.nodes[m]["parents"] = [mapping[p] for p in G.nodes[n]["parents"]]
1065 return H