Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/utils.py: 30%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

935 statements  

1from __future__ import annotations 

2 

3import codecs 

4import contextlib 

5import functools 

6import gc 

7import inspect 

8import os 

9import re 

10import shutil 

11import sys 

12import tempfile 

13import types 

14import uuid 

15import warnings 

16from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, Set 

17from contextlib import ContextDecorator, contextmanager, nullcontext, suppress 

18from datetime import datetime, timedelta 

19from errno import ENOENT 

20from functools import wraps 

21from importlib import import_module 

22from numbers import Integral, Number 

23from operator import add 

24from threading import Lock 

25from typing import Any, ClassVar, Literal, TypeVar, cast, overload 

26from weakref import WeakValueDictionary 

27 

28import tlz as toolz 

29 

30from dask import config 

31from dask.typing import no_default 

32 

33K = TypeVar("K") 

34V = TypeVar("V") 

35T = TypeVar("T") 

36 

37# used in decorators to preserve the signature of the function it decorates 

38# see https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators 

39FuncType = Callable[..., Any] 

40F = TypeVar("F", bound=FuncType) 

41 

42system_encoding = sys.getdefaultencoding() 

43if system_encoding == "ascii": 

44 system_encoding = "utf-8" 

45 

46 

47def apply(func, args, kwargs=None): 

48 """Apply a function given its positional and keyword arguments. 

49 

50 Equivalent to ``func(*args, **kwargs)`` 

51 Most Dask users will never need to use the ``apply`` function. 

52 It is typically only used by people who need to inject 

53 keyword argument values into a low level Dask task graph. 

54 

55 Parameters 

56 ---------- 

57 func : callable 

58 The function you want to apply. 

59 args : tuple 

60 A tuple containing all the positional arguments needed for ``func`` 

61 (eg: ``(arg_1, arg_2, arg_3)``) 

62 kwargs : dict, optional 

63 A dictionary mapping the keyword arguments 

64 (eg: ``{"kwarg_1": value, "kwarg_2": value}`` 

65 

66 Examples 

67 -------- 

68 >>> from dask.utils import apply 

69 >>> def add(number, second_number=5): 

70 ... return number + second_number 

71 ... 

72 >>> apply(add, (10,), {"second_number": 2}) # equivalent to add(*args, **kwargs) 

73 12 

74 

75 >>> task = apply(add, (10,), {"second_number": 2}) 

76 >>> dsk = {'task-name': task} # adds the task to a low level Dask task graph 

77 """ 

78 if kwargs: 

79 return func(*args, **kwargs) 

80 else: 

81 return func(*args) 

82 

83 

84def _deprecated( 

85 *, 

86 version: str | None = None, 

87 after_version: str | None = None, 

88 message: str | None = None, 

89 use_instead: str | None = None, 

90 category: type[Warning] = FutureWarning, 

91): 

92 """Decorator to mark a function as deprecated 

93 

94 Parameters 

95 ---------- 

96 version : str, optional 

97 Version of Dask in which the function was deprecated. If specified, the version 

98 will be included in the default warning message. This should no longer be used 

99 after the introduction of automated versioning system. 

100 after_version : str, optional 

101 Version of Dask after which the function was deprecated. If specified, the 

102 version will be included in the default warning message. 

103 message : str, optional 

104 Custom warning message to raise. 

105 use_instead : str, optional 

106 Name of function to use in place of the deprecated function. 

107 If specified, this will be included in the default warning 

108 message. 

109 category : type[Warning], optional 

110 Type of warning to raise. Defaults to ``FutureWarning``. 

111 

112 Examples 

113 -------- 

114 

115 >>> from dask.utils import _deprecated 

116 >>> @_deprecated(after_version="X.Y.Z", use_instead="bar") 

117 ... def foo(): 

118 ... return "baz" 

119 """ 

120 

121 def decorator(func): 

122 if message is None: 

123 msg = f"{func.__name__} " 

124 if after_version is not None: 

125 msg += f"was deprecated after version {after_version} " 

126 elif version is not None: 

127 msg += f"was deprecated in version {version} " 

128 else: 

129 msg += "is deprecated " 

130 msg += "and will be removed in a future release." 

131 

132 if use_instead is not None: 

133 msg += f" Please use {use_instead} instead." 

134 else: 

135 msg = message 

136 

137 @functools.wraps(func) 

138 def wrapper(*args, **kwargs): 

139 warnings.warn(msg, category=category, stacklevel=2) 

140 return func(*args, **kwargs) 

141 

142 return wrapper 

143 

144 return decorator 

145 

146 

147def _deprecated_kwarg( 

148 old_arg_name: str, 

149 new_arg_name: str | None = None, 

150 mapping: Mapping[Any, Any] | Callable[[Any], Any] | None = None, 

151 stacklevel: int = 2, 

152 comment: str | None = None, 

153) -> Callable[[F], F]: 

154 """ 

155 Decorator to deprecate a keyword argument of a function. 

156 

157 Parameters 

158 ---------- 

159 old_arg_name : str 

160 Name of argument in function to deprecate 

161 new_arg_name : str, optional 

162 Name of preferred argument in function. Omit to warn that 

163 ``old_arg_name`` keyword is deprecated. 

164 mapping : dict or callable, optional 

165 If mapping is present, use it to translate old arguments to 

166 new arguments. A callable must do its own value checking; 

167 values not found in a dict will be forwarded unchanged. 

168 comment : str, optional 

169 Additional message to deprecation message. Useful to pass 

170 on suggestions with the deprecation warning. 

171 

172 Examples 

173 -------- 

174 The following deprecates 'cols', using 'columns' instead 

175 

176 >>> @_deprecated_kwarg(old_arg_name='cols', new_arg_name='columns') 

177 ... def f(columns=''): 

178 ... print(columns) 

179 ... 

180 >>> f(columns='should work ok') 

181 should work ok 

182 

183 >>> f(cols='should raise warning') # doctest: +SKIP 

184 FutureWarning: cols is deprecated, use columns instead 

185 warnings.warn(msg, FutureWarning) 

186 should raise warning 

187 

188 >>> f(cols='should error', columns="can\'t pass do both") # doctest: +SKIP 

189 TypeError: Can only specify 'cols' or 'columns', not both 

190 

191 >>> @_deprecated_kwarg('old', 'new', {'yes': True, 'no': False}) 

192 ... def f(new=False): 

193 ... print('yes!' if new else 'no!') 

194 ... 

195 >>> f(old='yes') # doctest: +SKIP 

196 FutureWarning: old='yes' is deprecated, use new=True instead 

197 warnings.warn(msg, FutureWarning) 

198 yes! 

199 

200 To raise a warning that a keyword will be removed entirely in the future 

201 

202 >>> @_deprecated_kwarg(old_arg_name='cols', new_arg_name=None) 

203 ... def f(cols='', another_param=''): 

204 ... print(cols) 

205 ... 

206 >>> f(cols='should raise warning') # doctest: +SKIP 

207 FutureWarning: the 'cols' keyword is deprecated and will be removed in a 

208 future version please takes steps to stop use of 'cols' 

209 should raise warning 

210 >>> f(another_param='should not raise warning') # doctest: +SKIP 

211 should not raise warning 

212 

213 >>> f(cols='should raise warning', another_param='') # doctest: +SKIP 

214 FutureWarning: the 'cols' keyword is deprecated and will be removed in a 

215 future version please takes steps to stop use of 'cols' 

216 should raise warning 

217 """ 

218 if mapping is not None and not hasattr(mapping, "get") and not callable(mapping): 

219 raise TypeError( 

220 "mapping from old to new argument values must be dict or callable!" 

221 ) 

222 

223 comment_ = f"\n{comment}" or "" 

224 

225 def _deprecated_kwarg(func: F) -> F: 

226 @wraps(func) 

227 def wrapper(*args, **kwargs) -> Callable[..., Any]: 

228 old_arg_value = kwargs.pop(old_arg_name, no_default) 

229 

230 if old_arg_value is not no_default: 

231 if new_arg_name is None: 

232 msg = ( 

233 f"the {old_arg_name!r} keyword is deprecated and " 

234 "will be removed in a future version. Please take " 

235 f"steps to stop the use of {old_arg_name!r}" 

236 ) + comment_ 

237 warnings.warn(msg, FutureWarning, stacklevel=stacklevel) 

238 kwargs[old_arg_name] = old_arg_value 

239 return func(*args, **kwargs) 

240 

241 elif mapping is not None: 

242 if callable(mapping): 

243 new_arg_value = mapping(old_arg_value) 

244 else: 

245 new_arg_value = mapping.get(old_arg_value, old_arg_value) 

246 msg = ( 

247 f"the {old_arg_name}={old_arg_value!r} keyword is " 

248 "deprecated, use " 

249 f"{new_arg_name}={new_arg_value!r} instead." 

250 ) 

251 else: 

252 new_arg_value = old_arg_value 

253 msg = ( 

254 f"the {old_arg_name!r} keyword is deprecated, " 

255 f"use {new_arg_name!r} instead." 

256 ) 

257 

258 warnings.warn(msg + comment_, FutureWarning, stacklevel=stacklevel) 

259 if kwargs.get(new_arg_name) is not None: 

260 msg = ( 

261 f"Can only specify {old_arg_name!r} " 

262 f"or {new_arg_name!r}, not both." 

263 ) 

264 raise TypeError(msg) 

265 kwargs[new_arg_name] = new_arg_value 

266 return func(*args, **kwargs) 

267 

268 return cast(F, wrapper) 

269 

270 return _deprecated_kwarg 

271 

272 

273def deepmap(func, *seqs): 

274 """Apply function inside nested lists 

275 

276 >>> inc = lambda x: x + 1 

277 >>> deepmap(inc, [[1, 2], [3, 4]]) 

278 [[2, 3], [4, 5]] 

279 

280 >>> add = lambda x, y: x + y 

281 >>> deepmap(add, [[1, 2], [3, 4]], [[10, 20], [30, 40]]) 

282 [[11, 22], [33, 44]] 

283 """ 

284 if isinstance(seqs[0], (list, Iterator)): 

285 return [deepmap(func, *items) for items in zip(*seqs)] 

286 else: 

287 return func(*seqs) 

288 

289 

290@_deprecated() 

291def homogeneous_deepmap(func, seq): 

292 if not seq: 

293 return seq 

294 n = 0 

295 tmp = seq 

296 while isinstance(tmp, list): 

297 n += 1 

298 tmp = tmp[0] 

299 

300 return ndeepmap(n, func, seq) 

301 

302 

303def ndeepmap(n, func, seq): 

304 """Call a function on every element within a nested container 

305 

306 >>> def inc(x): 

307 ... return x + 1 

308 >>> L = [[1, 2], [3, 4, 5]] 

309 >>> ndeepmap(2, inc, L) 

310 [[2, 3], [4, 5, 6]] 

311 """ 

312 if n == 1: 

313 return [func(item) for item in seq] 

314 elif n > 1: 

315 return [ndeepmap(n - 1, func, item) for item in seq] 

316 elif isinstance(seq, list): 

317 return func(seq[0]) 

318 else: 

319 return func(seq) 

320 

321 

322def import_required(mod_name, error_msg): 

323 """Attempt to import a required dependency. 

324 

325 Raises a RuntimeError if the requested module is not available. 

326 """ 

327 try: 

328 return import_module(mod_name) 

329 except ImportError as e: 

330 raise RuntimeError(error_msg) from e 

331 

332 

333@contextmanager 

334def tmpfile(extension="", dir=None): 

335 """ 

336 Function to create and return a unique temporary file with the given extension, if provided. 

337 

338 Parameters 

339 ---------- 

340 extension : str 

341 The extension of the temporary file to be created 

342 dir : str 

343 If ``dir`` is not None, the file will be created in that directory; otherwise, 

344 Python's default temporary directory is used. 

345 

346 Returns 

347 ------- 

348 out : str 

349 Path to the temporary file 

350 

351 See Also 

352 -------- 

353 NamedTemporaryFile : Built-in alternative for creating temporary files 

354 tmp_path : pytest fixture for creating a temporary directory unique to the test invocation 

355 

356 Notes 

357 ----- 

358 This context manager is particularly useful on Windows for opening temporary files multiple times. 

359 """ 

360 extension = extension.lstrip(".") 

361 if extension: 

362 extension = "." + extension 

363 handle, filename = tempfile.mkstemp(extension, dir=dir) 

364 os.close(handle) 

365 os.remove(filename) 

366 

367 try: 

368 yield filename 

369 finally: 

370 if os.path.exists(filename): 

371 with suppress(OSError): # sometimes we can't remove a generated temp file 

372 if os.path.isdir(filename): 

373 shutil.rmtree(filename) 

374 else: 

375 os.remove(filename) 

376 

377 

378@contextmanager 

379def tmpdir(dir=None): 

380 """ 

381 Function to create and return a unique temporary directory. 

382 

383 Parameters 

384 ---------- 

385 dir : str 

386 If ``dir`` is not None, the directory will be created in that directory; otherwise, 

387 Python's default temporary directory is used. 

388 

389 Returns 

390 ------- 

391 out : str 

392 Path to the temporary directory 

393 

394 Notes 

395 ----- 

396 This context manager is particularly useful on Windows for opening temporary directories multiple times. 

397 """ 

398 dirname = tempfile.mkdtemp(dir=dir) 

399 

400 try: 

401 yield dirname 

402 finally: 

403 if os.path.exists(dirname): 

404 if os.path.isdir(dirname): 

405 with suppress(OSError): 

406 shutil.rmtree(dirname) 

407 else: 

408 with suppress(OSError): 

409 os.remove(dirname) 

410 

411 

412@contextmanager 

413def filetext(text, extension="", open=open, mode="w"): 

414 with tmpfile(extension=extension) as filename: 

415 f = open(filename, mode=mode) 

416 try: 

417 f.write(text) 

418 finally: 

419 try: 

420 f.close() 

421 except AttributeError: 

422 pass 

423 

424 yield filename 

425 

426 

427@contextmanager 

428def changed_cwd(new_cwd): 

429 old_cwd = os.getcwd() 

430 os.chdir(new_cwd) 

431 try: 

432 yield 

433 finally: 

434 os.chdir(old_cwd) 

435 

436 

437@contextmanager 

438def tmp_cwd(dir=None): 

439 with tmpdir(dir) as dirname: 

440 with changed_cwd(dirname): 

441 yield dirname 

442 

443 

444class IndexCallable: 

445 """Provide getitem syntax for functions 

446 

447 >>> def inc(x): 

448 ... return x + 1 

449 

450 >>> I = IndexCallable(inc) 

451 >>> I[3] 

452 4 

453 """ 

454 

455 __slots__ = ("fn",) 

456 

457 def __init__(self, fn): 

458 self.fn = fn 

459 

460 def __getitem__(self, key): 

461 return self.fn(key) 

462 

463 

464@contextmanager 

465def filetexts(d, open=open, mode="t", use_tmpdir=True): 

466 """Dumps a number of textfiles to disk 

467 

468 Parameters 

469 ---------- 

470 d : dict 

471 a mapping from filename to text like {'a.csv': '1,1\n2,2'} 

472 

473 Since this is meant for use in tests, this context manager will 

474 automatically switch to a temporary current directory, to avoid 

475 race conditions when running tests in parallel. 

476 """ 

477 with tmp_cwd() if use_tmpdir else nullcontext(): 

478 for filename, text in d.items(): 

479 try: 

480 os.makedirs(os.path.dirname(filename)) 

481 except OSError: 

482 pass 

483 f = open(filename, "w" + mode) 

484 try: 

485 f.write(text) 

486 finally: 

487 try: 

488 f.close() 

489 except AttributeError: 

490 pass 

491 

492 yield list(d) 

493 

494 for filename in d: 

495 if os.path.exists(filename): 

496 with suppress(OSError): 

497 os.remove(filename) 

498 

499 

500def concrete(seq): 

501 """Make nested iterators concrete lists 

502 

503 >>> data = [[1, 2], [3, 4]] 

504 >>> seq = iter(map(iter, data)) 

505 >>> concrete(seq) 

506 [[1, 2], [3, 4]] 

507 """ 

508 if isinstance(seq, Iterator): 

509 seq = list(seq) 

510 if isinstance(seq, (tuple, list)): 

511 seq = list(map(concrete, seq)) 

512 return seq 

513 

514 

515def pseudorandom(n: int, p, random_state=None): 

516 """Pseudorandom array of integer indexes 

517 

518 >>> pseudorandom(5, [0.5, 0.5], random_state=123) 

519 array([1, 0, 0, 1, 1], dtype=int8) 

520 

521 >>> pseudorandom(10, [0.5, 0.2, 0.2, 0.1], random_state=5) 

522 array([0, 2, 0, 3, 0, 1, 2, 1, 0, 0], dtype=int8) 

523 """ 

524 import numpy as np 

525 

526 p = list(p) 

527 cp = np.cumsum([0] + p) 

528 assert np.allclose(1, cp[-1]) 

529 assert len(p) < 256 

530 

531 if not isinstance(random_state, np.random.RandomState): 

532 random_state = np.random.RandomState(random_state) 

533 

534 x = random_state.random_sample(n) 

535 out = np.empty(n, dtype="i1") 

536 

537 for i, (low, high) in enumerate(zip(cp[:-1], cp[1:])): 

538 out[(x >= low) & (x < high)] = i 

539 return out 

540 

541 

542def random_state_data(n: int, random_state=None) -> list: 

543 """Return a list of arrays that can initialize 

544 ``np.random.RandomState``. 

545 

546 Parameters 

547 ---------- 

548 n : int 

549 Number of arrays to return. 

550 random_state : int or np.random.RandomState, optional 

551 If an int, is used to seed a new ``RandomState``. 

552 """ 

553 import numpy as np 

554 

555 if not all( 

556 hasattr(random_state, attr) for attr in ["normal", "beta", "bytes", "uniform"] 

557 ): 

558 random_state = np.random.RandomState(random_state) 

559 

560 random_data = random_state.bytes(624 * n * 4) # `n * 624` 32-bit integers 

561 l = list(np.frombuffer(random_data, dtype="<u4").reshape((n, -1))) 

562 assert len(l) == n 

563 return l 

564 

565 

566def is_integer(i) -> bool: 

567 """ 

568 >>> is_integer(6) 

569 True 

570 >>> is_integer(42.0) 

571 True 

572 >>> is_integer('abc') 

573 False 

574 """ 

575 return isinstance(i, Integral) or (isinstance(i, float) and i.is_integer()) 

576 

577 

578ONE_ARITY_BUILTINS = { 

579 abs, 

580 all, 

581 any, 

582 ascii, 

583 bool, 

584 bytearray, 

585 bytes, 

586 callable, 

587 chr, 

588 classmethod, 

589 complex, 

590 dict, 

591 dir, 

592 enumerate, 

593 eval, 

594 float, 

595 format, 

596 frozenset, 

597 hash, 

598 hex, 

599 id, 

600 int, 

601 iter, 

602 len, 

603 list, 

604 max, 

605 min, 

606 next, 

607 oct, 

608 open, 

609 ord, 

610 range, 

611 repr, 

612 reversed, 

613 round, 

614 set, 

615 slice, 

616 sorted, 

617 staticmethod, 

618 str, 

619 sum, 

620 tuple, 

621 type, 

622 vars, 

623 zip, 

624 memoryview, 

625} 

626MULTI_ARITY_BUILTINS = { 

627 compile, 

628 delattr, 

629 divmod, 

630 filter, 

631 getattr, 

632 hasattr, 

633 isinstance, 

634 issubclass, 

635 map, 

636 pow, 

637 setattr, 

638} 

639 

640 

641def getargspec(func): 

642 """Version of inspect.getargspec that works with partial and warps.""" 

643 if isinstance(func, functools.partial): 

644 return getargspec(func.func) 

645 

646 func = getattr(func, "__wrapped__", func) 

647 if isinstance(func, type): 

648 return inspect.getfullargspec(func.__init__) 

649 else: 

650 return inspect.getfullargspec(func) 

651 

652 

653def takes_multiple_arguments(func, varargs=True): 

654 """Does this function take multiple arguments? 

655 

656 >>> def f(x, y): pass 

657 >>> takes_multiple_arguments(f) 

658 True 

659 

660 >>> def f(x): pass 

661 >>> takes_multiple_arguments(f) 

662 False 

663 

664 >>> def f(x, y=None): pass 

665 >>> takes_multiple_arguments(f) 

666 False 

667 

668 >>> def f(*args): pass 

669 >>> takes_multiple_arguments(f) 

670 True 

671 

672 >>> class Thing: 

673 ... def __init__(self, a): pass 

674 >>> takes_multiple_arguments(Thing) 

675 False 

676 

677 """ 

678 if func in ONE_ARITY_BUILTINS: 

679 return False 

680 elif func in MULTI_ARITY_BUILTINS: 

681 return True 

682 

683 try: 

684 spec = getargspec(func) 

685 except Exception: 

686 return False 

687 

688 try: 

689 is_constructor = spec.args[0] == "self" and isinstance(func, type) 

690 except Exception: 

691 is_constructor = False 

692 

693 if varargs and spec.varargs: 

694 return True 

695 

696 ndefaults = 0 if spec.defaults is None else len(spec.defaults) 

697 return len(spec.args) - ndefaults - is_constructor > 1 

698 

699 

700def get_named_args(func) -> list[str]: 

701 """Get all non ``*args/**kwargs`` arguments for a function""" 

702 s = inspect.signature(func) 

703 return [ 

704 n 

705 for n, p in s.parameters.items() 

706 if p.kind in [p.POSITIONAL_OR_KEYWORD, p.POSITIONAL_ONLY, p.KEYWORD_ONLY] 

707 ] 

708 

709 

710class Dispatch: 

711 """Simple single dispatch.""" 

712 

713 def __init__(self, name=None): 

714 self._lookup = {} 

715 self._lazy = {} 

716 if name: 

717 self.__name__ = name 

718 

719 def register(self, type, func=None): 

720 """Register dispatch of `func` on arguments of type `type`""" 

721 

722 def wrapper(func): 

723 if isinstance(type, tuple): 

724 for t in type: 

725 self.register(t, func) 

726 else: 

727 self._lookup[type] = func 

728 return func 

729 

730 return wrapper(func) if func is not None else wrapper 

731 

732 def register_lazy(self, toplevel, func=None): 

733 """ 

734 Register a registration function which will be called if the 

735 *toplevel* module (e.g. 'pandas') is ever loaded. 

736 """ 

737 

738 def wrapper(func): 

739 self._lazy[toplevel] = func 

740 return func 

741 

742 return wrapper(func) if func is not None else wrapper 

743 

744 def dispatch(self, cls): 

745 """Return the function implementation for the given ``cls``""" 

746 lk = self._lookup 

747 if cls in lk: 

748 return lk[cls] 

749 for cls2 in cls.__mro__: 

750 # Is a lazy registration function present? 

751 try: 

752 toplevel, _, _ = cls2.__module__.partition(".") 

753 except Exception: 

754 continue 

755 try: 

756 register = self._lazy[toplevel] 

757 except KeyError: 

758 pass 

759 else: 

760 register() 

761 self._lazy.pop(toplevel, None) 

762 meth = self.dispatch(cls) # recurse 

763 lk[cls] = meth 

764 lk[cls2] = meth 

765 return meth 

766 try: 

767 impl = lk[cls2] 

768 except KeyError: 

769 pass 

770 else: 

771 if cls is not cls2: 

772 # Cache lookup 

773 lk[cls] = impl 

774 return impl 

775 raise TypeError(f"No dispatch for {cls}") 

776 

777 def __call__(self, arg, *args, **kwargs): 

778 """ 

779 Call the corresponding method based on type of argument. 

780 """ 

781 meth = self.dispatch(type(arg)) 

782 return meth(arg, *args, **kwargs) 

783 

784 @property 

785 def __doc__(self): 

786 try: 

787 func = self.dispatch(object) 

788 return func.__doc__ 

789 except TypeError: 

790 return f"Single Dispatch for {self.__name__}" 

791 

792 

793def ensure_not_exists(filename) -> None: 

794 """ 

795 Ensure that a file does not exist. 

796 """ 

797 try: 

798 os.unlink(filename) 

799 except OSError as e: 

800 if e.errno != ENOENT: 

801 raise 

802 

803 

804def _skip_doctest(line): 

805 # NumPy docstring contains cursor and comment only example 

806 stripped = line.strip() 

807 if stripped == ">>>" or stripped.startswith(">>> #"): 

808 return line 

809 elif ">>>" in stripped and "+SKIP" not in stripped: 

810 if "# doctest:" in line: 

811 return line + ", +SKIP" 

812 else: 

813 return line + " # doctest: +SKIP" 

814 else: 

815 return line 

816 

817 

818def skip_doctest(doc): 

819 if doc is None: 

820 return "" 

821 return "\n".join([_skip_doctest(line) for line in doc.split("\n")]) 

822 

823 

824def extra_titles(doc): 

825 lines = doc.split("\n") 

826 titles = { 

827 i: lines[i].strip() 

828 for i in range(len(lines) - 1) 

829 if lines[i + 1].strip() and all(c == "-" for c in lines[i + 1].strip()) 

830 } 

831 

832 seen = set() 

833 for i, title in sorted(titles.items()): 

834 if title in seen: 

835 new_title = "Extra " + title 

836 lines[i] = lines[i].replace(title, new_title) 

837 lines[i + 1] = lines[i + 1].replace("-" * len(title), "-" * len(new_title)) 

838 else: 

839 seen.add(title) 

840 

841 return "\n".join(lines) 

842 

843 

844def ignore_warning(doc, cls, name, extra="", skipblocks=0, inconsistencies=None): 

845 """Expand docstring by adding disclaimer and extra text""" 

846 import inspect 

847 

848 if inspect.isclass(cls): 

849 l1 = f"This docstring was copied from {cls.__module__}.{cls.__name__}.{name}.\n\n" 

850 else: 

851 l1 = f"This docstring was copied from {cls.__name__}.{name}.\n\n" 

852 l2 = "Some inconsistencies with the Dask version may exist." 

853 

854 i = doc.find("\n\n") 

855 if i != -1: 

856 # Insert our warning 

857 head = doc[: i + 2] 

858 tail = doc[i + 2 :] 

859 while skipblocks > 0: 

860 i = tail.find("\n\n") 

861 head = tail[: i + 2] 

862 tail = tail[i + 2 :] 

863 skipblocks -= 1 

864 # Indentation of next line 

865 indent = re.match(r"\s*", tail).group(0) 

866 # Insert the warning, indented, with a blank line before and after 

867 if extra: 

868 more = [indent, extra.rstrip("\n") + "\n\n"] 

869 else: 

870 more = [] 

871 if inconsistencies is not None: 

872 l3 = f"Known inconsistencies: \n {inconsistencies}" 

873 bits = [head, indent, l1, l2, "\n\n", l3, "\n\n"] + more + [tail] 

874 else: 

875 bits = [head, indent, l1, indent, l2, "\n\n"] + more + [tail] 

876 doc = "".join(bits) 

877 

878 return doc 

879 

880 

881def unsupported_arguments(doc, args): 

882 """Mark unsupported arguments with a disclaimer""" 

883 lines = doc.split("\n") 

884 for arg in args: 

885 subset = [ 

886 (i, line) 

887 for i, line in enumerate(lines) 

888 if re.match(r"^\s*" + arg + " ?:", line) 

889 ] 

890 if len(subset) == 1: 

891 [(i, line)] = subset 

892 lines[i] = line + " (Not supported in Dask)" 

893 return "\n".join(lines) 

894 

895 

896def _derived_from( 

897 cls, method, ua_args=None, extra="", skipblocks=0, inconsistencies=None 

898): 

899 """Helper function for derived_from to ease testing""" 

900 ua_args = ua_args or [] 

901 

902 # do not use wraps here, as it hides keyword arguments displayed 

903 # in the doc 

904 original_method = getattr(cls, method.__name__) 

905 

906 doc = getattr(original_method, "__doc__", None) 

907 

908 if isinstance(original_method, property): 

909 # some things like SeriesGroupBy.unique are generated. 

910 original_method = original_method.fget 

911 if not doc: 

912 doc = getattr(original_method, "__doc__", None) 

913 

914 if isinstance(original_method, functools.cached_property): 

915 original_method = original_method.func 

916 if not doc: 

917 doc = getattr(original_method, "__doc__", None) 

918 

919 if doc is None: 

920 doc = "" 

921 

922 # pandas DataFrame/Series sometimes override methods without setting __doc__ 

923 if not doc and cls.__name__ in {"DataFrame", "Series"}: 

924 for obj in cls.mro(): 

925 obj_method = getattr(obj, method.__name__, None) 

926 if obj_method is not None and obj_method.__doc__: 

927 doc = obj_method.__doc__ 

928 break 

929 

930 # Insert disclaimer that this is a copied docstring 

931 if doc: 

932 doc = ignore_warning( 

933 doc, 

934 cls, 

935 method.__name__, 

936 extra=extra, 

937 skipblocks=skipblocks, 

938 inconsistencies=inconsistencies, 

939 ) 

940 elif extra: 

941 doc += extra.rstrip("\n") + "\n\n" 

942 

943 # Mark unsupported arguments 

944 try: 

945 method_args = get_named_args(method) 

946 original_args = get_named_args(original_method) 

947 not_supported = [m for m in original_args if m not in method_args] 

948 except ValueError: 

949 not_supported = [] 

950 if len(ua_args) > 0: 

951 not_supported.extend(ua_args) 

952 if len(not_supported) > 0: 

953 doc = unsupported_arguments(doc, not_supported) 

954 

955 doc = skip_doctest(doc) 

956 doc = extra_titles(doc) 

957 

958 return doc 

959 

960 

961def derived_from( 

962 original_klass, version=None, ua_args=None, skipblocks=0, inconsistencies=None 

963): 

964 """Decorator to attach original class's docstring to the wrapped method. 

965 

966 The output structure will be: top line of docstring, disclaimer about this 

967 being auto-derived, any extra text associated with the method being patched, 

968 the body of the docstring and finally, the list of keywords that exist in 

969 the original method but not in the dask version. 

970 

971 Parameters 

972 ---------- 

973 original_klass: type 

974 Original class which the method is derived from 

975 version : str 

976 Original package version which supports the wrapped method 

977 ua_args : list 

978 List of keywords which Dask doesn't support. Keywords existing in 

979 original but not in Dask will automatically be added. 

980 skipblocks : int 

981 How many text blocks (paragraphs) to skip from the start of the 

982 docstring. Useful for cases where the target has extra front-matter. 

983 inconsistencies: list 

984 List of known inconsistencies with method whose docstrings are being 

985 copied. 

986 """ 

987 ua_args = ua_args or [] 

988 

989 def wrapper(method): 

990 try: 

991 extra = getattr(method, "__doc__", None) or "" 

992 method.__doc__ = _derived_from( 

993 original_klass, 

994 method, 

995 ua_args=ua_args, 

996 extra=extra, 

997 skipblocks=skipblocks, 

998 inconsistencies=inconsistencies, 

999 ) 

1000 return method 

1001 

1002 except AttributeError: 

1003 module_name = original_klass.__module__.split(".")[0] 

1004 

1005 @functools.wraps(method) 

1006 def wrapped(*args, **kwargs): 

1007 msg = f"Base package doesn't support '{method.__name__}'." 

1008 if version is not None: 

1009 msg2 = " Use {0} {1} or later to use this method." 

1010 msg += msg2.format(module_name, version) 

1011 raise NotImplementedError(msg) 

1012 

1013 return wrapped 

1014 

1015 return wrapper 

1016 

1017 

1018def funcname(func) -> str: 

1019 """Get the name of a function.""" 

1020 # functools.partial 

1021 if isinstance(func, functools.partial): 

1022 return funcname(func.func) 

1023 # methodcaller 

1024 if isinstance(func, methodcaller): 

1025 return func.method[:50] 

1026 

1027 module_name = getattr(func, "__module__", None) or "" 

1028 type_name = getattr(type(func), "__name__", None) or "" 

1029 

1030 # toolz.curry 

1031 if "toolz" in module_name and "curry" == type_name: 

1032 return func.func_name[:50] 

1033 # multipledispatch objects 

1034 if "multipledispatch" in module_name and "Dispatcher" == type_name: 

1035 return func.name[:50] 

1036 # numpy.vectorize objects 

1037 if "numpy" in module_name and "vectorize" == type_name: 

1038 return ("vectorize_" + funcname(func.pyfunc))[:50] 

1039 

1040 # All other callables 

1041 try: 

1042 name = func.__name__ 

1043 if name == "<lambda>": 

1044 return "lambda" 

1045 return name[:50] 

1046 except AttributeError: 

1047 return str(func)[:50] 

1048 

1049 

1050def typename(typ: Any, short: bool = False) -> str: 

1051 """ 

1052 Return the name of a type 

1053 

1054 Examples 

1055 -------- 

1056 >>> typename(int) 

1057 'int' 

1058 

1059 >>> from dask.core import literal 

1060 >>> typename(literal) 

1061 'dask.core.literal' 

1062 >>> typename(literal, short=True) 

1063 'dask.literal' 

1064 """ 

1065 if not isinstance(typ, type): 

1066 return typename(type(typ)) 

1067 try: 

1068 if not typ.__module__ or typ.__module__ == "builtins": 

1069 return typ.__name__ 

1070 else: 

1071 if short: 

1072 module, *_ = typ.__module__.split(".") 

1073 else: 

1074 module = typ.__module__ 

1075 return module + "." + typ.__name__ 

1076 except AttributeError: 

1077 return str(typ) 

1078 

1079 

1080def ensure_bytes(s) -> bytes: 

1081 """Attempt to turn `s` into bytes. 

1082 

1083 Parameters 

1084 ---------- 

1085 s : Any 

1086 The object to be converted. Will correctly handled 

1087 * str 

1088 * bytes 

1089 * objects implementing the buffer protocol (memoryview, ndarray, etc.) 

1090 

1091 Returns 

1092 ------- 

1093 b : bytes 

1094 

1095 Raises 

1096 ------ 

1097 TypeError 

1098 When `s` cannot be converted 

1099 

1100 Examples 

1101 -------- 

1102 >>> ensure_bytes('123') 

1103 b'123' 

1104 >>> ensure_bytes(b'123') 

1105 b'123' 

1106 >>> ensure_bytes(bytearray(b'123')) 

1107 b'123' 

1108 """ 

1109 if isinstance(s, bytes): 

1110 return s 

1111 elif hasattr(s, "encode"): 

1112 return s.encode() 

1113 else: 

1114 try: 

1115 return bytes(s) 

1116 except Exception as e: 

1117 raise TypeError( 

1118 f"Object {s} is neither a bytes object nor can be encoded to bytes" 

1119 ) from e 

1120 

1121 

1122def ensure_unicode(s) -> str: 

1123 """Turn string or bytes to string 

1124 

1125 >>> ensure_unicode('123') 

1126 '123' 

1127 >>> ensure_unicode(b'123') 

1128 '123' 

1129 """ 

1130 if isinstance(s, str): 

1131 return s 

1132 elif hasattr(s, "decode"): 

1133 return s.decode() 

1134 else: 

1135 try: 

1136 return codecs.decode(s) 

1137 except Exception as e: 

1138 raise TypeError( 

1139 f"Object {s} is neither a str object nor can be decoded to str" 

1140 ) from e 

1141 

1142 

1143def digit(n, k, base): 

1144 """ 

1145 

1146 >>> digit(1234, 0, 10) 

1147 4 

1148 >>> digit(1234, 1, 10) 

1149 3 

1150 >>> digit(1234, 2, 10) 

1151 2 

1152 >>> digit(1234, 3, 10) 

1153 1 

1154 """ 

1155 return n // base**k % base 

1156 

1157 

1158def insert(tup, loc, val): 

1159 """ 

1160 

1161 >>> insert(('a', 'b', 'c'), 0, 'x') 

1162 ('x', 'b', 'c') 

1163 """ 

1164 L = list(tup) 

1165 L[loc] = val 

1166 return tuple(L) 

1167 

1168 

1169def memory_repr(num): 

1170 for x in ["bytes", "KB", "MB", "GB", "TB"]: 

1171 if num < 1024.0: 

1172 return f"{num:3.1f} {x}" 

1173 num /= 1024.0 

1174 

1175 

1176def asciitable(columns, rows): 

1177 """Formats an ascii table for given columns and rows. 

1178 

1179 Parameters 

1180 ---------- 

1181 columns : list 

1182 The column names 

1183 rows : list of tuples 

1184 The rows in the table. Each tuple must be the same length as 

1185 ``columns``. 

1186 """ 

1187 rows = [tuple(str(i) for i in r) for r in rows] 

1188 columns = tuple(str(i) for i in columns) 

1189 widths = tuple(max(*map(len, x), len(c)) for x, c in zip(zip(*rows), columns)) 

1190 row_template = ("|" + (" %%-%ds |" * len(columns))) % widths 

1191 header = row_template % tuple(columns) 

1192 bar = "+{}+".format("+".join("-" * (w + 2) for w in widths)) 

1193 data = "\n".join(row_template % r for r in rows) 

1194 return "\n".join([bar, header, bar, data, bar]) 

1195 

1196 

1197def put_lines(buf, lines): 

1198 if any(not isinstance(x, str) for x in lines): 

1199 lines = [str(x) for x in lines] 

1200 buf.write("\n".join(lines)) 

1201 

1202 

1203_method_cache: dict[str, methodcaller] = {} 

1204 

1205 

1206class methodcaller: 

1207 """ 

1208 Return a callable object that calls the given method on its operand. 

1209 

1210 Unlike the builtin `operator.methodcaller`, instances of this class are 

1211 cached and arguments are passed at call time instead of build time. 

1212 """ 

1213 

1214 __slots__ = ("method",) 

1215 method: str 

1216 

1217 @property 

1218 def func(self) -> str: 

1219 # For `funcname` to work 

1220 return self.method 

1221 

1222 def __new__(cls, method: str): 

1223 try: 

1224 return _method_cache[method] 

1225 except KeyError: 

1226 self = object.__new__(cls) 

1227 self.method = method 

1228 _method_cache[method] = self 

1229 return self 

1230 

1231 def __call__(self, __obj, *args, **kwargs): 

1232 return getattr(__obj, self.method)(*args, **kwargs) 

1233 

1234 def __reduce__(self): 

1235 return (methodcaller, (self.method,)) 

1236 

1237 def __str__(self): 

1238 return f"<{self.__class__.__name__}: {self.method}>" 

1239 

1240 __repr__ = __str__ 

1241 

1242 

1243class itemgetter: 

1244 """Variant of operator.itemgetter that supports equality tests""" 

1245 

1246 __slots__ = ("index",) 

1247 

1248 def __init__(self, index): 

1249 self.index = index 

1250 

1251 def __call__(self, x): 

1252 return x[self.index] 

1253 

1254 def __reduce__(self): 

1255 return (itemgetter, (self.index,)) 

1256 

1257 def __eq__(self, other): 

1258 return type(self) is type(other) and self.index == other.index 

1259 

1260 

1261class MethodCache: 

1262 """Attribute access on this object returns a methodcaller for that 

1263 attribute. 

1264 

1265 Examples 

1266 -------- 

1267 >>> a = [1, 3, 3] 

1268 >>> M.count(a, 3) == a.count(3) 

1269 True 

1270 """ 

1271 

1272 def __getattr__(self, item): 

1273 return methodcaller(item) 

1274 

1275 def __dir__(self): 

1276 return list(_method_cache) 

1277 

1278 

1279M = MethodCache() 

1280 

1281 

1282class SerializableLock: 

1283 """A Serializable per-process Lock 

1284 

1285 This wraps a normal ``threading.Lock`` object and satisfies the same 

1286 interface. However, this lock can also be serialized and sent to different 

1287 processes. It will not block concurrent operations between processes (for 

1288 this you should look at ``multiprocessing.Lock`` or ``locket.lock_file`` 

1289 but will consistently deserialize into the same lock. 

1290 

1291 So if we make a lock in one process:: 

1292 

1293 lock = SerializableLock() 

1294 

1295 And then send it over to another process multiple times:: 

1296 

1297 bytes = pickle.dumps(lock) 

1298 a = pickle.loads(bytes) 

1299 b = pickle.loads(bytes) 

1300 

1301 Then the deserialized objects will operate as though they were the same 

1302 lock, and collide as appropriate. 

1303 

1304 This is useful for consistently protecting resources on a per-process 

1305 level. 

1306 

1307 The creation of locks is itself not threadsafe. 

1308 """ 

1309 

1310 _locks: ClassVar[WeakValueDictionary[Hashable, Lock]] = WeakValueDictionary() 

1311 token: Hashable 

1312 lock: Lock 

1313 

1314 def __init__(self, token: Hashable | None = None): 

1315 self.token = token or str(uuid.uuid4()) 

1316 if self.token in SerializableLock._locks: 

1317 self.lock = SerializableLock._locks[self.token] 

1318 else: 

1319 self.lock = Lock() 

1320 SerializableLock._locks[self.token] = self.lock 

1321 

1322 def acquire(self, *args, **kwargs): 

1323 return self.lock.acquire(*args, **kwargs) 

1324 

1325 def release(self, *args, **kwargs): 

1326 return self.lock.release(*args, **kwargs) 

1327 

1328 def __enter__(self): 

1329 self.lock.__enter__() 

1330 

1331 def __exit__(self, *args): 

1332 self.lock.__exit__(*args) 

1333 

1334 def locked(self): 

1335 return self.lock.locked() 

1336 

1337 def __getstate__(self): 

1338 return self.token 

1339 

1340 def __setstate__(self, token): 

1341 self.__init__(token) 

1342 

1343 def __str__(self): 

1344 return f"<{self.__class__.__name__}: {self.token}>" 

1345 

1346 __repr__ = __str__ 

1347 

1348 

1349def get_scheduler_lock(collection=None, scheduler=None): 

1350 """Get an instance of the appropriate lock for a certain situation based on 

1351 scheduler used.""" 

1352 from dask import multiprocessing 

1353 from dask.base import get_scheduler 

1354 

1355 actual_get = get_scheduler(collections=[collection], scheduler=scheduler) 

1356 

1357 if actual_get == multiprocessing.get: 

1358 return multiprocessing.get_context().Manager().Lock() 

1359 else: 

1360 # if this is a distributed client, we need to lock on 

1361 # the level between processes, SerializableLock won't work 

1362 try: 

1363 import distributed.lock 

1364 from distributed.worker import get_client 

1365 

1366 client = get_client() 

1367 except (ImportError, ValueError): 

1368 pass 

1369 else: 

1370 if actual_get == client.get: 

1371 return distributed.lock.Lock() 

1372 

1373 return SerializableLock() 

1374 

1375 

1376def ensure_dict(d: Mapping[K, V], *, copy: bool = False) -> dict[K, V]: 

1377 """Convert a generic Mapping into a dict. 

1378 Optimize use case of :class:`~dask.highlevelgraph.HighLevelGraph`. 

1379 

1380 Parameters 

1381 ---------- 

1382 d : Mapping 

1383 copy : bool 

1384 If True, guarantee that the return value is always a shallow copy of d; 

1385 otherwise it may be the input itself. 

1386 """ 

1387 if type(d) is dict: 

1388 return d.copy() if copy else d 

1389 try: 

1390 layers = d.layers # type: ignore 

1391 except AttributeError: 

1392 return dict(d) 

1393 

1394 result = {} 

1395 for layer in toolz.unique(layers.values(), key=id): 

1396 result.update(layer) 

1397 return result 

1398 

1399 

1400def ensure_set(s: Set[T], *, copy: bool = False) -> set[T]: 

1401 """Convert a generic Set into a set. 

1402 

1403 Parameters 

1404 ---------- 

1405 s : Set 

1406 copy : bool 

1407 If True, guarantee that the return value is always a shallow copy of s; 

1408 otherwise it may be the input itself. 

1409 """ 

1410 if type(s) is set: 

1411 return s.copy() if copy else s 

1412 return set(s) 

1413 

1414 

1415class OperatorMethodMixin: 

1416 """A mixin for dynamically implementing operators""" 

1417 

1418 __slots__ = () 

1419 

1420 @classmethod 

1421 def _bind_operator(cls, op): 

1422 """bind operator to this class""" 

1423 name = op.__name__ 

1424 

1425 if name.endswith("_"): 

1426 # for and_ and or_ 

1427 name = name[:-1] 

1428 elif name == "inv": 

1429 name = "invert" 

1430 

1431 meth = f"__{name}__" 

1432 

1433 if name in ("abs", "invert", "neg", "pos"): 

1434 setattr(cls, meth, cls._get_unary_operator(op)) 

1435 else: 

1436 setattr(cls, meth, cls._get_binary_operator(op)) 

1437 

1438 if name in ("eq", "gt", "ge", "lt", "le", "ne", "getitem"): 

1439 return 

1440 

1441 rmeth = f"__r{name}__" 

1442 setattr(cls, rmeth, cls._get_binary_operator(op, inv=True)) 

1443 

1444 @classmethod 

1445 def _get_unary_operator(cls, op): 

1446 """Must return a method used by unary operator""" 

1447 raise NotImplementedError 

1448 

1449 @classmethod 

1450 def _get_binary_operator(cls, op, inv=False): 

1451 """Must return a method used by binary operator""" 

1452 raise NotImplementedError 

1453 

1454 

1455def partial_by_order(*args, **kwargs): 

1456 """ 

1457 

1458 >>> from operator import add 

1459 >>> partial_by_order(5, function=add, other=[(1, 10)]) 

1460 15 

1461 """ 

1462 function = kwargs.pop("function") 

1463 other = kwargs.pop("other") 

1464 args2 = list(args) 

1465 for i, arg in other: 

1466 args2.insert(i, arg) 

1467 return function(*args2, **kwargs) 

1468 

1469 

1470def is_arraylike(x) -> bool: 

1471 """Is this object a numpy array or something similar? 

1472 

1473 This function tests specifically for an object that already has 

1474 array attributes (e.g. np.ndarray, dask.array.Array, cupy.ndarray, 

1475 sparse.COO), **NOT** for something that can be coerced into an 

1476 array object (e.g. Python lists and tuples). It is meant for dask 

1477 developers and developers of downstream libraries. 

1478 

1479 Note that this function does not correspond with NumPy's 

1480 definition of array_like, which includes any object that can be 

1481 coerced into an array (see definition in the NumPy glossary): 

1482 https://numpy.org/doc/stable/glossary.html 

1483 

1484 Examples 

1485 -------- 

1486 >>> import numpy as np 

1487 >>> is_arraylike(np.ones(5)) 

1488 True 

1489 >>> is_arraylike(np.ones(())) 

1490 True 

1491 >>> is_arraylike(5) 

1492 False 

1493 >>> is_arraylike('cat') 

1494 False 

1495 """ 

1496 from dask.base import is_dask_collection 

1497 

1498 is_duck_array = hasattr(x, "__array_function__") or hasattr(x, "__array_ufunc__") 

1499 

1500 return bool( 

1501 hasattr(x, "shape") 

1502 and isinstance(x.shape, tuple) 

1503 and hasattr(x, "dtype") 

1504 and not any(is_dask_collection(n) for n in x.shape) 

1505 # We special case scipy.sparse and cupyx.scipy.sparse arrays as having partial 

1506 # support for them is useful in scenarios where we mostly call `map_partitions` 

1507 # or `map_blocks` with scikit-learn functions on dask arrays and dask dataframes. 

1508 # https://github.com/dask/dask/pull/3738 

1509 and (is_duck_array or "scipy.sparse" in typename(type(x))) 

1510 ) 

1511 

1512 

1513def is_dataframe_like(df) -> bool: 

1514 """Looks like a Pandas DataFrame""" 

1515 if (df.__class__.__module__, df.__class__.__name__) == ( 

1516 "pandas.core.frame", 

1517 "DataFrame", 

1518 ): 

1519 # fast exec for most likely input 

1520 return True 

1521 typ = df.__class__ 

1522 return ( 

1523 all(hasattr(typ, name) for name in ("groupby", "head", "merge", "mean")) 

1524 and all(hasattr(df, name) for name in ("dtypes", "columns")) 

1525 and not any(hasattr(typ, name) for name in ("name", "dtype")) 

1526 ) 

1527 

1528 

1529def is_series_like(s) -> bool: 

1530 """Looks like a Pandas Series""" 

1531 typ = s.__class__ 

1532 return ( 

1533 all(hasattr(typ, name) for name in ("groupby", "head", "mean")) 

1534 and all(hasattr(s, name) for name in ("dtype", "name")) 

1535 and "index" not in typ.__name__.lower() 

1536 ) 

1537 

1538 

1539def is_index_like(s) -> bool: 

1540 """Looks like a Pandas Index""" 

1541 typ = s.__class__ 

1542 return ( 

1543 all(hasattr(s, name) for name in ("name", "dtype")) 

1544 and "index" in typ.__name__.lower() 

1545 ) 

1546 

1547 

1548def is_cupy_type(x) -> bool: 

1549 # TODO: avoid explicit reference to CuPy 

1550 return "cupy" in str(type(x)) 

1551 

1552 

1553def natural_sort_key(s: str) -> list[str | int]: 

1554 """ 

1555 Sorting `key` function for performing a natural sort on a collection of 

1556 strings 

1557 

1558 See https://en.wikipedia.org/wiki/Natural_sort_order 

1559 

1560 Parameters 

1561 ---------- 

1562 s : str 

1563 A string that is an element of the collection being sorted 

1564 

1565 Returns 

1566 ------- 

1567 tuple[str or int] 

1568 Tuple of the parts of the input string where each part is either a 

1569 string or an integer 

1570 

1571 Examples 

1572 -------- 

1573 >>> a = ['f0', 'f1', 'f2', 'f8', 'f9', 'f10', 'f11', 'f19', 'f20', 'f21'] 

1574 >>> sorted(a) 

1575 ['f0', 'f1', 'f10', 'f11', 'f19', 'f2', 'f20', 'f21', 'f8', 'f9'] 

1576 >>> sorted(a, key=natural_sort_key) 

1577 ['f0', 'f1', 'f2', 'f8', 'f9', 'f10', 'f11', 'f19', 'f20', 'f21'] 

1578 """ 

1579 return [int(part) if part.isdigit() else part for part in re.split(r"(\d+)", s)] 

1580 

1581 

1582def parse_bytes(s: float | str) -> int: 

1583 """Parse byte string to numbers 

1584 

1585 >>> from dask.utils import parse_bytes 

1586 >>> parse_bytes('100') 

1587 100 

1588 >>> parse_bytes('100 MB') 

1589 100000000 

1590 >>> parse_bytes('100M') 

1591 100000000 

1592 >>> parse_bytes('5kB') 

1593 5000 

1594 >>> parse_bytes('5.4 kB') 

1595 5400 

1596 >>> parse_bytes('1kiB') 

1597 1024 

1598 >>> parse_bytes('1e6') 

1599 1000000 

1600 >>> parse_bytes('1e6 kB') 

1601 1000000000 

1602 >>> parse_bytes('MB') 

1603 1000000 

1604 >>> parse_bytes(123) 

1605 123 

1606 >>> parse_bytes('5 foos') 

1607 Traceback (most recent call last): 

1608 ... 

1609 ValueError: Could not interpret 'foos' as a byte unit 

1610 """ 

1611 if isinstance(s, (int, float)): 

1612 return int(s) 

1613 s = s.replace(" ", "") 

1614 if not any(char.isdigit() for char in s): 

1615 s = "1" + s 

1616 

1617 for i in range(len(s) - 1, -1, -1): 

1618 if not s[i].isalpha(): 

1619 break 

1620 index = i + 1 

1621 

1622 prefix = s[:index] 

1623 suffix = s[index:] 

1624 

1625 try: 

1626 n = float(prefix) 

1627 except ValueError as e: 

1628 raise ValueError(f"Could not interpret '{prefix}' as a number") from e 

1629 

1630 try: 

1631 multiplier = byte_sizes[suffix.lower()] 

1632 except KeyError as e: 

1633 raise ValueError(f"Could not interpret '{suffix}' as a byte unit") from e 

1634 

1635 result = n * multiplier 

1636 return int(result) 

1637 

1638 

1639byte_sizes = { 

1640 "kB": 10**3, 

1641 "MB": 10**6, 

1642 "GB": 10**9, 

1643 "TB": 10**12, 

1644 "PB": 10**15, 

1645 "KiB": 2**10, 

1646 "MiB": 2**20, 

1647 "GiB": 2**30, 

1648 "TiB": 2**40, 

1649 "PiB": 2**50, 

1650 "B": 1, 

1651 "": 1, 

1652} 

1653byte_sizes = {k.lower(): v for k, v in byte_sizes.items()} 

1654byte_sizes.update({k[0]: v for k, v in byte_sizes.items() if k and "i" not in k}) 

1655byte_sizes.update({k[:-1]: v for k, v in byte_sizes.items() if k and "i" in k}) 

1656 

1657 

1658def format_time(n: float) -> str: 

1659 """format integers as time 

1660 

1661 >>> from dask.utils import format_time 

1662 >>> format_time(1) 

1663 '1.00 s' 

1664 >>> format_time(0.001234) 

1665 '1.23 ms' 

1666 >>> format_time(0.00012345) 

1667 '123.45 us' 

1668 >>> format_time(123.456) 

1669 '123.46 s' 

1670 >>> format_time(1234.567) 

1671 '20m 34s' 

1672 >>> format_time(12345.67) 

1673 '3hr 25m' 

1674 >>> format_time(123456.78) 

1675 '34hr 17m' 

1676 >>> format_time(1234567.89) 

1677 '14d 6hr' 

1678 """ 

1679 if n > 24 * 60 * 60 * 2: 

1680 d = int(n / 3600 / 24) 

1681 h = int((n - d * 3600 * 24) / 3600) 

1682 return f"{d}d {h}hr" 

1683 if n > 60 * 60 * 2: 

1684 h = int(n / 3600) 

1685 m = int((n - h * 3600) / 60) 

1686 return f"{h}hr {m}m" 

1687 if n > 60 * 10: 

1688 m = int(n / 60) 

1689 s = int(n - m * 60) 

1690 return f"{m}m {s}s" 

1691 if n >= 1: 

1692 return f"{n:.2f} s" 

1693 if n >= 1e-3: 

1694 return "%.2f ms" % (n * 1e3) 

1695 return "%.2f us" % (n * 1e6) 

1696 

1697 

1698def format_time_ago(n: datetime) -> str: 

1699 """Calculate a '3 hours ago' type string from a Python datetime. 

1700 

1701 Examples 

1702 -------- 

1703 >>> from datetime import datetime, timedelta 

1704 

1705 >>> now = datetime.now() 

1706 >>> format_time_ago(now) 

1707 'Just now' 

1708 

1709 >>> past = datetime.now() - timedelta(minutes=1) 

1710 >>> format_time_ago(past) 

1711 '1 minute ago' 

1712 

1713 >>> past = datetime.now() - timedelta(minutes=2) 

1714 >>> format_time_ago(past) 

1715 '2 minutes ago' 

1716 

1717 >>> past = datetime.now() - timedelta(hours=1) 

1718 >>> format_time_ago(past) 

1719 '1 hour ago' 

1720 

1721 >>> past = datetime.now() - timedelta(hours=6) 

1722 >>> format_time_ago(past) 

1723 '6 hours ago' 

1724 

1725 >>> past = datetime.now() - timedelta(days=1) 

1726 >>> format_time_ago(past) 

1727 '1 day ago' 

1728 

1729 >>> past = datetime.now() - timedelta(days=5) 

1730 >>> format_time_ago(past) 

1731 '5 days ago' 

1732 

1733 >>> past = datetime.now() - timedelta(days=8) 

1734 >>> format_time_ago(past) 

1735 '1 week ago' 

1736 

1737 >>> past = datetime.now() - timedelta(days=16) 

1738 >>> format_time_ago(past) 

1739 '2 weeks ago' 

1740 

1741 >>> past = datetime.now() - timedelta(days=190) 

1742 >>> format_time_ago(past) 

1743 '6 months ago' 

1744 

1745 >>> past = datetime.now() - timedelta(days=800) 

1746 >>> format_time_ago(past) 

1747 '2 years ago' 

1748 

1749 """ 

1750 units = { 

1751 "years": lambda diff: diff.days / 365, 

1752 "months": lambda diff: diff.days / 30.436875, # Average days per month 

1753 "weeks": lambda diff: diff.days / 7, 

1754 "days": lambda diff: diff.days, 

1755 "hours": lambda diff: diff.seconds / 3600, 

1756 "minutes": lambda diff: diff.seconds % 3600 / 60, 

1757 } 

1758 diff = datetime.now() - n 

1759 for unit, func in units.items(): 

1760 dur = int(func(diff)) 

1761 if dur > 0: 

1762 if dur == 1: # De-pluralize 

1763 unit = unit[:-1] 

1764 return f"{dur} {unit} ago" 

1765 return "Just now" 

1766 

1767 

1768def format_bytes(n: int) -> str: 

1769 """Format bytes as text 

1770 

1771 >>> from dask.utils import format_bytes 

1772 >>> format_bytes(1) 

1773 '1 B' 

1774 >>> format_bytes(1234) 

1775 '1.21 kiB' 

1776 >>> format_bytes(12345678) 

1777 '11.77 MiB' 

1778 >>> format_bytes(1234567890) 

1779 '1.15 GiB' 

1780 >>> format_bytes(1234567890000) 

1781 '1.12 TiB' 

1782 >>> format_bytes(1234567890000000) 

1783 '1.10 PiB' 

1784 

1785 For all values < 2**60, the output is always <= 10 characters. 

1786 """ 

1787 for prefix, k in ( 

1788 ("Pi", 2**50), 

1789 ("Ti", 2**40), 

1790 ("Gi", 2**30), 

1791 ("Mi", 2**20), 

1792 ("ki", 2**10), 

1793 ): 

1794 if n >= k * 0.9: 

1795 return f"{n / k:.2f} {prefix}B" 

1796 return f"{n} B" 

1797 

1798 

1799timedelta_sizes = { 

1800 "s": 1, 

1801 "ms": 1e-3, 

1802 "us": 1e-6, 

1803 "ns": 1e-9, 

1804 "m": 60, 

1805 "h": 3600, 

1806 "d": 3600 * 24, 

1807 "w": 7 * 3600 * 24, 

1808} 

1809 

1810tds2 = { 

1811 "second": 1, 

1812 "minute": 60, 

1813 "hour": 60 * 60, 

1814 "day": 60 * 60 * 24, 

1815 "week": 7 * 60 * 60 * 24, 

1816 "millisecond": 1e-3, 

1817 "microsecond": 1e-6, 

1818 "nanosecond": 1e-9, 

1819} 

1820tds2.update({k + "s": v for k, v in tds2.items()}) 

1821timedelta_sizes.update(tds2) 

1822timedelta_sizes.update({k.upper(): v for k, v in timedelta_sizes.items()}) 

1823 

1824 

1825@overload 

1826def parse_timedelta(s: None, default: str | Literal[False] = "seconds") -> None: ... 

1827 

1828 

1829@overload 

1830def parse_timedelta( 

1831 s: str | float | timedelta, default: str | Literal[False] = "seconds" 

1832) -> float: ... 

1833 

1834 

1835def parse_timedelta(s, default="seconds"): 

1836 """Parse timedelta string to number of seconds 

1837 

1838 Parameters 

1839 ---------- 

1840 s : str, float, timedelta, or None 

1841 default: str or False, optional 

1842 Unit of measure if s does not specify one. Defaults to seconds. 

1843 Set to False to require s to explicitly specify its own unit. 

1844 

1845 Examples 

1846 -------- 

1847 >>> from datetime import timedelta 

1848 >>> from dask.utils import parse_timedelta 

1849 >>> parse_timedelta('3s') 

1850 3 

1851 >>> parse_timedelta('3.5 seconds') 

1852 3.5 

1853 >>> parse_timedelta('300ms') 

1854 0.3 

1855 >>> parse_timedelta(timedelta(seconds=3)) # also supports timedeltas 

1856 3 

1857 """ 

1858 if s is None: 

1859 return None 

1860 if isinstance(s, timedelta): 

1861 s = s.total_seconds() 

1862 return int(s) if int(s) == s else s 

1863 if isinstance(s, Number): 

1864 s = str(s) 

1865 s = s.replace(" ", "") 

1866 if not s[0].isdigit(): 

1867 s = "1" + s 

1868 

1869 for i in range(len(s) - 1, -1, -1): 

1870 if not s[i].isalpha(): 

1871 break 

1872 index = i + 1 

1873 

1874 prefix = s[:index] 

1875 suffix = s[index:] or default 

1876 if suffix is False: 

1877 raise ValueError(f"Missing time unit: {s}") 

1878 if not isinstance(suffix, str): 

1879 raise TypeError(f"default must be str or False, got {default!r}") 

1880 

1881 n = float(prefix) 

1882 

1883 try: 

1884 multiplier = timedelta_sizes[suffix.lower()] 

1885 except KeyError: 

1886 valid_units = ", ".join(timedelta_sizes.keys()) 

1887 raise KeyError( 

1888 f"Invalid time unit: {suffix}. Valid units are: {valid_units}" 

1889 ) from None 

1890 

1891 result = n * multiplier 

1892 if int(result) == result: 

1893 result = int(result) 

1894 return result 

1895 

1896 

1897def has_keyword(func, keyword): 

1898 try: 

1899 return keyword in inspect.signature(func).parameters 

1900 except Exception: 

1901 return False 

1902 

1903 

1904def ndimlist(seq): 

1905 if not isinstance(seq, (list, tuple)): 

1906 return 0 

1907 elif not seq: 

1908 return 1 

1909 else: 

1910 return 1 + ndimlist(seq[0]) 

1911 

1912 

1913def iter_chunks(sizes, max_size): 

1914 """Split sizes into chunks of total max_size each 

1915 

1916 Parameters 

1917 ---------- 

1918 sizes : iterable of numbers 

1919 The sizes to be chunked 

1920 max_size : number 

1921 Maximum total size per chunk. 

1922 It must be greater or equal than each size in sizes 

1923 """ 

1924 chunk, chunk_sum = [], 0 

1925 iter_sizes = iter(sizes) 

1926 size = next(iter_sizes, None) 

1927 while size is not None: 

1928 assert size <= max_size 

1929 if chunk_sum + size <= max_size: 

1930 chunk.append(size) 

1931 chunk_sum += size 

1932 size = next(iter_sizes, None) 

1933 else: 

1934 assert chunk 

1935 yield chunk 

1936 chunk, chunk_sum = [], 0 

1937 if chunk: 

1938 yield chunk 

1939 

1940 

1941hex_pattern = re.compile("[a-f]+") 

1942 

1943 

1944@functools.lru_cache(100000) 

1945def key_split(s): 

1946 """ 

1947 >>> key_split('x') 

1948 'x' 

1949 >>> key_split('x-1') 

1950 'x' 

1951 >>> key_split('x-1-2-3') 

1952 'x' 

1953 >>> key_split(('x-2', 1)) 

1954 'x' 

1955 >>> key_split("('x-2', 1)") 

1956 'x' 

1957 >>> key_split("('x', 1)") 

1958 'x' 

1959 >>> key_split('hello-world-1') 

1960 'hello-world' 

1961 >>> key_split(b'hello-world-1') 

1962 'hello-world' 

1963 >>> key_split('ae05086432ca935f6eba409a8ecd4896') 

1964 'data' 

1965 >>> key_split('<module.submodule.myclass object at 0xdaf372') 

1966 'myclass' 

1967 >>> key_split(None) 

1968 'Other' 

1969 >>> key_split('x-abcdefab') # ignores hex 

1970 'x' 

1971 >>> key_split('_(x)') # strips unpleasant characters 

1972 'x' 

1973 """ 

1974 # If we convert the key, recurse to utilize LRU cache better 

1975 if type(s) is bytes: 

1976 return key_split(s.decode()) 

1977 if type(s) is tuple: 

1978 return key_split(s[0]) 

1979 try: 

1980 words = s.split("-") 

1981 if not words[0][0].isalpha(): 

1982 result = words[0].split(",")[0].strip("_'()\"") 

1983 else: 

1984 result = words[0] 

1985 for word in words[1:]: 

1986 if word.isalpha() and not ( 

1987 len(word) == 8 and hex_pattern.match(word) is not None 

1988 ): 

1989 result += "-" + word 

1990 else: 

1991 break 

1992 if len(result) == 32 and re.match(r"[a-f0-9]{32}", result): 

1993 return "data" 

1994 else: 

1995 if result[0] == "<": 

1996 result = result.strip("<>").split()[0].split(".")[-1] 

1997 return sys.intern(result) 

1998 except Exception: 

1999 return "Other" 

2000 

2001 

2002def stringify(obj, exclusive: Iterable | None = None): 

2003 """Convert an object to a string 

2004 

2005 If ``exclusive`` is specified, search through `obj` and convert 

2006 values that are in ``exclusive``. 

2007 

2008 Note that when searching through dictionaries, only values are 

2009 converted, not the keys. 

2010 

2011 Parameters 

2012 ---------- 

2013 obj : Any 

2014 Object (or values within) to convert to string 

2015 exclusive: Iterable, optional 

2016 Set of values to search for when converting values to strings 

2017 

2018 Returns 

2019 ------- 

2020 result : type(obj) 

2021 Stringified copy of ``obj`` or ``obj`` itself if it is already a 

2022 string or bytes. 

2023 

2024 Examples 

2025 -------- 

2026 >>> stringify(b'x') 

2027 b'x' 

2028 >>> stringify('x') 

2029 'x' 

2030 >>> stringify({('a',0):('a',0), ('a',1): ('a',1)}) 

2031 "{('a', 0): ('a', 0), ('a', 1): ('a', 1)}" 

2032 >>> stringify({('a',0):('a',0), ('a',1): ('a',1)}, exclusive={('a',0)}) 

2033 {('a', 0): "('a', 0)", ('a', 1): ('a', 1)} 

2034 """ 

2035 

2036 typ = type(obj) 

2037 if typ is str or typ is bytes: 

2038 return obj 

2039 elif exclusive is None: 

2040 return str(obj) 

2041 

2042 if typ is list: 

2043 return [stringify(v, exclusive) for v in obj] 

2044 if typ is dict: 

2045 return {k: stringify(v, exclusive) for k, v in obj.items()} 

2046 try: 

2047 if obj in exclusive: 

2048 return stringify(obj) 

2049 except TypeError: # `obj` not hashable 

2050 pass 

2051 if typ is tuple: # If the tuple itself isn't a key, check its elements 

2052 return tuple(stringify(v, exclusive) for v in obj) 

2053 return obj 

2054 

2055 

2056class cached_property(functools.cached_property): 

2057 """Read only version of functools.cached_property.""" 

2058 

2059 def __set__(self, instance, val): 

2060 """Raise an error when attempting to set a cached property.""" 

2061 raise AttributeError("Can't set attribute") 

2062 

2063 

2064class _HashIdWrapper: 

2065 """Hash and compare a wrapped object by identity instead of value""" 

2066 

2067 def __init__(self, wrapped): 

2068 self.wrapped = wrapped 

2069 

2070 def __eq__(self, other): 

2071 if not isinstance(other, _HashIdWrapper): 

2072 return NotImplemented 

2073 return self.wrapped is other.wrapped 

2074 

2075 def __ne__(self, other): 

2076 if not isinstance(other, _HashIdWrapper): 

2077 return NotImplemented 

2078 return self.wrapped is not other.wrapped 

2079 

2080 def __hash__(self): 

2081 return id(self.wrapped) 

2082 

2083 

2084@functools.lru_cache 

2085def _cumsum(seq, initial_zero): 

2086 if isinstance(seq, _HashIdWrapper): 

2087 seq = seq.wrapped 

2088 if initial_zero: 

2089 return tuple(toolz.accumulate(add, seq, 0)) 

2090 else: 

2091 return tuple(toolz.accumulate(add, seq)) 

2092 

2093 

2094@functools.lru_cache 

2095def _max(seq): 

2096 if isinstance(seq, _HashIdWrapper): 

2097 seq = seq.wrapped 

2098 return max(seq) 

2099 

2100 

2101def cached_max(seq): 

2102 """Compute max with caching. 

2103 

2104 Caching is by the identity of `seq` rather than the value. It is thus 

2105 important that `seq` is a tuple of immutable objects, and this function 

2106 is intended for use where `seq` is a value that will persist (generally 

2107 block sizes). 

2108 

2109 Parameters 

2110 ---------- 

2111 seq : tuple 

2112 Values to reduce 

2113 

2114 Returns 

2115 ------- 

2116 tuple 

2117 """ 

2118 assert isinstance(seq, tuple) 

2119 # Look up by identity first, to avoid a linear-time __hash__ 

2120 # if we've seen this tuple object before. 

2121 result = _max(_HashIdWrapper(seq)) 

2122 return result 

2123 

2124 

2125def cached_cumsum(seq, initial_zero=False): 

2126 """Compute :meth:`toolz.accumulate` with caching. 

2127 

2128 Caching is by the identify of `seq` rather than the value. It is thus 

2129 important that `seq` is a tuple of immutable objects, and this function 

2130 is intended for use where `seq` is a value that will persist (generally 

2131 block sizes). 

2132 

2133 Parameters 

2134 ---------- 

2135 seq : tuple 

2136 Values to cumulatively sum. 

2137 initial_zero : bool, optional 

2138 If true, the return value is prefixed with a zero. 

2139 

2140 Returns 

2141 ------- 

2142 tuple 

2143 """ 

2144 if isinstance(seq, tuple): 

2145 # Look up by identity first, to avoid a linear-time __hash__ 

2146 # if we've seen this tuple object before. 

2147 result = _cumsum(_HashIdWrapper(seq), initial_zero) 

2148 else: 

2149 # Construct a temporary tuple, and look up by value. 

2150 result = _cumsum(tuple(seq), initial_zero) 

2151 return result 

2152 

2153 

2154def show_versions() -> None: 

2155 """Provide version information for bug reports.""" 

2156 

2157 from json import dumps 

2158 from platform import uname 

2159 from sys import stdout, version_info 

2160 

2161 from dask._compatibility import importlib_metadata 

2162 

2163 try: 

2164 from distributed import __version__ as distributed_version 

2165 except ImportError: 

2166 distributed_version = None 

2167 

2168 from dask import __version__ as dask_version 

2169 

2170 deps = [ 

2171 "numpy", 

2172 "pandas", 

2173 "cloudpickle", 

2174 "fsspec", 

2175 "bokeh", 

2176 "pyarrow", 

2177 "zarr", 

2178 ] 

2179 

2180 result: dict[str, str | None] = { 

2181 # note: only major, minor, micro are extracted 

2182 "Python": ".".join([str(i) for i in version_info[:3]]), 

2183 "Platform": uname().system, 

2184 "dask": dask_version, 

2185 "distributed": distributed_version, 

2186 } 

2187 

2188 for modname in deps: 

2189 try: 

2190 result[modname] = importlib_metadata.version(modname) 

2191 except importlib_metadata.PackageNotFoundError: 

2192 result[modname] = None 

2193 

2194 stdout.writelines(dumps(result, indent=2)) 

2195 

2196 

2197def maybe_pluralize(count, noun, plural_form=None): 

2198 """Pluralize a count-noun string pattern when necessary""" 

2199 if count == 1: 

2200 return f"{count} {noun}" 

2201 else: 

2202 return f"{count} {plural_form or noun + 's'}" 

2203 

2204 

2205def is_namedtuple_instance(obj: Any) -> bool: 

2206 """Returns True if obj is an instance of a namedtuple. 

2207 

2208 Note: This function checks for the existence of the methods and 

2209 attributes that make up the namedtuple API, so it will return True 

2210 IFF obj's type implements that API. 

2211 """ 

2212 return ( 

2213 isinstance(obj, tuple) 

2214 and hasattr(obj, "_make") 

2215 and hasattr(obj, "_asdict") 

2216 and hasattr(obj, "_replace") 

2217 and hasattr(obj, "_fields") 

2218 and hasattr(obj, "_field_defaults") 

2219 ) 

2220 

2221 

2222def get_default_shuffle_method() -> str: 

2223 if d := config.get("dataframe.shuffle.method", None): 

2224 return d 

2225 try: 

2226 from distributed import default_client 

2227 

2228 default_client() 

2229 except (ImportError, ValueError): 

2230 return "disk" 

2231 

2232 try: 

2233 from distributed.shuffle import check_minimal_arrow_version 

2234 

2235 check_minimal_arrow_version() 

2236 except ModuleNotFoundError: 

2237 return "tasks" 

2238 return "p2p" 

2239 

2240 

2241def get_meta_library(like): 

2242 if hasattr(like, "_meta"): 

2243 like = like._meta 

2244 

2245 return import_module(typename(like).partition(".")[0]) 

2246 

2247 

2248class shorten_traceback: 

2249 """Context manager that removes irrelevant stack elements from traceback. 

2250 

2251 * omits frames from modules that match `admin.traceback.shorten` 

2252 * always keeps the first and last frame. 

2253 """ 

2254 

2255 __slots__ = () 

2256 

2257 def __enter__(self) -> None: 

2258 pass 

2259 

2260 def __exit__( 

2261 self, 

2262 exc_type: type[BaseException] | None, 

2263 exc_val: BaseException | None, 

2264 exc_tb: types.TracebackType | None, 

2265 ) -> None: 

2266 if exc_val and exc_tb: 

2267 exc_val.__traceback__ = self.shorten(exc_tb) 

2268 

2269 @staticmethod 

2270 def shorten(exc_tb: types.TracebackType) -> types.TracebackType: 

2271 paths = config.get("admin.traceback.shorten") 

2272 if not paths: 

2273 return exc_tb 

2274 

2275 exp = re.compile(".*(" + "|".join(paths) + ")") 

2276 curr: types.TracebackType | None = exc_tb 

2277 prev: types.TracebackType | None = None 

2278 

2279 while curr: 

2280 if prev is None: 

2281 prev = curr # first frame 

2282 elif not curr.tb_next: 

2283 # always keep last frame 

2284 prev.tb_next = curr 

2285 prev = prev.tb_next 

2286 elif not exp.match(curr.tb_frame.f_code.co_filename): 

2287 # keep if module is not listed in config 

2288 prev.tb_next = curr 

2289 prev = curr 

2290 curr = curr.tb_next 

2291 

2292 # Uncomment to remove the first frame, which is something you don't want to keep 

2293 # if it matches the regexes. Requires Python >=3.11. 

2294 # if exc_tb.tb_next and exp.match(exc_tb.tb_frame.f_code.co_filename): 

2295 # return exc_tb.tb_next 

2296 

2297 return exc_tb 

2298 

2299 

2300def unzip(ls, nout): 

2301 """Unzip a list of lists into ``nout`` outputs.""" 

2302 out = list(zip(*ls)) 

2303 if not out: 

2304 out = [()] * nout 

2305 return out 

2306 

2307 

2308class disable_gc(ContextDecorator): 

2309 """Context manager to disable garbage collection.""" 

2310 

2311 def __init__(self, collect=False): 

2312 self.collect = collect 

2313 self._gc_enabled = gc.isenabled() 

2314 

2315 def __enter__(self): 

2316 gc.disable() 

2317 return self 

2318 

2319 def __exit__(self, exc_type, exc_value, traceback): 

2320 if self._gc_enabled: 

2321 gc.enable() 

2322 return False 

2323 

2324 

2325def is_empty(obj): 

2326 """ 

2327 Duck-typed check for “emptiness” of an object. 

2328 

2329 Works for standard sequences (lists, tuples, etc.), NumPy arrays, 

2330 and sparse-like objects (e.g., SciPy sparse arrays). 

2331 

2332 The function checks: 

2333 1. If the object supports len(), returns True if len(obj) == 0. 

2334 2. If the object has a `.nnz` attribute (number of non-zero elements), 

2335 returns True if `.nnz == 0`. 

2336 3. If the object has a `.shape` attribute, returns True if any 

2337 dimension is zero. 

2338 4. Otherwise, returns False (assumes non-empty). 

2339 

2340 Parameters 

2341 ---------- 

2342 obj : any 

2343 The object to check for emptiness. 

2344 

2345 Returns 

2346 ------- 

2347 bool 

2348 True if the object is considered empty, False otherwise. 

2349 """ 

2350 # Check standard sequences 

2351 with contextlib.suppress(Exception): 

2352 return len(obj) == 0 

2353 

2354 # Sparse-like objects 

2355 with contextlib.suppress(Exception): 

2356 return obj.nnz == 0 

2357 

2358 with contextlib.suppress(Exception): 

2359 return 0 in obj.shape 

2360 

2361 # Fallback: assume non-empty 

2362 return False