1"""
2A layoutgrid is a nrows by ncols set of boxes, meant to be used by
3`._constrained_layout`, each box is analogous to a subplotspec element of
4a gridspec.
5
6Each box is defined by left[ncols], right[ncols], bottom[nrows] and top[nrows],
7and by two editable margins for each side. The main margin gets its value
8set by the size of ticklabels, titles, etc on each Axes that is in the figure.
9The outer margin is the padding around the Axes, and space for any
10colorbars.
11
12The "inner" widths and heights of these boxes are then constrained to be the
13same (relative the values of `width_ratios[ncols]` and `height_ratios[nrows]`).
14
15The layoutgrid is then constrained to be contained within a parent layoutgrid,
16its column(s) and row(s) specified when it is created.
17"""
18
19import itertools
20import kiwisolver as kiwi
21import logging
22import numpy as np
23
24import matplotlib as mpl
25import matplotlib.patches as mpatches
26from matplotlib.transforms import Bbox
27
28_log = logging.getLogger(__name__)
29
30
31class LayoutGrid:
32 """
33 Analogous to a gridspec, and contained in another LayoutGrid.
34 """
35
36 def __init__(self, parent=None, parent_pos=(0, 0),
37 parent_inner=False, name='', ncols=1, nrows=1,
38 h_pad=None, w_pad=None, width_ratios=None,
39 height_ratios=None):
40 Variable = kiwi.Variable
41 self.parent_pos = parent_pos
42 self.parent_inner = parent_inner
43 self.name = name + seq_id()
44 if isinstance(parent, LayoutGrid):
45 self.name = f'{parent.name}.{self.name}'
46 self.nrows = nrows
47 self.ncols = ncols
48 self.height_ratios = np.atleast_1d(height_ratios)
49 if height_ratios is None:
50 self.height_ratios = np.ones(nrows)
51 self.width_ratios = np.atleast_1d(width_ratios)
52 if width_ratios is None:
53 self.width_ratios = np.ones(ncols)
54
55 sn = self.name + '_'
56 if not isinstance(parent, LayoutGrid):
57 # parent can be a rect if not a LayoutGrid
58 # allows specifying a rectangle to contain the layout.
59 self.solver = kiwi.Solver()
60 else:
61 parent.add_child(self, *parent_pos)
62 self.solver = parent.solver
63 # keep track of artist associated w/ this layout. Can be none
64 self.artists = np.empty((nrows, ncols), dtype=object)
65 self.children = np.empty((nrows, ncols), dtype=object)
66
67 self.margins = {}
68 self.margin_vals = {}
69 # all the boxes in each column share the same left/right margins:
70 for todo in ['left', 'right', 'leftcb', 'rightcb']:
71 # track the value so we can change only if a margin is larger
72 # than the current value
73 self.margin_vals[todo] = np.zeros(ncols)
74
75 sol = self.solver
76
77 self.lefts = [Variable(f'{sn}lefts[{i}]') for i in range(ncols)]
78 self.rights = [Variable(f'{sn}rights[{i}]') for i in range(ncols)]
79 for todo in ['left', 'right', 'leftcb', 'rightcb']:
80 self.margins[todo] = [Variable(f'{sn}margins[{todo}][{i}]')
81 for i in range(ncols)]
82 for i in range(ncols):
83 sol.addEditVariable(self.margins[todo][i], 'strong')
84
85 for todo in ['bottom', 'top', 'bottomcb', 'topcb']:
86 self.margins[todo] = np.empty((nrows), dtype=object)
87 self.margin_vals[todo] = np.zeros(nrows)
88
89 self.bottoms = [Variable(f'{sn}bottoms[{i}]') for i in range(nrows)]
90 self.tops = [Variable(f'{sn}tops[{i}]') for i in range(nrows)]
91 for todo in ['bottom', 'top', 'bottomcb', 'topcb']:
92 self.margins[todo] = [Variable(f'{sn}margins[{todo}][{i}]')
93 for i in range(nrows)]
94 for i in range(nrows):
95 sol.addEditVariable(self.margins[todo][i], 'strong')
96
97 # set these margins to zero by default. They will be edited as
98 # children are filled.
99 self.reset_margins()
100 self.add_constraints(parent)
101
102 self.h_pad = h_pad
103 self.w_pad = w_pad
104
105 def __repr__(self):
106 str = f'LayoutBox: {self.name:25s} {self.nrows}x{self.ncols},\n'
107 for i in range(self.nrows):
108 for j in range(self.ncols):
109 str += f'{i}, {j}: '\
110 f'L{self.lefts[j].value():1.3f}, ' \
111 f'B{self.bottoms[i].value():1.3f}, ' \
112 f'R{self.rights[j].value():1.3f}, ' \
113 f'T{self.tops[i].value():1.3f}, ' \
114 f'ML{self.margins["left"][j].value():1.3f}, ' \
115 f'MR{self.margins["right"][j].value():1.3f}, ' \
116 f'MB{self.margins["bottom"][i].value():1.3f}, ' \
117 f'MT{self.margins["top"][i].value():1.3f}, \n'
118 return str
119
120 def reset_margins(self):
121 """
122 Reset all the margins to zero. Must do this after changing
123 figure size, for instance, because the relative size of the
124 axes labels etc changes.
125 """
126 for todo in ['left', 'right', 'bottom', 'top',
127 'leftcb', 'rightcb', 'bottomcb', 'topcb']:
128 self.edit_margins(todo, 0.0)
129
130 def add_constraints(self, parent):
131 # define self-consistent constraints
132 self.hard_constraints()
133 # define relationship with parent layoutgrid:
134 self.parent_constraints(parent)
135 # define relative widths of the grid cells to each other
136 # and stack horizontally and vertically.
137 self.grid_constraints()
138
139 def hard_constraints(self):
140 """
141 These are the redundant constraints, plus ones that make the
142 rest of the code easier.
143 """
144 for i in range(self.ncols):
145 hc = [self.rights[i] >= self.lefts[i],
146 (self.rights[i] - self.margins['right'][i] -
147 self.margins['rightcb'][i] >=
148 self.lefts[i] - self.margins['left'][i] -
149 self.margins['leftcb'][i])
150 ]
151 for c in hc:
152 self.solver.addConstraint(c | 'required')
153
154 for i in range(self.nrows):
155 hc = [self.tops[i] >= self.bottoms[i],
156 (self.tops[i] - self.margins['top'][i] -
157 self.margins['topcb'][i] >=
158 self.bottoms[i] - self.margins['bottom'][i] -
159 self.margins['bottomcb'][i])
160 ]
161 for c in hc:
162 self.solver.addConstraint(c | 'required')
163
164 def add_child(self, child, i=0, j=0):
165 # np.ix_ returns the cross product of i and j indices
166 self.children[np.ix_(np.atleast_1d(i), np.atleast_1d(j))] = child
167
168 def parent_constraints(self, parent):
169 # constraints that are due to the parent...
170 # i.e. the first column's left is equal to the
171 # parent's left, the last column right equal to the
172 # parent's right...
173 if not isinstance(parent, LayoutGrid):
174 # specify a rectangle in figure coordinates
175 hc = [self.lefts[0] == parent[0],
176 self.rights[-1] == parent[0] + parent[2],
177 # top and bottom reversed order...
178 self.tops[0] == parent[1] + parent[3],
179 self.bottoms[-1] == parent[1]]
180 else:
181 rows, cols = self.parent_pos
182 rows = np.atleast_1d(rows)
183 cols = np.atleast_1d(cols)
184
185 left = parent.lefts[cols[0]]
186 right = parent.rights[cols[-1]]
187 top = parent.tops[rows[0]]
188 bottom = parent.bottoms[rows[-1]]
189 if self.parent_inner:
190 # the layout grid is contained inside the inner
191 # grid of the parent.
192 left += parent.margins['left'][cols[0]]
193 left += parent.margins['leftcb'][cols[0]]
194 right -= parent.margins['right'][cols[-1]]
195 right -= parent.margins['rightcb'][cols[-1]]
196 top -= parent.margins['top'][rows[0]]
197 top -= parent.margins['topcb'][rows[0]]
198 bottom += parent.margins['bottom'][rows[-1]]
199 bottom += parent.margins['bottomcb'][rows[-1]]
200 hc = [self.lefts[0] == left,
201 self.rights[-1] == right,
202 # from top to bottom
203 self.tops[0] == top,
204 self.bottoms[-1] == bottom]
205 for c in hc:
206 self.solver.addConstraint(c | 'required')
207
208 def grid_constraints(self):
209 # constrain the ratio of the inner part of the grids
210 # to be the same (relative to width_ratios)
211
212 # constrain widths:
213 w = (self.rights[0] - self.margins['right'][0] -
214 self.margins['rightcb'][0])
215 w = (w - self.lefts[0] - self.margins['left'][0] -
216 self.margins['leftcb'][0])
217 w0 = w / self.width_ratios[0]
218 # from left to right
219 for i in range(1, self.ncols):
220 w = (self.rights[i] - self.margins['right'][i] -
221 self.margins['rightcb'][i])
222 w = (w - self.lefts[i] - self.margins['left'][i] -
223 self.margins['leftcb'][i])
224 c = (w == w0 * self.width_ratios[i])
225 self.solver.addConstraint(c | 'strong')
226 # constrain the grid cells to be directly next to each other.
227 c = (self.rights[i - 1] == self.lefts[i])
228 self.solver.addConstraint(c | 'strong')
229
230 # constrain heights:
231 h = self.tops[0] - self.margins['top'][0] - self.margins['topcb'][0]
232 h = (h - self.bottoms[0] - self.margins['bottom'][0] -
233 self.margins['bottomcb'][0])
234 h0 = h / self.height_ratios[0]
235 # from top to bottom:
236 for i in range(1, self.nrows):
237 h = (self.tops[i] - self.margins['top'][i] -
238 self.margins['topcb'][i])
239 h = (h - self.bottoms[i] - self.margins['bottom'][i] -
240 self.margins['bottomcb'][i])
241 c = (h == h0 * self.height_ratios[i])
242 self.solver.addConstraint(c | 'strong')
243 # constrain the grid cells to be directly above each other.
244 c = (self.bottoms[i - 1] == self.tops[i])
245 self.solver.addConstraint(c | 'strong')
246
247 # Margin editing: The margins are variable and meant to
248 # contain things of a fixed size like axes labels, tick labels, titles
249 # etc
250 def edit_margin(self, todo, size, cell):
251 """
252 Change the size of the margin for one cell.
253
254 Parameters
255 ----------
256 todo : string (one of 'left', 'right', 'bottom', 'top')
257 margin to alter.
258
259 size : float
260 Size of the margin. If it is larger than the existing minimum it
261 updates the margin size. Fraction of figure size.
262
263 cell : int
264 Cell column or row to edit.
265 """
266 self.solver.suggestValue(self.margins[todo][cell], size)
267 self.margin_vals[todo][cell] = size
268
269 def edit_margin_min(self, todo, size, cell=0):
270 """
271 Change the minimum size of the margin for one cell.
272
273 Parameters
274 ----------
275 todo : string (one of 'left', 'right', 'bottom', 'top')
276 margin to alter.
277
278 size : float
279 Minimum size of the margin . If it is larger than the
280 existing minimum it updates the margin size. Fraction of
281 figure size.
282
283 cell : int
284 Cell column or row to edit.
285 """
286
287 if size > self.margin_vals[todo][cell]:
288 self.edit_margin(todo, size, cell)
289
290 def edit_margins(self, todo, size):
291 """
292 Change the size of all the margin of all the cells in the layout grid.
293
294 Parameters
295 ----------
296 todo : string (one of 'left', 'right', 'bottom', 'top')
297 margin to alter.
298
299 size : float
300 Size to set the margins. Fraction of figure size.
301 """
302
303 for i in range(len(self.margin_vals[todo])):
304 self.edit_margin(todo, size, i)
305
306 def edit_all_margins_min(self, todo, size):
307 """
308 Change the minimum size of all the margin of all
309 the cells in the layout grid.
310
311 Parameters
312 ----------
313 todo : {'left', 'right', 'bottom', 'top'}
314 The margin to alter.
315
316 size : float
317 Minimum size of the margin. If it is larger than the
318 existing minimum it updates the margin size. Fraction of
319 figure size.
320 """
321
322 for i in range(len(self.margin_vals[todo])):
323 self.edit_margin_min(todo, size, i)
324
325 def edit_outer_margin_mins(self, margin, ss):
326 """
327 Edit all four margin minimums in one statement.
328
329 Parameters
330 ----------
331 margin : dict
332 size of margins in a dict with keys 'left', 'right', 'bottom',
333 'top'
334
335 ss : SubplotSpec
336 defines the subplotspec these margins should be applied to
337 """
338
339 self.edit_margin_min('left', margin['left'], ss.colspan.start)
340 self.edit_margin_min('leftcb', margin['leftcb'], ss.colspan.start)
341 self.edit_margin_min('right', margin['right'], ss.colspan.stop - 1)
342 self.edit_margin_min('rightcb', margin['rightcb'], ss.colspan.stop - 1)
343 # rows are from the top down:
344 self.edit_margin_min('top', margin['top'], ss.rowspan.start)
345 self.edit_margin_min('topcb', margin['topcb'], ss.rowspan.start)
346 self.edit_margin_min('bottom', margin['bottom'], ss.rowspan.stop - 1)
347 self.edit_margin_min('bottomcb', margin['bottomcb'],
348 ss.rowspan.stop - 1)
349
350 def get_margins(self, todo, col):
351 """Return the margin at this position"""
352 return self.margin_vals[todo][col]
353
354 def get_outer_bbox(self, rows=0, cols=0):
355 """
356 Return the outer bounding box of the subplot specs
357 given by rows and cols. rows and cols can be spans.
358 """
359 rows = np.atleast_1d(rows)
360 cols = np.atleast_1d(cols)
361
362 bbox = Bbox.from_extents(
363 self.lefts[cols[0]].value(),
364 self.bottoms[rows[-1]].value(),
365 self.rights[cols[-1]].value(),
366 self.tops[rows[0]].value())
367 return bbox
368
369 def get_inner_bbox(self, rows=0, cols=0):
370 """
371 Return the inner bounding box of the subplot specs
372 given by rows and cols. rows and cols can be spans.
373 """
374 rows = np.atleast_1d(rows)
375 cols = np.atleast_1d(cols)
376
377 bbox = Bbox.from_extents(
378 (self.lefts[cols[0]].value() +
379 self.margins['left'][cols[0]].value() +
380 self.margins['leftcb'][cols[0]].value()),
381 (self.bottoms[rows[-1]].value() +
382 self.margins['bottom'][rows[-1]].value() +
383 self.margins['bottomcb'][rows[-1]].value()),
384 (self.rights[cols[-1]].value() -
385 self.margins['right'][cols[-1]].value() -
386 self.margins['rightcb'][cols[-1]].value()),
387 (self.tops[rows[0]].value() -
388 self.margins['top'][rows[0]].value() -
389 self.margins['topcb'][rows[0]].value())
390 )
391 return bbox
392
393 def get_bbox_for_cb(self, rows=0, cols=0):
394 """
395 Return the bounding box that includes the
396 decorations but, *not* the colorbar...
397 """
398 rows = np.atleast_1d(rows)
399 cols = np.atleast_1d(cols)
400
401 bbox = Bbox.from_extents(
402 (self.lefts[cols[0]].value() +
403 self.margins['leftcb'][cols[0]].value()),
404 (self.bottoms[rows[-1]].value() +
405 self.margins['bottomcb'][rows[-1]].value()),
406 (self.rights[cols[-1]].value() -
407 self.margins['rightcb'][cols[-1]].value()),
408 (self.tops[rows[0]].value() -
409 self.margins['topcb'][rows[0]].value())
410 )
411 return bbox
412
413 def get_left_margin_bbox(self, rows=0, cols=0):
414 """
415 Return the left margin bounding box of the subplot specs
416 given by rows and cols. rows and cols can be spans.
417 """
418 rows = np.atleast_1d(rows)
419 cols = np.atleast_1d(cols)
420
421 bbox = Bbox.from_extents(
422 (self.lefts[cols[0]].value() +
423 self.margins['leftcb'][cols[0]].value()),
424 (self.bottoms[rows[-1]].value()),
425 (self.lefts[cols[0]].value() +
426 self.margins['leftcb'][cols[0]].value() +
427 self.margins['left'][cols[0]].value()),
428 (self.tops[rows[0]].value()))
429 return bbox
430
431 def get_bottom_margin_bbox(self, rows=0, cols=0):
432 """
433 Return the left margin bounding box of the subplot specs
434 given by rows and cols. rows and cols can be spans.
435 """
436 rows = np.atleast_1d(rows)
437 cols = np.atleast_1d(cols)
438
439 bbox = Bbox.from_extents(
440 (self.lefts[cols[0]].value()),
441 (self.bottoms[rows[-1]].value() +
442 self.margins['bottomcb'][rows[-1]].value()),
443 (self.rights[cols[-1]].value()),
444 (self.bottoms[rows[-1]].value() +
445 self.margins['bottom'][rows[-1]].value() +
446 self.margins['bottomcb'][rows[-1]].value()
447 ))
448 return bbox
449
450 def get_right_margin_bbox(self, rows=0, cols=0):
451 """
452 Return the left margin bounding box of the subplot specs
453 given by rows and cols. rows and cols can be spans.
454 """
455 rows = np.atleast_1d(rows)
456 cols = np.atleast_1d(cols)
457
458 bbox = Bbox.from_extents(
459 (self.rights[cols[-1]].value() -
460 self.margins['right'][cols[-1]].value() -
461 self.margins['rightcb'][cols[-1]].value()),
462 (self.bottoms[rows[-1]].value()),
463 (self.rights[cols[-1]].value() -
464 self.margins['rightcb'][cols[-1]].value()),
465 (self.tops[rows[0]].value()))
466 return bbox
467
468 def get_top_margin_bbox(self, rows=0, cols=0):
469 """
470 Return the left margin bounding box of the subplot specs
471 given by rows and cols. rows and cols can be spans.
472 """
473 rows = np.atleast_1d(rows)
474 cols = np.atleast_1d(cols)
475
476 bbox = Bbox.from_extents(
477 (self.lefts[cols[0]].value()),
478 (self.tops[rows[0]].value() -
479 self.margins['topcb'][rows[0]].value()),
480 (self.rights[cols[-1]].value()),
481 (self.tops[rows[0]].value() -
482 self.margins['topcb'][rows[0]].value() -
483 self.margins['top'][rows[0]].value()))
484 return bbox
485
486 def update_variables(self):
487 """
488 Update the variables for the solver attached to this layoutgrid.
489 """
490 self.solver.updateVariables()
491
492_layoutboxobjnum = itertools.count()
493
494
495def seq_id():
496 """Generate a short sequential id for layoutbox objects."""
497 return '%06d' % next(_layoutboxobjnum)
498
499
500def plot_children(fig, lg=None, level=0):
501 """Simple plotting to show where boxes are."""
502 if lg is None:
503 _layoutgrids = fig.get_layout_engine().execute(fig)
504 lg = _layoutgrids[fig]
505 colors = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
506 col = colors[level]
507 for i in range(lg.nrows):
508 for j in range(lg.ncols):
509 bb = lg.get_outer_bbox(rows=i, cols=j)
510 fig.add_artist(
511 mpatches.Rectangle(bb.p0, bb.width, bb.height, linewidth=1,
512 edgecolor='0.7', facecolor='0.7',
513 alpha=0.2, transform=fig.transFigure,
514 zorder=-3))
515 bbi = lg.get_inner_bbox(rows=i, cols=j)
516 fig.add_artist(
517 mpatches.Rectangle(bbi.p0, bbi.width, bbi.height, linewidth=2,
518 edgecolor=col, facecolor='none',
519 transform=fig.transFigure, zorder=-2))
520
521 bbi = lg.get_left_margin_bbox(rows=i, cols=j)
522 fig.add_artist(
523 mpatches.Rectangle(bbi.p0, bbi.width, bbi.height, linewidth=0,
524 edgecolor='none', alpha=0.2,
525 facecolor=[0.5, 0.7, 0.5],
526 transform=fig.transFigure, zorder=-2))
527 bbi = lg.get_right_margin_bbox(rows=i, cols=j)
528 fig.add_artist(
529 mpatches.Rectangle(bbi.p0, bbi.width, bbi.height, linewidth=0,
530 edgecolor='none', alpha=0.2,
531 facecolor=[0.7, 0.5, 0.5],
532 transform=fig.transFigure, zorder=-2))
533 bbi = lg.get_bottom_margin_bbox(rows=i, cols=j)
534 fig.add_artist(
535 mpatches.Rectangle(bbi.p0, bbi.width, bbi.height, linewidth=0,
536 edgecolor='none', alpha=0.2,
537 facecolor=[0.5, 0.5, 0.7],
538 transform=fig.transFigure, zorder=-2))
539 bbi = lg.get_top_margin_bbox(rows=i, cols=j)
540 fig.add_artist(
541 mpatches.Rectangle(bbi.p0, bbi.width, bbi.height, linewidth=0,
542 edgecolor='none', alpha=0.2,
543 facecolor=[0.7, 0.2, 0.7],
544 transform=fig.transFigure, zorder=-2))
545 for ch in lg.children.flat:
546 if ch is not None:
547 plot_children(fig, ch, level=level+1)