Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/utils.py: 30%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import codecs
4import contextlib
5import functools
6import gc
7import inspect
8import os
9import re
10import shutil
11import sys
12import tempfile
13import types
14import uuid
15import warnings
16from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, Set
17from contextlib import ContextDecorator, contextmanager, nullcontext, suppress
18from datetime import datetime, timedelta
19from errno import ENOENT
20from functools import wraps
21from importlib import import_module
22from numbers import Integral, Number
23from operator import add
24from threading import Lock
25from typing import Any, ClassVar, Literal, TypeVar, cast, overload
26from weakref import WeakValueDictionary
28import tlz as toolz
30from dask import config
31from dask.typing import no_default
33K = TypeVar("K")
34V = TypeVar("V")
35T = TypeVar("T")
37# used in decorators to preserve the signature of the function it decorates
38# see https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators
39FuncType = Callable[..., Any]
40F = TypeVar("F", bound=FuncType)
42system_encoding = sys.getdefaultencoding()
43if system_encoding == "ascii":
44 system_encoding = "utf-8"
47def apply(func, args, kwargs=None):
48 """Apply a function given its positional and keyword arguments.
50 Equivalent to ``func(*args, **kwargs)``
51 Most Dask users will never need to use the ``apply`` function.
52 It is typically only used by people who need to inject
53 keyword argument values into a low level Dask task graph.
55 Parameters
56 ----------
57 func : callable
58 The function you want to apply.
59 args : tuple
60 A tuple containing all the positional arguments needed for ``func``
61 (eg: ``(arg_1, arg_2, arg_3)``)
62 kwargs : dict, optional
63 A dictionary mapping the keyword arguments
64 (eg: ``{"kwarg_1": value, "kwarg_2": value}``
66 Examples
67 --------
68 >>> from dask.utils import apply
69 >>> def add(number, second_number=5):
70 ... return number + second_number
71 ...
72 >>> apply(add, (10,), {"second_number": 2}) # equivalent to add(*args, **kwargs)
73 12
75 >>> task = apply(add, (10,), {"second_number": 2})
76 >>> dsk = {'task-name': task} # adds the task to a low level Dask task graph
77 """
78 if kwargs:
79 return func(*args, **kwargs)
80 else:
81 return func(*args)
84def _deprecated(
85 *,
86 version: str | None = None,
87 after_version: str | None = None,
88 message: str | None = None,
89 use_instead: str | None = None,
90 category: type[Warning] = FutureWarning,
91):
92 """Decorator to mark a function as deprecated
94 Parameters
95 ----------
96 version : str, optional
97 Version of Dask in which the function was deprecated. If specified, the version
98 will be included in the default warning message. This should no longer be used
99 after the introduction of automated versioning system.
100 after_version : str, optional
101 Version of Dask after which the function was deprecated. If specified, the
102 version will be included in the default warning message.
103 message : str, optional
104 Custom warning message to raise.
105 use_instead : str, optional
106 Name of function to use in place of the deprecated function.
107 If specified, this will be included in the default warning
108 message.
109 category : type[Warning], optional
110 Type of warning to raise. Defaults to ``FutureWarning``.
112 Examples
113 --------
115 >>> from dask.utils import _deprecated
116 >>> @_deprecated(after_version="X.Y.Z", use_instead="bar")
117 ... def foo():
118 ... return "baz"
119 """
121 def decorator(func):
122 if message is None:
123 msg = f"{func.__name__} "
124 if after_version is not None:
125 msg += f"was deprecated after version {after_version} "
126 elif version is not None:
127 msg += f"was deprecated in version {version} "
128 else:
129 msg += "is deprecated "
130 msg += "and will be removed in a future release."
132 if use_instead is not None:
133 msg += f" Please use {use_instead} instead."
134 else:
135 msg = message
137 @functools.wraps(func)
138 def wrapper(*args, **kwargs):
139 warnings.warn(msg, category=category, stacklevel=2)
140 return func(*args, **kwargs)
142 return wrapper
144 return decorator
147def _deprecated_kwarg(
148 old_arg_name: str,
149 new_arg_name: str | None = None,
150 mapping: Mapping[Any, Any] | Callable[[Any], Any] | None = None,
151 stacklevel: int = 2,
152 comment: str | None = None,
153) -> Callable[[F], F]:
154 """
155 Decorator to deprecate a keyword argument of a function.
157 Parameters
158 ----------
159 old_arg_name : str
160 Name of argument in function to deprecate
161 new_arg_name : str, optional
162 Name of preferred argument in function. Omit to warn that
163 ``old_arg_name`` keyword is deprecated.
164 mapping : dict or callable, optional
165 If mapping is present, use it to translate old arguments to
166 new arguments. A callable must do its own value checking;
167 values not found in a dict will be forwarded unchanged.
168 comment : str, optional
169 Additional message to deprecation message. Useful to pass
170 on suggestions with the deprecation warning.
172 Examples
173 --------
174 The following deprecates 'cols', using 'columns' instead
176 >>> @_deprecated_kwarg(old_arg_name='cols', new_arg_name='columns')
177 ... def f(columns=''):
178 ... print(columns)
179 ...
180 >>> f(columns='should work ok')
181 should work ok
183 >>> f(cols='should raise warning') # doctest: +SKIP
184 FutureWarning: cols is deprecated, use columns instead
185 warnings.warn(msg, FutureWarning)
186 should raise warning
188 >>> f(cols='should error', columns="can\'t pass do both") # doctest: +SKIP
189 TypeError: Can only specify 'cols' or 'columns', not both
191 >>> @_deprecated_kwarg('old', 'new', {'yes': True, 'no': False})
192 ... def f(new=False):
193 ... print('yes!' if new else 'no!')
194 ...
195 >>> f(old='yes') # doctest: +SKIP
196 FutureWarning: old='yes' is deprecated, use new=True instead
197 warnings.warn(msg, FutureWarning)
198 yes!
200 To raise a warning that a keyword will be removed entirely in the future
202 >>> @_deprecated_kwarg(old_arg_name='cols', new_arg_name=None)
203 ... def f(cols='', another_param=''):
204 ... print(cols)
205 ...
206 >>> f(cols='should raise warning') # doctest: +SKIP
207 FutureWarning: the 'cols' keyword is deprecated and will be removed in a
208 future version please takes steps to stop use of 'cols'
209 should raise warning
210 >>> f(another_param='should not raise warning') # doctest: +SKIP
211 should not raise warning
213 >>> f(cols='should raise warning', another_param='') # doctest: +SKIP
214 FutureWarning: the 'cols' keyword is deprecated and will be removed in a
215 future version please takes steps to stop use of 'cols'
216 should raise warning
217 """
218 if mapping is not None and not hasattr(mapping, "get") and not callable(mapping):
219 raise TypeError(
220 "mapping from old to new argument values must be dict or callable!"
221 )
223 comment_ = f"\n{comment}" or ""
225 def _deprecated_kwarg(func: F) -> F:
226 @wraps(func)
227 def wrapper(*args, **kwargs) -> Callable[..., Any]:
228 old_arg_value = kwargs.pop(old_arg_name, no_default)
230 if old_arg_value is not no_default:
231 if new_arg_name is None:
232 msg = (
233 f"the {old_arg_name!r} keyword is deprecated and "
234 "will be removed in a future version. Please take "
235 f"steps to stop the use of {old_arg_name!r}"
236 ) + comment_
237 warnings.warn(msg, FutureWarning, stacklevel=stacklevel)
238 kwargs[old_arg_name] = old_arg_value
239 return func(*args, **kwargs)
241 elif mapping is not None:
242 if callable(mapping):
243 new_arg_value = mapping(old_arg_value)
244 else:
245 new_arg_value = mapping.get(old_arg_value, old_arg_value)
246 msg = (
247 f"the {old_arg_name}={old_arg_value!r} keyword is "
248 "deprecated, use "
249 f"{new_arg_name}={new_arg_value!r} instead."
250 )
251 else:
252 new_arg_value = old_arg_value
253 msg = (
254 f"the {old_arg_name!r} keyword is deprecated, "
255 f"use {new_arg_name!r} instead."
256 )
258 warnings.warn(msg + comment_, FutureWarning, stacklevel=stacklevel)
259 if kwargs.get(new_arg_name) is not None:
260 msg = (
261 f"Can only specify {old_arg_name!r} "
262 f"or {new_arg_name!r}, not both."
263 )
264 raise TypeError(msg)
265 kwargs[new_arg_name] = new_arg_value
266 return func(*args, **kwargs)
268 return cast(F, wrapper)
270 return _deprecated_kwarg
273def deepmap(func, *seqs):
274 """Apply function inside nested lists
276 >>> inc = lambda x: x + 1
277 >>> deepmap(inc, [[1, 2], [3, 4]])
278 [[2, 3], [4, 5]]
280 >>> add = lambda x, y: x + y
281 >>> deepmap(add, [[1, 2], [3, 4]], [[10, 20], [30, 40]])
282 [[11, 22], [33, 44]]
283 """
284 if isinstance(seqs[0], (list, Iterator)):
285 return [deepmap(func, *items) for items in zip(*seqs)]
286 else:
287 return func(*seqs)
290@_deprecated()
291def homogeneous_deepmap(func, seq):
292 if not seq:
293 return seq
294 n = 0
295 tmp = seq
296 while isinstance(tmp, list):
297 n += 1
298 tmp = tmp[0]
300 return ndeepmap(n, func, seq)
303def ndeepmap(n, func, seq):
304 """Call a function on every element within a nested container
306 >>> def inc(x):
307 ... return x + 1
308 >>> L = [[1, 2], [3, 4, 5]]
309 >>> ndeepmap(2, inc, L)
310 [[2, 3], [4, 5, 6]]
311 """
312 if n == 1:
313 return [func(item) for item in seq]
314 elif n > 1:
315 return [ndeepmap(n - 1, func, item) for item in seq]
316 elif isinstance(seq, list):
317 return func(seq[0])
318 else:
319 return func(seq)
322def import_required(mod_name, error_msg):
323 """Attempt to import a required dependency.
325 Raises a RuntimeError if the requested module is not available.
326 """
327 try:
328 return import_module(mod_name)
329 except ImportError as e:
330 raise RuntimeError(error_msg) from e
333@contextmanager
334def tmpfile(extension="", dir=None):
335 """
336 Function to create and return a unique temporary file with the given extension, if provided.
338 Parameters
339 ----------
340 extension : str
341 The extension of the temporary file to be created
342 dir : str
343 If ``dir`` is not None, the file will be created in that directory; otherwise,
344 Python's default temporary directory is used.
346 Returns
347 -------
348 out : str
349 Path to the temporary file
351 See Also
352 --------
353 NamedTemporaryFile : Built-in alternative for creating temporary files
354 tmp_path : pytest fixture for creating a temporary directory unique to the test invocation
356 Notes
357 -----
358 This context manager is particularly useful on Windows for opening temporary files multiple times.
359 """
360 extension = extension.lstrip(".")
361 if extension:
362 extension = "." + extension
363 handle, filename = tempfile.mkstemp(extension, dir=dir)
364 os.close(handle)
365 os.remove(filename)
367 try:
368 yield filename
369 finally:
370 if os.path.exists(filename):
371 with suppress(OSError): # sometimes we can't remove a generated temp file
372 if os.path.isdir(filename):
373 shutil.rmtree(filename)
374 else:
375 os.remove(filename)
378@contextmanager
379def tmpdir(dir=None):
380 """
381 Function to create and return a unique temporary directory.
383 Parameters
384 ----------
385 dir : str
386 If ``dir`` is not None, the directory will be created in that directory; otherwise,
387 Python's default temporary directory is used.
389 Returns
390 -------
391 out : str
392 Path to the temporary directory
394 Notes
395 -----
396 This context manager is particularly useful on Windows for opening temporary directories multiple times.
397 """
398 dirname = tempfile.mkdtemp(dir=dir)
400 try:
401 yield dirname
402 finally:
403 if os.path.exists(dirname):
404 if os.path.isdir(dirname):
405 with suppress(OSError):
406 shutil.rmtree(dirname)
407 else:
408 with suppress(OSError):
409 os.remove(dirname)
412@contextmanager
413def filetext(text, extension="", open=open, mode="w"):
414 with tmpfile(extension=extension) as filename:
415 f = open(filename, mode=mode)
416 try:
417 f.write(text)
418 finally:
419 try:
420 f.close()
421 except AttributeError:
422 pass
424 yield filename
427@contextmanager
428def changed_cwd(new_cwd):
429 old_cwd = os.getcwd()
430 os.chdir(new_cwd)
431 try:
432 yield
433 finally:
434 os.chdir(old_cwd)
437@contextmanager
438def tmp_cwd(dir=None):
439 with tmpdir(dir) as dirname:
440 with changed_cwd(dirname):
441 yield dirname
444class IndexCallable:
445 """Provide getitem syntax for functions
447 >>> def inc(x):
448 ... return x + 1
450 >>> I = IndexCallable(inc)
451 >>> I[3]
452 4
453 """
455 __slots__ = ("fn",)
457 def __init__(self, fn):
458 self.fn = fn
460 def __getitem__(self, key):
461 return self.fn(key)
464@contextmanager
465def filetexts(d, open=open, mode="t", use_tmpdir=True):
466 """Dumps a number of textfiles to disk
468 Parameters
469 ----------
470 d : dict
471 a mapping from filename to text like {'a.csv': '1,1\n2,2'}
473 Since this is meant for use in tests, this context manager will
474 automatically switch to a temporary current directory, to avoid
475 race conditions when running tests in parallel.
476 """
477 with tmp_cwd() if use_tmpdir else nullcontext():
478 for filename, text in d.items():
479 try:
480 os.makedirs(os.path.dirname(filename))
481 except OSError:
482 pass
483 f = open(filename, "w" + mode)
484 try:
485 f.write(text)
486 finally:
487 try:
488 f.close()
489 except AttributeError:
490 pass
492 yield list(d)
494 for filename in d:
495 if os.path.exists(filename):
496 with suppress(OSError):
497 os.remove(filename)
500def concrete(seq):
501 """Make nested iterators concrete lists
503 >>> data = [[1, 2], [3, 4]]
504 >>> seq = iter(map(iter, data))
505 >>> concrete(seq)
506 [[1, 2], [3, 4]]
507 """
508 if isinstance(seq, Iterator):
509 seq = list(seq)
510 if isinstance(seq, (tuple, list)):
511 seq = list(map(concrete, seq))
512 return seq
515def pseudorandom(n: int, p, random_state=None):
516 """Pseudorandom array of integer indexes
518 >>> pseudorandom(5, [0.5, 0.5], random_state=123)
519 array([1, 0, 0, 1, 1], dtype=int8)
521 >>> pseudorandom(10, [0.5, 0.2, 0.2, 0.1], random_state=5)
522 array([0, 2, 0, 3, 0, 1, 2, 1, 0, 0], dtype=int8)
523 """
524 import numpy as np
526 p = list(p)
527 cp = np.cumsum([0] + p)
528 assert np.allclose(1, cp[-1])
529 assert len(p) < 256
531 if not isinstance(random_state, np.random.RandomState):
532 random_state = np.random.RandomState(random_state)
534 x = random_state.random_sample(n)
535 out = np.empty(n, dtype="i1")
537 for i, (low, high) in enumerate(zip(cp[:-1], cp[1:])):
538 out[(x >= low) & (x < high)] = i
539 return out
542def random_state_data(n: int, random_state=None) -> list:
543 """Return a list of arrays that can initialize
544 ``np.random.RandomState``.
546 Parameters
547 ----------
548 n : int
549 Number of arrays to return.
550 random_state : int or np.random.RandomState, optional
551 If an int, is used to seed a new ``RandomState``.
552 """
553 import numpy as np
555 if not all(
556 hasattr(random_state, attr) for attr in ["normal", "beta", "bytes", "uniform"]
557 ):
558 random_state = np.random.RandomState(random_state)
560 random_data = random_state.bytes(624 * n * 4) # `n * 624` 32-bit integers
561 l = list(np.frombuffer(random_data, dtype="<u4").reshape((n, -1)))
562 assert len(l) == n
563 return l
566def is_integer(i) -> bool:
567 """
568 >>> is_integer(6)
569 True
570 >>> is_integer(42.0)
571 True
572 >>> is_integer('abc')
573 False
574 """
575 return isinstance(i, Integral) or (isinstance(i, float) and i.is_integer())
578ONE_ARITY_BUILTINS = {
579 abs,
580 all,
581 any,
582 ascii,
583 bool,
584 bytearray,
585 bytes,
586 callable,
587 chr,
588 classmethod,
589 complex,
590 dict,
591 dir,
592 enumerate,
593 eval,
594 float,
595 format,
596 frozenset,
597 hash,
598 hex,
599 id,
600 int,
601 iter,
602 len,
603 list,
604 max,
605 min,
606 next,
607 oct,
608 open,
609 ord,
610 range,
611 repr,
612 reversed,
613 round,
614 set,
615 slice,
616 sorted,
617 staticmethod,
618 str,
619 sum,
620 tuple,
621 type,
622 vars,
623 zip,
624 memoryview,
625}
626MULTI_ARITY_BUILTINS = {
627 compile,
628 delattr,
629 divmod,
630 filter,
631 getattr,
632 hasattr,
633 isinstance,
634 issubclass,
635 map,
636 pow,
637 setattr,
638}
641def getargspec(func):
642 """Version of inspect.getargspec that works with partial and warps."""
643 if isinstance(func, functools.partial):
644 return getargspec(func.func)
646 func = getattr(func, "__wrapped__", func)
647 if isinstance(func, type):
648 return inspect.getfullargspec(func.__init__)
649 else:
650 return inspect.getfullargspec(func)
653def takes_multiple_arguments(func, varargs=True):
654 """Does this function take multiple arguments?
656 >>> def f(x, y): pass
657 >>> takes_multiple_arguments(f)
658 True
660 >>> def f(x): pass
661 >>> takes_multiple_arguments(f)
662 False
664 >>> def f(x, y=None): pass
665 >>> takes_multiple_arguments(f)
666 False
668 >>> def f(*args): pass
669 >>> takes_multiple_arguments(f)
670 True
672 >>> class Thing:
673 ... def __init__(self, a): pass
674 >>> takes_multiple_arguments(Thing)
675 False
677 """
678 if func in ONE_ARITY_BUILTINS:
679 return False
680 elif func in MULTI_ARITY_BUILTINS:
681 return True
683 try:
684 spec = getargspec(func)
685 except Exception:
686 return False
688 try:
689 is_constructor = spec.args[0] == "self" and isinstance(func, type)
690 except Exception:
691 is_constructor = False
693 if varargs and spec.varargs:
694 return True
696 ndefaults = 0 if spec.defaults is None else len(spec.defaults)
697 return len(spec.args) - ndefaults - is_constructor > 1
700def get_named_args(func) -> list[str]:
701 """Get all non ``*args/**kwargs`` arguments for a function"""
702 s = inspect.signature(func)
703 return [
704 n
705 for n, p in s.parameters.items()
706 if p.kind in [p.POSITIONAL_OR_KEYWORD, p.POSITIONAL_ONLY, p.KEYWORD_ONLY]
707 ]
710class Dispatch:
711 """Simple single dispatch."""
713 def __init__(self, name=None):
714 self._lookup = {}
715 self._lazy = {}
716 if name:
717 self.__name__ = name
719 def register(self, type, func=None):
720 """Register dispatch of `func` on arguments of type `type`"""
722 def wrapper(func):
723 if isinstance(type, tuple):
724 for t in type:
725 self.register(t, func)
726 else:
727 self._lookup[type] = func
728 return func
730 return wrapper(func) if func is not None else wrapper
732 def register_lazy(self, toplevel, func=None):
733 """
734 Register a registration function which will be called if the
735 *toplevel* module (e.g. 'pandas') is ever loaded.
736 """
738 def wrapper(func):
739 self._lazy[toplevel] = func
740 return func
742 return wrapper(func) if func is not None else wrapper
744 def dispatch(self, cls):
745 """Return the function implementation for the given ``cls``"""
746 lk = self._lookup
747 if cls in lk:
748 return lk[cls]
749 for cls2 in cls.__mro__:
750 # Is a lazy registration function present?
751 try:
752 toplevel, _, _ = cls2.__module__.partition(".")
753 except Exception:
754 continue
755 try:
756 register = self._lazy[toplevel]
757 except KeyError:
758 pass
759 else:
760 register()
761 self._lazy.pop(toplevel, None)
762 meth = self.dispatch(cls) # recurse
763 lk[cls] = meth
764 lk[cls2] = meth
765 return meth
766 try:
767 impl = lk[cls2]
768 except KeyError:
769 pass
770 else:
771 if cls is not cls2:
772 # Cache lookup
773 lk[cls] = impl
774 return impl
775 raise TypeError(f"No dispatch for {cls}")
777 def __call__(self, arg, *args, **kwargs):
778 """
779 Call the corresponding method based on type of argument.
780 """
781 meth = self.dispatch(type(arg))
782 return meth(arg, *args, **kwargs)
784 @property
785 def __doc__(self):
786 try:
787 func = self.dispatch(object)
788 return func.__doc__
789 except TypeError:
790 return f"Single Dispatch for {self.__name__}"
793def ensure_not_exists(filename) -> None:
794 """
795 Ensure that a file does not exist.
796 """
797 try:
798 os.unlink(filename)
799 except OSError as e:
800 if e.errno != ENOENT:
801 raise
804def _skip_doctest(line):
805 # NumPy docstring contains cursor and comment only example
806 stripped = line.strip()
807 if stripped == ">>>" or stripped.startswith(">>> #"):
808 return line
809 elif ">>>" in stripped and "+SKIP" not in stripped:
810 if "# doctest:" in line:
811 return line + ", +SKIP"
812 else:
813 return line + " # doctest: +SKIP"
814 else:
815 return line
818def skip_doctest(doc):
819 if doc is None:
820 return ""
821 return "\n".join([_skip_doctest(line) for line in doc.split("\n")])
824def extra_titles(doc):
825 lines = doc.split("\n")
826 titles = {
827 i: lines[i].strip()
828 for i in range(len(lines) - 1)
829 if lines[i + 1].strip() and all(c == "-" for c in lines[i + 1].strip())
830 }
832 seen = set()
833 for i, title in sorted(titles.items()):
834 if title in seen:
835 new_title = "Extra " + title
836 lines[i] = lines[i].replace(title, new_title)
837 lines[i + 1] = lines[i + 1].replace("-" * len(title), "-" * len(new_title))
838 else:
839 seen.add(title)
841 return "\n".join(lines)
844def ignore_warning(doc, cls, name, extra="", skipblocks=0, inconsistencies=None):
845 """Expand docstring by adding disclaimer and extra text"""
846 import inspect
848 if inspect.isclass(cls):
849 l1 = f"This docstring was copied from {cls.__module__}.{cls.__name__}.{name}.\n\n"
850 else:
851 l1 = f"This docstring was copied from {cls.__name__}.{name}.\n\n"
852 l2 = "Some inconsistencies with the Dask version may exist."
854 i = doc.find("\n\n")
855 if i != -1:
856 # Insert our warning
857 head = doc[: i + 2]
858 tail = doc[i + 2 :]
859 while skipblocks > 0:
860 i = tail.find("\n\n")
861 head = tail[: i + 2]
862 tail = tail[i + 2 :]
863 skipblocks -= 1
864 # Indentation of next line
865 indent = re.match(r"\s*", tail).group(0)
866 # Insert the warning, indented, with a blank line before and after
867 if extra:
868 more = [indent, extra.rstrip("\n") + "\n\n"]
869 else:
870 more = []
871 if inconsistencies is not None:
872 l3 = f"Known inconsistencies: \n {inconsistencies}"
873 bits = [head, indent, l1, l2, "\n\n", l3, "\n\n"] + more + [tail]
874 else:
875 bits = [head, indent, l1, indent, l2, "\n\n"] + more + [tail]
876 doc = "".join(bits)
878 return doc
881def unsupported_arguments(doc, args):
882 """Mark unsupported arguments with a disclaimer"""
883 lines = doc.split("\n")
884 for arg in args:
885 subset = [
886 (i, line)
887 for i, line in enumerate(lines)
888 if re.match(r"^\s*" + arg + " ?:", line)
889 ]
890 if len(subset) == 1:
891 [(i, line)] = subset
892 lines[i] = line + " (Not supported in Dask)"
893 return "\n".join(lines)
896def _derived_from(
897 cls, method, ua_args=None, extra="", skipblocks=0, inconsistencies=None
898):
899 """Helper function for derived_from to ease testing"""
900 ua_args = ua_args or []
902 # do not use wraps here, as it hides keyword arguments displayed
903 # in the doc
904 original_method = getattr(cls, method.__name__)
906 doc = getattr(original_method, "__doc__", None)
908 if isinstance(original_method, property):
909 # some things like SeriesGroupBy.unique are generated.
910 original_method = original_method.fget
911 if not doc:
912 doc = getattr(original_method, "__doc__", None)
914 if isinstance(original_method, functools.cached_property):
915 original_method = original_method.func
916 if not doc:
917 doc = getattr(original_method, "__doc__", None)
919 if doc is None:
920 doc = ""
922 # pandas DataFrame/Series sometimes override methods without setting __doc__
923 if not doc and cls.__name__ in {"DataFrame", "Series"}:
924 for obj in cls.mro():
925 obj_method = getattr(obj, method.__name__, None)
926 if obj_method is not None and obj_method.__doc__:
927 doc = obj_method.__doc__
928 break
930 # Insert disclaimer that this is a copied docstring
931 if doc:
932 doc = ignore_warning(
933 doc,
934 cls,
935 method.__name__,
936 extra=extra,
937 skipblocks=skipblocks,
938 inconsistencies=inconsistencies,
939 )
940 elif extra:
941 doc += extra.rstrip("\n") + "\n\n"
943 # Mark unsupported arguments
944 try:
945 method_args = get_named_args(method)
946 original_args = get_named_args(original_method)
947 not_supported = [m for m in original_args if m not in method_args]
948 except ValueError:
949 not_supported = []
950 if len(ua_args) > 0:
951 not_supported.extend(ua_args)
952 if len(not_supported) > 0:
953 doc = unsupported_arguments(doc, not_supported)
955 doc = skip_doctest(doc)
956 doc = extra_titles(doc)
958 return doc
961def derived_from(
962 original_klass, version=None, ua_args=None, skipblocks=0, inconsistencies=None
963):
964 """Decorator to attach original class's docstring to the wrapped method.
966 The output structure will be: top line of docstring, disclaimer about this
967 being auto-derived, any extra text associated with the method being patched,
968 the body of the docstring and finally, the list of keywords that exist in
969 the original method but not in the dask version.
971 Parameters
972 ----------
973 original_klass: type
974 Original class which the method is derived from
975 version : str
976 Original package version which supports the wrapped method
977 ua_args : list
978 List of keywords which Dask doesn't support. Keywords existing in
979 original but not in Dask will automatically be added.
980 skipblocks : int
981 How many text blocks (paragraphs) to skip from the start of the
982 docstring. Useful for cases where the target has extra front-matter.
983 inconsistencies: list
984 List of known inconsistencies with method whose docstrings are being
985 copied.
986 """
987 ua_args = ua_args or []
989 def wrapper(method):
990 try:
991 extra = getattr(method, "__doc__", None) or ""
992 method.__doc__ = _derived_from(
993 original_klass,
994 method,
995 ua_args=ua_args,
996 extra=extra,
997 skipblocks=skipblocks,
998 inconsistencies=inconsistencies,
999 )
1000 return method
1002 except AttributeError:
1003 module_name = original_klass.__module__.split(".")[0]
1005 @functools.wraps(method)
1006 def wrapped(*args, **kwargs):
1007 msg = f"Base package doesn't support '{method.__name__}'."
1008 if version is not None:
1009 msg2 = " Use {0} {1} or later to use this method."
1010 msg += msg2.format(module_name, version)
1011 raise NotImplementedError(msg)
1013 return wrapped
1015 return wrapper
1018def funcname(func) -> str:
1019 """Get the name of a function."""
1020 # functools.partial
1021 if isinstance(func, functools.partial):
1022 return funcname(func.func)
1023 # methodcaller
1024 if isinstance(func, methodcaller):
1025 return func.method[:50]
1027 module_name = getattr(func, "__module__", None) or ""
1028 type_name = getattr(type(func), "__name__", None) or ""
1030 # toolz.curry
1031 if "toolz" in module_name and "curry" == type_name:
1032 return func.func_name[:50]
1033 # multipledispatch objects
1034 if "multipledispatch" in module_name and "Dispatcher" == type_name:
1035 return func.name[:50]
1036 # numpy.vectorize objects
1037 if "numpy" in module_name and "vectorize" == type_name:
1038 return ("vectorize_" + funcname(func.pyfunc))[:50]
1040 # All other callables
1041 try:
1042 name = func.__name__
1043 if name == "<lambda>":
1044 return "lambda"
1045 return name[:50]
1046 except AttributeError:
1047 return str(func)[:50]
1050def typename(typ: Any, short: bool = False) -> str:
1051 """
1052 Return the name of a type
1054 Examples
1055 --------
1056 >>> typename(int)
1057 'int'
1059 >>> from dask.core import literal
1060 >>> typename(literal)
1061 'dask.core.literal'
1062 >>> typename(literal, short=True)
1063 'dask.literal'
1064 """
1065 if not isinstance(typ, type):
1066 return typename(type(typ))
1067 try:
1068 if not typ.__module__ or typ.__module__ == "builtins":
1069 return typ.__name__
1070 else:
1071 if short:
1072 module, *_ = typ.__module__.split(".")
1073 else:
1074 module = typ.__module__
1075 return module + "." + typ.__name__
1076 except AttributeError:
1077 return str(typ)
1080def ensure_bytes(s) -> bytes:
1081 """Attempt to turn `s` into bytes.
1083 Parameters
1084 ----------
1085 s : Any
1086 The object to be converted. Will correctly handled
1087 * str
1088 * bytes
1089 * objects implementing the buffer protocol (memoryview, ndarray, etc.)
1091 Returns
1092 -------
1093 b : bytes
1095 Raises
1096 ------
1097 TypeError
1098 When `s` cannot be converted
1100 Examples
1101 --------
1102 >>> ensure_bytes('123')
1103 b'123'
1104 >>> ensure_bytes(b'123')
1105 b'123'
1106 >>> ensure_bytes(bytearray(b'123'))
1107 b'123'
1108 """
1109 if isinstance(s, bytes):
1110 return s
1111 elif hasattr(s, "encode"):
1112 return s.encode()
1113 else:
1114 try:
1115 return bytes(s)
1116 except Exception as e:
1117 raise TypeError(
1118 f"Object {s} is neither a bytes object nor can be encoded to bytes"
1119 ) from e
1122def ensure_unicode(s) -> str:
1123 """Turn string or bytes to string
1125 >>> ensure_unicode('123')
1126 '123'
1127 >>> ensure_unicode(b'123')
1128 '123'
1129 """
1130 if isinstance(s, str):
1131 return s
1132 elif hasattr(s, "decode"):
1133 return s.decode()
1134 else:
1135 try:
1136 return codecs.decode(s)
1137 except Exception as e:
1138 raise TypeError(
1139 f"Object {s} is neither a str object nor can be decoded to str"
1140 ) from e
1143def digit(n, k, base):
1144 """
1146 >>> digit(1234, 0, 10)
1147 4
1148 >>> digit(1234, 1, 10)
1149 3
1150 >>> digit(1234, 2, 10)
1151 2
1152 >>> digit(1234, 3, 10)
1153 1
1154 """
1155 return n // base**k % base
1158def insert(tup, loc, val):
1159 """
1161 >>> insert(('a', 'b', 'c'), 0, 'x')
1162 ('x', 'b', 'c')
1163 """
1164 L = list(tup)
1165 L[loc] = val
1166 return tuple(L)
1169def memory_repr(num):
1170 for x in ["bytes", "KB", "MB", "GB", "TB"]:
1171 if num < 1024.0:
1172 return f"{num:3.1f} {x}"
1173 num /= 1024.0
1176def asciitable(columns, rows):
1177 """Formats an ascii table for given columns and rows.
1179 Parameters
1180 ----------
1181 columns : list
1182 The column names
1183 rows : list of tuples
1184 The rows in the table. Each tuple must be the same length as
1185 ``columns``.
1186 """
1187 rows = [tuple(str(i) for i in r) for r in rows]
1188 columns = tuple(str(i) for i in columns)
1189 widths = tuple(max(*map(len, x), len(c)) for x, c in zip(zip(*rows), columns))
1190 row_template = ("|" + (" %%-%ds |" * len(columns))) % widths
1191 header = row_template % tuple(columns)
1192 bar = "+{}+".format("+".join("-" * (w + 2) for w in widths))
1193 data = "\n".join(row_template % r for r in rows)
1194 return "\n".join([bar, header, bar, data, bar])
1197def put_lines(buf, lines):
1198 if any(not isinstance(x, str) for x in lines):
1199 lines = [str(x) for x in lines]
1200 buf.write("\n".join(lines))
1203_method_cache: dict[str, methodcaller] = {}
1206class methodcaller:
1207 """
1208 Return a callable object that calls the given method on its operand.
1210 Unlike the builtin `operator.methodcaller`, instances of this class are
1211 cached and arguments are passed at call time instead of build time.
1212 """
1214 __slots__ = ("method",)
1215 method: str
1217 @property
1218 def func(self) -> str:
1219 # For `funcname` to work
1220 return self.method
1222 def __new__(cls, method: str):
1223 try:
1224 return _method_cache[method]
1225 except KeyError:
1226 self = object.__new__(cls)
1227 self.method = method
1228 _method_cache[method] = self
1229 return self
1231 def __call__(self, __obj, *args, **kwargs):
1232 return getattr(__obj, self.method)(*args, **kwargs)
1234 def __reduce__(self):
1235 return (methodcaller, (self.method,))
1237 def __str__(self):
1238 return f"<{self.__class__.__name__}: {self.method}>"
1240 __repr__ = __str__
1243class itemgetter:
1244 """Variant of operator.itemgetter that supports equality tests"""
1246 __slots__ = ("index",)
1248 def __init__(self, index):
1249 self.index = index
1251 def __call__(self, x):
1252 return x[self.index]
1254 def __reduce__(self):
1255 return (itemgetter, (self.index,))
1257 def __eq__(self, other):
1258 return type(self) is type(other) and self.index == other.index
1261class MethodCache:
1262 """Attribute access on this object returns a methodcaller for that
1263 attribute.
1265 Examples
1266 --------
1267 >>> a = [1, 3, 3]
1268 >>> M.count(a, 3) == a.count(3)
1269 True
1270 """
1272 def __getattr__(self, item):
1273 return methodcaller(item)
1275 def __dir__(self):
1276 return list(_method_cache)
1279M = MethodCache()
1282class SerializableLock:
1283 """A Serializable per-process Lock
1285 This wraps a normal ``threading.Lock`` object and satisfies the same
1286 interface. However, this lock can also be serialized and sent to different
1287 processes. It will not block concurrent operations between processes (for
1288 this you should look at ``multiprocessing.Lock`` or ``locket.lock_file``
1289 but will consistently deserialize into the same lock.
1291 So if we make a lock in one process::
1293 lock = SerializableLock()
1295 And then send it over to another process multiple times::
1297 bytes = pickle.dumps(lock)
1298 a = pickle.loads(bytes)
1299 b = pickle.loads(bytes)
1301 Then the deserialized objects will operate as though they were the same
1302 lock, and collide as appropriate.
1304 This is useful for consistently protecting resources on a per-process
1305 level.
1307 The creation of locks is itself not threadsafe.
1308 """
1310 _locks: ClassVar[WeakValueDictionary[Hashable, Lock]] = WeakValueDictionary()
1311 token: Hashable
1312 lock: Lock
1314 def __init__(self, token: Hashable | None = None):
1315 self.token = token or str(uuid.uuid4())
1316 if self.token in SerializableLock._locks:
1317 self.lock = SerializableLock._locks[self.token]
1318 else:
1319 self.lock = Lock()
1320 SerializableLock._locks[self.token] = self.lock
1322 def acquire(self, *args, **kwargs):
1323 return self.lock.acquire(*args, **kwargs)
1325 def release(self, *args, **kwargs):
1326 return self.lock.release(*args, **kwargs)
1328 def __enter__(self):
1329 self.lock.__enter__()
1331 def __exit__(self, *args):
1332 self.lock.__exit__(*args)
1334 def locked(self):
1335 return self.lock.locked()
1337 def __getstate__(self):
1338 return self.token
1340 def __setstate__(self, token):
1341 self.__init__(token)
1343 def __str__(self):
1344 return f"<{self.__class__.__name__}: {self.token}>"
1346 __repr__ = __str__
1349def get_scheduler_lock(collection=None, scheduler=None):
1350 """Get an instance of the appropriate lock for a certain situation based on
1351 scheduler used."""
1352 from dask import multiprocessing
1353 from dask.base import get_scheduler
1355 actual_get = get_scheduler(collections=[collection], scheduler=scheduler)
1357 if actual_get == multiprocessing.get:
1358 return multiprocessing.get_context().Manager().Lock()
1359 else:
1360 # if this is a distributed client, we need to lock on
1361 # the level between processes, SerializableLock won't work
1362 try:
1363 import distributed.lock
1364 from distributed.worker import get_client
1366 client = get_client()
1367 except (ImportError, ValueError):
1368 pass
1369 else:
1370 if actual_get == client.get:
1371 return distributed.lock.Lock()
1373 return SerializableLock()
1376def ensure_dict(d: Mapping[K, V], *, copy: bool = False) -> dict[K, V]:
1377 """Convert a generic Mapping into a dict.
1378 Optimize use case of :class:`~dask.highlevelgraph.HighLevelGraph`.
1380 Parameters
1381 ----------
1382 d : Mapping
1383 copy : bool
1384 If True, guarantee that the return value is always a shallow copy of d;
1385 otherwise it may be the input itself.
1386 """
1387 if type(d) is dict:
1388 return d.copy() if copy else d
1389 try:
1390 layers = d.layers # type: ignore
1391 except AttributeError:
1392 return dict(d)
1394 result = {}
1395 for layer in toolz.unique(layers.values(), key=id):
1396 result.update(layer)
1397 return result
1400def ensure_set(s: Set[T], *, copy: bool = False) -> set[T]:
1401 """Convert a generic Set into a set.
1403 Parameters
1404 ----------
1405 s : Set
1406 copy : bool
1407 If True, guarantee that the return value is always a shallow copy of s;
1408 otherwise it may be the input itself.
1409 """
1410 if type(s) is set:
1411 return s.copy() if copy else s
1412 return set(s)
1415class OperatorMethodMixin:
1416 """A mixin for dynamically implementing operators"""
1418 __slots__ = ()
1420 @classmethod
1421 def _bind_operator(cls, op):
1422 """bind operator to this class"""
1423 name = op.__name__
1425 if name.endswith("_"):
1426 # for and_ and or_
1427 name = name[:-1]
1428 elif name == "inv":
1429 name = "invert"
1431 meth = f"__{name}__"
1433 if name in ("abs", "invert", "neg", "pos"):
1434 setattr(cls, meth, cls._get_unary_operator(op))
1435 else:
1436 setattr(cls, meth, cls._get_binary_operator(op))
1438 if name in ("eq", "gt", "ge", "lt", "le", "ne", "getitem"):
1439 return
1441 rmeth = f"__r{name}__"
1442 setattr(cls, rmeth, cls._get_binary_operator(op, inv=True))
1444 @classmethod
1445 def _get_unary_operator(cls, op):
1446 """Must return a method used by unary operator"""
1447 raise NotImplementedError
1449 @classmethod
1450 def _get_binary_operator(cls, op, inv=False):
1451 """Must return a method used by binary operator"""
1452 raise NotImplementedError
1455def partial_by_order(*args, **kwargs):
1456 """
1458 >>> from operator import add
1459 >>> partial_by_order(5, function=add, other=[(1, 10)])
1460 15
1461 """
1462 function = kwargs.pop("function")
1463 other = kwargs.pop("other")
1464 args2 = list(args)
1465 for i, arg in other:
1466 args2.insert(i, arg)
1467 return function(*args2, **kwargs)
1470def is_arraylike(x) -> bool:
1471 """Is this object a numpy array or something similar?
1473 This function tests specifically for an object that already has
1474 array attributes (e.g. np.ndarray, dask.array.Array, cupy.ndarray,
1475 sparse.COO), **NOT** for something that can be coerced into an
1476 array object (e.g. Python lists and tuples). It is meant for dask
1477 developers and developers of downstream libraries.
1479 Note that this function does not correspond with NumPy's
1480 definition of array_like, which includes any object that can be
1481 coerced into an array (see definition in the NumPy glossary):
1482 https://numpy.org/doc/stable/glossary.html
1484 Examples
1485 --------
1486 >>> import numpy as np
1487 >>> is_arraylike(np.ones(5))
1488 True
1489 >>> is_arraylike(np.ones(()))
1490 True
1491 >>> is_arraylike(5)
1492 False
1493 >>> is_arraylike('cat')
1494 False
1495 """
1496 from dask.base import is_dask_collection
1498 is_duck_array = hasattr(x, "__array_function__") or hasattr(x, "__array_ufunc__")
1500 return bool(
1501 hasattr(x, "shape")
1502 and isinstance(x.shape, tuple)
1503 and hasattr(x, "dtype")
1504 and not any(is_dask_collection(n) for n in x.shape)
1505 # We special case scipy.sparse and cupyx.scipy.sparse arrays as having partial
1506 # support for them is useful in scenarios where we mostly call `map_partitions`
1507 # or `map_blocks` with scikit-learn functions on dask arrays and dask dataframes.
1508 # https://github.com/dask/dask/pull/3738
1509 and (is_duck_array or "scipy.sparse" in typename(type(x)))
1510 )
1513def is_dataframe_like(df) -> bool:
1514 """Looks like a Pandas DataFrame"""
1515 if (df.__class__.__module__, df.__class__.__name__) == (
1516 "pandas.core.frame",
1517 "DataFrame",
1518 ):
1519 # fast exec for most likely input
1520 return True
1521 typ = df.__class__
1522 return (
1523 all(hasattr(typ, name) for name in ("groupby", "head", "merge", "mean"))
1524 and all(hasattr(df, name) for name in ("dtypes", "columns"))
1525 and not any(hasattr(typ, name) for name in ("name", "dtype"))
1526 )
1529def is_series_like(s) -> bool:
1530 """Looks like a Pandas Series"""
1531 typ = s.__class__
1532 return (
1533 all(hasattr(typ, name) for name in ("groupby", "head", "mean"))
1534 and all(hasattr(s, name) for name in ("dtype", "name"))
1535 and "index" not in typ.__name__.lower()
1536 )
1539def is_index_like(s) -> bool:
1540 """Looks like a Pandas Index"""
1541 typ = s.__class__
1542 return (
1543 all(hasattr(s, name) for name in ("name", "dtype"))
1544 and "index" in typ.__name__.lower()
1545 )
1548def is_cupy_type(x) -> bool:
1549 # TODO: avoid explicit reference to CuPy
1550 return "cupy" in str(type(x))
1553def natural_sort_key(s: str) -> list[str | int]:
1554 """
1555 Sorting `key` function for performing a natural sort on a collection of
1556 strings
1558 See https://en.wikipedia.org/wiki/Natural_sort_order
1560 Parameters
1561 ----------
1562 s : str
1563 A string that is an element of the collection being sorted
1565 Returns
1566 -------
1567 tuple[str or int]
1568 Tuple of the parts of the input string where each part is either a
1569 string or an integer
1571 Examples
1572 --------
1573 >>> a = ['f0', 'f1', 'f2', 'f8', 'f9', 'f10', 'f11', 'f19', 'f20', 'f21']
1574 >>> sorted(a)
1575 ['f0', 'f1', 'f10', 'f11', 'f19', 'f2', 'f20', 'f21', 'f8', 'f9']
1576 >>> sorted(a, key=natural_sort_key)
1577 ['f0', 'f1', 'f2', 'f8', 'f9', 'f10', 'f11', 'f19', 'f20', 'f21']
1578 """
1579 return [int(part) if part.isdigit() else part for part in re.split(r"(\d+)", s)]
1582def parse_bytes(s: float | str) -> int:
1583 """Parse byte string to numbers
1585 >>> from dask.utils import parse_bytes
1586 >>> parse_bytes('100')
1587 100
1588 >>> parse_bytes('100 MB')
1589 100000000
1590 >>> parse_bytes('100M')
1591 100000000
1592 >>> parse_bytes('5kB')
1593 5000
1594 >>> parse_bytes('5.4 kB')
1595 5400
1596 >>> parse_bytes('1kiB')
1597 1024
1598 >>> parse_bytes('1e6')
1599 1000000
1600 >>> parse_bytes('1e6 kB')
1601 1000000000
1602 >>> parse_bytes('MB')
1603 1000000
1604 >>> parse_bytes(123)
1605 123
1606 >>> parse_bytes('5 foos')
1607 Traceback (most recent call last):
1608 ...
1609 ValueError: Could not interpret 'foos' as a byte unit
1610 """
1611 if isinstance(s, (int, float)):
1612 return int(s)
1613 s = s.replace(" ", "")
1614 if not any(char.isdigit() for char in s):
1615 s = "1" + s
1617 for i in range(len(s) - 1, -1, -1):
1618 if not s[i].isalpha():
1619 break
1620 index = i + 1
1622 prefix = s[:index]
1623 suffix = s[index:]
1625 try:
1626 n = float(prefix)
1627 except ValueError as e:
1628 raise ValueError(f"Could not interpret '{prefix}' as a number") from e
1630 try:
1631 multiplier = byte_sizes[suffix.lower()]
1632 except KeyError as e:
1633 raise ValueError(f"Could not interpret '{suffix}' as a byte unit") from e
1635 result = n * multiplier
1636 return int(result)
1639byte_sizes = {
1640 "kB": 10**3,
1641 "MB": 10**6,
1642 "GB": 10**9,
1643 "TB": 10**12,
1644 "PB": 10**15,
1645 "KiB": 2**10,
1646 "MiB": 2**20,
1647 "GiB": 2**30,
1648 "TiB": 2**40,
1649 "PiB": 2**50,
1650 "B": 1,
1651 "": 1,
1652}
1653byte_sizes = {k.lower(): v for k, v in byte_sizes.items()}
1654byte_sizes.update({k[0]: v for k, v in byte_sizes.items() if k and "i" not in k})
1655byte_sizes.update({k[:-1]: v for k, v in byte_sizes.items() if k and "i" in k})
1658def format_time(n: float) -> str:
1659 """format integers as time
1661 >>> from dask.utils import format_time
1662 >>> format_time(1)
1663 '1.00 s'
1664 >>> format_time(0.001234)
1665 '1.23 ms'
1666 >>> format_time(0.00012345)
1667 '123.45 us'
1668 >>> format_time(123.456)
1669 '123.46 s'
1670 >>> format_time(1234.567)
1671 '20m 34s'
1672 >>> format_time(12345.67)
1673 '3hr 25m'
1674 >>> format_time(123456.78)
1675 '34hr 17m'
1676 >>> format_time(1234567.89)
1677 '14d 6hr'
1678 """
1679 if n > 24 * 60 * 60 * 2:
1680 d = int(n / 3600 / 24)
1681 h = int((n - d * 3600 * 24) / 3600)
1682 return f"{d}d {h}hr"
1683 if n > 60 * 60 * 2:
1684 h = int(n / 3600)
1685 m = int((n - h * 3600) / 60)
1686 return f"{h}hr {m}m"
1687 if n > 60 * 10:
1688 m = int(n / 60)
1689 s = int(n - m * 60)
1690 return f"{m}m {s}s"
1691 if n >= 1:
1692 return f"{n:.2f} s"
1693 if n >= 1e-3:
1694 return "%.2f ms" % (n * 1e3)
1695 return "%.2f us" % (n * 1e6)
1698def format_time_ago(n: datetime) -> str:
1699 """Calculate a '3 hours ago' type string from a Python datetime.
1701 Examples
1702 --------
1703 >>> from datetime import datetime, timedelta
1705 >>> now = datetime.now()
1706 >>> format_time_ago(now)
1707 'Just now'
1709 >>> past = datetime.now() - timedelta(minutes=1)
1710 >>> format_time_ago(past)
1711 '1 minute ago'
1713 >>> past = datetime.now() - timedelta(minutes=2)
1714 >>> format_time_ago(past)
1715 '2 minutes ago'
1717 >>> past = datetime.now() - timedelta(hours=1)
1718 >>> format_time_ago(past)
1719 '1 hour ago'
1721 >>> past = datetime.now() - timedelta(hours=6)
1722 >>> format_time_ago(past)
1723 '6 hours ago'
1725 >>> past = datetime.now() - timedelta(days=1)
1726 >>> format_time_ago(past)
1727 '1 day ago'
1729 >>> past = datetime.now() - timedelta(days=5)
1730 >>> format_time_ago(past)
1731 '5 days ago'
1733 >>> past = datetime.now() - timedelta(days=8)
1734 >>> format_time_ago(past)
1735 '1 week ago'
1737 >>> past = datetime.now() - timedelta(days=16)
1738 >>> format_time_ago(past)
1739 '2 weeks ago'
1741 >>> past = datetime.now() - timedelta(days=190)
1742 >>> format_time_ago(past)
1743 '6 months ago'
1745 >>> past = datetime.now() - timedelta(days=800)
1746 >>> format_time_ago(past)
1747 '2 years ago'
1749 """
1750 units = {
1751 "years": lambda diff: diff.days / 365,
1752 "months": lambda diff: diff.days / 30.436875, # Average days per month
1753 "weeks": lambda diff: diff.days / 7,
1754 "days": lambda diff: diff.days,
1755 "hours": lambda diff: diff.seconds / 3600,
1756 "minutes": lambda diff: diff.seconds % 3600 / 60,
1757 }
1758 diff = datetime.now() - n
1759 for unit, func in units.items():
1760 dur = int(func(diff))
1761 if dur > 0:
1762 if dur == 1: # De-pluralize
1763 unit = unit[:-1]
1764 return f"{dur} {unit} ago"
1765 return "Just now"
1768def format_bytes(n: int) -> str:
1769 """Format bytes as text
1771 >>> from dask.utils import format_bytes
1772 >>> format_bytes(1)
1773 '1 B'
1774 >>> format_bytes(1234)
1775 '1.21 kiB'
1776 >>> format_bytes(12345678)
1777 '11.77 MiB'
1778 >>> format_bytes(1234567890)
1779 '1.15 GiB'
1780 >>> format_bytes(1234567890000)
1781 '1.12 TiB'
1782 >>> format_bytes(1234567890000000)
1783 '1.10 PiB'
1785 For all values < 2**60, the output is always <= 10 characters.
1786 """
1787 for prefix, k in (
1788 ("Pi", 2**50),
1789 ("Ti", 2**40),
1790 ("Gi", 2**30),
1791 ("Mi", 2**20),
1792 ("ki", 2**10),
1793 ):
1794 if n >= k * 0.9:
1795 return f"{n / k:.2f} {prefix}B"
1796 return f"{n} B"
1799timedelta_sizes = {
1800 "s": 1,
1801 "ms": 1e-3,
1802 "us": 1e-6,
1803 "ns": 1e-9,
1804 "m": 60,
1805 "h": 3600,
1806 "d": 3600 * 24,
1807 "w": 7 * 3600 * 24,
1808}
1810tds2 = {
1811 "second": 1,
1812 "minute": 60,
1813 "hour": 60 * 60,
1814 "day": 60 * 60 * 24,
1815 "week": 7 * 60 * 60 * 24,
1816 "millisecond": 1e-3,
1817 "microsecond": 1e-6,
1818 "nanosecond": 1e-9,
1819}
1820tds2.update({k + "s": v for k, v in tds2.items()})
1821timedelta_sizes.update(tds2)
1822timedelta_sizes.update({k.upper(): v for k, v in timedelta_sizes.items()})
1825@overload
1826def parse_timedelta(s: None, default: str | Literal[False] = "seconds") -> None: ...
1829@overload
1830def parse_timedelta(
1831 s: str | float | timedelta, default: str | Literal[False] = "seconds"
1832) -> float: ...
1835def parse_timedelta(s, default="seconds"):
1836 """Parse timedelta string to number of seconds
1838 Parameters
1839 ----------
1840 s : str, float, timedelta, or None
1841 default: str or False, optional
1842 Unit of measure if s does not specify one. Defaults to seconds.
1843 Set to False to require s to explicitly specify its own unit.
1845 Examples
1846 --------
1847 >>> from datetime import timedelta
1848 >>> from dask.utils import parse_timedelta
1849 >>> parse_timedelta('3s')
1850 3
1851 >>> parse_timedelta('3.5 seconds')
1852 3.5
1853 >>> parse_timedelta('300ms')
1854 0.3
1855 >>> parse_timedelta(timedelta(seconds=3)) # also supports timedeltas
1856 3
1857 """
1858 if s is None:
1859 return None
1860 if isinstance(s, timedelta):
1861 s = s.total_seconds()
1862 return int(s) if int(s) == s else s
1863 if isinstance(s, Number):
1864 s = str(s)
1865 s = s.replace(" ", "")
1866 if not s[0].isdigit():
1867 s = "1" + s
1869 for i in range(len(s) - 1, -1, -1):
1870 if not s[i].isalpha():
1871 break
1872 index = i + 1
1874 prefix = s[:index]
1875 suffix = s[index:] or default
1876 if suffix is False:
1877 raise ValueError(f"Missing time unit: {s}")
1878 if not isinstance(suffix, str):
1879 raise TypeError(f"default must be str or False, got {default!r}")
1881 n = float(prefix)
1883 try:
1884 multiplier = timedelta_sizes[suffix.lower()]
1885 except KeyError:
1886 valid_units = ", ".join(timedelta_sizes.keys())
1887 raise KeyError(
1888 f"Invalid time unit: {suffix}. Valid units are: {valid_units}"
1889 ) from None
1891 result = n * multiplier
1892 if int(result) == result:
1893 result = int(result)
1894 return result
1897def has_keyword(func, keyword):
1898 try:
1899 return keyword in inspect.signature(func).parameters
1900 except Exception:
1901 return False
1904def ndimlist(seq):
1905 if not isinstance(seq, (list, tuple)):
1906 return 0
1907 elif not seq:
1908 return 1
1909 else:
1910 return 1 + ndimlist(seq[0])
1913def iter_chunks(sizes, max_size):
1914 """Split sizes into chunks of total max_size each
1916 Parameters
1917 ----------
1918 sizes : iterable of numbers
1919 The sizes to be chunked
1920 max_size : number
1921 Maximum total size per chunk.
1922 It must be greater or equal than each size in sizes
1923 """
1924 chunk, chunk_sum = [], 0
1925 iter_sizes = iter(sizes)
1926 size = next(iter_sizes, None)
1927 while size is not None:
1928 assert size <= max_size
1929 if chunk_sum + size <= max_size:
1930 chunk.append(size)
1931 chunk_sum += size
1932 size = next(iter_sizes, None)
1933 else:
1934 assert chunk
1935 yield chunk
1936 chunk, chunk_sum = [], 0
1937 if chunk:
1938 yield chunk
1941hex_pattern = re.compile("[a-f]+")
1944@functools.lru_cache(100000)
1945def key_split(s):
1946 """
1947 >>> key_split('x')
1948 'x'
1949 >>> key_split('x-1')
1950 'x'
1951 >>> key_split('x-1-2-3')
1952 'x'
1953 >>> key_split(('x-2', 1))
1954 'x'
1955 >>> key_split("('x-2', 1)")
1956 'x'
1957 >>> key_split("('x', 1)")
1958 'x'
1959 >>> key_split('hello-world-1')
1960 'hello-world'
1961 >>> key_split(b'hello-world-1')
1962 'hello-world'
1963 >>> key_split('ae05086432ca935f6eba409a8ecd4896')
1964 'data'
1965 >>> key_split('<module.submodule.myclass object at 0xdaf372')
1966 'myclass'
1967 >>> key_split(None)
1968 'Other'
1969 >>> key_split('x-abcdefab') # ignores hex
1970 'x'
1971 >>> key_split('_(x)') # strips unpleasant characters
1972 'x'
1973 """
1974 # If we convert the key, recurse to utilize LRU cache better
1975 if type(s) is bytes:
1976 return key_split(s.decode())
1977 if type(s) is tuple:
1978 return key_split(s[0])
1979 try:
1980 words = s.split("-")
1981 if not words[0][0].isalpha():
1982 result = words[0].split(",")[0].strip("_'()\"")
1983 else:
1984 result = words[0]
1985 for word in words[1:]:
1986 if word.isalpha() and not (
1987 len(word) == 8 and hex_pattern.match(word) is not None
1988 ):
1989 result += "-" + word
1990 else:
1991 break
1992 if len(result) == 32 and re.match(r"[a-f0-9]{32}", result):
1993 return "data"
1994 else:
1995 if result[0] == "<":
1996 result = result.strip("<>").split()[0].split(".")[-1]
1997 return sys.intern(result)
1998 except Exception:
1999 return "Other"
2002def stringify(obj, exclusive: Iterable | None = None):
2003 """Convert an object to a string
2005 If ``exclusive`` is specified, search through `obj` and convert
2006 values that are in ``exclusive``.
2008 Note that when searching through dictionaries, only values are
2009 converted, not the keys.
2011 Parameters
2012 ----------
2013 obj : Any
2014 Object (or values within) to convert to string
2015 exclusive: Iterable, optional
2016 Set of values to search for when converting values to strings
2018 Returns
2019 -------
2020 result : type(obj)
2021 Stringified copy of ``obj`` or ``obj`` itself if it is already a
2022 string or bytes.
2024 Examples
2025 --------
2026 >>> stringify(b'x')
2027 b'x'
2028 >>> stringify('x')
2029 'x'
2030 >>> stringify({('a',0):('a',0), ('a',1): ('a',1)})
2031 "{('a', 0): ('a', 0), ('a', 1): ('a', 1)}"
2032 >>> stringify({('a',0):('a',0), ('a',1): ('a',1)}, exclusive={('a',0)})
2033 {('a', 0): "('a', 0)", ('a', 1): ('a', 1)}
2034 """
2036 typ = type(obj)
2037 if typ is str or typ is bytes:
2038 return obj
2039 elif exclusive is None:
2040 return str(obj)
2042 if typ is list:
2043 return [stringify(v, exclusive) for v in obj]
2044 if typ is dict:
2045 return {k: stringify(v, exclusive) for k, v in obj.items()}
2046 try:
2047 if obj in exclusive:
2048 return stringify(obj)
2049 except TypeError: # `obj` not hashable
2050 pass
2051 if typ is tuple: # If the tuple itself isn't a key, check its elements
2052 return tuple(stringify(v, exclusive) for v in obj)
2053 return obj
2056class cached_property(functools.cached_property):
2057 """Read only version of functools.cached_property."""
2059 def __set__(self, instance, val):
2060 """Raise an error when attempting to set a cached property."""
2061 raise AttributeError("Can't set attribute")
2064class _HashIdWrapper:
2065 """Hash and compare a wrapped object by identity instead of value"""
2067 def __init__(self, wrapped):
2068 self.wrapped = wrapped
2070 def __eq__(self, other):
2071 if not isinstance(other, _HashIdWrapper):
2072 return NotImplemented
2073 return self.wrapped is other.wrapped
2075 def __ne__(self, other):
2076 if not isinstance(other, _HashIdWrapper):
2077 return NotImplemented
2078 return self.wrapped is not other.wrapped
2080 def __hash__(self):
2081 return id(self.wrapped)
2084@functools.lru_cache
2085def _cumsum(seq, initial_zero):
2086 if isinstance(seq, _HashIdWrapper):
2087 seq = seq.wrapped
2088 if initial_zero:
2089 return tuple(toolz.accumulate(add, seq, 0))
2090 else:
2091 return tuple(toolz.accumulate(add, seq))
2094@functools.lru_cache
2095def _max(seq):
2096 if isinstance(seq, _HashIdWrapper):
2097 seq = seq.wrapped
2098 return max(seq)
2101def cached_max(seq):
2102 """Compute max with caching.
2104 Caching is by the identity of `seq` rather than the value. It is thus
2105 important that `seq` is a tuple of immutable objects, and this function
2106 is intended for use where `seq` is a value that will persist (generally
2107 block sizes).
2109 Parameters
2110 ----------
2111 seq : tuple
2112 Values to reduce
2114 Returns
2115 -------
2116 tuple
2117 """
2118 assert isinstance(seq, tuple)
2119 # Look up by identity first, to avoid a linear-time __hash__
2120 # if we've seen this tuple object before.
2121 result = _max(_HashIdWrapper(seq))
2122 return result
2125def cached_cumsum(seq, initial_zero=False):
2126 """Compute :meth:`toolz.accumulate` with caching.
2128 Caching is by the identify of `seq` rather than the value. It is thus
2129 important that `seq` is a tuple of immutable objects, and this function
2130 is intended for use where `seq` is a value that will persist (generally
2131 block sizes).
2133 Parameters
2134 ----------
2135 seq : tuple
2136 Values to cumulatively sum.
2137 initial_zero : bool, optional
2138 If true, the return value is prefixed with a zero.
2140 Returns
2141 -------
2142 tuple
2143 """
2144 if isinstance(seq, tuple):
2145 # Look up by identity first, to avoid a linear-time __hash__
2146 # if we've seen this tuple object before.
2147 result = _cumsum(_HashIdWrapper(seq), initial_zero)
2148 else:
2149 # Construct a temporary tuple, and look up by value.
2150 result = _cumsum(tuple(seq), initial_zero)
2151 return result
2154def show_versions() -> None:
2155 """Provide version information for bug reports."""
2157 from json import dumps
2158 from platform import uname
2159 from sys import stdout, version_info
2161 from dask._compatibility import importlib_metadata
2163 try:
2164 from distributed import __version__ as distributed_version
2165 except ImportError:
2166 distributed_version = None
2168 from dask import __version__ as dask_version
2170 deps = [
2171 "numpy",
2172 "pandas",
2173 "cloudpickle",
2174 "fsspec",
2175 "bokeh",
2176 "pyarrow",
2177 "zarr",
2178 ]
2180 result: dict[str, str | None] = {
2181 # note: only major, minor, micro are extracted
2182 "Python": ".".join([str(i) for i in version_info[:3]]),
2183 "Platform": uname().system,
2184 "dask": dask_version,
2185 "distributed": distributed_version,
2186 }
2188 for modname in deps:
2189 try:
2190 result[modname] = importlib_metadata.version(modname)
2191 except importlib_metadata.PackageNotFoundError:
2192 result[modname] = None
2194 stdout.writelines(dumps(result, indent=2))
2197def maybe_pluralize(count, noun, plural_form=None):
2198 """Pluralize a count-noun string pattern when necessary"""
2199 if count == 1:
2200 return f"{count} {noun}"
2201 else:
2202 return f"{count} {plural_form or noun + 's'}"
2205def is_namedtuple_instance(obj: Any) -> bool:
2206 """Returns True if obj is an instance of a namedtuple.
2208 Note: This function checks for the existence of the methods and
2209 attributes that make up the namedtuple API, so it will return True
2210 IFF obj's type implements that API.
2211 """
2212 return (
2213 isinstance(obj, tuple)
2214 and hasattr(obj, "_make")
2215 and hasattr(obj, "_asdict")
2216 and hasattr(obj, "_replace")
2217 and hasattr(obj, "_fields")
2218 and hasattr(obj, "_field_defaults")
2219 )
2222def get_default_shuffle_method() -> str:
2223 if d := config.get("dataframe.shuffle.method", None):
2224 return d
2225 try:
2226 from distributed import default_client
2228 default_client()
2229 except (ImportError, ValueError):
2230 return "disk"
2232 try:
2233 from distributed.shuffle import check_minimal_arrow_version
2235 check_minimal_arrow_version()
2236 except ModuleNotFoundError:
2237 return "tasks"
2238 return "p2p"
2241def get_meta_library(like):
2242 if hasattr(like, "_meta"):
2243 like = like._meta
2245 return import_module(typename(like).partition(".")[0])
2248class shorten_traceback:
2249 """Context manager that removes irrelevant stack elements from traceback.
2251 * omits frames from modules that match `admin.traceback.shorten`
2252 * always keeps the first and last frame.
2253 """
2255 __slots__ = ()
2257 def __enter__(self) -> None:
2258 pass
2260 def __exit__(
2261 self,
2262 exc_type: type[BaseException] | None,
2263 exc_val: BaseException | None,
2264 exc_tb: types.TracebackType | None,
2265 ) -> None:
2266 if exc_val and exc_tb:
2267 exc_val.__traceback__ = self.shorten(exc_tb)
2269 @staticmethod
2270 def shorten(exc_tb: types.TracebackType) -> types.TracebackType:
2271 paths = config.get("admin.traceback.shorten")
2272 if not paths:
2273 return exc_tb
2275 exp = re.compile(".*(" + "|".join(paths) + ")")
2276 curr: types.TracebackType | None = exc_tb
2277 prev: types.TracebackType | None = None
2279 while curr:
2280 if prev is None:
2281 prev = curr # first frame
2282 elif not curr.tb_next:
2283 # always keep last frame
2284 prev.tb_next = curr
2285 prev = prev.tb_next
2286 elif not exp.match(curr.tb_frame.f_code.co_filename):
2287 # keep if module is not listed in config
2288 prev.tb_next = curr
2289 prev = curr
2290 curr = curr.tb_next
2292 # Uncomment to remove the first frame, which is something you don't want to keep
2293 # if it matches the regexes. Requires Python >=3.11.
2294 # if exc_tb.tb_next and exp.match(exc_tb.tb_frame.f_code.co_filename):
2295 # return exc_tb.tb_next
2297 return exc_tb
2300def unzip(ls, nout):
2301 """Unzip a list of lists into ``nout`` outputs."""
2302 out = list(zip(*ls))
2303 if not out:
2304 out = [()] * nout
2305 return out
2308class disable_gc(ContextDecorator):
2309 """Context manager to disable garbage collection."""
2311 def __init__(self, collect=False):
2312 self.collect = collect
2313 self._gc_enabled = gc.isenabled()
2315 def __enter__(self):
2316 gc.disable()
2317 return self
2319 def __exit__(self, exc_type, exc_value, traceback):
2320 if self._gc_enabled:
2321 gc.enable()
2322 return False
2325def is_empty(obj):
2326 """
2327 Duck-typed check for “emptiness” of an object.
2329 Works for standard sequences (lists, tuples, etc.), NumPy arrays,
2330 and sparse-like objects (e.g., SciPy sparse arrays).
2332 The function checks:
2333 1. If the object supports len(), returns True if len(obj) == 0.
2334 2. If the object has a `.nnz` attribute (number of non-zero elements),
2335 returns True if `.nnz == 0`.
2336 3. If the object has a `.shape` attribute, returns True if any
2337 dimension is zero.
2338 4. Otherwise, returns False (assumes non-empty).
2340 Parameters
2341 ----------
2342 obj : any
2343 The object to check for emptiness.
2345 Returns
2346 -------
2347 bool
2348 True if the object is considered empty, False otherwise.
2349 """
2350 # Check standard sequences
2351 with contextlib.suppress(Exception):
2352 return len(obj) == 0
2354 # Sparse-like objects
2355 with contextlib.suppress(Exception):
2356 return obj.nnz == 0
2358 with contextlib.suppress(Exception):
2359 return 0 in obj.shape
2361 # Fallback: assume non-empty
2362 return False