Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/matplotlib/cbook.py: 20%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2A collection of utility functions and classes. Originally, many
3(but not all) were from the Python Cookbook -- hence the name cbook.
4"""
6import collections
7import collections.abc
8import contextlib
9import functools
10import gzip
11import itertools
12import math
13import operator
14import os
15from pathlib import Path
16import shlex
17import subprocess
18import sys
19import time
20import traceback
21import types
22import weakref
24import numpy as np
26try:
27 from numpy.exceptions import VisibleDeprecationWarning # numpy >= 1.25
28except ImportError:
29 from numpy import VisibleDeprecationWarning
31import matplotlib
32from matplotlib import _api, _c_internal_utils
35def _get_running_interactive_framework():
36 """
37 Return the interactive framework whose event loop is currently running, if
38 any, or "headless" if no event loop can be started, or None.
40 Returns
41 -------
42 Optional[str]
43 One of the following values: "qt", "gtk3", "gtk4", "wx", "tk",
44 "macosx", "headless", ``None``.
45 """
46 # Use ``sys.modules.get(name)`` rather than ``name in sys.modules`` as
47 # entries can also have been explicitly set to None.
48 QtWidgets = (
49 sys.modules.get("PyQt6.QtWidgets")
50 or sys.modules.get("PySide6.QtWidgets")
51 or sys.modules.get("PyQt5.QtWidgets")
52 or sys.modules.get("PySide2.QtWidgets")
53 )
54 if QtWidgets and QtWidgets.QApplication.instance():
55 return "qt"
56 Gtk = sys.modules.get("gi.repository.Gtk")
57 if Gtk:
58 if Gtk.MAJOR_VERSION == 4:
59 from gi.repository import GLib
60 if GLib.main_depth():
61 return "gtk4"
62 if Gtk.MAJOR_VERSION == 3 and Gtk.main_level():
63 return "gtk3"
64 wx = sys.modules.get("wx")
65 if wx and wx.GetApp():
66 return "wx"
67 tkinter = sys.modules.get("tkinter")
68 if tkinter:
69 codes = {tkinter.mainloop.__code__, tkinter.Misc.mainloop.__code__}
70 for frame in sys._current_frames().values():
71 while frame:
72 if frame.f_code in codes:
73 return "tk"
74 frame = frame.f_back
75 # Preemptively break reference cycle between locals and the frame.
76 del frame
77 macosx = sys.modules.get("matplotlib.backends._macosx")
78 if macosx and macosx.event_loop_is_running():
79 return "macosx"
80 if not _c_internal_utils.display_is_valid():
81 return "headless"
82 return None
85def _exception_printer(exc):
86 if _get_running_interactive_framework() in ["headless", None]:
87 raise exc
88 else:
89 traceback.print_exc()
92class _StrongRef:
93 """
94 Wrapper similar to a weakref, but keeping a strong reference to the object.
95 """
97 def __init__(self, obj):
98 self._obj = obj
100 def __call__(self):
101 return self._obj
103 def __eq__(self, other):
104 return isinstance(other, _StrongRef) and self._obj == other._obj
106 def __hash__(self):
107 return hash(self._obj)
110def _weak_or_strong_ref(func, callback):
111 """
112 Return a `WeakMethod` wrapping *func* if possible, else a `_StrongRef`.
113 """
114 try:
115 return weakref.WeakMethod(func, callback)
116 except TypeError:
117 return _StrongRef(func)
120class CallbackRegistry:
121 """
122 Handle registering, processing, blocking, and disconnecting
123 for a set of signals and callbacks:
125 >>> def oneat(x):
126 ... print('eat', x)
127 >>> def ondrink(x):
128 ... print('drink', x)
130 >>> from matplotlib.cbook import CallbackRegistry
131 >>> callbacks = CallbackRegistry()
133 >>> id_eat = callbacks.connect('eat', oneat)
134 >>> id_drink = callbacks.connect('drink', ondrink)
136 >>> callbacks.process('drink', 123)
137 drink 123
138 >>> callbacks.process('eat', 456)
139 eat 456
140 >>> callbacks.process('be merry', 456) # nothing will be called
142 >>> callbacks.disconnect(id_eat)
143 >>> callbacks.process('eat', 456) # nothing will be called
145 >>> with callbacks.blocked(signal='drink'):
146 ... callbacks.process('drink', 123) # nothing will be called
147 >>> callbacks.process('drink', 123)
148 drink 123
150 In practice, one should always disconnect all callbacks when they are
151 no longer needed to avoid dangling references (and thus memory leaks).
152 However, real code in Matplotlib rarely does so, and due to its design,
153 it is rather difficult to place this kind of code. To get around this,
154 and prevent this class of memory leaks, we instead store weak references
155 to bound methods only, so when the destination object needs to die, the
156 CallbackRegistry won't keep it alive.
158 Parameters
159 ----------
160 exception_handler : callable, optional
161 If not None, *exception_handler* must be a function that takes an
162 `Exception` as single parameter. It gets called with any `Exception`
163 raised by the callbacks during `CallbackRegistry.process`, and may
164 either re-raise the exception or handle it in another manner.
166 The default handler prints the exception (with `traceback.print_exc`) if
167 an interactive event loop is running; it re-raises the exception if no
168 interactive event loop is running.
170 signals : list, optional
171 If not None, *signals* is a list of signals that this registry handles:
172 attempting to `process` or to `connect` to a signal not in the list
173 throws a `ValueError`. The default, None, does not restrict the
174 handled signals.
175 """
177 # We maintain two mappings:
178 # callbacks: signal -> {cid -> weakref-to-callback}
179 # _func_cid_map: signal -> {weakref-to-callback -> cid}
181 def __init__(self, exception_handler=_exception_printer, *, signals=None):
182 self._signals = None if signals is None else list(signals) # Copy it.
183 self.exception_handler = exception_handler
184 self.callbacks = {}
185 self._cid_gen = itertools.count()
186 self._func_cid_map = {}
187 # A hidden variable that marks cids that need to be pickled.
188 self._pickled_cids = set()
190 def __getstate__(self):
191 return {
192 **vars(self),
193 # In general, callbacks may not be pickled, so we just drop them,
194 # unless directed otherwise by self._pickled_cids.
195 "callbacks": {s: {cid: proxy() for cid, proxy in d.items()
196 if cid in self._pickled_cids}
197 for s, d in self.callbacks.items()},
198 # It is simpler to reconstruct this from callbacks in __setstate__.
199 "_func_cid_map": None,
200 "_cid_gen": next(self._cid_gen)
201 }
203 def __setstate__(self, state):
204 cid_count = state.pop('_cid_gen')
205 vars(self).update(state)
206 self.callbacks = {
207 s: {cid: _weak_or_strong_ref(func, self._remove_proxy)
208 for cid, func in d.items()}
209 for s, d in self.callbacks.items()}
210 self._func_cid_map = {
211 s: {proxy: cid for cid, proxy in d.items()}
212 for s, d in self.callbacks.items()}
213 self._cid_gen = itertools.count(cid_count)
215 def connect(self, signal, func):
216 """Register *func* to be called when signal *signal* is generated."""
217 if self._signals is not None:
218 _api.check_in_list(self._signals, signal=signal)
219 self._func_cid_map.setdefault(signal, {})
220 proxy = _weak_or_strong_ref(func, self._remove_proxy)
221 if proxy in self._func_cid_map[signal]:
222 return self._func_cid_map[signal][proxy]
223 cid = next(self._cid_gen)
224 self._func_cid_map[signal][proxy] = cid
225 self.callbacks.setdefault(signal, {})
226 self.callbacks[signal][cid] = proxy
227 return cid
229 def _connect_picklable(self, signal, func):
230 """
231 Like `.connect`, but the callback is kept when pickling/unpickling.
233 Currently internal-use only.
234 """
235 cid = self.connect(signal, func)
236 self._pickled_cids.add(cid)
237 return cid
239 # Keep a reference to sys.is_finalizing, as sys may have been cleared out
240 # at that point.
241 def _remove_proxy(self, proxy, *, _is_finalizing=sys.is_finalizing):
242 if _is_finalizing():
243 # Weakrefs can't be properly torn down at that point anymore.
244 return
245 for signal, proxy_to_cid in list(self._func_cid_map.items()):
246 cid = proxy_to_cid.pop(proxy, None)
247 if cid is not None:
248 del self.callbacks[signal][cid]
249 self._pickled_cids.discard(cid)
250 break
251 else:
252 # Not found
253 return
254 # Clean up empty dicts
255 if len(self.callbacks[signal]) == 0:
256 del self.callbacks[signal]
257 del self._func_cid_map[signal]
259 def disconnect(self, cid):
260 """
261 Disconnect the callback registered with callback id *cid*.
263 No error is raised if such a callback does not exist.
264 """
265 self._pickled_cids.discard(cid)
266 # Clean up callbacks
267 for signal, cid_to_proxy in list(self.callbacks.items()):
268 proxy = cid_to_proxy.pop(cid, None)
269 if proxy is not None:
270 break
271 else:
272 # Not found
273 return
275 proxy_to_cid = self._func_cid_map[signal]
276 for current_proxy, current_cid in list(proxy_to_cid.items()):
277 if current_cid == cid:
278 assert proxy is current_proxy
279 del proxy_to_cid[current_proxy]
280 # Clean up empty dicts
281 if len(self.callbacks[signal]) == 0:
282 del self.callbacks[signal]
283 del self._func_cid_map[signal]
285 def process(self, s, *args, **kwargs):
286 """
287 Process signal *s*.
289 All of the functions registered to receive callbacks on *s* will be
290 called with ``*args`` and ``**kwargs``.
291 """
292 if self._signals is not None:
293 _api.check_in_list(self._signals, signal=s)
294 for ref in list(self.callbacks.get(s, {}).values()):
295 func = ref()
296 if func is not None:
297 try:
298 func(*args, **kwargs)
299 # this does not capture KeyboardInterrupt, SystemExit,
300 # and GeneratorExit
301 except Exception as exc:
302 if self.exception_handler is not None:
303 self.exception_handler(exc)
304 else:
305 raise
307 @contextlib.contextmanager
308 def blocked(self, *, signal=None):
309 """
310 Block callback signals from being processed.
312 A context manager to temporarily block/disable callback signals
313 from being processed by the registered listeners.
315 Parameters
316 ----------
317 signal : str, optional
318 The callback signal to block. The default is to block all signals.
319 """
320 orig = self.callbacks
321 try:
322 if signal is None:
323 # Empty out the callbacks
324 self.callbacks = {}
325 else:
326 # Only remove the specific signal
327 self.callbacks = {k: orig[k] for k in orig if k != signal}
328 yield
329 finally:
330 self.callbacks = orig
333class silent_list(list):
334 """
335 A list with a short ``repr()``.
337 This is meant to be used for a homogeneous list of artists, so that they
338 don't cause long, meaningless output.
340 Instead of ::
342 [<matplotlib.lines.Line2D object at 0x7f5749fed3c8>,
343 <matplotlib.lines.Line2D object at 0x7f5749fed4e0>,
344 <matplotlib.lines.Line2D object at 0x7f5758016550>]
346 one will get ::
348 <a list of 3 Line2D objects>
350 If ``self.type`` is None, the type name is obtained from the first item in
351 the list (if any).
352 """
354 def __init__(self, type, seq=None):
355 self.type = type
356 if seq is not None:
357 self.extend(seq)
359 def __repr__(self):
360 if self.type is not None or len(self) != 0:
361 tp = self.type if self.type is not None else type(self[0]).__name__
362 return f"<a list of {len(self)} {tp} objects>"
363 else:
364 return "<an empty list>"
367def _local_over_kwdict(
368 local_var, kwargs, *keys,
369 warning_cls=_api.MatplotlibDeprecationWarning):
370 out = local_var
371 for key in keys:
372 kwarg_val = kwargs.pop(key, None)
373 if kwarg_val is not None:
374 if out is None:
375 out = kwarg_val
376 else:
377 _api.warn_external(f'"{key}" keyword argument will be ignored',
378 warning_cls)
379 return out
382def strip_math(s):
383 """
384 Remove latex formatting from mathtext.
386 Only handles fully math and fully non-math strings.
387 """
388 if len(s) >= 2 and s[0] == s[-1] == "$":
389 s = s[1:-1]
390 for tex, plain in [
391 (r"\times", "x"), # Specifically for Formatter support.
392 (r"\mathdefault", ""),
393 (r"\rm", ""),
394 (r"\cal", ""),
395 (r"\tt", ""),
396 (r"\it", ""),
397 ("\\", ""),
398 ("{", ""),
399 ("}", ""),
400 ]:
401 s = s.replace(tex, plain)
402 return s
405def _strip_comment(s):
406 """Strip everything from the first unquoted #."""
407 pos = 0
408 while True:
409 quote_pos = s.find('"', pos)
410 hash_pos = s.find('#', pos)
411 if quote_pos < 0:
412 without_comment = s if hash_pos < 0 else s[:hash_pos]
413 return without_comment.strip()
414 elif 0 <= hash_pos < quote_pos:
415 return s[:hash_pos].strip()
416 else:
417 closing_quote_pos = s.find('"', quote_pos + 1)
418 if closing_quote_pos < 0:
419 raise ValueError(
420 f"Missing closing quote in: {s!r}. If you need a double-"
421 'quote inside a string, use escaping: e.g. "the \" char"')
422 pos = closing_quote_pos + 1 # behind closing quote
425def is_writable_file_like(obj):
426 """Return whether *obj* looks like a file object with a *write* method."""
427 return callable(getattr(obj, 'write', None))
430def file_requires_unicode(x):
431 """
432 Return whether the given writable file-like object requires Unicode to be
433 written to it.
434 """
435 try:
436 x.write(b'')
437 except TypeError:
438 return True
439 else:
440 return False
443def to_filehandle(fname, flag='r', return_opened=False, encoding=None):
444 """
445 Convert a path to an open file handle or pass-through a file-like object.
447 Consider using `open_file_cm` instead, as it allows one to properly close
448 newly created file objects more easily.
450 Parameters
451 ----------
452 fname : str or path-like or file-like
453 If `str` or `os.PathLike`, the file is opened using the flags specified
454 by *flag* and *encoding*. If a file-like object, it is passed through.
455 flag : str, default: 'r'
456 Passed as the *mode* argument to `open` when *fname* is `str` or
457 `os.PathLike`; ignored if *fname* is file-like.
458 return_opened : bool, default: False
459 If True, return both the file object and a boolean indicating whether
460 this was a new file (that the caller needs to close). If False, return
461 only the new file.
462 encoding : str or None, default: None
463 Passed as the *mode* argument to `open` when *fname* is `str` or
464 `os.PathLike`; ignored if *fname* is file-like.
466 Returns
467 -------
468 fh : file-like
469 opened : bool
470 *opened* is only returned if *return_opened* is True.
471 """
472 if isinstance(fname, os.PathLike):
473 fname = os.fspath(fname)
474 if isinstance(fname, str):
475 if fname.endswith('.gz'):
476 fh = gzip.open(fname, flag)
477 elif fname.endswith('.bz2'):
478 # python may not be compiled with bz2 support,
479 # bury import until we need it
480 import bz2
481 fh = bz2.BZ2File(fname, flag)
482 else:
483 fh = open(fname, flag, encoding=encoding)
484 opened = True
485 elif hasattr(fname, 'seek'):
486 fh = fname
487 opened = False
488 else:
489 raise ValueError('fname must be a PathLike or file handle')
490 if return_opened:
491 return fh, opened
492 return fh
495def open_file_cm(path_or_file, mode="r", encoding=None):
496 r"""Pass through file objects and context-manage path-likes."""
497 fh, opened = to_filehandle(path_or_file, mode, True, encoding)
498 return fh if opened else contextlib.nullcontext(fh)
501def is_scalar_or_string(val):
502 """Return whether the given object is a scalar or string like."""
503 return isinstance(val, str) or not np.iterable(val)
506@_api.delete_parameter(
507 "3.8", "np_load", alternative="open(get_sample_data(..., asfileobj=False))")
508def get_sample_data(fname, asfileobj=True, *, np_load=True):
509 """
510 Return a sample data file. *fname* is a path relative to the
511 :file:`mpl-data/sample_data` directory. If *asfileobj* is `True`
512 return a file object, otherwise just a file path.
514 Sample data files are stored in the 'mpl-data/sample_data' directory within
515 the Matplotlib package.
517 If the filename ends in .gz, the file is implicitly ungzipped. If the
518 filename ends with .npy or .npz, and *asfileobj* is `True`, the file is
519 loaded with `numpy.load`.
520 """
521 path = _get_data_path('sample_data', fname)
522 if asfileobj:
523 suffix = path.suffix.lower()
524 if suffix == '.gz':
525 return gzip.open(path)
526 elif suffix in ['.npy', '.npz']:
527 if np_load:
528 return np.load(path)
529 else:
530 return path.open('rb')
531 elif suffix in ['.csv', '.xrc', '.txt']:
532 return path.open('r')
533 else:
534 return path.open('rb')
535 else:
536 return str(path)
539def _get_data_path(*args):
540 """
541 Return the `pathlib.Path` to a resource file provided by Matplotlib.
543 ``*args`` specify a path relative to the base data path.
544 """
545 return Path(matplotlib.get_data_path(), *args)
548def flatten(seq, scalarp=is_scalar_or_string):
549 """
550 Return a generator of flattened nested containers.
552 For example:
554 >>> from matplotlib.cbook import flatten
555 >>> l = (('John', ['Hunter']), (1, 23), [[([42, (5, 23)], )]])
556 >>> print(list(flatten(l)))
557 ['John', 'Hunter', 1, 23, 42, 5, 23]
559 By: Composite of Holger Krekel and Luther Blissett
560 From: https://code.activestate.com/recipes/121294/
561 and Recipe 1.12 in cookbook
562 """
563 for item in seq:
564 if scalarp(item) or item is None:
565 yield item
566 else:
567 yield from flatten(item, scalarp)
570@_api.deprecated("3.8")
571class Stack:
572 """
573 Stack of elements with a movable cursor.
575 Mimics home/back/forward in a web browser.
576 """
578 def __init__(self, default=None):
579 self.clear()
580 self._default = default
582 def __call__(self):
583 """Return the current element, or None."""
584 if not self._elements:
585 return self._default
586 else:
587 return self._elements[self._pos]
589 def __len__(self):
590 return len(self._elements)
592 def __getitem__(self, ind):
593 return self._elements[ind]
595 def forward(self):
596 """Move the position forward and return the current element."""
597 self._pos = min(self._pos + 1, len(self._elements) - 1)
598 return self()
600 def back(self):
601 """Move the position back and return the current element."""
602 if self._pos > 0:
603 self._pos -= 1
604 return self()
606 def push(self, o):
607 """
608 Push *o* to the stack at current position. Discard all later elements.
610 *o* is returned.
611 """
612 self._elements = self._elements[:self._pos + 1] + [o]
613 self._pos = len(self._elements) - 1
614 return self()
616 def home(self):
617 """
618 Push the first element onto the top of the stack.
620 The first element is returned.
621 """
622 if not self._elements:
623 return
624 self.push(self._elements[0])
625 return self()
627 def empty(self):
628 """Return whether the stack is empty."""
629 return len(self._elements) == 0
631 def clear(self):
632 """Empty the stack."""
633 self._pos = -1
634 self._elements = []
636 def bubble(self, o):
637 """
638 Raise all references of *o* to the top of the stack, and return it.
640 Raises
641 ------
642 ValueError
643 If *o* is not in the stack.
644 """
645 if o not in self._elements:
646 raise ValueError('Given element not contained in the stack')
647 old_elements = self._elements.copy()
648 self.clear()
649 top_elements = []
650 for elem in old_elements:
651 if elem == o:
652 top_elements.append(elem)
653 else:
654 self.push(elem)
655 for _ in top_elements:
656 self.push(o)
657 return o
659 def remove(self, o):
660 """
661 Remove *o* from the stack.
663 Raises
664 ------
665 ValueError
666 If *o* is not in the stack.
667 """
668 if o not in self._elements:
669 raise ValueError('Given element not contained in the stack')
670 old_elements = self._elements.copy()
671 self.clear()
672 for elem in old_elements:
673 if elem != o:
674 self.push(elem)
677class _Stack:
678 """
679 Stack of elements with a movable cursor.
681 Mimics home/back/forward in a web browser.
682 """
684 def __init__(self):
685 self._pos = -1
686 self._elements = []
688 def clear(self):
689 """Empty the stack."""
690 self._pos = -1
691 self._elements = []
693 def __call__(self):
694 """Return the current element, or None."""
695 return self._elements[self._pos] if self._elements else None
697 def __len__(self):
698 return len(self._elements)
700 def __getitem__(self, ind):
701 return self._elements[ind]
703 def forward(self):
704 """Move the position forward and return the current element."""
705 self._pos = min(self._pos + 1, len(self._elements) - 1)
706 return self()
708 def back(self):
709 """Move the position back and return the current element."""
710 self._pos = max(self._pos - 1, 0)
711 return self()
713 def push(self, o):
714 """
715 Push *o* to the stack after the current position, and return *o*.
717 Discard all later elements.
718 """
719 self._elements[self._pos + 1:] = [o]
720 self._pos = len(self._elements) - 1
721 return o
723 def home(self):
724 """
725 Push the first element onto the top of the stack.
727 The first element is returned.
728 """
729 return self.push(self._elements[0]) if self._elements else None
732def safe_masked_invalid(x, copy=False):
733 x = np.array(x, subok=True, copy=copy)
734 if not x.dtype.isnative:
735 # If we have already made a copy, do the byteswap in place, else make a
736 # copy with the byte order swapped.
737 # Swap to native order.
738 x = x.byteswap(inplace=copy).view(x.dtype.newbyteorder('N'))
739 try:
740 xm = np.ma.masked_where(~(np.isfinite(x)), x, copy=False)
741 except TypeError:
742 return x
743 return xm
746def print_cycles(objects, outstream=sys.stdout, show_progress=False):
747 """
748 Print loops of cyclic references in the given *objects*.
750 It is often useful to pass in ``gc.garbage`` to find the cycles that are
751 preventing some objects from being garbage collected.
753 Parameters
754 ----------
755 objects
756 A list of objects to find cycles in.
757 outstream
758 The stream for output.
759 show_progress : bool
760 If True, print the number of objects reached as they are found.
761 """
762 import gc
764 def print_path(path):
765 for i, step in enumerate(path):
766 # next "wraps around"
767 next = path[(i + 1) % len(path)]
769 outstream.write(" %s -- " % type(step))
770 if isinstance(step, dict):
771 for key, val in step.items():
772 if val is next:
773 outstream.write(f"[{key!r}]")
774 break
775 if key is next:
776 outstream.write(f"[key] = {val!r}")
777 break
778 elif isinstance(step, list):
779 outstream.write("[%d]" % step.index(next))
780 elif isinstance(step, tuple):
781 outstream.write("( tuple )")
782 else:
783 outstream.write(repr(step))
784 outstream.write(" ->\n")
785 outstream.write("\n")
787 def recurse(obj, start, all, current_path):
788 if show_progress:
789 outstream.write("%d\r" % len(all))
791 all[id(obj)] = None
793 referents = gc.get_referents(obj)
794 for referent in referents:
795 # If we've found our way back to the start, this is
796 # a cycle, so print it out
797 if referent is start:
798 print_path(current_path)
800 # Don't go back through the original list of objects, or
801 # through temporary references to the object, since those
802 # are just an artifact of the cycle detector itself.
803 elif referent is objects or isinstance(referent, types.FrameType):
804 continue
806 # We haven't seen this object before, so recurse
807 elif id(referent) not in all:
808 recurse(referent, start, all, current_path + [obj])
810 for obj in objects:
811 outstream.write(f"Examining: {obj!r}\n")
812 recurse(obj, obj, {}, [])
815class Grouper:
816 """
817 A disjoint-set data structure.
819 Objects can be joined using :meth:`join`, tested for connectedness
820 using :meth:`joined`, and all disjoint sets can be retrieved by
821 using the object as an iterator.
823 The objects being joined must be hashable and weak-referenceable.
825 Examples
826 --------
827 >>> from matplotlib.cbook import Grouper
828 >>> class Foo:
829 ... def __init__(self, s):
830 ... self.s = s
831 ... def __repr__(self):
832 ... return self.s
833 ...
834 >>> a, b, c, d, e, f = [Foo(x) for x in 'abcdef']
835 >>> grp = Grouper()
836 >>> grp.join(a, b)
837 >>> grp.join(b, c)
838 >>> grp.join(d, e)
839 >>> list(grp)
840 [[a, b, c], [d, e]]
841 >>> grp.joined(a, b)
842 True
843 >>> grp.joined(a, c)
844 True
845 >>> grp.joined(a, d)
846 False
847 """
849 def __init__(self, init=()):
850 self._mapping = weakref.WeakKeyDictionary(
851 {x: weakref.WeakSet([x]) for x in init})
852 self._ordering = weakref.WeakKeyDictionary()
853 for x in init:
854 if x not in self._ordering:
855 self._ordering[x] = len(self._ordering)
856 self._next_order = len(self._ordering) # Plain int to simplify pickling.
858 def __getstate__(self):
859 return {
860 **vars(self),
861 # Convert weak refs to strong ones.
862 "_mapping": {k: set(v) for k, v in self._mapping.items()},
863 "_ordering": {**self._ordering},
864 }
866 def __setstate__(self, state):
867 vars(self).update(state)
868 # Convert strong refs to weak ones.
869 self._mapping = weakref.WeakKeyDictionary(
870 {k: weakref.WeakSet(v) for k, v in self._mapping.items()})
871 self._ordering = weakref.WeakKeyDictionary(self._ordering)
873 def __contains__(self, item):
874 return item in self._mapping
876 @_api.deprecated("3.8", alternative="none, you no longer need to clean a Grouper")
877 def clean(self):
878 """Clean dead weak references from the dictionary."""
880 def join(self, a, *args):
881 """
882 Join given arguments into the same set. Accepts one or more arguments.
883 """
884 mapping = self._mapping
885 try:
886 set_a = mapping[a]
887 except KeyError:
888 set_a = mapping[a] = weakref.WeakSet([a])
889 self._ordering[a] = self._next_order
890 self._next_order += 1
891 for arg in args:
892 try:
893 set_b = mapping[arg]
894 except KeyError:
895 set_b = mapping[arg] = weakref.WeakSet([arg])
896 self._ordering[arg] = self._next_order
897 self._next_order += 1
898 if set_b is not set_a:
899 if len(set_b) > len(set_a):
900 set_a, set_b = set_b, set_a
901 set_a.update(set_b)
902 for elem in set_b:
903 mapping[elem] = set_a
905 def joined(self, a, b):
906 """Return whether *a* and *b* are members of the same set."""
907 return (self._mapping.get(a, object()) is self._mapping.get(b))
909 def remove(self, a):
910 """Remove *a* from the grouper, doing nothing if it is not there."""
911 self._mapping.pop(a, {a}).remove(a)
912 self._ordering.pop(a, None)
914 def __iter__(self):
915 """
916 Iterate over each of the disjoint sets as a list.
918 The iterator is invalid if interleaved with calls to join().
919 """
920 unique_groups = {id(group): group for group in self._mapping.values()}
921 for group in unique_groups.values():
922 yield sorted(group, key=self._ordering.__getitem__)
924 def get_siblings(self, a):
925 """Return all of the items joined with *a*, including itself."""
926 siblings = self._mapping.get(a, [a])
927 return sorted(siblings, key=self._ordering.get)
930class GrouperView:
931 """Immutable view over a `.Grouper`."""
933 def __init__(self, grouper): self._grouper = grouper
934 def __contains__(self, item): return item in self._grouper
935 def __iter__(self): return iter(self._grouper)
936 def joined(self, a, b): return self._grouper.joined(a, b)
937 def get_siblings(self, a): return self._grouper.get_siblings(a)
940def simple_linear_interpolation(a, steps):
941 """
942 Resample an array with ``steps - 1`` points between original point pairs.
944 Along each column of *a*, ``(steps - 1)`` points are introduced between
945 each original values; the values are linearly interpolated.
947 Parameters
948 ----------
949 a : array, shape (n, ...)
950 steps : int
952 Returns
953 -------
954 array
955 shape ``((n - 1) * steps + 1, ...)``
956 """
957 fps = a.reshape((len(a), -1))
958 xp = np.arange(len(a)) * steps
959 x = np.arange((len(a) - 1) * steps + 1)
960 return (np.column_stack([np.interp(x, xp, fp) for fp in fps.T])
961 .reshape((len(x),) + a.shape[1:]))
964def delete_masked_points(*args):
965 """
966 Find all masked and/or non-finite points in a set of arguments,
967 and return the arguments with only the unmasked points remaining.
969 Arguments can be in any of 5 categories:
971 1) 1-D masked arrays
972 2) 1-D ndarrays
973 3) ndarrays with more than one dimension
974 4) other non-string iterables
975 5) anything else
977 The first argument must be in one of the first four categories;
978 any argument with a length differing from that of the first
979 argument (and hence anything in category 5) then will be
980 passed through unchanged.
982 Masks are obtained from all arguments of the correct length
983 in categories 1, 2, and 4; a point is bad if masked in a masked
984 array or if it is a nan or inf. No attempt is made to
985 extract a mask from categories 2, 3, and 4 if `numpy.isfinite`
986 does not yield a Boolean array.
988 All input arguments that are not passed unchanged are returned
989 as ndarrays after removing the points or rows corresponding to
990 masks in any of the arguments.
992 A vastly simpler version of this function was originally
993 written as a helper for Axes.scatter().
995 """
996 if not len(args):
997 return ()
998 if is_scalar_or_string(args[0]):
999 raise ValueError("First argument must be a sequence")
1000 nrecs = len(args[0])
1001 margs = []
1002 seqlist = [False] * len(args)
1003 for i, x in enumerate(args):
1004 if not isinstance(x, str) and np.iterable(x) and len(x) == nrecs:
1005 seqlist[i] = True
1006 if isinstance(x, np.ma.MaskedArray):
1007 if x.ndim > 1:
1008 raise ValueError("Masked arrays must be 1-D")
1009 else:
1010 x = np.asarray(x)
1011 margs.append(x)
1012 masks = [] # List of masks that are True where good.
1013 for i, x in enumerate(margs):
1014 if seqlist[i]:
1015 if x.ndim > 1:
1016 continue # Don't try to get nan locations unless 1-D.
1017 if isinstance(x, np.ma.MaskedArray):
1018 masks.append(~np.ma.getmaskarray(x)) # invert the mask
1019 xd = x.data
1020 else:
1021 xd = x
1022 try:
1023 mask = np.isfinite(xd)
1024 if isinstance(mask, np.ndarray):
1025 masks.append(mask)
1026 except Exception: # Fixme: put in tuple of possible exceptions?
1027 pass
1028 if len(masks):
1029 mask = np.logical_and.reduce(masks)
1030 igood = mask.nonzero()[0]
1031 if len(igood) < nrecs:
1032 for i, x in enumerate(margs):
1033 if seqlist[i]:
1034 margs[i] = x[igood]
1035 for i, x in enumerate(margs):
1036 if seqlist[i] and isinstance(x, np.ma.MaskedArray):
1037 margs[i] = x.filled()
1038 return margs
1041def _combine_masks(*args):
1042 """
1043 Find all masked and/or non-finite points in a set of arguments,
1044 and return the arguments as masked arrays with a common mask.
1046 Arguments can be in any of 5 categories:
1048 1) 1-D masked arrays
1049 2) 1-D ndarrays
1050 3) ndarrays with more than one dimension
1051 4) other non-string iterables
1052 5) anything else
1054 The first argument must be in one of the first four categories;
1055 any argument with a length differing from that of the first
1056 argument (and hence anything in category 5) then will be
1057 passed through unchanged.
1059 Masks are obtained from all arguments of the correct length
1060 in categories 1, 2, and 4; a point is bad if masked in a masked
1061 array or if it is a nan or inf. No attempt is made to
1062 extract a mask from categories 2 and 4 if `numpy.isfinite`
1063 does not yield a Boolean array. Category 3 is included to
1064 support RGB or RGBA ndarrays, which are assumed to have only
1065 valid values and which are passed through unchanged.
1067 All input arguments that are not passed unchanged are returned
1068 as masked arrays if any masked points are found, otherwise as
1069 ndarrays.
1071 """
1072 if not len(args):
1073 return ()
1074 if is_scalar_or_string(args[0]):
1075 raise ValueError("First argument must be a sequence")
1076 nrecs = len(args[0])
1077 margs = [] # Output args; some may be modified.
1078 seqlist = [False] * len(args) # Flags: True if output will be masked.
1079 masks = [] # List of masks.
1080 for i, x in enumerate(args):
1081 if is_scalar_or_string(x) or len(x) != nrecs:
1082 margs.append(x) # Leave it unmodified.
1083 else:
1084 if isinstance(x, np.ma.MaskedArray) and x.ndim > 1:
1085 raise ValueError("Masked arrays must be 1-D")
1086 try:
1087 x = np.asanyarray(x)
1088 except (VisibleDeprecationWarning, ValueError):
1089 # NumPy 1.19 raises a warning about ragged arrays, but we want
1090 # to accept basically anything here.
1091 x = np.asanyarray(x, dtype=object)
1092 if x.ndim == 1:
1093 x = safe_masked_invalid(x)
1094 seqlist[i] = True
1095 if np.ma.is_masked(x):
1096 masks.append(np.ma.getmaskarray(x))
1097 margs.append(x) # Possibly modified.
1098 if len(masks):
1099 mask = np.logical_or.reduce(masks)
1100 for i, x in enumerate(margs):
1101 if seqlist[i]:
1102 margs[i] = np.ma.array(x, mask=mask)
1103 return margs
1106def _broadcast_with_masks(*args, compress=False):
1107 """
1108 Broadcast inputs, combining all masked arrays.
1110 Parameters
1111 ----------
1112 *args : array-like
1113 The inputs to broadcast.
1114 compress : bool, default: False
1115 Whether to compress the masked arrays. If False, the masked values
1116 are replaced by NaNs.
1118 Returns
1119 -------
1120 list of array-like
1121 The broadcasted and masked inputs.
1122 """
1123 # extract the masks, if any
1124 masks = [k.mask for k in args if isinstance(k, np.ma.MaskedArray)]
1125 # broadcast to match the shape
1126 bcast = np.broadcast_arrays(*args, *masks)
1127 inputs = bcast[:len(args)]
1128 masks = bcast[len(args):]
1129 if masks:
1130 # combine the masks into one
1131 mask = np.logical_or.reduce(masks)
1132 # put mask on and compress
1133 if compress:
1134 inputs = [np.ma.array(k, mask=mask).compressed()
1135 for k in inputs]
1136 else:
1137 inputs = [np.ma.array(k, mask=mask, dtype=float).filled(np.nan).ravel()
1138 for k in inputs]
1139 else:
1140 inputs = [np.ravel(k) for k in inputs]
1141 return inputs
1144def boxplot_stats(X, whis=1.5, bootstrap=None, labels=None, autorange=False):
1145 r"""
1146 Return a list of dictionaries of statistics used to draw a series of box
1147 and whisker plots using `~.Axes.bxp`.
1149 Parameters
1150 ----------
1151 X : array-like
1152 Data that will be represented in the boxplots. Should have 2 or
1153 fewer dimensions.
1155 whis : float or (float, float), default: 1.5
1156 The position of the whiskers.
1158 If a float, the lower whisker is at the lowest datum above
1159 ``Q1 - whis*(Q3-Q1)``, and the upper whisker at the highest datum below
1160 ``Q3 + whis*(Q3-Q1)``, where Q1 and Q3 are the first and third
1161 quartiles. The default value of ``whis = 1.5`` corresponds to Tukey's
1162 original definition of boxplots.
1164 If a pair of floats, they indicate the percentiles at which to draw the
1165 whiskers (e.g., (5, 95)). In particular, setting this to (0, 100)
1166 results in whiskers covering the whole range of the data.
1168 In the edge case where ``Q1 == Q3``, *whis* is automatically set to
1169 (0, 100) (cover the whole range of the data) if *autorange* is True.
1171 Beyond the whiskers, data are considered outliers and are plotted as
1172 individual points.
1174 bootstrap : int, optional
1175 Number of times the confidence intervals around the median
1176 should be bootstrapped (percentile method).
1178 labels : list of str, optional
1179 Labels for each dataset. Length must be compatible with
1180 dimensions of *X*.
1182 autorange : bool, optional (False)
1183 When `True` and the data are distributed such that the 25th and 75th
1184 percentiles are equal, ``whis`` is set to (0, 100) such that the
1185 whisker ends are at the minimum and maximum of the data.
1187 Returns
1188 -------
1189 list of dict
1190 A list of dictionaries containing the results for each column
1191 of data. Keys of each dictionary are the following:
1193 ======== ===================================
1194 Key Value Description
1195 ======== ===================================
1196 label tick label for the boxplot
1197 mean arithmetic mean value
1198 med 50th percentile
1199 q1 first quartile (25th percentile)
1200 q3 third quartile (75th percentile)
1201 iqr interquartile range
1202 cilo lower notch around the median
1203 cihi upper notch around the median
1204 whislo end of the lower whisker
1205 whishi end of the upper whisker
1206 fliers outliers
1207 ======== ===================================
1209 Notes
1210 -----
1211 Non-bootstrapping approach to confidence interval uses Gaussian-based
1212 asymptotic approximation:
1214 .. math::
1216 \mathrm{med} \pm 1.57 \times \frac{\mathrm{iqr}}{\sqrt{N}}
1218 General approach from:
1219 McGill, R., Tukey, J.W., and Larsen, W.A. (1978) "Variations of
1220 Boxplots", The American Statistician, 32:12-16.
1221 """
1223 def _bootstrap_median(data, N=5000):
1224 # determine 95% confidence intervals of the median
1225 M = len(data)
1226 percentiles = [2.5, 97.5]
1228 bs_index = np.random.randint(M, size=(N, M))
1229 bsData = data[bs_index]
1230 estimate = np.median(bsData, axis=1, overwrite_input=True)
1232 CI = np.percentile(estimate, percentiles)
1233 return CI
1235 def _compute_conf_interval(data, med, iqr, bootstrap):
1236 if bootstrap is not None:
1237 # Do a bootstrap estimate of notch locations.
1238 # get conf. intervals around median
1239 CI = _bootstrap_median(data, N=bootstrap)
1240 notch_min = CI[0]
1241 notch_max = CI[1]
1242 else:
1244 N = len(data)
1245 notch_min = med - 1.57 * iqr / np.sqrt(N)
1246 notch_max = med + 1.57 * iqr / np.sqrt(N)
1248 return notch_min, notch_max
1250 # output is a list of dicts
1251 bxpstats = []
1253 # convert X to a list of lists
1254 X = _reshape_2D(X, "X")
1256 ncols = len(X)
1257 if labels is None:
1258 labels = itertools.repeat(None)
1259 elif len(labels) != ncols:
1260 raise ValueError("Dimensions of labels and X must be compatible")
1262 input_whis = whis
1263 for ii, (x, label) in enumerate(zip(X, labels)):
1265 # empty dict
1266 stats = {}
1267 if label is not None:
1268 stats['label'] = label
1270 # restore whis to the input values in case it got changed in the loop
1271 whis = input_whis
1273 # note tricksiness, append up here and then mutate below
1274 bxpstats.append(stats)
1276 # if empty, bail
1277 if len(x) == 0:
1278 stats['fliers'] = np.array([])
1279 stats['mean'] = np.nan
1280 stats['med'] = np.nan
1281 stats['q1'] = np.nan
1282 stats['q3'] = np.nan
1283 stats['iqr'] = np.nan
1284 stats['cilo'] = np.nan
1285 stats['cihi'] = np.nan
1286 stats['whislo'] = np.nan
1287 stats['whishi'] = np.nan
1288 continue
1290 # up-convert to an array, just to be safe
1291 x = np.ma.asarray(x)
1292 x = x.data[~x.mask].ravel()
1294 # arithmetic mean
1295 stats['mean'] = np.mean(x)
1297 # medians and quartiles
1298 q1, med, q3 = np.percentile(x, [25, 50, 75])
1300 # interquartile range
1301 stats['iqr'] = q3 - q1
1302 if stats['iqr'] == 0 and autorange:
1303 whis = (0, 100)
1305 # conf. interval around median
1306 stats['cilo'], stats['cihi'] = _compute_conf_interval(
1307 x, med, stats['iqr'], bootstrap
1308 )
1310 # lowest/highest non-outliers
1311 if np.iterable(whis) and not isinstance(whis, str):
1312 loval, hival = np.percentile(x, whis)
1313 elif np.isreal(whis):
1314 loval = q1 - whis * stats['iqr']
1315 hival = q3 + whis * stats['iqr']
1316 else:
1317 raise ValueError('whis must be a float or list of percentiles')
1319 # get high extreme
1320 wiskhi = x[x <= hival]
1321 if len(wiskhi) == 0 or np.max(wiskhi) < q3:
1322 stats['whishi'] = q3
1323 else:
1324 stats['whishi'] = np.max(wiskhi)
1326 # get low extreme
1327 wisklo = x[x >= loval]
1328 if len(wisklo) == 0 or np.min(wisklo) > q1:
1329 stats['whislo'] = q1
1330 else:
1331 stats['whislo'] = np.min(wisklo)
1333 # compute a single array of outliers
1334 stats['fliers'] = np.concatenate([
1335 x[x < stats['whislo']],
1336 x[x > stats['whishi']],
1337 ])
1339 # add in the remaining stats
1340 stats['q1'], stats['med'], stats['q3'] = q1, med, q3
1342 return bxpstats
1345#: Maps short codes for line style to their full name used by backends.
1346ls_mapper = {'-': 'solid', '--': 'dashed', '-.': 'dashdot', ':': 'dotted'}
1347#: Maps full names for line styles used by backends to their short codes.
1348ls_mapper_r = {v: k for k, v in ls_mapper.items()}
1351def contiguous_regions(mask):
1352 """
1353 Return a list of (ind0, ind1) such that ``mask[ind0:ind1].all()`` is
1354 True and we cover all such regions.
1355 """
1356 mask = np.asarray(mask, dtype=bool)
1358 if not mask.size:
1359 return []
1361 # Find the indices of region changes, and correct offset
1362 idx, = np.nonzero(mask[:-1] != mask[1:])
1363 idx += 1
1365 # List operations are faster for moderately sized arrays
1366 idx = idx.tolist()
1368 # Add first and/or last index if needed
1369 if mask[0]:
1370 idx = [0] + idx
1371 if mask[-1]:
1372 idx.append(len(mask))
1374 return list(zip(idx[::2], idx[1::2]))
1377def is_math_text(s):
1378 """
1379 Return whether the string *s* contains math expressions.
1381 This is done by checking whether *s* contains an even number of
1382 non-escaped dollar signs.
1383 """
1384 s = str(s)
1385 dollar_count = s.count(r'$') - s.count(r'\$')
1386 even_dollars = (dollar_count > 0 and dollar_count % 2 == 0)
1387 return even_dollars
1390def _to_unmasked_float_array(x):
1391 """
1392 Convert a sequence to a float array; if input was a masked array, masked
1393 values are converted to nans.
1394 """
1395 if hasattr(x, 'mask'):
1396 return np.ma.asarray(x, float).filled(np.nan)
1397 else:
1398 return np.asarray(x, float)
1401def _check_1d(x):
1402 """Convert scalars to 1D arrays; pass-through arrays as is."""
1403 # Unpack in case of e.g. Pandas or xarray object
1404 x = _unpack_to_numpy(x)
1405 # plot requires `shape` and `ndim`. If passed an
1406 # object that doesn't provide them, then force to numpy array.
1407 # Note this will strip unit information.
1408 if (not hasattr(x, 'shape') or
1409 not hasattr(x, 'ndim') or
1410 len(x.shape) < 1):
1411 return np.atleast_1d(x)
1412 else:
1413 return x
1416def _reshape_2D(X, name):
1417 """
1418 Use Fortran ordering to convert ndarrays and lists of iterables to lists of
1419 1D arrays.
1421 Lists of iterables are converted by applying `numpy.asanyarray` to each of
1422 their elements. 1D ndarrays are returned in a singleton list containing
1423 them. 2D ndarrays are converted to the list of their *columns*.
1425 *name* is used to generate the error message for invalid inputs.
1426 """
1428 # Unpack in case of e.g. Pandas or xarray object
1429 X = _unpack_to_numpy(X)
1431 # Iterate over columns for ndarrays.
1432 if isinstance(X, np.ndarray):
1433 X = X.T
1435 if len(X) == 0:
1436 return [[]]
1437 elif X.ndim == 1 and np.ndim(X[0]) == 0:
1438 # 1D array of scalars: directly return it.
1439 return [X]
1440 elif X.ndim in [1, 2]:
1441 # 2D array, or 1D array of iterables: flatten them first.
1442 return [np.reshape(x, -1) for x in X]
1443 else:
1444 raise ValueError(f'{name} must have 2 or fewer dimensions')
1446 # Iterate over list of iterables.
1447 if len(X) == 0:
1448 return [[]]
1450 result = []
1451 is_1d = True
1452 for xi in X:
1453 # check if this is iterable, except for strings which we
1454 # treat as singletons.
1455 if not isinstance(xi, str):
1456 try:
1457 iter(xi)
1458 except TypeError:
1459 pass
1460 else:
1461 is_1d = False
1462 xi = np.asanyarray(xi)
1463 nd = np.ndim(xi)
1464 if nd > 1:
1465 raise ValueError(f'{name} must have 2 or fewer dimensions')
1466 result.append(xi.reshape(-1))
1468 if is_1d:
1469 # 1D array of scalars: directly return it.
1470 return [np.reshape(result, -1)]
1471 else:
1472 # 2D array, or 1D array of iterables: use flattened version.
1473 return result
1476def violin_stats(X, method, points=100, quantiles=None):
1477 """
1478 Return a list of dictionaries of data which can be used to draw a series
1479 of violin plots.
1481 See the ``Returns`` section below to view the required keys of the
1482 dictionary.
1484 Users can skip this function and pass a user-defined set of dictionaries
1485 with the same keys to `~.axes.Axes.violinplot` instead of using Matplotlib
1486 to do the calculations. See the *Returns* section below for the keys
1487 that must be present in the dictionaries.
1489 Parameters
1490 ----------
1491 X : array-like
1492 Sample data that will be used to produce the gaussian kernel density
1493 estimates. Must have 2 or fewer dimensions.
1495 method : callable
1496 The method used to calculate the kernel density estimate for each
1497 column of data. When called via ``method(v, coords)``, it should
1498 return a vector of the values of the KDE evaluated at the values
1499 specified in coords.
1501 points : int, default: 100
1502 Defines the number of points to evaluate each of the gaussian kernel
1503 density estimates at.
1505 quantiles : array-like, default: None
1506 Defines (if not None) a list of floats in interval [0, 1] for each
1507 column of data, which represents the quantiles that will be rendered
1508 for that column of data. Must have 2 or fewer dimensions. 1D array will
1509 be treated as a singleton list containing them.
1511 Returns
1512 -------
1513 list of dict
1514 A list of dictionaries containing the results for each column of data.
1515 The dictionaries contain at least the following:
1517 - coords: A list of scalars containing the coordinates this particular
1518 kernel density estimate was evaluated at.
1519 - vals: A list of scalars containing the values of the kernel density
1520 estimate at each of the coordinates given in *coords*.
1521 - mean: The mean value for this column of data.
1522 - median: The median value for this column of data.
1523 - min: The minimum value for this column of data.
1524 - max: The maximum value for this column of data.
1525 - quantiles: The quantile values for this column of data.
1526 """
1528 # List of dictionaries describing each of the violins.
1529 vpstats = []
1531 # Want X to be a list of data sequences
1532 X = _reshape_2D(X, "X")
1534 # Want quantiles to be as the same shape as data sequences
1535 if quantiles is not None and len(quantiles) != 0:
1536 quantiles = _reshape_2D(quantiles, "quantiles")
1537 # Else, mock quantiles if it's none or empty
1538 else:
1539 quantiles = [[]] * len(X)
1541 # quantiles should have the same size as dataset
1542 if len(X) != len(quantiles):
1543 raise ValueError("List of violinplot statistics and quantiles values"
1544 " must have the same length")
1546 # Zip x and quantiles
1547 for (x, q) in zip(X, quantiles):
1548 # Dictionary of results for this distribution
1549 stats = {}
1551 # Calculate basic stats for the distribution
1552 min_val = np.min(x)
1553 max_val = np.max(x)
1554 quantile_val = np.percentile(x, 100 * q)
1556 # Evaluate the kernel density estimate
1557 coords = np.linspace(min_val, max_val, points)
1558 stats['vals'] = method(x, coords)
1559 stats['coords'] = coords
1561 # Store additional statistics for this distribution
1562 stats['mean'] = np.mean(x)
1563 stats['median'] = np.median(x)
1564 stats['min'] = min_val
1565 stats['max'] = max_val
1566 stats['quantiles'] = np.atleast_1d(quantile_val)
1568 # Append to output
1569 vpstats.append(stats)
1571 return vpstats
1574def pts_to_prestep(x, *args):
1575 """
1576 Convert continuous line to pre-steps.
1578 Given a set of ``N`` points, convert to ``2N - 1`` points, which when
1579 connected linearly give a step function which changes values at the
1580 beginning of the intervals.
1582 Parameters
1583 ----------
1584 x : array
1585 The x location of the steps. May be empty.
1587 y1, ..., yp : array
1588 y arrays to be turned into steps; all must be the same length as ``x``.
1590 Returns
1591 -------
1592 array
1593 The x and y values converted to steps in the same order as the input;
1594 can be unpacked as ``x_out, y1_out, ..., yp_out``. If the input is
1595 length ``N``, each of these arrays will be length ``2N + 1``. For
1596 ``N=0``, the length will be 0.
1598 Examples
1599 --------
1600 >>> x_s, y1_s, y2_s = pts_to_prestep(x, y1, y2)
1601 """
1602 steps = np.zeros((1 + len(args), max(2 * len(x) - 1, 0)))
1603 # In all `pts_to_*step` functions, only assign once using *x* and *args*,
1604 # as converting to an array may be expensive.
1605 steps[0, 0::2] = x
1606 steps[0, 1::2] = steps[0, 0:-2:2]
1607 steps[1:, 0::2] = args
1608 steps[1:, 1::2] = steps[1:, 2::2]
1609 return steps
1612def pts_to_poststep(x, *args):
1613 """
1614 Convert continuous line to post-steps.
1616 Given a set of ``N`` points convert to ``2N + 1`` points, which when
1617 connected linearly give a step function which changes values at the end of
1618 the intervals.
1620 Parameters
1621 ----------
1622 x : array
1623 The x location of the steps. May be empty.
1625 y1, ..., yp : array
1626 y arrays to be turned into steps; all must be the same length as ``x``.
1628 Returns
1629 -------
1630 array
1631 The x and y values converted to steps in the same order as the input;
1632 can be unpacked as ``x_out, y1_out, ..., yp_out``. If the input is
1633 length ``N``, each of these arrays will be length ``2N + 1``. For
1634 ``N=0``, the length will be 0.
1636 Examples
1637 --------
1638 >>> x_s, y1_s, y2_s = pts_to_poststep(x, y1, y2)
1639 """
1640 steps = np.zeros((1 + len(args), max(2 * len(x) - 1, 0)))
1641 steps[0, 0::2] = x
1642 steps[0, 1::2] = steps[0, 2::2]
1643 steps[1:, 0::2] = args
1644 steps[1:, 1::2] = steps[1:, 0:-2:2]
1645 return steps
1648def pts_to_midstep(x, *args):
1649 """
1650 Convert continuous line to mid-steps.
1652 Given a set of ``N`` points convert to ``2N`` points which when connected
1653 linearly give a step function which changes values at the middle of the
1654 intervals.
1656 Parameters
1657 ----------
1658 x : array
1659 The x location of the steps. May be empty.
1661 y1, ..., yp : array
1662 y arrays to be turned into steps; all must be the same length as
1663 ``x``.
1665 Returns
1666 -------
1667 array
1668 The x and y values converted to steps in the same order as the input;
1669 can be unpacked as ``x_out, y1_out, ..., yp_out``. If the input is
1670 length ``N``, each of these arrays will be length ``2N``.
1672 Examples
1673 --------
1674 >>> x_s, y1_s, y2_s = pts_to_midstep(x, y1, y2)
1675 """
1676 steps = np.zeros((1 + len(args), 2 * len(x)))
1677 x = np.asanyarray(x)
1678 steps[0, 1:-1:2] = steps[0, 2::2] = (x[:-1] + x[1:]) / 2
1679 steps[0, :1] = x[:1] # Also works for zero-sized input.
1680 steps[0, -1:] = x[-1:]
1681 steps[1:, 0::2] = args
1682 steps[1:, 1::2] = steps[1:, 0::2]
1683 return steps
1686STEP_LOOKUP_MAP = {'default': lambda x, y: (x, y),
1687 'steps': pts_to_prestep,
1688 'steps-pre': pts_to_prestep,
1689 'steps-post': pts_to_poststep,
1690 'steps-mid': pts_to_midstep}
1693def index_of(y):
1694 """
1695 A helper function to create reasonable x values for the given *y*.
1697 This is used for plotting (x, y) if x values are not explicitly given.
1699 First try ``y.index`` (assuming *y* is a `pandas.Series`), if that
1700 fails, use ``range(len(y))``.
1702 This will be extended in the future to deal with more types of
1703 labeled data.
1705 Parameters
1706 ----------
1707 y : float or array-like
1709 Returns
1710 -------
1711 x, y : ndarray
1712 The x and y values to plot.
1713 """
1714 try:
1715 return y.index.to_numpy(), y.to_numpy()
1716 except AttributeError:
1717 pass
1718 try:
1719 y = _check_1d(y)
1720 except (VisibleDeprecationWarning, ValueError):
1721 # NumPy 1.19 will warn on ragged input, and we can't actually use it.
1722 pass
1723 else:
1724 return np.arange(y.shape[0], dtype=float), y
1725 raise ValueError('Input could not be cast to an at-least-1D NumPy array')
1728def safe_first_element(obj):
1729 """
1730 Return the first element in *obj*.
1732 This is a type-independent way of obtaining the first element,
1733 supporting both index access and the iterator protocol.
1734 """
1735 if isinstance(obj, collections.abc.Iterator):
1736 # needed to accept `array.flat` as input.
1737 # np.flatiter reports as an instance of collections.Iterator but can still be
1738 # indexed via []. This has the side effect of re-setting the iterator, but
1739 # that is acceptable.
1740 try:
1741 return obj[0]
1742 except TypeError:
1743 pass
1744 raise RuntimeError("matplotlib does not support generators as input")
1745 return next(iter(obj))
1748def _safe_first_finite(obj):
1749 """
1750 Return the first finite element in *obj* if one is available and skip_nonfinite is
1751 True. Otherwise, return the first element.
1753 This is a method for internal use.
1755 This is a type-independent way of obtaining the first finite element, supporting
1756 both index access and the iterator protocol.
1757 """
1758 def safe_isfinite(val):
1759 if val is None:
1760 return False
1761 try:
1762 return math.isfinite(val)
1763 except (TypeError, ValueError):
1764 # if the outer object is 2d, then val is a 1d array, and
1765 # - math.isfinite(numpy.zeros(3)) raises TypeError
1766 # - math.isfinite(torch.zeros(3)) raises ValueError
1767 pass
1768 try:
1769 return np.isfinite(val) if np.isscalar(val) else True
1770 except TypeError:
1771 # This is something that NumPy cannot make heads or tails of,
1772 # assume "finite"
1773 return True
1775 if isinstance(obj, np.flatiter):
1776 # TODO do the finite filtering on this
1777 return obj[0]
1778 elif isinstance(obj, collections.abc.Iterator):
1779 raise RuntimeError("matplotlib does not support generators as input")
1780 else:
1781 for val in obj:
1782 if safe_isfinite(val):
1783 return val
1784 return safe_first_element(obj)
1787def sanitize_sequence(data):
1788 """
1789 Convert dictview objects to list. Other inputs are returned unchanged.
1790 """
1791 return (list(data) if isinstance(data, collections.abc.MappingView)
1792 else data)
1795def normalize_kwargs(kw, alias_mapping=None):
1796 """
1797 Helper function to normalize kwarg inputs.
1799 Parameters
1800 ----------
1801 kw : dict or None
1802 A dict of keyword arguments. None is explicitly supported and treated
1803 as an empty dict, to support functions with an optional parameter of
1804 the form ``props=None``.
1806 alias_mapping : dict or Artist subclass or Artist instance, optional
1807 A mapping between a canonical name to a list of aliases, in order of
1808 precedence from lowest to highest.
1810 If the canonical value is not in the list it is assumed to have the
1811 highest priority.
1813 If an Artist subclass or instance is passed, use its properties alias
1814 mapping.
1816 Raises
1817 ------
1818 TypeError
1819 To match what Python raises if invalid arguments/keyword arguments are
1820 passed to a callable.
1821 """
1822 from matplotlib.artist import Artist
1824 if kw is None:
1825 return {}
1827 # deal with default value of alias_mapping
1828 if alias_mapping is None:
1829 alias_mapping = {}
1830 elif (isinstance(alias_mapping, type) and issubclass(alias_mapping, Artist)
1831 or isinstance(alias_mapping, Artist)):
1832 alias_mapping = getattr(alias_mapping, "_alias_map", {})
1834 to_canonical = {alias: canonical
1835 for canonical, alias_list in alias_mapping.items()
1836 for alias in alias_list}
1837 canonical_to_seen = {}
1838 ret = {} # output dictionary
1840 for k, v in kw.items():
1841 canonical = to_canonical.get(k, k)
1842 if canonical in canonical_to_seen:
1843 raise TypeError(f"Got both {canonical_to_seen[canonical]!r} and "
1844 f"{k!r}, which are aliases of one another")
1845 canonical_to_seen[canonical] = k
1846 ret[canonical] = v
1848 return ret
1851@contextlib.contextmanager
1852def _lock_path(path):
1853 """
1854 Context manager for locking a path.
1856 Usage::
1858 with _lock_path(path):
1859 ...
1861 Another thread or process that attempts to lock the same path will wait
1862 until this context manager is exited.
1864 The lock is implemented by creating a temporary file in the parent
1865 directory, so that directory must exist and be writable.
1866 """
1867 path = Path(path)
1868 lock_path = path.with_name(path.name + ".matplotlib-lock")
1869 retries = 50
1870 sleeptime = 0.1
1871 for _ in range(retries):
1872 try:
1873 with lock_path.open("xb"):
1874 break
1875 except FileExistsError:
1876 time.sleep(sleeptime)
1877 else:
1878 raise TimeoutError("""\
1879Lock error: Matplotlib failed to acquire the following lock file:
1880 {}
1881This maybe due to another process holding this lock file. If you are sure no
1882other Matplotlib process is running, remove this file and try again.""".format(
1883 lock_path))
1884 try:
1885 yield
1886 finally:
1887 lock_path.unlink()
1890def _topmost_artist(
1891 artists,
1892 _cached_max=functools.partial(max, key=operator.attrgetter("zorder"))):
1893 """
1894 Get the topmost artist of a list.
1896 In case of a tie, return the *last* of the tied artists, as it will be
1897 drawn on top of the others. `max` returns the first maximum in case of
1898 ties, so we need to iterate over the list in reverse order.
1899 """
1900 return _cached_max(reversed(artists))
1903def _str_equal(obj, s):
1904 """
1905 Return whether *obj* is a string equal to string *s*.
1907 This helper solely exists to handle the case where *obj* is a numpy array,
1908 because in such cases, a naive ``obj == s`` would yield an array, which
1909 cannot be used in a boolean context.
1910 """
1911 return isinstance(obj, str) and obj == s
1914def _str_lower_equal(obj, s):
1915 """
1916 Return whether *obj* is a string equal, when lowercased, to string *s*.
1918 This helper solely exists to handle the case where *obj* is a numpy array,
1919 because in such cases, a naive ``obj == s`` would yield an array, which
1920 cannot be used in a boolean context.
1921 """
1922 return isinstance(obj, str) and obj.lower() == s
1925def _array_perimeter(arr):
1926 """
1927 Get the elements on the perimeter of *arr*.
1929 Parameters
1930 ----------
1931 arr : ndarray, shape (M, N)
1932 The input array.
1934 Returns
1935 -------
1936 ndarray, shape (2*(M - 1) + 2*(N - 1),)
1937 The elements on the perimeter of the array::
1939 [arr[0, 0], ..., arr[0, -1], ..., arr[-1, -1], ..., arr[-1, 0], ...]
1941 Examples
1942 --------
1943 >>> i, j = np.ogrid[:3, :4]
1944 >>> a = i*10 + j
1945 >>> a
1946 array([[ 0, 1, 2, 3],
1947 [10, 11, 12, 13],
1948 [20, 21, 22, 23]])
1949 >>> _array_perimeter(a)
1950 array([ 0, 1, 2, 3, 13, 23, 22, 21, 20, 10])
1951 """
1952 # note we use Python's half-open ranges to avoid repeating
1953 # the corners
1954 forward = np.s_[0:-1] # [0 ... -1)
1955 backward = np.s_[-1:0:-1] # [-1 ... 0)
1956 return np.concatenate((
1957 arr[0, forward],
1958 arr[forward, -1],
1959 arr[-1, backward],
1960 arr[backward, 0],
1961 ))
1964def _unfold(arr, axis, size, step):
1965 """
1966 Append an extra dimension containing sliding windows along *axis*.
1968 All windows are of size *size* and begin with every *step* elements.
1970 Parameters
1971 ----------
1972 arr : ndarray, shape (N_1, ..., N_k)
1973 The input array
1974 axis : int
1975 Axis along which the windows are extracted
1976 size : int
1977 Size of the windows
1978 step : int
1979 Stride between first elements of subsequent windows.
1981 Returns
1982 -------
1983 ndarray, shape (N_1, ..., 1 + (N_axis-size)/step, ..., N_k, size)
1985 Examples
1986 --------
1987 >>> i, j = np.ogrid[:3, :7]
1988 >>> a = i*10 + j
1989 >>> a
1990 array([[ 0, 1, 2, 3, 4, 5, 6],
1991 [10, 11, 12, 13, 14, 15, 16],
1992 [20, 21, 22, 23, 24, 25, 26]])
1993 >>> _unfold(a, axis=1, size=3, step=2)
1994 array([[[ 0, 1, 2],
1995 [ 2, 3, 4],
1996 [ 4, 5, 6]],
1997 [[10, 11, 12],
1998 [12, 13, 14],
1999 [14, 15, 16]],
2000 [[20, 21, 22],
2001 [22, 23, 24],
2002 [24, 25, 26]]])
2003 """
2004 new_shape = [*arr.shape, size]
2005 new_strides = [*arr.strides, arr.strides[axis]]
2006 new_shape[axis] = (new_shape[axis] - size) // step + 1
2007 new_strides[axis] = new_strides[axis] * step
2008 return np.lib.stride_tricks.as_strided(arr,
2009 shape=new_shape,
2010 strides=new_strides,
2011 writeable=False)
2014def _array_patch_perimeters(x, rstride, cstride):
2015 """
2016 Extract perimeters of patches from *arr*.
2018 Extracted patches are of size (*rstride* + 1) x (*cstride* + 1) and
2019 share perimeters with their neighbors. The ordering of the vertices matches
2020 that returned by ``_array_perimeter``.
2022 Parameters
2023 ----------
2024 x : ndarray, shape (N, M)
2025 Input array
2026 rstride : int
2027 Vertical (row) stride between corresponding elements of each patch
2028 cstride : int
2029 Horizontal (column) stride between corresponding elements of each patch
2031 Returns
2032 -------
2033 ndarray, shape (N/rstride * M/cstride, 2 * (rstride + cstride))
2034 """
2035 assert rstride > 0 and cstride > 0
2036 assert (x.shape[0] - 1) % rstride == 0
2037 assert (x.shape[1] - 1) % cstride == 0
2038 # We build up each perimeter from four half-open intervals. Here is an
2039 # illustrated explanation for rstride == cstride == 3
2040 #
2041 # T T T R
2042 # L R
2043 # L R
2044 # L B B B
2045 #
2046 # where T means that this element will be in the top array, R for right,
2047 # B for bottom and L for left. Each of the arrays below has a shape of:
2048 #
2049 # (number of perimeters that can be extracted vertically,
2050 # number of perimeters that can be extracted horizontally,
2051 # cstride for top and bottom and rstride for left and right)
2052 #
2053 # Note that _unfold doesn't incur any memory copies, so the only costly
2054 # operation here is the np.concatenate.
2055 top = _unfold(x[:-1:rstride, :-1], 1, cstride, cstride)
2056 bottom = _unfold(x[rstride::rstride, 1:], 1, cstride, cstride)[..., ::-1]
2057 right = _unfold(x[:-1, cstride::cstride], 0, rstride, rstride)
2058 left = _unfold(x[1:, :-1:cstride], 0, rstride, rstride)[..., ::-1]
2059 return (np.concatenate((top, right, bottom, left), axis=2)
2060 .reshape(-1, 2 * (rstride + cstride)))
2063@contextlib.contextmanager
2064def _setattr_cm(obj, **kwargs):
2065 """
2066 Temporarily set some attributes; restore original state at context exit.
2067 """
2068 sentinel = object()
2069 origs = {}
2070 for attr in kwargs:
2071 orig = getattr(obj, attr, sentinel)
2072 if attr in obj.__dict__ or orig is sentinel:
2073 # if we are pulling from the instance dict or the object
2074 # does not have this attribute we can trust the above
2075 origs[attr] = orig
2076 else:
2077 # if the attribute is not in the instance dict it must be
2078 # from the class level
2079 cls_orig = getattr(type(obj), attr)
2080 # if we are dealing with a property (but not a general descriptor)
2081 # we want to set the original value back.
2082 if isinstance(cls_orig, property):
2083 origs[attr] = orig
2084 # otherwise this is _something_ we are going to shadow at
2085 # the instance dict level from higher up in the MRO. We
2086 # are going to assume we can delattr(obj, attr) to clean
2087 # up after ourselves. It is possible that this code will
2088 # fail if used with a non-property custom descriptor which
2089 # implements __set__ (and __delete__ does not act like a
2090 # stack). However, this is an internal tool and we do not
2091 # currently have any custom descriptors.
2092 else:
2093 origs[attr] = sentinel
2095 try:
2096 for attr, val in kwargs.items():
2097 setattr(obj, attr, val)
2098 yield
2099 finally:
2100 for attr, orig in origs.items():
2101 if orig is sentinel:
2102 delattr(obj, attr)
2103 else:
2104 setattr(obj, attr, orig)
2107class _OrderedSet(collections.abc.MutableSet):
2108 def __init__(self):
2109 self._od = collections.OrderedDict()
2111 def __contains__(self, key):
2112 return key in self._od
2114 def __iter__(self):
2115 return iter(self._od)
2117 def __len__(self):
2118 return len(self._od)
2120 def add(self, key):
2121 self._od.pop(key, None)
2122 self._od[key] = None
2124 def discard(self, key):
2125 self._od.pop(key, None)
2128# Agg's buffers are unmultiplied RGBA8888, which neither PyQt<=5.1 nor cairo
2129# support; however, both do support premultiplied ARGB32.
2132def _premultiplied_argb32_to_unmultiplied_rgba8888(buf):
2133 """
2134 Convert a premultiplied ARGB32 buffer to an unmultiplied RGBA8888 buffer.
2135 """
2136 rgba = np.take( # .take() ensures C-contiguity of the result.
2137 buf,
2138 [2, 1, 0, 3] if sys.byteorder == "little" else [1, 2, 3, 0], axis=2)
2139 rgb = rgba[..., :-1]
2140 alpha = rgba[..., -1]
2141 # Un-premultiply alpha. The formula is the same as in cairo-png.c.
2142 mask = alpha != 0
2143 for channel in np.rollaxis(rgb, -1):
2144 channel[mask] = (
2145 (channel[mask].astype(int) * 255 + alpha[mask] // 2)
2146 // alpha[mask])
2147 return rgba
2150def _unmultiplied_rgba8888_to_premultiplied_argb32(rgba8888):
2151 """
2152 Convert an unmultiplied RGBA8888 buffer to a premultiplied ARGB32 buffer.
2153 """
2154 if sys.byteorder == "little":
2155 argb32 = np.take(rgba8888, [2, 1, 0, 3], axis=2)
2156 rgb24 = argb32[..., :-1]
2157 alpha8 = argb32[..., -1:]
2158 else:
2159 argb32 = np.take(rgba8888, [3, 0, 1, 2], axis=2)
2160 alpha8 = argb32[..., :1]
2161 rgb24 = argb32[..., 1:]
2162 # Only bother premultiplying when the alpha channel is not fully opaque,
2163 # as the cost is not negligible. The unsafe cast is needed to do the
2164 # multiplication in-place in an integer buffer.
2165 if alpha8.min() != 0xff:
2166 np.multiply(rgb24, alpha8 / 0xff, out=rgb24, casting="unsafe")
2167 return argb32
2170def _get_nonzero_slices(buf):
2171 """
2172 Return the bounds of the nonzero region of a 2D array as a pair of slices.
2174 ``buf[_get_nonzero_slices(buf)]`` is the smallest sub-rectangle in *buf*
2175 that encloses all non-zero entries in *buf*. If *buf* is fully zero, then
2176 ``(slice(0, 0), slice(0, 0))`` is returned.
2177 """
2178 x_nz, = buf.any(axis=0).nonzero()
2179 y_nz, = buf.any(axis=1).nonzero()
2180 if len(x_nz) and len(y_nz):
2181 l, r = x_nz[[0, -1]]
2182 b, t = y_nz[[0, -1]]
2183 return slice(b, t + 1), slice(l, r + 1)
2184 else:
2185 return slice(0, 0), slice(0, 0)
2188def _pformat_subprocess(command):
2189 """Pretty-format a subprocess command for printing/logging purposes."""
2190 return (command if isinstance(command, str)
2191 else " ".join(shlex.quote(os.fspath(arg)) for arg in command))
2194def _check_and_log_subprocess(command, logger, **kwargs):
2195 """
2196 Run *command*, returning its stdout output if it succeeds.
2198 If it fails (exits with nonzero return code), raise an exception whose text
2199 includes the failed command and captured stdout and stderr output.
2201 Regardless of the return code, the command is logged at DEBUG level on
2202 *logger*. In case of success, the output is likewise logged.
2203 """
2204 logger.debug('%s', _pformat_subprocess(command))
2205 proc = subprocess.run(command, capture_output=True, **kwargs)
2206 if proc.returncode:
2207 stdout = proc.stdout
2208 if isinstance(stdout, bytes):
2209 stdout = stdout.decode()
2210 stderr = proc.stderr
2211 if isinstance(stderr, bytes):
2212 stderr = stderr.decode()
2213 raise RuntimeError(
2214 f"The command\n"
2215 f" {_pformat_subprocess(command)}\n"
2216 f"failed and generated the following output:\n"
2217 f"{stdout}\n"
2218 f"and the following error:\n"
2219 f"{stderr}")
2220 if proc.stdout:
2221 logger.debug("stdout:\n%s", proc.stdout)
2222 if proc.stderr:
2223 logger.debug("stderr:\n%s", proc.stderr)
2224 return proc.stdout
2227def _setup_new_guiapp():
2228 """
2229 Perform OS-dependent setup when Matplotlib creates a new GUI application.
2230 """
2231 # Windows: If not explicit app user model id has been set yet (so we're not
2232 # already embedded), then set it to "matplotlib", so that taskbar icons are
2233 # correct.
2234 try:
2235 _c_internal_utils.Win32_GetCurrentProcessExplicitAppUserModelID()
2236 except OSError:
2237 _c_internal_utils.Win32_SetCurrentProcessExplicitAppUserModelID(
2238 "matplotlib")
2241def _format_approx(number, precision):
2242 """
2243 Format the number with at most the number of decimals given as precision.
2244 Remove trailing zeros and possibly the decimal point.
2245 """
2246 return f'{number:.{precision}f}'.rstrip('0').rstrip('.') or '0'
2249def _g_sig_digits(value, delta):
2250 """
2251 Return the number of significant digits to %g-format *value*, assuming that
2252 it is known with an error of *delta*.
2253 """
2254 if delta == 0:
2255 if value == 0:
2256 # if both value and delta are 0, np.spacing below returns 5e-324
2257 # which results in rather silly results
2258 return 3
2259 # delta = 0 may occur when trying to format values over a tiny range;
2260 # in that case, replace it by the distance to the closest float.
2261 delta = abs(np.spacing(value))
2262 # If e.g. value = 45.67 and delta = 0.02, then we want to round to 2 digits
2263 # after the decimal point (floor(log10(0.02)) = -2); 45.67 contributes 2
2264 # digits before the decimal point (floor(log10(45.67)) + 1 = 2): the total
2265 # is 4 significant digits. A value of 0 contributes 1 "digit" before the
2266 # decimal point.
2267 # For inf or nan, the precision doesn't matter.
2268 return max(
2269 0,
2270 (math.floor(math.log10(abs(value))) + 1 if value else 1)
2271 - math.floor(math.log10(delta))) if math.isfinite(value) else 0
2274def _unikey_or_keysym_to_mplkey(unikey, keysym):
2275 """
2276 Convert a Unicode key or X keysym to a Matplotlib key name.
2278 The Unicode key is checked first; this avoids having to list most printable
2279 keysyms such as ``EuroSign``.
2280 """
2281 # For non-printable characters, gtk3 passes "\0" whereas tk passes an "".
2282 if unikey and unikey.isprintable():
2283 return unikey
2284 key = keysym.lower()
2285 if key.startswith("kp_"): # keypad_x (including kp_enter).
2286 key = key[3:]
2287 if key.startswith("page_"): # page_{up,down}
2288 key = key.replace("page_", "page")
2289 if key.endswith(("_l", "_r")): # alt_l, ctrl_l, shift_l.
2290 key = key[:-2]
2291 if sys.platform == "darwin" and key == "meta":
2292 # meta should be reported as command on mac
2293 key = "cmd"
2294 key = {
2295 "return": "enter",
2296 "prior": "pageup", # Used by tk.
2297 "next": "pagedown", # Used by tk.
2298 }.get(key, key)
2299 return key
2302@functools.cache
2303def _make_class_factory(mixin_class, fmt, attr_name=None):
2304 """
2305 Return a function that creates picklable classes inheriting from a mixin.
2307 After ::
2309 factory = _make_class_factory(FooMixin, fmt, attr_name)
2310 FooAxes = factory(Axes)
2312 ``Foo`` is a class that inherits from ``FooMixin`` and ``Axes`` and **is
2313 picklable** (picklability is what differentiates this from a plain call to
2314 `type`). Its ``__name__`` is set to ``fmt.format(Axes.__name__)`` and the
2315 base class is stored in the ``attr_name`` attribute, if not None.
2317 Moreover, the return value of ``factory`` is memoized: calls with the same
2318 ``Axes`` class always return the same subclass.
2319 """
2321 @functools.cache
2322 def class_factory(axes_class):
2323 # if we have already wrapped this class, declare victory!
2324 if issubclass(axes_class, mixin_class):
2325 return axes_class
2327 # The parameter is named "axes_class" for backcompat but is really just
2328 # a base class; no axes semantics are used.
2329 base_class = axes_class
2331 class subcls(mixin_class, base_class):
2332 # Better approximation than __module__ = "matplotlib.cbook".
2333 __module__ = mixin_class.__module__
2335 def __reduce__(self):
2336 return (_picklable_class_constructor,
2337 (mixin_class, fmt, attr_name, base_class),
2338 self.__getstate__())
2340 subcls.__name__ = subcls.__qualname__ = fmt.format(base_class.__name__)
2341 if attr_name is not None:
2342 setattr(subcls, attr_name, base_class)
2343 return subcls
2345 class_factory.__module__ = mixin_class.__module__
2346 return class_factory
2349def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
2350 """Internal helper for _make_class_factory."""
2351 factory = _make_class_factory(mixin_class, fmt, attr_name)
2352 cls = factory(base_class)
2353 return cls.__new__(cls)
2356def _is_torch_array(x):
2357 """Check if 'x' is a PyTorch Tensor."""
2358 try:
2359 # we're intentionally not attempting to import torch. If somebody
2360 # has created a torch array, torch should already be in sys.modules
2361 return isinstance(x, sys.modules['torch'].Tensor)
2362 except Exception: # TypeError, KeyError, AttributeError, maybe others?
2363 # we're attempting to access attributes on imported modules which
2364 # may have arbitrary user code, so we deliberately catch all exceptions
2365 return False
2368def _is_jax_array(x):
2369 """Check if 'x' is a JAX Array."""
2370 try:
2371 # we're intentionally not attempting to import jax. If somebody
2372 # has created a jax array, jax should already be in sys.modules
2373 return isinstance(x, sys.modules['jax'].Array)
2374 except Exception: # TypeError, KeyError, AttributeError, maybe others?
2375 # we're attempting to access attributes on imported modules which
2376 # may have arbitrary user code, so we deliberately catch all exceptions
2377 return False
2380def _unpack_to_numpy(x):
2381 """Internal helper to extract data from e.g. pandas and xarray objects."""
2382 if isinstance(x, np.ndarray):
2383 # If numpy, return directly
2384 return x
2385 if hasattr(x, 'to_numpy'):
2386 # Assume that any to_numpy() method actually returns a numpy array
2387 return x.to_numpy()
2388 if hasattr(x, 'values'):
2389 xtmp = x.values
2390 # For example a dict has a 'values' attribute, but it is not a property
2391 # so in this case we do not want to return a function
2392 if isinstance(xtmp, np.ndarray):
2393 return xtmp
2394 if _is_torch_array(x) or _is_jax_array(x):
2395 xtmp = x.__array__()
2397 # In case __array__() method does not return a numpy array in future
2398 if isinstance(xtmp, np.ndarray):
2399 return xtmp
2400 return x
2403def _auto_format_str(fmt, value):
2404 """
2405 Apply *value* to the format string *fmt*.
2407 This works both with unnamed %-style formatting and
2408 unnamed {}-style formatting. %-style formatting has priority.
2409 If *fmt* is %-style formattable that will be used. Otherwise,
2410 {}-formatting is applied. Strings without formatting placeholders
2411 are passed through as is.
2413 Examples
2414 --------
2415 >>> _auto_format_str('%.2f m', 0.2)
2416 '0.20 m'
2417 >>> _auto_format_str('{} m', 0.2)
2418 '0.2 m'
2419 >>> _auto_format_str('const', 0.2)
2420 'const'
2421 >>> _auto_format_str('%d or {}', 0.2)
2422 '0 or {}'
2423 """
2424 try:
2425 return fmt % (value,)
2426 except (TypeError, ValueError):
2427 return fmt.format(value)