Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/networkx/readwrite/gexf.py: 8%

607 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-20 07:00 +0000

1"""Read and write graphs in GEXF format. 

2 

3.. warning:: 

4 This parser uses the standard xml library present in Python, which is 

5 insecure - see :external+python:mod:`xml` for additional information. 

6 Only parse GEFX files you trust. 

7 

8GEXF (Graph Exchange XML Format) is a language for describing complex 

9network structures, their associated data and dynamics. 

10 

11This implementation does not support mixed graphs (directed and 

12undirected edges together). 

13 

14Format 

15------ 

16GEXF is an XML format. See http://gexf.net/schema.html for the 

17specification and http://gexf.net/basic.html for examples. 

18""" 

19import itertools 

20import time 

21from xml.etree.ElementTree import ( 

22 Element, 

23 ElementTree, 

24 SubElement, 

25 register_namespace, 

26 tostring, 

27) 

28 

29import networkx as nx 

30from networkx.utils import open_file 

31 

32__all__ = ["write_gexf", "read_gexf", "relabel_gexf_graph", "generate_gexf"] 

33 

34 

35@open_file(1, mode="wb") 

36def write_gexf(G, path, encoding="utf-8", prettyprint=True, version="1.2draft"): 

37 """Write G in GEXF format to path. 

38 

39 "GEXF (Graph Exchange XML Format) is a language for describing 

40 complex networks structures, their associated data and dynamics" [1]_. 

41 

42 Node attributes are checked according to the version of the GEXF 

43 schemas used for parameters which are not user defined, 

44 e.g. visualization 'viz' [2]_. See example for usage. 

45 

46 Parameters 

47 ---------- 

48 G : graph 

49 A NetworkX graph 

50 path : file or string 

51 File or file name to write. 

52 File names ending in .gz or .bz2 will be compressed. 

53 encoding : string (optional, default: 'utf-8') 

54 Encoding for text data. 

55 prettyprint : bool (optional, default: True) 

56 If True use line breaks and indenting in output XML. 

57 version: string (optional, default: '1.2draft') 

58 The version of GEXF to be used for nodes attributes checking 

59 

60 Examples 

61 -------- 

62 >>> G = nx.path_graph(4) 

63 >>> nx.write_gexf(G, "test.gexf") 

64 

65 # visualization data 

66 >>> G.nodes[0]["viz"] = {"size": 54} 

67 >>> G.nodes[0]["viz"]["position"] = {"x": 0, "y": 1} 

68 >>> G.nodes[0]["viz"]["color"] = {"r": 0, "g": 0, "b": 256} 

69 

70 

71 Notes 

72 ----- 

73 This implementation does not support mixed graphs (directed and undirected 

74 edges together). 

75 

76 The node id attribute is set to be the string of the node label. 

77 If you want to specify an id use set it as node data, e.g. 

78 node['a']['id']=1 to set the id of node 'a' to 1. 

79 

80 References 

81 ---------- 

82 .. [1] GEXF File Format, http://gexf.net/ 

83 .. [2] GEXF schema, http://gexf.net/schema.html 

84 """ 

85 writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version) 

86 writer.add_graph(G) 

87 writer.write(path) 

88 

89 

90def generate_gexf(G, encoding="utf-8", prettyprint=True, version="1.2draft"): 

91 """Generate lines of GEXF format representation of G. 

92 

93 "GEXF (Graph Exchange XML Format) is a language for describing 

94 complex networks structures, their associated data and dynamics" [1]_. 

95 

96 Parameters 

97 ---------- 

98 G : graph 

99 A NetworkX graph 

100 encoding : string (optional, default: 'utf-8') 

101 Encoding for text data. 

102 prettyprint : bool (optional, default: True) 

103 If True use line breaks and indenting in output XML. 

104 version : string (default: 1.2draft) 

105 Version of GEFX File Format (see http://gexf.net/schema.html) 

106 Supported values: "1.1draft", "1.2draft" 

107 

108 

109 Examples 

110 -------- 

111 >>> G = nx.path_graph(4) 

112 >>> linefeed = chr(10) # linefeed=\n 

113 >>> s = linefeed.join(nx.generate_gexf(G)) 

114 >>> for line in nx.generate_gexf(G): # doctest: +SKIP 

115 ... print(line) 

116 

117 Notes 

118 ----- 

119 This implementation does not support mixed graphs (directed and undirected 

120 edges together). 

121 

122 The node id attribute is set to be the string of the node label. 

123 If you want to specify an id use set it as node data, e.g. 

124 node['a']['id']=1 to set the id of node 'a' to 1. 

125 

126 References 

127 ---------- 

128 .. [1] GEXF File Format, https://gephi.org/gexf/format/ 

129 """ 

130 writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version) 

131 writer.add_graph(G) 

132 yield from str(writer).splitlines() 

133 

134 

135@open_file(0, mode="rb") 

136@nx._dispatch(graphs=None) 

137def read_gexf(path, node_type=None, relabel=False, version="1.2draft"): 

138 """Read graph in GEXF format from path. 

139 

140 "GEXF (Graph Exchange XML Format) is a language for describing 

141 complex networks structures, their associated data and dynamics" [1]_. 

142 

143 Parameters 

144 ---------- 

145 path : file or string 

146 File or file name to read. 

147 File names ending in .gz or .bz2 will be decompressed. 

148 node_type: Python type (default: None) 

149 Convert node ids to this type if not None. 

150 relabel : bool (default: False) 

151 If True relabel the nodes to use the GEXF node "label" attribute 

152 instead of the node "id" attribute as the NetworkX node label. 

153 version : string (default: 1.2draft) 

154 Version of GEFX File Format (see http://gexf.net/schema.html) 

155 Supported values: "1.1draft", "1.2draft" 

156 

157 Returns 

158 ------- 

159 graph: NetworkX graph 

160 If no parallel edges are found a Graph or DiGraph is returned. 

161 Otherwise a MultiGraph or MultiDiGraph is returned. 

162 

163 Notes 

164 ----- 

165 This implementation does not support mixed graphs (directed and undirected 

166 edges together). 

167 

168 References 

169 ---------- 

170 .. [1] GEXF File Format, http://gexf.net/ 

171 """ 

172 reader = GEXFReader(node_type=node_type, version=version) 

173 if relabel: 

174 G = relabel_gexf_graph(reader(path)) 

175 else: 

176 G = reader(path) 

177 return G 

178 

179 

180class GEXF: 

181 versions = { 

182 "1.1draft": { 

183 "NS_GEXF": "http://www.gexf.net/1.1draft", 

184 "NS_VIZ": "http://www.gexf.net/1.1draft/viz", 

185 "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance", 

186 "SCHEMALOCATION": " ".join( 

187 [ 

188 "http://www.gexf.net/1.1draft", 

189 "http://www.gexf.net/1.1draft/gexf.xsd", 

190 ] 

191 ), 

192 "VERSION": "1.1", 

193 }, 

194 "1.2draft": { 

195 "NS_GEXF": "http://www.gexf.net/1.2draft", 

196 "NS_VIZ": "http://www.gexf.net/1.2draft/viz", 

197 "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance", 

198 "SCHEMALOCATION": " ".join( 

199 [ 

200 "http://www.gexf.net/1.2draft", 

201 "http://www.gexf.net/1.2draft/gexf.xsd", 

202 ] 

203 ), 

204 "VERSION": "1.2", 

205 }, 

206 } 

207 

208 def construct_types(self): 

209 types = [ 

210 (int, "integer"), 

211 (float, "float"), 

212 (float, "double"), 

213 (bool, "boolean"), 

214 (list, "string"), 

215 (dict, "string"), 

216 (int, "long"), 

217 (str, "liststring"), 

218 (str, "anyURI"), 

219 (str, "string"), 

220 ] 

221 

222 # These additions to types allow writing numpy types 

223 try: 

224 import numpy as np 

225 except ImportError: 

226 pass 

227 else: 

228 # prepend so that python types are created upon read (last entry wins) 

229 types = [ 

230 (np.float64, "float"), 

231 (np.float32, "float"), 

232 (np.float16, "float"), 

233 (np.int_, "int"), 

234 (np.int8, "int"), 

235 (np.int16, "int"), 

236 (np.int32, "int"), 

237 (np.int64, "int"), 

238 (np.uint8, "int"), 

239 (np.uint16, "int"), 

240 (np.uint32, "int"), 

241 (np.uint64, "int"), 

242 (np.int_, "int"), 

243 (np.intc, "int"), 

244 (np.intp, "int"), 

245 ] + types 

246 

247 self.xml_type = dict(types) 

248 self.python_type = dict(reversed(a) for a in types) 

249 

250 # http://www.w3.org/TR/xmlschema-2/#boolean 

251 convert_bool = { 

252 "true": True, 

253 "false": False, 

254 "True": True, 

255 "False": False, 

256 "0": False, 

257 0: False, 

258 "1": True, 

259 1: True, 

260 } 

261 

262 def set_version(self, version): 

263 d = self.versions.get(version) 

264 if d is None: 

265 raise nx.NetworkXError(f"Unknown GEXF version {version}.") 

266 self.NS_GEXF = d["NS_GEXF"] 

267 self.NS_VIZ = d["NS_VIZ"] 

268 self.NS_XSI = d["NS_XSI"] 

269 self.SCHEMALOCATION = d["SCHEMALOCATION"] 

270 self.VERSION = d["VERSION"] 

271 self.version = version 

272 

273 

274class GEXFWriter(GEXF): 

275 # class for writing GEXF format files 

276 # use write_gexf() function 

277 def __init__( 

278 self, graph=None, encoding="utf-8", prettyprint=True, version="1.2draft" 

279 ): 

280 self.construct_types() 

281 self.prettyprint = prettyprint 

282 self.encoding = encoding 

283 self.set_version(version) 

284 self.xml = Element( 

285 "gexf", 

286 { 

287 "xmlns": self.NS_GEXF, 

288 "xmlns:xsi": self.NS_XSI, 

289 "xsi:schemaLocation": self.SCHEMALOCATION, 

290 "version": self.VERSION, 

291 }, 

292 ) 

293 

294 # Make meta element a non-graph element 

295 # Also add lastmodifieddate as attribute, not tag 

296 meta_element = Element("meta") 

297 subelement_text = f"NetworkX {nx.__version__}" 

298 SubElement(meta_element, "creator").text = subelement_text 

299 meta_element.set("lastmodifieddate", time.strftime("%Y-%m-%d")) 

300 self.xml.append(meta_element) 

301 

302 register_namespace("viz", self.NS_VIZ) 

303 

304 # counters for edge and attribute identifiers 

305 self.edge_id = itertools.count() 

306 self.attr_id = itertools.count() 

307 self.all_edge_ids = set() 

308 # default attributes are stored in dictionaries 

309 self.attr = {} 

310 self.attr["node"] = {} 

311 self.attr["edge"] = {} 

312 self.attr["node"]["dynamic"] = {} 

313 self.attr["node"]["static"] = {} 

314 self.attr["edge"]["dynamic"] = {} 

315 self.attr["edge"]["static"] = {} 

316 

317 if graph is not None: 

318 self.add_graph(graph) 

319 

320 def __str__(self): 

321 if self.prettyprint: 

322 self.indent(self.xml) 

323 s = tostring(self.xml).decode(self.encoding) 

324 return s 

325 

326 def add_graph(self, G): 

327 # first pass through G collecting edge ids 

328 for u, v, dd in G.edges(data=True): 

329 eid = dd.get("id") 

330 if eid is not None: 

331 self.all_edge_ids.add(str(eid)) 

332 # set graph attributes 

333 if G.graph.get("mode") == "dynamic": 

334 mode = "dynamic" 

335 else: 

336 mode = "static" 

337 # Add a graph element to the XML 

338 if G.is_directed(): 

339 default = "directed" 

340 else: 

341 default = "undirected" 

342 name = G.graph.get("name", "") 

343 graph_element = Element("graph", defaultedgetype=default, mode=mode, name=name) 

344 self.graph_element = graph_element 

345 self.add_nodes(G, graph_element) 

346 self.add_edges(G, graph_element) 

347 self.xml.append(graph_element) 

348 

349 def add_nodes(self, G, graph_element): 

350 nodes_element = Element("nodes") 

351 for node, data in G.nodes(data=True): 

352 node_data = data.copy() 

353 node_id = str(node_data.pop("id", node)) 

354 kw = {"id": node_id} 

355 label = str(node_data.pop("label", node)) 

356 kw["label"] = label 

357 try: 

358 pid = node_data.pop("pid") 

359 kw["pid"] = str(pid) 

360 except KeyError: 

361 pass 

362 try: 

363 start = node_data.pop("start") 

364 kw["start"] = str(start) 

365 self.alter_graph_mode_timeformat(start) 

366 except KeyError: 

367 pass 

368 try: 

369 end = node_data.pop("end") 

370 kw["end"] = str(end) 

371 self.alter_graph_mode_timeformat(end) 

372 except KeyError: 

373 pass 

374 # add node element with attributes 

375 node_element = Element("node", **kw) 

376 # add node element and attr subelements 

377 default = G.graph.get("node_default", {}) 

378 node_data = self.add_parents(node_element, node_data) 

379 if self.VERSION == "1.1": 

380 node_data = self.add_slices(node_element, node_data) 

381 else: 

382 node_data = self.add_spells(node_element, node_data) 

383 node_data = self.add_viz(node_element, node_data) 

384 node_data = self.add_attributes("node", node_element, node_data, default) 

385 nodes_element.append(node_element) 

386 graph_element.append(nodes_element) 

387 

388 def add_edges(self, G, graph_element): 

389 def edge_key_data(G): 

390 # helper function to unify multigraph and graph edge iterator 

391 if G.is_multigraph(): 

392 for u, v, key, data in G.edges(data=True, keys=True): 

393 edge_data = data.copy() 

394 edge_data.update(key=key) 

395 edge_id = edge_data.pop("id", None) 

396 if edge_id is None: 

397 edge_id = next(self.edge_id) 

398 while str(edge_id) in self.all_edge_ids: 

399 edge_id = next(self.edge_id) 

400 self.all_edge_ids.add(str(edge_id)) 

401 yield u, v, edge_id, edge_data 

402 else: 

403 for u, v, data in G.edges(data=True): 

404 edge_data = data.copy() 

405 edge_id = edge_data.pop("id", None) 

406 if edge_id is None: 

407 edge_id = next(self.edge_id) 

408 while str(edge_id) in self.all_edge_ids: 

409 edge_id = next(self.edge_id) 

410 self.all_edge_ids.add(str(edge_id)) 

411 yield u, v, edge_id, edge_data 

412 

413 edges_element = Element("edges") 

414 for u, v, key, edge_data in edge_key_data(G): 

415 kw = {"id": str(key)} 

416 try: 

417 edge_label = edge_data.pop("label") 

418 kw["label"] = str(edge_label) 

419 except KeyError: 

420 pass 

421 try: 

422 edge_weight = edge_data.pop("weight") 

423 kw["weight"] = str(edge_weight) 

424 except KeyError: 

425 pass 

426 try: 

427 edge_type = edge_data.pop("type") 

428 kw["type"] = str(edge_type) 

429 except KeyError: 

430 pass 

431 try: 

432 start = edge_data.pop("start") 

433 kw["start"] = str(start) 

434 self.alter_graph_mode_timeformat(start) 

435 except KeyError: 

436 pass 

437 try: 

438 end = edge_data.pop("end") 

439 kw["end"] = str(end) 

440 self.alter_graph_mode_timeformat(end) 

441 except KeyError: 

442 pass 

443 source_id = str(G.nodes[u].get("id", u)) 

444 target_id = str(G.nodes[v].get("id", v)) 

445 edge_element = Element("edge", source=source_id, target=target_id, **kw) 

446 default = G.graph.get("edge_default", {}) 

447 if self.VERSION == "1.1": 

448 edge_data = self.add_slices(edge_element, edge_data) 

449 else: 

450 edge_data = self.add_spells(edge_element, edge_data) 

451 edge_data = self.add_viz(edge_element, edge_data) 

452 edge_data = self.add_attributes("edge", edge_element, edge_data, default) 

453 edges_element.append(edge_element) 

454 graph_element.append(edges_element) 

455 

456 def add_attributes(self, node_or_edge, xml_obj, data, default): 

457 # Add attrvalues to node or edge 

458 attvalues = Element("attvalues") 

459 if len(data) == 0: 

460 return data 

461 mode = "static" 

462 for k, v in data.items(): 

463 # rename generic multigraph key to avoid any name conflict 

464 if k == "key": 

465 k = "networkx_key" 

466 val_type = type(v) 

467 if val_type not in self.xml_type: 

468 raise TypeError(f"attribute value type is not allowed: {val_type}") 

469 if isinstance(v, list): 

470 # dynamic data 

471 for val, start, end in v: 

472 val_type = type(val) 

473 if start is not None or end is not None: 

474 mode = "dynamic" 

475 self.alter_graph_mode_timeformat(start) 

476 self.alter_graph_mode_timeformat(end) 

477 break 

478 attr_id = self.get_attr_id( 

479 str(k), self.xml_type[val_type], node_or_edge, default, mode 

480 ) 

481 for val, start, end in v: 

482 e = Element("attvalue") 

483 e.attrib["for"] = attr_id 

484 e.attrib["value"] = str(val) 

485 # Handle nan, inf, -inf differently 

486 if val_type == float: 

487 if e.attrib["value"] == "inf": 

488 e.attrib["value"] = "INF" 

489 elif e.attrib["value"] == "nan": 

490 e.attrib["value"] = "NaN" 

491 elif e.attrib["value"] == "-inf": 

492 e.attrib["value"] = "-INF" 

493 if start is not None: 

494 e.attrib["start"] = str(start) 

495 if end is not None: 

496 e.attrib["end"] = str(end) 

497 attvalues.append(e) 

498 else: 

499 # static data 

500 mode = "static" 

501 attr_id = self.get_attr_id( 

502 str(k), self.xml_type[val_type], node_or_edge, default, mode 

503 ) 

504 e = Element("attvalue") 

505 e.attrib["for"] = attr_id 

506 if isinstance(v, bool): 

507 e.attrib["value"] = str(v).lower() 

508 else: 

509 e.attrib["value"] = str(v) 

510 # Handle float nan, inf, -inf differently 

511 if val_type == float: 

512 if e.attrib["value"] == "inf": 

513 e.attrib["value"] = "INF" 

514 elif e.attrib["value"] == "nan": 

515 e.attrib["value"] = "NaN" 

516 elif e.attrib["value"] == "-inf": 

517 e.attrib["value"] = "-INF" 

518 attvalues.append(e) 

519 xml_obj.append(attvalues) 

520 return data 

521 

522 def get_attr_id(self, title, attr_type, edge_or_node, default, mode): 

523 # find the id of the attribute or generate a new id 

524 try: 

525 return self.attr[edge_or_node][mode][title] 

526 except KeyError: 

527 # generate new id 

528 new_id = str(next(self.attr_id)) 

529 self.attr[edge_or_node][mode][title] = new_id 

530 attr_kwargs = {"id": new_id, "title": title, "type": attr_type} 

531 attribute = Element("attribute", **attr_kwargs) 

532 # add subelement for data default value if present 

533 default_title = default.get(title) 

534 if default_title is not None: 

535 default_element = Element("default") 

536 default_element.text = str(default_title) 

537 attribute.append(default_element) 

538 # new insert it into the XML 

539 attributes_element = None 

540 for a in self.graph_element.findall("attributes"): 

541 # find existing attributes element by class and mode 

542 a_class = a.get("class") 

543 a_mode = a.get("mode", "static") 

544 if a_class == edge_or_node and a_mode == mode: 

545 attributes_element = a 

546 if attributes_element is None: 

547 # create new attributes element 

548 attr_kwargs = {"mode": mode, "class": edge_or_node} 

549 attributes_element = Element("attributes", **attr_kwargs) 

550 self.graph_element.insert(0, attributes_element) 

551 attributes_element.append(attribute) 

552 return new_id 

553 

554 def add_viz(self, element, node_data): 

555 viz = node_data.pop("viz", False) 

556 if viz: 

557 color = viz.get("color") 

558 if color is not None: 

559 if self.VERSION == "1.1": 

560 e = Element( 

561 f"{{{self.NS_VIZ}}}color", 

562 r=str(color.get("r")), 

563 g=str(color.get("g")), 

564 b=str(color.get("b")), 

565 ) 

566 else: 

567 e = Element( 

568 f"{{{self.NS_VIZ}}}color", 

569 r=str(color.get("r")), 

570 g=str(color.get("g")), 

571 b=str(color.get("b")), 

572 a=str(color.get("a", 1.0)), 

573 ) 

574 element.append(e) 

575 

576 size = viz.get("size") 

577 if size is not None: 

578 e = Element(f"{{{self.NS_VIZ}}}size", value=str(size)) 

579 element.append(e) 

580 

581 thickness = viz.get("thickness") 

582 if thickness is not None: 

583 e = Element(f"{{{self.NS_VIZ}}}thickness", value=str(thickness)) 

584 element.append(e) 

585 

586 shape = viz.get("shape") 

587 if shape is not None: 

588 if shape.startswith("http"): 

589 e = Element( 

590 f"{{{self.NS_VIZ}}}shape", value="image", uri=str(shape) 

591 ) 

592 else: 

593 e = Element(f"{{{self.NS_VIZ}}}shape", value=str(shape)) 

594 element.append(e) 

595 

596 position = viz.get("position") 

597 if position is not None: 

598 e = Element( 

599 f"{{{self.NS_VIZ}}}position", 

600 x=str(position.get("x")), 

601 y=str(position.get("y")), 

602 z=str(position.get("z")), 

603 ) 

604 element.append(e) 

605 return node_data 

606 

607 def add_parents(self, node_element, node_data): 

608 parents = node_data.pop("parents", False) 

609 if parents: 

610 parents_element = Element("parents") 

611 for p in parents: 

612 e = Element("parent") 

613 e.attrib["for"] = str(p) 

614 parents_element.append(e) 

615 node_element.append(parents_element) 

616 return node_data 

617 

618 def add_slices(self, node_or_edge_element, node_or_edge_data): 

619 slices = node_or_edge_data.pop("slices", False) 

620 if slices: 

621 slices_element = Element("slices") 

622 for start, end in slices: 

623 e = Element("slice", start=str(start), end=str(end)) 

624 slices_element.append(e) 

625 node_or_edge_element.append(slices_element) 

626 return node_or_edge_data 

627 

628 def add_spells(self, node_or_edge_element, node_or_edge_data): 

629 spells = node_or_edge_data.pop("spells", False) 

630 if spells: 

631 spells_element = Element("spells") 

632 for start, end in spells: 

633 e = Element("spell") 

634 if start is not None: 

635 e.attrib["start"] = str(start) 

636 self.alter_graph_mode_timeformat(start) 

637 if end is not None: 

638 e.attrib["end"] = str(end) 

639 self.alter_graph_mode_timeformat(end) 

640 spells_element.append(e) 

641 node_or_edge_element.append(spells_element) 

642 return node_or_edge_data 

643 

644 def alter_graph_mode_timeformat(self, start_or_end): 

645 # If 'start' or 'end' appears, alter Graph mode to dynamic and 

646 # set timeformat 

647 if self.graph_element.get("mode") == "static": 

648 if start_or_end is not None: 

649 if isinstance(start_or_end, str): 

650 timeformat = "date" 

651 elif isinstance(start_or_end, float): 

652 timeformat = "double" 

653 elif isinstance(start_or_end, int): 

654 timeformat = "long" 

655 else: 

656 raise nx.NetworkXError( 

657 "timeformat should be of the type int, float or str" 

658 ) 

659 self.graph_element.set("timeformat", timeformat) 

660 self.graph_element.set("mode", "dynamic") 

661 

662 def write(self, fh): 

663 # Serialize graph G in GEXF to the open fh 

664 if self.prettyprint: 

665 self.indent(self.xml) 

666 document = ElementTree(self.xml) 

667 document.write(fh, encoding=self.encoding, xml_declaration=True) 

668 

669 def indent(self, elem, level=0): 

670 # in-place prettyprint formatter 

671 i = "\n" + " " * level 

672 if len(elem): 

673 if not elem.text or not elem.text.strip(): 

674 elem.text = i + " " 

675 if not elem.tail or not elem.tail.strip(): 

676 elem.tail = i 

677 for elem in elem: 

678 self.indent(elem, level + 1) 

679 if not elem.tail or not elem.tail.strip(): 

680 elem.tail = i 

681 else: 

682 if level and (not elem.tail or not elem.tail.strip()): 

683 elem.tail = i 

684 

685 

686class GEXFReader(GEXF): 

687 # Class to read GEXF format files 

688 # use read_gexf() function 

689 def __init__(self, node_type=None, version="1.2draft"): 

690 self.construct_types() 

691 self.node_type = node_type 

692 # assume simple graph and test for multigraph on read 

693 self.simple_graph = True 

694 self.set_version(version) 

695 

696 def __call__(self, stream): 

697 self.xml = ElementTree(file=stream) 

698 g = self.xml.find(f"{{{self.NS_GEXF}}}graph") 

699 if g is not None: 

700 return self.make_graph(g) 

701 # try all the versions 

702 for version in self.versions: 

703 self.set_version(version) 

704 g = self.xml.find(f"{{{self.NS_GEXF}}}graph") 

705 if g is not None: 

706 return self.make_graph(g) 

707 raise nx.NetworkXError("No <graph> element in GEXF file.") 

708 

709 def make_graph(self, graph_xml): 

710 # start with empty DiGraph or MultiDiGraph 

711 edgedefault = graph_xml.get("defaultedgetype", None) 

712 if edgedefault == "directed": 

713 G = nx.MultiDiGraph() 

714 else: 

715 G = nx.MultiGraph() 

716 

717 # graph attributes 

718 graph_name = graph_xml.get("name", "") 

719 if graph_name != "": 

720 G.graph["name"] = graph_name 

721 graph_start = graph_xml.get("start") 

722 if graph_start is not None: 

723 G.graph["start"] = graph_start 

724 graph_end = graph_xml.get("end") 

725 if graph_end is not None: 

726 G.graph["end"] = graph_end 

727 graph_mode = graph_xml.get("mode", "") 

728 if graph_mode == "dynamic": 

729 G.graph["mode"] = "dynamic" 

730 else: 

731 G.graph["mode"] = "static" 

732 

733 # timeformat 

734 self.timeformat = graph_xml.get("timeformat") 

735 if self.timeformat == "date": 

736 self.timeformat = "string" 

737 

738 # node and edge attributes 

739 attributes_elements = graph_xml.findall(f"{{{self.NS_GEXF}}}attributes") 

740 # dictionaries to hold attributes and attribute defaults 

741 node_attr = {} 

742 node_default = {} 

743 edge_attr = {} 

744 edge_default = {} 

745 for a in attributes_elements: 

746 attr_class = a.get("class") 

747 if attr_class == "node": 

748 na, nd = self.find_gexf_attributes(a) 

749 node_attr.update(na) 

750 node_default.update(nd) 

751 G.graph["node_default"] = node_default 

752 elif attr_class == "edge": 

753 ea, ed = self.find_gexf_attributes(a) 

754 edge_attr.update(ea) 

755 edge_default.update(ed) 

756 G.graph["edge_default"] = edge_default 

757 else: 

758 raise # unknown attribute class 

759 

760 # Hack to handle Gephi0.7beta bug 

761 # add weight attribute 

762 ea = {"weight": {"type": "double", "mode": "static", "title": "weight"}} 

763 ed = {} 

764 edge_attr.update(ea) 

765 edge_default.update(ed) 

766 G.graph["edge_default"] = edge_default 

767 

768 # add nodes 

769 nodes_element = graph_xml.find(f"{{{self.NS_GEXF}}}nodes") 

770 if nodes_element is not None: 

771 for node_xml in nodes_element.findall(f"{{{self.NS_GEXF}}}node"): 

772 self.add_node(G, node_xml, node_attr) 

773 

774 # add edges 

775 edges_element = graph_xml.find(f"{{{self.NS_GEXF}}}edges") 

776 if edges_element is not None: 

777 for edge_xml in edges_element.findall(f"{{{self.NS_GEXF}}}edge"): 

778 self.add_edge(G, edge_xml, edge_attr) 

779 

780 # switch to Graph or DiGraph if no parallel edges were found. 

781 if self.simple_graph: 

782 if G.is_directed(): 

783 G = nx.DiGraph(G) 

784 else: 

785 G = nx.Graph(G) 

786 return G 

787 

788 def add_node(self, G, node_xml, node_attr, node_pid=None): 

789 # add a single node with attributes to the graph 

790 

791 # get attributes and subattributues for node 

792 data = self.decode_attr_elements(node_attr, node_xml) 

793 data = self.add_parents(data, node_xml) # add any parents 

794 if self.VERSION == "1.1": 

795 data = self.add_slices(data, node_xml) # add slices 

796 else: 

797 data = self.add_spells(data, node_xml) # add spells 

798 data = self.add_viz(data, node_xml) # add viz 

799 data = self.add_start_end(data, node_xml) # add start/end 

800 

801 # find the node id and cast it to the appropriate type 

802 node_id = node_xml.get("id") 

803 if self.node_type is not None: 

804 node_id = self.node_type(node_id) 

805 

806 # every node should have a label 

807 node_label = node_xml.get("label") 

808 data["label"] = node_label 

809 

810 # parent node id 

811 node_pid = node_xml.get("pid", node_pid) 

812 if node_pid is not None: 

813 data["pid"] = node_pid 

814 

815 # check for subnodes, recursive 

816 subnodes = node_xml.find(f"{{{self.NS_GEXF}}}nodes") 

817 if subnodes is not None: 

818 for node_xml in subnodes.findall(f"{{{self.NS_GEXF}}}node"): 

819 self.add_node(G, node_xml, node_attr, node_pid=node_id) 

820 

821 G.add_node(node_id, **data) 

822 

823 def add_start_end(self, data, xml): 

824 # start and end times 

825 ttype = self.timeformat 

826 node_start = xml.get("start") 

827 if node_start is not None: 

828 data["start"] = self.python_type[ttype](node_start) 

829 node_end = xml.get("end") 

830 if node_end is not None: 

831 data["end"] = self.python_type[ttype](node_end) 

832 return data 

833 

834 def add_viz(self, data, node_xml): 

835 # add viz element for node 

836 viz = {} 

837 color = node_xml.find(f"{{{self.NS_VIZ}}}color") 

838 if color is not None: 

839 if self.VERSION == "1.1": 

840 viz["color"] = { 

841 "r": int(color.get("r")), 

842 "g": int(color.get("g")), 

843 "b": int(color.get("b")), 

844 } 

845 else: 

846 viz["color"] = { 

847 "r": int(color.get("r")), 

848 "g": int(color.get("g")), 

849 "b": int(color.get("b")), 

850 "a": float(color.get("a", 1)), 

851 } 

852 

853 size = node_xml.find(f"{{{self.NS_VIZ}}}size") 

854 if size is not None: 

855 viz["size"] = float(size.get("value")) 

856 

857 thickness = node_xml.find(f"{{{self.NS_VIZ}}}thickness") 

858 if thickness is not None: 

859 viz["thickness"] = float(thickness.get("value")) 

860 

861 shape = node_xml.find(f"{{{self.NS_VIZ}}}shape") 

862 if shape is not None: 

863 viz["shape"] = shape.get("shape") 

864 if viz["shape"] == "image": 

865 viz["shape"] = shape.get("uri") 

866 

867 position = node_xml.find(f"{{{self.NS_VIZ}}}position") 

868 if position is not None: 

869 viz["position"] = { 

870 "x": float(position.get("x", 0)), 

871 "y": float(position.get("y", 0)), 

872 "z": float(position.get("z", 0)), 

873 } 

874 

875 if len(viz) > 0: 

876 data["viz"] = viz 

877 return data 

878 

879 def add_parents(self, data, node_xml): 

880 parents_element = node_xml.find(f"{{{self.NS_GEXF}}}parents") 

881 if parents_element is not None: 

882 data["parents"] = [] 

883 for p in parents_element.findall(f"{{{self.NS_GEXF}}}parent"): 

884 parent = p.get("for") 

885 data["parents"].append(parent) 

886 return data 

887 

888 def add_slices(self, data, node_or_edge_xml): 

889 slices_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}slices") 

890 if slices_element is not None: 

891 data["slices"] = [] 

892 for s in slices_element.findall(f"{{{self.NS_GEXF}}}slice"): 

893 start = s.get("start") 

894 end = s.get("end") 

895 data["slices"].append((start, end)) 

896 return data 

897 

898 def add_spells(self, data, node_or_edge_xml): 

899 spells_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}spells") 

900 if spells_element is not None: 

901 data["spells"] = [] 

902 ttype = self.timeformat 

903 for s in spells_element.findall(f"{{{self.NS_GEXF}}}spell"): 

904 start = self.python_type[ttype](s.get("start")) 

905 end = self.python_type[ttype](s.get("end")) 

906 data["spells"].append((start, end)) 

907 return data 

908 

909 def add_edge(self, G, edge_element, edge_attr): 

910 # add an edge to the graph 

911 

912 # raise error if we find mixed directed and undirected edges 

913 edge_direction = edge_element.get("type") 

914 if G.is_directed() and edge_direction == "undirected": 

915 raise nx.NetworkXError("Undirected edge found in directed graph.") 

916 if (not G.is_directed()) and edge_direction == "directed": 

917 raise nx.NetworkXError("Directed edge found in undirected graph.") 

918 

919 # Get source and target and recast type if required 

920 source = edge_element.get("source") 

921 target = edge_element.get("target") 

922 if self.node_type is not None: 

923 source = self.node_type(source) 

924 target = self.node_type(target) 

925 

926 data = self.decode_attr_elements(edge_attr, edge_element) 

927 data = self.add_start_end(data, edge_element) 

928 

929 if self.VERSION == "1.1": 

930 data = self.add_slices(data, edge_element) # add slices 

931 else: 

932 data = self.add_spells(data, edge_element) # add spells 

933 

934 # GEXF stores edge ids as an attribute 

935 # NetworkX uses them as keys in multigraphs 

936 # if networkx_key is not specified as an attribute 

937 edge_id = edge_element.get("id") 

938 if edge_id is not None: 

939 data["id"] = edge_id 

940 

941 # check if there is a 'multigraph_key' and use that as edge_id 

942 multigraph_key = data.pop("networkx_key", None) 

943 if multigraph_key is not None: 

944 edge_id = multigraph_key 

945 

946 weight = edge_element.get("weight") 

947 if weight is not None: 

948 data["weight"] = float(weight) 

949 

950 edge_label = edge_element.get("label") 

951 if edge_label is not None: 

952 data["label"] = edge_label 

953 

954 if G.has_edge(source, target): 

955 # seen this edge before - this is a multigraph 

956 self.simple_graph = False 

957 G.add_edge(source, target, key=edge_id, **data) 

958 if edge_direction == "mutual": 

959 G.add_edge(target, source, key=edge_id, **data) 

960 

961 def decode_attr_elements(self, gexf_keys, obj_xml): 

962 # Use the key information to decode the attr XML 

963 attr = {} 

964 # look for outer '<attvalues>' element 

965 attr_element = obj_xml.find(f"{{{self.NS_GEXF}}}attvalues") 

966 if attr_element is not None: 

967 # loop over <attvalue> elements 

968 for a in attr_element.findall(f"{{{self.NS_GEXF}}}attvalue"): 

969 key = a.get("for") # for is required 

970 try: # should be in our gexf_keys dictionary 

971 title = gexf_keys[key]["title"] 

972 except KeyError as err: 

973 raise nx.NetworkXError(f"No attribute defined for={key}.") from err 

974 atype = gexf_keys[key]["type"] 

975 value = a.get("value") 

976 if atype == "boolean": 

977 value = self.convert_bool[value] 

978 else: 

979 value = self.python_type[atype](value) 

980 if gexf_keys[key]["mode"] == "dynamic": 

981 # for dynamic graphs use list of three-tuples 

982 # [(value1,start1,end1), (value2,start2,end2), etc] 

983 ttype = self.timeformat 

984 start = self.python_type[ttype](a.get("start")) 

985 end = self.python_type[ttype](a.get("end")) 

986 if title in attr: 

987 attr[title].append((value, start, end)) 

988 else: 

989 attr[title] = [(value, start, end)] 

990 else: 

991 # for static graphs just assign the value 

992 attr[title] = value 

993 return attr 

994 

995 def find_gexf_attributes(self, attributes_element): 

996 # Extract all the attributes and defaults 

997 attrs = {} 

998 defaults = {} 

999 mode = attributes_element.get("mode") 

1000 for k in attributes_element.findall(f"{{{self.NS_GEXF}}}attribute"): 

1001 attr_id = k.get("id") 

1002 title = k.get("title") 

1003 atype = k.get("type") 

1004 attrs[attr_id] = {"title": title, "type": atype, "mode": mode} 

1005 # check for the 'default' subelement of key element and add 

1006 default = k.find(f"{{{self.NS_GEXF}}}default") 

1007 if default is not None: 

1008 if atype == "boolean": 

1009 value = self.convert_bool[default.text] 

1010 else: 

1011 value = self.python_type[atype](default.text) 

1012 defaults[title] = value 

1013 return attrs, defaults 

1014 

1015 

1016def relabel_gexf_graph(G): 

1017 """Relabel graph using "label" node keyword for node label. 

1018 

1019 Parameters 

1020 ---------- 

1021 G : graph 

1022 A NetworkX graph read from GEXF data 

1023 

1024 Returns 

1025 ------- 

1026 H : graph 

1027 A NetworkX graph with relabeled nodes 

1028 

1029 Raises 

1030 ------ 

1031 NetworkXError 

1032 If node labels are missing or not unique while relabel=True. 

1033 

1034 Notes 

1035 ----- 

1036 This function relabels the nodes in a NetworkX graph with the 

1037 "label" attribute. It also handles relabeling the specific GEXF 

1038 node attributes "parents", and "pid". 

1039 """ 

1040 # build mapping of node labels, do some error checking 

1041 try: 

1042 mapping = [(u, G.nodes[u]["label"]) for u in G] 

1043 except KeyError as err: 

1044 raise nx.NetworkXError( 

1045 "Failed to relabel nodes: missing node labels found. Use relabel=False." 

1046 ) from err 

1047 x, y = zip(*mapping) 

1048 if len(set(y)) != len(G): 

1049 raise nx.NetworkXError( 

1050 "Failed to relabel nodes: " 

1051 "duplicate node labels found. " 

1052 "Use relabel=False." 

1053 ) 

1054 mapping = dict(mapping) 

1055 H = nx.relabel_nodes(G, mapping) 

1056 # relabel attributes 

1057 for n in G: 

1058 m = mapping[n] 

1059 H.nodes[m]["id"] = n 

1060 H.nodes[m].pop("label") 

1061 if "pid" in H.nodes[m]: 

1062 H.nodes[m]["pid"] = mapping[G.nodes[n]["pid"]] 

1063 if "parents" in H.nodes[m]: 

1064 H.nodes[m]["parents"] = [mapping[p] for p in G.nodes[n]["parents"]] 

1065 return H