Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/base.py: 18%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import dataclasses
4import inspect
5import uuid
6import warnings
7from collections import OrderedDict
8from collections.abc import Hashable, Iterable, Iterator, Mapping
9from concurrent.futures import Executor
10from contextlib import contextmanager, suppress
11from contextvars import ContextVar
12from functools import partial
13from numbers import Integral, Number
14from operator import getitem
15from typing import TYPE_CHECKING, Any, Literal, TypeVar
17from tlz import merge
19from dask import config, local
20from dask._compatibility import EMSCRIPTEN
21from dask._task_spec import DataNode, Dict, List, Task, TaskRef
22from dask.core import flatten
23from dask.core import get as simple_get
24from dask.system import CPU_COUNT
25from dask.typing import Key, SchedulerGetCallable
26from dask.utils import is_namedtuple_instance, key_split, shorten_traceback
28if TYPE_CHECKING:
29 from dask._expr import Expr
31_DistributedClient = None
32_get_distributed_client = None
33_DISTRIBUTED_AVAILABLE = None
36def _distributed_available() -> bool:
37 # Lazy import in get_scheduler can be expensive
38 global _DistributedClient, _get_distributed_client, _DISTRIBUTED_AVAILABLE
39 if _DISTRIBUTED_AVAILABLE is not None:
40 return _DISTRIBUTED_AVAILABLE # type: ignore[unreachable]
41 try:
42 from distributed import Client as _DistributedClient
43 from distributed.worker import get_client as _get_distributed_client
45 _DISTRIBUTED_AVAILABLE = True
46 except ImportError:
47 _DISTRIBUTED_AVAILABLE = False
48 return _DISTRIBUTED_AVAILABLE
51__all__ = (
52 "DaskMethodsMixin",
53 "annotate",
54 "get_annotations",
55 "is_dask_collection",
56 "compute",
57 "persist",
58 "optimize",
59 "visualize",
60 "tokenize",
61 "normalize_token",
62 "get_collection_names",
63 "get_name_from_key",
64 "replace_name_in_key",
65 "clone_key",
66)
68# Backwards compat
69from dask.tokenize import TokenizationError, normalize_token, tokenize # noqa: F401
71_annotations: ContextVar[dict[str, Any] | None] = ContextVar(
72 "annotations", default=None
73)
76def get_annotations() -> dict[str, Any]:
77 """Get current annotations.
79 Returns
80 -------
81 Dict of all current annotations
83 See Also
84 --------
85 annotate
86 """
87 return _annotations.get() or {}
90@contextmanager
91def annotate(**annotations: Any) -> Iterator[None]:
92 """Context Manager for setting HighLevelGraph Layer annotations.
94 Annotations are metadata or soft constraints associated with
95 tasks that dask schedulers may choose to respect: They signal intent
96 without enforcing hard constraints. As such, they are
97 primarily designed for use with the distributed scheduler.
99 Almost any object can serve as an annotation, but small Python objects
100 are preferred, while large objects such as NumPy arrays are discouraged.
102 Callables supplied as an annotation should take a single *key* argument and
103 produce the appropriate annotation. Individual task keys in the annotated collection
104 are supplied to the callable.
106 Parameters
107 ----------
108 **annotations : key-value pairs
110 Examples
111 --------
113 All tasks within array A should have priority 100 and be retried 3 times
114 on failure.
116 >>> import dask
117 >>> import dask.array as da
118 >>> with dask.annotate(priority=100, retries=3):
119 ... A = da.ones((10000, 10000))
121 Prioritise tasks within Array A on flattened block ID.
123 >>> nblocks = (10, 10)
124 >>> with dask.annotate(priority=lambda k: k[1]*nblocks[1] + k[2]):
125 ... A = da.ones((1000, 1000), chunks=(100, 100))
127 Annotations may be nested.
129 >>> with dask.annotate(priority=1):
130 ... with dask.annotate(retries=3):
131 ... A = da.ones((1000, 1000))
132 ... B = A + 1
134 See Also
135 --------
136 get_annotations
137 """
139 # Sanity check annotations used in place of
140 # legacy distributed Client.{submit, persist, compute} keywords
141 if "workers" in annotations:
142 if isinstance(annotations["workers"], (list, set, tuple)):
143 annotations["workers"] = list(annotations["workers"])
144 elif isinstance(annotations["workers"], str):
145 annotations["workers"] = [annotations["workers"]]
146 elif callable(annotations["workers"]):
147 pass
148 else:
149 raise TypeError(
150 "'workers' annotation must be a sequence of str, a str or a callable, but got %s."
151 % annotations["workers"]
152 )
154 if (
155 "priority" in annotations
156 and not isinstance(annotations["priority"], Number)
157 and not callable(annotations["priority"])
158 ):
159 raise TypeError(
160 "'priority' annotation must be a Number or a callable, but got %s"
161 % annotations["priority"]
162 )
164 if (
165 "retries" in annotations
166 and not isinstance(annotations["retries"], Number)
167 and not callable(annotations["retries"])
168 ):
169 raise TypeError(
170 "'retries' annotation must be a Number or a callable, but got %s"
171 % annotations["retries"]
172 )
174 if (
175 "resources" in annotations
176 and not isinstance(annotations["resources"], dict)
177 and not callable(annotations["resources"])
178 ):
179 raise TypeError(
180 "'resources' annotation must be a dict, but got %s"
181 % annotations["resources"]
182 )
184 if (
185 "allow_other_workers" in annotations
186 and not isinstance(annotations["allow_other_workers"], bool)
187 and not callable(annotations["allow_other_workers"])
188 ):
189 raise TypeError(
190 "'allow_other_workers' annotations must be a bool or a callable, but got %s"
191 % annotations["allow_other_workers"]
192 )
193 ctx_annot = _annotations.get()
194 if ctx_annot is None:
195 ctx_annot = {}
196 token = _annotations.set(merge(ctx_annot, annotations))
197 try:
198 yield
199 finally:
200 _annotations.reset(token)
203def is_dask_collection(x) -> bool:
204 """Returns ``True`` if ``x`` is a dask collection.
206 Parameters
207 ----------
208 x : Any
209 Object to test.
211 Returns
212 -------
213 result : bool
214 ``True`` if `x` is a Dask collection.
216 Notes
217 -----
218 The DaskCollection typing.Protocol implementation defines a Dask
219 collection as a class that returns a Mapping from the
220 ``__dask_graph__`` method. This helper function existed before the
221 implementation of the protocol.
223 """
224 if (
225 isinstance(x, type)
226 or not hasattr(x, "__dask_graph__")
227 or not callable(x.__dask_graph__)
228 ):
229 return False
231 pkg_name = getattr(type(x), "__module__", "")
232 if pkg_name.split(".")[0] in ("dask_cudf",):
233 # Temporary hack to avoid graph materialization. Note that this won't work with
234 # dask_expr.array objects wrapped by xarray or pint. By the time dask_expr.array
235 # is published, we hope to be able to rewrite this method completely.
236 # Read: https://github.com/dask/dask/pull/10676
237 return True
238 elif pkg_name.startswith("dask.dataframe.dask_expr"):
239 return True
240 elif pkg_name.startswith("dask.array._array_expr"):
241 return True
243 # xarray, pint, and possibly other wrappers always define a __dask_graph__ method,
244 # but it may return None if they wrap around a non-dask object.
245 # In all known dask collections other than dask-expr,
246 # calling __dask_graph__ is cheap.
247 return x.__dask_graph__() is not None
250class DaskMethodsMixin:
251 """A mixin adding standard dask collection methods"""
253 __slots__ = ("__weakref__",)
255 def visualize(self, filename="mydask", format=None, optimize_graph=False, **kwargs):
256 """Render the computation of this object's task graph using graphviz.
258 Requires ``graphviz`` to be installed.
260 Parameters
261 ----------
262 filename : str or None, optional
263 The name of the file to write to disk. If the provided `filename`
264 doesn't include an extension, '.png' will be used by default.
265 If `filename` is None, no file will be written, and we communicate
266 with dot using only pipes.
267 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional
268 Format in which to write output file. Default is 'png'.
269 optimize_graph : bool, optional
270 If True, the graph is optimized before rendering. Otherwise,
271 the graph is displayed as is. Default is False.
272 color: {None, 'order'}, optional
273 Options to color nodes. Provide ``cmap=`` keyword for additional
274 colormap
275 **kwargs
276 Additional keyword arguments to forward to ``to_graphviz``.
278 Examples
279 --------
280 >>> x.visualize(filename='dask.pdf') # doctest: +SKIP
281 >>> x.visualize(filename='dask.pdf', color='order') # doctest: +SKIP
283 Returns
284 -------
285 result : IPython.display.Image, IPython.display.SVG, or None
286 See dask.dot.dot_graph for more information.
288 See Also
289 --------
290 dask.visualize
291 dask.dot.dot_graph
293 Notes
294 -----
295 For more information on optimization see here:
297 https://docs.dask.org/en/latest/optimize.html
298 """
299 return visualize(
300 self,
301 filename=filename,
302 format=format,
303 optimize_graph=optimize_graph,
304 **kwargs,
305 )
307 def persist(self, **kwargs):
308 """Persist this dask collection into memory
310 This turns a lazy Dask collection into a Dask collection with the same
311 metadata, but now with the results fully computed or actively computing
312 in the background.
314 The action of function differs significantly depending on the active
315 task scheduler. If the task scheduler supports asynchronous computing,
316 such as is the case of the dask.distributed scheduler, then persist
317 will return *immediately* and the return value's task graph will
318 contain Dask Future objects. However if the task scheduler only
319 supports blocking computation then the call to persist will *block*
320 and the return value's task graph will contain concrete Python results.
322 This function is particularly useful when using distributed systems,
323 because the results will be kept in distributed memory, rather than
324 returned to the local process as with compute.
326 Parameters
327 ----------
328 scheduler : string, optional
329 Which scheduler to use like "threads", "synchronous" or "processes".
330 If not provided, the default is to check the global settings first,
331 and then fall back to the collection defaults.
332 optimize_graph : bool, optional
333 If True [default], the graph is optimized before computation.
334 Otherwise the graph is run as is. This can be useful for debugging.
335 **kwargs
336 Extra keywords to forward to the scheduler function.
338 Returns
339 -------
340 New dask collections backed by in-memory data
342 See Also
343 --------
344 dask.persist
345 """
346 (result,) = persist(self, traverse=False, **kwargs)
347 return result
349 def compute(self, **kwargs):
350 """Compute this dask collection
352 This turns a lazy Dask collection into its in-memory equivalent.
353 For example a Dask array turns into a NumPy array and a Dask dataframe
354 turns into a Pandas dataframe. The entire dataset must fit into memory
355 before calling this operation.
357 Parameters
358 ----------
359 scheduler : string, optional
360 Which scheduler to use like "threads", "synchronous" or "processes".
361 If not provided, the default is to check the global settings first,
362 and then fall back to the collection defaults.
363 optimize_graph : bool, optional
364 If True [default], the graph is optimized before computation.
365 Otherwise the graph is run as is. This can be useful for debugging.
366 kwargs
367 Extra keywords to forward to the scheduler function.
369 See Also
370 --------
371 dask.compute
372 """
373 (result,) = compute(self, traverse=False, **kwargs)
374 return result
376 def __await__(self):
377 try:
378 from distributed import futures_of, wait
379 except ImportError as e:
380 raise ImportError(
381 "Using async/await with dask requires the `distributed` package"
382 ) from e
384 async def f():
385 if futures_of(self):
386 await wait(self)
387 return self
389 return f().__await__()
392def compute_as_if_collection(cls, dsk, keys, scheduler=None, get=None, **kwargs):
393 """Compute a graph as if it were of type cls.
395 Allows for applying the same optimizations and default scheduler."""
396 schedule = get_scheduler(scheduler=scheduler, cls=cls, get=get)
397 dsk2 = optimization_function(cls)(dsk, keys, **kwargs)
398 return schedule(dsk2, keys, **kwargs)
401def dont_optimize(dsk, keys, **kwargs):
402 return dsk
405def optimization_function(x):
406 return getattr(x, "__dask_optimize__", dont_optimize)
409def collections_to_expr(
410 collections: Iterable,
411 optimize_graph: bool = True,
412) -> Expr:
413 """
414 Convert many collections into a single dask expression.
416 Typically, users should not be required to interact with this function.
418 Parameters
419 ----------
420 collections : Iterable
421 An iterable of dask collections to be combined.
422 optimize_graph : bool, optional
423 If this is True and collections are encountered which are backed by
424 legacy HighLevelGraph objects, the returned Expression will run a low
425 level task optimization during materialization.
426 """
427 is_iterable = False
428 if isinstance(collections, (tuple, list, set)):
429 is_iterable = True
430 else:
431 collections = [collections]
432 if not collections:
433 raise ValueError("No collections provided")
434 from dask._expr import HLGExpr, _ExprSequence
436 graphs = []
437 for coll in collections:
438 from dask.delayed import Delayed
440 if isinstance(coll, Delayed) or not hasattr(coll, "expr"):
441 graphs.append(HLGExpr.from_collection(coll, optimize_graph=optimize_graph))
442 else:
443 graphs.append(coll.expr)
445 if len(graphs) > 1 or is_iterable:
446 return _ExprSequence(*graphs)
447 else:
448 return graphs[0]
451def unpack_collections(*args, traverse=True):
452 """Extract collections in preparation for compute/persist/etc...
454 Intended use is to find all collections in a set of (possibly nested)
455 python objects, do something to them (compute, etc...), then repackage them
456 in equivalent python objects.
458 Parameters
459 ----------
460 *args
461 Any number of objects. If it is a dask collection, it's extracted and
462 added to the list of collections returned. By default, python builtin
463 collections are also traversed to look for dask collections (for more
464 information see the ``traverse`` keyword).
465 traverse : bool, optional
466 If True (default), builtin python collections are traversed looking for
467 any dask collections they might contain.
469 Returns
470 -------
471 collections : list
472 A list of all dask collections contained in ``args``
473 repack : callable
474 A function to call on the transformed collections to repackage them as
475 they were in the original ``args``.
476 """
478 collections = []
479 repack_dsk = {}
481 collections_token = uuid.uuid4().hex
483 def _unpack(expr):
484 if is_dask_collection(expr):
485 tok = tokenize(expr)
486 if tok not in repack_dsk:
487 repack_dsk[tok] = Task(
488 tok, getitem, TaskRef(collections_token), len(collections)
489 )
490 collections.append(expr)
491 return TaskRef(tok)
493 tok = uuid.uuid4().hex
494 tsk: DataNode | Task # type: ignore
495 if not traverse:
496 tsk = DataNode(None, expr)
497 else:
498 # Treat iterators like lists
499 typ = list if isinstance(expr, Iterator) else type(expr)
500 if typ in (list, tuple, set):
501 tsk = Task(tok, typ, List(*[_unpack(i) for i in expr]))
502 elif typ in (dict, OrderedDict):
503 tsk = Task(
504 tok, typ, Dict({_unpack(k): _unpack(v) for k, v in expr.items()})
505 )
506 elif dataclasses.is_dataclass(expr) and not isinstance(expr, type):
507 tsk = Task(
508 tok,
509 typ,
510 *[_unpack(getattr(expr, f.name)) for f in dataclasses.fields(expr)],
511 )
512 elif is_namedtuple_instance(expr):
513 tsk = Task(tok, typ, *[_unpack(i) for i in expr])
514 else:
515 return expr
517 repack_dsk[tok] = tsk
518 return TaskRef(tok)
520 out = uuid.uuid4().hex
521 repack_dsk[out] = Task(out, tuple, List(*[_unpack(i) for i in args]))
523 def repack(results):
524 dsk = repack_dsk.copy()
525 dsk[collections_token] = DataNode(collections_token, results)
526 return simple_get(dsk, out)
528 # The original `collections` is kept alive by the closure
529 # This causes the collection to be only freed by the garbage collector
530 collections2 = list(collections)
531 collections.clear()
532 return collections2, repack
535def optimize(*args, traverse=True, **kwargs):
536 """Optimize several dask collections at once.
538 Returns equivalent dask collections that all share the same merged and
539 optimized underlying graph. This can be useful if converting multiple
540 collections to delayed objects, or to manually apply the optimizations at
541 strategic points.
543 Note that in most cases you shouldn't need to call this function directly.
545 Warning::
547 This function triggers a materialization of the collections and looses
548 any annotations attached to HLG layers.
550 Parameters
551 ----------
552 *args : objects
553 Any number of objects. If a dask object, its graph is optimized and
554 merged with all those of all other dask objects before returning an
555 equivalent dask collection. Non-dask arguments are passed through
556 unchanged.
557 traverse : bool, optional
558 By default dask traverses builtin python collections looking for dask
559 objects passed to ``optimize``. For large collections this can be
560 expensive. If none of the arguments contain any dask objects, set
561 ``traverse=False`` to avoid doing this traversal.
562 optimizations : list of callables, optional
563 Additional optimization passes to perform.
564 **kwargs
565 Extra keyword arguments to forward to the optimization passes.
567 Examples
568 --------
569 >>> import dask
570 >>> import dask.array as da
571 >>> a = da.arange(10, chunks=2).sum()
572 >>> b = da.arange(10, chunks=2).mean()
573 >>> a2, b2 = dask.optimize(a, b)
575 >>> a2.compute() == a.compute()
576 np.True_
577 >>> b2.compute() == b.compute()
578 np.True_
579 """
580 # TODO: This API is problematic. The approach to using postpersist forces us
581 # to materialize the graph. Most low level optimizations will materialize as
582 # well
583 collections, repack = unpack_collections(*args, traverse=traverse)
584 if not collections:
585 return args
587 dsk = collections_to_expr(collections)
589 postpersists = []
590 for a in collections:
591 r, s = a.__dask_postpersist__()
592 postpersists.append(r(dsk.__dask_graph__(), *s))
594 return repack(postpersists)
597def compute(
598 *args,
599 traverse=True,
600 optimize_graph=True,
601 scheduler=None,
602 get=None,
603 **kwargs,
604):
605 """Compute several dask collections at once.
607 Parameters
608 ----------
609 args : object
610 Any number of objects. If it is a dask object, it's computed and the
611 result is returned. By default, python builtin collections are also
612 traversed to look for dask objects (for more information see the
613 ``traverse`` keyword). Non-dask arguments are passed through unchanged.
614 traverse : bool, optional
615 By default dask traverses builtin python collections looking for dask
616 objects passed to ``compute``. For large collections this can be
617 expensive. If none of the arguments contain any dask objects, set
618 ``traverse=False`` to avoid doing this traversal.
619 scheduler : string, optional
620 Which scheduler to use like "threads", "synchronous" or "processes".
621 If not provided, the default is to check the global settings first,
622 and then fall back to the collection defaults.
623 optimize_graph : bool, optional
624 If True [default], the optimizations for each collection are applied
625 before computation. Otherwise the graph is run as is. This can be
626 useful for debugging.
627 get : ``None``
628 Should be left to ``None`` The get= keyword has been removed.
629 kwargs
630 Extra keywords to forward to the scheduler function.
632 Examples
633 --------
634 >>> import dask
635 >>> import dask.array as da
636 >>> a = da.arange(10, chunks=2).sum()
637 >>> b = da.arange(10, chunks=2).mean()
638 >>> dask.compute(a, b)
639 (np.int64(45), np.float64(4.5))
641 By default, dask objects inside python collections will also be computed:
643 >>> dask.compute({'a': a, 'b': b, 'c': 1})
644 ({'a': np.int64(45), 'b': np.float64(4.5), 'c': 1},)
645 """
647 collections, repack = unpack_collections(*args, traverse=traverse)
648 if not collections:
649 return args
651 schedule = get_scheduler(
652 scheduler=scheduler,
653 collections=collections,
654 get=get,
655 )
656 from dask._expr import FinalizeCompute
658 expr = collections_to_expr(collections, optimize_graph)
659 expr = FinalizeCompute(expr)
661 with shorten_traceback():
662 # The high level optimize will have to be called client side (for now)
663 # The optimize can internally trigger already a computation
664 # (e.g. parquet is reading some statistics). To move this to the
665 # scheduler we'd need some sort of scheduler-client to trigger a
666 # computation from inside the scheduler and continue with optimization
667 # once the results are in. An alternative could be to introduce a
668 # pre-optimize step for the Expressions that handles steps like these as
669 # a dedicated computation
671 # Another caveat is that optimize will only lock in the expression names
672 # after optimization. Names are determined using tokenize and tokenize
673 # is not cross-interpreter (let alone cross-host) stable such that we
674 # have to lock this in before sending stuff (otherwise we'd need to
675 # change the graph submission to a handshake which introduces all sorts
676 # of concurrency control issues)
678 expr = expr.optimize()
679 keys = list(flatten(expr.__dask_keys__()))
681 results = schedule(expr, keys, **kwargs)
683 return repack(results)
686def visualize(
687 *args,
688 filename="mydask",
689 traverse=True,
690 optimize_graph=False,
691 maxval=None,
692 engine: Literal["cytoscape", "ipycytoscape", "graphviz"] | None = None,
693 **kwargs,
694):
695 """
696 Visualize several dask graphs simultaneously.
698 Requires ``graphviz`` to be installed. All options that are not the dask
699 graph(s) should be passed as keyword arguments.
701 Parameters
702 ----------
703 args : object
704 Any number of objects. If it is a dask collection (for example, a
705 dask DataFrame, Array, Bag, or Delayed), its associated graph
706 will be included in the output of visualize. By default, python builtin
707 collections are also traversed to look for dask objects (for more
708 information see the ``traverse`` keyword). Arguments lacking an
709 associated graph will be ignored.
710 filename : str or None, optional
711 The name of the file to write to disk. If the provided `filename`
712 doesn't include an extension, '.png' will be used by default.
713 If `filename` is None, no file will be written, and we communicate
714 with dot using only pipes.
715 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional
716 Format in which to write output file. Default is 'png'.
717 traverse : bool, optional
718 By default, dask traverses builtin python collections looking for dask
719 objects passed to ``visualize``. For large collections this can be
720 expensive. If none of the arguments contain any dask objects, set
721 ``traverse=False`` to avoid doing this traversal.
722 optimize_graph : bool, optional
723 If True, the graph is optimized before rendering. Otherwise,
724 the graph is displayed as is. Default is False.
725 color : {None, 'order', 'ages', 'freed', 'memoryincreases', 'memorydecreases', 'memorypressure'}, optional
726 Options to color nodes. colormap:
728 - None, the default, no colors.
729 - 'order', colors the nodes' border based on the order they appear in the graph.
730 - 'ages', how long the data of a node is held.
731 - 'freed', the number of dependencies released after running a node.
732 - 'memoryincreases', how many more outputs are held after the lifetime of a node.
733 Large values may indicate nodes that should have run later.
734 - 'memorydecreases', how many fewer outputs are held after the lifetime of a node.
735 Large values may indicate nodes that should have run sooner.
736 - 'memorypressure', the number of data held when the node is run (circle), or
737 the data is released (rectangle).
738 maxval : {int, float}, optional
739 Maximum value for colormap to normalize form 0 to 1.0. Default is ``None``
740 will make it the max number of values
741 collapse_outputs : bool, optional
742 Whether to collapse output boxes, which often have empty labels.
743 Default is False.
744 verbose : bool, optional
745 Whether to label output and input boxes even if the data aren't chunked.
746 Beware: these labels can get very long. Default is False.
747 engine : {"graphviz", "ipycytoscape", "cytoscape"}, optional.
748 The visualization engine to use. If not provided, this checks the dask config
749 value "visualization.engine". If that is not set, it tries to import ``graphviz``
750 and ``ipycytoscape``, using the first one to succeed.
751 **kwargs
752 Additional keyword arguments to forward to the visualization engine.
754 Examples
755 --------
756 >>> x.visualize(filename='dask.pdf') # doctest: +SKIP
757 >>> x.visualize(filename='dask.pdf', color='order') # doctest: +SKIP
759 Returns
760 -------
761 result : IPython.display.Image, IPython.display.SVG, or None
762 See dask.dot.dot_graph for more information.
764 See Also
765 --------
766 dask.dot.dot_graph
768 Notes
769 -----
770 For more information on optimization see here:
772 https://docs.dask.org/en/latest/optimize.html
773 """
774 args, _ = unpack_collections(*args, traverse=traverse)
776 dsk = collections_to_expr(args, optimize_graph=optimize_graph).__dask_graph__()
778 return visualize_dsk(
779 dsk=dsk,
780 filename=filename,
781 traverse=traverse,
782 optimize_graph=optimize_graph,
783 maxval=maxval,
784 engine=engine,
785 **kwargs,
786 )
789def visualize_dsk(
790 dsk,
791 filename="mydask",
792 traverse=True,
793 optimize_graph=False,
794 maxval=None,
795 o=None,
796 engine: Literal["cytoscape", "ipycytoscape", "graphviz"] | None = None,
797 limit=None,
798 **kwargs,
799):
800 color = kwargs.get("color")
801 from dask.order import diagnostics, order
803 if color in {
804 "order",
805 "order-age",
806 "order-freed",
807 "order-memoryincreases",
808 "order-memorydecreases",
809 "order-memorypressure",
810 "age",
811 "freed",
812 "memoryincreases",
813 "memorydecreases",
814 "memorypressure",
815 "critical",
816 "cpath",
817 }:
818 import matplotlib.pyplot as plt
820 if o is None:
821 o_stats = order(dsk, return_stats=True)
822 o = {k: v.priority for k, v in o_stats.items()}
823 elif isinstance(next(iter(o.values())), int):
824 o_stats = order(dsk, return_stats=True)
825 else:
826 o_stats = o
827 o = {k: v.priority for k, v in o.items()}
829 try:
830 cmap = kwargs.pop("cmap")
831 except KeyError:
832 cmap = plt.cm.plasma
833 if isinstance(cmap, str):
834 import matplotlib.pyplot as plt
836 cmap = getattr(plt.cm, cmap)
838 def label(x):
839 return str(values[x])
841 data_values = None
842 if color != "order":
843 info = diagnostics(dsk, o)[0]
844 if color.endswith("age"):
845 values = {key: val.age for key, val in info.items()}
846 elif color.endswith("freed"):
847 values = {key: val.num_dependencies_freed for key, val in info.items()}
848 elif color.endswith("memorypressure"):
849 values = {key: val.num_data_when_run for key, val in info.items()}
850 data_values = {
851 key: val.num_data_when_released for key, val in info.items()
852 }
853 elif color.endswith("memoryincreases"):
854 values = {
855 key: max(0, val.num_data_when_released - val.num_data_when_run)
856 for key, val in info.items()
857 }
858 elif color.endswith("memorydecreases"):
859 values = {
860 key: max(0, val.num_data_when_run - val.num_data_when_released)
861 for key, val in info.items()
862 }
863 elif color.split("-")[-1] in {"critical", "cpath"}:
864 values = {key: val.critical_path for key, val in o_stats.items()}
865 else:
866 raise NotImplementedError(color)
868 if color.startswith("order-"):
870 def label(x):
871 return str(o[x]) + "-" + str(values[x])
873 else:
874 values = o
875 if maxval is None:
876 maxval = max(1, max(values.values()))
877 colors = {
878 k: _colorize(tuple(map(int, cmap(v / maxval, bytes=True))))
879 for k, v in values.items()
880 }
881 if data_values is None:
882 data_colors = colors
883 else:
884 data_colors = {
885 k: _colorize(tuple(map(int, cmap(v / maxval, bytes=True))))
886 for k, v in values.items()
887 }
889 kwargs["function_attributes"] = {
890 k: {"color": v, "label": label(k)} for k, v in colors.items()
891 }
892 kwargs["data_attributes"] = {k: {"color": v} for k, v in data_colors.items()}
893 elif color:
894 raise NotImplementedError("Unknown value color=%s" % color)
896 # Determine which engine to dispatch to, first checking the kwarg, then config,
897 # then whichever of graphviz or ipycytoscape are installed, in that order.
898 engine = engine or config.get("visualization.engine", None)
900 if not engine:
901 try:
902 import graphviz # noqa: F401
904 engine = "graphviz"
905 except ImportError:
906 try:
907 import ipycytoscape # noqa: F401
909 engine = "cytoscape"
910 except ImportError:
911 pass
912 if engine == "graphviz":
913 from dask.dot import dot_graph
915 return dot_graph(dsk, filename=filename, **kwargs)
916 elif engine in ("cytoscape", "ipycytoscape"):
917 from dask.dot import cytoscape_graph
919 return cytoscape_graph(dsk, filename=filename, **kwargs)
920 elif engine is None:
921 raise RuntimeError(
922 "No visualization engine detected, please install graphviz or ipycytoscape"
923 )
924 else:
925 raise ValueError(f"Visualization engine {engine} not recognized")
928def persist(*args, traverse=True, optimize_graph=True, scheduler=None, **kwargs):
929 """Persist multiple Dask collections into memory
931 This turns lazy Dask collections into Dask collections with the same
932 metadata, but now with their results fully computed or actively computing
933 in the background.
935 For example a lazy dask.array built up from many lazy calls will now be a
936 dask.array of the same shape, dtype, chunks, etc., but now with all of
937 those previously lazy tasks either computed in memory as many small :class:`numpy.array`
938 (in the single-machine case) or asynchronously running in the
939 background on a cluster (in the distributed case).
941 This function operates differently if a ``dask.distributed.Client`` exists
942 and is connected to a distributed scheduler. In this case this function
943 will return as soon as the task graph has been submitted to the cluster,
944 but before the computations have completed. Computations will continue
945 asynchronously in the background. When using this function with the single
946 machine scheduler it blocks until the computations have finished.
948 When using Dask on a single machine you should ensure that the dataset fits
949 entirely within memory.
951 Examples
952 --------
953 >>> df = dd.read_csv('/path/to/*.csv') # doctest: +SKIP
954 >>> df = df[df.name == 'Alice'] # doctest: +SKIP
955 >>> df['in-debt'] = df.balance < 0 # doctest: +SKIP
956 >>> df = df.persist() # triggers computation # doctest: +SKIP
958 >>> df.value().min() # future computations are now fast # doctest: +SKIP
959 -10
960 >>> df.value().max() # doctest: +SKIP
961 100
963 >>> from dask import persist # use persist function on multiple collections
964 >>> a, b = persist(a, b) # doctest: +SKIP
966 Parameters
967 ----------
968 *args: Dask collections
969 scheduler : string, optional
970 Which scheduler to use like "threads", "synchronous" or "processes".
971 If not provided, the default is to check the global settings first,
972 and then fall back to the collection defaults.
973 traverse : bool, optional
974 By default dask traverses builtin python collections looking for dask
975 objects passed to ``persist``. For large collections this can be
976 expensive. If none of the arguments contain any dask objects, set
977 ``traverse=False`` to avoid doing this traversal.
978 optimize_graph : bool, optional
979 If True [default], the graph is optimized before computation.
980 Otherwise the graph is run as is. This can be useful for debugging.
981 **kwargs
982 Extra keywords to forward to the scheduler function.
984 Returns
985 -------
986 New dask collections backed by in-memory data
987 """
988 collections, repack = unpack_collections(*args, traverse=traverse)
989 if not collections:
990 return args
992 schedule = get_scheduler(scheduler=scheduler, collections=collections)
994 if inspect.ismethod(schedule):
995 try:
996 from distributed.client import default_client
997 except ImportError:
998 pass
999 else:
1000 try:
1001 client = default_client()
1002 except ValueError:
1003 pass
1004 else:
1005 if client.get == schedule:
1006 results = client.persist(
1007 collections, optimize_graph=optimize_graph, **kwargs
1008 )
1009 return repack(results)
1011 expr = collections_to_expr(collections, optimize_graph)
1012 expr = expr.optimize()
1013 keys, postpersists = [], []
1014 for a, akeys in zip(collections, expr.__dask_keys__(), strict=True):
1015 a_keys = list(flatten(akeys))
1016 rebuild, state = a.__dask_postpersist__()
1017 keys.extend(a_keys)
1018 postpersists.append((rebuild, a_keys, state))
1020 with shorten_traceback():
1021 results = schedule(expr, keys, **kwargs)
1023 d = dict(zip(keys, results))
1024 results2 = [r({k: d[k] for k in ks}, *s) for r, ks, s in postpersists]
1025 return repack(results2)
1028def _colorize(t):
1029 """Convert (r, g, b) triple to "#RRGGBB" string
1031 For use with ``visualize(color=...)``
1033 Examples
1034 --------
1035 >>> _colorize((255, 255, 255))
1036 '#FFFFFF'
1037 >>> _colorize((0, 32, 128))
1038 '#002080'
1039 """
1040 t = t[:3]
1041 i = sum(v * 256 ** (len(t) - i - 1) for i, v in enumerate(t))
1042 h = hex(int(i))[2:].upper()
1043 h = "0" * (6 - len(h)) + h
1044 return "#" + h
1047named_schedulers: dict[str, SchedulerGetCallable] = {
1048 "sync": local.get_sync,
1049 "synchronous": local.get_sync,
1050 "single-threaded": local.get_sync,
1051}
1053if not EMSCRIPTEN:
1054 from dask import threaded
1056 named_schedulers.update(
1057 {
1058 "threads": threaded.get,
1059 "threading": threaded.get,
1060 }
1061 )
1063 from dask import multiprocessing as dask_multiprocessing
1065 named_schedulers.update(
1066 {
1067 "processes": dask_multiprocessing.get,
1068 "multiprocessing": dask_multiprocessing.get,
1069 }
1070 )
1073get_err_msg = """
1074The get= keyword has been removed.
1076Please use the scheduler= keyword instead with the name of
1077the desired scheduler like 'threads' or 'processes'
1079 x.compute(scheduler='single-threaded')
1080 x.compute(scheduler='threads')
1081 x.compute(scheduler='processes')
1083or with a function that takes the graph and keys
1085 x.compute(scheduler=my_scheduler_function)
1087or with a Dask client
1089 x.compute(scheduler=client)
1090""".strip()
1093def _ensure_not_async(client):
1094 if client.asynchronous:
1095 if fallback := config.get("admin.async-client-fallback", None):
1096 warnings.warn(
1097 "Distributed Client detected but Client instance is "
1098 f"asynchronous. Falling back to `{fallback}` scheduler. "
1099 "To use an asynchronous Client, please use "
1100 "``Client.compute`` and ``Client.gather`` "
1101 "instead of the top level ``dask.compute``",
1102 UserWarning,
1103 )
1104 return get_scheduler(scheduler=fallback)
1105 else:
1106 raise RuntimeError(
1107 "Attempting to use an asynchronous "
1108 "Client in a synchronous context of `dask.compute`"
1109 )
1110 return client.get
1113def get_scheduler(get=None, scheduler=None, collections=None, cls=None):
1114 """Get scheduler function
1116 There are various ways to specify the scheduler to use:
1118 1. Passing in scheduler= parameters
1119 2. Passing these into global configuration
1120 3. Using a dask.distributed default Client
1121 4. Using defaults of a dask collection
1123 This function centralizes the logic to determine the right scheduler to use
1124 from those many options
1125 """
1126 if get:
1127 raise TypeError(get_err_msg)
1129 if scheduler is not None:
1130 if callable(scheduler):
1131 return scheduler
1132 elif "Client" in type(scheduler).__name__ and hasattr(scheduler, "get"):
1133 return _ensure_not_async(scheduler)
1134 elif isinstance(scheduler, str):
1135 scheduler = scheduler.lower()
1137 client_available = False
1138 if _distributed_available():
1139 assert _DistributedClient is not None
1140 with suppress(ValueError):
1141 _DistributedClient.current(allow_global=True)
1142 client_available = True
1143 if scheduler in named_schedulers:
1144 return named_schedulers[scheduler]
1145 elif scheduler in ("dask.distributed", "distributed"):
1146 if not client_available:
1147 raise RuntimeError(
1148 f"Requested {scheduler} scheduler but no Client active."
1149 )
1150 assert _get_distributed_client is not None
1151 client = _get_distributed_client()
1152 return _ensure_not_async(client)
1153 else:
1154 raise ValueError(
1155 "Expected one of [distributed, %s]"
1156 % ", ".join(sorted(named_schedulers))
1157 )
1158 elif isinstance(scheduler, Executor):
1159 # Get `num_workers` from `Executor`'s `_max_workers` attribute.
1160 # If undefined, fallback to `config` or worst case CPU_COUNT.
1161 num_workers = getattr(scheduler, "_max_workers", None)
1162 if num_workers is None:
1163 num_workers = config.get("num_workers", CPU_COUNT)
1164 assert isinstance(num_workers, Integral) and num_workers > 0
1165 return partial(local.get_async, scheduler.submit, num_workers)
1166 else:
1167 raise ValueError("Unexpected scheduler: %s" % repr(scheduler))
1168 # else: # try to connect to remote scheduler with this name
1169 # return get_client(scheduler).get
1171 if config.get("scheduler", None):
1172 return get_scheduler(scheduler=config.get("scheduler", None))
1174 if config.get("get", None):
1175 raise ValueError(get_err_msg)
1177 try:
1178 from distributed import get_client
1180 return _ensure_not_async(get_client())
1181 except (ImportError, ValueError):
1182 pass
1184 if cls is not None:
1185 return cls.__dask_scheduler__
1187 if collections:
1188 collections = [c for c in collections if c is not None]
1189 if collections:
1190 get = collections[0].__dask_scheduler__
1191 if not all(c.__dask_scheduler__ == get for c in collections):
1192 raise ValueError(
1193 "Compute called on multiple collections with "
1194 "differing default schedulers. Please specify a "
1195 "scheduler=` parameter explicitly in compute or "
1196 "globally with `dask.config.set`."
1197 )
1198 return get
1200 return None
1203def wait(x, timeout=None, return_when="ALL_COMPLETED"):
1204 """Wait until computation has finished
1206 This is a compatibility alias for ``dask.distributed.wait``.
1207 If it is applied onto Dask collections without Dask Futures or if Dask
1208 distributed is not installed then it is a no-op
1209 """
1210 try:
1211 from distributed import wait
1213 return wait(x, timeout=timeout, return_when=return_when)
1214 except (ImportError, ValueError):
1215 return x
1218def get_collection_names(collection) -> set[str]:
1219 """Infer the collection names from the dask keys, under the assumption that all keys
1220 are either tuples with matching first element, and that element is a string, or
1221 there is exactly one key and it is a string.
1223 Examples
1224 --------
1225 >>> a.__dask_keys__() # doctest: +SKIP
1226 ["foo", "bar"]
1227 >>> get_collection_names(a) # doctest: +SKIP
1228 {"foo", "bar"}
1229 >>> b.__dask_keys__() # doctest: +SKIP
1230 [[("foo-123", 0, 0), ("foo-123", 0, 1)], [("foo-123", 1, 0), ("foo-123", 1, 1)]]
1231 >>> get_collection_names(b) # doctest: +SKIP
1232 {"foo-123"}
1233 """
1234 if not is_dask_collection(collection):
1235 raise TypeError(f"Expected Dask collection; got {type(collection)}")
1236 return {get_name_from_key(k) for k in flatten(collection.__dask_keys__())}
1239def get_name_from_key(key: Key) -> str:
1240 """Given a dask collection's key, extract the collection name.
1242 Parameters
1243 ----------
1244 key: string or tuple
1245 Dask collection's key, which must be either a single string or a tuple whose
1246 first element is a string (commonly referred to as a collection's 'name'),
1248 Examples
1249 --------
1250 >>> get_name_from_key("foo")
1251 'foo'
1252 >>> get_name_from_key(("foo-123", 1, 2))
1253 'foo-123'
1254 """
1255 if isinstance(key, tuple) and key and isinstance(key[0], str):
1256 return key[0]
1257 if isinstance(key, str):
1258 return key
1259 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}")
1262KeyOrStrT = TypeVar("KeyOrStrT", Key, str)
1265def replace_name_in_key(key: KeyOrStrT, rename: Mapping[str, str]) -> KeyOrStrT:
1266 """Given a dask collection's key, replace the collection name with a new one.
1268 Parameters
1269 ----------
1270 key: string or tuple
1271 Dask collection's key, which must be either a single string or a tuple whose
1272 first element is a string (commonly referred to as a collection's 'name'),
1273 rename:
1274 Mapping of zero or more names from : to. Extraneous names will be ignored.
1275 Names not found in this mapping won't be replaced.
1277 Examples
1278 --------
1279 >>> replace_name_in_key("foo", {})
1280 'foo'
1281 >>> replace_name_in_key("foo", {"foo": "bar"})
1282 'bar'
1283 >>> replace_name_in_key(("foo-123", 1, 2), {"foo-123": "bar-456"})
1284 ('bar-456', 1, 2)
1285 """
1286 if isinstance(key, tuple) and key and isinstance(key[0], str):
1287 return (rename.get(key[0], key[0]),) + key[1:]
1288 if isinstance(key, str):
1289 return rename.get(key, key)
1290 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}")
1293def clone_key(key: KeyOrStrT, seed: Hashable) -> KeyOrStrT:
1294 """Clone a key from a Dask collection, producing a new key with the same prefix and
1295 indices and a token which is a deterministic function of the previous key and seed.
1297 Examples
1298 --------
1299 >>> clone_key("x", 123) # doctest: +SKIP
1300 'x-c4fb64ccca807af85082413d7ef01721'
1301 >>> clone_key("inc-cbb1eca3bafafbb3e8b2419c4eebb387", 123) # doctest: +SKIP
1302 'inc-bc629c23014a4472e18b575fdaf29ee7'
1303 >>> clone_key(("sum-cbb1eca3bafafbb3e8b2419c4eebb387", 4, 3), 123) # doctest: +SKIP
1304 ('sum-c053f3774e09bd0f7de6044dbc40e71d', 4, 3)
1305 """
1306 if isinstance(key, tuple) and key and isinstance(key[0], str):
1307 return (clone_key(key[0], seed),) + key[1:]
1308 if isinstance(key, str):
1309 prefix = key_split(key)
1310 return prefix + "-" + tokenize(key, seed)
1311 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}")