1# Copyright (c) 2010-2024 openpyxl
2
3from collections import OrderedDict
4from operator import attrgetter
5
6from openpyxl.descriptors import (
7 Typed,
8 Integer,
9 Alias,
10 MinMax,
11 Bool,
12 Set,
13)
14from openpyxl.descriptors.sequence import ValueSequence
15from openpyxl.descriptors.serialisable import Serialisable
16
17from ._3d import _3DBase
18from .data_source import AxDataSource, NumRef
19from .layout import Layout
20from .legend import Legend
21from .reference import Reference
22from .series_factory import SeriesFactory
23from .series import attribute_mapping
24from .shapes import GraphicalProperties
25from .title import TitleDescriptor
26
27class AxId(Serialisable):
28
29 val = Integer()
30
31 def __init__(self, val):
32 self.val = val
33
34
35def PlotArea():
36 from .chartspace import PlotArea
37 return PlotArea()
38
39
40class ChartBase(Serialisable):
41
42 """
43 Base class for all charts
44 """
45
46 legend = Typed(expected_type=Legend, allow_none=True)
47 layout = Typed(expected_type=Layout, allow_none=True)
48 roundedCorners = Bool(allow_none=True)
49 axId = ValueSequence(expected_type=int)
50 visible_cells_only = Bool(allow_none=True)
51 display_blanks = Set(values=['span', 'gap', 'zero'])
52 graphical_properties = Typed(expected_type=GraphicalProperties, allow_none=True)
53
54 _series_type = ""
55 ser = ()
56 series = Alias('ser')
57 title = TitleDescriptor()
58 anchor = "E15" # default anchor position
59 width = 15 # in cm, approx 5 rows
60 height = 7.5 # in cm, approx 14 rows
61 _id = 1
62 _path = "/xl/charts/chart{0}.xml"
63 style = MinMax(allow_none=True, min=1, max=48)
64 mime_type = "application/vnd.openxmlformats-officedocument.drawingml.chart+xml"
65 graphical_properties = Typed(expected_type=GraphicalProperties, allow_none=True) # mapped to chartspace
66
67 __elements__ = ()
68
69
70 def __init__(self, axId=(), **kw):
71 self._charts = [self]
72 self.title = None
73 self.layout = None
74 self.roundedCorners = None
75 self.legend = Legend()
76 self.graphical_properties = None
77 self.style = None
78 self.plot_area = PlotArea()
79 self.axId = axId
80 self.display_blanks = 'gap'
81 self.pivotSource = None
82 self.pivotFormats = ()
83 self.visible_cells_only = True
84 self.idx_base = 0
85 self.graphical_properties = None
86 super().__init__()
87
88
89 def __hash__(self):
90 """
91 Just need to check for identity
92 """
93 return id(self)
94
95 def __iadd__(self, other):
96 """
97 Combine the chart with another one
98 """
99 if not isinstance(other, ChartBase):
100 raise TypeError("Only other charts can be added")
101 self._charts.append(other)
102 return self
103
104
105 def to_tree(self, namespace=None, tagname=None, idx=None):
106 self.axId = [id for id in self._axes]
107 if self.ser is not None:
108 for s in self.ser:
109 s.__elements__ = attribute_mapping[self._series_type]
110 return super().to_tree(tagname, idx)
111
112
113 def _reindex(self):
114 """
115 Normalise and rebase series: sort by order and then rebase order
116
117 """
118 # sort data series in order and rebase
119 ds = sorted(self.series, key=attrgetter("order"))
120 for idx, s in enumerate(ds):
121 s.order = idx
122 self.series = ds
123
124
125 def _write(self):
126 from .chartspace import ChartSpace, ChartContainer
127 self.plot_area.layout = self.layout
128
129 idx_base = self.idx_base
130 for chart in self._charts:
131 if chart not in self.plot_area._charts:
132 chart.idx_base = idx_base
133 idx_base += len(chart.series)
134 self.plot_area._charts = self._charts
135
136 container = ChartContainer(plotArea=self.plot_area, legend=self.legend, title=self.title)
137 if isinstance(chart, _3DBase):
138 container.view3D = chart.view3D
139 container.floor = chart.floor
140 container.sideWall = chart.sideWall
141 container.backWall = chart.backWall
142 container.plotVisOnly = self.visible_cells_only
143 container.dispBlanksAs = self.display_blanks
144 container.pivotFmts = self.pivotFormats
145 cs = ChartSpace(chart=container)
146 cs.style = self.style
147 cs.roundedCorners = self.roundedCorners
148 cs.pivotSource = self.pivotSource
149 cs.spPr = self.graphical_properties
150 return cs.to_tree()
151
152
153 @property
154 def _axes(self):
155 x = getattr(self, "x_axis", None)
156 y = getattr(self, "y_axis", None)
157 z = getattr(self, "z_axis", None)
158 return OrderedDict([(axis.axId, axis) for axis in (x, y, z) if axis])
159
160
161 def set_categories(self, labels):
162 """
163 Set the categories / x-axis values
164 """
165 if not isinstance(labels, Reference):
166 labels = Reference(range_string=labels)
167 for s in self.ser:
168 s.cat = AxDataSource(numRef=NumRef(f=labels))
169
170
171 def add_data(self, data, from_rows=False, titles_from_data=False):
172 """
173 Add a range of data in a single pass.
174 The default is to treat each column as a data series.
175 """
176 if not isinstance(data, Reference):
177 data = Reference(range_string=data)
178
179 if from_rows:
180 values = data.rows
181
182 else:
183 values = data.cols
184
185 for ref in values:
186 series = SeriesFactory(ref, title_from_data=titles_from_data)
187 self.series.append(series)
188
189
190 def append(self, value):
191 """Append a data series to the chart"""
192 l = self.series[:]
193 l.append(value)
194 self.series = l
195
196
197 @property
198 def path(self):
199 return self._path.format(self._id)