1from __future__ import absolute_import
2# Copyright (c) 2010-2015 openpyxl
3
4"""Implements the lxml.etree.xmlfile API using the standard library xml.etree"""
5
6
7from contextlib import contextmanager
8
9from xml.etree.ElementTree import (
10 Element,
11 _escape_cdata,
12)
13
14from . import incremental_tree
15
16
17class LxmlSyntaxError(Exception):
18 pass
19
20
21class _IncrementalFileWriter(object):
22 """Replacement for _IncrementalFileWriter of lxml"""
23 def __init__(self, output_file):
24 self._element_stack = []
25 self._file = output_file
26 self._have_root = False
27 self.global_nsmap = incremental_tree.current_global_nsmap()
28 self.is_html = False
29
30 @contextmanager
31 def element(self, tag, attrib=None, nsmap=None, **_extra):
32 """Create a new xml element using a context manager."""
33 if nsmap and None in nsmap:
34 # Normalise None prefix (lxml's default namespace prefix) -> "", as
35 # required for incremental_tree
36 if "" in nsmap and nsmap[""] != nsmap[None]:
37 raise ValueError(
38 'Found None and "" as default nsmap prefixes with different URIs'
39 )
40 nsmap = nsmap.copy()
41 nsmap[""] = nsmap.pop(None)
42
43 # __enter__ part
44 self._have_root = True
45 if attrib is None:
46 attrib = {}
47 elem = Element(tag, attrib=attrib, **_extra)
48 elem.text = ''
49 elem.tail = ''
50 if self._element_stack:
51 is_root = False
52 (
53 nsmap_scope,
54 default_ns_attr_prefix,
55 uri_to_prefix,
56 ) = self._element_stack[-1]
57 else:
58 is_root = True
59 nsmap_scope = {}
60 default_ns_attr_prefix = None
61 uri_to_prefix = {}
62 (
63 tag,
64 nsmap_scope,
65 default_ns_attr_prefix,
66 uri_to_prefix,
67 next_remains_root,
68 ) = incremental_tree.write_elem_start(
69 self._file,
70 elem,
71 nsmap_scope=nsmap_scope,
72 global_nsmap=self.global_nsmap,
73 short_empty_elements=False,
74 is_html=self.is_html,
75 is_root=is_root,
76 uri_to_prefix=uri_to_prefix,
77 default_ns_attr_prefix=default_ns_attr_prefix,
78 new_nsmap=nsmap,
79 )
80 self._element_stack.append(
81 (
82 nsmap_scope,
83 default_ns_attr_prefix,
84 uri_to_prefix,
85 )
86 )
87 yield
88
89 # __exit__ part
90 self._element_stack.pop()
91 self._file(f"</{tag}>")
92 if elem.tail:
93 self._file(_escape_cdata(elem.tail))
94
95 def write(self, arg):
96 """Write a string or subelement."""
97
98 if isinstance(arg, str):
99 # it is not allowed to write a string outside of an element
100 if not self._element_stack:
101 raise LxmlSyntaxError()
102 self._file(_escape_cdata(arg))
103
104 else:
105 if not self._element_stack and self._have_root:
106 raise LxmlSyntaxError()
107
108 if self._element_stack:
109 is_root = False
110 (
111 nsmap_scope,
112 default_ns_attr_prefix,
113 uri_to_prefix,
114 ) = self._element_stack[-1]
115 else:
116 is_root = True
117 nsmap_scope = {}
118 default_ns_attr_prefix = None
119 uri_to_prefix = {}
120 incremental_tree._serialize_ns_xml(
121 self._file,
122 arg,
123 nsmap_scope=nsmap_scope,
124 global_nsmap=self.global_nsmap,
125 short_empty_elements=True,
126 is_html=self.is_html,
127 is_root=is_root,
128 uri_to_prefix=uri_to_prefix,
129 default_ns_attr_prefix=default_ns_attr_prefix,
130 )
131
132 def __enter__(self):
133 pass
134
135 def __exit__(self, type, value, traceback):
136 # without root the xml document is incomplete
137 if not self._have_root:
138 raise LxmlSyntaxError()
139
140
141class xmlfile(object):
142 """Context manager that can replace lxml.etree.xmlfile."""
143 def __init__(self, output_file, buffered=False, encoding="utf-8", close=False):
144 self._file = output_file
145 self._close = close
146 self.encoding = encoding
147 self.writer_cm = None
148
149 def __enter__(self):
150 self.writer_cm = incremental_tree._get_writer(self._file, encoding=self.encoding)
151 writer, declared_encoding = self.writer_cm.__enter__()
152 return _IncrementalFileWriter(writer)
153
154 def __exit__(self, type, value, traceback):
155 if self.writer_cm:
156 self.writer_cm.__exit__(type, value, traceback)
157 if self._close:
158 self._file.close()