1# axis3d.py, original mplot3d version by John Porter
2# Created: 23 Sep 2005
3# Parts rewritten by Reinier Heeres <reinier@heeres.eu>
4
5import inspect
6
7import numpy as np
8
9import matplotlib as mpl
10from matplotlib import (
11 _api, artist, lines as mlines, axis as maxis, patches as mpatches,
12 transforms as mtransforms, colors as mcolors)
13from . import art3d, proj3d
14
15
16def _move_from_center(coord, centers, deltas, axmask=(True, True, True)):
17 """
18 For each coordinate where *axmask* is True, move *coord* away from
19 *centers* by *deltas*.
20 """
21 coord = np.asarray(coord)
22 return coord + axmask * np.copysign(1, coord - centers) * deltas
23
24
25def _tick_update_position(tick, tickxs, tickys, labelpos):
26 """Update tick line and label position and style."""
27
28 tick.label1.set_position(labelpos)
29 tick.label2.set_position(labelpos)
30 tick.tick1line.set_visible(True)
31 tick.tick2line.set_visible(False)
32 tick.tick1line.set_linestyle('-')
33 tick.tick1line.set_marker('')
34 tick.tick1line.set_data(tickxs, tickys)
35 tick.gridline.set_data([0], [0])
36
37
38class Axis(maxis.XAxis):
39 """An Axis class for the 3D plots."""
40 # These points from the unit cube make up the x, y and z-planes
41 _PLANES = (
42 (0, 3, 7, 4), (1, 2, 6, 5), # yz planes
43 (0, 1, 5, 4), (3, 2, 6, 7), # xz planes
44 (0, 1, 2, 3), (4, 5, 6, 7), # xy planes
45 )
46
47 # Some properties for the axes
48 _AXINFO = {
49 'x': {'i': 0, 'tickdir': 1, 'juggled': (1, 0, 2)},
50 'y': {'i': 1, 'tickdir': 0, 'juggled': (0, 1, 2)},
51 'z': {'i': 2, 'tickdir': 0, 'juggled': (0, 2, 1)},
52 }
53
54 def _old_init(self, adir, v_intervalx, d_intervalx, axes, *args,
55 rotate_label=None, **kwargs):
56 return locals()
57
58 def _new_init(self, axes, *, rotate_label=None, **kwargs):
59 return locals()
60
61 def __init__(self, *args, **kwargs):
62 params = _api.select_matching_signature(
63 [self._old_init, self._new_init], *args, **kwargs)
64 if "adir" in params:
65 _api.warn_deprecated(
66 "3.6", message=f"The signature of 3D Axis constructors has "
67 f"changed in %(since)s; the new signature is "
68 f"{inspect.signature(type(self).__init__)}", pending=True)
69 if params["adir"] != self.axis_name:
70 raise ValueError(f"Cannot instantiate {type(self).__name__} "
71 f"with adir={params['adir']!r}")
72 axes = params["axes"]
73 rotate_label = params["rotate_label"]
74 args = params.get("args", ())
75 kwargs = params["kwargs"]
76
77 name = self.axis_name
78
79 self._label_position = 'default'
80 self._tick_position = 'default'
81
82 # This is a temporary member variable.
83 # Do not depend on this existing in future releases!
84 self._axinfo = self._AXINFO[name].copy()
85 # Common parts
86 self._axinfo.update({
87 'label': {'va': 'center', 'ha': 'center',
88 'rotation_mode': 'anchor'},
89 'color': mpl.rcParams[f'axes3d.{name}axis.panecolor'],
90 'tick': {
91 'inward_factor': 0.2,
92 'outward_factor': 0.1,
93 },
94 })
95
96 if mpl.rcParams['_internal.classic_mode']:
97 self._axinfo.update({
98 'axisline': {'linewidth': 0.75, 'color': (0, 0, 0, 1)},
99 'grid': {
100 'color': (0.9, 0.9, 0.9, 1),
101 'linewidth': 1.0,
102 'linestyle': '-',
103 },
104 })
105 self._axinfo['tick'].update({
106 'linewidth': {
107 True: mpl.rcParams['lines.linewidth'], # major
108 False: mpl.rcParams['lines.linewidth'], # minor
109 }
110 })
111 else:
112 self._axinfo.update({
113 'axisline': {
114 'linewidth': mpl.rcParams['axes.linewidth'],
115 'color': mpl.rcParams['axes.edgecolor'],
116 },
117 'grid': {
118 'color': mpl.rcParams['grid.color'],
119 'linewidth': mpl.rcParams['grid.linewidth'],
120 'linestyle': mpl.rcParams['grid.linestyle'],
121 },
122 })
123 self._axinfo['tick'].update({
124 'linewidth': {
125 True: ( # major
126 mpl.rcParams['xtick.major.width'] if name in 'xz'
127 else mpl.rcParams['ytick.major.width']),
128 False: ( # minor
129 mpl.rcParams['xtick.minor.width'] if name in 'xz'
130 else mpl.rcParams['ytick.minor.width']),
131 }
132 })
133
134 super().__init__(axes, *args, **kwargs)
135
136 # data and viewing intervals for this direction
137 if "d_intervalx" in params:
138 self.set_data_interval(*params["d_intervalx"])
139 if "v_intervalx" in params:
140 self.set_view_interval(*params["v_intervalx"])
141 self.set_rotate_label(rotate_label)
142 self._init3d() # Inline after init3d deprecation elapses.
143
144 __init__.__signature__ = inspect.signature(_new_init)
145 adir = _api.deprecated("3.6", pending=True)(
146 property(lambda self: self.axis_name))
147
148 def _init3d(self):
149 self.line = mlines.Line2D(
150 xdata=(0, 0), ydata=(0, 0),
151 linewidth=self._axinfo['axisline']['linewidth'],
152 color=self._axinfo['axisline']['color'],
153 antialiased=True)
154
155 # Store dummy data in Polygon object
156 self.pane = mpatches.Polygon([[0, 0], [0, 1]], closed=False)
157 self.set_pane_color(self._axinfo['color'])
158
159 self.axes._set_artist_props(self.line)
160 self.axes._set_artist_props(self.pane)
161 self.gridlines = art3d.Line3DCollection([])
162 self.axes._set_artist_props(self.gridlines)
163 self.axes._set_artist_props(self.label)
164 self.axes._set_artist_props(self.offsetText)
165 # Need to be able to place the label at the correct location
166 self.label._transform = self.axes.transData
167 self.offsetText._transform = self.axes.transData
168
169 @_api.deprecated("3.6", pending=True)
170 def init3d(self): # After deprecation elapses, inline _init3d to __init__.
171 self._init3d()
172
173 def get_major_ticks(self, numticks=None):
174 ticks = super().get_major_ticks(numticks)
175 for t in ticks:
176 for obj in [
177 t.tick1line, t.tick2line, t.gridline, t.label1, t.label2]:
178 obj.set_transform(self.axes.transData)
179 return ticks
180
181 def get_minor_ticks(self, numticks=None):
182 ticks = super().get_minor_ticks(numticks)
183 for t in ticks:
184 for obj in [
185 t.tick1line, t.tick2line, t.gridline, t.label1, t.label2]:
186 obj.set_transform(self.axes.transData)
187 return ticks
188
189 def set_ticks_position(self, position):
190 """
191 Set the ticks position.
192
193 Parameters
194 ----------
195 position : {'lower', 'upper', 'both', 'default', 'none'}
196 The position of the bolded axis lines, ticks, and tick labels.
197 """
198 if position in ['top', 'bottom']:
199 _api.warn_deprecated('3.8', name=f'{position=}',
200 obj_type='argument value',
201 alternative="'upper' or 'lower'")
202 return
203 _api.check_in_list(['lower', 'upper', 'both', 'default', 'none'],
204 position=position)
205 self._tick_position = position
206
207 def get_ticks_position(self):
208 """
209 Get the ticks position.
210
211 Returns
212 -------
213 str : {'lower', 'upper', 'both', 'default', 'none'}
214 The position of the bolded axis lines, ticks, and tick labels.
215 """
216 return self._tick_position
217
218 def set_label_position(self, position):
219 """
220 Set the label position.
221
222 Parameters
223 ----------
224 position : {'lower', 'upper', 'both', 'default', 'none'}
225 The position of the axis label.
226 """
227 if position in ['top', 'bottom']:
228 _api.warn_deprecated('3.8', name=f'{position=}',
229 obj_type='argument value',
230 alternative="'upper' or 'lower'")
231 return
232 _api.check_in_list(['lower', 'upper', 'both', 'default', 'none'],
233 position=position)
234 self._label_position = position
235
236 def get_label_position(self):
237 """
238 Get the label position.
239
240 Returns
241 -------
242 str : {'lower', 'upper', 'both', 'default', 'none'}
243 The position of the axis label.
244 """
245 return self._label_position
246
247 def set_pane_color(self, color, alpha=None):
248 """
249 Set pane color.
250
251 Parameters
252 ----------
253 color : :mpltype:`color`
254 Color for axis pane.
255 alpha : float, optional
256 Alpha value for axis pane. If None, base it on *color*.
257 """
258 color = mcolors.to_rgba(color, alpha)
259 self._axinfo['color'] = color
260 self.pane.set_edgecolor(color)
261 self.pane.set_facecolor(color)
262 self.pane.set_alpha(color[-1])
263 self.stale = True
264
265 def set_rotate_label(self, val):
266 """
267 Whether to rotate the axis label: True, False or None.
268 If set to None the label will be rotated if longer than 4 chars.
269 """
270 self._rotate_label = val
271 self.stale = True
272
273 def get_rotate_label(self, text):
274 if self._rotate_label is not None:
275 return self._rotate_label
276 else:
277 return len(text) > 4
278
279 def _get_coord_info(self):
280 mins, maxs = np.array([
281 self.axes.get_xbound(),
282 self.axes.get_ybound(),
283 self.axes.get_zbound(),
284 ]).T
285
286 # Project the bounds along the current position of the cube:
287 bounds = mins[0], maxs[0], mins[1], maxs[1], mins[2], maxs[2]
288 bounds_proj = self.axes._transformed_cube(bounds)
289
290 # Determine which one of the parallel planes are higher up:
291 means_z0 = np.zeros(3)
292 means_z1 = np.zeros(3)
293 for i in range(3):
294 means_z0[i] = np.mean(bounds_proj[self._PLANES[2 * i], 2])
295 means_z1[i] = np.mean(bounds_proj[self._PLANES[2 * i + 1], 2])
296 highs = means_z0 < means_z1
297
298 # Special handling for edge-on views
299 equals = np.abs(means_z0 - means_z1) <= np.finfo(float).eps
300 if np.sum(equals) == 2:
301 vertical = np.where(~equals)[0][0]
302 if vertical == 2: # looking at XY plane
303 highs = np.array([True, True, highs[2]])
304 elif vertical == 1: # looking at XZ plane
305 highs = np.array([True, highs[1], False])
306 elif vertical == 0: # looking at YZ plane
307 highs = np.array([highs[0], False, False])
308
309 return mins, maxs, bounds_proj, highs
310
311 def _calc_centers_deltas(self, maxs, mins):
312 centers = 0.5 * (maxs + mins)
313 # In mpl3.8, the scale factor was 1/12. mpl3.9 changes this to
314 # 1/12 * 24/25 = 0.08 to compensate for the change in automargin
315 # behavior and keep appearance the same. The 24/25 factor is from the
316 # 1/48 padding added to each side of the axis in mpl3.8.
317 scale = 0.08
318 deltas = (maxs - mins) * scale
319 return centers, deltas
320
321 def _get_axis_line_edge_points(self, minmax, maxmin, position=None):
322 """Get the edge points for the black bolded axis line."""
323 # When changing vertical axis some of the axes has to be
324 # moved to the other plane so it looks the same as if the z-axis
325 # was the vertical axis.
326 mb = [minmax, maxmin] # line from origin to nearest corner to camera
327 mb_rev = mb[::-1]
328 mm = [[mb, mb_rev, mb_rev], [mb_rev, mb_rev, mb], [mb, mb, mb]]
329 mm = mm[self.axes._vertical_axis][self._axinfo["i"]]
330
331 juggled = self._axinfo["juggled"]
332 edge_point_0 = mm[0].copy() # origin point
333
334 if ((position == 'lower' and mm[1][juggled[-1]] < mm[0][juggled[-1]]) or
335 (position == 'upper' and mm[1][juggled[-1]] > mm[0][juggled[-1]])):
336 edge_point_0[juggled[-1]] = mm[1][juggled[-1]]
337 else:
338 edge_point_0[juggled[0]] = mm[1][juggled[0]]
339
340 edge_point_1 = edge_point_0.copy()
341 edge_point_1[juggled[1]] = mm[1][juggled[1]]
342
343 return edge_point_0, edge_point_1
344
345 def _get_all_axis_line_edge_points(self, minmax, maxmin, axis_position=None):
346 # Determine edge points for the axis lines
347 edgep1s = []
348 edgep2s = []
349 position = []
350 if axis_position in (None, 'default'):
351 edgep1, edgep2 = self._get_axis_line_edge_points(minmax, maxmin)
352 edgep1s = [edgep1]
353 edgep2s = [edgep2]
354 position = ['default']
355 else:
356 edgep1_l, edgep2_l = self._get_axis_line_edge_points(minmax, maxmin,
357 position='lower')
358 edgep1_u, edgep2_u = self._get_axis_line_edge_points(minmax, maxmin,
359 position='upper')
360 if axis_position in ('lower', 'both'):
361 edgep1s.append(edgep1_l)
362 edgep2s.append(edgep2_l)
363 position.append('lower')
364 if axis_position in ('upper', 'both'):
365 edgep1s.append(edgep1_u)
366 edgep2s.append(edgep2_u)
367 position.append('upper')
368 return edgep1s, edgep2s, position
369
370 def _get_tickdir(self, position):
371 """
372 Get the direction of the tick.
373
374 Parameters
375 ----------
376 position : str, optional : {'upper', 'lower', 'default'}
377 The position of the axis.
378
379 Returns
380 -------
381 tickdir : int
382 Index which indicates which coordinate the tick line will
383 align with.
384 """
385 _api.check_in_list(('upper', 'lower', 'default'), position=position)
386
387 # TODO: Move somewhere else where it's triggered less:
388 tickdirs_base = [v["tickdir"] for v in self._AXINFO.values()] # default
389 elev_mod = np.mod(self.axes.elev + 180, 360) - 180
390 azim_mod = np.mod(self.axes.azim, 360)
391 if position == 'upper':
392 if elev_mod >= 0:
393 tickdirs_base = [2, 2, 0]
394 else:
395 tickdirs_base = [1, 0, 0]
396 if 0 <= azim_mod < 180:
397 tickdirs_base[2] = 1
398 elif position == 'lower':
399 if elev_mod >= 0:
400 tickdirs_base = [1, 0, 1]
401 else:
402 tickdirs_base = [2, 2, 1]
403 if 0 <= azim_mod < 180:
404 tickdirs_base[2] = 0
405 info_i = [v["i"] for v in self._AXINFO.values()]
406
407 i = self._axinfo["i"]
408 vert_ax = self.axes._vertical_axis
409 j = vert_ax - 2
410 # default: tickdir = [[1, 2, 1], [2, 2, 0], [1, 0, 0]][vert_ax][i]
411 tickdir = np.roll(info_i, -j)[np.roll(tickdirs_base, j)][i]
412 return tickdir
413
414 def active_pane(self):
415 mins, maxs, tc, highs = self._get_coord_info()
416 info = self._axinfo
417 index = info['i']
418 if not highs[index]:
419 loc = mins[index]
420 plane = self._PLANES[2 * index]
421 else:
422 loc = maxs[index]
423 plane = self._PLANES[2 * index + 1]
424 xys = np.array([tc[p] for p in plane])
425 return xys, loc
426
427 def draw_pane(self, renderer):
428 """
429 Draw pane.
430
431 Parameters
432 ----------
433 renderer : `~matplotlib.backend_bases.RendererBase` subclass
434 """
435 renderer.open_group('pane3d', gid=self.get_gid())
436 xys, loc = self.active_pane()
437 self.pane.xy = xys[:, :2]
438 self.pane.draw(renderer)
439 renderer.close_group('pane3d')
440
441 def _axmask(self):
442 axmask = [True, True, True]
443 axmask[self._axinfo["i"]] = False
444 return axmask
445
446 def _draw_ticks(self, renderer, edgep1, centers, deltas, highs,
447 deltas_per_point, pos):
448 ticks = self._update_ticks()
449 info = self._axinfo
450 index = info["i"]
451 juggled = info["juggled"]
452
453 mins, maxs, tc, highs = self._get_coord_info()
454 centers, deltas = self._calc_centers_deltas(maxs, mins)
455
456 # Draw ticks:
457 tickdir = self._get_tickdir(pos)
458 tickdelta = deltas[tickdir] if highs[tickdir] else -deltas[tickdir]
459
460 tick_info = info['tick']
461 tick_out = tick_info['outward_factor'] * tickdelta
462 tick_in = tick_info['inward_factor'] * tickdelta
463 tick_lw = tick_info['linewidth']
464 edgep1_tickdir = edgep1[tickdir]
465 out_tickdir = edgep1_tickdir + tick_out
466 in_tickdir = edgep1_tickdir - tick_in
467
468 default_label_offset = 8. # A rough estimate
469 points = deltas_per_point * deltas
470 for tick in ticks:
471 # Get tick line positions
472 pos = edgep1.copy()
473 pos[index] = tick.get_loc()
474 pos[tickdir] = out_tickdir
475 x1, y1, z1 = proj3d.proj_transform(*pos, self.axes.M)
476 pos[tickdir] = in_tickdir
477 x2, y2, z2 = proj3d.proj_transform(*pos, self.axes.M)
478
479 # Get position of label
480 labeldeltas = (tick.get_pad() + default_label_offset) * points
481
482 pos[tickdir] = edgep1_tickdir
483 pos = _move_from_center(pos, centers, labeldeltas, self._axmask())
484 lx, ly, lz = proj3d.proj_transform(*pos, self.axes.M)
485
486 _tick_update_position(tick, (x1, x2), (y1, y2), (lx, ly))
487 tick.tick1line.set_linewidth(tick_lw[tick._major])
488 tick.draw(renderer)
489
490 def _draw_offset_text(self, renderer, edgep1, edgep2, labeldeltas, centers,
491 highs, pep, dx, dy):
492 # Get general axis information:
493 info = self._axinfo
494 index = info["i"]
495 juggled = info["juggled"]
496 tickdir = info["tickdir"]
497
498 # Which of the two edge points do we want to
499 # use for locating the offset text?
500 if juggled[2] == 2:
501 outeredgep = edgep1
502 outerindex = 0
503 else:
504 outeredgep = edgep2
505 outerindex = 1
506
507 pos = _move_from_center(outeredgep, centers, labeldeltas,
508 self._axmask())
509 olx, oly, olz = proj3d.proj_transform(*pos, self.axes.M)
510 self.offsetText.set_text(self.major.formatter.get_offset())
511 self.offsetText.set_position((olx, oly))
512 angle = art3d._norm_text_angle(np.rad2deg(np.arctan2(dy, dx)))
513 self.offsetText.set_rotation(angle)
514 # Must set rotation mode to "anchor" so that
515 # the alignment point is used as the "fulcrum" for rotation.
516 self.offsetText.set_rotation_mode('anchor')
517
518 # ----------------------------------------------------------------------
519 # Note: the following statement for determining the proper alignment of
520 # the offset text. This was determined entirely by trial-and-error
521 # and should not be in any way considered as "the way". There are
522 # still some edge cases where alignment is not quite right, but this
523 # seems to be more of a geometry issue (in other words, I might be
524 # using the wrong reference points).
525 #
526 # (TT, FF, TF, FT) are the shorthand for the tuple of
527 # (centpt[tickdir] <= pep[tickdir, outerindex],
528 # centpt[index] <= pep[index, outerindex])
529 #
530 # Three-letters (e.g., TFT, FTT) are short-hand for the array of bools
531 # from the variable 'highs'.
532 # ---------------------------------------------------------------------
533 centpt = proj3d.proj_transform(*centers, self.axes.M)
534 if centpt[tickdir] > pep[tickdir, outerindex]:
535 # if FT and if highs has an even number of Trues
536 if (centpt[index] <= pep[index, outerindex]
537 and np.count_nonzero(highs) % 2 == 0):
538 # Usually, this means align right, except for the FTT case,
539 # in which offset for axis 1 and 2 are aligned left.
540 if highs.tolist() == [False, True, True] and index in (1, 2):
541 align = 'left'
542 else:
543 align = 'right'
544 else:
545 # The FF case
546 align = 'left'
547 else:
548 # if TF and if highs has an even number of Trues
549 if (centpt[index] > pep[index, outerindex]
550 and np.count_nonzero(highs) % 2 == 0):
551 # Usually mean align left, except if it is axis 2
552 align = 'right' if index == 2 else 'left'
553 else:
554 # The TT case
555 align = 'right'
556
557 self.offsetText.set_va('center')
558 self.offsetText.set_ha(align)
559 self.offsetText.draw(renderer)
560
561 def _draw_labels(self, renderer, edgep1, edgep2, labeldeltas, centers, dx, dy):
562 label = self._axinfo["label"]
563
564 # Draw labels
565 lxyz = 0.5 * (edgep1 + edgep2)
566 lxyz = _move_from_center(lxyz, centers, labeldeltas, self._axmask())
567 tlx, tly, tlz = proj3d.proj_transform(*lxyz, self.axes.M)
568 self.label.set_position((tlx, tly))
569 if self.get_rotate_label(self.label.get_text()):
570 angle = art3d._norm_text_angle(np.rad2deg(np.arctan2(dy, dx)))
571 self.label.set_rotation(angle)
572 self.label.set_va(label['va'])
573 self.label.set_ha(label['ha'])
574 self.label.set_rotation_mode(label['rotation_mode'])
575 self.label.draw(renderer)
576
577 @artist.allow_rasterization
578 def draw(self, renderer):
579 self.label._transform = self.axes.transData
580 self.offsetText._transform = self.axes.transData
581 renderer.open_group("axis3d", gid=self.get_gid())
582
583 # Get general axis information:
584 mins, maxs, tc, highs = self._get_coord_info()
585 centers, deltas = self._calc_centers_deltas(maxs, mins)
586
587 # Calculate offset distances
588 # A rough estimate; points are ambiguous since 3D plots rotate
589 reltoinches = self.figure.dpi_scale_trans.inverted()
590 ax_inches = reltoinches.transform(self.axes.bbox.size)
591 ax_points_estimate = sum(72. * ax_inches)
592 deltas_per_point = 48 / ax_points_estimate
593 default_offset = 21.
594 labeldeltas = (self.labelpad + default_offset) * deltas_per_point * deltas
595
596 # Determine edge points for the axis lines
597 minmax = np.where(highs, maxs, mins) # "origin" point
598 maxmin = np.where(~highs, maxs, mins) # "opposite" corner near camera
599
600 for edgep1, edgep2, pos in zip(*self._get_all_axis_line_edge_points(
601 minmax, maxmin, self._tick_position)):
602 # Project the edge points along the current position
603 pep = proj3d._proj_trans_points([edgep1, edgep2], self.axes.M)
604 pep = np.asarray(pep)
605
606 # The transAxes transform is used because the Text object
607 # rotates the text relative to the display coordinate system.
608 # Therefore, if we want the labels to remain parallel to the
609 # axis regardless of the aspect ratio, we need to convert the
610 # edge points of the plane to display coordinates and calculate
611 # an angle from that.
612 # TODO: Maybe Text objects should handle this themselves?
613 dx, dy = (self.axes.transAxes.transform([pep[0:2, 1]]) -
614 self.axes.transAxes.transform([pep[0:2, 0]]))[0]
615
616 # Draw the lines
617 self.line.set_data(pep[0], pep[1])
618 self.line.draw(renderer)
619
620 # Draw ticks
621 self._draw_ticks(renderer, edgep1, centers, deltas, highs,
622 deltas_per_point, pos)
623
624 # Draw Offset text
625 self._draw_offset_text(renderer, edgep1, edgep2, labeldeltas,
626 centers, highs, pep, dx, dy)
627
628 for edgep1, edgep2, pos in zip(*self._get_all_axis_line_edge_points(
629 minmax, maxmin, self._label_position)):
630 # See comments above
631 pep = proj3d._proj_trans_points([edgep1, edgep2], self.axes.M)
632 pep = np.asarray(pep)
633 dx, dy = (self.axes.transAxes.transform([pep[0:2, 1]]) -
634 self.axes.transAxes.transform([pep[0:2, 0]]))[0]
635
636 # Draw labels
637 self._draw_labels(renderer, edgep1, edgep2, labeldeltas, centers, dx, dy)
638
639 renderer.close_group('axis3d')
640 self.stale = False
641
642 @artist.allow_rasterization
643 def draw_grid(self, renderer):
644 if not self.axes._draw_grid:
645 return
646
647 renderer.open_group("grid3d", gid=self.get_gid())
648
649 ticks = self._update_ticks()
650 if len(ticks):
651 # Get general axis information:
652 info = self._axinfo
653 index = info["i"]
654
655 mins, maxs, tc, highs = self._get_coord_info()
656
657 minmax = np.where(highs, maxs, mins)
658 maxmin = np.where(~highs, maxs, mins)
659
660 # Grid points where the planes meet
661 xyz0 = np.tile(minmax, (len(ticks), 1))
662 xyz0[:, index] = [tick.get_loc() for tick in ticks]
663
664 # Grid lines go from the end of one plane through the plane
665 # intersection (at xyz0) to the end of the other plane. The first
666 # point (0) differs along dimension index-2 and the last (2) along
667 # dimension index-1.
668 lines = np.stack([xyz0, xyz0, xyz0], axis=1)
669 lines[:, 0, index - 2] = maxmin[index - 2]
670 lines[:, 2, index - 1] = maxmin[index - 1]
671 self.gridlines.set_segments(lines)
672 gridinfo = info['grid']
673 self.gridlines.set_color(gridinfo['color'])
674 self.gridlines.set_linewidth(gridinfo['linewidth'])
675 self.gridlines.set_linestyle(gridinfo['linestyle'])
676 self.gridlines.do_3d_projection()
677 self.gridlines.draw(renderer)
678
679 renderer.close_group('grid3d')
680
681 # TODO: Get this to work (more) properly when mplot3d supports the
682 # transforms framework.
683 def get_tightbbox(self, renderer=None, *, for_layout_only=False):
684 # docstring inherited
685 if not self.get_visible():
686 return
687 # We have to directly access the internal data structures
688 # (and hope they are up to date) because at draw time we
689 # shift the ticks and their labels around in (x, y) space
690 # based on the projection, the current view port, and their
691 # position in 3D space. If we extend the transforms framework
692 # into 3D we would not need to do this different book keeping
693 # than we do in the normal axis
694 major_locs = self.get_majorticklocs()
695 minor_locs = self.get_minorticklocs()
696
697 ticks = [*self.get_minor_ticks(len(minor_locs)),
698 *self.get_major_ticks(len(major_locs))]
699 view_low, view_high = self.get_view_interval()
700 if view_low > view_high:
701 view_low, view_high = view_high, view_low
702 interval_t = self.get_transform().transform([view_low, view_high])
703
704 ticks_to_draw = []
705 for tick in ticks:
706 try:
707 loc_t = self.get_transform().transform(tick.get_loc())
708 except AssertionError:
709 # Transform.transform doesn't allow masked values but
710 # some scales might make them, so we need this try/except.
711 pass
712 else:
713 if mtransforms._interval_contains_close(interval_t, loc_t):
714 ticks_to_draw.append(tick)
715
716 ticks = ticks_to_draw
717
718 bb_1, bb_2 = self._get_ticklabel_bboxes(ticks, renderer)
719 other = []
720
721 if self.line.get_visible():
722 other.append(self.line.get_window_extent(renderer))
723 if (self.label.get_visible() and not for_layout_only and
724 self.label.get_text()):
725 other.append(self.label.get_window_extent(renderer))
726
727 return mtransforms.Bbox.union([*bb_1, *bb_2, *other])
728
729 d_interval = _api.deprecated(
730 "3.6", alternative="get_data_interval", pending=True)(
731 property(lambda self: self.get_data_interval(),
732 lambda self, minmax: self.set_data_interval(*minmax)))
733 v_interval = _api.deprecated(
734 "3.6", alternative="get_view_interval", pending=True)(
735 property(lambda self: self.get_view_interval(),
736 lambda self, minmax: self.set_view_interval(*minmax)))
737
738
739class XAxis(Axis):
740 axis_name = "x"
741 get_view_interval, set_view_interval = maxis._make_getset_interval(
742 "view", "xy_viewLim", "intervalx")
743 get_data_interval, set_data_interval = maxis._make_getset_interval(
744 "data", "xy_dataLim", "intervalx")
745
746
747class YAxis(Axis):
748 axis_name = "y"
749 get_view_interval, set_view_interval = maxis._make_getset_interval(
750 "view", "xy_viewLim", "intervaly")
751 get_data_interval, set_data_interval = maxis._make_getset_interval(
752 "data", "xy_dataLim", "intervaly")
753
754
755class ZAxis(Axis):
756 axis_name = "z"
757 get_view_interval, set_view_interval = maxis._make_getset_interval(
758 "view", "zz_viewLim", "intervalx")
759 get_data_interval, set_data_interval = maxis._make_getset_interval(
760 "data", "zz_dataLim", "intervalx")