1# Copyright (c) 2010-2024 openpyxl
2
3from warnings import warn
4
5from openpyxl.descriptors.serialisable import Serialisable
6from openpyxl.descriptors import (
7 Typed,
8)
9from openpyxl.descriptors.sequence import NestedSequence
10from openpyxl.descriptors.excel import ExtensionList
11from openpyxl.utils.indexed_list import IndexedList
12from openpyxl.xml.constants import ARC_STYLE, SHEET_MAIN_NS
13from openpyxl.xml.functions import fromstring
14
15from .builtins import styles
16from .colors import ColorList
17from .differential import DifferentialStyle
18from .table import TableStyleList
19from .borders import Border
20from .fills import Fill
21from .fonts import Font
22from .numbers import (
23 NumberFormatList,
24 BUILTIN_FORMATS,
25 BUILTIN_FORMATS_MAX_SIZE,
26 BUILTIN_FORMATS_REVERSE,
27 is_date_format,
28 is_timedelta_format,
29 builtin_format_code
30)
31from .named_styles import (
32 _NamedCellStyleList,
33 NamedStyleList,
34 NamedStyle,
35)
36from .cell_style import CellStyle, CellStyleList
37
38
39class Stylesheet(Serialisable):
40
41 tagname = "styleSheet"
42
43 numFmts = Typed(expected_type=NumberFormatList)
44 fonts = NestedSequence(expected_type=Font, count=True)
45 fills = NestedSequence(expected_type=Fill, count=True)
46 borders = NestedSequence(expected_type=Border, count=True)
47 cellStyleXfs = Typed(expected_type=CellStyleList)
48 cellXfs = Typed(expected_type=CellStyleList)
49 cellStyles = Typed(expected_type=_NamedCellStyleList)
50 dxfs = NestedSequence(expected_type=DifferentialStyle, count=True)
51 tableStyles = Typed(expected_type=TableStyleList, allow_none=True)
52 colors = Typed(expected_type=ColorList, allow_none=True)
53 extLst = Typed(expected_type=ExtensionList, allow_none=True)
54
55 __elements__ = ('numFmts', 'fonts', 'fills', 'borders', 'cellStyleXfs',
56 'cellXfs', 'cellStyles', 'dxfs', 'tableStyles', 'colors')
57
58 def __init__(self,
59 numFmts=None,
60 fonts=(),
61 fills=(),
62 borders=(),
63 cellStyleXfs=None,
64 cellXfs=None,
65 cellStyles=None,
66 dxfs=(),
67 tableStyles=None,
68 colors=None,
69 extLst=None,
70 ):
71 if numFmts is None:
72 numFmts = NumberFormatList()
73 self.numFmts = numFmts
74 self.number_formats = IndexedList()
75 self.fonts = fonts
76 self.fills = fills
77 self.borders = borders
78 if cellStyleXfs is None:
79 cellStyleXfs = CellStyleList()
80 self.cellStyleXfs = cellStyleXfs
81 if cellXfs is None:
82 cellXfs = CellStyleList()
83 self.cellXfs = cellXfs
84 if cellStyles is None:
85 cellStyles = _NamedCellStyleList()
86 self.cellStyles = cellStyles
87
88 self.dxfs = dxfs
89 self.tableStyles = tableStyles
90 self.colors = colors
91
92 self.cell_styles = self.cellXfs._to_array()
93 self.alignments = self.cellXfs.alignments
94 self.protections = self.cellXfs.prots
95 self._normalise_numbers()
96 self.named_styles = self._merge_named_styles()
97
98
99 @classmethod
100 def from_tree(cls, node):
101 # strip all attribs
102 attrs = dict(node.attrib)
103 for k in attrs:
104 del node.attrib[k]
105 return super().from_tree(node)
106
107
108 def _merge_named_styles(self):
109 """
110 Merge named style names "cellStyles" with their associated styles
111 "cellStyleXfs"
112 """
113 style_refs = self.cellStyles.remove_duplicates()
114 from_ref = [self._expand_named_style(style_ref) for style_ref in style_refs]
115
116 return NamedStyleList(from_ref)
117
118
119 def _expand_named_style(self, style_ref):
120 """
121 Expand a named style reference element to a
122 named style object by binding the relevant
123 objects from the stylesheet
124 """
125 xf = self.cellStyleXfs[style_ref.xfId]
126 named_style = NamedStyle(
127 name=style_ref.name,
128 hidden=style_ref.hidden,
129 builtinId=style_ref.builtinId,
130 )
131
132 named_style.font = self.fonts[xf.fontId]
133 named_style.fill = self.fills[xf.fillId]
134 named_style.border = self.borders[xf.borderId]
135 if xf.numFmtId < BUILTIN_FORMATS_MAX_SIZE:
136 formats = BUILTIN_FORMATS
137 else:
138 formats = self.custom_formats
139
140 if xf.numFmtId in formats:
141 named_style.number_format = formats[xf.numFmtId]
142 if xf.alignment:
143 named_style.alignment = xf.alignment
144 if xf.protection:
145 named_style.protection = xf.protection
146
147 return named_style
148
149
150 def _split_named_styles(self, wb):
151 """
152 Convert NamedStyle into separate CellStyle and Xf objects
153
154 """
155 for style in wb._named_styles:
156 self.cellStyles.cellStyle.append(style.as_name())
157 self.cellStyleXfs.xf.append(style.as_xf())
158
159
160 @property
161 def custom_formats(self):
162 return dict([(n.numFmtId, n.formatCode) for n in self.numFmts.numFmt])
163
164
165 def _normalise_numbers(self):
166 """
167 Rebase custom numFmtIds with a floor of 164 when reading stylesheet
168 And index datetime formats
169 """
170 date_formats = set()
171 timedelta_formats = set()
172 custom = self.custom_formats
173 formats = self.number_formats
174 for idx, style in enumerate(self.cell_styles):
175 if style.numFmtId in custom:
176 fmt = custom[style.numFmtId]
177 if fmt in BUILTIN_FORMATS_REVERSE: # remove builtins
178 style.numFmtId = BUILTIN_FORMATS_REVERSE[fmt]
179 else:
180 style.numFmtId = formats.add(fmt) + BUILTIN_FORMATS_MAX_SIZE
181 else:
182 fmt = builtin_format_code(style.numFmtId)
183 if is_date_format(fmt):
184 # Create an index of which styles refer to datetimes
185 date_formats.add(idx)
186 if is_timedelta_format(fmt):
187 # Create an index of which styles refer to timedeltas
188 timedelta_formats.add(idx)
189 self.date_formats = date_formats
190 self.timedelta_formats = timedelta_formats
191
192
193 def to_tree(self, tagname=None, idx=None, namespace=None):
194 tree = super().to_tree(tagname, idx, namespace)
195 tree.set("xmlns", SHEET_MAIN_NS)
196 return tree
197
198
199def apply_stylesheet(archive, wb):
200 """
201 Add styles to workbook if present
202 """
203 try:
204 src = archive.read(ARC_STYLE)
205 except KeyError:
206 return wb
207
208 node = fromstring(src)
209 stylesheet = Stylesheet.from_tree(node)
210
211 if stylesheet.cell_styles:
212
213 wb._borders = IndexedList(stylesheet.borders)
214 wb._fonts = IndexedList(stylesheet.fonts)
215 wb._fills = IndexedList(stylesheet.fills)
216 wb._differential_styles.styles = stylesheet.dxfs
217 wb._number_formats = stylesheet.number_formats
218 wb._protections = stylesheet.protections
219 wb._alignments = stylesheet.alignments
220 wb._table_styles = stylesheet.tableStyles
221
222 # need to overwrite openpyxl defaults in case workbook has different ones
223 wb._cell_styles = stylesheet.cell_styles
224 wb._named_styles = stylesheet.named_styles
225 wb._date_formats = stylesheet.date_formats
226 wb._timedelta_formats = stylesheet.timedelta_formats
227
228 for ns in wb._named_styles:
229 ns.bind(wb)
230
231 else:
232 warn("Workbook contains no stylesheet, using openpyxl's defaults")
233
234 if not wb._named_styles:
235 normal = styles['Normal']
236 wb.add_named_style(normal)
237 warn("Workbook contains no default style, apply openpyxl's default")
238
239 if stylesheet.colors is not None:
240 wb._colors = stylesheet.colors.index
241
242
243def write_stylesheet(wb):
244 stylesheet = Stylesheet()
245 stylesheet.fonts = wb._fonts
246 stylesheet.fills = wb._fills
247 stylesheet.borders = wb._borders
248 stylesheet.dxfs = wb._differential_styles.styles
249 stylesheet.colors = ColorList(indexedColors=wb._colors)
250
251 from .numbers import NumberFormat
252 fmts = []
253 for idx, code in enumerate(wb._number_formats, BUILTIN_FORMATS_MAX_SIZE):
254 fmt = NumberFormat(idx, code)
255 fmts.append(fmt)
256
257 stylesheet.numFmts.numFmt = fmts
258
259 xfs = []
260 for style in wb._cell_styles:
261 xf = CellStyle.from_array(style)
262
263 if style.alignmentId:
264 xf.alignment = wb._alignments[style.alignmentId]
265
266 if style.protectionId:
267 xf.protection = wb._protections[style.protectionId]
268 xfs.append(xf)
269 stylesheet.cellXfs = CellStyleList(xf=xfs)
270
271 stylesheet._split_named_styles(wb)
272 stylesheet.tableStyles = wb._table_styles
273
274 return stylesheet.to_tree()