Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/utils.py: 30%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import codecs
4import functools
5import gc
6import inspect
7import os
8import re
9import shutil
10import sys
11import tempfile
12import types
13import uuid
14import warnings
15from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, Set
16from contextlib import ContextDecorator, contextmanager, nullcontext, suppress
17from datetime import datetime, timedelta
18from errno import ENOENT
19from functools import wraps
20from importlib import import_module
21from numbers import Integral, Number
22from operator import add
23from threading import Lock
24from typing import Any, ClassVar, Literal, TypeVar, cast, overload
25from weakref import WeakValueDictionary
27import tlz as toolz
29from dask import config
30from dask.typing import no_default
32K = TypeVar("K")
33V = TypeVar("V")
34T = TypeVar("T")
36# used in decorators to preserve the signature of the function it decorates
37# see https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators
38FuncType = Callable[..., Any]
39F = TypeVar("F", bound=FuncType)
41system_encoding = sys.getdefaultencoding()
42if system_encoding == "ascii":
43 system_encoding = "utf-8"
46def apply(func, args, kwargs=None):
47 """Apply a function given its positional and keyword arguments.
49 Equivalent to ``func(*args, **kwargs)``
50 Most Dask users will never need to use the ``apply`` function.
51 It is typically only used by people who need to inject
52 keyword argument values into a low level Dask task graph.
54 Parameters
55 ----------
56 func : callable
57 The function you want to apply.
58 args : tuple
59 A tuple containing all the positional arguments needed for ``func``
60 (eg: ``(arg_1, arg_2, arg_3)``)
61 kwargs : dict, optional
62 A dictionary mapping the keyword arguments
63 (eg: ``{"kwarg_1": value, "kwarg_2": value}``
65 Examples
66 --------
67 >>> from dask.utils import apply
68 >>> def add(number, second_number=5):
69 ... return number + second_number
70 ...
71 >>> apply(add, (10,), {"second_number": 2}) # equivalent to add(*args, **kwargs)
72 12
74 >>> task = apply(add, (10,), {"second_number": 2})
75 >>> dsk = {'task-name': task} # adds the task to a low level Dask task graph
76 """
77 if kwargs:
78 return func(*args, **kwargs)
79 else:
80 return func(*args)
83def _deprecated(
84 *,
85 version: str | None = None,
86 after_version: str | None = None,
87 message: str | None = None,
88 use_instead: str | None = None,
89 category: type[Warning] = FutureWarning,
90):
91 """Decorator to mark a function as deprecated
93 Parameters
94 ----------
95 version : str, optional
96 Version of Dask in which the function was deprecated. If specified, the version
97 will be included in the default warning message. This should no longer be used
98 after the introduction of automated versioning system.
99 after_version : str, optional
100 Version of Dask after which the function was deprecated. If specified, the
101 version will be included in the default warning message.
102 message : str, optional
103 Custom warning message to raise.
104 use_instead : str, optional
105 Name of function to use in place of the deprecated function.
106 If specified, this will be included in the default warning
107 message.
108 category : type[Warning], optional
109 Type of warning to raise. Defaults to ``FutureWarning``.
111 Examples
112 --------
114 >>> from dask.utils import _deprecated
115 >>> @_deprecated(after_version="X.Y.Z", use_instead="bar")
116 ... def foo():
117 ... return "baz"
118 """
120 def decorator(func):
121 if message is None:
122 msg = f"{func.__name__} "
123 if after_version is not None:
124 msg += f"was deprecated after version {after_version} "
125 elif version is not None:
126 msg += f"was deprecated in version {version} "
127 else:
128 msg += "is deprecated "
129 msg += "and will be removed in a future release."
131 if use_instead is not None:
132 msg += f" Please use {use_instead} instead."
133 else:
134 msg = message
136 @functools.wraps(func)
137 def wrapper(*args, **kwargs):
138 warnings.warn(msg, category=category, stacklevel=2)
139 return func(*args, **kwargs)
141 return wrapper
143 return decorator
146def _deprecated_kwarg(
147 old_arg_name: str,
148 new_arg_name: str | None = None,
149 mapping: Mapping[Any, Any] | Callable[[Any], Any] | None = None,
150 stacklevel: int = 2,
151 comment: str | None = None,
152) -> Callable[[F], F]:
153 """
154 Decorator to deprecate a keyword argument of a function.
156 Parameters
157 ----------
158 old_arg_name : str
159 Name of argument in function to deprecate
160 new_arg_name : str, optional
161 Name of preferred argument in function. Omit to warn that
162 ``old_arg_name`` keyword is deprecated.
163 mapping : dict or callable, optional
164 If mapping is present, use it to translate old arguments to
165 new arguments. A callable must do its own value checking;
166 values not found in a dict will be forwarded unchanged.
167 comment : str, optional
168 Additional message to deprecation message. Useful to pass
169 on suggestions with the deprecation warning.
171 Examples
172 --------
173 The following deprecates 'cols', using 'columns' instead
175 >>> @_deprecated_kwarg(old_arg_name='cols', new_arg_name='columns')
176 ... def f(columns=''):
177 ... print(columns)
178 ...
179 >>> f(columns='should work ok')
180 should work ok
182 >>> f(cols='should raise warning') # doctest: +SKIP
183 FutureWarning: cols is deprecated, use columns instead
184 warnings.warn(msg, FutureWarning)
185 should raise warning
187 >>> f(cols='should error', columns="can\'t pass do both") # doctest: +SKIP
188 TypeError: Can only specify 'cols' or 'columns', not both
190 >>> @_deprecated_kwarg('old', 'new', {'yes': True, 'no': False})
191 ... def f(new=False):
192 ... print('yes!' if new else 'no!')
193 ...
194 >>> f(old='yes') # doctest: +SKIP
195 FutureWarning: old='yes' is deprecated, use new=True instead
196 warnings.warn(msg, FutureWarning)
197 yes!
199 To raise a warning that a keyword will be removed entirely in the future
201 >>> @_deprecated_kwarg(old_arg_name='cols', new_arg_name=None)
202 ... def f(cols='', another_param=''):
203 ... print(cols)
204 ...
205 >>> f(cols='should raise warning') # doctest: +SKIP
206 FutureWarning: the 'cols' keyword is deprecated and will be removed in a
207 future version please takes steps to stop use of 'cols'
208 should raise warning
209 >>> f(another_param='should not raise warning') # doctest: +SKIP
210 should not raise warning
212 >>> f(cols='should raise warning', another_param='') # doctest: +SKIP
213 FutureWarning: the 'cols' keyword is deprecated and will be removed in a
214 future version please takes steps to stop use of 'cols'
215 should raise warning
216 """
217 if mapping is not None and not hasattr(mapping, "get") and not callable(mapping):
218 raise TypeError(
219 "mapping from old to new argument values must be dict or callable!"
220 )
222 comment_ = f"\n{comment}" or ""
224 def _deprecated_kwarg(func: F) -> F:
225 @wraps(func)
226 def wrapper(*args, **kwargs) -> Callable[..., Any]:
227 old_arg_value = kwargs.pop(old_arg_name, no_default)
229 if old_arg_value is not no_default:
230 if new_arg_name is None:
231 msg = (
232 f"the {repr(old_arg_name)} keyword is deprecated and "
233 "will be removed in a future version. Please take "
234 f"steps to stop the use of {repr(old_arg_name)}"
235 ) + comment_
236 warnings.warn(msg, FutureWarning, stacklevel=stacklevel)
237 kwargs[old_arg_name] = old_arg_value
238 return func(*args, **kwargs)
240 elif mapping is not None:
241 if callable(mapping):
242 new_arg_value = mapping(old_arg_value)
243 else:
244 new_arg_value = mapping.get(old_arg_value, old_arg_value)
245 msg = (
246 f"the {old_arg_name}={repr(old_arg_value)} keyword is "
247 "deprecated, use "
248 f"{new_arg_name}={repr(new_arg_value)} instead."
249 )
250 else:
251 new_arg_value = old_arg_value
252 msg = (
253 f"the {repr(old_arg_name)} keyword is deprecated, "
254 f"use {repr(new_arg_name)} instead."
255 )
257 warnings.warn(msg + comment_, FutureWarning, stacklevel=stacklevel)
258 if kwargs.get(new_arg_name) is not None:
259 msg = (
260 f"Can only specify {repr(old_arg_name)} "
261 f"or {repr(new_arg_name)}, not both."
262 )
263 raise TypeError(msg)
264 kwargs[new_arg_name] = new_arg_value
265 return func(*args, **kwargs)
267 return cast(F, wrapper)
269 return _deprecated_kwarg
272def deepmap(func, *seqs):
273 """Apply function inside nested lists
275 >>> inc = lambda x: x + 1
276 >>> deepmap(inc, [[1, 2], [3, 4]])
277 [[2, 3], [4, 5]]
279 >>> add = lambda x, y: x + y
280 >>> deepmap(add, [[1, 2], [3, 4]], [[10, 20], [30, 40]])
281 [[11, 22], [33, 44]]
282 """
283 if isinstance(seqs[0], (list, Iterator)):
284 return [deepmap(func, *items) for items in zip(*seqs)]
285 else:
286 return func(*seqs)
289@_deprecated()
290def homogeneous_deepmap(func, seq):
291 if not seq:
292 return seq
293 n = 0
294 tmp = seq
295 while isinstance(tmp, list):
296 n += 1
297 tmp = tmp[0]
299 return ndeepmap(n, func, seq)
302def ndeepmap(n, func, seq):
303 """Call a function on every element within a nested container
305 >>> def inc(x):
306 ... return x + 1
307 >>> L = [[1, 2], [3, 4, 5]]
308 >>> ndeepmap(2, inc, L)
309 [[2, 3], [4, 5, 6]]
310 """
311 if n == 1:
312 return [func(item) for item in seq]
313 elif n > 1:
314 return [ndeepmap(n - 1, func, item) for item in seq]
315 elif isinstance(seq, list):
316 return func(seq[0])
317 else:
318 return func(seq)
321def import_required(mod_name, error_msg):
322 """Attempt to import a required dependency.
324 Raises a RuntimeError if the requested module is not available.
325 """
326 try:
327 return import_module(mod_name)
328 except ImportError as e:
329 raise RuntimeError(error_msg) from e
332@contextmanager
333def tmpfile(extension="", dir=None):
334 """
335 Function to create and return a unique temporary file with the given extension, if provided.
337 Parameters
338 ----------
339 extension : str
340 The extension of the temporary file to be created
341 dir : str
342 If ``dir`` is not None, the file will be created in that directory; otherwise,
343 Python's default temporary directory is used.
345 Returns
346 -------
347 out : str
348 Path to the temporary file
350 See Also
351 --------
352 NamedTemporaryFile : Built-in alternative for creating temporary files
353 tmp_path : pytest fixture for creating a temporary directory unique to the test invocation
355 Notes
356 -----
357 This context manager is particularly useful on Windows for opening temporary files multiple times.
358 """
359 extension = extension.lstrip(".")
360 if extension:
361 extension = "." + extension
362 handle, filename = tempfile.mkstemp(extension, dir=dir)
363 os.close(handle)
364 os.remove(filename)
366 try:
367 yield filename
368 finally:
369 if os.path.exists(filename):
370 with suppress(OSError): # sometimes we can't remove a generated temp file
371 if os.path.isdir(filename):
372 shutil.rmtree(filename)
373 else:
374 os.remove(filename)
377@contextmanager
378def tmpdir(dir=None):
379 """
380 Function to create and return a unique temporary directory.
382 Parameters
383 ----------
384 dir : str
385 If ``dir`` is not None, the directory will be created in that directory; otherwise,
386 Python's default temporary directory is used.
388 Returns
389 -------
390 out : str
391 Path to the temporary directory
393 Notes
394 -----
395 This context manager is particularly useful on Windows for opening temporary directories multiple times.
396 """
397 dirname = tempfile.mkdtemp(dir=dir)
399 try:
400 yield dirname
401 finally:
402 if os.path.exists(dirname):
403 if os.path.isdir(dirname):
404 with suppress(OSError):
405 shutil.rmtree(dirname)
406 else:
407 with suppress(OSError):
408 os.remove(dirname)
411@contextmanager
412def filetext(text, extension="", open=open, mode="w"):
413 with tmpfile(extension=extension) as filename:
414 f = open(filename, mode=mode)
415 try:
416 f.write(text)
417 finally:
418 try:
419 f.close()
420 except AttributeError:
421 pass
423 yield filename
426@contextmanager
427def changed_cwd(new_cwd):
428 old_cwd = os.getcwd()
429 os.chdir(new_cwd)
430 try:
431 yield
432 finally:
433 os.chdir(old_cwd)
436@contextmanager
437def tmp_cwd(dir=None):
438 with tmpdir(dir) as dirname:
439 with changed_cwd(dirname):
440 yield dirname
443class IndexCallable:
444 """Provide getitem syntax for functions
446 >>> def inc(x):
447 ... return x + 1
449 >>> I = IndexCallable(inc)
450 >>> I[3]
451 4
452 """
454 __slots__ = ("fn",)
456 def __init__(self, fn):
457 self.fn = fn
459 def __getitem__(self, key):
460 return self.fn(key)
463@contextmanager
464def filetexts(d, open=open, mode="t", use_tmpdir=True):
465 """Dumps a number of textfiles to disk
467 Parameters
468 ----------
469 d : dict
470 a mapping from filename to text like {'a.csv': '1,1\n2,2'}
472 Since this is meant for use in tests, this context manager will
473 automatically switch to a temporary current directory, to avoid
474 race conditions when running tests in parallel.
475 """
476 with tmp_cwd() if use_tmpdir else nullcontext():
477 for filename, text in d.items():
478 try:
479 os.makedirs(os.path.dirname(filename))
480 except OSError:
481 pass
482 f = open(filename, "w" + mode)
483 try:
484 f.write(text)
485 finally:
486 try:
487 f.close()
488 except AttributeError:
489 pass
491 yield list(d)
493 for filename in d:
494 if os.path.exists(filename):
495 with suppress(OSError):
496 os.remove(filename)
499def concrete(seq):
500 """Make nested iterators concrete lists
502 >>> data = [[1, 2], [3, 4]]
503 >>> seq = iter(map(iter, data))
504 >>> concrete(seq)
505 [[1, 2], [3, 4]]
506 """
507 if isinstance(seq, Iterator):
508 seq = list(seq)
509 if isinstance(seq, (tuple, list)):
510 seq = list(map(concrete, seq))
511 return seq
514def pseudorandom(n: int, p, random_state=None):
515 """Pseudorandom array of integer indexes
517 >>> pseudorandom(5, [0.5, 0.5], random_state=123)
518 array([1, 0, 0, 1, 1], dtype=int8)
520 >>> pseudorandom(10, [0.5, 0.2, 0.2, 0.1], random_state=5)
521 array([0, 2, 0, 3, 0, 1, 2, 1, 0, 0], dtype=int8)
522 """
523 import numpy as np
525 p = list(p)
526 cp = np.cumsum([0] + p)
527 assert np.allclose(1, cp[-1])
528 assert len(p) < 256
530 if not isinstance(random_state, np.random.RandomState):
531 random_state = np.random.RandomState(random_state)
533 x = random_state.random_sample(n)
534 out = np.empty(n, dtype="i1")
536 for i, (low, high) in enumerate(zip(cp[:-1], cp[1:])):
537 out[(x >= low) & (x < high)] = i
538 return out
541def random_state_data(n: int, random_state=None) -> list:
542 """Return a list of arrays that can initialize
543 ``np.random.RandomState``.
545 Parameters
546 ----------
547 n : int
548 Number of arrays to return.
549 random_state : int or np.random.RandomState, optional
550 If an int, is used to seed a new ``RandomState``.
551 """
552 import numpy as np
554 if not all(
555 hasattr(random_state, attr) for attr in ["normal", "beta", "bytes", "uniform"]
556 ):
557 random_state = np.random.RandomState(random_state)
559 random_data = random_state.bytes(624 * n * 4) # `n * 624` 32-bit integers
560 l = list(np.frombuffer(random_data, dtype="<u4").reshape((n, -1)))
561 assert len(l) == n
562 return l
565def is_integer(i) -> bool:
566 """
567 >>> is_integer(6)
568 True
569 >>> is_integer(42.0)
570 True
571 >>> is_integer('abc')
572 False
573 """
574 return isinstance(i, Integral) or (isinstance(i, float) and i.is_integer())
577ONE_ARITY_BUILTINS = {
578 abs,
579 all,
580 any,
581 ascii,
582 bool,
583 bytearray,
584 bytes,
585 callable,
586 chr,
587 classmethod,
588 complex,
589 dict,
590 dir,
591 enumerate,
592 eval,
593 float,
594 format,
595 frozenset,
596 hash,
597 hex,
598 id,
599 int,
600 iter,
601 len,
602 list,
603 max,
604 min,
605 next,
606 oct,
607 open,
608 ord,
609 range,
610 repr,
611 reversed,
612 round,
613 set,
614 slice,
615 sorted,
616 staticmethod,
617 str,
618 sum,
619 tuple,
620 type,
621 vars,
622 zip,
623 memoryview,
624}
625MULTI_ARITY_BUILTINS = {
626 compile,
627 delattr,
628 divmod,
629 filter,
630 getattr,
631 hasattr,
632 isinstance,
633 issubclass,
634 map,
635 pow,
636 setattr,
637}
640def getargspec(func):
641 """Version of inspect.getargspec that works with partial and warps."""
642 if isinstance(func, functools.partial):
643 return getargspec(func.func)
645 func = getattr(func, "__wrapped__", func)
646 if isinstance(func, type):
647 return inspect.getfullargspec(func.__init__)
648 else:
649 return inspect.getfullargspec(func)
652def takes_multiple_arguments(func, varargs=True):
653 """Does this function take multiple arguments?
655 >>> def f(x, y): pass
656 >>> takes_multiple_arguments(f)
657 True
659 >>> def f(x): pass
660 >>> takes_multiple_arguments(f)
661 False
663 >>> def f(x, y=None): pass
664 >>> takes_multiple_arguments(f)
665 False
667 >>> def f(*args): pass
668 >>> takes_multiple_arguments(f)
669 True
671 >>> class Thing:
672 ... def __init__(self, a): pass
673 >>> takes_multiple_arguments(Thing)
674 False
676 """
677 if func in ONE_ARITY_BUILTINS:
678 return False
679 elif func in MULTI_ARITY_BUILTINS:
680 return True
682 try:
683 spec = getargspec(func)
684 except Exception:
685 return False
687 try:
688 is_constructor = spec.args[0] == "self" and isinstance(func, type)
689 except Exception:
690 is_constructor = False
692 if varargs and spec.varargs:
693 return True
695 ndefaults = 0 if spec.defaults is None else len(spec.defaults)
696 return len(spec.args) - ndefaults - is_constructor > 1
699def get_named_args(func) -> list[str]:
700 """Get all non ``*args/**kwargs`` arguments for a function"""
701 s = inspect.signature(func)
702 return [
703 n
704 for n, p in s.parameters.items()
705 if p.kind in [p.POSITIONAL_OR_KEYWORD, p.POSITIONAL_ONLY, p.KEYWORD_ONLY]
706 ]
709class Dispatch:
710 """Simple single dispatch."""
712 def __init__(self, name=None):
713 self._lookup = {}
714 self._lazy = {}
715 if name:
716 self.__name__ = name
718 def register(self, type, func=None):
719 """Register dispatch of `func` on arguments of type `type`"""
721 def wrapper(func):
722 if isinstance(type, tuple):
723 for t in type:
724 self.register(t, func)
725 else:
726 self._lookup[type] = func
727 return func
729 return wrapper(func) if func is not None else wrapper
731 def register_lazy(self, toplevel, func=None):
732 """
733 Register a registration function which will be called if the
734 *toplevel* module (e.g. 'pandas') is ever loaded.
735 """
737 def wrapper(func):
738 self._lazy[toplevel] = func
739 return func
741 return wrapper(func) if func is not None else wrapper
743 def dispatch(self, cls):
744 """Return the function implementation for the given ``cls``"""
745 lk = self._lookup
746 if cls in lk:
747 return lk[cls]
748 for cls2 in cls.__mro__:
749 # Is a lazy registration function present?
750 try:
751 toplevel, _, _ = cls2.__module__.partition(".")
752 except Exception:
753 continue
754 try:
755 register = self._lazy[toplevel]
756 except KeyError:
757 pass
758 else:
759 register()
760 self._lazy.pop(toplevel, None)
761 meth = self.dispatch(cls) # recurse
762 lk[cls] = meth
763 lk[cls2] = meth
764 return meth
765 try:
766 impl = lk[cls2]
767 except KeyError:
768 pass
769 else:
770 if cls is not cls2:
771 # Cache lookup
772 lk[cls] = impl
773 return impl
774 raise TypeError(f"No dispatch for {cls}")
776 def __call__(self, arg, *args, **kwargs):
777 """
778 Call the corresponding method based on type of argument.
779 """
780 meth = self.dispatch(type(arg))
781 return meth(arg, *args, **kwargs)
783 @property
784 def __doc__(self):
785 try:
786 func = self.dispatch(object)
787 return func.__doc__
788 except TypeError:
789 return "Single Dispatch for %s" % self.__name__
792def ensure_not_exists(filename) -> None:
793 """
794 Ensure that a file does not exist.
795 """
796 try:
797 os.unlink(filename)
798 except OSError as e:
799 if e.errno != ENOENT:
800 raise
803def _skip_doctest(line):
804 # NumPy docstring contains cursor and comment only example
805 stripped = line.strip()
806 if stripped == ">>>" or stripped.startswith(">>> #"):
807 return line
808 elif ">>>" in stripped and "+SKIP" not in stripped:
809 if "# doctest:" in line:
810 return line + ", +SKIP"
811 else:
812 return line + " # doctest: +SKIP"
813 else:
814 return line
817def skip_doctest(doc):
818 if doc is None:
819 return ""
820 return "\n".join([_skip_doctest(line) for line in doc.split("\n")])
823def extra_titles(doc):
824 lines = doc.split("\n")
825 titles = {
826 i: lines[i].strip()
827 for i in range(len(lines) - 1)
828 if lines[i + 1].strip() and all(c == "-" for c in lines[i + 1].strip())
829 }
831 seen = set()
832 for i, title in sorted(titles.items()):
833 if title in seen:
834 new_title = "Extra " + title
835 lines[i] = lines[i].replace(title, new_title)
836 lines[i + 1] = lines[i + 1].replace("-" * len(title), "-" * len(new_title))
837 else:
838 seen.add(title)
840 return "\n".join(lines)
843def ignore_warning(doc, cls, name, extra="", skipblocks=0, inconsistencies=None):
844 """Expand docstring by adding disclaimer and extra text"""
845 import inspect
847 if inspect.isclass(cls):
848 l1 = f"This docstring was copied from {cls.__module__}.{cls.__name__}.{name}.\n\n"
849 else:
850 l1 = f"This docstring was copied from {cls.__name__}.{name}.\n\n"
851 l2 = "Some inconsistencies with the Dask version may exist."
853 i = doc.find("\n\n")
854 if i != -1:
855 # Insert our warning
856 head = doc[: i + 2]
857 tail = doc[i + 2 :]
858 while skipblocks > 0:
859 i = tail.find("\n\n")
860 head = tail[: i + 2]
861 tail = tail[i + 2 :]
862 skipblocks -= 1
863 # Indentation of next line
864 indent = re.match(r"\s*", tail).group(0)
865 # Insert the warning, indented, with a blank line before and after
866 if extra:
867 more = [indent, extra.rstrip("\n") + "\n\n"]
868 else:
869 more = []
870 if inconsistencies is not None:
871 l3 = f"Known inconsistencies: \n {inconsistencies}"
872 bits = [head, indent, l1, l2, "\n\n", l3, "\n\n"] + more + [tail]
873 else:
874 bits = [head, indent, l1, indent, l2, "\n\n"] + more + [tail]
875 doc = "".join(bits)
877 return doc
880def unsupported_arguments(doc, args):
881 """Mark unsupported arguments with a disclaimer"""
882 lines = doc.split("\n")
883 for arg in args:
884 subset = [
885 (i, line)
886 for i, line in enumerate(lines)
887 if re.match(r"^\s*" + arg + " ?:", line)
888 ]
889 if len(subset) == 1:
890 [(i, line)] = subset
891 lines[i] = line + " (Not supported in Dask)"
892 return "\n".join(lines)
895def _derived_from(
896 cls, method, ua_args=None, extra="", skipblocks=0, inconsistencies=None
897):
898 """Helper function for derived_from to ease testing"""
899 ua_args = ua_args or []
901 # do not use wraps here, as it hides keyword arguments displayed
902 # in the doc
903 original_method = getattr(cls, method.__name__)
905 doc = getattr(original_method, "__doc__", None)
907 if isinstance(original_method, property):
908 # some things like SeriesGroupBy.unique are generated.
909 original_method = original_method.fget
910 if not doc:
911 doc = getattr(original_method, "__doc__", None)
913 if isinstance(original_method, functools.cached_property):
914 original_method = original_method.func
915 if not doc:
916 doc = getattr(original_method, "__doc__", None)
918 if doc is None:
919 doc = ""
921 # pandas DataFrame/Series sometimes override methods without setting __doc__
922 if not doc and cls.__name__ in {"DataFrame", "Series"}:
923 for obj in cls.mro():
924 obj_method = getattr(obj, method.__name__, None)
925 if obj_method is not None and obj_method.__doc__:
926 doc = obj_method.__doc__
927 break
929 # Insert disclaimer that this is a copied docstring
930 if doc:
931 doc = ignore_warning(
932 doc,
933 cls,
934 method.__name__,
935 extra=extra,
936 skipblocks=skipblocks,
937 inconsistencies=inconsistencies,
938 )
939 elif extra:
940 doc += extra.rstrip("\n") + "\n\n"
942 # Mark unsupported arguments
943 try:
944 method_args = get_named_args(method)
945 original_args = get_named_args(original_method)
946 not_supported = [m for m in original_args if m not in method_args]
947 except ValueError:
948 not_supported = []
949 if len(ua_args) > 0:
950 not_supported.extend(ua_args)
951 if len(not_supported) > 0:
952 doc = unsupported_arguments(doc, not_supported)
954 doc = skip_doctest(doc)
955 doc = extra_titles(doc)
957 return doc
960def derived_from(
961 original_klass, version=None, ua_args=None, skipblocks=0, inconsistencies=None
962):
963 """Decorator to attach original class's docstring to the wrapped method.
965 The output structure will be: top line of docstring, disclaimer about this
966 being auto-derived, any extra text associated with the method being patched,
967 the body of the docstring and finally, the list of keywords that exist in
968 the original method but not in the dask version.
970 Parameters
971 ----------
972 original_klass: type
973 Original class which the method is derived from
974 version : str
975 Original package version which supports the wrapped method
976 ua_args : list
977 List of keywords which Dask doesn't support. Keywords existing in
978 original but not in Dask will automatically be added.
979 skipblocks : int
980 How many text blocks (paragraphs) to skip from the start of the
981 docstring. Useful for cases where the target has extra front-matter.
982 inconsistencies: list
983 List of known inconsistencies with method whose docstrings are being
984 copied.
985 """
986 ua_args = ua_args or []
988 def wrapper(method):
989 try:
990 extra = getattr(method, "__doc__", None) or ""
991 method.__doc__ = _derived_from(
992 original_klass,
993 method,
994 ua_args=ua_args,
995 extra=extra,
996 skipblocks=skipblocks,
997 inconsistencies=inconsistencies,
998 )
999 return method
1001 except AttributeError:
1002 module_name = original_klass.__module__.split(".")[0]
1004 @functools.wraps(method)
1005 def wrapped(*args, **kwargs):
1006 msg = f"Base package doesn't support '{method.__name__}'."
1007 if version is not None:
1008 msg2 = " Use {0} {1} or later to use this method."
1009 msg += msg2.format(module_name, version)
1010 raise NotImplementedError(msg)
1012 return wrapped
1014 return wrapper
1017def funcname(func) -> str:
1018 """Get the name of a function."""
1019 # functools.partial
1020 if isinstance(func, functools.partial):
1021 return funcname(func.func)
1022 # methodcaller
1023 if isinstance(func, methodcaller):
1024 return func.method[:50]
1026 module_name = getattr(func, "__module__", None) or ""
1027 type_name = getattr(type(func), "__name__", None) or ""
1029 # toolz.curry
1030 if "toolz" in module_name and "curry" == type_name:
1031 return func.func_name[:50]
1032 # multipledispatch objects
1033 if "multipledispatch" in module_name and "Dispatcher" == type_name:
1034 return func.name[:50]
1035 # numpy.vectorize objects
1036 if "numpy" in module_name and "vectorize" == type_name:
1037 return ("vectorize_" + funcname(func.pyfunc))[:50]
1039 # All other callables
1040 try:
1041 name = func.__name__
1042 if name == "<lambda>":
1043 return "lambda"
1044 return name[:50]
1045 except AttributeError:
1046 return str(func)[:50]
1049def typename(typ: Any, short: bool = False) -> str:
1050 """
1051 Return the name of a type
1053 Examples
1054 --------
1055 >>> typename(int)
1056 'int'
1058 >>> from dask.core import literal
1059 >>> typename(literal)
1060 'dask.core.literal'
1061 >>> typename(literal, short=True)
1062 'dask.literal'
1063 """
1064 if not isinstance(typ, type):
1065 return typename(type(typ))
1066 try:
1067 if not typ.__module__ or typ.__module__ == "builtins":
1068 return typ.__name__
1069 else:
1070 if short:
1071 module, *_ = typ.__module__.split(".")
1072 else:
1073 module = typ.__module__
1074 return module + "." + typ.__name__
1075 except AttributeError:
1076 return str(typ)
1079def ensure_bytes(s) -> bytes:
1080 """Attempt to turn `s` into bytes.
1082 Parameters
1083 ----------
1084 s : Any
1085 The object to be converted. Will correctly handled
1086 * str
1087 * bytes
1088 * objects implementing the buffer protocol (memoryview, ndarray, etc.)
1090 Returns
1091 -------
1092 b : bytes
1094 Raises
1095 ------
1096 TypeError
1097 When `s` cannot be converted
1099 Examples
1100 --------
1101 >>> ensure_bytes('123')
1102 b'123'
1103 >>> ensure_bytes(b'123')
1104 b'123'
1105 >>> ensure_bytes(bytearray(b'123'))
1106 b'123'
1107 """
1108 if isinstance(s, bytes):
1109 return s
1110 elif hasattr(s, "encode"):
1111 return s.encode()
1112 else:
1113 try:
1114 return bytes(s)
1115 except Exception as e:
1116 raise TypeError(
1117 f"Object {s} is neither a bytes object nor can be encoded to bytes"
1118 ) from e
1121def ensure_unicode(s) -> str:
1122 """Turn string or bytes to string
1124 >>> ensure_unicode('123')
1125 '123'
1126 >>> ensure_unicode(b'123')
1127 '123'
1128 """
1129 if isinstance(s, str):
1130 return s
1131 elif hasattr(s, "decode"):
1132 return s.decode()
1133 else:
1134 try:
1135 return codecs.decode(s)
1136 except Exception as e:
1137 raise TypeError(
1138 f"Object {s} is neither a str object nor can be decoded to str"
1139 ) from e
1142def digit(n, k, base):
1143 """
1145 >>> digit(1234, 0, 10)
1146 4
1147 >>> digit(1234, 1, 10)
1148 3
1149 >>> digit(1234, 2, 10)
1150 2
1151 >>> digit(1234, 3, 10)
1152 1
1153 """
1154 return n // base**k % base
1157def insert(tup, loc, val):
1158 """
1160 >>> insert(('a', 'b', 'c'), 0, 'x')
1161 ('x', 'b', 'c')
1162 """
1163 L = list(tup)
1164 L[loc] = val
1165 return tuple(L)
1168def memory_repr(num):
1169 for x in ["bytes", "KB", "MB", "GB", "TB"]:
1170 if num < 1024.0:
1171 return f"{num:3.1f} {x}"
1172 num /= 1024.0
1175def asciitable(columns, rows):
1176 """Formats an ascii table for given columns and rows.
1178 Parameters
1179 ----------
1180 columns : list
1181 The column names
1182 rows : list of tuples
1183 The rows in the table. Each tuple must be the same length as
1184 ``columns``.
1185 """
1186 rows = [tuple(str(i) for i in r) for r in rows]
1187 columns = tuple(str(i) for i in columns)
1188 widths = tuple(max(*map(len, x), len(c)) for x, c in zip(zip(*rows), columns))
1189 row_template = ("|" + (" %%-%ds |" * len(columns))) % widths
1190 header = row_template % tuple(columns)
1191 bar = "+%s+" % "+".join("-" * (w + 2) for w in widths)
1192 data = "\n".join(row_template % r for r in rows)
1193 return "\n".join([bar, header, bar, data, bar])
1196def put_lines(buf, lines):
1197 if any(not isinstance(x, str) for x in lines):
1198 lines = [str(x) for x in lines]
1199 buf.write("\n".join(lines))
1202_method_cache: dict[str, methodcaller] = {}
1205class methodcaller:
1206 """
1207 Return a callable object that calls the given method on its operand.
1209 Unlike the builtin `operator.methodcaller`, instances of this class are
1210 cached and arguments are passed at call time instead of build time.
1211 """
1213 __slots__ = ("method",)
1214 method: str
1216 @property
1217 def func(self) -> str:
1218 # For `funcname` to work
1219 return self.method
1221 def __new__(cls, method: str):
1222 try:
1223 return _method_cache[method]
1224 except KeyError:
1225 self = object.__new__(cls)
1226 self.method = method
1227 _method_cache[method] = self
1228 return self
1230 def __call__(self, __obj, *args, **kwargs):
1231 return getattr(__obj, self.method)(*args, **kwargs)
1233 def __reduce__(self):
1234 return (methodcaller, (self.method,))
1236 def __str__(self):
1237 return f"<{self.__class__.__name__}: {self.method}>"
1239 __repr__ = __str__
1242class itemgetter:
1243 """Variant of operator.itemgetter that supports equality tests"""
1245 __slots__ = ("index",)
1247 def __init__(self, index):
1248 self.index = index
1250 def __call__(self, x):
1251 return x[self.index]
1253 def __reduce__(self):
1254 return (itemgetter, (self.index,))
1256 def __eq__(self, other):
1257 return type(self) is type(other) and self.index == other.index
1260class MethodCache:
1261 """Attribute access on this object returns a methodcaller for that
1262 attribute.
1264 Examples
1265 --------
1266 >>> a = [1, 3, 3]
1267 >>> M.count(a, 3) == a.count(3)
1268 True
1269 """
1271 def __getattr__(self, item):
1272 return methodcaller(item)
1274 def __dir__(self):
1275 return list(_method_cache)
1278M = MethodCache()
1281class SerializableLock:
1282 """A Serializable per-process Lock
1284 This wraps a normal ``threading.Lock`` object and satisfies the same
1285 interface. However, this lock can also be serialized and sent to different
1286 processes. It will not block concurrent operations between processes (for
1287 this you should look at ``multiprocessing.Lock`` or ``locket.lock_file``
1288 but will consistently deserialize into the same lock.
1290 So if we make a lock in one process::
1292 lock = SerializableLock()
1294 And then send it over to another process multiple times::
1296 bytes = pickle.dumps(lock)
1297 a = pickle.loads(bytes)
1298 b = pickle.loads(bytes)
1300 Then the deserialized objects will operate as though they were the same
1301 lock, and collide as appropriate.
1303 This is useful for consistently protecting resources on a per-process
1304 level.
1306 The creation of locks is itself not threadsafe.
1307 """
1309 _locks: ClassVar[WeakValueDictionary[Hashable, Lock]] = WeakValueDictionary()
1310 token: Hashable
1311 lock: Lock
1313 def __init__(self, token: Hashable | None = None):
1314 self.token = token or str(uuid.uuid4())
1315 if self.token in SerializableLock._locks:
1316 self.lock = SerializableLock._locks[self.token]
1317 else:
1318 self.lock = Lock()
1319 SerializableLock._locks[self.token] = self.lock
1321 def acquire(self, *args, **kwargs):
1322 return self.lock.acquire(*args, **kwargs)
1324 def release(self, *args, **kwargs):
1325 return self.lock.release(*args, **kwargs)
1327 def __enter__(self):
1328 self.lock.__enter__()
1330 def __exit__(self, *args):
1331 self.lock.__exit__(*args)
1333 def locked(self):
1334 return self.lock.locked()
1336 def __getstate__(self):
1337 return self.token
1339 def __setstate__(self, token):
1340 self.__init__(token)
1342 def __str__(self):
1343 return f"<{self.__class__.__name__}: {self.token}>"
1345 __repr__ = __str__
1348def get_scheduler_lock(collection=None, scheduler=None):
1349 """Get an instance of the appropriate lock for a certain situation based on
1350 scheduler used."""
1351 from dask import multiprocessing
1352 from dask.base import get_scheduler
1354 actual_get = get_scheduler(collections=[collection], scheduler=scheduler)
1356 if actual_get == multiprocessing.get:
1357 return multiprocessing.get_context().Manager().Lock()
1358 else:
1359 # if this is a distributed client, we need to lock on
1360 # the level between processes, SerializableLock won't work
1361 try:
1362 import distributed.lock
1363 from distributed.worker import get_client
1365 client = get_client()
1366 except (ImportError, ValueError):
1367 pass
1368 else:
1369 if actual_get == client.get:
1370 return distributed.lock.Lock()
1372 return SerializableLock()
1375def ensure_dict(d: Mapping[K, V], *, copy: bool = False) -> dict[K, V]:
1376 """Convert a generic Mapping into a dict.
1377 Optimize use case of :class:`~dask.highlevelgraph.HighLevelGraph`.
1379 Parameters
1380 ----------
1381 d : Mapping
1382 copy : bool
1383 If True, guarantee that the return value is always a shallow copy of d;
1384 otherwise it may be the input itself.
1385 """
1386 if type(d) is dict:
1387 return d.copy() if copy else d
1388 try:
1389 layers = d.layers # type: ignore
1390 except AttributeError:
1391 return dict(d)
1393 result = {}
1394 for layer in toolz.unique(layers.values(), key=id):
1395 result.update(layer)
1396 return result
1399def ensure_set(s: Set[T], *, copy: bool = False) -> set[T]:
1400 """Convert a generic Set into a set.
1402 Parameters
1403 ----------
1404 s : Set
1405 copy : bool
1406 If True, guarantee that the return value is always a shallow copy of s;
1407 otherwise it may be the input itself.
1408 """
1409 if type(s) is set:
1410 return s.copy() if copy else s
1411 return set(s)
1414class OperatorMethodMixin:
1415 """A mixin for dynamically implementing operators"""
1417 __slots__ = ()
1419 @classmethod
1420 def _bind_operator(cls, op):
1421 """bind operator to this class"""
1422 name = op.__name__
1424 if name.endswith("_"):
1425 # for and_ and or_
1426 name = name[:-1]
1427 elif name == "inv":
1428 name = "invert"
1430 meth = f"__{name}__"
1432 if name in ("abs", "invert", "neg", "pos"):
1433 setattr(cls, meth, cls._get_unary_operator(op))
1434 else:
1435 setattr(cls, meth, cls._get_binary_operator(op))
1437 if name in ("eq", "gt", "ge", "lt", "le", "ne", "getitem"):
1438 return
1440 rmeth = f"__r{name}__"
1441 setattr(cls, rmeth, cls._get_binary_operator(op, inv=True))
1443 @classmethod
1444 def _get_unary_operator(cls, op):
1445 """Must return a method used by unary operator"""
1446 raise NotImplementedError
1448 @classmethod
1449 def _get_binary_operator(cls, op, inv=False):
1450 """Must return a method used by binary operator"""
1451 raise NotImplementedError
1454def partial_by_order(*args, **kwargs):
1455 """
1457 >>> from operator import add
1458 >>> partial_by_order(5, function=add, other=[(1, 10)])
1459 15
1460 """
1461 function = kwargs.pop("function")
1462 other = kwargs.pop("other")
1463 args2 = list(args)
1464 for i, arg in other:
1465 args2.insert(i, arg)
1466 return function(*args2, **kwargs)
1469def is_arraylike(x) -> bool:
1470 """Is this object a numpy array or something similar?
1472 This function tests specifically for an object that already has
1473 array attributes (e.g. np.ndarray, dask.array.Array, cupy.ndarray,
1474 sparse.COO), **NOT** for something that can be coerced into an
1475 array object (e.g. Python lists and tuples). It is meant for dask
1476 developers and developers of downstream libraries.
1478 Note that this function does not correspond with NumPy's
1479 definition of array_like, which includes any object that can be
1480 coerced into an array (see definition in the NumPy glossary):
1481 https://numpy.org/doc/stable/glossary.html
1483 Examples
1484 --------
1485 >>> import numpy as np
1486 >>> is_arraylike(np.ones(5))
1487 True
1488 >>> is_arraylike(np.ones(()))
1489 True
1490 >>> is_arraylike(5)
1491 False
1492 >>> is_arraylike('cat')
1493 False
1494 """
1495 from dask.base import is_dask_collection
1497 is_duck_array = hasattr(x, "__array_function__") or hasattr(x, "__array_ufunc__")
1499 return bool(
1500 hasattr(x, "shape")
1501 and isinstance(x.shape, tuple)
1502 and hasattr(x, "dtype")
1503 and not any(is_dask_collection(n) for n in x.shape)
1504 # We special case scipy.sparse and cupyx.scipy.sparse arrays as having partial
1505 # support for them is useful in scenarios where we mostly call `map_partitions`
1506 # or `map_blocks` with scikit-learn functions on dask arrays and dask dataframes.
1507 # https://github.com/dask/dask/pull/3738
1508 and (is_duck_array or "scipy.sparse" in typename(type(x)))
1509 )
1512def is_dataframe_like(df) -> bool:
1513 """Looks like a Pandas DataFrame"""
1514 if (df.__class__.__module__, df.__class__.__name__) == (
1515 "pandas.core.frame",
1516 "DataFrame",
1517 ):
1518 # fast exec for most likely input
1519 return True
1520 typ = df.__class__
1521 return (
1522 all(hasattr(typ, name) for name in ("groupby", "head", "merge", "mean"))
1523 and all(hasattr(df, name) for name in ("dtypes", "columns"))
1524 and not any(hasattr(typ, name) for name in ("name", "dtype"))
1525 )
1528def is_series_like(s) -> bool:
1529 """Looks like a Pandas Series"""
1530 typ = s.__class__
1531 return (
1532 all(hasattr(typ, name) for name in ("groupby", "head", "mean"))
1533 and all(hasattr(s, name) for name in ("dtype", "name"))
1534 and "index" not in typ.__name__.lower()
1535 )
1538def is_index_like(s) -> bool:
1539 """Looks like a Pandas Index"""
1540 typ = s.__class__
1541 return (
1542 all(hasattr(s, name) for name in ("name", "dtype"))
1543 and "index" in typ.__name__.lower()
1544 )
1547def is_cupy_type(x) -> bool:
1548 # TODO: avoid explicit reference to CuPy
1549 return "cupy" in str(type(x))
1552def natural_sort_key(s: str) -> list[str | int]:
1553 """
1554 Sorting `key` function for performing a natural sort on a collection of
1555 strings
1557 See https://en.wikipedia.org/wiki/Natural_sort_order
1559 Parameters
1560 ----------
1561 s : str
1562 A string that is an element of the collection being sorted
1564 Returns
1565 -------
1566 tuple[str or int]
1567 Tuple of the parts of the input string where each part is either a
1568 string or an integer
1570 Examples
1571 --------
1572 >>> a = ['f0', 'f1', 'f2', 'f8', 'f9', 'f10', 'f11', 'f19', 'f20', 'f21']
1573 >>> sorted(a)
1574 ['f0', 'f1', 'f10', 'f11', 'f19', 'f2', 'f20', 'f21', 'f8', 'f9']
1575 >>> sorted(a, key=natural_sort_key)
1576 ['f0', 'f1', 'f2', 'f8', 'f9', 'f10', 'f11', 'f19', 'f20', 'f21']
1577 """
1578 return [int(part) if part.isdigit() else part for part in re.split(r"(\d+)", s)]
1581def parse_bytes(s: float | str) -> int:
1582 """Parse byte string to numbers
1584 >>> from dask.utils import parse_bytes
1585 >>> parse_bytes('100')
1586 100
1587 >>> parse_bytes('100 MB')
1588 100000000
1589 >>> parse_bytes('100M')
1590 100000000
1591 >>> parse_bytes('5kB')
1592 5000
1593 >>> parse_bytes('5.4 kB')
1594 5400
1595 >>> parse_bytes('1kiB')
1596 1024
1597 >>> parse_bytes('1e6')
1598 1000000
1599 >>> parse_bytes('1e6 kB')
1600 1000000000
1601 >>> parse_bytes('MB')
1602 1000000
1603 >>> parse_bytes(123)
1604 123
1605 >>> parse_bytes('5 foos')
1606 Traceback (most recent call last):
1607 ...
1608 ValueError: Could not interpret 'foos' as a byte unit
1609 """
1610 if isinstance(s, (int, float)):
1611 return int(s)
1612 s = s.replace(" ", "")
1613 if not any(char.isdigit() for char in s):
1614 s = "1" + s
1616 for i in range(len(s) - 1, -1, -1):
1617 if not s[i].isalpha():
1618 break
1619 index = i + 1
1621 prefix = s[:index]
1622 suffix = s[index:]
1624 try:
1625 n = float(prefix)
1626 except ValueError as e:
1627 raise ValueError("Could not interpret '%s' as a number" % prefix) from e
1629 try:
1630 multiplier = byte_sizes[suffix.lower()]
1631 except KeyError as e:
1632 raise ValueError("Could not interpret '%s' as a byte unit" % suffix) from e
1634 result = n * multiplier
1635 return int(result)
1638byte_sizes = {
1639 "kB": 10**3,
1640 "MB": 10**6,
1641 "GB": 10**9,
1642 "TB": 10**12,
1643 "PB": 10**15,
1644 "KiB": 2**10,
1645 "MiB": 2**20,
1646 "GiB": 2**30,
1647 "TiB": 2**40,
1648 "PiB": 2**50,
1649 "B": 1,
1650 "": 1,
1651}
1652byte_sizes = {k.lower(): v for k, v in byte_sizes.items()}
1653byte_sizes.update({k[0]: v for k, v in byte_sizes.items() if k and "i" not in k})
1654byte_sizes.update({k[:-1]: v for k, v in byte_sizes.items() if k and "i" in k})
1657def format_time(n: float) -> str:
1658 """format integers as time
1660 >>> from dask.utils import format_time
1661 >>> format_time(1)
1662 '1.00 s'
1663 >>> format_time(0.001234)
1664 '1.23 ms'
1665 >>> format_time(0.00012345)
1666 '123.45 us'
1667 >>> format_time(123.456)
1668 '123.46 s'
1669 >>> format_time(1234.567)
1670 '20m 34s'
1671 >>> format_time(12345.67)
1672 '3hr 25m'
1673 >>> format_time(123456.78)
1674 '34hr 17m'
1675 >>> format_time(1234567.89)
1676 '14d 6hr'
1677 """
1678 if n > 24 * 60 * 60 * 2:
1679 d = int(n / 3600 / 24)
1680 h = int((n - d * 3600 * 24) / 3600)
1681 return f"{d}d {h}hr"
1682 if n > 60 * 60 * 2:
1683 h = int(n / 3600)
1684 m = int((n - h * 3600) / 60)
1685 return f"{h}hr {m}m"
1686 if n > 60 * 10:
1687 m = int(n / 60)
1688 s = int(n - m * 60)
1689 return f"{m}m {s}s"
1690 if n >= 1:
1691 return "%.2f s" % n
1692 if n >= 1e-3:
1693 return "%.2f ms" % (n * 1e3)
1694 return "%.2f us" % (n * 1e6)
1697def format_time_ago(n: datetime) -> str:
1698 """Calculate a '3 hours ago' type string from a Python datetime.
1700 Examples
1701 --------
1702 >>> from datetime import datetime, timedelta
1704 >>> now = datetime.now()
1705 >>> format_time_ago(now)
1706 'Just now'
1708 >>> past = datetime.now() - timedelta(minutes=1)
1709 >>> format_time_ago(past)
1710 '1 minute ago'
1712 >>> past = datetime.now() - timedelta(minutes=2)
1713 >>> format_time_ago(past)
1714 '2 minutes ago'
1716 >>> past = datetime.now() - timedelta(hours=1)
1717 >>> format_time_ago(past)
1718 '1 hour ago'
1720 >>> past = datetime.now() - timedelta(hours=6)
1721 >>> format_time_ago(past)
1722 '6 hours ago'
1724 >>> past = datetime.now() - timedelta(days=1)
1725 >>> format_time_ago(past)
1726 '1 day ago'
1728 >>> past = datetime.now() - timedelta(days=5)
1729 >>> format_time_ago(past)
1730 '5 days ago'
1732 >>> past = datetime.now() - timedelta(days=8)
1733 >>> format_time_ago(past)
1734 '1 week ago'
1736 >>> past = datetime.now() - timedelta(days=16)
1737 >>> format_time_ago(past)
1738 '2 weeks ago'
1740 >>> past = datetime.now() - timedelta(days=190)
1741 >>> format_time_ago(past)
1742 '6 months ago'
1744 >>> past = datetime.now() - timedelta(days=800)
1745 >>> format_time_ago(past)
1746 '2 years ago'
1748 """
1749 units = {
1750 "years": lambda diff: diff.days / 365,
1751 "months": lambda diff: diff.days / 30.436875, # Average days per month
1752 "weeks": lambda diff: diff.days / 7,
1753 "days": lambda diff: diff.days,
1754 "hours": lambda diff: diff.seconds / 3600,
1755 "minutes": lambda diff: diff.seconds % 3600 / 60,
1756 }
1757 diff = datetime.now() - n
1758 for unit, func in units.items():
1759 dur = int(func(diff))
1760 if dur > 0:
1761 if dur == 1: # De-pluralize
1762 unit = unit[:-1]
1763 return f"{dur} {unit} ago"
1764 return "Just now"
1767def format_bytes(n: int) -> str:
1768 """Format bytes as text
1770 >>> from dask.utils import format_bytes
1771 >>> format_bytes(1)
1772 '1 B'
1773 >>> format_bytes(1234)
1774 '1.21 kiB'
1775 >>> format_bytes(12345678)
1776 '11.77 MiB'
1777 >>> format_bytes(1234567890)
1778 '1.15 GiB'
1779 >>> format_bytes(1234567890000)
1780 '1.12 TiB'
1781 >>> format_bytes(1234567890000000)
1782 '1.10 PiB'
1784 For all values < 2**60, the output is always <= 10 characters.
1785 """
1786 for prefix, k in (
1787 ("Pi", 2**50),
1788 ("Ti", 2**40),
1789 ("Gi", 2**30),
1790 ("Mi", 2**20),
1791 ("ki", 2**10),
1792 ):
1793 if n >= k * 0.9:
1794 return f"{n / k:.2f} {prefix}B"
1795 return f"{n} B"
1798timedelta_sizes = {
1799 "s": 1,
1800 "ms": 1e-3,
1801 "us": 1e-6,
1802 "ns": 1e-9,
1803 "m": 60,
1804 "h": 3600,
1805 "d": 3600 * 24,
1806 "w": 7 * 3600 * 24,
1807}
1809tds2 = {
1810 "second": 1,
1811 "minute": 60,
1812 "hour": 60 * 60,
1813 "day": 60 * 60 * 24,
1814 "week": 7 * 60 * 60 * 24,
1815 "millisecond": 1e-3,
1816 "microsecond": 1e-6,
1817 "nanosecond": 1e-9,
1818}
1819tds2.update({k + "s": v for k, v in tds2.items()})
1820timedelta_sizes.update(tds2)
1821timedelta_sizes.update({k.upper(): v for k, v in timedelta_sizes.items()})
1824@overload
1825def parse_timedelta(s: None, default: str | Literal[False] = "seconds") -> None: ...
1828@overload
1829def parse_timedelta(
1830 s: str | float | timedelta, default: str | Literal[False] = "seconds"
1831) -> float: ...
1834def parse_timedelta(s, default="seconds"):
1835 """Parse timedelta string to number of seconds
1837 Parameters
1838 ----------
1839 s : str, float, timedelta, or None
1840 default: str or False, optional
1841 Unit of measure if s does not specify one. Defaults to seconds.
1842 Set to False to require s to explicitly specify its own unit.
1844 Examples
1845 --------
1846 >>> from datetime import timedelta
1847 >>> from dask.utils import parse_timedelta
1848 >>> parse_timedelta('3s')
1849 3
1850 >>> parse_timedelta('3.5 seconds')
1851 3.5
1852 >>> parse_timedelta('300ms')
1853 0.3
1854 >>> parse_timedelta(timedelta(seconds=3)) # also supports timedeltas
1855 3
1856 """
1857 if s is None:
1858 return None
1859 if isinstance(s, timedelta):
1860 s = s.total_seconds()
1861 return int(s) if int(s) == s else s
1862 if isinstance(s, Number):
1863 s = str(s)
1864 s = s.replace(" ", "")
1865 if not s[0].isdigit():
1866 s = "1" + s
1868 for i in range(len(s) - 1, -1, -1):
1869 if not s[i].isalpha():
1870 break
1871 index = i + 1
1873 prefix = s[:index]
1874 suffix = s[index:] or default
1875 if suffix is False:
1876 raise ValueError(f"Missing time unit: {s}")
1877 if not isinstance(suffix, str):
1878 raise TypeError(f"default must be str or False, got {default!r}")
1880 n = float(prefix)
1882 try:
1883 multiplier = timedelta_sizes[suffix.lower()]
1884 except KeyError:
1885 valid_units = ", ".join(timedelta_sizes.keys())
1886 raise KeyError(
1887 f"Invalid time unit: {suffix}. Valid units are: {valid_units}"
1888 ) from None
1890 result = n * multiplier
1891 if int(result) == result:
1892 result = int(result)
1893 return result
1896def has_keyword(func, keyword):
1897 try:
1898 return keyword in inspect.signature(func).parameters
1899 except Exception:
1900 return False
1903def ndimlist(seq):
1904 if not isinstance(seq, (list, tuple)):
1905 return 0
1906 elif not seq:
1907 return 1
1908 else:
1909 return 1 + ndimlist(seq[0])
1912def iter_chunks(sizes, max_size):
1913 """Split sizes into chunks of total max_size each
1915 Parameters
1916 ----------
1917 sizes : iterable of numbers
1918 The sizes to be chunked
1919 max_size : number
1920 Maximum total size per chunk.
1921 It must be greater or equal than each size in sizes
1922 """
1923 chunk, chunk_sum = [], 0
1924 iter_sizes = iter(sizes)
1925 size = next(iter_sizes, None)
1926 while size is not None:
1927 assert size <= max_size
1928 if chunk_sum + size <= max_size:
1929 chunk.append(size)
1930 chunk_sum += size
1931 size = next(iter_sizes, None)
1932 else:
1933 assert chunk
1934 yield chunk
1935 chunk, chunk_sum = [], 0
1936 if chunk:
1937 yield chunk
1940hex_pattern = re.compile("[a-f]+")
1943@functools.lru_cache(100000)
1944def key_split(s):
1945 """
1946 >>> key_split('x')
1947 'x'
1948 >>> key_split('x-1')
1949 'x'
1950 >>> key_split('x-1-2-3')
1951 'x'
1952 >>> key_split(('x-2', 1))
1953 'x'
1954 >>> key_split("('x-2', 1)")
1955 'x'
1956 >>> key_split("('x', 1)")
1957 'x'
1958 >>> key_split('hello-world-1')
1959 'hello-world'
1960 >>> key_split(b'hello-world-1')
1961 'hello-world'
1962 >>> key_split('ae05086432ca935f6eba409a8ecd4896')
1963 'data'
1964 >>> key_split('<module.submodule.myclass object at 0xdaf372')
1965 'myclass'
1966 >>> key_split(None)
1967 'Other'
1968 >>> key_split('x-abcdefab') # ignores hex
1969 'x'
1970 >>> key_split('_(x)') # strips unpleasant characters
1971 'x'
1972 """
1973 # If we convert the key, recurse to utilize LRU cache better
1974 if type(s) is bytes:
1975 return key_split(s.decode())
1976 if type(s) is tuple:
1977 return key_split(s[0])
1978 try:
1979 words = s.split("-")
1980 if not words[0][0].isalpha():
1981 result = words[0].split(",")[0].strip("_'()\"")
1982 else:
1983 result = words[0]
1984 for word in words[1:]:
1985 if word.isalpha() and not (
1986 len(word) == 8 and hex_pattern.match(word) is not None
1987 ):
1988 result += "-" + word
1989 else:
1990 break
1991 if len(result) == 32 and re.match(r"[a-f0-9]{32}", result):
1992 return "data"
1993 else:
1994 if result[0] == "<":
1995 result = result.strip("<>").split()[0].split(".")[-1]
1996 return sys.intern(result)
1997 except Exception:
1998 return "Other"
2001def stringify(obj, exclusive: Iterable | None = None):
2002 """Convert an object to a string
2004 If ``exclusive`` is specified, search through `obj` and convert
2005 values that are in ``exclusive``.
2007 Note that when searching through dictionaries, only values are
2008 converted, not the keys.
2010 Parameters
2011 ----------
2012 obj : Any
2013 Object (or values within) to convert to string
2014 exclusive: Iterable, optional
2015 Set of values to search for when converting values to strings
2017 Returns
2018 -------
2019 result : type(obj)
2020 Stringified copy of ``obj`` or ``obj`` itself if it is already a
2021 string or bytes.
2023 Examples
2024 --------
2025 >>> stringify(b'x')
2026 b'x'
2027 >>> stringify('x')
2028 'x'
2029 >>> stringify({('a',0):('a',0), ('a',1): ('a',1)})
2030 "{('a', 0): ('a', 0), ('a', 1): ('a', 1)}"
2031 >>> stringify({('a',0):('a',0), ('a',1): ('a',1)}, exclusive={('a',0)})
2032 {('a', 0): "('a', 0)", ('a', 1): ('a', 1)}
2033 """
2035 typ = type(obj)
2036 if typ is str or typ is bytes:
2037 return obj
2038 elif exclusive is None:
2039 return str(obj)
2041 if typ is list:
2042 return [stringify(v, exclusive) for v in obj]
2043 if typ is dict:
2044 return {k: stringify(v, exclusive) for k, v in obj.items()}
2045 try:
2046 if obj in exclusive:
2047 return stringify(obj)
2048 except TypeError: # `obj` not hashable
2049 pass
2050 if typ is tuple: # If the tuple itself isn't a key, check its elements
2051 return tuple(stringify(v, exclusive) for v in obj)
2052 return obj
2055class cached_property(functools.cached_property):
2056 """Read only version of functools.cached_property."""
2058 def __set__(self, instance, val):
2059 """Raise an error when attempting to set a cached property."""
2060 raise AttributeError("Can't set attribute")
2063class _HashIdWrapper:
2064 """Hash and compare a wrapped object by identity instead of value"""
2066 def __init__(self, wrapped):
2067 self.wrapped = wrapped
2069 def __eq__(self, other):
2070 if not isinstance(other, _HashIdWrapper):
2071 return NotImplemented
2072 return self.wrapped is other.wrapped
2074 def __ne__(self, other):
2075 if not isinstance(other, _HashIdWrapper):
2076 return NotImplemented
2077 return self.wrapped is not other.wrapped
2079 def __hash__(self):
2080 return id(self.wrapped)
2083@functools.lru_cache
2084def _cumsum(seq, initial_zero):
2085 if isinstance(seq, _HashIdWrapper):
2086 seq = seq.wrapped
2087 if initial_zero:
2088 return tuple(toolz.accumulate(add, seq, 0))
2089 else:
2090 return tuple(toolz.accumulate(add, seq))
2093@functools.lru_cache
2094def _max(seq):
2095 if isinstance(seq, _HashIdWrapper):
2096 seq = seq.wrapped
2097 return max(seq)
2100def cached_max(seq):
2101 """Compute max with caching.
2103 Caching is by the identity of `seq` rather than the value. It is thus
2104 important that `seq` is a tuple of immutable objects, and this function
2105 is intended for use where `seq` is a value that will persist (generally
2106 block sizes).
2108 Parameters
2109 ----------
2110 seq : tuple
2111 Values to reduce
2113 Returns
2114 -------
2115 tuple
2116 """
2117 assert isinstance(seq, tuple)
2118 # Look up by identity first, to avoid a linear-time __hash__
2119 # if we've seen this tuple object before.
2120 result = _max(_HashIdWrapper(seq))
2121 return result
2124def cached_cumsum(seq, initial_zero=False):
2125 """Compute :meth:`toolz.accumulate` with caching.
2127 Caching is by the identify of `seq` rather than the value. It is thus
2128 important that `seq` is a tuple of immutable objects, and this function
2129 is intended for use where `seq` is a value that will persist (generally
2130 block sizes).
2132 Parameters
2133 ----------
2134 seq : tuple
2135 Values to cumulatively sum.
2136 initial_zero : bool, optional
2137 If true, the return value is prefixed with a zero.
2139 Returns
2140 -------
2141 tuple
2142 """
2143 if isinstance(seq, tuple):
2144 # Look up by identity first, to avoid a linear-time __hash__
2145 # if we've seen this tuple object before.
2146 result = _cumsum(_HashIdWrapper(seq), initial_zero)
2147 else:
2148 # Construct a temporary tuple, and look up by value.
2149 result = _cumsum(tuple(seq), initial_zero)
2150 return result
2153def show_versions() -> None:
2154 """Provide version information for bug reports."""
2156 from json import dumps
2157 from platform import uname
2158 from sys import stdout, version_info
2160 from dask._compatibility import importlib_metadata
2162 try:
2163 from distributed import __version__ as distributed_version
2164 except ImportError:
2165 distributed_version = None
2167 from dask import __version__ as dask_version
2169 deps = [
2170 "numpy",
2171 "pandas",
2172 "cloudpickle",
2173 "fsspec",
2174 "bokeh",
2175 "pyarrow",
2176 "zarr",
2177 ]
2179 result: dict[str, str | None] = {
2180 # note: only major, minor, micro are extracted
2181 "Python": ".".join([str(i) for i in version_info[:3]]),
2182 "Platform": uname().system,
2183 "dask": dask_version,
2184 "distributed": distributed_version,
2185 }
2187 for modname in deps:
2188 try:
2189 result[modname] = importlib_metadata.version(modname)
2190 except importlib_metadata.PackageNotFoundError:
2191 result[modname] = None
2193 stdout.writelines(dumps(result, indent=2))
2196def maybe_pluralize(count, noun, plural_form=None):
2197 """Pluralize a count-noun string pattern when necessary"""
2198 if count == 1:
2199 return f"{count} {noun}"
2200 else:
2201 return f"{count} {plural_form or noun + 's'}"
2204def is_namedtuple_instance(obj: Any) -> bool:
2205 """Returns True if obj is an instance of a namedtuple.
2207 Note: This function checks for the existence of the methods and
2208 attributes that make up the namedtuple API, so it will return True
2209 IFF obj's type implements that API.
2210 """
2211 return (
2212 isinstance(obj, tuple)
2213 and hasattr(obj, "_make")
2214 and hasattr(obj, "_asdict")
2215 and hasattr(obj, "_replace")
2216 and hasattr(obj, "_fields")
2217 and hasattr(obj, "_field_defaults")
2218 )
2221def get_default_shuffle_method() -> str:
2222 if d := config.get("dataframe.shuffle.method", None):
2223 return d
2224 try:
2225 from distributed import default_client
2227 default_client()
2228 except (ImportError, ValueError):
2229 return "disk"
2231 try:
2232 from distributed.shuffle import check_minimal_arrow_version
2234 check_minimal_arrow_version()
2235 except ModuleNotFoundError:
2236 return "tasks"
2237 return "p2p"
2240def get_meta_library(like):
2241 if hasattr(like, "_meta"):
2242 like = like._meta
2244 return import_module(typename(like).partition(".")[0])
2247class shorten_traceback:
2248 """Context manager that removes irrelevant stack elements from traceback.
2250 * omits frames from modules that match `admin.traceback.shorten`
2251 * always keeps the first and last frame.
2252 """
2254 __slots__ = ()
2256 def __enter__(self) -> None:
2257 pass
2259 def __exit__(
2260 self,
2261 exc_type: type[BaseException] | None,
2262 exc_val: BaseException | None,
2263 exc_tb: types.TracebackType | None,
2264 ) -> None:
2265 if exc_val and exc_tb:
2266 exc_val.__traceback__ = self.shorten(exc_tb)
2268 @staticmethod
2269 def shorten(exc_tb: types.TracebackType) -> types.TracebackType:
2270 paths = config.get("admin.traceback.shorten")
2271 if not paths:
2272 return exc_tb
2274 exp = re.compile(".*(" + "|".join(paths) + ")")
2275 curr: types.TracebackType | None = exc_tb
2276 prev: types.TracebackType | None = None
2278 while curr:
2279 if prev is None:
2280 prev = curr # first frame
2281 elif not curr.tb_next:
2282 # always keep last frame
2283 prev.tb_next = curr
2284 prev = prev.tb_next
2285 elif not exp.match(curr.tb_frame.f_code.co_filename):
2286 # keep if module is not listed in config
2287 prev.tb_next = curr
2288 prev = curr
2289 curr = curr.tb_next
2291 # Uncomment to remove the first frame, which is something you don't want to keep
2292 # if it matches the regexes. Requires Python >=3.11.
2293 # if exc_tb.tb_next and exp.match(exc_tb.tb_frame.f_code.co_filename):
2294 # return exc_tb.tb_next
2296 return exc_tb
2299def unzip(ls, nout):
2300 """Unzip a list of lists into ``nout`` outputs."""
2301 out = list(zip(*ls))
2302 if not out:
2303 out = [()] * nout
2304 return out
2307class disable_gc(ContextDecorator):
2308 """Context manager to disable garbage collection."""
2310 def __init__(self, collect=False):
2311 self.collect = collect
2312 self._gc_enabled = gc.isenabled()
2314 def __enter__(self):
2315 gc.disable()
2316 return self
2318 def __exit__(self, exc_type, exc_value, traceback):
2319 if self._gc_enabled:
2320 gc.enable()
2321 return False