Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/utils.py: 30%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import codecs
4import functools
5import gc
6import inspect
7import os
8import re
9import shutil
10import sys
11import tempfile
12import types
13import uuid
14import warnings
15from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, Set
16from contextlib import ContextDecorator, contextmanager, nullcontext, suppress
17from datetime import datetime, timedelta
18from errno import ENOENT
19from functools import wraps
20from importlib import import_module
21from numbers import Integral, Number
22from operator import add
23from threading import Lock
24from typing import Any, ClassVar, Literal, TypeVar, cast, overload
25from weakref import WeakValueDictionary
27import tlz as toolz
29from dask import config
30from dask.typing import no_default
32K = TypeVar("K")
33V = TypeVar("V")
34T = TypeVar("T")
36# used in decorators to preserve the signature of the function it decorates
37# see https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators
38FuncType = Callable[..., Any]
39F = TypeVar("F", bound=FuncType)
41system_encoding = sys.getdefaultencoding()
42if system_encoding == "ascii":
43 system_encoding = "utf-8"
46def apply(func, args, kwargs=None):
47 """Apply a function given its positional and keyword arguments.
49 Equivalent to ``func(*args, **kwargs)``
50 Most Dask users will never need to use the ``apply`` function.
51 It is typically only used by people who need to inject
52 keyword argument values into a low level Dask task graph.
54 Parameters
55 ----------
56 func : callable
57 The function you want to apply.
58 args : tuple
59 A tuple containing all the positional arguments needed for ``func``
60 (eg: ``(arg_1, arg_2, arg_3)``)
61 kwargs : dict, optional
62 A dictionary mapping the keyword arguments
63 (eg: ``{"kwarg_1": value, "kwarg_2": value}``
65 Examples
66 --------
67 >>> from dask.utils import apply
68 >>> def add(number, second_number=5):
69 ... return number + second_number
70 ...
71 >>> apply(add, (10,), {"second_number": 2}) # equivalent to add(*args, **kwargs)
72 12
74 >>> task = apply(add, (10,), {"second_number": 2})
75 >>> dsk = {'task-name': task} # adds the task to a low level Dask task graph
76 """
77 if kwargs:
78 return func(*args, **kwargs)
79 else:
80 return func(*args)
83def _deprecated(
84 *,
85 version: str | None = None,
86 after_version: str | None = None,
87 message: str | None = None,
88 use_instead: str | None = None,
89 category: type[Warning] = FutureWarning,
90):
91 """Decorator to mark a function as deprecated
93 Parameters
94 ----------
95 version : str, optional
96 Version of Dask in which the function was deprecated. If specified, the version
97 will be included in the default warning message. This should no longer be used
98 after the introduction of automated versioning system.
99 after_version : str, optional
100 Version of Dask after which the function was deprecated. If specified, the
101 version will be included in the default warning message.
102 message : str, optional
103 Custom warning message to raise.
104 use_instead : str, optional
105 Name of function to use in place of the deprecated function.
106 If specified, this will be included in the default warning
107 message.
108 category : type[Warning], optional
109 Type of warning to raise. Defaults to ``FutureWarning``.
111 Examples
112 --------
114 >>> from dask.utils import _deprecated
115 >>> @_deprecated(after_version="X.Y.Z", use_instead="bar")
116 ... def foo():
117 ... return "baz"
118 """
120 def decorator(func):
121 if message is None:
122 msg = f"{func.__name__} "
123 if after_version is not None:
124 msg += f"was deprecated after version {after_version} "
125 elif version is not None:
126 msg += f"was deprecated in version {version} "
127 else:
128 msg += "is deprecated "
129 msg += "and will be removed in a future release."
131 if use_instead is not None:
132 msg += f" Please use {use_instead} instead."
133 else:
134 msg = message
136 @functools.wraps(func)
137 def wrapper(*args, **kwargs):
138 warnings.warn(msg, category=category, stacklevel=2)
139 return func(*args, **kwargs)
141 return wrapper
143 return decorator
146def _deprecated_kwarg(
147 old_arg_name: str,
148 new_arg_name: str | None = None,
149 mapping: Mapping[Any, Any] | Callable[[Any], Any] | None = None,
150 stacklevel: int = 2,
151 comment: str | None = None,
152) -> Callable[[F], F]:
153 """
154 Decorator to deprecate a keyword argument of a function.
156 Parameters
157 ----------
158 old_arg_name : str
159 Name of argument in function to deprecate
160 new_arg_name : str, optional
161 Name of preferred argument in function. Omit to warn that
162 ``old_arg_name`` keyword is deprecated.
163 mapping : dict or callable, optional
164 If mapping is present, use it to translate old arguments to
165 new arguments. A callable must do its own value checking;
166 values not found in a dict will be forwarded unchanged.
167 comment : str, optional
168 Additional message to deprecation message. Useful to pass
169 on suggestions with the deprecation warning.
171 Examples
172 --------
173 The following deprecates 'cols', using 'columns' instead
175 >>> @_deprecated_kwarg(old_arg_name='cols', new_arg_name='columns')
176 ... def f(columns=''):
177 ... print(columns)
178 ...
179 >>> f(columns='should work ok')
180 should work ok
182 >>> f(cols='should raise warning') # doctest: +SKIP
183 FutureWarning: cols is deprecated, use columns instead
184 warnings.warn(msg, FutureWarning)
185 should raise warning
187 >>> f(cols='should error', columns="can\'t pass do both") # doctest: +SKIP
188 TypeError: Can only specify 'cols' or 'columns', not both
190 >>> @_deprecated_kwarg('old', 'new', {'yes': True, 'no': False})
191 ... def f(new=False):
192 ... print('yes!' if new else 'no!')
193 ...
194 >>> f(old='yes') # doctest: +SKIP
195 FutureWarning: old='yes' is deprecated, use new=True instead
196 warnings.warn(msg, FutureWarning)
197 yes!
199 To raise a warning that a keyword will be removed entirely in the future
201 >>> @_deprecated_kwarg(old_arg_name='cols', new_arg_name=None)
202 ... def f(cols='', another_param=''):
203 ... print(cols)
204 ...
205 >>> f(cols='should raise warning') # doctest: +SKIP
206 FutureWarning: the 'cols' keyword is deprecated and will be removed in a
207 future version please takes steps to stop use of 'cols'
208 should raise warning
209 >>> f(another_param='should not raise warning') # doctest: +SKIP
210 should not raise warning
212 >>> f(cols='should raise warning', another_param='') # doctest: +SKIP
213 FutureWarning: the 'cols' keyword is deprecated and will be removed in a
214 future version please takes steps to stop use of 'cols'
215 should raise warning
216 """
217 if mapping is not None and not hasattr(mapping, "get") and not callable(mapping):
218 raise TypeError(
219 "mapping from old to new argument values must be dict or callable!"
220 )
222 comment_ = f"\n{comment}" or ""
224 def _deprecated_kwarg(func: F) -> F:
225 @wraps(func)
226 def wrapper(*args, **kwargs) -> Callable[..., Any]:
227 old_arg_value = kwargs.pop(old_arg_name, no_default)
229 if old_arg_value is not no_default:
230 if new_arg_name is None:
231 msg = (
232 f"the {repr(old_arg_name)} keyword is deprecated and "
233 "will be removed in a future version. Please take "
234 f"steps to stop the use of {repr(old_arg_name)}"
235 ) + comment_
236 warnings.warn(msg, FutureWarning, stacklevel=stacklevel)
237 kwargs[old_arg_name] = old_arg_value
238 return func(*args, **kwargs)
240 elif mapping is not None:
241 if callable(mapping):
242 new_arg_value = mapping(old_arg_value)
243 else:
244 new_arg_value = mapping.get(old_arg_value, old_arg_value)
245 msg = (
246 f"the {old_arg_name}={repr(old_arg_value)} keyword is "
247 "deprecated, use "
248 f"{new_arg_name}={repr(new_arg_value)} instead."
249 )
250 else:
251 new_arg_value = old_arg_value
252 msg = (
253 f"the {repr(old_arg_name)} keyword is deprecated, "
254 f"use {repr(new_arg_name)} instead."
255 )
257 warnings.warn(msg + comment_, FutureWarning, stacklevel=stacklevel)
258 if kwargs.get(new_arg_name) is not None:
259 msg = (
260 f"Can only specify {repr(old_arg_name)} "
261 f"or {repr(new_arg_name)}, not both."
262 )
263 raise TypeError(msg)
264 kwargs[new_arg_name] = new_arg_value
265 return func(*args, **kwargs)
267 return cast(F, wrapper)
269 return _deprecated_kwarg
272def deepmap(func, *seqs):
273 """Apply function inside nested lists
275 >>> inc = lambda x: x + 1
276 >>> deepmap(inc, [[1, 2], [3, 4]])
277 [[2, 3], [4, 5]]
279 >>> add = lambda x, y: x + y
280 >>> deepmap(add, [[1, 2], [3, 4]], [[10, 20], [30, 40]])
281 [[11, 22], [33, 44]]
282 """
283 if isinstance(seqs[0], (list, Iterator)):
284 return [deepmap(func, *items) for items in zip(*seqs)]
285 else:
286 return func(*seqs)
289@_deprecated()
290def homogeneous_deepmap(func, seq):
291 if not seq:
292 return seq
293 n = 0
294 tmp = seq
295 while isinstance(tmp, list):
296 n += 1
297 tmp = tmp[0]
299 return ndeepmap(n, func, seq)
302def ndeepmap(n, func, seq):
303 """Call a function on every element within a nested container
305 >>> def inc(x):
306 ... return x + 1
307 >>> L = [[1, 2], [3, 4, 5]]
308 >>> ndeepmap(2, inc, L)
309 [[2, 3], [4, 5, 6]]
310 """
311 if n == 1:
312 return [func(item) for item in seq]
313 elif n > 1:
314 return [ndeepmap(n - 1, func, item) for item in seq]
315 elif isinstance(seq, list):
316 return func(seq[0])
317 else:
318 return func(seq)
321def import_required(mod_name, error_msg):
322 """Attempt to import a required dependency.
324 Raises a RuntimeError if the requested module is not available.
325 """
326 try:
327 return import_module(mod_name)
328 except ImportError as e:
329 raise RuntimeError(error_msg) from e
332@contextmanager
333def tmpfile(extension="", dir=None):
334 """
335 Function to create and return a unique temporary file with the given extension, if provided.
337 Parameters
338 ----------
339 extension : str
340 The extension of the temporary file to be created
341 dir : str
342 If ``dir`` is not None, the file will be created in that directory; otherwise,
343 Python's default temporary directory is used.
345 Returns
346 -------
347 out : str
348 Path to the temporary file
350 See Also
351 --------
352 NamedTemporaryFile : Built-in alternative for creating temporary files
353 tmp_path : pytest fixture for creating a temporary directory unique to the test invocation
355 Notes
356 -----
357 This context manager is particularly useful on Windows for opening temporary files multiple times.
358 """
359 extension = extension.lstrip(".")
360 if extension:
361 extension = "." + extension
362 handle, filename = tempfile.mkstemp(extension, dir=dir)
363 os.close(handle)
364 os.remove(filename)
366 try:
367 yield filename
368 finally:
369 if os.path.exists(filename):
370 with suppress(OSError): # sometimes we can't remove a generated temp file
371 if os.path.isdir(filename):
372 shutil.rmtree(filename)
373 else:
374 os.remove(filename)
377@contextmanager
378def tmpdir(dir=None):
379 """
380 Function to create and return a unique temporary directory.
382 Parameters
383 ----------
384 dir : str
385 If ``dir`` is not None, the directory will be created in that directory; otherwise,
386 Python's default temporary directory is used.
388 Returns
389 -------
390 out : str
391 Path to the temporary directory
393 Notes
394 -----
395 This context manager is particularly useful on Windows for opening temporary directories multiple times.
396 """
397 dirname = tempfile.mkdtemp(dir=dir)
399 try:
400 yield dirname
401 finally:
402 if os.path.exists(dirname):
403 if os.path.isdir(dirname):
404 with suppress(OSError):
405 shutil.rmtree(dirname)
406 else:
407 with suppress(OSError):
408 os.remove(dirname)
411@contextmanager
412def filetext(text, extension="", open=open, mode="w"):
413 with tmpfile(extension=extension) as filename:
414 f = open(filename, mode=mode)
415 try:
416 f.write(text)
417 finally:
418 try:
419 f.close()
420 except AttributeError:
421 pass
423 yield filename
426@contextmanager
427def changed_cwd(new_cwd):
428 old_cwd = os.getcwd()
429 os.chdir(new_cwd)
430 try:
431 yield
432 finally:
433 os.chdir(old_cwd)
436@contextmanager
437def tmp_cwd(dir=None):
438 with tmpdir(dir) as dirname:
439 with changed_cwd(dirname):
440 yield dirname
443class IndexCallable:
444 """Provide getitem syntax for functions
446 >>> def inc(x):
447 ... return x + 1
449 >>> I = IndexCallable(inc)
450 >>> I[3]
451 4
452 """
454 __slots__ = ("fn",)
456 def __init__(self, fn):
457 self.fn = fn
459 def __getitem__(self, key):
460 return self.fn(key)
463@contextmanager
464def filetexts(d, open=open, mode="t", use_tmpdir=True):
465 """Dumps a number of textfiles to disk
467 Parameters
468 ----------
469 d : dict
470 a mapping from filename to text like {'a.csv': '1,1\n2,2'}
472 Since this is meant for use in tests, this context manager will
473 automatically switch to a temporary current directory, to avoid
474 race conditions when running tests in parallel.
475 """
476 with tmp_cwd() if use_tmpdir else nullcontext():
477 for filename, text in d.items():
478 try:
479 os.makedirs(os.path.dirname(filename))
480 except OSError:
481 pass
482 f = open(filename, "w" + mode)
483 try:
484 f.write(text)
485 finally:
486 try:
487 f.close()
488 except AttributeError:
489 pass
491 yield list(d)
493 for filename in d:
494 if os.path.exists(filename):
495 with suppress(OSError):
496 os.remove(filename)
499def concrete(seq):
500 """Make nested iterators concrete lists
502 >>> data = [[1, 2], [3, 4]]
503 >>> seq = iter(map(iter, data))
504 >>> concrete(seq)
505 [[1, 2], [3, 4]]
506 """
507 if isinstance(seq, Iterator):
508 seq = list(seq)
509 if isinstance(seq, (tuple, list)):
510 seq = list(map(concrete, seq))
511 return seq
514def pseudorandom(n: int, p, random_state=None):
515 """Pseudorandom array of integer indexes
517 >>> pseudorandom(5, [0.5, 0.5], random_state=123)
518 array([1, 0, 0, 1, 1], dtype=int8)
520 >>> pseudorandom(10, [0.5, 0.2, 0.2, 0.1], random_state=5)
521 array([0, 2, 0, 3, 0, 1, 2, 1, 0, 0], dtype=int8)
522 """
523 import numpy as np
525 p = list(p)
526 cp = np.cumsum([0] + p)
527 assert np.allclose(1, cp[-1])
528 assert len(p) < 256
530 if not isinstance(random_state, np.random.RandomState):
531 random_state = np.random.RandomState(random_state)
533 x = random_state.random_sample(n)
534 out = np.empty(n, dtype="i1")
536 for i, (low, high) in enumerate(zip(cp[:-1], cp[1:])):
537 out[(x >= low) & (x < high)] = i
538 return out
541def random_state_data(n: int, random_state=None) -> list:
542 """Return a list of arrays that can initialize
543 ``np.random.RandomState``.
545 Parameters
546 ----------
547 n : int
548 Number of arrays to return.
549 random_state : int or np.random.RandomState, optional
550 If an int, is used to seed a new ``RandomState``.
551 """
552 import numpy as np
554 if not all(
555 hasattr(random_state, attr) for attr in ["normal", "beta", "bytes", "uniform"]
556 ):
557 random_state = np.random.RandomState(random_state)
559 random_data = random_state.bytes(624 * n * 4) # `n * 624` 32-bit integers
560 l = list(np.frombuffer(random_data, dtype="<u4").reshape((n, -1)))
561 assert len(l) == n
562 return l
565def is_integer(i) -> bool:
566 """
567 >>> is_integer(6)
568 True
569 >>> is_integer(42.0)
570 True
571 >>> is_integer('abc')
572 False
573 """
574 return isinstance(i, Integral) or (isinstance(i, float) and i.is_integer())
577ONE_ARITY_BUILTINS = {
578 abs,
579 all,
580 any,
581 ascii,
582 bool,
583 bytearray,
584 bytes,
585 callable,
586 chr,
587 classmethod,
588 complex,
589 dict,
590 dir,
591 enumerate,
592 eval,
593 float,
594 format,
595 frozenset,
596 hash,
597 hex,
598 id,
599 int,
600 iter,
601 len,
602 list,
603 max,
604 min,
605 next,
606 oct,
607 open,
608 ord,
609 range,
610 repr,
611 reversed,
612 round,
613 set,
614 slice,
615 sorted,
616 staticmethod,
617 str,
618 sum,
619 tuple,
620 type,
621 vars,
622 zip,
623 memoryview,
624}
625MULTI_ARITY_BUILTINS = {
626 compile,
627 delattr,
628 divmod,
629 filter,
630 getattr,
631 hasattr,
632 isinstance,
633 issubclass,
634 map,
635 pow,
636 setattr,
637}
640def getargspec(func):
641 """Version of inspect.getargspec that works with partial and warps."""
642 if isinstance(func, functools.partial):
643 return getargspec(func.func)
645 func = getattr(func, "__wrapped__", func)
646 if isinstance(func, type):
647 return inspect.getfullargspec(func.__init__)
648 else:
649 return inspect.getfullargspec(func)
652def takes_multiple_arguments(func, varargs=True):
653 """Does this function take multiple arguments?
655 >>> def f(x, y): pass
656 >>> takes_multiple_arguments(f)
657 True
659 >>> def f(x): pass
660 >>> takes_multiple_arguments(f)
661 False
663 >>> def f(x, y=None): pass
664 >>> takes_multiple_arguments(f)
665 False
667 >>> def f(*args): pass
668 >>> takes_multiple_arguments(f)
669 True
671 >>> class Thing:
672 ... def __init__(self, a): pass
673 >>> takes_multiple_arguments(Thing)
674 False
676 """
677 if func in ONE_ARITY_BUILTINS:
678 return False
679 elif func in MULTI_ARITY_BUILTINS:
680 return True
682 try:
683 spec = getargspec(func)
684 except Exception:
685 return False
687 try:
688 is_constructor = spec.args[0] == "self" and isinstance(func, type)
689 except Exception:
690 is_constructor = False
692 if varargs and spec.varargs:
693 return True
695 ndefaults = 0 if spec.defaults is None else len(spec.defaults)
696 return len(spec.args) - ndefaults - is_constructor > 1
699def get_named_args(func) -> list[str]:
700 """Get all non ``*args/**kwargs`` arguments for a function"""
701 s = inspect.signature(func)
702 return [
703 n
704 for n, p in s.parameters.items()
705 if p.kind in [p.POSITIONAL_OR_KEYWORD, p.POSITIONAL_ONLY, p.KEYWORD_ONLY]
706 ]
709class Dispatch:
710 """Simple single dispatch."""
712 def __init__(self, name=None):
713 self._lookup = {}
714 self._lazy = {}
715 if name:
716 self.__name__ = name
718 def register(self, type, func=None):
719 """Register dispatch of `func` on arguments of type `type`"""
721 def wrapper(func):
722 if isinstance(type, tuple):
723 for t in type:
724 self.register(t, func)
725 else:
726 self._lookup[type] = func
727 return func
729 return wrapper(func) if func is not None else wrapper
731 def register_lazy(self, toplevel, func=None):
732 """
733 Register a registration function which will be called if the
734 *toplevel* module (e.g. 'pandas') is ever loaded.
735 """
737 def wrapper(func):
738 self._lazy[toplevel] = func
739 return func
741 return wrapper(func) if func is not None else wrapper
743 def dispatch(self, cls):
744 """Return the function implementation for the given ``cls``"""
745 lk = self._lookup
746 if cls in lk:
747 return lk[cls]
748 for cls2 in cls.__mro__:
749 # Is a lazy registration function present?
750 try:
751 toplevel, _, _ = cls2.__module__.partition(".")
752 except Exception:
753 continue
754 try:
755 register = self._lazy[toplevel]
756 except KeyError:
757 pass
758 else:
759 register()
760 self._lazy.pop(toplevel, None)
761 meth = self.dispatch(cls) # recurse
762 lk[cls] = meth
763 lk[cls2] = meth
764 return meth
765 try:
766 impl = lk[cls2]
767 except KeyError:
768 pass
769 else:
770 if cls is not cls2:
771 # Cache lookup
772 lk[cls] = impl
773 return impl
774 raise TypeError(f"No dispatch for {cls}")
776 def __call__(self, arg, *args, **kwargs):
777 """
778 Call the corresponding method based on type of argument.
779 """
780 meth = self.dispatch(type(arg))
781 return meth(arg, *args, **kwargs)
783 @property
784 def __doc__(self):
785 try:
786 func = self.dispatch(object)
787 return func.__doc__
788 except TypeError:
789 return "Single Dispatch for %s" % self.__name__
792def ensure_not_exists(filename) -> None:
793 """
794 Ensure that a file does not exist.
795 """
796 try:
797 os.unlink(filename)
798 except OSError as e:
799 if e.errno != ENOENT:
800 raise
803def _skip_doctest(line):
804 # NumPy docstring contains cursor and comment only example
805 stripped = line.strip()
806 if stripped == ">>>" or stripped.startswith(">>> #"):
807 return line
808 elif ">>>" in stripped and "+SKIP" not in stripped:
809 if "# doctest:" in line:
810 return line + ", +SKIP"
811 else:
812 return line + " # doctest: +SKIP"
813 else:
814 return line
817def skip_doctest(doc):
818 if doc is None:
819 return ""
820 return "\n".join([_skip_doctest(line) for line in doc.split("\n")])
823def extra_titles(doc):
824 lines = doc.split("\n")
825 titles = {
826 i: lines[i].strip()
827 for i in range(len(lines) - 1)
828 if lines[i + 1].strip() and all(c == "-" for c in lines[i + 1].strip())
829 }
831 seen = set()
832 for i, title in sorted(titles.items()):
833 if title in seen:
834 new_title = "Extra " + title
835 lines[i] = lines[i].replace(title, new_title)
836 lines[i + 1] = lines[i + 1].replace("-" * len(title), "-" * len(new_title))
837 else:
838 seen.add(title)
840 return "\n".join(lines)
843def ignore_warning(doc, cls, name, extra="", skipblocks=0, inconsistencies=None):
844 """Expand docstring by adding disclaimer and extra text"""
845 import inspect
847 if inspect.isclass(cls):
848 l1 = "This docstring was copied from {}.{}.{}.\n\n".format(
849 cls.__module__,
850 cls.__name__,
851 name,
852 )
853 else:
854 l1 = f"This docstring was copied from {cls.__name__}.{name}.\n\n"
855 l2 = "Some inconsistencies with the Dask version may exist."
857 i = doc.find("\n\n")
858 if i != -1:
859 # Insert our warning
860 head = doc[: i + 2]
861 tail = doc[i + 2 :]
862 while skipblocks > 0:
863 i = tail.find("\n\n")
864 head = tail[: i + 2]
865 tail = tail[i + 2 :]
866 skipblocks -= 1
867 # Indentation of next line
868 indent = re.match(r"\s*", tail).group(0)
869 # Insert the warning, indented, with a blank line before and after
870 if extra:
871 more = [indent, extra.rstrip("\n") + "\n\n"]
872 else:
873 more = []
874 if inconsistencies is not None:
875 l3 = f"Known inconsistencies: \n {inconsistencies}"
876 bits = [head, indent, l1, l2, "\n\n", l3, "\n\n"] + more + [tail]
877 else:
878 bits = [head, indent, l1, indent, l2, "\n\n"] + more + [tail]
879 doc = "".join(bits)
881 return doc
884def unsupported_arguments(doc, args):
885 """Mark unsupported arguments with a disclaimer"""
886 lines = doc.split("\n")
887 for arg in args:
888 subset = [
889 (i, line)
890 for i, line in enumerate(lines)
891 if re.match(r"^\s*" + arg + " ?:", line)
892 ]
893 if len(subset) == 1:
894 [(i, line)] = subset
895 lines[i] = line + " (Not supported in Dask)"
896 return "\n".join(lines)
899def _derived_from(
900 cls, method, ua_args=None, extra="", skipblocks=0, inconsistencies=None
901):
902 """Helper function for derived_from to ease testing"""
903 ua_args = ua_args or []
905 # do not use wraps here, as it hides keyword arguments displayed
906 # in the doc
907 original_method = getattr(cls, method.__name__)
909 doc = getattr(original_method, "__doc__", None)
911 if isinstance(original_method, property):
912 # some things like SeriesGroupBy.unique are generated.
913 original_method = original_method.fget
914 if not doc:
915 doc = getattr(original_method, "__doc__", None)
917 if isinstance(original_method, functools.cached_property):
918 original_method = original_method.func
919 if not doc:
920 doc = getattr(original_method, "__doc__", None)
922 if doc is None:
923 doc = ""
925 # pandas DataFrame/Series sometimes override methods without setting __doc__
926 if not doc and cls.__name__ in {"DataFrame", "Series"}:
927 for obj in cls.mro():
928 obj_method = getattr(obj, method.__name__, None)
929 if obj_method is not None and obj_method.__doc__:
930 doc = obj_method.__doc__
931 break
933 # Insert disclaimer that this is a copied docstring
934 if doc:
935 doc = ignore_warning(
936 doc,
937 cls,
938 method.__name__,
939 extra=extra,
940 skipblocks=skipblocks,
941 inconsistencies=inconsistencies,
942 )
943 elif extra:
944 doc += extra.rstrip("\n") + "\n\n"
946 # Mark unsupported arguments
947 try:
948 method_args = get_named_args(method)
949 original_args = get_named_args(original_method)
950 not_supported = [m for m in original_args if m not in method_args]
951 except ValueError:
952 not_supported = []
953 if len(ua_args) > 0:
954 not_supported.extend(ua_args)
955 if len(not_supported) > 0:
956 doc = unsupported_arguments(doc, not_supported)
958 doc = skip_doctest(doc)
959 doc = extra_titles(doc)
961 return doc
964def derived_from(
965 original_klass, version=None, ua_args=None, skipblocks=0, inconsistencies=None
966):
967 """Decorator to attach original class's docstring to the wrapped method.
969 The output structure will be: top line of docstring, disclaimer about this
970 being auto-derived, any extra text associated with the method being patched,
971 the body of the docstring and finally, the list of keywords that exist in
972 the original method but not in the dask version.
974 Parameters
975 ----------
976 original_klass: type
977 Original class which the method is derived from
978 version : str
979 Original package version which supports the wrapped method
980 ua_args : list
981 List of keywords which Dask doesn't support. Keywords existing in
982 original but not in Dask will automatically be added.
983 skipblocks : int
984 How many text blocks (paragraphs) to skip from the start of the
985 docstring. Useful for cases where the target has extra front-matter.
986 inconsistencies: list
987 List of known inconsistencies with method whose docstrings are being
988 copied.
989 """
990 ua_args = ua_args or []
992 def wrapper(method):
993 try:
994 extra = getattr(method, "__doc__", None) or ""
995 method.__doc__ = _derived_from(
996 original_klass,
997 method,
998 ua_args=ua_args,
999 extra=extra,
1000 skipblocks=skipblocks,
1001 inconsistencies=inconsistencies,
1002 )
1003 return method
1005 except AttributeError:
1006 module_name = original_klass.__module__.split(".")[0]
1008 @functools.wraps(method)
1009 def wrapped(*args, **kwargs):
1010 msg = f"Base package doesn't support '{method.__name__}'."
1011 if version is not None:
1012 msg2 = " Use {0} {1} or later to use this method."
1013 msg += msg2.format(module_name, version)
1014 raise NotImplementedError(msg)
1016 return wrapped
1018 return wrapper
1021def funcname(func) -> str:
1022 """Get the name of a function."""
1023 # functools.partial
1024 if isinstance(func, functools.partial):
1025 return funcname(func.func)
1026 # methodcaller
1027 if isinstance(func, methodcaller):
1028 return func.method[:50]
1030 module_name = getattr(func, "__module__", None) or ""
1031 type_name = getattr(type(func), "__name__", None) or ""
1033 # toolz.curry
1034 if "toolz" in module_name and "curry" == type_name:
1035 return func.func_name[:50]
1036 # multipledispatch objects
1037 if "multipledispatch" in module_name and "Dispatcher" == type_name:
1038 return func.name[:50]
1039 # numpy.vectorize objects
1040 if "numpy" in module_name and "vectorize" == type_name:
1041 return ("vectorize_" + funcname(func.pyfunc))[:50]
1043 # All other callables
1044 try:
1045 name = func.__name__
1046 if name == "<lambda>":
1047 return "lambda"
1048 return name[:50]
1049 except AttributeError:
1050 return str(func)[:50]
1053def typename(typ: Any, short: bool = False) -> str:
1054 """
1055 Return the name of a type
1057 Examples
1058 --------
1059 >>> typename(int)
1060 'int'
1062 >>> from dask.core import literal
1063 >>> typename(literal)
1064 'dask.core.literal'
1065 >>> typename(literal, short=True)
1066 'dask.literal'
1067 """
1068 if not isinstance(typ, type):
1069 return typename(type(typ))
1070 try:
1071 if not typ.__module__ or typ.__module__ == "builtins":
1072 return typ.__name__
1073 else:
1074 if short:
1075 module, *_ = typ.__module__.split(".")
1076 else:
1077 module = typ.__module__
1078 return module + "." + typ.__name__
1079 except AttributeError:
1080 return str(typ)
1083def ensure_bytes(s) -> bytes:
1084 """Attempt to turn `s` into bytes.
1086 Parameters
1087 ----------
1088 s : Any
1089 The object to be converted. Will correctly handled
1090 * str
1091 * bytes
1092 * objects implementing the buffer protocol (memoryview, ndarray, etc.)
1094 Returns
1095 -------
1096 b : bytes
1098 Raises
1099 ------
1100 TypeError
1101 When `s` cannot be converted
1103 Examples
1104 --------
1105 >>> ensure_bytes('123')
1106 b'123'
1107 >>> ensure_bytes(b'123')
1108 b'123'
1109 >>> ensure_bytes(bytearray(b'123'))
1110 b'123'
1111 """
1112 if isinstance(s, bytes):
1113 return s
1114 elif hasattr(s, "encode"):
1115 return s.encode()
1116 else:
1117 try:
1118 return bytes(s)
1119 except Exception as e:
1120 raise TypeError(
1121 f"Object {s} is neither a bytes object nor can be encoded to bytes"
1122 ) from e
1125def ensure_unicode(s) -> str:
1126 """Turn string or bytes to string
1128 >>> ensure_unicode('123')
1129 '123'
1130 >>> ensure_unicode(b'123')
1131 '123'
1132 """
1133 if isinstance(s, str):
1134 return s
1135 elif hasattr(s, "decode"):
1136 return s.decode()
1137 else:
1138 try:
1139 return codecs.decode(s)
1140 except Exception as e:
1141 raise TypeError(
1142 f"Object {s} is neither a str object nor can be decoded to str"
1143 ) from e
1146def digit(n, k, base):
1147 """
1149 >>> digit(1234, 0, 10)
1150 4
1151 >>> digit(1234, 1, 10)
1152 3
1153 >>> digit(1234, 2, 10)
1154 2
1155 >>> digit(1234, 3, 10)
1156 1
1157 """
1158 return n // base**k % base
1161def insert(tup, loc, val):
1162 """
1164 >>> insert(('a', 'b', 'c'), 0, 'x')
1165 ('x', 'b', 'c')
1166 """
1167 L = list(tup)
1168 L[loc] = val
1169 return tuple(L)
1172def memory_repr(num):
1173 for x in ["bytes", "KB", "MB", "GB", "TB"]:
1174 if num < 1024.0:
1175 return f"{num:3.1f} {x}"
1176 num /= 1024.0
1179def asciitable(columns, rows):
1180 """Formats an ascii table for given columns and rows.
1182 Parameters
1183 ----------
1184 columns : list
1185 The column names
1186 rows : list of tuples
1187 The rows in the table. Each tuple must be the same length as
1188 ``columns``.
1189 """
1190 rows = [tuple(str(i) for i in r) for r in rows]
1191 columns = tuple(str(i) for i in columns)
1192 widths = tuple(max(max(map(len, x)), len(c)) for x, c in zip(zip(*rows), columns))
1193 row_template = ("|" + (" %%-%ds |" * len(columns))) % widths
1194 header = row_template % tuple(columns)
1195 bar = "+%s+" % "+".join("-" * (w + 2) for w in widths)
1196 data = "\n".join(row_template % r for r in rows)
1197 return "\n".join([bar, header, bar, data, bar])
1200def put_lines(buf, lines):
1201 if any(not isinstance(x, str) for x in lines):
1202 lines = [str(x) for x in lines]
1203 buf.write("\n".join(lines))
1206_method_cache: dict[str, methodcaller] = {}
1209class methodcaller:
1210 """
1211 Return a callable object that calls the given method on its operand.
1213 Unlike the builtin `operator.methodcaller`, instances of this class are
1214 cached and arguments are passed at call time instead of build time.
1215 """
1217 __slots__ = ("method",)
1218 method: str
1220 @property
1221 def func(self) -> str:
1222 # For `funcname` to work
1223 return self.method
1225 def __new__(cls, method: str):
1226 try:
1227 return _method_cache[method]
1228 except KeyError:
1229 self = object.__new__(cls)
1230 self.method = method
1231 _method_cache[method] = self
1232 return self
1234 def __call__(self, __obj, *args, **kwargs):
1235 return getattr(__obj, self.method)(*args, **kwargs)
1237 def __reduce__(self):
1238 return (methodcaller, (self.method,))
1240 def __str__(self):
1241 return f"<{self.__class__.__name__}: {self.method}>"
1243 __repr__ = __str__
1246class itemgetter:
1247 """Variant of operator.itemgetter that supports equality tests"""
1249 __slots__ = ("index",)
1251 def __init__(self, index):
1252 self.index = index
1254 def __call__(self, x):
1255 return x[self.index]
1257 def __reduce__(self):
1258 return (itemgetter, (self.index,))
1260 def __eq__(self, other):
1261 return type(self) is type(other) and self.index == other.index
1264class MethodCache:
1265 """Attribute access on this object returns a methodcaller for that
1266 attribute.
1268 Examples
1269 --------
1270 >>> a = [1, 3, 3]
1271 >>> M.count(a, 3) == a.count(3)
1272 True
1273 """
1275 def __getattr__(self, item):
1276 return methodcaller(item)
1278 def __dir__(self):
1279 return list(_method_cache)
1282M = MethodCache()
1285class SerializableLock:
1286 """A Serializable per-process Lock
1288 This wraps a normal ``threading.Lock`` object and satisfies the same
1289 interface. However, this lock can also be serialized and sent to different
1290 processes. It will not block concurrent operations between processes (for
1291 this you should look at ``multiprocessing.Lock`` or ``locket.lock_file``
1292 but will consistently deserialize into the same lock.
1294 So if we make a lock in one process::
1296 lock = SerializableLock()
1298 And then send it over to another process multiple times::
1300 bytes = pickle.dumps(lock)
1301 a = pickle.loads(bytes)
1302 b = pickle.loads(bytes)
1304 Then the deserialized objects will operate as though they were the same
1305 lock, and collide as appropriate.
1307 This is useful for consistently protecting resources on a per-process
1308 level.
1310 The creation of locks is itself not threadsafe.
1311 """
1313 _locks: ClassVar[WeakValueDictionary[Hashable, Lock]] = WeakValueDictionary()
1314 token: Hashable
1315 lock: Lock
1317 def __init__(self, token: Hashable | None = None):
1318 self.token = token or str(uuid.uuid4())
1319 if self.token in SerializableLock._locks:
1320 self.lock = SerializableLock._locks[self.token]
1321 else:
1322 self.lock = Lock()
1323 SerializableLock._locks[self.token] = self.lock
1325 def acquire(self, *args, **kwargs):
1326 return self.lock.acquire(*args, **kwargs)
1328 def release(self, *args, **kwargs):
1329 return self.lock.release(*args, **kwargs)
1331 def __enter__(self):
1332 self.lock.__enter__()
1334 def __exit__(self, *args):
1335 self.lock.__exit__(*args)
1337 def locked(self):
1338 return self.lock.locked()
1340 def __getstate__(self):
1341 return self.token
1343 def __setstate__(self, token):
1344 self.__init__(token)
1346 def __str__(self):
1347 return f"<{self.__class__.__name__}: {self.token}>"
1349 __repr__ = __str__
1352def get_scheduler_lock(collection=None, scheduler=None):
1353 """Get an instance of the appropriate lock for a certain situation based on
1354 scheduler used."""
1355 from dask import multiprocessing
1356 from dask.base import get_scheduler
1358 actual_get = get_scheduler(collections=[collection], scheduler=scheduler)
1360 if actual_get == multiprocessing.get:
1361 return multiprocessing.get_context().Manager().Lock()
1362 else:
1363 # if this is a distributed client, we need to lock on
1364 # the level between processes, SerializableLock won't work
1365 try:
1366 import distributed.lock
1367 from distributed.worker import get_client
1369 client = get_client()
1370 except (ImportError, ValueError):
1371 pass
1372 else:
1373 if actual_get == client.get:
1374 return distributed.lock.Lock()
1376 return SerializableLock()
1379def ensure_dict(d: Mapping[K, V], *, copy: bool = False) -> dict[K, V]:
1380 """Convert a generic Mapping into a dict.
1381 Optimize use case of :class:`~dask.highlevelgraph.HighLevelGraph`.
1383 Parameters
1384 ----------
1385 d : Mapping
1386 copy : bool
1387 If True, guarantee that the return value is always a shallow copy of d;
1388 otherwise it may be the input itself.
1389 """
1390 if type(d) is dict:
1391 return d.copy() if copy else d
1392 try:
1393 layers = d.layers # type: ignore
1394 except AttributeError:
1395 return dict(d)
1397 result = {}
1398 for layer in toolz.unique(layers.values(), key=id):
1399 result.update(layer)
1400 return result
1403def ensure_set(s: Set[T], *, copy: bool = False) -> set[T]:
1404 """Convert a generic Set into a set.
1406 Parameters
1407 ----------
1408 s : Set
1409 copy : bool
1410 If True, guarantee that the return value is always a shallow copy of s;
1411 otherwise it may be the input itself.
1412 """
1413 if type(s) is set:
1414 return s.copy() if copy else s
1415 return set(s)
1418class OperatorMethodMixin:
1419 """A mixin for dynamically implementing operators"""
1421 __slots__ = ()
1423 @classmethod
1424 def _bind_operator(cls, op):
1425 """bind operator to this class"""
1426 name = op.__name__
1428 if name.endswith("_"):
1429 # for and_ and or_
1430 name = name[:-1]
1431 elif name == "inv":
1432 name = "invert"
1434 meth = f"__{name}__"
1436 if name in ("abs", "invert", "neg", "pos"):
1437 setattr(cls, meth, cls._get_unary_operator(op))
1438 else:
1439 setattr(cls, meth, cls._get_binary_operator(op))
1441 if name in ("eq", "gt", "ge", "lt", "le", "ne", "getitem"):
1442 return
1444 rmeth = f"__r{name}__"
1445 setattr(cls, rmeth, cls._get_binary_operator(op, inv=True))
1447 @classmethod
1448 def _get_unary_operator(cls, op):
1449 """Must return a method used by unary operator"""
1450 raise NotImplementedError
1452 @classmethod
1453 def _get_binary_operator(cls, op, inv=False):
1454 """Must return a method used by binary operator"""
1455 raise NotImplementedError
1458def partial_by_order(*args, **kwargs):
1459 """
1461 >>> from operator import add
1462 >>> partial_by_order(5, function=add, other=[(1, 10)])
1463 15
1464 """
1465 function = kwargs.pop("function")
1466 other = kwargs.pop("other")
1467 args2 = list(args)
1468 for i, arg in other:
1469 args2.insert(i, arg)
1470 return function(*args2, **kwargs)
1473def is_arraylike(x) -> bool:
1474 """Is this object a numpy array or something similar?
1476 This function tests specifically for an object that already has
1477 array attributes (e.g. np.ndarray, dask.array.Array, cupy.ndarray,
1478 sparse.COO), **NOT** for something that can be coerced into an
1479 array object (e.g. Python lists and tuples). It is meant for dask
1480 developers and developers of downstream libraries.
1482 Note that this function does not correspond with NumPy's
1483 definition of array_like, which includes any object that can be
1484 coerced into an array (see definition in the NumPy glossary):
1485 https://numpy.org/doc/stable/glossary.html
1487 Examples
1488 --------
1489 >>> import numpy as np
1490 >>> is_arraylike(np.ones(5))
1491 True
1492 >>> is_arraylike(np.ones(()))
1493 True
1494 >>> is_arraylike(5)
1495 False
1496 >>> is_arraylike('cat')
1497 False
1498 """
1499 from dask.base import is_dask_collection
1501 is_duck_array = hasattr(x, "__array_function__") or hasattr(x, "__array_ufunc__")
1503 return bool(
1504 hasattr(x, "shape")
1505 and isinstance(x.shape, tuple)
1506 and hasattr(x, "dtype")
1507 and not any(is_dask_collection(n) for n in x.shape)
1508 # We special case scipy.sparse and cupyx.scipy.sparse arrays as having partial
1509 # support for them is useful in scenarios where we mostly call `map_partitions`
1510 # or `map_blocks` with scikit-learn functions on dask arrays and dask dataframes.
1511 # https://github.com/dask/dask/pull/3738
1512 and (is_duck_array or "scipy.sparse" in typename(type(x)))
1513 )
1516def is_dataframe_like(df) -> bool:
1517 """Looks like a Pandas DataFrame"""
1518 if (df.__class__.__module__, df.__class__.__name__) == (
1519 "pandas.core.frame",
1520 "DataFrame",
1521 ):
1522 # fast exec for most likely input
1523 return True
1524 typ = df.__class__
1525 return (
1526 all(hasattr(typ, name) for name in ("groupby", "head", "merge", "mean"))
1527 and all(hasattr(df, name) for name in ("dtypes", "columns"))
1528 and not any(hasattr(typ, name) for name in ("name", "dtype"))
1529 )
1532def is_series_like(s) -> bool:
1533 """Looks like a Pandas Series"""
1534 typ = s.__class__
1535 return (
1536 all(hasattr(typ, name) for name in ("groupby", "head", "mean"))
1537 and all(hasattr(s, name) for name in ("dtype", "name"))
1538 and "index" not in typ.__name__.lower()
1539 )
1542def is_index_like(s) -> bool:
1543 """Looks like a Pandas Index"""
1544 typ = s.__class__
1545 return (
1546 all(hasattr(s, name) for name in ("name", "dtype"))
1547 and "index" in typ.__name__.lower()
1548 )
1551def is_cupy_type(x) -> bool:
1552 # TODO: avoid explicit reference to CuPy
1553 return "cupy" in str(type(x))
1556def natural_sort_key(s: str) -> list[str | int]:
1557 """
1558 Sorting `key` function for performing a natural sort on a collection of
1559 strings
1561 See https://en.wikipedia.org/wiki/Natural_sort_order
1563 Parameters
1564 ----------
1565 s : str
1566 A string that is an element of the collection being sorted
1568 Returns
1569 -------
1570 tuple[str or int]
1571 Tuple of the parts of the input string where each part is either a
1572 string or an integer
1574 Examples
1575 --------
1576 >>> a = ['f0', 'f1', 'f2', 'f8', 'f9', 'f10', 'f11', 'f19', 'f20', 'f21']
1577 >>> sorted(a)
1578 ['f0', 'f1', 'f10', 'f11', 'f19', 'f2', 'f20', 'f21', 'f8', 'f9']
1579 >>> sorted(a, key=natural_sort_key)
1580 ['f0', 'f1', 'f2', 'f8', 'f9', 'f10', 'f11', 'f19', 'f20', 'f21']
1581 """
1582 return [int(part) if part.isdigit() else part for part in re.split(r"(\d+)", s)]
1585def parse_bytes(s: float | str) -> int:
1586 """Parse byte string to numbers
1588 >>> from dask.utils import parse_bytes
1589 >>> parse_bytes('100')
1590 100
1591 >>> parse_bytes('100 MB')
1592 100000000
1593 >>> parse_bytes('100M')
1594 100000000
1595 >>> parse_bytes('5kB')
1596 5000
1597 >>> parse_bytes('5.4 kB')
1598 5400
1599 >>> parse_bytes('1kiB')
1600 1024
1601 >>> parse_bytes('1e6')
1602 1000000
1603 >>> parse_bytes('1e6 kB')
1604 1000000000
1605 >>> parse_bytes('MB')
1606 1000000
1607 >>> parse_bytes(123)
1608 123
1609 >>> parse_bytes('5 foos')
1610 Traceback (most recent call last):
1611 ...
1612 ValueError: Could not interpret 'foos' as a byte unit
1613 """
1614 if isinstance(s, (int, float)):
1615 return int(s)
1616 s = s.replace(" ", "")
1617 if not any(char.isdigit() for char in s):
1618 s = "1" + s
1620 for i in range(len(s) - 1, -1, -1):
1621 if not s[i].isalpha():
1622 break
1623 index = i + 1
1625 prefix = s[:index]
1626 suffix = s[index:]
1628 try:
1629 n = float(prefix)
1630 except ValueError as e:
1631 raise ValueError("Could not interpret '%s' as a number" % prefix) from e
1633 try:
1634 multiplier = byte_sizes[suffix.lower()]
1635 except KeyError as e:
1636 raise ValueError("Could not interpret '%s' as a byte unit" % suffix) from e
1638 result = n * multiplier
1639 return int(result)
1642byte_sizes = {
1643 "kB": 10**3,
1644 "MB": 10**6,
1645 "GB": 10**9,
1646 "TB": 10**12,
1647 "PB": 10**15,
1648 "KiB": 2**10,
1649 "MiB": 2**20,
1650 "GiB": 2**30,
1651 "TiB": 2**40,
1652 "PiB": 2**50,
1653 "B": 1,
1654 "": 1,
1655}
1656byte_sizes = {k.lower(): v for k, v in byte_sizes.items()}
1657byte_sizes.update({k[0]: v for k, v in byte_sizes.items() if k and "i" not in k})
1658byte_sizes.update({k[:-1]: v for k, v in byte_sizes.items() if k and "i" in k})
1661def format_time(n: float) -> str:
1662 """format integers as time
1664 >>> from dask.utils import format_time
1665 >>> format_time(1)
1666 '1.00 s'
1667 >>> format_time(0.001234)
1668 '1.23 ms'
1669 >>> format_time(0.00012345)
1670 '123.45 us'
1671 >>> format_time(123.456)
1672 '123.46 s'
1673 >>> format_time(1234.567)
1674 '20m 34s'
1675 >>> format_time(12345.67)
1676 '3hr 25m'
1677 >>> format_time(123456.78)
1678 '34hr 17m'
1679 >>> format_time(1234567.89)
1680 '14d 6hr'
1681 """
1682 if n > 24 * 60 * 60 * 2:
1683 d = int(n / 3600 / 24)
1684 h = int((n - d * 3600 * 24) / 3600)
1685 return f"{d}d {h}hr"
1686 if n > 60 * 60 * 2:
1687 h = int(n / 3600)
1688 m = int((n - h * 3600) / 60)
1689 return f"{h}hr {m}m"
1690 if n > 60 * 10:
1691 m = int(n / 60)
1692 s = int(n - m * 60)
1693 return f"{m}m {s}s"
1694 if n >= 1:
1695 return "%.2f s" % n
1696 if n >= 1e-3:
1697 return "%.2f ms" % (n * 1e3)
1698 return "%.2f us" % (n * 1e6)
1701def format_time_ago(n: datetime) -> str:
1702 """Calculate a '3 hours ago' type string from a Python datetime.
1704 Examples
1705 --------
1706 >>> from datetime import datetime, timedelta
1708 >>> now = datetime.now()
1709 >>> format_time_ago(now)
1710 'Just now'
1712 >>> past = datetime.now() - timedelta(minutes=1)
1713 >>> format_time_ago(past)
1714 '1 minute ago'
1716 >>> past = datetime.now() - timedelta(minutes=2)
1717 >>> format_time_ago(past)
1718 '2 minutes ago'
1720 >>> past = datetime.now() - timedelta(hours=1)
1721 >>> format_time_ago(past)
1722 '1 hour ago'
1724 >>> past = datetime.now() - timedelta(hours=6)
1725 >>> format_time_ago(past)
1726 '6 hours ago'
1728 >>> past = datetime.now() - timedelta(days=1)
1729 >>> format_time_ago(past)
1730 '1 day ago'
1732 >>> past = datetime.now() - timedelta(days=5)
1733 >>> format_time_ago(past)
1734 '5 days ago'
1736 >>> past = datetime.now() - timedelta(days=8)
1737 >>> format_time_ago(past)
1738 '1 week ago'
1740 >>> past = datetime.now() - timedelta(days=16)
1741 >>> format_time_ago(past)
1742 '2 weeks ago'
1744 >>> past = datetime.now() - timedelta(days=190)
1745 >>> format_time_ago(past)
1746 '6 months ago'
1748 >>> past = datetime.now() - timedelta(days=800)
1749 >>> format_time_ago(past)
1750 '2 years ago'
1752 """
1753 units = {
1754 "years": lambda diff: diff.days / 365,
1755 "months": lambda diff: diff.days / 30.436875, # Average days per month
1756 "weeks": lambda diff: diff.days / 7,
1757 "days": lambda diff: diff.days,
1758 "hours": lambda diff: diff.seconds / 3600,
1759 "minutes": lambda diff: diff.seconds % 3600 / 60,
1760 }
1761 diff = datetime.now() - n
1762 for unit in units:
1763 dur = int(units[unit](diff))
1764 if dur > 0:
1765 if dur == 1: # De-pluralize
1766 unit = unit[:-1]
1767 return f"{dur} {unit} ago"
1768 return "Just now"
1771def format_bytes(n: int) -> str:
1772 """Format bytes as text
1774 >>> from dask.utils import format_bytes
1775 >>> format_bytes(1)
1776 '1 B'
1777 >>> format_bytes(1234)
1778 '1.21 kiB'
1779 >>> format_bytes(12345678)
1780 '11.77 MiB'
1781 >>> format_bytes(1234567890)
1782 '1.15 GiB'
1783 >>> format_bytes(1234567890000)
1784 '1.12 TiB'
1785 >>> format_bytes(1234567890000000)
1786 '1.10 PiB'
1788 For all values < 2**60, the output is always <= 10 characters.
1789 """
1790 for prefix, k in (
1791 ("Pi", 2**50),
1792 ("Ti", 2**40),
1793 ("Gi", 2**30),
1794 ("Mi", 2**20),
1795 ("ki", 2**10),
1796 ):
1797 if n >= k * 0.9:
1798 return f"{n / k:.2f} {prefix}B"
1799 return f"{n} B"
1802timedelta_sizes = {
1803 "s": 1,
1804 "ms": 1e-3,
1805 "us": 1e-6,
1806 "ns": 1e-9,
1807 "m": 60,
1808 "h": 3600,
1809 "d": 3600 * 24,
1810 "w": 7 * 3600 * 24,
1811}
1813tds2 = {
1814 "second": 1,
1815 "minute": 60,
1816 "hour": 60 * 60,
1817 "day": 60 * 60 * 24,
1818 "week": 7 * 60 * 60 * 24,
1819 "millisecond": 1e-3,
1820 "microsecond": 1e-6,
1821 "nanosecond": 1e-9,
1822}
1823tds2.update({k + "s": v for k, v in tds2.items()})
1824timedelta_sizes.update(tds2)
1825timedelta_sizes.update({k.upper(): v for k, v in timedelta_sizes.items()})
1828@overload
1829def parse_timedelta(s: None, default: str | Literal[False] = "seconds") -> None: ...
1832@overload
1833def parse_timedelta(
1834 s: str | float | timedelta, default: str | Literal[False] = "seconds"
1835) -> float: ...
1838def parse_timedelta(s, default="seconds"):
1839 """Parse timedelta string to number of seconds
1841 Parameters
1842 ----------
1843 s : str, float, timedelta, or None
1844 default: str or False, optional
1845 Unit of measure if s does not specify one. Defaults to seconds.
1846 Set to False to require s to explicitly specify its own unit.
1848 Examples
1849 --------
1850 >>> from datetime import timedelta
1851 >>> from dask.utils import parse_timedelta
1852 >>> parse_timedelta('3s')
1853 3
1854 >>> parse_timedelta('3.5 seconds')
1855 3.5
1856 >>> parse_timedelta('300ms')
1857 0.3
1858 >>> parse_timedelta(timedelta(seconds=3)) # also supports timedeltas
1859 3
1860 """
1861 if s is None:
1862 return None
1863 if isinstance(s, timedelta):
1864 s = s.total_seconds()
1865 return int(s) if int(s) == s else s
1866 if isinstance(s, Number):
1867 s = str(s)
1868 s = s.replace(" ", "")
1869 if not s[0].isdigit():
1870 s = "1" + s
1872 for i in range(len(s) - 1, -1, -1):
1873 if not s[i].isalpha():
1874 break
1875 index = i + 1
1877 prefix = s[:index]
1878 suffix = s[index:] or default
1879 if suffix is False:
1880 raise ValueError(f"Missing time unit: {s}")
1881 if not isinstance(suffix, str):
1882 raise TypeError(f"default must be str or False, got {default!r}")
1884 n = float(prefix)
1886 try:
1887 multiplier = timedelta_sizes[suffix.lower()]
1888 except KeyError:
1889 valid_units = ", ".join(timedelta_sizes.keys())
1890 raise KeyError(
1891 f"Invalid time unit: {suffix}. Valid units are: {valid_units}"
1892 ) from None
1894 result = n * multiplier
1895 if int(result) == result:
1896 result = int(result)
1897 return result
1900def has_keyword(func, keyword):
1901 try:
1902 return keyword in inspect.signature(func).parameters
1903 except Exception:
1904 return False
1907def ndimlist(seq):
1908 if not isinstance(seq, (list, tuple)):
1909 return 0
1910 elif not seq:
1911 return 1
1912 else:
1913 return 1 + ndimlist(seq[0])
1916def iter_chunks(sizes, max_size):
1917 """Split sizes into chunks of total max_size each
1919 Parameters
1920 ----------
1921 sizes : iterable of numbers
1922 The sizes to be chunked
1923 max_size : number
1924 Maximum total size per chunk.
1925 It must be greater or equal than each size in sizes
1926 """
1927 chunk, chunk_sum = [], 0
1928 iter_sizes = iter(sizes)
1929 size = next(iter_sizes, None)
1930 while size is not None:
1931 assert size <= max_size
1932 if chunk_sum + size <= max_size:
1933 chunk.append(size)
1934 chunk_sum += size
1935 size = next(iter_sizes, None)
1936 else:
1937 assert chunk
1938 yield chunk
1939 chunk, chunk_sum = [], 0
1940 if chunk:
1941 yield chunk
1944hex_pattern = re.compile("[a-f]+")
1947@functools.lru_cache(100000)
1948def key_split(s):
1949 """
1950 >>> key_split('x')
1951 'x'
1952 >>> key_split('x-1')
1953 'x'
1954 >>> key_split('x-1-2-3')
1955 'x'
1956 >>> key_split(('x-2', 1))
1957 'x'
1958 >>> key_split("('x-2', 1)")
1959 'x'
1960 >>> key_split("('x', 1)")
1961 'x'
1962 >>> key_split('hello-world-1')
1963 'hello-world'
1964 >>> key_split(b'hello-world-1')
1965 'hello-world'
1966 >>> key_split('ae05086432ca935f6eba409a8ecd4896')
1967 'data'
1968 >>> key_split('<module.submodule.myclass object at 0xdaf372')
1969 'myclass'
1970 >>> key_split(None)
1971 'Other'
1972 >>> key_split('x-abcdefab') # ignores hex
1973 'x'
1974 >>> key_split('_(x)') # strips unpleasant characters
1975 'x'
1976 """
1977 # If we convert the key, recurse to utilize LRU cache better
1978 if type(s) is bytes:
1979 return key_split(s.decode())
1980 if type(s) is tuple:
1981 return key_split(s[0])
1982 try:
1983 words = s.split("-")
1984 if not words[0][0].isalpha():
1985 result = words[0].split(",")[0].strip("_'()\"")
1986 else:
1987 result = words[0]
1988 for word in words[1:]:
1989 if word.isalpha() and not (
1990 len(word) == 8 and hex_pattern.match(word) is not None
1991 ):
1992 result += "-" + word
1993 else:
1994 break
1995 if len(result) == 32 and re.match(r"[a-f0-9]{32}", result):
1996 return "data"
1997 else:
1998 if result[0] == "<":
1999 result = result.strip("<>").split()[0].split(".")[-1]
2000 return sys.intern(result)
2001 except Exception:
2002 return "Other"
2005def stringify(obj, exclusive: Iterable | None = None):
2006 """Convert an object to a string
2008 If ``exclusive`` is specified, search through `obj` and convert
2009 values that are in ``exclusive``.
2011 Note that when searching through dictionaries, only values are
2012 converted, not the keys.
2014 Parameters
2015 ----------
2016 obj : Any
2017 Object (or values within) to convert to string
2018 exclusive: Iterable, optional
2019 Set of values to search for when converting values to strings
2021 Returns
2022 -------
2023 result : type(obj)
2024 Stringified copy of ``obj`` or ``obj`` itself if it is already a
2025 string or bytes.
2027 Examples
2028 --------
2029 >>> stringify(b'x')
2030 b'x'
2031 >>> stringify('x')
2032 'x'
2033 >>> stringify({('a',0):('a',0), ('a',1): ('a',1)})
2034 "{('a', 0): ('a', 0), ('a', 1): ('a', 1)}"
2035 >>> stringify({('a',0):('a',0), ('a',1): ('a',1)}, exclusive={('a',0)})
2036 {('a', 0): "('a', 0)", ('a', 1): ('a', 1)}
2037 """
2039 typ = type(obj)
2040 if typ is str or typ is bytes:
2041 return obj
2042 elif exclusive is None:
2043 return str(obj)
2045 if typ is list:
2046 return [stringify(v, exclusive) for v in obj]
2047 if typ is dict:
2048 return {k: stringify(v, exclusive) for k, v in obj.items()}
2049 try:
2050 if obj in exclusive:
2051 return stringify(obj)
2052 except TypeError: # `obj` not hashable
2053 pass
2054 if typ is tuple: # If the tuple itself isn't a key, check its elements
2055 return tuple(stringify(v, exclusive) for v in obj)
2056 return obj
2059class cached_property(functools.cached_property):
2060 """Read only version of functools.cached_property."""
2062 def __set__(self, instance, val):
2063 """Raise an error when attempting to set a cached property."""
2064 raise AttributeError("Can't set attribute")
2067class _HashIdWrapper:
2068 """Hash and compare a wrapped object by identity instead of value"""
2070 def __init__(self, wrapped):
2071 self.wrapped = wrapped
2073 def __eq__(self, other):
2074 if not isinstance(other, _HashIdWrapper):
2075 return NotImplemented
2076 return self.wrapped is other.wrapped
2078 def __ne__(self, other):
2079 if not isinstance(other, _HashIdWrapper):
2080 return NotImplemented
2081 return self.wrapped is not other.wrapped
2083 def __hash__(self):
2084 return id(self.wrapped)
2087@functools.lru_cache
2088def _cumsum(seq, initial_zero):
2089 if isinstance(seq, _HashIdWrapper):
2090 seq = seq.wrapped
2091 if initial_zero:
2092 return tuple(toolz.accumulate(add, seq, 0))
2093 else:
2094 return tuple(toolz.accumulate(add, seq))
2097@functools.lru_cache
2098def _max(seq):
2099 if isinstance(seq, _HashIdWrapper):
2100 seq = seq.wrapped
2101 return max(seq)
2104def cached_max(seq):
2105 """Compute max with caching.
2107 Caching is by the identity of `seq` rather than the value. It is thus
2108 important that `seq` is a tuple of immutable objects, and this function
2109 is intended for use where `seq` is a value that will persist (generally
2110 block sizes).
2112 Parameters
2113 ----------
2114 seq : tuple
2115 Values to reduce
2117 Returns
2118 -------
2119 tuple
2120 """
2121 assert isinstance(seq, tuple)
2122 # Look up by identity first, to avoid a linear-time __hash__
2123 # if we've seen this tuple object before.
2124 result = _max(_HashIdWrapper(seq))
2125 return result
2128def cached_cumsum(seq, initial_zero=False):
2129 """Compute :meth:`toolz.accumulate` with caching.
2131 Caching is by the identify of `seq` rather than the value. It is thus
2132 important that `seq` is a tuple of immutable objects, and this function
2133 is intended for use where `seq` is a value that will persist (generally
2134 block sizes).
2136 Parameters
2137 ----------
2138 seq : tuple
2139 Values to cumulatively sum.
2140 initial_zero : bool, optional
2141 If true, the return value is prefixed with a zero.
2143 Returns
2144 -------
2145 tuple
2146 """
2147 if isinstance(seq, tuple):
2148 # Look up by identity first, to avoid a linear-time __hash__
2149 # if we've seen this tuple object before.
2150 result = _cumsum(_HashIdWrapper(seq), initial_zero)
2151 else:
2152 # Construct a temporary tuple, and look up by value.
2153 result = _cumsum(tuple(seq), initial_zero)
2154 return result
2157def show_versions() -> None:
2158 """Provide version information for bug reports."""
2160 from json import dumps
2161 from platform import uname
2162 from sys import stdout, version_info
2164 from dask._compatibility import importlib_metadata
2166 try:
2167 from distributed import __version__ as distributed_version
2168 except ImportError:
2169 distributed_version = None
2171 from dask import __version__ as dask_version
2173 deps = [
2174 "numpy",
2175 "pandas",
2176 "cloudpickle",
2177 "fsspec",
2178 "bokeh",
2179 "pyarrow",
2180 "zarr",
2181 ]
2183 result: dict[str, str | None] = {
2184 # note: only major, minor, micro are extracted
2185 "Python": ".".join([str(i) for i in version_info[:3]]),
2186 "Platform": uname().system,
2187 "dask": dask_version,
2188 "distributed": distributed_version,
2189 }
2191 for modname in deps:
2192 try:
2193 result[modname] = importlib_metadata.version(modname)
2194 except importlib_metadata.PackageNotFoundError:
2195 result[modname] = None
2197 stdout.writelines(dumps(result, indent=2))
2199 return
2202def maybe_pluralize(count, noun, plural_form=None):
2203 """Pluralize a count-noun string pattern when necessary"""
2204 if count == 1:
2205 return f"{count} {noun}"
2206 else:
2207 return f"{count} {plural_form or noun + 's'}"
2210def is_namedtuple_instance(obj: Any) -> bool:
2211 """Returns True if obj is an instance of a namedtuple.
2213 Note: This function checks for the existence of the methods and
2214 attributes that make up the namedtuple API, so it will return True
2215 IFF obj's type implements that API.
2216 """
2217 return (
2218 isinstance(obj, tuple)
2219 and hasattr(obj, "_make")
2220 and hasattr(obj, "_asdict")
2221 and hasattr(obj, "_replace")
2222 and hasattr(obj, "_fields")
2223 and hasattr(obj, "_field_defaults")
2224 )
2227def get_default_shuffle_method() -> str:
2228 if d := config.get("dataframe.shuffle.method", None):
2229 return d
2230 try:
2231 from distributed import default_client
2233 default_client()
2234 except (ImportError, ValueError):
2235 return "disk"
2237 try:
2238 from distributed.shuffle import check_minimal_arrow_version
2240 check_minimal_arrow_version()
2241 except ModuleNotFoundError:
2242 return "tasks"
2243 return "p2p"
2246def get_meta_library(like):
2247 if hasattr(like, "_meta"):
2248 like = like._meta
2250 return import_module(typename(like).partition(".")[0])
2253class shorten_traceback:
2254 """Context manager that removes irrelevant stack elements from traceback.
2256 * omits frames from modules that match `admin.traceback.shorten`
2257 * always keeps the first and last frame.
2258 """
2260 __slots__ = ()
2262 def __enter__(self) -> None:
2263 pass
2265 def __exit__(
2266 self,
2267 exc_type: type[BaseException] | None,
2268 exc_val: BaseException | None,
2269 exc_tb: types.TracebackType | None,
2270 ) -> None:
2271 if exc_val and exc_tb:
2272 exc_val.__traceback__ = self.shorten(exc_tb)
2274 @staticmethod
2275 def shorten(exc_tb: types.TracebackType) -> types.TracebackType:
2276 paths = config.get("admin.traceback.shorten")
2277 if not paths:
2278 return exc_tb
2280 exp = re.compile(".*(" + "|".join(paths) + ")")
2281 curr: types.TracebackType | None = exc_tb
2282 prev: types.TracebackType | None = None
2284 while curr:
2285 if prev is None:
2286 prev = curr # first frame
2287 elif not curr.tb_next:
2288 # always keep last frame
2289 prev.tb_next = curr
2290 prev = prev.tb_next
2291 elif not exp.match(curr.tb_frame.f_code.co_filename):
2292 # keep if module is not listed in config
2293 prev.tb_next = curr
2294 prev = curr
2295 curr = curr.tb_next
2297 # Uncomment to remove the first frame, which is something you don't want to keep
2298 # if it matches the regexes. Requires Python >=3.11.
2299 # if exc_tb.tb_next and exp.match(exc_tb.tb_frame.f_code.co_filename):
2300 # return exc_tb.tb_next
2302 return exc_tb
2305def unzip(ls, nout):
2306 """Unzip a list of lists into ``nout`` outputs."""
2307 out = list(zip(*ls))
2308 if not out:
2309 out = [()] * nout
2310 return out
2313class disable_gc(ContextDecorator):
2314 """Context manager to disable garbage collection."""
2316 def __init__(self, collect=False):
2317 self.collect = collect
2318 self._gc_enabled = gc.isenabled()
2320 def __enter__(self):
2321 gc.disable()
2322 return self
2324 def __exit__(self, exc_type, exc_value, traceback):
2325 if self._gc_enabled:
2326 gc.enable()
2327 return False