Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/base.py: 18%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import dataclasses
4import inspect
5import uuid
6import warnings
7from collections import OrderedDict
8from collections.abc import Hashable, Iterable, Iterator, Mapping
9from concurrent.futures import Executor
10from contextlib import contextmanager, suppress
11from contextvars import ContextVar
12from functools import partial
13from numbers import Integral, Number
14from operator import getitem
15from typing import TYPE_CHECKING, Any, Literal, TypeVar
17from tlz import merge
19from dask import config, local
20from dask._compatibility import EMSCRIPTEN
21from dask._task_spec import DataNode, Dict, List, Task, TaskRef
22from dask.core import flatten
23from dask.core import get as simple_get
24from dask.system import CPU_COUNT
25from dask.typing import Key, SchedulerGetCallable
26from dask.utils import is_namedtuple_instance, key_split, shorten_traceback
28if TYPE_CHECKING:
29 from dask._expr import Expr
31_DistributedClient = None
32_get_distributed_client = None
33_DISTRIBUTED_AVAILABLE = None
36def _distributed_available() -> bool:
37 # Lazy import in get_scheduler can be expensive
38 global _DistributedClient, _get_distributed_client, _DISTRIBUTED_AVAILABLE
39 if _DISTRIBUTED_AVAILABLE is not None:
40 return _DISTRIBUTED_AVAILABLE # type: ignore[unreachable]
41 try:
42 from distributed import Client as _DistributedClient
43 from distributed.worker import get_client as _get_distributed_client
45 _DISTRIBUTED_AVAILABLE = True
46 except ImportError:
47 _DISTRIBUTED_AVAILABLE = False
48 return _DISTRIBUTED_AVAILABLE
51__all__ = (
52 "DaskMethodsMixin",
53 "annotate",
54 "get_annotations",
55 "is_dask_collection",
56 "compute",
57 "persist",
58 "optimize",
59 "visualize",
60 "tokenize",
61 "normalize_token",
62 "get_collection_names",
63 "get_name_from_key",
64 "replace_name_in_key",
65 "clone_key",
66)
68# Backwards compat
69from dask.tokenize import TokenizationError, normalize_token, tokenize # noqa: F401
71_annotations: ContextVar[dict[str, Any] | None] = ContextVar(
72 "annotations", default=None
73)
76def get_annotations() -> dict[str, Any]:
77 """Get current annotations.
79 Returns
80 -------
81 Dict of all current annotations
83 See Also
84 --------
85 annotate
86 """
87 return _annotations.get() or {}
90@contextmanager
91def annotate(**annotations: Any) -> Iterator[None]:
92 """Context Manager for setting HighLevelGraph Layer annotations.
94 Annotations are metadata or soft constraints associated with
95 tasks that dask schedulers may choose to respect: They signal intent
96 without enforcing hard constraints. As such, they are
97 primarily designed for use with the distributed scheduler.
99 Almost any object can serve as an annotation, but small Python objects
100 are preferred, while large objects such as NumPy arrays are discouraged.
102 Callables supplied as an annotation should take a single *key* argument and
103 produce the appropriate annotation. Individual task keys in the annotated collection
104 are supplied to the callable.
106 Parameters
107 ----------
108 **annotations : key-value pairs
110 Examples
111 --------
113 All tasks within array A should have priority 100 and be retried 3 times
114 on failure.
116 >>> import dask
117 >>> import dask.array as da
118 >>> with dask.annotate(priority=100, retries=3):
119 ... A = da.ones((10000, 10000))
121 Prioritise tasks within Array A on flattened block ID.
123 >>> nblocks = (10, 10)
124 >>> with dask.annotate(priority=lambda k: k[1]*nblocks[1] + k[2]):
125 ... A = da.ones((1000, 1000), chunks=(100, 100))
127 Annotations may be nested.
129 >>> with dask.annotate(priority=1):
130 ... with dask.annotate(retries=3):
131 ... A = da.ones((1000, 1000))
132 ... B = A + 1
134 See Also
135 --------
136 get_annotations
137 """
139 # Sanity check annotations used in place of
140 # legacy distributed Client.{submit, persist, compute} keywords
141 if "workers" in annotations:
142 if isinstance(annotations["workers"], (list, set, tuple)):
143 annotations["workers"] = list(annotations["workers"])
144 elif isinstance(annotations["workers"], str):
145 annotations["workers"] = [annotations["workers"]]
146 elif callable(annotations["workers"]):
147 pass
148 else:
149 raise TypeError(
150 "'workers' annotation must be a sequence of str, a str or a callable, but got {}.".format(
151 annotations["workers"]
152 )
153 )
155 if (
156 "priority" in annotations
157 and not isinstance(annotations["priority"], Number)
158 and not callable(annotations["priority"])
159 ):
160 raise TypeError(
161 "'priority' annotation must be a Number or a callable, but got {}".format(
162 annotations["priority"]
163 )
164 )
166 if (
167 "retries" in annotations
168 and not isinstance(annotations["retries"], Number)
169 and not callable(annotations["retries"])
170 ):
171 raise TypeError(
172 "'retries' annotation must be a Number or a callable, but got {}".format(
173 annotations["retries"]
174 )
175 )
177 if (
178 "resources" in annotations
179 and not isinstance(annotations["resources"], dict)
180 and not callable(annotations["resources"])
181 ):
182 raise TypeError(
183 "'resources' annotation must be a dict, but got {}".format(
184 annotations["resources"]
185 )
186 )
188 if (
189 "allow_other_workers" in annotations
190 and not isinstance(annotations["allow_other_workers"], bool)
191 and not callable(annotations["allow_other_workers"])
192 ):
193 raise TypeError(
194 "'allow_other_workers' annotations must be a bool or a callable, but got {}".format(
195 annotations["allow_other_workers"]
196 )
197 )
198 ctx_annot = _annotations.get()
199 if ctx_annot is None:
200 ctx_annot = {}
201 token = _annotations.set(merge(ctx_annot, annotations))
202 try:
203 yield
204 finally:
205 _annotations.reset(token)
208def is_dask_collection(x) -> bool:
209 """Returns ``True`` if ``x`` is a dask collection.
211 Parameters
212 ----------
213 x : Any
214 Object to test.
216 Returns
217 -------
218 result : bool
219 ``True`` if `x` is a Dask collection.
221 Notes
222 -----
223 The DaskCollection typing.Protocol implementation defines a Dask
224 collection as a class that returns a Mapping from the
225 ``__dask_graph__`` method. This helper function existed before the
226 implementation of the protocol.
228 """
229 if (
230 isinstance(x, type)
231 or not hasattr(x, "__dask_graph__")
232 or not callable(x.__dask_graph__)
233 ):
234 return False
236 pkg_name = getattr(type(x), "__module__", "")
237 if pkg_name.split(".")[0] in ("dask_cudf",):
238 # Temporary hack to avoid graph materialization. Note that this won't work with
239 # dask_expr.array objects wrapped by xarray or pint. By the time dask_expr.array
240 # is published, we hope to be able to rewrite this method completely.
241 # Read: https://github.com/dask/dask/pull/10676
242 return True
243 elif pkg_name.startswith("dask.dataframe.dask_expr"):
244 return True
245 elif pkg_name.startswith("dask.array._array_expr"):
246 return True
248 # xarray, pint, and possibly other wrappers always define a __dask_graph__ method,
249 # but it may return None if they wrap around a non-dask object.
250 # In all known dask collections other than dask-expr,
251 # calling __dask_graph__ is cheap.
252 return x.__dask_graph__() is not None
255class DaskMethodsMixin:
256 """A mixin adding standard dask collection methods"""
258 __slots__ = ("__weakref__",)
260 def visualize(self, filename="mydask", format=None, optimize_graph=False, **kwargs):
261 """Render the computation of this object's task graph using graphviz.
263 Requires ``graphviz`` to be installed.
265 Parameters
266 ----------
267 filename : str or None, optional
268 The name of the file to write to disk. If the provided `filename`
269 doesn't include an extension, '.png' will be used by default.
270 If `filename` is None, no file will be written, and we communicate
271 with dot using only pipes.
272 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional
273 Format in which to write output file. Default is 'png'.
274 optimize_graph : bool, optional
275 If True, the graph is optimized before rendering. Otherwise,
276 the graph is displayed as is. Default is False.
277 color: {None, 'order'}, optional
278 Options to color nodes. Provide ``cmap=`` keyword for additional
279 colormap
280 **kwargs
281 Additional keyword arguments to forward to ``to_graphviz``.
283 Examples
284 --------
285 >>> x.visualize(filename='dask.pdf') # doctest: +SKIP
286 >>> x.visualize(filename='dask.pdf', color='order') # doctest: +SKIP
288 Returns
289 -------
290 result : IPython.display.Image, IPython.display.SVG, or None
291 See dask.dot.dot_graph for more information.
293 See Also
294 --------
295 dask.visualize
296 dask.dot.dot_graph
298 Notes
299 -----
300 For more information on optimization see here:
302 https://docs.dask.org/en/latest/optimize.html
303 """
304 return visualize(
305 self,
306 filename=filename,
307 format=format,
308 optimize_graph=optimize_graph,
309 **kwargs,
310 )
312 def persist(self, **kwargs):
313 """Persist this dask collection into memory
315 This turns a lazy Dask collection into a Dask collection with the same
316 metadata, but now with the results fully computed or actively computing
317 in the background.
319 The action of function differs significantly depending on the active
320 task scheduler. If the task scheduler supports asynchronous computing,
321 such as is the case of the dask.distributed scheduler, then persist
322 will return *immediately* and the return value's task graph will
323 contain Dask Future objects. However if the task scheduler only
324 supports blocking computation then the call to persist will *block*
325 and the return value's task graph will contain concrete Python results.
327 This function is particularly useful when using distributed systems,
328 because the results will be kept in distributed memory, rather than
329 returned to the local process as with compute.
331 Parameters
332 ----------
333 scheduler : string, optional
334 Which scheduler to use like "threads", "synchronous" or "processes".
335 If not provided, the default is to check the global settings first,
336 and then fall back to the collection defaults.
337 optimize_graph : bool, optional
338 If True [default], the graph is optimized before computation.
339 Otherwise the graph is run as is. This can be useful for debugging.
340 **kwargs
341 Extra keywords to forward to the scheduler function.
343 Returns
344 -------
345 New dask collections backed by in-memory data
347 See Also
348 --------
349 dask.persist
350 """
351 (result,) = persist(self, traverse=False, **kwargs)
352 return result
354 def compute(self, **kwargs):
355 """Compute this dask collection
357 This turns a lazy Dask collection into its in-memory equivalent.
358 For example a Dask array turns into a NumPy array and a Dask dataframe
359 turns into a Pandas dataframe. The entire dataset must fit into memory
360 before calling this operation.
362 Parameters
363 ----------
364 scheduler : string, optional
365 Which scheduler to use like "threads", "synchronous" or "processes".
366 If not provided, the default is to check the global settings first,
367 and then fall back to the collection defaults.
368 optimize_graph : bool, optional
369 If True [default], the graph is optimized before computation.
370 Otherwise the graph is run as is. This can be useful for debugging.
371 kwargs
372 Extra keywords to forward to the scheduler function.
374 See Also
375 --------
376 dask.compute
377 """
378 (result,) = compute(self, traverse=False, **kwargs)
379 return result
381 def __await__(self):
382 try:
383 from distributed import futures_of, wait
384 except ImportError as e:
385 raise ImportError(
386 "Using async/await with dask requires the `distributed` package"
387 ) from e
389 async def f():
390 if futures_of(self):
391 await wait(self)
392 return self
394 return f().__await__()
397def compute_as_if_collection(cls, dsk, keys, scheduler=None, get=None, **kwargs):
398 """Compute a graph as if it were of type cls.
400 Allows for applying the same optimizations and default scheduler."""
401 schedule = get_scheduler(scheduler=scheduler, cls=cls, get=get)
402 dsk2 = optimization_function(cls)(dsk, keys, **kwargs)
403 return schedule(dsk2, keys, **kwargs)
406def dont_optimize(dsk, keys, **kwargs):
407 return dsk
410def optimization_function(x):
411 return getattr(x, "__dask_optimize__", dont_optimize)
414def collections_to_expr(
415 collections: Iterable,
416 optimize_graph: bool = True,
417) -> Expr:
418 """
419 Convert many collections into a single dask expression.
421 Typically, users should not be required to interact with this function.
423 Parameters
424 ----------
425 collections : Iterable
426 An iterable of dask collections to be combined.
427 optimize_graph : bool, optional
428 If this is True and collections are encountered which are backed by
429 legacy HighLevelGraph objects, the returned Expression will run a low
430 level task optimization during materialization.
431 """
432 is_iterable = False
433 if isinstance(collections, (tuple, list, set)):
434 is_iterable = True
435 else:
436 collections = [collections]
437 if not collections:
438 raise ValueError("No collections provided")
439 from dask._expr import HLGExpr, _ExprSequence
441 graphs = []
442 for coll in collections:
443 from dask.delayed import Delayed
445 if isinstance(coll, Delayed) or not hasattr(coll, "expr"):
446 graphs.append(HLGExpr.from_collection(coll, optimize_graph=optimize_graph))
447 else:
448 graphs.append(coll.expr)
450 if len(graphs) > 1 or is_iterable:
451 return _ExprSequence(*graphs)
452 else:
453 return graphs[0]
456def unpack_collections(*args, traverse=True):
457 """Extract collections in preparation for compute/persist/etc...
459 Intended use is to find all collections in a set of (possibly nested)
460 python objects, do something to them (compute, etc...), then repackage them
461 in equivalent python objects.
463 Parameters
464 ----------
465 *args
466 Any number of objects. If it is a dask collection, it's extracted and
467 added to the list of collections returned. By default, python builtin
468 collections are also traversed to look for dask collections (for more
469 information see the ``traverse`` keyword).
470 traverse : bool, optional
471 If True (default), builtin python collections are traversed looking for
472 any dask collections they might contain.
474 Returns
475 -------
476 collections : list
477 A list of all dask collections contained in ``args``
478 repack : callable
479 A function to call on the transformed collections to repackage them as
480 they were in the original ``args``.
481 """
483 collections = []
484 repack_dsk = {}
486 collections_token = uuid.uuid4().hex
488 def _unpack(expr):
489 if is_dask_collection(expr):
490 tok = tokenize(expr)
491 if tok not in repack_dsk:
492 repack_dsk[tok] = Task(
493 tok, getitem, TaskRef(collections_token), len(collections)
494 )
495 collections.append(expr)
496 return TaskRef(tok)
498 tok = uuid.uuid4().hex
499 tsk: DataNode | Task # type: ignore
500 if not traverse:
501 tsk = DataNode(None, expr)
502 else:
503 # Treat iterators like lists
504 typ = list if isinstance(expr, Iterator) else type(expr)
505 if typ in (list, tuple, set):
506 tsk = Task(tok, typ, List(*[_unpack(i) for i in expr]))
507 elif typ in (dict, OrderedDict):
508 tsk = Task(
509 tok, typ, Dict({_unpack(k): _unpack(v) for k, v in expr.items()})
510 )
511 elif dataclasses.is_dataclass(expr) and not isinstance(expr, type):
512 tsk = Task(
513 tok,
514 typ,
515 *[_unpack(getattr(expr, f.name)) for f in dataclasses.fields(expr)],
516 )
517 elif is_namedtuple_instance(expr):
518 tsk = Task(tok, typ, *[_unpack(i) for i in expr])
519 else:
520 return expr
522 repack_dsk[tok] = tsk
523 return TaskRef(tok)
525 out = uuid.uuid4().hex
526 repack_dsk[out] = Task(out, tuple, List(*[_unpack(i) for i in args]))
528 def repack(results):
529 dsk = repack_dsk.copy()
530 dsk[collections_token] = DataNode(collections_token, results)
531 return simple_get(dsk, out)
533 # The original `collections` is kept alive by the closure
534 # This causes the collection to be only freed by the garbage collector
535 collections2 = list(collections)
536 collections.clear()
537 return collections2, repack
540def optimize(*args, traverse=True, **kwargs):
541 """Optimize several dask collections at once.
543 Returns equivalent dask collections that all share the same merged and
544 optimized underlying graph. This can be useful if converting multiple
545 collections to delayed objects, or to manually apply the optimizations at
546 strategic points.
548 Note that in most cases you shouldn't need to call this function directly.
550 Warning::
552 This function triggers a materialization of the collections and looses
553 any annotations attached to HLG layers.
555 Parameters
556 ----------
557 *args : objects
558 Any number of objects. If a dask object, its graph is optimized and
559 merged with all those of all other dask objects before returning an
560 equivalent dask collection. Non-dask arguments are passed through
561 unchanged.
562 traverse : bool, optional
563 By default dask traverses builtin python collections looking for dask
564 objects passed to ``optimize``. For large collections this can be
565 expensive. If none of the arguments contain any dask objects, set
566 ``traverse=False`` to avoid doing this traversal.
567 optimizations : list of callables, optional
568 Additional optimization passes to perform.
569 **kwargs
570 Extra keyword arguments to forward to the optimization passes.
572 Examples
573 --------
574 >>> import dask
575 >>> import dask.array as da
576 >>> a = da.arange(10, chunks=2).sum()
577 >>> b = da.arange(10, chunks=2).mean()
578 >>> a2, b2 = dask.optimize(a, b)
580 >>> a2.compute() == a.compute()
581 np.True_
582 >>> b2.compute() == b.compute()
583 np.True_
584 """
585 # TODO: This API is problematic. The approach to using postpersist forces us
586 # to materialize the graph. Most low level optimizations will materialize as
587 # well
588 collections, repack = unpack_collections(*args, traverse=traverse)
589 if not collections:
590 return args
592 dsk = collections_to_expr(collections)
594 postpersists = []
595 for a in collections:
596 r, s = a.__dask_postpersist__()
597 postpersists.append(r(dsk.__dask_graph__(), *s))
599 return repack(postpersists)
602def compute(
603 *args,
604 traverse=True,
605 optimize_graph=True,
606 scheduler=None,
607 get=None,
608 **kwargs,
609):
610 """Compute several dask collections at once.
612 Parameters
613 ----------
614 args : object
615 Any number of objects. If it is a dask object, it's computed and the
616 result is returned. By default, python builtin collections are also
617 traversed to look for dask objects (for more information see the
618 ``traverse`` keyword). Non-dask arguments are passed through unchanged.
619 traverse : bool, optional
620 By default dask traverses builtin python collections looking for dask
621 objects passed to ``compute``. For large collections this can be
622 expensive. If none of the arguments contain any dask objects, set
623 ``traverse=False`` to avoid doing this traversal.
624 scheduler : string, optional
625 Which scheduler to use like "threads", "synchronous" or "processes".
626 If not provided, the default is to check the global settings first,
627 and then fall back to the collection defaults.
628 optimize_graph : bool, optional
629 If True [default], the optimizations for each collection are applied
630 before computation. Otherwise the graph is run as is. This can be
631 useful for debugging.
632 get : ``None``
633 Should be left to ``None`` The get= keyword has been removed.
634 kwargs
635 Extra keywords to forward to the scheduler function.
637 Examples
638 --------
639 >>> import dask
640 >>> import dask.array as da
641 >>> a = da.arange(10, chunks=2).sum()
642 >>> b = da.arange(10, chunks=2).mean()
643 >>> dask.compute(a, b)
644 (np.int64(45), np.float64(4.5))
646 By default, dask objects inside python collections will also be computed:
648 >>> dask.compute({'a': a, 'b': b, 'c': 1})
649 ({'a': np.int64(45), 'b': np.float64(4.5), 'c': 1},)
650 """
652 collections, repack = unpack_collections(*args, traverse=traverse)
653 if not collections:
654 return args
656 schedule = get_scheduler(
657 scheduler=scheduler,
658 collections=collections,
659 get=get,
660 )
661 from dask._expr import FinalizeCompute
663 expr = collections_to_expr(collections, optimize_graph)
664 expr = FinalizeCompute(expr)
666 with shorten_traceback():
667 # The high level optimize will have to be called client side (for now)
668 # The optimize can internally trigger already a computation
669 # (e.g. parquet is reading some statistics). To move this to the
670 # scheduler we'd need some sort of scheduler-client to trigger a
671 # computation from inside the scheduler and continue with optimization
672 # once the results are in. An alternative could be to introduce a
673 # pre-optimize step for the Expressions that handles steps like these as
674 # a dedicated computation
676 # Another caveat is that optimize will only lock in the expression names
677 # after optimization. Names are determined using tokenize and tokenize
678 # is not cross-interpreter (let alone cross-host) stable such that we
679 # have to lock this in before sending stuff (otherwise we'd need to
680 # change the graph submission to a handshake which introduces all sorts
681 # of concurrency control issues)
683 expr = expr.optimize()
684 keys = list(flatten(expr.__dask_keys__()))
686 results = schedule(expr, keys, **kwargs)
688 return repack(results)
691def visualize(
692 *args,
693 filename="mydask",
694 traverse=True,
695 optimize_graph=False,
696 maxval=None,
697 engine: Literal["cytoscape", "ipycytoscape", "graphviz"] | None = None,
698 **kwargs,
699):
700 """
701 Visualize several dask graphs simultaneously.
703 Requires ``graphviz`` to be installed. All options that are not the dask
704 graph(s) should be passed as keyword arguments.
706 Parameters
707 ----------
708 args : object
709 Any number of objects. If it is a dask collection (for example, a
710 dask DataFrame, Array, Bag, or Delayed), its associated graph
711 will be included in the output of visualize. By default, python builtin
712 collections are also traversed to look for dask objects (for more
713 information see the ``traverse`` keyword). Arguments lacking an
714 associated graph will be ignored.
715 filename : str or None, optional
716 The name of the file to write to disk. If the provided `filename`
717 doesn't include an extension, '.png' will be used by default.
718 If `filename` is None, no file will be written, and we communicate
719 with dot using only pipes.
720 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional
721 Format in which to write output file. Default is 'png'.
722 traverse : bool, optional
723 By default, dask traverses builtin python collections looking for dask
724 objects passed to ``visualize``. For large collections this can be
725 expensive. If none of the arguments contain any dask objects, set
726 ``traverse=False`` to avoid doing this traversal.
727 optimize_graph : bool, optional
728 If True, the graph is optimized before rendering. Otherwise,
729 the graph is displayed as is. Default is False.
730 color : {None, 'order', 'ages', 'freed', 'memoryincreases', 'memorydecreases', 'memorypressure'}, optional
731 Options to color nodes. colormap:
733 - None, the default, no colors.
734 - 'order', colors the nodes' border based on the order they appear in the graph.
735 - 'ages', how long the data of a node is held.
736 - 'freed', the number of dependencies released after running a node.
737 - 'memoryincreases', how many more outputs are held after the lifetime of a node.
738 Large values may indicate nodes that should have run later.
739 - 'memorydecreases', how many fewer outputs are held after the lifetime of a node.
740 Large values may indicate nodes that should have run sooner.
741 - 'memorypressure', the number of data held when the node is run (circle), or
742 the data is released (rectangle).
743 maxval : {int, float}, optional
744 Maximum value for colormap to normalize form 0 to 1.0. Default is ``None``
745 will make it the max number of values
746 collapse_outputs : bool, optional
747 Whether to collapse output boxes, which often have empty labels.
748 Default is False.
749 verbose : bool, optional
750 Whether to label output and input boxes even if the data aren't chunked.
751 Beware: these labels can get very long. Default is False.
752 engine : {"graphviz", "ipycytoscape", "cytoscape"}, optional.
753 The visualization engine to use. If not provided, this checks the dask config
754 value "visualization.engine". If that is not set, it tries to import ``graphviz``
755 and ``ipycytoscape``, using the first one to succeed.
756 **kwargs
757 Additional keyword arguments to forward to the visualization engine.
759 Examples
760 --------
761 >>> x.visualize(filename='dask.pdf') # doctest: +SKIP
762 >>> x.visualize(filename='dask.pdf', color='order') # doctest: +SKIP
764 Returns
765 -------
766 result : IPython.display.Image, IPython.display.SVG, or None
767 See dask.dot.dot_graph for more information.
769 See Also
770 --------
771 dask.dot.dot_graph
773 Notes
774 -----
775 For more information on optimization see here:
777 https://docs.dask.org/en/latest/optimize.html
778 """
779 args, _ = unpack_collections(*args, traverse=traverse)
781 dsk = collections_to_expr(args, optimize_graph=optimize_graph).__dask_graph__()
783 return visualize_dsk(
784 dsk=dsk,
785 filename=filename,
786 traverse=traverse,
787 optimize_graph=optimize_graph,
788 maxval=maxval,
789 engine=engine,
790 **kwargs,
791 )
794def visualize_dsk(
795 dsk,
796 filename="mydask",
797 traverse=True,
798 optimize_graph=False,
799 maxval=None,
800 o=None,
801 engine: Literal["cytoscape", "ipycytoscape", "graphviz"] | None = None,
802 limit=None,
803 **kwargs,
804):
805 color = kwargs.get("color")
806 from dask.order import diagnostics, order
808 if color in {
809 "order",
810 "order-age",
811 "order-freed",
812 "order-memoryincreases",
813 "order-memorydecreases",
814 "order-memorypressure",
815 "age",
816 "freed",
817 "memoryincreases",
818 "memorydecreases",
819 "memorypressure",
820 "critical",
821 "cpath",
822 }:
823 import matplotlib.pyplot as plt
825 if o is None:
826 o_stats = order(dsk, return_stats=True)
827 o = {k: v.priority for k, v in o_stats.items()}
828 elif isinstance(next(iter(o.values())), int):
829 o_stats = order(dsk, return_stats=True)
830 else:
831 o_stats = o
832 o = {k: v.priority for k, v in o.items()}
834 try:
835 cmap = kwargs.pop("cmap")
836 except KeyError:
837 cmap = plt.cm.plasma
838 if isinstance(cmap, str):
839 import matplotlib.pyplot as plt
841 cmap = getattr(plt.cm, cmap)
843 def label(x):
844 return str(values[x])
846 data_values = None
847 if color != "order":
848 info = diagnostics(dsk, o)[0]
849 if color.endswith("age"):
850 values = {key: val.age for key, val in info.items()}
851 elif color.endswith("freed"):
852 values = {key: val.num_dependencies_freed for key, val in info.items()}
853 elif color.endswith("memorypressure"):
854 values = {key: val.num_data_when_run for key, val in info.items()}
855 data_values = {
856 key: val.num_data_when_released for key, val in info.items()
857 }
858 elif color.endswith("memoryincreases"):
859 values = {
860 key: max(0, val.num_data_when_released - val.num_data_when_run)
861 for key, val in info.items()
862 }
863 elif color.endswith("memorydecreases"):
864 values = {
865 key: max(0, val.num_data_when_run - val.num_data_when_released)
866 for key, val in info.items()
867 }
868 elif color.split("-")[-1] in {"critical", "cpath"}:
869 values = {key: val.critical_path for key, val in o_stats.items()}
870 else:
871 raise NotImplementedError(color)
873 if color.startswith("order-"):
875 def label(x):
876 return str(o[x]) + "-" + str(values[x])
878 else:
879 values = o
880 if maxval is None:
881 maxval = max(1, *values.values())
882 colors = {
883 k: _colorize(tuple(map(int, cmap(v / maxval, bytes=True))))
884 for k, v in values.items()
885 }
886 if data_values is None:
887 data_colors = colors
888 else:
889 data_colors = {
890 k: _colorize(tuple(map(int, cmap(v / maxval, bytes=True))))
891 for k, v in values.items()
892 }
894 kwargs["function_attributes"] = {
895 k: {"color": v, "label": label(k)} for k, v in colors.items()
896 }
897 kwargs["data_attributes"] = {k: {"color": v} for k, v in data_colors.items()}
898 elif color:
899 raise NotImplementedError(f"Unknown value color={color}")
901 # Determine which engine to dispatch to, first checking the kwarg, then config,
902 # then whichever of graphviz or ipycytoscape are installed, in that order.
903 engine = engine or config.get("visualization.engine", None)
905 if not engine:
906 try:
907 import graphviz # noqa: F401
909 engine = "graphviz"
910 except ImportError:
911 try:
912 import ipycytoscape # noqa: F401
914 engine = "cytoscape"
915 except ImportError:
916 pass
917 if engine == "graphviz":
918 from dask.dot import dot_graph
920 return dot_graph(dsk, filename=filename, **kwargs)
921 elif engine in ("cytoscape", "ipycytoscape"):
922 from dask.dot import cytoscape_graph
924 return cytoscape_graph(dsk, filename=filename, **kwargs)
925 elif engine is None:
926 raise RuntimeError(
927 "No visualization engine detected, please install graphviz or ipycytoscape"
928 )
929 else:
930 raise ValueError(f"Visualization engine {engine} not recognized")
933def persist(*args, traverse=True, optimize_graph=True, scheduler=None, **kwargs):
934 """Persist multiple Dask collections into memory
936 This turns lazy Dask collections into Dask collections with the same
937 metadata, but now with their results fully computed or actively computing
938 in the background.
940 For example a lazy dask.array built up from many lazy calls will now be a
941 dask.array of the same shape, dtype, chunks, etc., but now with all of
942 those previously lazy tasks either computed in memory as many small :class:`numpy.array`
943 (in the single-machine case) or asynchronously running in the
944 background on a cluster (in the distributed case).
946 This function operates differently if a ``dask.distributed.Client`` exists
947 and is connected to a distributed scheduler. In this case this function
948 will return as soon as the task graph has been submitted to the cluster,
949 but before the computations have completed. Computations will continue
950 asynchronously in the background. When using this function with the single
951 machine scheduler it blocks until the computations have finished.
953 When using Dask on a single machine you should ensure that the dataset fits
954 entirely within memory.
956 Examples
957 --------
958 >>> df = dd.read_csv('/path/to/*.csv') # doctest: +SKIP
959 >>> df = df[df.name == 'Alice'] # doctest: +SKIP
960 >>> df['in-debt'] = df.balance < 0 # doctest: +SKIP
961 >>> df = df.persist() # triggers computation # doctest: +SKIP
963 >>> df.value().min() # future computations are now fast # doctest: +SKIP
964 -10
965 >>> df.value().max() # doctest: +SKIP
966 100
968 >>> from dask import persist # use persist function on multiple collections
969 >>> a, b = persist(a, b) # doctest: +SKIP
971 Parameters
972 ----------
973 *args: Dask collections
974 scheduler : string, optional
975 Which scheduler to use like "threads", "synchronous" or "processes".
976 If not provided, the default is to check the global settings first,
977 and then fall back to the collection defaults.
978 traverse : bool, optional
979 By default dask traverses builtin python collections looking for dask
980 objects passed to ``persist``. For large collections this can be
981 expensive. If none of the arguments contain any dask objects, set
982 ``traverse=False`` to avoid doing this traversal.
983 optimize_graph : bool, optional
984 If True [default], the graph is optimized before computation.
985 Otherwise the graph is run as is. This can be useful for debugging.
986 **kwargs
987 Extra keywords to forward to the scheduler function.
989 Returns
990 -------
991 New dask collections backed by in-memory data
992 """
993 collections, repack = unpack_collections(*args, traverse=traverse)
994 if not collections:
995 return args
997 schedule = get_scheduler(scheduler=scheduler, collections=collections)
999 if inspect.ismethod(schedule):
1000 try:
1001 from distributed.client import default_client
1002 except ImportError:
1003 pass
1004 else:
1005 try:
1006 client = default_client()
1007 except ValueError:
1008 pass
1009 else:
1010 if client.get == schedule:
1011 results = client.persist(
1012 collections, optimize_graph=optimize_graph, **kwargs
1013 )
1014 return repack(results)
1016 expr = collections_to_expr(collections, optimize_graph)
1017 expr = expr.optimize()
1018 keys, postpersists = [], []
1019 for a, akeys in zip(collections, expr.__dask_keys__(), strict=True):
1020 a_keys = list(flatten(akeys))
1021 rebuild, state = a.__dask_postpersist__()
1022 keys.extend(a_keys)
1023 postpersists.append((rebuild, a_keys, state))
1025 with shorten_traceback():
1026 results = schedule(expr, keys, **kwargs)
1028 d = dict(zip(keys, results))
1029 results2 = [r({k: d[k] for k in ks}, *s) for r, ks, s in postpersists]
1030 return repack(results2)
1033def _colorize(t):
1034 """Convert (r, g, b) triple to "#RRGGBB" string
1036 For use with ``visualize(color=...)``
1038 Examples
1039 --------
1040 >>> _colorize((255, 255, 255))
1041 '#FFFFFF'
1042 >>> _colorize((0, 32, 128))
1043 '#002080'
1044 """
1045 t = t[:3]
1046 i = sum(v * 256 ** (len(t) - i - 1) for i, v in enumerate(t))
1047 h = hex(int(i))[2:].upper()
1048 h = "0" * (6 - len(h)) + h
1049 return "#" + h
1052named_schedulers: dict[str, SchedulerGetCallable] = {
1053 "sync": local.get_sync,
1054 "synchronous": local.get_sync,
1055 "single-threaded": local.get_sync,
1056}
1058if not EMSCRIPTEN:
1059 from dask import threaded
1061 named_schedulers.update(
1062 {
1063 "threads": threaded.get,
1064 "threading": threaded.get,
1065 }
1066 )
1068 from dask import multiprocessing as dask_multiprocessing
1070 named_schedulers.update(
1071 {
1072 "processes": dask_multiprocessing.get,
1073 "multiprocessing": dask_multiprocessing.get,
1074 }
1075 )
1078get_err_msg = """
1079The get= keyword has been removed.
1081Please use the scheduler= keyword instead with the name of
1082the desired scheduler like 'threads' or 'processes'
1084 x.compute(scheduler='single-threaded')
1085 x.compute(scheduler='threads')
1086 x.compute(scheduler='processes')
1088or with a function that takes the graph and keys
1090 x.compute(scheduler=my_scheduler_function)
1092or with a Dask client
1094 x.compute(scheduler=client)
1095""".strip()
1098def _ensure_not_async(client):
1099 if client.asynchronous:
1100 if fallback := config.get("admin.async-client-fallback", None):
1101 warnings.warn(
1102 "Distributed Client detected but Client instance is "
1103 f"asynchronous. Falling back to `{fallback}` scheduler. "
1104 "To use an asynchronous Client, please use "
1105 "``Client.compute`` and ``Client.gather`` "
1106 "instead of the top level ``dask.compute``",
1107 UserWarning,
1108 )
1109 return get_scheduler(scheduler=fallback)
1110 else:
1111 raise RuntimeError(
1112 "Attempting to use an asynchronous "
1113 "Client in a synchronous context of `dask.compute`"
1114 )
1115 return client.get
1118def get_scheduler(get=None, scheduler=None, collections=None, cls=None):
1119 """Get scheduler function
1121 There are various ways to specify the scheduler to use:
1123 1. Passing in scheduler= parameters
1124 2. Passing these into global configuration
1125 3. Using a dask.distributed default Client
1126 4. Using defaults of a dask collection
1128 This function centralizes the logic to determine the right scheduler to use
1129 from those many options
1130 """
1131 if get:
1132 raise TypeError(get_err_msg)
1134 if scheduler is not None:
1135 if callable(scheduler):
1136 return scheduler
1137 elif "Client" in type(scheduler).__name__ and hasattr(scheduler, "get"):
1138 return _ensure_not_async(scheduler)
1139 elif isinstance(scheduler, str):
1140 scheduler = scheduler.lower()
1142 client_available = False
1143 if _distributed_available():
1144 assert _DistributedClient is not None
1145 with suppress(ValueError):
1146 _DistributedClient.current(allow_global=True)
1147 client_available = True
1148 if scheduler in named_schedulers:
1149 return named_schedulers[scheduler]
1150 elif scheduler in ("dask.distributed", "distributed"):
1151 if not client_available:
1152 raise RuntimeError(
1153 f"Requested {scheduler} scheduler but no Client active."
1154 )
1155 assert _get_distributed_client is not None
1156 client = _get_distributed_client()
1157 return _ensure_not_async(client)
1158 else:
1159 raise ValueError(
1160 "Expected one of [distributed, {}]".format(
1161 ", ".join(sorted(named_schedulers))
1162 )
1163 )
1164 elif isinstance(scheduler, Executor):
1165 # Get `num_workers` from `Executor`'s `_max_workers` attribute.
1166 # If undefined, fallback to `config` or worst case CPU_COUNT.
1167 num_workers = getattr(scheduler, "_max_workers", None)
1168 if num_workers is None:
1169 num_workers = config.get("num_workers", CPU_COUNT)
1170 assert isinstance(num_workers, Integral) and num_workers > 0
1171 return partial(local.get_async, scheduler.submit, num_workers)
1172 else:
1173 raise ValueError(f"Unexpected scheduler: {scheduler!r}")
1174 # else: # try to connect to remote scheduler with this name
1175 # return get_client(scheduler).get
1177 if config.get("scheduler", None):
1178 return get_scheduler(scheduler=config.get("scheduler", None))
1180 if config.get("get", None):
1181 raise ValueError(get_err_msg)
1183 try:
1184 from distributed import get_client
1186 return _ensure_not_async(get_client())
1187 except (ImportError, ValueError):
1188 pass
1190 if cls is not None:
1191 return cls.__dask_scheduler__
1193 if collections:
1194 collections = [c for c in collections if c is not None]
1195 if collections:
1196 get = collections[0].__dask_scheduler__
1197 if not all(c.__dask_scheduler__ == get for c in collections):
1198 raise ValueError(
1199 "Compute called on multiple collections with "
1200 "differing default schedulers. Please specify a "
1201 "scheduler=` parameter explicitly in compute or "
1202 "globally with `dask.config.set`."
1203 )
1204 return get
1206 return None
1209def wait(x, timeout=None, return_when="ALL_COMPLETED"):
1210 """Wait until computation has finished
1212 This is a compatibility alias for ``dask.distributed.wait``.
1213 If it is applied onto Dask collections without Dask Futures or if Dask
1214 distributed is not installed then it is a no-op
1215 """
1216 try:
1217 from distributed import wait
1219 return wait(x, timeout=timeout, return_when=return_when)
1220 except (ImportError, ValueError):
1221 return x
1224def get_collection_names(collection) -> set[str]:
1225 """Infer the collection names from the dask keys, under the assumption that all keys
1226 are either tuples with matching first element, and that element is a string, or
1227 there is exactly one key and it is a string.
1229 Examples
1230 --------
1231 >>> a.__dask_keys__() # doctest: +SKIP
1232 ["foo", "bar"]
1233 >>> get_collection_names(a) # doctest: +SKIP
1234 {"foo", "bar"}
1235 >>> b.__dask_keys__() # doctest: +SKIP
1236 [[("foo-123", 0, 0), ("foo-123", 0, 1)], [("foo-123", 1, 0), ("foo-123", 1, 1)]]
1237 >>> get_collection_names(b) # doctest: +SKIP
1238 {"foo-123"}
1239 """
1240 if not is_dask_collection(collection):
1241 raise TypeError(f"Expected Dask collection; got {type(collection)}")
1242 return {get_name_from_key(k) for k in flatten(collection.__dask_keys__())}
1245def get_name_from_key(key: Key) -> str:
1246 """Given a dask collection's key, extract the collection name.
1248 Parameters
1249 ----------
1250 key: string or tuple
1251 Dask collection's key, which must be either a single string or a tuple whose
1252 first element is a string (commonly referred to as a collection's 'name'),
1254 Examples
1255 --------
1256 >>> get_name_from_key("foo")
1257 'foo'
1258 >>> get_name_from_key(("foo-123", 1, 2))
1259 'foo-123'
1260 """
1261 if isinstance(key, tuple) and key and isinstance(key[0], str):
1262 return key[0]
1263 if isinstance(key, str):
1264 return key
1265 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}")
1268KeyOrStrT = TypeVar("KeyOrStrT", Key, str)
1271def replace_name_in_key(key: KeyOrStrT, rename: Mapping[str, str]) -> KeyOrStrT:
1272 """Given a dask collection's key, replace the collection name with a new one.
1274 Parameters
1275 ----------
1276 key: string or tuple
1277 Dask collection's key, which must be either a single string or a tuple whose
1278 first element is a string (commonly referred to as a collection's 'name'),
1279 rename:
1280 Mapping of zero or more names from : to. Extraneous names will be ignored.
1281 Names not found in this mapping won't be replaced.
1283 Examples
1284 --------
1285 >>> replace_name_in_key("foo", {})
1286 'foo'
1287 >>> replace_name_in_key("foo", {"foo": "bar"})
1288 'bar'
1289 >>> replace_name_in_key(("foo-123", 1, 2), {"foo-123": "bar-456"})
1290 ('bar-456', 1, 2)
1291 """
1292 if isinstance(key, tuple) and key and isinstance(key[0], str):
1293 return (rename.get(key[0], key[0]),) + key[1:]
1294 if isinstance(key, str):
1295 return rename.get(key, key)
1296 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}")
1299def clone_key(key: KeyOrStrT, seed: Hashable) -> KeyOrStrT:
1300 """Clone a key from a Dask collection, producing a new key with the same prefix and
1301 indices and a token which is a deterministic function of the previous key and seed.
1303 Examples
1304 --------
1305 >>> clone_key("x", 123) # doctest: +SKIP
1306 'x-c4fb64ccca807af85082413d7ef01721'
1307 >>> clone_key("inc-cbb1eca3bafafbb3e8b2419c4eebb387", 123) # doctest: +SKIP
1308 'inc-bc629c23014a4472e18b575fdaf29ee7'
1309 >>> clone_key(("sum-cbb1eca3bafafbb3e8b2419c4eebb387", 4, 3), 123) # doctest: +SKIP
1310 ('sum-c053f3774e09bd0f7de6044dbc40e71d', 4, 3)
1311 """
1312 if isinstance(key, tuple) and key and isinstance(key[0], str):
1313 return (clone_key(key[0], seed),) + key[1:]
1314 if isinstance(key, str):
1315 prefix = key_split(key)
1316 return prefix + "-" + tokenize(key, seed)
1317 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}")