Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/matplotlib/cbook.py: 20%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

946 statements  

1""" 

2A collection of utility functions and classes. Originally, many 

3(but not all) were from the Python Cookbook -- hence the name cbook. 

4""" 

5 

6import collections 

7import collections.abc 

8import contextlib 

9import functools 

10import gzip 

11import itertools 

12import math 

13import operator 

14import os 

15from pathlib import Path 

16import shlex 

17import subprocess 

18import sys 

19import time 

20import traceback 

21import types 

22import weakref 

23 

24import numpy as np 

25 

26try: 

27 from numpy.exceptions import VisibleDeprecationWarning # numpy >= 1.25 

28except ImportError: 

29 from numpy import VisibleDeprecationWarning 

30 

31import matplotlib 

32from matplotlib import _api, _c_internal_utils 

33 

34 

35def _get_running_interactive_framework(): 

36 """ 

37 Return the interactive framework whose event loop is currently running, if 

38 any, or "headless" if no event loop can be started, or None. 

39 

40 Returns 

41 ------- 

42 Optional[str] 

43 One of the following values: "qt", "gtk3", "gtk4", "wx", "tk", 

44 "macosx", "headless", ``None``. 

45 """ 

46 # Use ``sys.modules.get(name)`` rather than ``name in sys.modules`` as 

47 # entries can also have been explicitly set to None. 

48 QtWidgets = ( 

49 sys.modules.get("PyQt6.QtWidgets") 

50 or sys.modules.get("PySide6.QtWidgets") 

51 or sys.modules.get("PyQt5.QtWidgets") 

52 or sys.modules.get("PySide2.QtWidgets") 

53 ) 

54 if QtWidgets and QtWidgets.QApplication.instance(): 

55 return "qt" 

56 Gtk = sys.modules.get("gi.repository.Gtk") 

57 if Gtk: 

58 if Gtk.MAJOR_VERSION == 4: 

59 from gi.repository import GLib 

60 if GLib.main_depth(): 

61 return "gtk4" 

62 if Gtk.MAJOR_VERSION == 3 and Gtk.main_level(): 

63 return "gtk3" 

64 wx = sys.modules.get("wx") 

65 if wx and wx.GetApp(): 

66 return "wx" 

67 tkinter = sys.modules.get("tkinter") 

68 if tkinter: 

69 codes = {tkinter.mainloop.__code__, tkinter.Misc.mainloop.__code__} 

70 for frame in sys._current_frames().values(): 

71 while frame: 

72 if frame.f_code in codes: 

73 return "tk" 

74 frame = frame.f_back 

75 # Preemptively break reference cycle between locals and the frame. 

76 del frame 

77 macosx = sys.modules.get("matplotlib.backends._macosx") 

78 if macosx and macosx.event_loop_is_running(): 

79 return "macosx" 

80 if not _c_internal_utils.display_is_valid(): 

81 return "headless" 

82 return None 

83 

84 

85def _exception_printer(exc): 

86 if _get_running_interactive_framework() in ["headless", None]: 

87 raise exc 

88 else: 

89 traceback.print_exc() 

90 

91 

92class _StrongRef: 

93 """ 

94 Wrapper similar to a weakref, but keeping a strong reference to the object. 

95 """ 

96 

97 def __init__(self, obj): 

98 self._obj = obj 

99 

100 def __call__(self): 

101 return self._obj 

102 

103 def __eq__(self, other): 

104 return isinstance(other, _StrongRef) and self._obj == other._obj 

105 

106 def __hash__(self): 

107 return hash(self._obj) 

108 

109 

110def _weak_or_strong_ref(func, callback): 

111 """ 

112 Return a `WeakMethod` wrapping *func* if possible, else a `_StrongRef`. 

113 """ 

114 try: 

115 return weakref.WeakMethod(func, callback) 

116 except TypeError: 

117 return _StrongRef(func) 

118 

119 

120class CallbackRegistry: 

121 """ 

122 Handle registering, processing, blocking, and disconnecting 

123 for a set of signals and callbacks: 

124 

125 >>> def oneat(x): 

126 ... print('eat', x) 

127 >>> def ondrink(x): 

128 ... print('drink', x) 

129 

130 >>> from matplotlib.cbook import CallbackRegistry 

131 >>> callbacks = CallbackRegistry() 

132 

133 >>> id_eat = callbacks.connect('eat', oneat) 

134 >>> id_drink = callbacks.connect('drink', ondrink) 

135 

136 >>> callbacks.process('drink', 123) 

137 drink 123 

138 >>> callbacks.process('eat', 456) 

139 eat 456 

140 >>> callbacks.process('be merry', 456) # nothing will be called 

141 

142 >>> callbacks.disconnect(id_eat) 

143 >>> callbacks.process('eat', 456) # nothing will be called 

144 

145 >>> with callbacks.blocked(signal='drink'): 

146 ... callbacks.process('drink', 123) # nothing will be called 

147 >>> callbacks.process('drink', 123) 

148 drink 123 

149 

150 In practice, one should always disconnect all callbacks when they are 

151 no longer needed to avoid dangling references (and thus memory leaks). 

152 However, real code in Matplotlib rarely does so, and due to its design, 

153 it is rather difficult to place this kind of code. To get around this, 

154 and prevent this class of memory leaks, we instead store weak references 

155 to bound methods only, so when the destination object needs to die, the 

156 CallbackRegistry won't keep it alive. 

157 

158 Parameters 

159 ---------- 

160 exception_handler : callable, optional 

161 If not None, *exception_handler* must be a function that takes an 

162 `Exception` as single parameter. It gets called with any `Exception` 

163 raised by the callbacks during `CallbackRegistry.process`, and may 

164 either re-raise the exception or handle it in another manner. 

165 

166 The default handler prints the exception (with `traceback.print_exc`) if 

167 an interactive event loop is running; it re-raises the exception if no 

168 interactive event loop is running. 

169 

170 signals : list, optional 

171 If not None, *signals* is a list of signals that this registry handles: 

172 attempting to `process` or to `connect` to a signal not in the list 

173 throws a `ValueError`. The default, None, does not restrict the 

174 handled signals. 

175 """ 

176 

177 # We maintain two mappings: 

178 # callbacks: signal -> {cid -> weakref-to-callback} 

179 # _func_cid_map: signal -> {weakref-to-callback -> cid} 

180 

181 def __init__(self, exception_handler=_exception_printer, *, signals=None): 

182 self._signals = None if signals is None else list(signals) # Copy it. 

183 self.exception_handler = exception_handler 

184 self.callbacks = {} 

185 self._cid_gen = itertools.count() 

186 self._func_cid_map = {} 

187 # A hidden variable that marks cids that need to be pickled. 

188 self._pickled_cids = set() 

189 

190 def __getstate__(self): 

191 return { 

192 **vars(self), 

193 # In general, callbacks may not be pickled, so we just drop them, 

194 # unless directed otherwise by self._pickled_cids. 

195 "callbacks": {s: {cid: proxy() for cid, proxy in d.items() 

196 if cid in self._pickled_cids} 

197 for s, d in self.callbacks.items()}, 

198 # It is simpler to reconstruct this from callbacks in __setstate__. 

199 "_func_cid_map": None, 

200 "_cid_gen": next(self._cid_gen) 

201 } 

202 

203 def __setstate__(self, state): 

204 cid_count = state.pop('_cid_gen') 

205 vars(self).update(state) 

206 self.callbacks = { 

207 s: {cid: _weak_or_strong_ref(func, self._remove_proxy) 

208 for cid, func in d.items()} 

209 for s, d in self.callbacks.items()} 

210 self._func_cid_map = { 

211 s: {proxy: cid for cid, proxy in d.items()} 

212 for s, d in self.callbacks.items()} 

213 self._cid_gen = itertools.count(cid_count) 

214 

215 def connect(self, signal, func): 

216 """Register *func* to be called when signal *signal* is generated.""" 

217 if self._signals is not None: 

218 _api.check_in_list(self._signals, signal=signal) 

219 self._func_cid_map.setdefault(signal, {}) 

220 proxy = _weak_or_strong_ref(func, self._remove_proxy) 

221 if proxy in self._func_cid_map[signal]: 

222 return self._func_cid_map[signal][proxy] 

223 cid = next(self._cid_gen) 

224 self._func_cid_map[signal][proxy] = cid 

225 self.callbacks.setdefault(signal, {}) 

226 self.callbacks[signal][cid] = proxy 

227 return cid 

228 

229 def _connect_picklable(self, signal, func): 

230 """ 

231 Like `.connect`, but the callback is kept when pickling/unpickling. 

232 

233 Currently internal-use only. 

234 """ 

235 cid = self.connect(signal, func) 

236 self._pickled_cids.add(cid) 

237 return cid 

238 

239 # Keep a reference to sys.is_finalizing, as sys may have been cleared out 

240 # at that point. 

241 def _remove_proxy(self, proxy, *, _is_finalizing=sys.is_finalizing): 

242 if _is_finalizing(): 

243 # Weakrefs can't be properly torn down at that point anymore. 

244 return 

245 for signal, proxy_to_cid in list(self._func_cid_map.items()): 

246 cid = proxy_to_cid.pop(proxy, None) 

247 if cid is not None: 

248 del self.callbacks[signal][cid] 

249 self._pickled_cids.discard(cid) 

250 break 

251 else: 

252 # Not found 

253 return 

254 # Clean up empty dicts 

255 if len(self.callbacks[signal]) == 0: 

256 del self.callbacks[signal] 

257 del self._func_cid_map[signal] 

258 

259 def disconnect(self, cid): 

260 """ 

261 Disconnect the callback registered with callback id *cid*. 

262 

263 No error is raised if such a callback does not exist. 

264 """ 

265 self._pickled_cids.discard(cid) 

266 # Clean up callbacks 

267 for signal, cid_to_proxy in list(self.callbacks.items()): 

268 proxy = cid_to_proxy.pop(cid, None) 

269 if proxy is not None: 

270 break 

271 else: 

272 # Not found 

273 return 

274 

275 proxy_to_cid = self._func_cid_map[signal] 

276 for current_proxy, current_cid in list(proxy_to_cid.items()): 

277 if current_cid == cid: 

278 assert proxy is current_proxy 

279 del proxy_to_cid[current_proxy] 

280 # Clean up empty dicts 

281 if len(self.callbacks[signal]) == 0: 

282 del self.callbacks[signal] 

283 del self._func_cid_map[signal] 

284 

285 def process(self, s, *args, **kwargs): 

286 """ 

287 Process signal *s*. 

288 

289 All of the functions registered to receive callbacks on *s* will be 

290 called with ``*args`` and ``**kwargs``. 

291 """ 

292 if self._signals is not None: 

293 _api.check_in_list(self._signals, signal=s) 

294 for ref in list(self.callbacks.get(s, {}).values()): 

295 func = ref() 

296 if func is not None: 

297 try: 

298 func(*args, **kwargs) 

299 # this does not capture KeyboardInterrupt, SystemExit, 

300 # and GeneratorExit 

301 except Exception as exc: 

302 if self.exception_handler is not None: 

303 self.exception_handler(exc) 

304 else: 

305 raise 

306 

307 @contextlib.contextmanager 

308 def blocked(self, *, signal=None): 

309 """ 

310 Block callback signals from being processed. 

311 

312 A context manager to temporarily block/disable callback signals 

313 from being processed by the registered listeners. 

314 

315 Parameters 

316 ---------- 

317 signal : str, optional 

318 The callback signal to block. The default is to block all signals. 

319 """ 

320 orig = self.callbacks 

321 try: 

322 if signal is None: 

323 # Empty out the callbacks 

324 self.callbacks = {} 

325 else: 

326 # Only remove the specific signal 

327 self.callbacks = {k: orig[k] for k in orig if k != signal} 

328 yield 

329 finally: 

330 self.callbacks = orig 

331 

332 

333class silent_list(list): 

334 """ 

335 A list with a short ``repr()``. 

336 

337 This is meant to be used for a homogeneous list of artists, so that they 

338 don't cause long, meaningless output. 

339 

340 Instead of :: 

341 

342 [<matplotlib.lines.Line2D object at 0x7f5749fed3c8>, 

343 <matplotlib.lines.Line2D object at 0x7f5749fed4e0>, 

344 <matplotlib.lines.Line2D object at 0x7f5758016550>] 

345 

346 one will get :: 

347 

348 <a list of 3 Line2D objects> 

349 

350 If ``self.type`` is None, the type name is obtained from the first item in 

351 the list (if any). 

352 """ 

353 

354 def __init__(self, type, seq=None): 

355 self.type = type 

356 if seq is not None: 

357 self.extend(seq) 

358 

359 def __repr__(self): 

360 if self.type is not None or len(self) != 0: 

361 tp = self.type if self.type is not None else type(self[0]).__name__ 

362 return f"<a list of {len(self)} {tp} objects>" 

363 else: 

364 return "<an empty list>" 

365 

366 

367def _local_over_kwdict( 

368 local_var, kwargs, *keys, 

369 warning_cls=_api.MatplotlibDeprecationWarning): 

370 out = local_var 

371 for key in keys: 

372 kwarg_val = kwargs.pop(key, None) 

373 if kwarg_val is not None: 

374 if out is None: 

375 out = kwarg_val 

376 else: 

377 _api.warn_external(f'"{key}" keyword argument will be ignored', 

378 warning_cls) 

379 return out 

380 

381 

382def strip_math(s): 

383 """ 

384 Remove latex formatting from mathtext. 

385 

386 Only handles fully math and fully non-math strings. 

387 """ 

388 if len(s) >= 2 and s[0] == s[-1] == "$": 

389 s = s[1:-1] 

390 for tex, plain in [ 

391 (r"\times", "x"), # Specifically for Formatter support. 

392 (r"\mathdefault", ""), 

393 (r"\rm", ""), 

394 (r"\cal", ""), 

395 (r"\tt", ""), 

396 (r"\it", ""), 

397 ("\\", ""), 

398 ("{", ""), 

399 ("}", ""), 

400 ]: 

401 s = s.replace(tex, plain) 

402 return s 

403 

404 

405def _strip_comment(s): 

406 """Strip everything from the first unquoted #.""" 

407 pos = 0 

408 while True: 

409 quote_pos = s.find('"', pos) 

410 hash_pos = s.find('#', pos) 

411 if quote_pos < 0: 

412 without_comment = s if hash_pos < 0 else s[:hash_pos] 

413 return without_comment.strip() 

414 elif 0 <= hash_pos < quote_pos: 

415 return s[:hash_pos].strip() 

416 else: 

417 closing_quote_pos = s.find('"', quote_pos + 1) 

418 if closing_quote_pos < 0: 

419 raise ValueError( 

420 f"Missing closing quote in: {s!r}. If you need a double-" 

421 'quote inside a string, use escaping: e.g. "the \" char"') 

422 pos = closing_quote_pos + 1 # behind closing quote 

423 

424 

425def is_writable_file_like(obj): 

426 """Return whether *obj* looks like a file object with a *write* method.""" 

427 return callable(getattr(obj, 'write', None)) 

428 

429 

430def file_requires_unicode(x): 

431 """ 

432 Return whether the given writable file-like object requires Unicode to be 

433 written to it. 

434 """ 

435 try: 

436 x.write(b'') 

437 except TypeError: 

438 return True 

439 else: 

440 return False 

441 

442 

443def to_filehandle(fname, flag='r', return_opened=False, encoding=None): 

444 """ 

445 Convert a path to an open file handle or pass-through a file-like object. 

446 

447 Consider using `open_file_cm` instead, as it allows one to properly close 

448 newly created file objects more easily. 

449 

450 Parameters 

451 ---------- 

452 fname : str or path-like or file-like 

453 If `str` or `os.PathLike`, the file is opened using the flags specified 

454 by *flag* and *encoding*. If a file-like object, it is passed through. 

455 flag : str, default: 'r' 

456 Passed as the *mode* argument to `open` when *fname* is `str` or 

457 `os.PathLike`; ignored if *fname* is file-like. 

458 return_opened : bool, default: False 

459 If True, return both the file object and a boolean indicating whether 

460 this was a new file (that the caller needs to close). If False, return 

461 only the new file. 

462 encoding : str or None, default: None 

463 Passed as the *mode* argument to `open` when *fname* is `str` or 

464 `os.PathLike`; ignored if *fname* is file-like. 

465 

466 Returns 

467 ------- 

468 fh : file-like 

469 opened : bool 

470 *opened* is only returned if *return_opened* is True. 

471 """ 

472 if isinstance(fname, os.PathLike): 

473 fname = os.fspath(fname) 

474 if isinstance(fname, str): 

475 if fname.endswith('.gz'): 

476 fh = gzip.open(fname, flag) 

477 elif fname.endswith('.bz2'): 

478 # python may not be compiled with bz2 support, 

479 # bury import until we need it 

480 import bz2 

481 fh = bz2.BZ2File(fname, flag) 

482 else: 

483 fh = open(fname, flag, encoding=encoding) 

484 opened = True 

485 elif hasattr(fname, 'seek'): 

486 fh = fname 

487 opened = False 

488 else: 

489 raise ValueError('fname must be a PathLike or file handle') 

490 if return_opened: 

491 return fh, opened 

492 return fh 

493 

494 

495def open_file_cm(path_or_file, mode="r", encoding=None): 

496 r"""Pass through file objects and context-manage path-likes.""" 

497 fh, opened = to_filehandle(path_or_file, mode, True, encoding) 

498 return fh if opened else contextlib.nullcontext(fh) 

499 

500 

501def is_scalar_or_string(val): 

502 """Return whether the given object is a scalar or string like.""" 

503 return isinstance(val, str) or not np.iterable(val) 

504 

505 

506@_api.delete_parameter( 

507 "3.8", "np_load", alternative="open(get_sample_data(..., asfileobj=False))") 

508def get_sample_data(fname, asfileobj=True, *, np_load=True): 

509 """ 

510 Return a sample data file. *fname* is a path relative to the 

511 :file:`mpl-data/sample_data` directory. If *asfileobj* is `True` 

512 return a file object, otherwise just a file path. 

513 

514 Sample data files are stored in the 'mpl-data/sample_data' directory within 

515 the Matplotlib package. 

516 

517 If the filename ends in .gz, the file is implicitly ungzipped. If the 

518 filename ends with .npy or .npz, and *asfileobj* is `True`, the file is 

519 loaded with `numpy.load`. 

520 """ 

521 path = _get_data_path('sample_data', fname) 

522 if asfileobj: 

523 suffix = path.suffix.lower() 

524 if suffix == '.gz': 

525 return gzip.open(path) 

526 elif suffix in ['.npy', '.npz']: 

527 if np_load: 

528 return np.load(path) 

529 else: 

530 return path.open('rb') 

531 elif suffix in ['.csv', '.xrc', '.txt']: 

532 return path.open('r') 

533 else: 

534 return path.open('rb') 

535 else: 

536 return str(path) 

537 

538 

539def _get_data_path(*args): 

540 """ 

541 Return the `pathlib.Path` to a resource file provided by Matplotlib. 

542 

543 ``*args`` specify a path relative to the base data path. 

544 """ 

545 return Path(matplotlib.get_data_path(), *args) 

546 

547 

548def flatten(seq, scalarp=is_scalar_or_string): 

549 """ 

550 Return a generator of flattened nested containers. 

551 

552 For example: 

553 

554 >>> from matplotlib.cbook import flatten 

555 >>> l = (('John', ['Hunter']), (1, 23), [[([42, (5, 23)], )]]) 

556 >>> print(list(flatten(l))) 

557 ['John', 'Hunter', 1, 23, 42, 5, 23] 

558 

559 By: Composite of Holger Krekel and Luther Blissett 

560 From: https://code.activestate.com/recipes/121294/ 

561 and Recipe 1.12 in cookbook 

562 """ 

563 for item in seq: 

564 if scalarp(item) or item is None: 

565 yield item 

566 else: 

567 yield from flatten(item, scalarp) 

568 

569 

570@_api.deprecated("3.8") 

571class Stack: 

572 """ 

573 Stack of elements with a movable cursor. 

574 

575 Mimics home/back/forward in a web browser. 

576 """ 

577 

578 def __init__(self, default=None): 

579 self.clear() 

580 self._default = default 

581 

582 def __call__(self): 

583 """Return the current element, or None.""" 

584 if not self._elements: 

585 return self._default 

586 else: 

587 return self._elements[self._pos] 

588 

589 def __len__(self): 

590 return len(self._elements) 

591 

592 def __getitem__(self, ind): 

593 return self._elements[ind] 

594 

595 def forward(self): 

596 """Move the position forward and return the current element.""" 

597 self._pos = min(self._pos + 1, len(self._elements) - 1) 

598 return self() 

599 

600 def back(self): 

601 """Move the position back and return the current element.""" 

602 if self._pos > 0: 

603 self._pos -= 1 

604 return self() 

605 

606 def push(self, o): 

607 """ 

608 Push *o* to the stack at current position. Discard all later elements. 

609 

610 *o* is returned. 

611 """ 

612 self._elements = self._elements[:self._pos + 1] + [o] 

613 self._pos = len(self._elements) - 1 

614 return self() 

615 

616 def home(self): 

617 """ 

618 Push the first element onto the top of the stack. 

619 

620 The first element is returned. 

621 """ 

622 if not self._elements: 

623 return 

624 self.push(self._elements[0]) 

625 return self() 

626 

627 def empty(self): 

628 """Return whether the stack is empty.""" 

629 return len(self._elements) == 0 

630 

631 def clear(self): 

632 """Empty the stack.""" 

633 self._pos = -1 

634 self._elements = [] 

635 

636 def bubble(self, o): 

637 """ 

638 Raise all references of *o* to the top of the stack, and return it. 

639 

640 Raises 

641 ------ 

642 ValueError 

643 If *o* is not in the stack. 

644 """ 

645 if o not in self._elements: 

646 raise ValueError('Given element not contained in the stack') 

647 old_elements = self._elements.copy() 

648 self.clear() 

649 top_elements = [] 

650 for elem in old_elements: 

651 if elem == o: 

652 top_elements.append(elem) 

653 else: 

654 self.push(elem) 

655 for _ in top_elements: 

656 self.push(o) 

657 return o 

658 

659 def remove(self, o): 

660 """ 

661 Remove *o* from the stack. 

662 

663 Raises 

664 ------ 

665 ValueError 

666 If *o* is not in the stack. 

667 """ 

668 if o not in self._elements: 

669 raise ValueError('Given element not contained in the stack') 

670 old_elements = self._elements.copy() 

671 self.clear() 

672 for elem in old_elements: 

673 if elem != o: 

674 self.push(elem) 

675 

676 

677class _Stack: 

678 """ 

679 Stack of elements with a movable cursor. 

680 

681 Mimics home/back/forward in a web browser. 

682 """ 

683 

684 def __init__(self): 

685 self._pos = -1 

686 self._elements = [] 

687 

688 def clear(self): 

689 """Empty the stack.""" 

690 self._pos = -1 

691 self._elements = [] 

692 

693 def __call__(self): 

694 """Return the current element, or None.""" 

695 return self._elements[self._pos] if self._elements else None 

696 

697 def __len__(self): 

698 return len(self._elements) 

699 

700 def __getitem__(self, ind): 

701 return self._elements[ind] 

702 

703 def forward(self): 

704 """Move the position forward and return the current element.""" 

705 self._pos = min(self._pos + 1, len(self._elements) - 1) 

706 return self() 

707 

708 def back(self): 

709 """Move the position back and return the current element.""" 

710 self._pos = max(self._pos - 1, 0) 

711 return self() 

712 

713 def push(self, o): 

714 """ 

715 Push *o* to the stack after the current position, and return *o*. 

716 

717 Discard all later elements. 

718 """ 

719 self._elements[self._pos + 1:] = [o] 

720 self._pos = len(self._elements) - 1 

721 return o 

722 

723 def home(self): 

724 """ 

725 Push the first element onto the top of the stack. 

726 

727 The first element is returned. 

728 """ 

729 return self.push(self._elements[0]) if self._elements else None 

730 

731 

732def safe_masked_invalid(x, copy=False): 

733 x = np.array(x, subok=True, copy=copy) 

734 if not x.dtype.isnative: 

735 # If we have already made a copy, do the byteswap in place, else make a 

736 # copy with the byte order swapped. 

737 # Swap to native order. 

738 x = x.byteswap(inplace=copy).view(x.dtype.newbyteorder('N')) 

739 try: 

740 xm = np.ma.masked_where(~(np.isfinite(x)), x, copy=False) 

741 except TypeError: 

742 return x 

743 return xm 

744 

745 

746def print_cycles(objects, outstream=sys.stdout, show_progress=False): 

747 """ 

748 Print loops of cyclic references in the given *objects*. 

749 

750 It is often useful to pass in ``gc.garbage`` to find the cycles that are 

751 preventing some objects from being garbage collected. 

752 

753 Parameters 

754 ---------- 

755 objects 

756 A list of objects to find cycles in. 

757 outstream 

758 The stream for output. 

759 show_progress : bool 

760 If True, print the number of objects reached as they are found. 

761 """ 

762 import gc 

763 

764 def print_path(path): 

765 for i, step in enumerate(path): 

766 # next "wraps around" 

767 next = path[(i + 1) % len(path)] 

768 

769 outstream.write(" %s -- " % type(step)) 

770 if isinstance(step, dict): 

771 for key, val in step.items(): 

772 if val is next: 

773 outstream.write(f"[{key!r}]") 

774 break 

775 if key is next: 

776 outstream.write(f"[key] = {val!r}") 

777 break 

778 elif isinstance(step, list): 

779 outstream.write("[%d]" % step.index(next)) 

780 elif isinstance(step, tuple): 

781 outstream.write("( tuple )") 

782 else: 

783 outstream.write(repr(step)) 

784 outstream.write(" ->\n") 

785 outstream.write("\n") 

786 

787 def recurse(obj, start, all, current_path): 

788 if show_progress: 

789 outstream.write("%d\r" % len(all)) 

790 

791 all[id(obj)] = None 

792 

793 referents = gc.get_referents(obj) 

794 for referent in referents: 

795 # If we've found our way back to the start, this is 

796 # a cycle, so print it out 

797 if referent is start: 

798 print_path(current_path) 

799 

800 # Don't go back through the original list of objects, or 

801 # through temporary references to the object, since those 

802 # are just an artifact of the cycle detector itself. 

803 elif referent is objects or isinstance(referent, types.FrameType): 

804 continue 

805 

806 # We haven't seen this object before, so recurse 

807 elif id(referent) not in all: 

808 recurse(referent, start, all, current_path + [obj]) 

809 

810 for obj in objects: 

811 outstream.write(f"Examining: {obj!r}\n") 

812 recurse(obj, obj, {}, []) 

813 

814 

815class Grouper: 

816 """ 

817 A disjoint-set data structure. 

818 

819 Objects can be joined using :meth:`join`, tested for connectedness 

820 using :meth:`joined`, and all disjoint sets can be retrieved by 

821 using the object as an iterator. 

822 

823 The objects being joined must be hashable and weak-referenceable. 

824 

825 Examples 

826 -------- 

827 >>> from matplotlib.cbook import Grouper 

828 >>> class Foo: 

829 ... def __init__(self, s): 

830 ... self.s = s 

831 ... def __repr__(self): 

832 ... return self.s 

833 ... 

834 >>> a, b, c, d, e, f = [Foo(x) for x in 'abcdef'] 

835 >>> grp = Grouper() 

836 >>> grp.join(a, b) 

837 >>> grp.join(b, c) 

838 >>> grp.join(d, e) 

839 >>> list(grp) 

840 [[a, b, c], [d, e]] 

841 >>> grp.joined(a, b) 

842 True 

843 >>> grp.joined(a, c) 

844 True 

845 >>> grp.joined(a, d) 

846 False 

847 """ 

848 

849 def __init__(self, init=()): 

850 self._mapping = weakref.WeakKeyDictionary( 

851 {x: weakref.WeakSet([x]) for x in init}) 

852 self._ordering = weakref.WeakKeyDictionary() 

853 for x in init: 

854 if x not in self._ordering: 

855 self._ordering[x] = len(self._ordering) 

856 self._next_order = len(self._ordering) # Plain int to simplify pickling. 

857 

858 def __getstate__(self): 

859 return { 

860 **vars(self), 

861 # Convert weak refs to strong ones. 

862 "_mapping": {k: set(v) for k, v in self._mapping.items()}, 

863 "_ordering": {**self._ordering}, 

864 } 

865 

866 def __setstate__(self, state): 

867 vars(self).update(state) 

868 # Convert strong refs to weak ones. 

869 self._mapping = weakref.WeakKeyDictionary( 

870 {k: weakref.WeakSet(v) for k, v in self._mapping.items()}) 

871 self._ordering = weakref.WeakKeyDictionary(self._ordering) 

872 

873 def __contains__(self, item): 

874 return item in self._mapping 

875 

876 @_api.deprecated("3.8", alternative="none, you no longer need to clean a Grouper") 

877 def clean(self): 

878 """Clean dead weak references from the dictionary.""" 

879 

880 def join(self, a, *args): 

881 """ 

882 Join given arguments into the same set. Accepts one or more arguments. 

883 """ 

884 mapping = self._mapping 

885 try: 

886 set_a = mapping[a] 

887 except KeyError: 

888 set_a = mapping[a] = weakref.WeakSet([a]) 

889 self._ordering[a] = self._next_order 

890 self._next_order += 1 

891 for arg in args: 

892 try: 

893 set_b = mapping[arg] 

894 except KeyError: 

895 set_b = mapping[arg] = weakref.WeakSet([arg]) 

896 self._ordering[arg] = self._next_order 

897 self._next_order += 1 

898 if set_b is not set_a: 

899 if len(set_b) > len(set_a): 

900 set_a, set_b = set_b, set_a 

901 set_a.update(set_b) 

902 for elem in set_b: 

903 mapping[elem] = set_a 

904 

905 def joined(self, a, b): 

906 """Return whether *a* and *b* are members of the same set.""" 

907 return (self._mapping.get(a, object()) is self._mapping.get(b)) 

908 

909 def remove(self, a): 

910 """Remove *a* from the grouper, doing nothing if it is not there.""" 

911 self._mapping.pop(a, {a}).remove(a) 

912 self._ordering.pop(a, None) 

913 

914 def __iter__(self): 

915 """ 

916 Iterate over each of the disjoint sets as a list. 

917 

918 The iterator is invalid if interleaved with calls to join(). 

919 """ 

920 unique_groups = {id(group): group for group in self._mapping.values()} 

921 for group in unique_groups.values(): 

922 yield sorted(group, key=self._ordering.__getitem__) 

923 

924 def get_siblings(self, a): 

925 """Return all of the items joined with *a*, including itself.""" 

926 siblings = self._mapping.get(a, [a]) 

927 return sorted(siblings, key=self._ordering.get) 

928 

929 

930class GrouperView: 

931 """Immutable view over a `.Grouper`.""" 

932 

933 def __init__(self, grouper): self._grouper = grouper 

934 def __contains__(self, item): return item in self._grouper 

935 def __iter__(self): return iter(self._grouper) 

936 def joined(self, a, b): return self._grouper.joined(a, b) 

937 def get_siblings(self, a): return self._grouper.get_siblings(a) 

938 

939 

940def simple_linear_interpolation(a, steps): 

941 """ 

942 Resample an array with ``steps - 1`` points between original point pairs. 

943 

944 Along each column of *a*, ``(steps - 1)`` points are introduced between 

945 each original values; the values are linearly interpolated. 

946 

947 Parameters 

948 ---------- 

949 a : array, shape (n, ...) 

950 steps : int 

951 

952 Returns 

953 ------- 

954 array 

955 shape ``((n - 1) * steps + 1, ...)`` 

956 """ 

957 fps = a.reshape((len(a), -1)) 

958 xp = np.arange(len(a)) * steps 

959 x = np.arange((len(a) - 1) * steps + 1) 

960 return (np.column_stack([np.interp(x, xp, fp) for fp in fps.T]) 

961 .reshape((len(x),) + a.shape[1:])) 

962 

963 

964def delete_masked_points(*args): 

965 """ 

966 Find all masked and/or non-finite points in a set of arguments, 

967 and return the arguments with only the unmasked points remaining. 

968 

969 Arguments can be in any of 5 categories: 

970 

971 1) 1-D masked arrays 

972 2) 1-D ndarrays 

973 3) ndarrays with more than one dimension 

974 4) other non-string iterables 

975 5) anything else 

976 

977 The first argument must be in one of the first four categories; 

978 any argument with a length differing from that of the first 

979 argument (and hence anything in category 5) then will be 

980 passed through unchanged. 

981 

982 Masks are obtained from all arguments of the correct length 

983 in categories 1, 2, and 4; a point is bad if masked in a masked 

984 array or if it is a nan or inf. No attempt is made to 

985 extract a mask from categories 2, 3, and 4 if `numpy.isfinite` 

986 does not yield a Boolean array. 

987 

988 All input arguments that are not passed unchanged are returned 

989 as ndarrays after removing the points or rows corresponding to 

990 masks in any of the arguments. 

991 

992 A vastly simpler version of this function was originally 

993 written as a helper for Axes.scatter(). 

994 

995 """ 

996 if not len(args): 

997 return () 

998 if is_scalar_or_string(args[0]): 

999 raise ValueError("First argument must be a sequence") 

1000 nrecs = len(args[0]) 

1001 margs = [] 

1002 seqlist = [False] * len(args) 

1003 for i, x in enumerate(args): 

1004 if not isinstance(x, str) and np.iterable(x) and len(x) == nrecs: 

1005 seqlist[i] = True 

1006 if isinstance(x, np.ma.MaskedArray): 

1007 if x.ndim > 1: 

1008 raise ValueError("Masked arrays must be 1-D") 

1009 else: 

1010 x = np.asarray(x) 

1011 margs.append(x) 

1012 masks = [] # List of masks that are True where good. 

1013 for i, x in enumerate(margs): 

1014 if seqlist[i]: 

1015 if x.ndim > 1: 

1016 continue # Don't try to get nan locations unless 1-D. 

1017 if isinstance(x, np.ma.MaskedArray): 

1018 masks.append(~np.ma.getmaskarray(x)) # invert the mask 

1019 xd = x.data 

1020 else: 

1021 xd = x 

1022 try: 

1023 mask = np.isfinite(xd) 

1024 if isinstance(mask, np.ndarray): 

1025 masks.append(mask) 

1026 except Exception: # Fixme: put in tuple of possible exceptions? 

1027 pass 

1028 if len(masks): 

1029 mask = np.logical_and.reduce(masks) 

1030 igood = mask.nonzero()[0] 

1031 if len(igood) < nrecs: 

1032 for i, x in enumerate(margs): 

1033 if seqlist[i]: 

1034 margs[i] = x[igood] 

1035 for i, x in enumerate(margs): 

1036 if seqlist[i] and isinstance(x, np.ma.MaskedArray): 

1037 margs[i] = x.filled() 

1038 return margs 

1039 

1040 

1041def _combine_masks(*args): 

1042 """ 

1043 Find all masked and/or non-finite points in a set of arguments, 

1044 and return the arguments as masked arrays with a common mask. 

1045 

1046 Arguments can be in any of 5 categories: 

1047 

1048 1) 1-D masked arrays 

1049 2) 1-D ndarrays 

1050 3) ndarrays with more than one dimension 

1051 4) other non-string iterables 

1052 5) anything else 

1053 

1054 The first argument must be in one of the first four categories; 

1055 any argument with a length differing from that of the first 

1056 argument (and hence anything in category 5) then will be 

1057 passed through unchanged. 

1058 

1059 Masks are obtained from all arguments of the correct length 

1060 in categories 1, 2, and 4; a point is bad if masked in a masked 

1061 array or if it is a nan or inf. No attempt is made to 

1062 extract a mask from categories 2 and 4 if `numpy.isfinite` 

1063 does not yield a Boolean array. Category 3 is included to 

1064 support RGB or RGBA ndarrays, which are assumed to have only 

1065 valid values and which are passed through unchanged. 

1066 

1067 All input arguments that are not passed unchanged are returned 

1068 as masked arrays if any masked points are found, otherwise as 

1069 ndarrays. 

1070 

1071 """ 

1072 if not len(args): 

1073 return () 

1074 if is_scalar_or_string(args[0]): 

1075 raise ValueError("First argument must be a sequence") 

1076 nrecs = len(args[0]) 

1077 margs = [] # Output args; some may be modified. 

1078 seqlist = [False] * len(args) # Flags: True if output will be masked. 

1079 masks = [] # List of masks. 

1080 for i, x in enumerate(args): 

1081 if is_scalar_or_string(x) or len(x) != nrecs: 

1082 margs.append(x) # Leave it unmodified. 

1083 else: 

1084 if isinstance(x, np.ma.MaskedArray) and x.ndim > 1: 

1085 raise ValueError("Masked arrays must be 1-D") 

1086 try: 

1087 x = np.asanyarray(x) 

1088 except (VisibleDeprecationWarning, ValueError): 

1089 # NumPy 1.19 raises a warning about ragged arrays, but we want 

1090 # to accept basically anything here. 

1091 x = np.asanyarray(x, dtype=object) 

1092 if x.ndim == 1: 

1093 x = safe_masked_invalid(x) 

1094 seqlist[i] = True 

1095 if np.ma.is_masked(x): 

1096 masks.append(np.ma.getmaskarray(x)) 

1097 margs.append(x) # Possibly modified. 

1098 if len(masks): 

1099 mask = np.logical_or.reduce(masks) 

1100 for i, x in enumerate(margs): 

1101 if seqlist[i]: 

1102 margs[i] = np.ma.array(x, mask=mask) 

1103 return margs 

1104 

1105 

1106def _broadcast_with_masks(*args, compress=False): 

1107 """ 

1108 Broadcast inputs, combining all masked arrays. 

1109 

1110 Parameters 

1111 ---------- 

1112 *args : array-like 

1113 The inputs to broadcast. 

1114 compress : bool, default: False 

1115 Whether to compress the masked arrays. If False, the masked values 

1116 are replaced by NaNs. 

1117 

1118 Returns 

1119 ------- 

1120 list of array-like 

1121 The broadcasted and masked inputs. 

1122 """ 

1123 # extract the masks, if any 

1124 masks = [k.mask for k in args if isinstance(k, np.ma.MaskedArray)] 

1125 # broadcast to match the shape 

1126 bcast = np.broadcast_arrays(*args, *masks) 

1127 inputs = bcast[:len(args)] 

1128 masks = bcast[len(args):] 

1129 if masks: 

1130 # combine the masks into one 

1131 mask = np.logical_or.reduce(masks) 

1132 # put mask on and compress 

1133 if compress: 

1134 inputs = [np.ma.array(k, mask=mask).compressed() 

1135 for k in inputs] 

1136 else: 

1137 inputs = [np.ma.array(k, mask=mask, dtype=float).filled(np.nan).ravel() 

1138 for k in inputs] 

1139 else: 

1140 inputs = [np.ravel(k) for k in inputs] 

1141 return inputs 

1142 

1143 

1144def boxplot_stats(X, whis=1.5, bootstrap=None, labels=None, autorange=False): 

1145 r""" 

1146 Return a list of dictionaries of statistics used to draw a series of box 

1147 and whisker plots using `~.Axes.bxp`. 

1148 

1149 Parameters 

1150 ---------- 

1151 X : array-like 

1152 Data that will be represented in the boxplots. Should have 2 or 

1153 fewer dimensions. 

1154 

1155 whis : float or (float, float), default: 1.5 

1156 The position of the whiskers. 

1157 

1158 If a float, the lower whisker is at the lowest datum above 

1159 ``Q1 - whis*(Q3-Q1)``, and the upper whisker at the highest datum below 

1160 ``Q3 + whis*(Q3-Q1)``, where Q1 and Q3 are the first and third 

1161 quartiles. The default value of ``whis = 1.5`` corresponds to Tukey's 

1162 original definition of boxplots. 

1163 

1164 If a pair of floats, they indicate the percentiles at which to draw the 

1165 whiskers (e.g., (5, 95)). In particular, setting this to (0, 100) 

1166 results in whiskers covering the whole range of the data. 

1167 

1168 In the edge case where ``Q1 == Q3``, *whis* is automatically set to 

1169 (0, 100) (cover the whole range of the data) if *autorange* is True. 

1170 

1171 Beyond the whiskers, data are considered outliers and are plotted as 

1172 individual points. 

1173 

1174 bootstrap : int, optional 

1175 Number of times the confidence intervals around the median 

1176 should be bootstrapped (percentile method). 

1177 

1178 labels : list of str, optional 

1179 Labels for each dataset. Length must be compatible with 

1180 dimensions of *X*. 

1181 

1182 autorange : bool, optional (False) 

1183 When `True` and the data are distributed such that the 25th and 75th 

1184 percentiles are equal, ``whis`` is set to (0, 100) such that the 

1185 whisker ends are at the minimum and maximum of the data. 

1186 

1187 Returns 

1188 ------- 

1189 list of dict 

1190 A list of dictionaries containing the results for each column 

1191 of data. Keys of each dictionary are the following: 

1192 

1193 ======== =================================== 

1194 Key Value Description 

1195 ======== =================================== 

1196 label tick label for the boxplot 

1197 mean arithmetic mean value 

1198 med 50th percentile 

1199 q1 first quartile (25th percentile) 

1200 q3 third quartile (75th percentile) 

1201 iqr interquartile range 

1202 cilo lower notch around the median 

1203 cihi upper notch around the median 

1204 whislo end of the lower whisker 

1205 whishi end of the upper whisker 

1206 fliers outliers 

1207 ======== =================================== 

1208 

1209 Notes 

1210 ----- 

1211 Non-bootstrapping approach to confidence interval uses Gaussian-based 

1212 asymptotic approximation: 

1213 

1214 .. math:: 

1215 

1216 \mathrm{med} \pm 1.57 \times \frac{\mathrm{iqr}}{\sqrt{N}} 

1217 

1218 General approach from: 

1219 McGill, R., Tukey, J.W., and Larsen, W.A. (1978) "Variations of 

1220 Boxplots", The American Statistician, 32:12-16. 

1221 """ 

1222 

1223 def _bootstrap_median(data, N=5000): 

1224 # determine 95% confidence intervals of the median 

1225 M = len(data) 

1226 percentiles = [2.5, 97.5] 

1227 

1228 bs_index = np.random.randint(M, size=(N, M)) 

1229 bsData = data[bs_index] 

1230 estimate = np.median(bsData, axis=1, overwrite_input=True) 

1231 

1232 CI = np.percentile(estimate, percentiles) 

1233 return CI 

1234 

1235 def _compute_conf_interval(data, med, iqr, bootstrap): 

1236 if bootstrap is not None: 

1237 # Do a bootstrap estimate of notch locations. 

1238 # get conf. intervals around median 

1239 CI = _bootstrap_median(data, N=bootstrap) 

1240 notch_min = CI[0] 

1241 notch_max = CI[1] 

1242 else: 

1243 

1244 N = len(data) 

1245 notch_min = med - 1.57 * iqr / np.sqrt(N) 

1246 notch_max = med + 1.57 * iqr / np.sqrt(N) 

1247 

1248 return notch_min, notch_max 

1249 

1250 # output is a list of dicts 

1251 bxpstats = [] 

1252 

1253 # convert X to a list of lists 

1254 X = _reshape_2D(X, "X") 

1255 

1256 ncols = len(X) 

1257 if labels is None: 

1258 labels = itertools.repeat(None) 

1259 elif len(labels) != ncols: 

1260 raise ValueError("Dimensions of labels and X must be compatible") 

1261 

1262 input_whis = whis 

1263 for ii, (x, label) in enumerate(zip(X, labels)): 

1264 

1265 # empty dict 

1266 stats = {} 

1267 if label is not None: 

1268 stats['label'] = label 

1269 

1270 # restore whis to the input values in case it got changed in the loop 

1271 whis = input_whis 

1272 

1273 # note tricksiness, append up here and then mutate below 

1274 bxpstats.append(stats) 

1275 

1276 # if empty, bail 

1277 if len(x) == 0: 

1278 stats['fliers'] = np.array([]) 

1279 stats['mean'] = np.nan 

1280 stats['med'] = np.nan 

1281 stats['q1'] = np.nan 

1282 stats['q3'] = np.nan 

1283 stats['iqr'] = np.nan 

1284 stats['cilo'] = np.nan 

1285 stats['cihi'] = np.nan 

1286 stats['whislo'] = np.nan 

1287 stats['whishi'] = np.nan 

1288 continue 

1289 

1290 # up-convert to an array, just to be safe 

1291 x = np.ma.asarray(x) 

1292 x = x.data[~x.mask].ravel() 

1293 

1294 # arithmetic mean 

1295 stats['mean'] = np.mean(x) 

1296 

1297 # medians and quartiles 

1298 q1, med, q3 = np.percentile(x, [25, 50, 75]) 

1299 

1300 # interquartile range 

1301 stats['iqr'] = q3 - q1 

1302 if stats['iqr'] == 0 and autorange: 

1303 whis = (0, 100) 

1304 

1305 # conf. interval around median 

1306 stats['cilo'], stats['cihi'] = _compute_conf_interval( 

1307 x, med, stats['iqr'], bootstrap 

1308 ) 

1309 

1310 # lowest/highest non-outliers 

1311 if np.iterable(whis) and not isinstance(whis, str): 

1312 loval, hival = np.percentile(x, whis) 

1313 elif np.isreal(whis): 

1314 loval = q1 - whis * stats['iqr'] 

1315 hival = q3 + whis * stats['iqr'] 

1316 else: 

1317 raise ValueError('whis must be a float or list of percentiles') 

1318 

1319 # get high extreme 

1320 wiskhi = x[x <= hival] 

1321 if len(wiskhi) == 0 or np.max(wiskhi) < q3: 

1322 stats['whishi'] = q3 

1323 else: 

1324 stats['whishi'] = np.max(wiskhi) 

1325 

1326 # get low extreme 

1327 wisklo = x[x >= loval] 

1328 if len(wisklo) == 0 or np.min(wisklo) > q1: 

1329 stats['whislo'] = q1 

1330 else: 

1331 stats['whislo'] = np.min(wisklo) 

1332 

1333 # compute a single array of outliers 

1334 stats['fliers'] = np.concatenate([ 

1335 x[x < stats['whislo']], 

1336 x[x > stats['whishi']], 

1337 ]) 

1338 

1339 # add in the remaining stats 

1340 stats['q1'], stats['med'], stats['q3'] = q1, med, q3 

1341 

1342 return bxpstats 

1343 

1344 

1345#: Maps short codes for line style to their full name used by backends. 

1346ls_mapper = {'-': 'solid', '--': 'dashed', '-.': 'dashdot', ':': 'dotted'} 

1347#: Maps full names for line styles used by backends to their short codes. 

1348ls_mapper_r = {v: k for k, v in ls_mapper.items()} 

1349 

1350 

1351def contiguous_regions(mask): 

1352 """ 

1353 Return a list of (ind0, ind1) such that ``mask[ind0:ind1].all()`` is 

1354 True and we cover all such regions. 

1355 """ 

1356 mask = np.asarray(mask, dtype=bool) 

1357 

1358 if not mask.size: 

1359 return [] 

1360 

1361 # Find the indices of region changes, and correct offset 

1362 idx, = np.nonzero(mask[:-1] != mask[1:]) 

1363 idx += 1 

1364 

1365 # List operations are faster for moderately sized arrays 

1366 idx = idx.tolist() 

1367 

1368 # Add first and/or last index if needed 

1369 if mask[0]: 

1370 idx = [0] + idx 

1371 if mask[-1]: 

1372 idx.append(len(mask)) 

1373 

1374 return list(zip(idx[::2], idx[1::2])) 

1375 

1376 

1377def is_math_text(s): 

1378 """ 

1379 Return whether the string *s* contains math expressions. 

1380 

1381 This is done by checking whether *s* contains an even number of 

1382 non-escaped dollar signs. 

1383 """ 

1384 s = str(s) 

1385 dollar_count = s.count(r'$') - s.count(r'\$') 

1386 even_dollars = (dollar_count > 0 and dollar_count % 2 == 0) 

1387 return even_dollars 

1388 

1389 

1390def _to_unmasked_float_array(x): 

1391 """ 

1392 Convert a sequence to a float array; if input was a masked array, masked 

1393 values are converted to nans. 

1394 """ 

1395 if hasattr(x, 'mask'): 

1396 return np.ma.asarray(x, float).filled(np.nan) 

1397 else: 

1398 return np.asarray(x, float) 

1399 

1400 

1401def _check_1d(x): 

1402 """Convert scalars to 1D arrays; pass-through arrays as is.""" 

1403 # Unpack in case of e.g. Pandas or xarray object 

1404 x = _unpack_to_numpy(x) 

1405 # plot requires `shape` and `ndim`. If passed an 

1406 # object that doesn't provide them, then force to numpy array. 

1407 # Note this will strip unit information. 

1408 if (not hasattr(x, 'shape') or 

1409 not hasattr(x, 'ndim') or 

1410 len(x.shape) < 1): 

1411 return np.atleast_1d(x) 

1412 else: 

1413 return x 

1414 

1415 

1416def _reshape_2D(X, name): 

1417 """ 

1418 Use Fortran ordering to convert ndarrays and lists of iterables to lists of 

1419 1D arrays. 

1420 

1421 Lists of iterables are converted by applying `numpy.asanyarray` to each of 

1422 their elements. 1D ndarrays are returned in a singleton list containing 

1423 them. 2D ndarrays are converted to the list of their *columns*. 

1424 

1425 *name* is used to generate the error message for invalid inputs. 

1426 """ 

1427 

1428 # Unpack in case of e.g. Pandas or xarray object 

1429 X = _unpack_to_numpy(X) 

1430 

1431 # Iterate over columns for ndarrays. 

1432 if isinstance(X, np.ndarray): 

1433 X = X.T 

1434 

1435 if len(X) == 0: 

1436 return [[]] 

1437 elif X.ndim == 1 and np.ndim(X[0]) == 0: 

1438 # 1D array of scalars: directly return it. 

1439 return [X] 

1440 elif X.ndim in [1, 2]: 

1441 # 2D array, or 1D array of iterables: flatten them first. 

1442 return [np.reshape(x, -1) for x in X] 

1443 else: 

1444 raise ValueError(f'{name} must have 2 or fewer dimensions') 

1445 

1446 # Iterate over list of iterables. 

1447 if len(X) == 0: 

1448 return [[]] 

1449 

1450 result = [] 

1451 is_1d = True 

1452 for xi in X: 

1453 # check if this is iterable, except for strings which we 

1454 # treat as singletons. 

1455 if not isinstance(xi, str): 

1456 try: 

1457 iter(xi) 

1458 except TypeError: 

1459 pass 

1460 else: 

1461 is_1d = False 

1462 xi = np.asanyarray(xi) 

1463 nd = np.ndim(xi) 

1464 if nd > 1: 

1465 raise ValueError(f'{name} must have 2 or fewer dimensions') 

1466 result.append(xi.reshape(-1)) 

1467 

1468 if is_1d: 

1469 # 1D array of scalars: directly return it. 

1470 return [np.reshape(result, -1)] 

1471 else: 

1472 # 2D array, or 1D array of iterables: use flattened version. 

1473 return result 

1474 

1475 

1476def violin_stats(X, method, points=100, quantiles=None): 

1477 """ 

1478 Return a list of dictionaries of data which can be used to draw a series 

1479 of violin plots. 

1480 

1481 See the ``Returns`` section below to view the required keys of the 

1482 dictionary. 

1483 

1484 Users can skip this function and pass a user-defined set of dictionaries 

1485 with the same keys to `~.axes.Axes.violinplot` instead of using Matplotlib 

1486 to do the calculations. See the *Returns* section below for the keys 

1487 that must be present in the dictionaries. 

1488 

1489 Parameters 

1490 ---------- 

1491 X : array-like 

1492 Sample data that will be used to produce the gaussian kernel density 

1493 estimates. Must have 2 or fewer dimensions. 

1494 

1495 method : callable 

1496 The method used to calculate the kernel density estimate for each 

1497 column of data. When called via ``method(v, coords)``, it should 

1498 return a vector of the values of the KDE evaluated at the values 

1499 specified in coords. 

1500 

1501 points : int, default: 100 

1502 Defines the number of points to evaluate each of the gaussian kernel 

1503 density estimates at. 

1504 

1505 quantiles : array-like, default: None 

1506 Defines (if not None) a list of floats in interval [0, 1] for each 

1507 column of data, which represents the quantiles that will be rendered 

1508 for that column of data. Must have 2 or fewer dimensions. 1D array will 

1509 be treated as a singleton list containing them. 

1510 

1511 Returns 

1512 ------- 

1513 list of dict 

1514 A list of dictionaries containing the results for each column of data. 

1515 The dictionaries contain at least the following: 

1516 

1517 - coords: A list of scalars containing the coordinates this particular 

1518 kernel density estimate was evaluated at. 

1519 - vals: A list of scalars containing the values of the kernel density 

1520 estimate at each of the coordinates given in *coords*. 

1521 - mean: The mean value for this column of data. 

1522 - median: The median value for this column of data. 

1523 - min: The minimum value for this column of data. 

1524 - max: The maximum value for this column of data. 

1525 - quantiles: The quantile values for this column of data. 

1526 """ 

1527 

1528 # List of dictionaries describing each of the violins. 

1529 vpstats = [] 

1530 

1531 # Want X to be a list of data sequences 

1532 X = _reshape_2D(X, "X") 

1533 

1534 # Want quantiles to be as the same shape as data sequences 

1535 if quantiles is not None and len(quantiles) != 0: 

1536 quantiles = _reshape_2D(quantiles, "quantiles") 

1537 # Else, mock quantiles if it's none or empty 

1538 else: 

1539 quantiles = [[]] * len(X) 

1540 

1541 # quantiles should have the same size as dataset 

1542 if len(X) != len(quantiles): 

1543 raise ValueError("List of violinplot statistics and quantiles values" 

1544 " must have the same length") 

1545 

1546 # Zip x and quantiles 

1547 for (x, q) in zip(X, quantiles): 

1548 # Dictionary of results for this distribution 

1549 stats = {} 

1550 

1551 # Calculate basic stats for the distribution 

1552 min_val = np.min(x) 

1553 max_val = np.max(x) 

1554 quantile_val = np.percentile(x, 100 * q) 

1555 

1556 # Evaluate the kernel density estimate 

1557 coords = np.linspace(min_val, max_val, points) 

1558 stats['vals'] = method(x, coords) 

1559 stats['coords'] = coords 

1560 

1561 # Store additional statistics for this distribution 

1562 stats['mean'] = np.mean(x) 

1563 stats['median'] = np.median(x) 

1564 stats['min'] = min_val 

1565 stats['max'] = max_val 

1566 stats['quantiles'] = np.atleast_1d(quantile_val) 

1567 

1568 # Append to output 

1569 vpstats.append(stats) 

1570 

1571 return vpstats 

1572 

1573 

1574def pts_to_prestep(x, *args): 

1575 """ 

1576 Convert continuous line to pre-steps. 

1577 

1578 Given a set of ``N`` points, convert to ``2N - 1`` points, which when 

1579 connected linearly give a step function which changes values at the 

1580 beginning of the intervals. 

1581 

1582 Parameters 

1583 ---------- 

1584 x : array 

1585 The x location of the steps. May be empty. 

1586 

1587 y1, ..., yp : array 

1588 y arrays to be turned into steps; all must be the same length as ``x``. 

1589 

1590 Returns 

1591 ------- 

1592 array 

1593 The x and y values converted to steps in the same order as the input; 

1594 can be unpacked as ``x_out, y1_out, ..., yp_out``. If the input is 

1595 length ``N``, each of these arrays will be length ``2N + 1``. For 

1596 ``N=0``, the length will be 0. 

1597 

1598 Examples 

1599 -------- 

1600 >>> x_s, y1_s, y2_s = pts_to_prestep(x, y1, y2) 

1601 """ 

1602 steps = np.zeros((1 + len(args), max(2 * len(x) - 1, 0))) 

1603 # In all `pts_to_*step` functions, only assign once using *x* and *args*, 

1604 # as converting to an array may be expensive. 

1605 steps[0, 0::2] = x 

1606 steps[0, 1::2] = steps[0, 0:-2:2] 

1607 steps[1:, 0::2] = args 

1608 steps[1:, 1::2] = steps[1:, 2::2] 

1609 return steps 

1610 

1611 

1612def pts_to_poststep(x, *args): 

1613 """ 

1614 Convert continuous line to post-steps. 

1615 

1616 Given a set of ``N`` points convert to ``2N + 1`` points, which when 

1617 connected linearly give a step function which changes values at the end of 

1618 the intervals. 

1619 

1620 Parameters 

1621 ---------- 

1622 x : array 

1623 The x location of the steps. May be empty. 

1624 

1625 y1, ..., yp : array 

1626 y arrays to be turned into steps; all must be the same length as ``x``. 

1627 

1628 Returns 

1629 ------- 

1630 array 

1631 The x and y values converted to steps in the same order as the input; 

1632 can be unpacked as ``x_out, y1_out, ..., yp_out``. If the input is 

1633 length ``N``, each of these arrays will be length ``2N + 1``. For 

1634 ``N=0``, the length will be 0. 

1635 

1636 Examples 

1637 -------- 

1638 >>> x_s, y1_s, y2_s = pts_to_poststep(x, y1, y2) 

1639 """ 

1640 steps = np.zeros((1 + len(args), max(2 * len(x) - 1, 0))) 

1641 steps[0, 0::2] = x 

1642 steps[0, 1::2] = steps[0, 2::2] 

1643 steps[1:, 0::2] = args 

1644 steps[1:, 1::2] = steps[1:, 0:-2:2] 

1645 return steps 

1646 

1647 

1648def pts_to_midstep(x, *args): 

1649 """ 

1650 Convert continuous line to mid-steps. 

1651 

1652 Given a set of ``N`` points convert to ``2N`` points which when connected 

1653 linearly give a step function which changes values at the middle of the 

1654 intervals. 

1655 

1656 Parameters 

1657 ---------- 

1658 x : array 

1659 The x location of the steps. May be empty. 

1660 

1661 y1, ..., yp : array 

1662 y arrays to be turned into steps; all must be the same length as 

1663 ``x``. 

1664 

1665 Returns 

1666 ------- 

1667 array 

1668 The x and y values converted to steps in the same order as the input; 

1669 can be unpacked as ``x_out, y1_out, ..., yp_out``. If the input is 

1670 length ``N``, each of these arrays will be length ``2N``. 

1671 

1672 Examples 

1673 -------- 

1674 >>> x_s, y1_s, y2_s = pts_to_midstep(x, y1, y2) 

1675 """ 

1676 steps = np.zeros((1 + len(args), 2 * len(x))) 

1677 x = np.asanyarray(x) 

1678 steps[0, 1:-1:2] = steps[0, 2::2] = (x[:-1] + x[1:]) / 2 

1679 steps[0, :1] = x[:1] # Also works for zero-sized input. 

1680 steps[0, -1:] = x[-1:] 

1681 steps[1:, 0::2] = args 

1682 steps[1:, 1::2] = steps[1:, 0::2] 

1683 return steps 

1684 

1685 

1686STEP_LOOKUP_MAP = {'default': lambda x, y: (x, y), 

1687 'steps': pts_to_prestep, 

1688 'steps-pre': pts_to_prestep, 

1689 'steps-post': pts_to_poststep, 

1690 'steps-mid': pts_to_midstep} 

1691 

1692 

1693def index_of(y): 

1694 """ 

1695 A helper function to create reasonable x values for the given *y*. 

1696 

1697 This is used for plotting (x, y) if x values are not explicitly given. 

1698 

1699 First try ``y.index`` (assuming *y* is a `pandas.Series`), if that 

1700 fails, use ``range(len(y))``. 

1701 

1702 This will be extended in the future to deal with more types of 

1703 labeled data. 

1704 

1705 Parameters 

1706 ---------- 

1707 y : float or array-like 

1708 

1709 Returns 

1710 ------- 

1711 x, y : ndarray 

1712 The x and y values to plot. 

1713 """ 

1714 try: 

1715 return y.index.to_numpy(), y.to_numpy() 

1716 except AttributeError: 

1717 pass 

1718 try: 

1719 y = _check_1d(y) 

1720 except (VisibleDeprecationWarning, ValueError): 

1721 # NumPy 1.19 will warn on ragged input, and we can't actually use it. 

1722 pass 

1723 else: 

1724 return np.arange(y.shape[0], dtype=float), y 

1725 raise ValueError('Input could not be cast to an at-least-1D NumPy array') 

1726 

1727 

1728def safe_first_element(obj): 

1729 """ 

1730 Return the first element in *obj*. 

1731 

1732 This is a type-independent way of obtaining the first element, 

1733 supporting both index access and the iterator protocol. 

1734 """ 

1735 if isinstance(obj, collections.abc.Iterator): 

1736 # needed to accept `array.flat` as input. 

1737 # np.flatiter reports as an instance of collections.Iterator but can still be 

1738 # indexed via []. This has the side effect of re-setting the iterator, but 

1739 # that is acceptable. 

1740 try: 

1741 return obj[0] 

1742 except TypeError: 

1743 pass 

1744 raise RuntimeError("matplotlib does not support generators as input") 

1745 return next(iter(obj)) 

1746 

1747 

1748def _safe_first_finite(obj): 

1749 """ 

1750 Return the first finite element in *obj* if one is available and skip_nonfinite is 

1751 True. Otherwise, return the first element. 

1752 

1753 This is a method for internal use. 

1754 

1755 This is a type-independent way of obtaining the first finite element, supporting 

1756 both index access and the iterator protocol. 

1757 """ 

1758 def safe_isfinite(val): 

1759 if val is None: 

1760 return False 

1761 try: 

1762 return math.isfinite(val) 

1763 except (TypeError, ValueError): 

1764 # if the outer object is 2d, then val is a 1d array, and 

1765 # - math.isfinite(numpy.zeros(3)) raises TypeError 

1766 # - math.isfinite(torch.zeros(3)) raises ValueError 

1767 pass 

1768 try: 

1769 return np.isfinite(val) if np.isscalar(val) else True 

1770 except TypeError: 

1771 # This is something that NumPy cannot make heads or tails of, 

1772 # assume "finite" 

1773 return True 

1774 

1775 if isinstance(obj, np.flatiter): 

1776 # TODO do the finite filtering on this 

1777 return obj[0] 

1778 elif isinstance(obj, collections.abc.Iterator): 

1779 raise RuntimeError("matplotlib does not support generators as input") 

1780 else: 

1781 for val in obj: 

1782 if safe_isfinite(val): 

1783 return val 

1784 return safe_first_element(obj) 

1785 

1786 

1787def sanitize_sequence(data): 

1788 """ 

1789 Convert dictview objects to list. Other inputs are returned unchanged. 

1790 """ 

1791 return (list(data) if isinstance(data, collections.abc.MappingView) 

1792 else data) 

1793 

1794 

1795def normalize_kwargs(kw, alias_mapping=None): 

1796 """ 

1797 Helper function to normalize kwarg inputs. 

1798 

1799 Parameters 

1800 ---------- 

1801 kw : dict or None 

1802 A dict of keyword arguments. None is explicitly supported and treated 

1803 as an empty dict, to support functions with an optional parameter of 

1804 the form ``props=None``. 

1805 

1806 alias_mapping : dict or Artist subclass or Artist instance, optional 

1807 A mapping between a canonical name to a list of aliases, in order of 

1808 precedence from lowest to highest. 

1809 

1810 If the canonical value is not in the list it is assumed to have the 

1811 highest priority. 

1812 

1813 If an Artist subclass or instance is passed, use its properties alias 

1814 mapping. 

1815 

1816 Raises 

1817 ------ 

1818 TypeError 

1819 To match what Python raises if invalid arguments/keyword arguments are 

1820 passed to a callable. 

1821 """ 

1822 from matplotlib.artist import Artist 

1823 

1824 if kw is None: 

1825 return {} 

1826 

1827 # deal with default value of alias_mapping 

1828 if alias_mapping is None: 

1829 alias_mapping = {} 

1830 elif (isinstance(alias_mapping, type) and issubclass(alias_mapping, Artist) 

1831 or isinstance(alias_mapping, Artist)): 

1832 alias_mapping = getattr(alias_mapping, "_alias_map", {}) 

1833 

1834 to_canonical = {alias: canonical 

1835 for canonical, alias_list in alias_mapping.items() 

1836 for alias in alias_list} 

1837 canonical_to_seen = {} 

1838 ret = {} # output dictionary 

1839 

1840 for k, v in kw.items(): 

1841 canonical = to_canonical.get(k, k) 

1842 if canonical in canonical_to_seen: 

1843 raise TypeError(f"Got both {canonical_to_seen[canonical]!r} and " 

1844 f"{k!r}, which are aliases of one another") 

1845 canonical_to_seen[canonical] = k 

1846 ret[canonical] = v 

1847 

1848 return ret 

1849 

1850 

1851@contextlib.contextmanager 

1852def _lock_path(path): 

1853 """ 

1854 Context manager for locking a path. 

1855 

1856 Usage:: 

1857 

1858 with _lock_path(path): 

1859 ... 

1860 

1861 Another thread or process that attempts to lock the same path will wait 

1862 until this context manager is exited. 

1863 

1864 The lock is implemented by creating a temporary file in the parent 

1865 directory, so that directory must exist and be writable. 

1866 """ 

1867 path = Path(path) 

1868 lock_path = path.with_name(path.name + ".matplotlib-lock") 

1869 retries = 50 

1870 sleeptime = 0.1 

1871 for _ in range(retries): 

1872 try: 

1873 with lock_path.open("xb"): 

1874 break 

1875 except FileExistsError: 

1876 time.sleep(sleeptime) 

1877 else: 

1878 raise TimeoutError("""\ 

1879Lock error: Matplotlib failed to acquire the following lock file: 

1880 {} 

1881This maybe due to another process holding this lock file. If you are sure no 

1882other Matplotlib process is running, remove this file and try again.""".format( 

1883 lock_path)) 

1884 try: 

1885 yield 

1886 finally: 

1887 lock_path.unlink() 

1888 

1889 

1890def _topmost_artist( 

1891 artists, 

1892 _cached_max=functools.partial(max, key=operator.attrgetter("zorder"))): 

1893 """ 

1894 Get the topmost artist of a list. 

1895 

1896 In case of a tie, return the *last* of the tied artists, as it will be 

1897 drawn on top of the others. `max` returns the first maximum in case of 

1898 ties, so we need to iterate over the list in reverse order. 

1899 """ 

1900 return _cached_max(reversed(artists)) 

1901 

1902 

1903def _str_equal(obj, s): 

1904 """ 

1905 Return whether *obj* is a string equal to string *s*. 

1906 

1907 This helper solely exists to handle the case where *obj* is a numpy array, 

1908 because in such cases, a naive ``obj == s`` would yield an array, which 

1909 cannot be used in a boolean context. 

1910 """ 

1911 return isinstance(obj, str) and obj == s 

1912 

1913 

1914def _str_lower_equal(obj, s): 

1915 """ 

1916 Return whether *obj* is a string equal, when lowercased, to string *s*. 

1917 

1918 This helper solely exists to handle the case where *obj* is a numpy array, 

1919 because in such cases, a naive ``obj == s`` would yield an array, which 

1920 cannot be used in a boolean context. 

1921 """ 

1922 return isinstance(obj, str) and obj.lower() == s 

1923 

1924 

1925def _array_perimeter(arr): 

1926 """ 

1927 Get the elements on the perimeter of *arr*. 

1928 

1929 Parameters 

1930 ---------- 

1931 arr : ndarray, shape (M, N) 

1932 The input array. 

1933 

1934 Returns 

1935 ------- 

1936 ndarray, shape (2*(M - 1) + 2*(N - 1),) 

1937 The elements on the perimeter of the array:: 

1938 

1939 [arr[0, 0], ..., arr[0, -1], ..., arr[-1, -1], ..., arr[-1, 0], ...] 

1940 

1941 Examples 

1942 -------- 

1943 >>> i, j = np.ogrid[:3, :4] 

1944 >>> a = i*10 + j 

1945 >>> a 

1946 array([[ 0, 1, 2, 3], 

1947 [10, 11, 12, 13], 

1948 [20, 21, 22, 23]]) 

1949 >>> _array_perimeter(a) 

1950 array([ 0, 1, 2, 3, 13, 23, 22, 21, 20, 10]) 

1951 """ 

1952 # note we use Python's half-open ranges to avoid repeating 

1953 # the corners 

1954 forward = np.s_[0:-1] # [0 ... -1) 

1955 backward = np.s_[-1:0:-1] # [-1 ... 0) 

1956 return np.concatenate(( 

1957 arr[0, forward], 

1958 arr[forward, -1], 

1959 arr[-1, backward], 

1960 arr[backward, 0], 

1961 )) 

1962 

1963 

1964def _unfold(arr, axis, size, step): 

1965 """ 

1966 Append an extra dimension containing sliding windows along *axis*. 

1967 

1968 All windows are of size *size* and begin with every *step* elements. 

1969 

1970 Parameters 

1971 ---------- 

1972 arr : ndarray, shape (N_1, ..., N_k) 

1973 The input array 

1974 axis : int 

1975 Axis along which the windows are extracted 

1976 size : int 

1977 Size of the windows 

1978 step : int 

1979 Stride between first elements of subsequent windows. 

1980 

1981 Returns 

1982 ------- 

1983 ndarray, shape (N_1, ..., 1 + (N_axis-size)/step, ..., N_k, size) 

1984 

1985 Examples 

1986 -------- 

1987 >>> i, j = np.ogrid[:3, :7] 

1988 >>> a = i*10 + j 

1989 >>> a 

1990 array([[ 0, 1, 2, 3, 4, 5, 6], 

1991 [10, 11, 12, 13, 14, 15, 16], 

1992 [20, 21, 22, 23, 24, 25, 26]]) 

1993 >>> _unfold(a, axis=1, size=3, step=2) 

1994 array([[[ 0, 1, 2], 

1995 [ 2, 3, 4], 

1996 [ 4, 5, 6]], 

1997 [[10, 11, 12], 

1998 [12, 13, 14], 

1999 [14, 15, 16]], 

2000 [[20, 21, 22], 

2001 [22, 23, 24], 

2002 [24, 25, 26]]]) 

2003 """ 

2004 new_shape = [*arr.shape, size] 

2005 new_strides = [*arr.strides, arr.strides[axis]] 

2006 new_shape[axis] = (new_shape[axis] - size) // step + 1 

2007 new_strides[axis] = new_strides[axis] * step 

2008 return np.lib.stride_tricks.as_strided(arr, 

2009 shape=new_shape, 

2010 strides=new_strides, 

2011 writeable=False) 

2012 

2013 

2014def _array_patch_perimeters(x, rstride, cstride): 

2015 """ 

2016 Extract perimeters of patches from *arr*. 

2017 

2018 Extracted patches are of size (*rstride* + 1) x (*cstride* + 1) and 

2019 share perimeters with their neighbors. The ordering of the vertices matches 

2020 that returned by ``_array_perimeter``. 

2021 

2022 Parameters 

2023 ---------- 

2024 x : ndarray, shape (N, M) 

2025 Input array 

2026 rstride : int 

2027 Vertical (row) stride between corresponding elements of each patch 

2028 cstride : int 

2029 Horizontal (column) stride between corresponding elements of each patch 

2030 

2031 Returns 

2032 ------- 

2033 ndarray, shape (N/rstride * M/cstride, 2 * (rstride + cstride)) 

2034 """ 

2035 assert rstride > 0 and cstride > 0 

2036 assert (x.shape[0] - 1) % rstride == 0 

2037 assert (x.shape[1] - 1) % cstride == 0 

2038 # We build up each perimeter from four half-open intervals. Here is an 

2039 # illustrated explanation for rstride == cstride == 3 

2040 # 

2041 # T T T R 

2042 # L R 

2043 # L R 

2044 # L B B B 

2045 # 

2046 # where T means that this element will be in the top array, R for right, 

2047 # B for bottom and L for left. Each of the arrays below has a shape of: 

2048 # 

2049 # (number of perimeters that can be extracted vertically, 

2050 # number of perimeters that can be extracted horizontally, 

2051 # cstride for top and bottom and rstride for left and right) 

2052 # 

2053 # Note that _unfold doesn't incur any memory copies, so the only costly 

2054 # operation here is the np.concatenate. 

2055 top = _unfold(x[:-1:rstride, :-1], 1, cstride, cstride) 

2056 bottom = _unfold(x[rstride::rstride, 1:], 1, cstride, cstride)[..., ::-1] 

2057 right = _unfold(x[:-1, cstride::cstride], 0, rstride, rstride) 

2058 left = _unfold(x[1:, :-1:cstride], 0, rstride, rstride)[..., ::-1] 

2059 return (np.concatenate((top, right, bottom, left), axis=2) 

2060 .reshape(-1, 2 * (rstride + cstride))) 

2061 

2062 

2063@contextlib.contextmanager 

2064def _setattr_cm(obj, **kwargs): 

2065 """ 

2066 Temporarily set some attributes; restore original state at context exit. 

2067 """ 

2068 sentinel = object() 

2069 origs = {} 

2070 for attr in kwargs: 

2071 orig = getattr(obj, attr, sentinel) 

2072 if attr in obj.__dict__ or orig is sentinel: 

2073 # if we are pulling from the instance dict or the object 

2074 # does not have this attribute we can trust the above 

2075 origs[attr] = orig 

2076 else: 

2077 # if the attribute is not in the instance dict it must be 

2078 # from the class level 

2079 cls_orig = getattr(type(obj), attr) 

2080 # if we are dealing with a property (but not a general descriptor) 

2081 # we want to set the original value back. 

2082 if isinstance(cls_orig, property): 

2083 origs[attr] = orig 

2084 # otherwise this is _something_ we are going to shadow at 

2085 # the instance dict level from higher up in the MRO. We 

2086 # are going to assume we can delattr(obj, attr) to clean 

2087 # up after ourselves. It is possible that this code will 

2088 # fail if used with a non-property custom descriptor which 

2089 # implements __set__ (and __delete__ does not act like a 

2090 # stack). However, this is an internal tool and we do not 

2091 # currently have any custom descriptors. 

2092 else: 

2093 origs[attr] = sentinel 

2094 

2095 try: 

2096 for attr, val in kwargs.items(): 

2097 setattr(obj, attr, val) 

2098 yield 

2099 finally: 

2100 for attr, orig in origs.items(): 

2101 if orig is sentinel: 

2102 delattr(obj, attr) 

2103 else: 

2104 setattr(obj, attr, orig) 

2105 

2106 

2107class _OrderedSet(collections.abc.MutableSet): 

2108 def __init__(self): 

2109 self._od = collections.OrderedDict() 

2110 

2111 def __contains__(self, key): 

2112 return key in self._od 

2113 

2114 def __iter__(self): 

2115 return iter(self._od) 

2116 

2117 def __len__(self): 

2118 return len(self._od) 

2119 

2120 def add(self, key): 

2121 self._od.pop(key, None) 

2122 self._od[key] = None 

2123 

2124 def discard(self, key): 

2125 self._od.pop(key, None) 

2126 

2127 

2128# Agg's buffers are unmultiplied RGBA8888, which neither PyQt<=5.1 nor cairo 

2129# support; however, both do support premultiplied ARGB32. 

2130 

2131 

2132def _premultiplied_argb32_to_unmultiplied_rgba8888(buf): 

2133 """ 

2134 Convert a premultiplied ARGB32 buffer to an unmultiplied RGBA8888 buffer. 

2135 """ 

2136 rgba = np.take( # .take() ensures C-contiguity of the result. 

2137 buf, 

2138 [2, 1, 0, 3] if sys.byteorder == "little" else [1, 2, 3, 0], axis=2) 

2139 rgb = rgba[..., :-1] 

2140 alpha = rgba[..., -1] 

2141 # Un-premultiply alpha. The formula is the same as in cairo-png.c. 

2142 mask = alpha != 0 

2143 for channel in np.rollaxis(rgb, -1): 

2144 channel[mask] = ( 

2145 (channel[mask].astype(int) * 255 + alpha[mask] // 2) 

2146 // alpha[mask]) 

2147 return rgba 

2148 

2149 

2150def _unmultiplied_rgba8888_to_premultiplied_argb32(rgba8888): 

2151 """ 

2152 Convert an unmultiplied RGBA8888 buffer to a premultiplied ARGB32 buffer. 

2153 """ 

2154 if sys.byteorder == "little": 

2155 argb32 = np.take(rgba8888, [2, 1, 0, 3], axis=2) 

2156 rgb24 = argb32[..., :-1] 

2157 alpha8 = argb32[..., -1:] 

2158 else: 

2159 argb32 = np.take(rgba8888, [3, 0, 1, 2], axis=2) 

2160 alpha8 = argb32[..., :1] 

2161 rgb24 = argb32[..., 1:] 

2162 # Only bother premultiplying when the alpha channel is not fully opaque, 

2163 # as the cost is not negligible. The unsafe cast is needed to do the 

2164 # multiplication in-place in an integer buffer. 

2165 if alpha8.min() != 0xff: 

2166 np.multiply(rgb24, alpha8 / 0xff, out=rgb24, casting="unsafe") 

2167 return argb32 

2168 

2169 

2170def _get_nonzero_slices(buf): 

2171 """ 

2172 Return the bounds of the nonzero region of a 2D array as a pair of slices. 

2173 

2174 ``buf[_get_nonzero_slices(buf)]`` is the smallest sub-rectangle in *buf* 

2175 that encloses all non-zero entries in *buf*. If *buf* is fully zero, then 

2176 ``(slice(0, 0), slice(0, 0))`` is returned. 

2177 """ 

2178 x_nz, = buf.any(axis=0).nonzero() 

2179 y_nz, = buf.any(axis=1).nonzero() 

2180 if len(x_nz) and len(y_nz): 

2181 l, r = x_nz[[0, -1]] 

2182 b, t = y_nz[[0, -1]] 

2183 return slice(b, t + 1), slice(l, r + 1) 

2184 else: 

2185 return slice(0, 0), slice(0, 0) 

2186 

2187 

2188def _pformat_subprocess(command): 

2189 """Pretty-format a subprocess command for printing/logging purposes.""" 

2190 return (command if isinstance(command, str) 

2191 else " ".join(shlex.quote(os.fspath(arg)) for arg in command)) 

2192 

2193 

2194def _check_and_log_subprocess(command, logger, **kwargs): 

2195 """ 

2196 Run *command*, returning its stdout output if it succeeds. 

2197 

2198 If it fails (exits with nonzero return code), raise an exception whose text 

2199 includes the failed command and captured stdout and stderr output. 

2200 

2201 Regardless of the return code, the command is logged at DEBUG level on 

2202 *logger*. In case of success, the output is likewise logged. 

2203 """ 

2204 logger.debug('%s', _pformat_subprocess(command)) 

2205 proc = subprocess.run(command, capture_output=True, **kwargs) 

2206 if proc.returncode: 

2207 stdout = proc.stdout 

2208 if isinstance(stdout, bytes): 

2209 stdout = stdout.decode() 

2210 stderr = proc.stderr 

2211 if isinstance(stderr, bytes): 

2212 stderr = stderr.decode() 

2213 raise RuntimeError( 

2214 f"The command\n" 

2215 f" {_pformat_subprocess(command)}\n" 

2216 f"failed and generated the following output:\n" 

2217 f"{stdout}\n" 

2218 f"and the following error:\n" 

2219 f"{stderr}") 

2220 if proc.stdout: 

2221 logger.debug("stdout:\n%s", proc.stdout) 

2222 if proc.stderr: 

2223 logger.debug("stderr:\n%s", proc.stderr) 

2224 return proc.stdout 

2225 

2226 

2227def _setup_new_guiapp(): 

2228 """ 

2229 Perform OS-dependent setup when Matplotlib creates a new GUI application. 

2230 """ 

2231 # Windows: If not explicit app user model id has been set yet (so we're not 

2232 # already embedded), then set it to "matplotlib", so that taskbar icons are 

2233 # correct. 

2234 try: 

2235 _c_internal_utils.Win32_GetCurrentProcessExplicitAppUserModelID() 

2236 except OSError: 

2237 _c_internal_utils.Win32_SetCurrentProcessExplicitAppUserModelID( 

2238 "matplotlib") 

2239 

2240 

2241def _format_approx(number, precision): 

2242 """ 

2243 Format the number with at most the number of decimals given as precision. 

2244 Remove trailing zeros and possibly the decimal point. 

2245 """ 

2246 return f'{number:.{precision}f}'.rstrip('0').rstrip('.') or '0' 

2247 

2248 

2249def _g_sig_digits(value, delta): 

2250 """ 

2251 Return the number of significant digits to %g-format *value*, assuming that 

2252 it is known with an error of *delta*. 

2253 """ 

2254 if delta == 0: 

2255 if value == 0: 

2256 # if both value and delta are 0, np.spacing below returns 5e-324 

2257 # which results in rather silly results 

2258 return 3 

2259 # delta = 0 may occur when trying to format values over a tiny range; 

2260 # in that case, replace it by the distance to the closest float. 

2261 delta = abs(np.spacing(value)) 

2262 # If e.g. value = 45.67 and delta = 0.02, then we want to round to 2 digits 

2263 # after the decimal point (floor(log10(0.02)) = -2); 45.67 contributes 2 

2264 # digits before the decimal point (floor(log10(45.67)) + 1 = 2): the total 

2265 # is 4 significant digits. A value of 0 contributes 1 "digit" before the 

2266 # decimal point. 

2267 # For inf or nan, the precision doesn't matter. 

2268 return max( 

2269 0, 

2270 (math.floor(math.log10(abs(value))) + 1 if value else 1) 

2271 - math.floor(math.log10(delta))) if math.isfinite(value) else 0 

2272 

2273 

2274def _unikey_or_keysym_to_mplkey(unikey, keysym): 

2275 """ 

2276 Convert a Unicode key or X keysym to a Matplotlib key name. 

2277 

2278 The Unicode key is checked first; this avoids having to list most printable 

2279 keysyms such as ``EuroSign``. 

2280 """ 

2281 # For non-printable characters, gtk3 passes "\0" whereas tk passes an "". 

2282 if unikey and unikey.isprintable(): 

2283 return unikey 

2284 key = keysym.lower() 

2285 if key.startswith("kp_"): # keypad_x (including kp_enter). 

2286 key = key[3:] 

2287 if key.startswith("page_"): # page_{up,down} 

2288 key = key.replace("page_", "page") 

2289 if key.endswith(("_l", "_r")): # alt_l, ctrl_l, shift_l. 

2290 key = key[:-2] 

2291 if sys.platform == "darwin" and key == "meta": 

2292 # meta should be reported as command on mac 

2293 key = "cmd" 

2294 key = { 

2295 "return": "enter", 

2296 "prior": "pageup", # Used by tk. 

2297 "next": "pagedown", # Used by tk. 

2298 }.get(key, key) 

2299 return key 

2300 

2301 

2302@functools.cache 

2303def _make_class_factory(mixin_class, fmt, attr_name=None): 

2304 """ 

2305 Return a function that creates picklable classes inheriting from a mixin. 

2306 

2307 After :: 

2308 

2309 factory = _make_class_factory(FooMixin, fmt, attr_name) 

2310 FooAxes = factory(Axes) 

2311 

2312 ``Foo`` is a class that inherits from ``FooMixin`` and ``Axes`` and **is 

2313 picklable** (picklability is what differentiates this from a plain call to 

2314 `type`). Its ``__name__`` is set to ``fmt.format(Axes.__name__)`` and the 

2315 base class is stored in the ``attr_name`` attribute, if not None. 

2316 

2317 Moreover, the return value of ``factory`` is memoized: calls with the same 

2318 ``Axes`` class always return the same subclass. 

2319 """ 

2320 

2321 @functools.cache 

2322 def class_factory(axes_class): 

2323 # if we have already wrapped this class, declare victory! 

2324 if issubclass(axes_class, mixin_class): 

2325 return axes_class 

2326 

2327 # The parameter is named "axes_class" for backcompat but is really just 

2328 # a base class; no axes semantics are used. 

2329 base_class = axes_class 

2330 

2331 class subcls(mixin_class, base_class): 

2332 # Better approximation than __module__ = "matplotlib.cbook". 

2333 __module__ = mixin_class.__module__ 

2334 

2335 def __reduce__(self): 

2336 return (_picklable_class_constructor, 

2337 (mixin_class, fmt, attr_name, base_class), 

2338 self.__getstate__()) 

2339 

2340 subcls.__name__ = subcls.__qualname__ = fmt.format(base_class.__name__) 

2341 if attr_name is not None: 

2342 setattr(subcls, attr_name, base_class) 

2343 return subcls 

2344 

2345 class_factory.__module__ = mixin_class.__module__ 

2346 return class_factory 

2347 

2348 

2349def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class): 

2350 """Internal helper for _make_class_factory.""" 

2351 factory = _make_class_factory(mixin_class, fmt, attr_name) 

2352 cls = factory(base_class) 

2353 return cls.__new__(cls) 

2354 

2355 

2356def _is_torch_array(x): 

2357 """Check if 'x' is a PyTorch Tensor.""" 

2358 try: 

2359 # we're intentionally not attempting to import torch. If somebody 

2360 # has created a torch array, torch should already be in sys.modules 

2361 return isinstance(x, sys.modules['torch'].Tensor) 

2362 except Exception: # TypeError, KeyError, AttributeError, maybe others? 

2363 # we're attempting to access attributes on imported modules which 

2364 # may have arbitrary user code, so we deliberately catch all exceptions 

2365 return False 

2366 

2367 

2368def _is_jax_array(x): 

2369 """Check if 'x' is a JAX Array.""" 

2370 try: 

2371 # we're intentionally not attempting to import jax. If somebody 

2372 # has created a jax array, jax should already be in sys.modules 

2373 return isinstance(x, sys.modules['jax'].Array) 

2374 except Exception: # TypeError, KeyError, AttributeError, maybe others? 

2375 # we're attempting to access attributes on imported modules which 

2376 # may have arbitrary user code, so we deliberately catch all exceptions 

2377 return False 

2378 

2379 

2380def _unpack_to_numpy(x): 

2381 """Internal helper to extract data from e.g. pandas and xarray objects.""" 

2382 if isinstance(x, np.ndarray): 

2383 # If numpy, return directly 

2384 return x 

2385 if hasattr(x, 'to_numpy'): 

2386 # Assume that any to_numpy() method actually returns a numpy array 

2387 return x.to_numpy() 

2388 if hasattr(x, 'values'): 

2389 xtmp = x.values 

2390 # For example a dict has a 'values' attribute, but it is not a property 

2391 # so in this case we do not want to return a function 

2392 if isinstance(xtmp, np.ndarray): 

2393 return xtmp 

2394 if _is_torch_array(x) or _is_jax_array(x): 

2395 xtmp = x.__array__() 

2396 

2397 # In case __array__() method does not return a numpy array in future 

2398 if isinstance(xtmp, np.ndarray): 

2399 return xtmp 

2400 return x 

2401 

2402 

2403def _auto_format_str(fmt, value): 

2404 """ 

2405 Apply *value* to the format string *fmt*. 

2406 

2407 This works both with unnamed %-style formatting and 

2408 unnamed {}-style formatting. %-style formatting has priority. 

2409 If *fmt* is %-style formattable that will be used. Otherwise, 

2410 {}-formatting is applied. Strings without formatting placeholders 

2411 are passed through as is. 

2412 

2413 Examples 

2414 -------- 

2415 >>> _auto_format_str('%.2f m', 0.2) 

2416 '0.20 m' 

2417 >>> _auto_format_str('{} m', 0.2) 

2418 '0.2 m' 

2419 >>> _auto_format_str('const', 0.2) 

2420 'const' 

2421 >>> _auto_format_str('%d or {}', 0.2) 

2422 '0 or {}' 

2423 """ 

2424 try: 

2425 return fmt % (value,) 

2426 except (TypeError, ValueError): 

2427 return fmt.format(value)