1###############################################################################
2#
3# Shape - A class for to represent Excel XLSX shape objects.
4#
5# SPDX-License-Identifier: BSD-2-Clause
6#
7# Copyright (c) 2013-2025, John McNamara, jmcnamara@cpan.org
8#
9import copy
10from warnings import warn
11
12from xlsxwriter.color import Color
13
14
15class Shape:
16 """
17 A class for to represent Excel XLSX shape objects.
18
19
20 """
21
22 ###########################################################################
23 #
24 # Public API.
25 #
26 ###########################################################################
27
28 def __init__(self, shape_type, name, options):
29 """
30 Constructor.
31
32 """
33 super().__init__()
34 self.name = name
35 self.shape_type = shape_type
36 self.connect = 0
37 self.drawing = 0
38 self.edit_as = ""
39 self.id = 0
40 self.text = ""
41 self.textlink = ""
42 self.stencil = 1
43 self.element = -1
44 self.start = None
45 self.start_index = None
46 self.end = None
47 self.end_index = None
48 self.adjustments = []
49 self.start_side = ""
50 self.end_side = ""
51 self.flip_h = 0
52 self.flip_v = 0
53 self.rotation = 0
54 self.text_rotation = 0
55 self.textbox = False
56
57 self.align = None
58 self.fill = None
59 self.font = None
60 self.format = None
61 self.line = None
62
63 self._set_options(options)
64
65 ###########################################################################
66 #
67 # Private API.
68 #
69 ###########################################################################
70
71 def _set_options(self, options):
72 self.align = self._get_align_properties(options.get("align"))
73 self.fill = self._get_fill_properties(options.get("fill"))
74 self.font = self._get_font_properties(options.get("font"))
75 self.gradient = self._get_gradient_properties(options.get("gradient"))
76 self.line = self._get_line_properties(options.get("line"))
77
78 self.text_rotation = options.get("text_rotation", 0)
79
80 self.textlink = options.get("textlink", "")
81 if self.textlink.startswith("="):
82 self.textlink = self.textlink.lstrip("=")
83
84 if options.get("border"):
85 self.line = self._get_line_properties(options["border"])
86
87 # Gradient fill overrides solid fill.
88 if self.gradient:
89 self.fill = None
90
91 ###########################################################################
92 #
93 # Static methods for processing chart/shape style properties.
94 #
95 ###########################################################################
96
97 @staticmethod
98 def _get_line_properties(line):
99 # Convert user line properties to the structure required internally.
100
101 if not line:
102 return {"defined": False}
103
104 # Copy the user defined properties since they will be modified.
105 line = copy.deepcopy(line)
106
107 dash_types = {
108 "solid": "solid",
109 "round_dot": "sysDot",
110 "square_dot": "sysDash",
111 "dash": "dash",
112 "dash_dot": "dashDot",
113 "long_dash": "lgDash",
114 "long_dash_dot": "lgDashDot",
115 "long_dash_dot_dot": "lgDashDotDot",
116 "dot": "dot",
117 "system_dash_dot": "sysDashDot",
118 "system_dash_dot_dot": "sysDashDotDot",
119 }
120
121 # Check the dash type.
122 dash_type = line.get("dash_type")
123
124 if dash_type is not None:
125 if dash_type in dash_types:
126 line["dash_type"] = dash_types[dash_type]
127 else:
128 warn(f"Unknown dash type '{dash_type}'")
129 return {}
130
131 if line.get("color"):
132 line["color"] = Color._from_value(line["color"])
133
134 line["defined"] = True
135
136 return line
137
138 @staticmethod
139 def _get_fill_properties(fill):
140 # Convert user fill properties to the structure required internally.
141
142 if not fill:
143 return {"defined": False}
144
145 # Copy the user defined properties since they will be modified.
146 fill = copy.deepcopy(fill)
147
148 if fill.get("color"):
149 fill["color"] = Color._from_value(fill["color"])
150
151 fill["defined"] = True
152
153 return fill
154
155 @staticmethod
156 def _get_pattern_properties(pattern):
157 # Convert user defined pattern to the structure required internally.
158
159 if not pattern:
160 return {}
161
162 # Copy the user defined properties since they will be modified.
163 pattern = copy.deepcopy(pattern)
164
165 if not pattern.get("pattern"):
166 warn("Pattern must include 'pattern'")
167 return {}
168
169 if not pattern.get("fg_color"):
170 warn("Pattern must include 'fg_color'")
171 return {}
172
173 types = {
174 "percent_5": "pct5",
175 "percent_10": "pct10",
176 "percent_20": "pct20",
177 "percent_25": "pct25",
178 "percent_30": "pct30",
179 "percent_40": "pct40",
180 "percent_50": "pct50",
181 "percent_60": "pct60",
182 "percent_70": "pct70",
183 "percent_75": "pct75",
184 "percent_80": "pct80",
185 "percent_90": "pct90",
186 "light_downward_diagonal": "ltDnDiag",
187 "light_upward_diagonal": "ltUpDiag",
188 "dark_downward_diagonal": "dkDnDiag",
189 "dark_upward_diagonal": "dkUpDiag",
190 "wide_downward_diagonal": "wdDnDiag",
191 "wide_upward_diagonal": "wdUpDiag",
192 "light_vertical": "ltVert",
193 "light_horizontal": "ltHorz",
194 "narrow_vertical": "narVert",
195 "narrow_horizontal": "narHorz",
196 "dark_vertical": "dkVert",
197 "dark_horizontal": "dkHorz",
198 "dashed_downward_diagonal": "dashDnDiag",
199 "dashed_upward_diagonal": "dashUpDiag",
200 "dashed_horizontal": "dashHorz",
201 "dashed_vertical": "dashVert",
202 "small_confetti": "smConfetti",
203 "large_confetti": "lgConfetti",
204 "zigzag": "zigZag",
205 "wave": "wave",
206 "diagonal_brick": "diagBrick",
207 "horizontal_brick": "horzBrick",
208 "weave": "weave",
209 "plaid": "plaid",
210 "divot": "divot",
211 "dotted_grid": "dotGrid",
212 "dotted_diamond": "dotDmnd",
213 "shingle": "shingle",
214 "trellis": "trellis",
215 "sphere": "sphere",
216 "small_grid": "smGrid",
217 "large_grid": "lgGrid",
218 "small_check": "smCheck",
219 "large_check": "lgCheck",
220 "outlined_diamond": "openDmnd",
221 "solid_diamond": "solidDmnd",
222 }
223
224 # Check for valid types.
225 if pattern["pattern"] not in types:
226 warn(f"unknown pattern type '{pattern['pattern']}'")
227 return {}
228
229 pattern["pattern"] = types[pattern["pattern"]]
230
231 if pattern.get("fg_color"):
232 pattern["fg_color"] = Color._from_value(pattern["fg_color"])
233
234 if pattern.get("bg_color"):
235 pattern["bg_color"] = Color._from_value(pattern["bg_color"])
236 else:
237 pattern["bg_color"] = Color("#FFFFFF")
238
239 return pattern
240
241 @staticmethod
242 def _get_gradient_properties(gradient):
243 # pylint: disable=too-many-return-statements
244 # Convert user defined gradient to the structure required internally.
245
246 if not gradient:
247 return {}
248
249 # Copy the user defined properties since they will be modified.
250 gradient = copy.deepcopy(gradient)
251
252 types = {
253 "linear": "linear",
254 "radial": "circle",
255 "rectangular": "rect",
256 "path": "shape",
257 }
258
259 # Check the colors array exists and is valid.
260 if "colors" not in gradient or not isinstance(gradient["colors"], list):
261 warn("Gradient must include colors list")
262 return {}
263
264 # Check the colors array has the required number of entries.
265 if not 2 <= len(gradient["colors"]) <= 10:
266 warn("Gradient colors list must at least 2 values and not more than 10")
267 return {}
268
269 if "positions" in gradient:
270 # Check the positions array has the right number of entries.
271 if len(gradient["positions"]) != len(gradient["colors"]):
272 warn("Gradient positions not equal to number of colors")
273 return {}
274
275 # Check the positions are in the correct range.
276 for pos in gradient["positions"]:
277 if not 0 <= pos <= 100:
278 warn("Gradient position must be in the range 0 <= position <= 100")
279 return {}
280 else:
281 # Use the default gradient positions.
282 if len(gradient["colors"]) == 2:
283 gradient["positions"] = [0, 100]
284
285 elif len(gradient["colors"]) == 3:
286 gradient["positions"] = [0, 50, 100]
287
288 elif len(gradient["colors"]) == 4:
289 gradient["positions"] = [0, 33, 66, 100]
290
291 else:
292 warn("Must specify gradient positions")
293 return {}
294
295 angle = gradient.get("angle")
296 if angle:
297 if not 0 <= angle < 360:
298 warn("Gradient angle must be in the range 0 <= angle < 360")
299 return {}
300 else:
301 gradient["angle"] = 90
302
303 # Check for valid types.
304 gradient_type = gradient.get("type")
305
306 if gradient_type is not None:
307 if gradient_type in types:
308 gradient["type"] = types[gradient_type]
309 else:
310 warn(f"Unknown gradient type '{gradient_type}")
311 return {}
312 else:
313 gradient["type"] = "linear"
314
315 gradient["colors"] = [Color._from_value(color) for color in gradient["colors"]]
316
317 return gradient
318
319 @staticmethod
320 def _get_font_properties(options):
321 # Convert user defined font values into private dict values.
322 if options is None:
323 options = {}
324
325 font = {
326 "name": options.get("name"),
327 "color": options.get("color"),
328 "size": options.get("size", 11),
329 "bold": options.get("bold"),
330 "italic": options.get("italic"),
331 "underline": options.get("underline"),
332 "pitch_family": options.get("pitch_family"),
333 "charset": options.get("charset"),
334 "baseline": options.get("baseline", -1),
335 "lang": options.get("lang", "en-US"),
336 }
337
338 # Convert font size units.
339 if font["size"]:
340 font["size"] = int(font["size"] * 100)
341
342 if font.get("color"):
343 font["color"] = Color._from_value(font["color"])
344
345 return font
346
347 @staticmethod
348 def _get_font_style_attributes(font):
349 # _get_font_style_attributes.
350 attributes = []
351
352 if not font:
353 return attributes
354
355 if font.get("size"):
356 attributes.append(("sz", font["size"]))
357
358 if font.get("bold") is not None:
359 attributes.append(("b", 0 + font["bold"]))
360
361 if font.get("italic") is not None:
362 attributes.append(("i", 0 + font["italic"]))
363
364 if font.get("underline") is not None:
365 attributes.append(("u", "sng"))
366
367 if font.get("baseline") != -1:
368 attributes.append(("baseline", font["baseline"]))
369
370 return attributes
371
372 @staticmethod
373 def _get_font_latin_attributes(font):
374 # _get_font_latin_attributes.
375 attributes = []
376
377 if not font:
378 return attributes
379
380 if font["name"] is not None:
381 attributes.append(("typeface", font["name"]))
382
383 if font["pitch_family"] is not None:
384 attributes.append(("pitchFamily", font["pitch_family"]))
385
386 if font["charset"] is not None:
387 attributes.append(("charset", font["charset"]))
388
389 return attributes
390
391 @staticmethod
392 def _get_align_properties(align):
393 # Convert user defined align to the structure required internally.
394 if not align:
395 return {"defined": False}
396
397 # Copy the user defined properties since they will be modified.
398 align = copy.deepcopy(align)
399
400 if "vertical" in align:
401 align_type = align["vertical"]
402
403 align_types = {
404 "top": "top",
405 "middle": "middle",
406 "bottom": "bottom",
407 }
408
409 if align_type in align_types:
410 align["vertical"] = align_types[align_type]
411 else:
412 warn(f"Unknown alignment type '{align_type}'")
413 return {"defined": False}
414
415 if "horizontal" in align:
416 align_type = align["horizontal"]
417
418 align_types = {
419 "left": "left",
420 "center": "center",
421 "right": "right",
422 }
423
424 if align_type in align_types:
425 align["horizontal"] = align_types[align_type]
426 else:
427 warn(f"Unknown alignment type '{align_type}'")
428 return {"defined": False}
429
430 align["defined"] = True
431
432 return align