Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/networkx/readwrite/gexf.py: 8%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

608 statements  

1"""Read and write graphs in GEXF format. 

2 

3.. warning:: 

4 This parser uses the standard xml library present in Python, which is 

5 insecure - see :external+python:mod:`xml` for additional information. 

6 Only parse GEFX files you trust. 

7 

8GEXF (Graph Exchange XML Format) is a language for describing complex 

9network structures, their associated data and dynamics. 

10 

11This implementation does not support mixed graphs (directed and 

12undirected edges together). 

13 

14Format 

15------ 

16GEXF is an XML format. See http://gexf.net/schema.html for the 

17specification and http://gexf.net/basic.html for examples. 

18""" 

19 

20import itertools 

21import time 

22from xml.etree.ElementTree import ( 

23 Element, 

24 ElementTree, 

25 SubElement, 

26 register_namespace, 

27 tostring, 

28) 

29 

30import networkx as nx 

31from networkx.utils import open_file 

32 

33__all__ = ["write_gexf", "read_gexf", "relabel_gexf_graph", "generate_gexf"] 

34 

35 

36@open_file(1, mode="wb") 

37def write_gexf(G, path, encoding="utf-8", prettyprint=True, version="1.2draft"): 

38 """Write G in GEXF format to path. 

39 

40 "GEXF (Graph Exchange XML Format) is a language for describing 

41 complex networks structures, their associated data and dynamics" [1]_. 

42 

43 Node attributes are checked according to the version of the GEXF 

44 schemas used for parameters which are not user defined, 

45 e.g. visualization 'viz' [2]_. See example for usage. 

46 

47 .. warning:: 

48 

49 The `GEXF specification <https://gexf.net/schema.html>`_ reserves some 

50 keywords (e.g. ``id``, ``pid``, ``label``, etc.) for specifying node/edge 

51 metadata in the file format. Ensure NetworkX node/edge attribute names 

52 do not use these special keywords to guarantee all attributes are preserved 

53 as expected when roundtripping to/from GEXF format. 

54 

55 Parameters 

56 ---------- 

57 G : graph 

58 A NetworkX graph 

59 path : file or string 

60 File or file name to write. 

61 File names ending in .gz or .bz2 will be compressed. 

62 encoding : string (optional, default: 'utf-8') 

63 Encoding for text data. 

64 prettyprint : bool (optional, default: True) 

65 If True use line breaks and indenting in output XML. 

66 version: string (optional, default: '1.2draft') 

67 The version of GEXF to be used for nodes attributes checking 

68 

69 Examples 

70 -------- 

71 >>> G = nx.path_graph(4) 

72 >>> nx.write_gexf(G, "test.gexf") 

73 

74 # visualization data 

75 >>> G.nodes[0]["viz"] = {"size": 54} 

76 >>> G.nodes[0]["viz"]["position"] = {"x": 0, "y": 1} 

77 >>> G.nodes[0]["viz"]["color"] = {"r": 0, "g": 0, "b": 256} 

78 

79 

80 Notes 

81 ----- 

82 This implementation does not support mixed graphs (directed and undirected 

83 edges together). 

84 

85 The node id attribute is set to be the string of the node label. 

86 If you want to specify an id use set it as node data, e.g. 

87 node['a']['id']=1 to set the id of node 'a' to 1. 

88 

89 References 

90 ---------- 

91 .. [1] GEXF File Format, http://gexf.net/ 

92 .. [2] GEXF schema, http://gexf.net/schema.html 

93 """ 

94 writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version) 

95 writer.add_graph(G) 

96 writer.write(path) 

97 

98 

99def generate_gexf(G, encoding="utf-8", prettyprint=True, version="1.2draft"): 

100 """Generate lines of GEXF format representation of G. 

101 

102 "GEXF (Graph Exchange XML Format) is a language for describing 

103 complex networks structures, their associated data and dynamics" [1]_. 

104 

105 Parameters 

106 ---------- 

107 G : graph 

108 A NetworkX graph 

109 encoding : string (optional, default: 'utf-8') 

110 Encoding for text data. 

111 prettyprint : bool (optional, default: True) 

112 If True use line breaks and indenting in output XML. 

113 version : string (default: 1.2draft) 

114 Version of GEFX File Format (see http://gexf.net/schema.html) 

115 Supported values: "1.1draft", "1.2draft" 

116 

117 

118 Examples 

119 -------- 

120 >>> G = nx.path_graph(4) 

121 >>> linefeed = chr(10) # linefeed=\n 

122 >>> s = linefeed.join(nx.generate_gexf(G)) 

123 >>> for line in nx.generate_gexf(G): # doctest: +SKIP 

124 ... print(line) 

125 

126 Notes 

127 ----- 

128 This implementation does not support mixed graphs (directed and undirected 

129 edges together). 

130 

131 The node id attribute is set to be the string of the node label. 

132 If you want to specify an id use set it as node data, e.g. 

133 node['a']['id']=1 to set the id of node 'a' to 1. 

134 

135 References 

136 ---------- 

137 .. [1] GEXF File Format, https://gephi.org/gexf/format/ 

138 """ 

139 writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version) 

140 writer.add_graph(G) 

141 yield from str(writer).splitlines() 

142 

143 

144@open_file(0, mode="rb") 

145@nx._dispatchable(graphs=None, returns_graph=True) 

146def read_gexf(path, node_type=None, relabel=False, version="1.2draft"): 

147 """Read graph in GEXF format from path. 

148 

149 "GEXF (Graph Exchange XML Format) is a language for describing 

150 complex networks structures, their associated data and dynamics" [1]_. 

151 

152 Parameters 

153 ---------- 

154 path : file or string 

155 Filename or file handle to read. 

156 Filenames ending in .gz or .bz2 will be decompressed. 

157 node_type: Python type (default: None) 

158 Convert node ids to this type if not None. 

159 relabel : bool (default: False) 

160 If True relabel the nodes to use the GEXF node "label" attribute 

161 instead of the node "id" attribute as the NetworkX node label. 

162 version : string (default: 1.2draft) 

163 Version of GEFX File Format (see http://gexf.net/schema.html) 

164 Supported values: "1.1draft", "1.2draft" 

165 

166 Returns 

167 ------- 

168 graph: NetworkX graph 

169 If no parallel edges are found a Graph or DiGraph is returned. 

170 Otherwise a MultiGraph or MultiDiGraph is returned. 

171 

172 Notes 

173 ----- 

174 This implementation does not support mixed graphs (directed and undirected 

175 edges together). 

176 

177 References 

178 ---------- 

179 .. [1] GEXF File Format, http://gexf.net/ 

180 """ 

181 reader = GEXFReader(node_type=node_type, version=version) 

182 if relabel: 

183 G = relabel_gexf_graph(reader(path)) 

184 else: 

185 G = reader(path) 

186 return G 

187 

188 

189class GEXF: 

190 versions = { 

191 "1.1draft": { 

192 "NS_GEXF": "http://www.gexf.net/1.1draft", 

193 "NS_VIZ": "http://www.gexf.net/1.1draft/viz", 

194 "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance", 

195 "SCHEMALOCATION": " ".join( 

196 [ 

197 "http://www.gexf.net/1.1draft", 

198 "http://www.gexf.net/1.1draft/gexf.xsd", 

199 ] 

200 ), 

201 "VERSION": "1.1", 

202 }, 

203 "1.2draft": { 

204 "NS_GEXF": "http://www.gexf.net/1.2draft", 

205 "NS_VIZ": "http://www.gexf.net/1.2draft/viz", 

206 "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance", 

207 "SCHEMALOCATION": " ".join( 

208 [ 

209 "http://www.gexf.net/1.2draft", 

210 "http://www.gexf.net/1.2draft/gexf.xsd", 

211 ] 

212 ), 

213 "VERSION": "1.2", 

214 }, 

215 "1.3": { 

216 "NS_GEXF": "http://gexf.net/1.3", 

217 "NS_VIZ": "http://gexf.net/1.3/viz", 

218 "NS_XSI": "http://w3.org/2001/XMLSchema-instance", 

219 "SCHEMALOCATION": " ".join( 

220 [ 

221 "http://gexf.net/1.3", 

222 "http://gexf.net/1.3/gexf.xsd", 

223 ] 

224 ), 

225 "VERSION": "1.3", 

226 }, 

227 } 

228 

229 def construct_types(self): 

230 types = [ 

231 (int, "integer"), 

232 (float, "float"), 

233 (float, "double"), 

234 (bool, "boolean"), 

235 (list, "string"), 

236 (dict, "string"), 

237 (int, "long"), 

238 (str, "liststring"), 

239 (str, "anyURI"), 

240 (str, "string"), 

241 ] 

242 

243 # These additions to types allow writing numpy types 

244 try: 

245 import numpy as np 

246 except ImportError: 

247 pass 

248 else: 

249 # prepend so that python types are created upon read (last entry wins) 

250 types = [ 

251 (np.float64, "float"), 

252 (np.float32, "float"), 

253 (np.float16, "float"), 

254 (np.int_, "int"), 

255 (np.int8, "int"), 

256 (np.int16, "int"), 

257 (np.int32, "int"), 

258 (np.int64, "int"), 

259 (np.uint8, "int"), 

260 (np.uint16, "int"), 

261 (np.uint32, "int"), 

262 (np.uint64, "int"), 

263 (np.int_, "int"), 

264 (np.intc, "int"), 

265 (np.intp, "int"), 

266 ] + types 

267 

268 self.xml_type = dict(types) 

269 self.python_type = dict(reversed(a) for a in types) 

270 

271 # http://www.w3.org/TR/xmlschema-2/#boolean 

272 convert_bool = { 

273 "true": True, 

274 "false": False, 

275 "True": True, 

276 "False": False, 

277 "0": False, 

278 0: False, 

279 "1": True, 

280 1: True, 

281 } 

282 

283 def set_version(self, version): 

284 d = self.versions.get(version) 

285 if d is None: 

286 raise nx.NetworkXError(f"Unknown GEXF version {version}.") 

287 self.NS_GEXF = d["NS_GEXF"] 

288 self.NS_VIZ = d["NS_VIZ"] 

289 self.NS_XSI = d["NS_XSI"] 

290 self.SCHEMALOCATION = d["SCHEMALOCATION"] 

291 self.VERSION = d["VERSION"] 

292 self.version = version 

293 

294 

295class GEXFWriter(GEXF): 

296 # class for writing GEXF format files 

297 # use write_gexf() function 

298 def __init__( 

299 self, graph=None, encoding="utf-8", prettyprint=True, version="1.2draft" 

300 ): 

301 self.construct_types() 

302 self.prettyprint = prettyprint 

303 self.encoding = encoding 

304 self.set_version(version) 

305 self.xml = Element( 

306 "gexf", 

307 { 

308 "xmlns": self.NS_GEXF, 

309 "xmlns:xsi": self.NS_XSI, 

310 "xsi:schemaLocation": self.SCHEMALOCATION, 

311 "version": self.VERSION, 

312 }, 

313 ) 

314 

315 # Make meta element a non-graph element 

316 # Also add lastmodifieddate as attribute, not tag 

317 meta_element = Element("meta") 

318 subelement_text = f"NetworkX {nx.__version__}" 

319 SubElement(meta_element, "creator").text = subelement_text 

320 meta_element.set("lastmodifieddate", time.strftime("%Y-%m-%d")) 

321 self.xml.append(meta_element) 

322 

323 register_namespace("viz", self.NS_VIZ) 

324 

325 # counters for edge and attribute identifiers 

326 self.edge_id = itertools.count() 

327 self.attr_id = itertools.count() 

328 self.all_edge_ids = set() 

329 # default attributes are stored in dictionaries 

330 self.attr = {} 

331 self.attr["node"] = {} 

332 self.attr["edge"] = {} 

333 self.attr["node"]["dynamic"] = {} 

334 self.attr["node"]["static"] = {} 

335 self.attr["edge"]["dynamic"] = {} 

336 self.attr["edge"]["static"] = {} 

337 

338 if graph is not None: 

339 self.add_graph(graph) 

340 

341 def __str__(self): 

342 if self.prettyprint: 

343 self.indent(self.xml) 

344 s = tostring(self.xml).decode(self.encoding) 

345 return s 

346 

347 def add_graph(self, G): 

348 # first pass through G collecting edge ids 

349 for u, v, dd in G.edges(data=True): 

350 eid = dd.get("id") 

351 if eid is not None: 

352 self.all_edge_ids.add(str(eid)) 

353 # set graph attributes 

354 if G.graph.get("mode") == "dynamic": 

355 mode = "dynamic" 

356 else: 

357 mode = "static" 

358 # Add a graph element to the XML 

359 if G.is_directed(): 

360 default = "directed" 

361 else: 

362 default = "undirected" 

363 name = G.graph.get("name", "") 

364 graph_element = Element("graph", defaultedgetype=default, mode=mode, name=name) 

365 self.graph_element = graph_element 

366 self.add_nodes(G, graph_element) 

367 self.add_edges(G, graph_element) 

368 self.xml.append(graph_element) 

369 

370 def add_nodes(self, G, graph_element): 

371 nodes_element = Element("nodes") 

372 for node, data in G.nodes(data=True): 

373 node_data = data.copy() 

374 node_id = str(node_data.pop("id", node)) 

375 kw = {"id": node_id} 

376 label = str(node_data.pop("label", node)) 

377 kw["label"] = label 

378 try: 

379 pid = node_data.pop("pid") 

380 kw["pid"] = str(pid) 

381 except KeyError: 

382 pass 

383 try: 

384 start = node_data.pop("start") 

385 kw["start"] = str(start) 

386 self.alter_graph_mode_timeformat(start) 

387 except KeyError: 

388 pass 

389 try: 

390 end = node_data.pop("end") 

391 kw["end"] = str(end) 

392 self.alter_graph_mode_timeformat(end) 

393 except KeyError: 

394 pass 

395 # add node element with attributes 

396 node_element = Element("node", **kw) 

397 # add node element and attr subelements 

398 default = G.graph.get("node_default", {}) 

399 node_data = self.add_parents(node_element, node_data) 

400 if self.VERSION == "1.1": 

401 node_data = self.add_slices(node_element, node_data) 

402 else: 

403 node_data = self.add_spells(node_element, node_data) 

404 node_data = self.add_viz(node_element, node_data) 

405 node_data = self.add_attributes("node", node_element, node_data, default) 

406 nodes_element.append(node_element) 

407 graph_element.append(nodes_element) 

408 

409 def add_edges(self, G, graph_element): 

410 def edge_key_data(G): 

411 # helper function to unify multigraph and graph edge iterator 

412 if G.is_multigraph(): 

413 for u, v, key, data in G.edges(data=True, keys=True): 

414 edge_data = data.copy() 

415 edge_data.update(key=key) 

416 edge_id = edge_data.pop("id", None) 

417 if edge_id is None: 

418 edge_id = next(self.edge_id) 

419 while str(edge_id) in self.all_edge_ids: 

420 edge_id = next(self.edge_id) 

421 self.all_edge_ids.add(str(edge_id)) 

422 yield u, v, edge_id, edge_data 

423 else: 

424 for u, v, data in G.edges(data=True): 

425 edge_data = data.copy() 

426 edge_id = edge_data.pop("id", None) 

427 if edge_id is None: 

428 edge_id = next(self.edge_id) 

429 while str(edge_id) in self.all_edge_ids: 

430 edge_id = next(self.edge_id) 

431 self.all_edge_ids.add(str(edge_id)) 

432 yield u, v, edge_id, edge_data 

433 

434 edges_element = Element("edges") 

435 for u, v, key, edge_data in edge_key_data(G): 

436 kw = {"id": str(key)} 

437 try: 

438 edge_label = edge_data.pop("label") 

439 kw["label"] = str(edge_label) 

440 except KeyError: 

441 pass 

442 try: 

443 edge_weight = edge_data.pop("weight") 

444 kw["weight"] = str(edge_weight) 

445 except KeyError: 

446 pass 

447 try: 

448 edge_type = edge_data.pop("type") 

449 kw["type"] = str(edge_type) 

450 except KeyError: 

451 pass 

452 try: 

453 start = edge_data.pop("start") 

454 kw["start"] = str(start) 

455 self.alter_graph_mode_timeformat(start) 

456 except KeyError: 

457 pass 

458 try: 

459 end = edge_data.pop("end") 

460 kw["end"] = str(end) 

461 self.alter_graph_mode_timeformat(end) 

462 except KeyError: 

463 pass 

464 source_id = str(G.nodes[u].get("id", u)) 

465 target_id = str(G.nodes[v].get("id", v)) 

466 edge_element = Element("edge", source=source_id, target=target_id, **kw) 

467 default = G.graph.get("edge_default", {}) 

468 if self.VERSION == "1.1": 

469 edge_data = self.add_slices(edge_element, edge_data) 

470 else: 

471 edge_data = self.add_spells(edge_element, edge_data) 

472 edge_data = self.add_viz(edge_element, edge_data) 

473 edge_data = self.add_attributes("edge", edge_element, edge_data, default) 

474 edges_element.append(edge_element) 

475 graph_element.append(edges_element) 

476 

477 def add_attributes(self, node_or_edge, xml_obj, data, default): 

478 # Add attrvalues to node or edge 

479 attvalues = Element("attvalues") 

480 if len(data) == 0: 

481 return data 

482 mode = "static" 

483 for k, v in data.items(): 

484 # rename generic multigraph key to avoid any name conflict 

485 if k == "key": 

486 k = "networkx_key" 

487 val_type = type(v) 

488 if val_type not in self.xml_type: 

489 raise TypeError(f"attribute value type is not allowed: {val_type}") 

490 if isinstance(v, list): 

491 # dynamic data 

492 for val, start, end in v: 

493 val_type = type(val) 

494 if start is not None or end is not None: 

495 mode = "dynamic" 

496 self.alter_graph_mode_timeformat(start) 

497 self.alter_graph_mode_timeformat(end) 

498 break 

499 attr_id = self.get_attr_id( 

500 str(k), self.xml_type[val_type], node_or_edge, default, mode 

501 ) 

502 for val, start, end in v: 

503 e = Element("attvalue") 

504 e.attrib["for"] = attr_id 

505 e.attrib["value"] = str(val) 

506 # Handle nan, inf, -inf differently 

507 if val_type is float: 

508 if e.attrib["value"] == "inf": 

509 e.attrib["value"] = "INF" 

510 elif e.attrib["value"] == "nan": 

511 e.attrib["value"] = "NaN" 

512 elif e.attrib["value"] == "-inf": 

513 e.attrib["value"] = "-INF" 

514 if start is not None: 

515 e.attrib["start"] = str(start) 

516 if end is not None: 

517 e.attrib["end"] = str(end) 

518 attvalues.append(e) 

519 else: 

520 # static data 

521 mode = "static" 

522 attr_id = self.get_attr_id( 

523 str(k), self.xml_type[val_type], node_or_edge, default, mode 

524 ) 

525 e = Element("attvalue") 

526 e.attrib["for"] = attr_id 

527 if isinstance(v, bool): 

528 e.attrib["value"] = str(v).lower() 

529 else: 

530 e.attrib["value"] = str(v) 

531 # Handle float nan, inf, -inf differently 

532 if val_type is float: 

533 if e.attrib["value"] == "inf": 

534 e.attrib["value"] = "INF" 

535 elif e.attrib["value"] == "nan": 

536 e.attrib["value"] = "NaN" 

537 elif e.attrib["value"] == "-inf": 

538 e.attrib["value"] = "-INF" 

539 attvalues.append(e) 

540 xml_obj.append(attvalues) 

541 return data 

542 

543 def get_attr_id(self, title, attr_type, edge_or_node, default, mode): 

544 # find the id of the attribute or generate a new id 

545 try: 

546 return self.attr[edge_or_node][mode][title] 

547 except KeyError: 

548 # generate new id 

549 new_id = str(next(self.attr_id)) 

550 self.attr[edge_or_node][mode][title] = new_id 

551 attr_kwargs = {"id": new_id, "title": title, "type": attr_type} 

552 attribute = Element("attribute", **attr_kwargs) 

553 # add subelement for data default value if present 

554 default_title = default.get(title) 

555 if default_title is not None: 

556 default_element = Element("default") 

557 default_element.text = str(default_title) 

558 attribute.append(default_element) 

559 # new insert it into the XML 

560 attributes_element = None 

561 for a in self.graph_element.findall("attributes"): 

562 # find existing attributes element by class and mode 

563 a_class = a.get("class") 

564 a_mode = a.get("mode", "static") 

565 if a_class == edge_or_node and a_mode == mode: 

566 attributes_element = a 

567 if attributes_element is None: 

568 # create new attributes element 

569 attr_kwargs = {"mode": mode, "class": edge_or_node} 

570 attributes_element = Element("attributes", **attr_kwargs) 

571 self.graph_element.insert(0, attributes_element) 

572 attributes_element.append(attribute) 

573 return new_id 

574 

575 def add_viz(self, element, node_data): 

576 viz = node_data.pop("viz", False) 

577 if viz: 

578 color = viz.get("color") 

579 if color is not None: 

580 if self.VERSION == "1.1": 

581 e = Element( 

582 f"{{{self.NS_VIZ}}}color", 

583 r=str(color.get("r")), 

584 g=str(color.get("g")), 

585 b=str(color.get("b")), 

586 ) 

587 else: 

588 e = Element( 

589 f"{{{self.NS_VIZ}}}color", 

590 r=str(color.get("r")), 

591 g=str(color.get("g")), 

592 b=str(color.get("b")), 

593 a=str(color.get("a", 1.0)), 

594 ) 

595 element.append(e) 

596 

597 size = viz.get("size") 

598 if size is not None: 

599 e = Element(f"{{{self.NS_VIZ}}}size", value=str(size)) 

600 element.append(e) 

601 

602 thickness = viz.get("thickness") 

603 if thickness is not None: 

604 e = Element(f"{{{self.NS_VIZ}}}thickness", value=str(thickness)) 

605 element.append(e) 

606 

607 shape = viz.get("shape") 

608 if shape is not None: 

609 if shape.startswith("http"): 

610 e = Element( 

611 f"{{{self.NS_VIZ}}}shape", value="image", uri=str(shape) 

612 ) 

613 else: 

614 e = Element(f"{{{self.NS_VIZ}}}shape", value=str(shape)) 

615 element.append(e) 

616 

617 position = viz.get("position") 

618 if position is not None: 

619 e = Element( 

620 f"{{{self.NS_VIZ}}}position", 

621 x=str(position.get("x")), 

622 y=str(position.get("y")), 

623 z=str(position.get("z")), 

624 ) 

625 element.append(e) 

626 return node_data 

627 

628 def add_parents(self, node_element, node_data): 

629 parents = node_data.pop("parents", False) 

630 if parents: 

631 parents_element = Element("parents") 

632 for p in parents: 

633 e = Element("parent") 

634 e.attrib["for"] = str(p) 

635 parents_element.append(e) 

636 node_element.append(parents_element) 

637 return node_data 

638 

639 def add_slices(self, node_or_edge_element, node_or_edge_data): 

640 slices = node_or_edge_data.pop("slices", False) 

641 if slices: 

642 slices_element = Element("slices") 

643 for start, end in slices: 

644 e = Element("slice", start=str(start), end=str(end)) 

645 slices_element.append(e) 

646 node_or_edge_element.append(slices_element) 

647 return node_or_edge_data 

648 

649 def add_spells(self, node_or_edge_element, node_or_edge_data): 

650 spells = node_or_edge_data.pop("spells", False) 

651 if spells: 

652 spells_element = Element("spells") 

653 for start, end in spells: 

654 e = Element("spell") 

655 if start is not None: 

656 e.attrib["start"] = str(start) 

657 self.alter_graph_mode_timeformat(start) 

658 if end is not None: 

659 e.attrib["end"] = str(end) 

660 self.alter_graph_mode_timeformat(end) 

661 spells_element.append(e) 

662 node_or_edge_element.append(spells_element) 

663 return node_or_edge_data 

664 

665 def alter_graph_mode_timeformat(self, start_or_end): 

666 # If 'start' or 'end' appears, set timeformat 

667 if start_or_end is not None: 

668 if isinstance(start_or_end, str): 

669 timeformat = "date" 

670 elif isinstance(start_or_end, float): 

671 timeformat = "double" 

672 elif isinstance(start_or_end, int): 

673 timeformat = "long" 

674 else: 

675 raise nx.NetworkXError( 

676 "timeformat should be of the type int, float or str" 

677 ) 

678 self.graph_element.set("timeformat", timeformat) 

679 # If Graph mode is static, alter to dynamic 

680 if self.graph_element.get("mode") == "static": 

681 self.graph_element.set("mode", "dynamic") 

682 

683 def write(self, fh): 

684 # Serialize graph G in GEXF to the open fh 

685 if self.prettyprint: 

686 self.indent(self.xml) 

687 document = ElementTree(self.xml) 

688 document.write(fh, encoding=self.encoding, xml_declaration=True) 

689 

690 def indent(self, elem, level=0): 

691 # in-place prettyprint formatter 

692 i = "\n" + " " * level 

693 if len(elem): 

694 if not elem.text or not elem.text.strip(): 

695 elem.text = i + " " 

696 if not elem.tail or not elem.tail.strip(): 

697 elem.tail = i 

698 for elem in elem: 

699 self.indent(elem, level + 1) 

700 if not elem.tail or not elem.tail.strip(): 

701 elem.tail = i 

702 else: 

703 if level and (not elem.tail or not elem.tail.strip()): 

704 elem.tail = i 

705 

706 

707class GEXFReader(GEXF): 

708 # Class to read GEXF format files 

709 # use read_gexf() function 

710 def __init__(self, node_type=None, version="1.2draft"): 

711 self.construct_types() 

712 self.node_type = node_type 

713 # assume simple graph and test for multigraph on read 

714 self.simple_graph = True 

715 self.set_version(version) 

716 

717 def __call__(self, stream): 

718 self.xml = ElementTree(file=stream) 

719 g = self.xml.find(f"{{{self.NS_GEXF}}}graph") 

720 if g is not None: 

721 return self.make_graph(g) 

722 # try all the versions 

723 for version in self.versions: 

724 self.set_version(version) 

725 g = self.xml.find(f"{{{self.NS_GEXF}}}graph") 

726 if g is not None: 

727 return self.make_graph(g) 

728 raise nx.NetworkXError("No <graph> element in GEXF file.") 

729 

730 def make_graph(self, graph_xml): 

731 # start with empty DiGraph or MultiDiGraph 

732 edgedefault = graph_xml.get("defaultedgetype", None) 

733 if edgedefault == "directed": 

734 G = nx.MultiDiGraph() 

735 else: 

736 G = nx.MultiGraph() 

737 

738 # graph attributes 

739 graph_name = graph_xml.get("name", "") 

740 if graph_name != "": 

741 G.graph["name"] = graph_name 

742 graph_start = graph_xml.get("start") 

743 if graph_start is not None: 

744 G.graph["start"] = graph_start 

745 graph_end = graph_xml.get("end") 

746 if graph_end is not None: 

747 G.graph["end"] = graph_end 

748 graph_mode = graph_xml.get("mode", "") 

749 if graph_mode == "dynamic": 

750 G.graph["mode"] = "dynamic" 

751 else: 

752 G.graph["mode"] = "static" 

753 

754 # timeformat 

755 self.timeformat = graph_xml.get("timeformat") 

756 if self.timeformat == "date": 

757 self.timeformat = "string" 

758 

759 # node and edge attributes 

760 attributes_elements = graph_xml.findall(f"{{{self.NS_GEXF}}}attributes") 

761 # dictionaries to hold attributes and attribute defaults 

762 node_attr = {} 

763 node_default = {} 

764 edge_attr = {} 

765 edge_default = {} 

766 for a in attributes_elements: 

767 attr_class = a.get("class") 

768 if attr_class == "node": 

769 na, nd = self.find_gexf_attributes(a) 

770 node_attr.update(na) 

771 node_default.update(nd) 

772 G.graph["node_default"] = node_default 

773 elif attr_class == "edge": 

774 ea, ed = self.find_gexf_attributes(a) 

775 edge_attr.update(ea) 

776 edge_default.update(ed) 

777 G.graph["edge_default"] = edge_default 

778 else: 

779 raise # unknown attribute class 

780 

781 # Hack to handle Gephi0.7beta bug 

782 # add weight attribute 

783 ea = {"weight": {"type": "double", "mode": "static", "title": "weight"}} 

784 ed = {} 

785 edge_attr.update(ea) 

786 edge_default.update(ed) 

787 G.graph["edge_default"] = edge_default 

788 

789 # add nodes 

790 nodes_element = graph_xml.find(f"{{{self.NS_GEXF}}}nodes") 

791 if nodes_element is not None: 

792 for node_xml in nodes_element.findall(f"{{{self.NS_GEXF}}}node"): 

793 self.add_node(G, node_xml, node_attr) 

794 

795 # add edges 

796 edges_element = graph_xml.find(f"{{{self.NS_GEXF}}}edges") 

797 if edges_element is not None: 

798 for edge_xml in edges_element.findall(f"{{{self.NS_GEXF}}}edge"): 

799 self.add_edge(G, edge_xml, edge_attr) 

800 

801 # switch to Graph or DiGraph if no parallel edges were found. 

802 if self.simple_graph: 

803 if G.is_directed(): 

804 G = nx.DiGraph(G) 

805 else: 

806 G = nx.Graph(G) 

807 return G 

808 

809 def add_node(self, G, node_xml, node_attr, node_pid=None): 

810 # add a single node with attributes to the graph 

811 

812 # get attributes and subattributues for node 

813 data = self.decode_attr_elements(node_attr, node_xml) 

814 data = self.add_parents(data, node_xml) # add any parents 

815 if self.VERSION == "1.1": 

816 data = self.add_slices(data, node_xml) # add slices 

817 else: 

818 data = self.add_spells(data, node_xml) # add spells 

819 data = self.add_viz(data, node_xml) # add viz 

820 data = self.add_start_end(data, node_xml) # add start/end 

821 

822 # find the node id and cast it to the appropriate type 

823 node_id = node_xml.get("id") 

824 if self.node_type is not None: 

825 node_id = self.node_type(node_id) 

826 

827 # every node should have a label 

828 node_label = node_xml.get("label") 

829 data["label"] = node_label 

830 

831 # parent node id 

832 node_pid = node_xml.get("pid", node_pid) 

833 if node_pid is not None: 

834 data["pid"] = node_pid 

835 

836 # check for subnodes, recursive 

837 subnodes = node_xml.find(f"{{{self.NS_GEXF}}}nodes") 

838 if subnodes is not None: 

839 for node_xml in subnodes.findall(f"{{{self.NS_GEXF}}}node"): 

840 self.add_node(G, node_xml, node_attr, node_pid=node_id) 

841 

842 G.add_node(node_id, **data) 

843 

844 def add_start_end(self, data, xml): 

845 # start and end times 

846 ttype = self.timeformat 

847 node_start = xml.get("start") 

848 if node_start is not None: 

849 data["start"] = self.python_type[ttype](node_start) 

850 node_end = xml.get("end") 

851 if node_end is not None: 

852 data["end"] = self.python_type[ttype](node_end) 

853 return data 

854 

855 def add_viz(self, data, node_xml): 

856 # add viz element for node 

857 viz = {} 

858 color = node_xml.find(f"{{{self.NS_VIZ}}}color") 

859 if color is not None: 

860 if self.VERSION == "1.1": 

861 viz["color"] = { 

862 "r": int(color.get("r")), 

863 "g": int(color.get("g")), 

864 "b": int(color.get("b")), 

865 } 

866 else: 

867 viz["color"] = { 

868 "r": int(color.get("r")), 

869 "g": int(color.get("g")), 

870 "b": int(color.get("b")), 

871 "a": float(color.get("a", 1)), 

872 } 

873 

874 size = node_xml.find(f"{{{self.NS_VIZ}}}size") 

875 if size is not None: 

876 viz["size"] = float(size.get("value")) 

877 

878 thickness = node_xml.find(f"{{{self.NS_VIZ}}}thickness") 

879 if thickness is not None: 

880 viz["thickness"] = float(thickness.get("value")) 

881 

882 shape = node_xml.find(f"{{{self.NS_VIZ}}}shape") 

883 if shape is not None: 

884 viz["shape"] = shape.get("shape") 

885 if viz["shape"] == "image": 

886 viz["shape"] = shape.get("uri") 

887 

888 position = node_xml.find(f"{{{self.NS_VIZ}}}position") 

889 if position is not None: 

890 viz["position"] = { 

891 "x": float(position.get("x", 0)), 

892 "y": float(position.get("y", 0)), 

893 "z": float(position.get("z", 0)), 

894 } 

895 

896 if len(viz) > 0: 

897 data["viz"] = viz 

898 return data 

899 

900 def add_parents(self, data, node_xml): 

901 parents_element = node_xml.find(f"{{{self.NS_GEXF}}}parents") 

902 if parents_element is not None: 

903 data["parents"] = [] 

904 for p in parents_element.findall(f"{{{self.NS_GEXF}}}parent"): 

905 parent = p.get("for") 

906 data["parents"].append(parent) 

907 return data 

908 

909 def add_slices(self, data, node_or_edge_xml): 

910 slices_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}slices") 

911 if slices_element is not None: 

912 data["slices"] = [] 

913 for s in slices_element.findall(f"{{{self.NS_GEXF}}}slice"): 

914 start = s.get("start") 

915 end = s.get("end") 

916 data["slices"].append((start, end)) 

917 return data 

918 

919 def add_spells(self, data, node_or_edge_xml): 

920 spells_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}spells") 

921 if spells_element is not None: 

922 data["spells"] = [] 

923 ttype = self.timeformat 

924 for s in spells_element.findall(f"{{{self.NS_GEXF}}}spell"): 

925 start = self.python_type[ttype](s.get("start")) 

926 end = self.python_type[ttype](s.get("end")) 

927 data["spells"].append((start, end)) 

928 return data 

929 

930 def add_edge(self, G, edge_element, edge_attr): 

931 # add an edge to the graph 

932 

933 # raise error if we find mixed directed and undirected edges 

934 edge_direction = edge_element.get("type") 

935 if G.is_directed() and edge_direction == "undirected": 

936 raise nx.NetworkXError("Undirected edge found in directed graph.") 

937 if (not G.is_directed()) and edge_direction == "directed": 

938 raise nx.NetworkXError("Directed edge found in undirected graph.") 

939 

940 # Get source and target and recast type if required 

941 source = edge_element.get("source") 

942 target = edge_element.get("target") 

943 if self.node_type is not None: 

944 source = self.node_type(source) 

945 target = self.node_type(target) 

946 

947 data = self.decode_attr_elements(edge_attr, edge_element) 

948 data = self.add_start_end(data, edge_element) 

949 

950 if self.VERSION == "1.1": 

951 data = self.add_slices(data, edge_element) # add slices 

952 else: 

953 data = self.add_spells(data, edge_element) # add spells 

954 

955 # GEXF stores edge ids as an attribute 

956 # NetworkX uses them as keys in multigraphs 

957 # if networkx_key is not specified as an attribute 

958 edge_id = edge_element.get("id") 

959 if edge_id is not None: 

960 data["id"] = edge_id 

961 

962 # check if there is a 'multigraph_key' and use that as edge_id 

963 multigraph_key = data.pop("networkx_key", None) 

964 if multigraph_key is not None: 

965 edge_id = multigraph_key 

966 

967 weight = edge_element.get("weight") 

968 if weight is not None: 

969 data["weight"] = float(weight) 

970 

971 edge_label = edge_element.get("label") 

972 if edge_label is not None: 

973 data["label"] = edge_label 

974 

975 if G.has_edge(source, target): 

976 # seen this edge before - this is a multigraph 

977 self.simple_graph = False 

978 G.add_edge(source, target, key=edge_id, **data) 

979 if edge_direction == "mutual": 

980 G.add_edge(target, source, key=edge_id, **data) 

981 

982 def decode_attr_elements(self, gexf_keys, obj_xml): 

983 # Use the key information to decode the attr XML 

984 attr = {} 

985 # look for outer '<attvalues>' element 

986 attr_element = obj_xml.find(f"{{{self.NS_GEXF}}}attvalues") 

987 if attr_element is not None: 

988 # loop over <attvalue> elements 

989 for a in attr_element.findall(f"{{{self.NS_GEXF}}}attvalue"): 

990 key = a.get("for") # for is required 

991 try: # should be in our gexf_keys dictionary 

992 title = gexf_keys[key]["title"] 

993 except KeyError as err: 

994 raise nx.NetworkXError(f"No attribute defined for={key}.") from err 

995 atype = gexf_keys[key]["type"] 

996 value = a.get("value") 

997 if atype == "boolean": 

998 value = self.convert_bool[value] 

999 else: 

1000 value = self.python_type[atype](value) 

1001 if gexf_keys[key]["mode"] == "dynamic": 

1002 # for dynamic graphs use list of three-tuples 

1003 # [(value1,start1,end1), (value2,start2,end2), etc] 

1004 ttype = self.timeformat 

1005 start = self.python_type[ttype](a.get("start")) 

1006 end = self.python_type[ttype](a.get("end")) 

1007 if title in attr: 

1008 attr[title].append((value, start, end)) 

1009 else: 

1010 attr[title] = [(value, start, end)] 

1011 else: 

1012 # for static graphs just assign the value 

1013 attr[title] = value 

1014 return attr 

1015 

1016 def find_gexf_attributes(self, attributes_element): 

1017 # Extract all the attributes and defaults 

1018 attrs = {} 

1019 defaults = {} 

1020 mode = attributes_element.get("mode") 

1021 for k in attributes_element.findall(f"{{{self.NS_GEXF}}}attribute"): 

1022 attr_id = k.get("id") 

1023 title = k.get("title") 

1024 atype = k.get("type") 

1025 attrs[attr_id] = {"title": title, "type": atype, "mode": mode} 

1026 # check for the 'default' subelement of key element and add 

1027 default = k.find(f"{{{self.NS_GEXF}}}default") 

1028 if default is not None: 

1029 if atype == "boolean": 

1030 value = self.convert_bool[default.text] 

1031 else: 

1032 value = self.python_type[atype](default.text) 

1033 defaults[title] = value 

1034 return attrs, defaults 

1035 

1036 

1037def relabel_gexf_graph(G): 

1038 """Relabel graph using "label" node keyword for node label. 

1039 

1040 Parameters 

1041 ---------- 

1042 G : graph 

1043 A NetworkX graph read from GEXF data 

1044 

1045 Returns 

1046 ------- 

1047 H : graph 

1048 A NetworkX graph with relabeled nodes 

1049 

1050 Raises 

1051 ------ 

1052 NetworkXError 

1053 If node labels are missing or not unique while relabel=True. 

1054 

1055 Notes 

1056 ----- 

1057 This function relabels the nodes in a NetworkX graph with the 

1058 "label" attribute. It also handles relabeling the specific GEXF 

1059 node attributes "parents", and "pid". 

1060 """ 

1061 # build mapping of node labels, do some error checking 

1062 try: 

1063 mapping = [(u, G.nodes[u]["label"]) for u in G] 

1064 except KeyError as err: 

1065 raise nx.NetworkXError( 

1066 "Failed to relabel nodes: missing node labels found. Use relabel=False." 

1067 ) from err 

1068 x, y = zip(*mapping) 

1069 if len(set(y)) != len(G): 

1070 raise nx.NetworkXError( 

1071 "Failed to relabel nodes: duplicate node labels found. Use relabel=False." 

1072 ) 

1073 mapping = dict(mapping) 

1074 H = nx.relabel_nodes(G, mapping) 

1075 # relabel attributes 

1076 for n in G: 

1077 m = mapping[n] 

1078 H.nodes[m]["id"] = n 

1079 H.nodes[m].pop("label") 

1080 if "pid" in H.nodes[m]: 

1081 H.nodes[m]["pid"] = mapping[G.nodes[n]["pid"]] 

1082 if "parents" in H.nodes[m]: 

1083 H.nodes[m]["parents"] = [mapping[p] for p in G.nodes[n]["parents"]] 

1084 return H