Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/base.py: 18%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import dataclasses
4import uuid
5import warnings
6from collections import OrderedDict
7from collections.abc import Hashable, Iterable, Iterator, Mapping
8from concurrent.futures import Executor
9from contextlib import contextmanager, suppress
10from contextvars import ContextVar
11from functools import partial
12from numbers import Integral, Number
13from operator import getitem
14from typing import TYPE_CHECKING, Any, Literal, TypeVar
16from tlz import merge
18from dask import config, local
19from dask._compatibility import EMSCRIPTEN
20from dask._task_spec import DataNode, Dict, List, Task, TaskRef
21from dask.core import flatten
22from dask.core import get as simple_get
23from dask.system import CPU_COUNT
24from dask.typing import Key, SchedulerGetCallable
25from dask.utils import is_namedtuple_instance, key_split, shorten_traceback
27if TYPE_CHECKING:
28 from dask._expr import Expr
30_DistributedClient = None
31_get_distributed_client = None
32_DISTRIBUTED_AVAILABLE = None
35def _distributed_available() -> bool:
36 # Lazy import in get_scheduler can be expensive
37 global _DistributedClient, _get_distributed_client, _DISTRIBUTED_AVAILABLE
38 if _DISTRIBUTED_AVAILABLE is not None:
39 return _DISTRIBUTED_AVAILABLE # type: ignore[unreachable]
40 try:
41 from distributed import Client as _DistributedClient
42 from distributed.worker import get_client as _get_distributed_client
44 _DISTRIBUTED_AVAILABLE = True
45 except ImportError:
46 _DISTRIBUTED_AVAILABLE = False
47 return _DISTRIBUTED_AVAILABLE
50__all__ = (
51 "DaskMethodsMixin",
52 "annotate",
53 "get_annotations",
54 "is_dask_collection",
55 "compute",
56 "persist",
57 "optimize",
58 "visualize",
59 "tokenize",
60 "normalize_token",
61 "get_collection_names",
62 "get_name_from_key",
63 "replace_name_in_key",
64 "clone_key",
65)
67# Backwards compat
68from dask.tokenize import TokenizationError, normalize_token, tokenize # noqa: F401
70_annotations: ContextVar[dict[str, Any] | None] = ContextVar(
71 "annotations", default=None
72)
75def get_annotations() -> dict[str, Any]:
76 """Get current annotations.
78 Returns
79 -------
80 Dict of all current annotations
82 See Also
83 --------
84 annotate
85 """
86 return _annotations.get() or {}
89@contextmanager
90def annotate(**annotations: Any) -> Iterator[None]:
91 """Context Manager for setting HighLevelGraph Layer annotations.
93 Annotations are metadata or soft constraints associated with
94 tasks that dask schedulers may choose to respect: They signal intent
95 without enforcing hard constraints. As such, they are
96 primarily designed for use with the distributed scheduler.
98 Almost any object can serve as an annotation, but small Python objects
99 are preferred, while large objects such as NumPy arrays are discouraged.
101 Callables supplied as an annotation should take a single *key* argument and
102 produce the appropriate annotation. Individual task keys in the annotated collection
103 are supplied to the callable.
105 Parameters
106 ----------
107 **annotations : key-value pairs
109 Examples
110 --------
112 All tasks within array A should have priority 100 and be retried 3 times
113 on failure.
115 >>> import dask
116 >>> import dask.array as da
117 >>> with dask.annotate(priority=100, retries=3):
118 ... A = da.ones((10000, 10000))
120 Prioritise tasks within Array A on flattened block ID.
122 >>> nblocks = (10, 10)
123 >>> with dask.annotate(priority=lambda k: k[1]*nblocks[1] + k[2]):
124 ... A = da.ones((1000, 1000), chunks=(100, 100))
126 Annotations may be nested.
128 >>> with dask.annotate(priority=1):
129 ... with dask.annotate(retries=3):
130 ... A = da.ones((1000, 1000))
131 ... B = A + 1
133 See Also
134 --------
135 get_annotations
136 """
138 # Sanity check annotations used in place of
139 # legacy distributed Client.{submit, persist, compute} keywords
140 if "workers" in annotations:
141 if isinstance(annotations["workers"], (list, set, tuple)):
142 annotations["workers"] = list(annotations["workers"])
143 elif isinstance(annotations["workers"], str):
144 annotations["workers"] = [annotations["workers"]]
145 elif callable(annotations["workers"]):
146 pass
147 else:
148 raise TypeError(
149 "'workers' annotation must be a sequence of str, a str or a callable, but got {}.".format(
150 annotations["workers"]
151 )
152 )
154 if (
155 "priority" in annotations
156 and not isinstance(annotations["priority"], Number)
157 and not callable(annotations["priority"])
158 ):
159 raise TypeError(
160 "'priority' annotation must be a Number or a callable, but got {}".format(
161 annotations["priority"]
162 )
163 )
165 if (
166 "retries" in annotations
167 and not isinstance(annotations["retries"], Number)
168 and not callable(annotations["retries"])
169 ):
170 raise TypeError(
171 "'retries' annotation must be a Number or a callable, but got {}".format(
172 annotations["retries"]
173 )
174 )
176 if (
177 "resources" in annotations
178 and not isinstance(annotations["resources"], dict)
179 and not callable(annotations["resources"])
180 ):
181 raise TypeError(
182 "'resources' annotation must be a dict, but got {}".format(
183 annotations["resources"]
184 )
185 )
187 if (
188 "allow_other_workers" in annotations
189 and not isinstance(annotations["allow_other_workers"], bool)
190 and not callable(annotations["allow_other_workers"])
191 ):
192 raise TypeError(
193 "'allow_other_workers' annotations must be a bool or a callable, but got {}".format(
194 annotations["allow_other_workers"]
195 )
196 )
197 ctx_annot = _annotations.get()
198 if ctx_annot is None:
199 ctx_annot = {}
200 token = _annotations.set(merge(ctx_annot, annotations))
201 try:
202 yield
203 finally:
204 _annotations.reset(token)
207def is_dask_collection(x) -> bool:
208 """Returns ``True`` if ``x`` is a dask collection.
210 Parameters
211 ----------
212 x : Any
213 Object to test.
215 Returns
216 -------
217 result : bool
218 ``True`` if `x` is a Dask collection.
220 Notes
221 -----
222 The DaskCollection typing.Protocol implementation defines a Dask
223 collection as a class that returns a Mapping from the
224 ``__dask_graph__`` method. This helper function existed before the
225 implementation of the protocol.
227 """
228 if (
229 isinstance(x, type)
230 or not hasattr(x, "__dask_graph__")
231 or not callable(x.__dask_graph__)
232 ):
233 return False
235 pkg_name = getattr(type(x), "__module__", "")
236 if pkg_name.split(".")[0] in ("dask_cudf",):
237 # Temporary hack to avoid graph materialization. Note that this won't work with
238 # dask_expr.array objects wrapped by xarray or pint. By the time dask_expr.array
239 # is published, we hope to be able to rewrite this method completely.
240 # Read: https://github.com/dask/dask/pull/10676
241 return True
242 elif pkg_name.startswith("dask.dataframe.dask_expr"):
243 return True
244 elif pkg_name.startswith("dask.array._array_expr"):
245 return True
247 # xarray, pint, and possibly other wrappers always define a __dask_graph__ method,
248 # but it may return None if they wrap around a non-dask object.
249 # In all known dask collections other than dask-expr,
250 # calling __dask_graph__ is cheap.
251 return x.__dask_graph__() is not None
254class DaskMethodsMixin:
255 """A mixin adding standard dask collection methods"""
257 __slots__ = ("__weakref__",)
259 def visualize(self, filename="mydask", format=None, optimize_graph=False, **kwargs):
260 """Render the computation of this object's task graph using graphviz.
262 Requires ``graphviz`` to be installed.
264 Parameters
265 ----------
266 filename : str or None, optional
267 The name of the file to write to disk. If the provided `filename`
268 doesn't include an extension, '.png' will be used by default.
269 If `filename` is None, no file will be written, and we communicate
270 with dot using only pipes.
271 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional
272 Format in which to write output file. Default is 'png'.
273 optimize_graph : bool, optional
274 If True, the graph is optimized before rendering. Otherwise,
275 the graph is displayed as is. Default is False.
276 color: {None, 'order'}, optional
277 Options to color nodes. Provide ``cmap=`` keyword for additional
278 colormap
279 **kwargs
280 Additional keyword arguments to forward to ``to_graphviz``.
282 Examples
283 --------
284 >>> x.visualize(filename='dask.pdf') # doctest: +SKIP
285 >>> x.visualize(filename='dask.pdf', color='order') # doctest: +SKIP
287 Returns
288 -------
289 result : IPython.display.Image, IPython.display.SVG, or None
290 See dask.dot.dot_graph for more information.
292 See Also
293 --------
294 dask.visualize
295 dask.dot.dot_graph
297 Notes
298 -----
299 For more information on optimization see here:
301 https://docs.dask.org/en/latest/optimize.html
302 """
303 return visualize(
304 self,
305 filename=filename,
306 format=format,
307 optimize_graph=optimize_graph,
308 **kwargs,
309 )
311 def persist(self, **kwargs):
312 """Persist this dask collection into memory
314 This turns a lazy Dask collection into a Dask collection with the same
315 metadata, but now with the results fully computed or actively computing
316 in the background.
318 The action of function differs significantly depending on the active
319 task scheduler. If the task scheduler supports asynchronous computing,
320 such as is the case of the dask.distributed scheduler, then persist
321 will return *immediately* and the return value's task graph will
322 contain Dask Future objects. However if the task scheduler only
323 supports blocking computation then the call to persist will *block*
324 and the return value's task graph will contain concrete Python results.
326 This function is particularly useful when using distributed systems,
327 because the results will be kept in distributed memory, rather than
328 returned to the local process as with compute.
330 Parameters
331 ----------
332 scheduler : string, optional
333 Which scheduler to use like "threads", "synchronous" or "processes".
334 If not provided, the default is to check the global settings first,
335 and then fall back to the collection defaults.
336 optimize_graph : bool, optional
337 If True [default], the graph is optimized before computation.
338 Otherwise the graph is run as is. This can be useful for debugging.
339 **kwargs
340 Extra keywords to forward to the scheduler function.
342 Returns
343 -------
344 New dask collections backed by in-memory data
346 See Also
347 --------
348 dask.persist
349 """
350 (result,) = persist(self, traverse=False, **kwargs)
351 return result
353 def compute(self, **kwargs):
354 """Compute this dask collection
356 This turns a lazy Dask collection into its in-memory equivalent.
357 For example a Dask array turns into a NumPy array and a Dask dataframe
358 turns into a Pandas dataframe. The entire dataset must fit into memory
359 before calling this operation.
361 Parameters
362 ----------
363 scheduler : string, optional
364 Which scheduler to use like "threads", "synchronous" or "processes".
365 If not provided, the default is to check the global settings first,
366 and then fall back to the collection defaults.
367 optimize_graph : bool, optional
368 If True [default], the graph is optimized before computation.
369 Otherwise the graph is run as is. This can be useful for debugging.
370 kwargs
371 Extra keywords to forward to the scheduler function.
373 See Also
374 --------
375 dask.compute
376 """
377 (result,) = compute(self, traverse=False, **kwargs)
378 return result
380 def __await__(self):
381 try:
382 from distributed import futures_of, wait
383 except ImportError as e:
384 raise ImportError(
385 "Using async/await with dask requires the `distributed` package"
386 ) from e
388 async def f():
389 if futures_of(self):
390 await wait(self)
391 return self
393 return f().__await__()
396def compute_as_if_collection(cls, dsk, keys, scheduler=None, get=None, **kwargs):
397 """Compute a graph as if it were of type cls.
399 Allows for applying the same optimizations and default scheduler."""
400 schedule = get_scheduler(scheduler=scheduler, cls=cls, get=get)
401 dsk2 = optimization_function(cls)(dsk, keys, **kwargs)
402 return schedule(dsk2, keys, **kwargs)
405def dont_optimize(dsk, keys, **kwargs):
406 return dsk
409def optimization_function(x):
410 return getattr(x, "__dask_optimize__", dont_optimize)
413def collections_to_expr(
414 collections: Iterable,
415 optimize_graph: bool = True,
416) -> Expr:
417 """
418 Convert many collections into a single dask expression.
420 Typically, users should not be required to interact with this function.
422 Parameters
423 ----------
424 collections : Iterable
425 An iterable of dask collections to be combined.
426 optimize_graph : bool, optional
427 If this is True and collections are encountered which are backed by
428 legacy HighLevelGraph objects, the returned Expression will run a low
429 level task optimization during materialization.
430 """
431 is_iterable = False
432 if isinstance(collections, (tuple, list, set)):
433 is_iterable = True
434 else:
435 collections = [collections]
436 if not collections:
437 raise ValueError("No collections provided")
438 from dask._expr import HLGExpr, _ExprSequence
440 graphs = []
441 for coll in collections:
442 from dask.delayed import Delayed
444 if isinstance(coll, Delayed) or not hasattr(coll, "expr"):
445 graphs.append(HLGExpr.from_collection(coll, optimize_graph=optimize_graph))
446 else:
447 graphs.append(coll.expr)
449 if len(graphs) > 1 or is_iterable:
450 return _ExprSequence(*graphs)
451 else:
452 return graphs[0]
455def unpack_collections(*args, traverse=True):
456 """Extract collections in preparation for compute/persist/etc...
458 Intended use is to find all collections in a set of (possibly nested)
459 python objects, do something to them (compute, etc...), then repackage them
460 in equivalent python objects.
462 Parameters
463 ----------
464 *args
465 Any number of objects. If it is a dask collection, it's extracted and
466 added to the list of collections returned. By default, python builtin
467 collections are also traversed to look for dask collections (for more
468 information see the ``traverse`` keyword).
469 traverse : bool, optional
470 If True (default), builtin python collections are traversed looking for
471 any dask collections they might contain.
473 Returns
474 -------
475 collections : list
476 A list of all dask collections contained in ``args``
477 repack : callable
478 A function to call on the transformed collections to repackage them as
479 they were in the original ``args``.
480 """
482 collections = []
483 repack_dsk = {}
485 collections_token = uuid.uuid4().hex
487 def _unpack(expr):
488 if is_dask_collection(expr):
489 tok = tokenize(expr)
490 if tok not in repack_dsk:
491 repack_dsk[tok] = Task(
492 tok, getitem, TaskRef(collections_token), len(collections)
493 )
494 collections.append(expr)
495 return TaskRef(tok)
497 tok = uuid.uuid4().hex
498 tsk: DataNode | Task # type: ignore[annotation-unchecked]
499 if not traverse:
500 tsk = DataNode(None, expr)
501 else:
502 # Treat iterators like lists
503 typ = list if isinstance(expr, Iterator) else type(expr)
504 if typ in (list, tuple, set):
505 tsk = Task(tok, typ, List(*[_unpack(i) for i in expr]))
506 elif typ in (dict, OrderedDict):
507 tsk = Task(
508 tok, typ, Dict({_unpack(k): _unpack(v) for k, v in expr.items()})
509 )
510 elif dataclasses.is_dataclass(expr) and not isinstance(expr, type):
511 tsk = Task(
512 tok,
513 typ,
514 *[_unpack(getattr(expr, f.name)) for f in dataclasses.fields(expr)],
515 )
516 elif is_namedtuple_instance(expr):
517 tsk = Task(tok, typ, *[_unpack(i) for i in expr])
518 else:
519 return expr
521 repack_dsk[tok] = tsk
522 return TaskRef(tok)
524 out = uuid.uuid4().hex
525 repack_dsk[out] = Task(out, tuple, List(*[_unpack(i) for i in args]))
527 def repack(results):
528 dsk = repack_dsk.copy()
529 dsk[collections_token] = DataNode(collections_token, results)
530 return simple_get(dsk, out)
532 # The original `collections` is kept alive by the closure
533 # This causes the collection to be only freed by the garbage collector
534 collections2 = list(collections)
535 collections.clear()
536 return collections2, repack
539def optimize(*args, traverse=True, **kwargs):
540 """Optimize several dask collections at once.
542 Returns equivalent dask collections that all share the same merged and
543 optimized underlying graph. This can be useful if converting multiple
544 collections to delayed objects, or to manually apply the optimizations at
545 strategic points.
547 Note that in most cases you shouldn't need to call this function directly.
549 Warning::
551 This function triggers a materialization of the collections and looses
552 any annotations attached to HLG layers.
554 Parameters
555 ----------
556 *args : objects
557 Any number of objects. If a dask object, its graph is optimized and
558 merged with all those of all other dask objects before returning an
559 equivalent dask collection. Non-dask arguments are passed through
560 unchanged.
561 traverse : bool, optional
562 By default dask traverses builtin python collections looking for dask
563 objects passed to ``optimize``. For large collections this can be
564 expensive. If none of the arguments contain any dask objects, set
565 ``traverse=False`` to avoid doing this traversal.
566 optimizations : list of callables, optional
567 Additional optimization passes to perform.
568 **kwargs
569 Extra keyword arguments to forward to the optimization passes.
571 Examples
572 --------
573 >>> import dask
574 >>> import dask.array as da
575 >>> a = da.arange(10, chunks=2).sum()
576 >>> b = da.arange(10, chunks=2).mean()
577 >>> a2, b2 = dask.optimize(a, b)
579 >>> a2.compute() == a.compute()
580 True
581 >>> b2.compute() == b.compute()
582 True
583 """
584 # TODO: This API is problematic. The approach to using postpersist forces us
585 # to materialize the graph. Most low level optimizations will materialize as
586 # well
587 collections, repack = unpack_collections(*args, traverse=traverse)
588 if not collections:
589 return args
591 dsk = collections_to_expr(collections)
593 postpersists = []
594 for a in collections:
595 r, s = a.__dask_postpersist__()
596 postpersists.append(r(dsk.__dask_graph__(), *s))
598 return repack(postpersists)
601def compute(
602 *args,
603 traverse=True,
604 optimize_graph=True,
605 scheduler=None,
606 get=None,
607 **kwargs,
608):
609 """Compute several dask collections at once.
611 Parameters
612 ----------
613 args : object
614 Any number of objects. If it is a dask object, it's computed and the
615 result is returned. By default, python builtin collections are also
616 traversed to look for dask objects (for more information see the
617 ``traverse`` keyword). Non-dask arguments are passed through unchanged.
618 traverse : bool, optional
619 By default dask traverses builtin python collections looking for dask
620 objects passed to ``compute``. For large collections this can be
621 expensive. If none of the arguments contain any dask objects, set
622 ``traverse=False`` to avoid doing this traversal.
623 scheduler : string, optional
624 Which scheduler to use like "threads", "synchronous" or "processes".
625 If not provided, the default is to check the global settings first,
626 and then fall back to the collection defaults.
627 optimize_graph : bool, optional
628 If True [default], the optimizations for each collection are applied
629 before computation. Otherwise the graph is run as is. This can be
630 useful for debugging.
631 get : ``None``
632 Should be left to ``None`` The get= keyword has been removed.
633 kwargs
634 Extra keywords to forward to the scheduler function.
636 Examples
637 --------
638 >>> import dask
639 >>> import dask.array as da
640 >>> a = da.arange(10, chunks=2).sum()
641 >>> b = da.arange(10, chunks=2).mean()
642 >>> dask.compute(a, b)
643 (45, 4.5)
645 By default, dask objects inside python collections will also be computed:
647 >>> dask.compute({'a': a, 'b': b, 'c': 1})
648 ({'a': 45, 'b': 4.5, 'c': 1},)
649 """
651 collections, repack = unpack_collections(*args, traverse=traverse)
652 if not collections:
653 return args
655 schedule = get_scheduler(
656 scheduler=scheduler,
657 collections=collections,
658 get=get,
659 )
660 from dask._expr import FinalizeCompute
662 expr = collections_to_expr(collections, optimize_graph)
663 expr = FinalizeCompute(expr)
665 with shorten_traceback():
666 # The high level optimize will have to be called client side (for now)
667 # The optimize can internally trigger already a computation
668 # (e.g. parquet is reading some statistics). To move this to the
669 # scheduler we'd need some sort of scheduler-client to trigger a
670 # computation from inside the scheduler and continue with optimization
671 # once the results are in. An alternative could be to introduce a
672 # pre-optimize step for the Expressions that handles steps like these as
673 # a dedicated computation
675 # Another caveat is that optimize will only lock in the expression names
676 # after optimization. Names are determined using tokenize and tokenize
677 # is not cross-interpreter (let alone cross-host) stable such that we
678 # have to lock this in before sending stuff (otherwise we'd need to
679 # change the graph submission to a handshake which introduces all sorts
680 # of concurrency control issues)
682 expr = expr.optimize()
683 keys = list(flatten(expr.__dask_keys__()))
685 results = schedule(expr, keys, **kwargs)
687 return repack(results)
690def visualize(
691 *args,
692 filename="mydask",
693 traverse=True,
694 optimize_graph=False,
695 maxval=None,
696 engine: Literal["cytoscape", "ipycytoscape", "graphviz"] | None = None,
697 **kwargs,
698):
699 """
700 Visualize several dask graphs simultaneously.
702 Requires ``graphviz`` to be installed. All options that are not the dask
703 graph(s) should be passed as keyword arguments.
705 Parameters
706 ----------
707 args : object
708 Any number of objects. If it is a dask collection (for example, a
709 dask DataFrame, Array, Bag, or Delayed), its associated graph
710 will be included in the output of visualize. By default, python builtin
711 collections are also traversed to look for dask objects (for more
712 information see the ``traverse`` keyword). Arguments lacking an
713 associated graph will be ignored.
714 filename : str or None, optional
715 The name of the file to write to disk. If the provided `filename`
716 doesn't include an extension, '.png' will be used by default.
717 If `filename` is None, no file will be written, and we communicate
718 with dot using only pipes.
719 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional
720 Format in which to write output file. Default is 'png'.
721 traverse : bool, optional
722 By default, dask traverses builtin python collections looking for dask
723 objects passed to ``visualize``. For large collections this can be
724 expensive. If none of the arguments contain any dask objects, set
725 ``traverse=False`` to avoid doing this traversal.
726 optimize_graph : bool, optional
727 If True, the graph is optimized before rendering. Otherwise,
728 the graph is displayed as is. Default is False.
729 color : {None, 'order', 'ages', 'freed', 'memoryincreases', 'memorydecreases', 'memorypressure'}, optional
730 Options to color nodes. colormap:
732 - None, the default, no colors.
733 - 'order', colors the nodes' border based on the order they appear in the graph.
734 - 'ages', how long the data of a node is held.
735 - 'freed', the number of dependencies released after running a node.
736 - 'memoryincreases', how many more outputs are held after the lifetime of a node.
737 Large values may indicate nodes that should have run later.
738 - 'memorydecreases', how many fewer outputs are held after the lifetime of a node.
739 Large values may indicate nodes that should have run sooner.
740 - 'memorypressure', the number of data held when the node is run (circle), or
741 the data is released (rectangle).
742 maxval : {int, float}, optional
743 Maximum value for colormap to normalize form 0 to 1.0. Default is ``None``
744 will make it the max number of values
745 collapse_outputs : bool, optional
746 Whether to collapse output boxes, which often have empty labels.
747 Default is False.
748 verbose : bool, optional
749 Whether to label output and input boxes even if the data aren't chunked.
750 Beware: these labels can get very long. Default is False.
751 engine : {"graphviz", "ipycytoscape", "cytoscape"}, optional.
752 The visualization engine to use. If not provided, this checks the dask config
753 value "visualization.engine". If that is not set, it tries to import ``graphviz``
754 and ``ipycytoscape``, using the first one to succeed.
755 **kwargs
756 Additional keyword arguments to forward to the visualization engine.
758 Examples
759 --------
760 >>> x.visualize(filename='dask.pdf') # doctest: +SKIP
761 >>> x.visualize(filename='dask.pdf', color='order') # doctest: +SKIP
763 Returns
764 -------
765 result : IPython.display.Image, IPython.display.SVG, or None
766 See dask.dot.dot_graph for more information.
768 See Also
769 --------
770 dask.dot.dot_graph
772 Notes
773 -----
774 For more information on optimization see here:
776 https://docs.dask.org/en/latest/optimize.html
777 """
778 args, _ = unpack_collections(*args, traverse=traverse)
780 dsk = collections_to_expr(args, optimize_graph=optimize_graph).__dask_graph__()
782 return visualize_dsk(
783 dsk=dsk,
784 filename=filename,
785 traverse=traverse,
786 optimize_graph=optimize_graph,
787 maxval=maxval,
788 engine=engine,
789 **kwargs,
790 )
793def visualize_dsk(
794 dsk,
795 filename="mydask",
796 traverse=True,
797 optimize_graph=False,
798 maxval=None,
799 o=None,
800 engine: Literal["cytoscape", "ipycytoscape", "graphviz"] | None = None,
801 limit=None,
802 **kwargs,
803):
804 color = kwargs.get("color")
805 from dask.order import diagnostics, order
807 if color in {
808 "order",
809 "order-age",
810 "order-freed",
811 "order-memoryincreases",
812 "order-memorydecreases",
813 "order-memorypressure",
814 "age",
815 "freed",
816 "memoryincreases",
817 "memorydecreases",
818 "memorypressure",
819 "critical",
820 "cpath",
821 }:
822 import matplotlib.pyplot as plt
824 if o is None:
825 o_stats = order(dsk, return_stats=True)
826 o = {k: v.priority for k, v in o_stats.items()}
827 elif isinstance(next(iter(o.values())), int):
828 o_stats = order(dsk, return_stats=True)
829 else:
830 o_stats = o
831 o = {k: v.priority for k, v in o.items()}
833 try:
834 cmap = kwargs.pop("cmap")
835 except KeyError:
836 cmap = plt.cm.plasma
837 if isinstance(cmap, str):
838 import matplotlib.pyplot as plt
840 cmap = getattr(plt.cm, cmap)
842 def label(x):
843 return str(values[x])
845 data_values = None
846 if color != "order":
847 info = diagnostics(dsk, o)[0]
848 if color.endswith("age"):
849 values = {key: val.age for key, val in info.items()}
850 elif color.endswith("freed"):
851 values = {key: val.num_dependencies_freed for key, val in info.items()}
852 elif color.endswith("memorypressure"):
853 values = {key: val.num_data_when_run for key, val in info.items()}
854 data_values = {
855 key: val.num_data_when_released for key, val in info.items()
856 }
857 elif color.endswith("memoryincreases"):
858 values = {
859 key: max(0, val.num_data_when_released - val.num_data_when_run)
860 for key, val in info.items()
861 }
862 elif color.endswith("memorydecreases"):
863 values = {
864 key: max(0, val.num_data_when_run - val.num_data_when_released)
865 for key, val in info.items()
866 }
867 elif color.split("-")[-1] in {"critical", "cpath"}:
868 values = {key: val.critical_path for key, val in o_stats.items()}
869 else:
870 raise NotImplementedError(color)
872 if color.startswith("order-"):
874 def label(x):
875 return f"{o[x]}-{values[x]}"
877 else:
878 values = o
879 if maxval is None:
880 maxval = max(1, *values.values())
881 colors = {
882 k: _colorize(tuple(map(int, cmap(v / maxval, bytes=True))))
883 for k, v in values.items()
884 }
885 if data_values is None:
886 data_colors = colors
887 else:
888 data_colors = {
889 k: _colorize(tuple(map(int, cmap(v / maxval, bytes=True))))
890 for k, v in values.items()
891 }
893 kwargs["function_attributes"] = {
894 k: {"color": v, "label": label(k)} for k, v in colors.items()
895 }
896 kwargs["data_attributes"] = {k: {"color": v} for k, v in data_colors.items()}
897 elif color:
898 raise NotImplementedError(f"Unknown value color={color}")
900 # Determine which engine to dispatch to, first checking the kwarg, then config,
901 # then whichever of graphviz or ipycytoscape are installed, in that order.
902 engine = engine or config.get("visualization.engine", None)
904 if not engine:
905 try:
906 import graphviz # noqa: F401
908 engine = "graphviz"
909 except ImportError:
910 try:
911 import ipycytoscape # noqa: F401
913 engine = "cytoscape"
914 except ImportError:
915 pass
916 if engine == "graphviz":
917 from dask.dot import dot_graph
919 return dot_graph(dsk, filename=filename, **kwargs)
920 elif engine in ("cytoscape", "ipycytoscape"):
921 from dask.dot import cytoscape_graph
923 return cytoscape_graph(dsk, filename=filename, **kwargs)
924 elif engine is None:
925 raise RuntimeError(
926 "No visualization engine detected, please install graphviz or ipycytoscape"
927 )
928 else:
929 raise ValueError(f"Visualization engine {engine} not recognized")
932def persist(*args, traverse=True, optimize_graph=True, scheduler=None, **kwargs):
933 """Persist multiple Dask collections into memory
935 This turns lazy Dask collections into Dask collections with the same
936 metadata, but now with their results fully computed or actively computing
937 in the background.
939 For example a lazy dask.array built up from many lazy calls will now be a
940 dask.array of the same shape, dtype, chunks, etc., but now with all of
941 those previously lazy tasks either computed in memory as many small :class:`numpy.array`
942 (in the single-machine case) or asynchronously running in the
943 background on a cluster (in the distributed case).
945 This function operates differently if a ``dask.distributed.Client`` exists
946 and is connected to a distributed scheduler. In this case this function
947 will return as soon as the task graph has been submitted to the cluster,
948 but before the computations have completed. Computations will continue
949 asynchronously in the background. When using this function with the single
950 machine scheduler it blocks until the computations have finished.
952 When using Dask on a single machine you should ensure that the dataset fits
953 entirely within memory.
955 Examples
956 --------
957 >>> df = dd.read_csv('/path/to/*.csv') # doctest: +SKIP
958 >>> df = df[df.name == 'Alice'] # doctest: +SKIP
959 >>> df['in-debt'] = df.balance < 0 # doctest: +SKIP
960 >>> df = df.persist() # triggers computation # doctest: +SKIP
962 >>> df.value().min() # future computations are now fast # doctest: +SKIP
963 -10
964 >>> df.value().max() # doctest: +SKIP
965 100
967 >>> from dask import persist # use persist function on multiple collections
968 >>> a, b = persist(a, b) # doctest: +SKIP
970 Parameters
971 ----------
972 *args: Dask collections
973 scheduler : string, optional
974 Which scheduler to use like "threads", "synchronous" or "processes".
975 If not provided, the default is to check the global settings first,
976 and then fall back to the collection defaults.
977 traverse : bool, optional
978 By default dask traverses builtin python collections looking for dask
979 objects passed to ``persist``. For large collections this can be
980 expensive. If none of the arguments contain any dask objects, set
981 ``traverse=False`` to avoid doing this traversal.
982 optimize_graph : bool, optional
983 If True [default], the graph is optimized before computation.
984 Otherwise the graph is run as is. This can be useful for debugging.
985 **kwargs
986 Extra keywords to forward to the scheduler function.
988 Returns
989 -------
990 New dask collections backed by in-memory data
991 """
992 collections, repack = unpack_collections(*args, traverse=traverse)
993 if not collections:
994 return args
996 schedule = get_scheduler(scheduler=scheduler, collections=collections)
998 # Protocol: scheduler can provide its own persist method for async behavior.
999 # For Client-like objects, get_scheduler returns client.get (a bound method),
1000 # so we check __self__ for the actual client instance.
1001 client = getattr(schedule, "__self__", schedule)
1002 if hasattr(client, "persist") and callable(client.persist):
1003 results = client.persist(collections, optimize_graph=optimize_graph, **kwargs)
1004 return repack(results)
1006 expr = collections_to_expr(collections, optimize_graph)
1007 expr = expr.optimize()
1008 keys, postpersists = [], []
1009 for a, akeys in zip(collections, expr.__dask_keys__(), strict=True):
1010 a_keys = list(flatten(akeys))
1011 rebuild, state = a.__dask_postpersist__()
1012 keys.extend(a_keys)
1013 postpersists.append((rebuild, a_keys, state))
1015 with shorten_traceback():
1016 results = schedule(expr, keys, **kwargs)
1018 d = dict(zip(keys, results))
1019 results2 = [r({k: d[k] for k in ks}, *s) for r, ks, s in postpersists]
1020 return repack(results2)
1023def _colorize(t):
1024 """Convert (r, g, b) triple to "#RRGGBB" string
1026 For use with ``visualize(color=...)``
1028 Examples
1029 --------
1030 >>> _colorize((255, 255, 255))
1031 '#FFFFFF'
1032 >>> _colorize((0, 32, 128))
1033 '#002080'
1034 """
1035 t = t[:3]
1036 i = sum(v << 8 * i for i, v in enumerate(reversed(t)))
1037 return f"#{i:>06X}"
1040named_schedulers: dict[str, SchedulerGetCallable] = {
1041 "sync": local.get_sync,
1042 "synchronous": local.get_sync,
1043 "single-threaded": local.get_sync,
1044}
1046if not EMSCRIPTEN:
1047 from dask import threaded
1049 named_schedulers.update(
1050 {
1051 "threads": threaded.get,
1052 "threading": threaded.get,
1053 }
1054 )
1056 from dask import multiprocessing as dask_multiprocessing
1058 named_schedulers.update(
1059 {
1060 "processes": dask_multiprocessing.get,
1061 "multiprocessing": dask_multiprocessing.get,
1062 }
1063 )
1066get_err_msg = """
1067The get= keyword has been removed.
1069Please use the scheduler= keyword instead with the name of
1070the desired scheduler like 'threads' or 'processes'
1072 x.compute(scheduler='single-threaded')
1073 x.compute(scheduler='threads')
1074 x.compute(scheduler='processes')
1076or with a function that takes the graph and keys
1078 x.compute(scheduler=my_scheduler_function)
1080or with a Dask client
1082 x.compute(scheduler=client)
1083""".strip()
1086def _ensure_not_async(client):
1087 if client.asynchronous:
1088 if fallback := config.get("admin.async-client-fallback", None):
1089 warnings.warn(
1090 "Distributed Client detected but Client instance is "
1091 f"asynchronous. Falling back to `{fallback}` scheduler. "
1092 "To use an asynchronous Client, please use "
1093 "``Client.compute`` and ``Client.gather`` "
1094 "instead of the top level ``dask.compute``",
1095 UserWarning,
1096 )
1097 return get_scheduler(scheduler=fallback)
1098 else:
1099 raise RuntimeError(
1100 "Attempting to use an asynchronous "
1101 "Client in a synchronous context of `dask.compute`"
1102 )
1103 return client.get
1106def get_scheduler(get=None, scheduler=None, collections=None, cls=None):
1107 """Get scheduler function
1109 There are various ways to specify the scheduler to use:
1111 1. Passing in scheduler= parameters
1112 2. Passing these into global configuration
1113 3. Using a dask.distributed default Client
1114 4. Using defaults of a dask collection
1116 This function centralizes the logic to determine the right scheduler to use
1117 from those many options
1118 """
1119 if get:
1120 raise TypeError(get_err_msg)
1122 if scheduler is not None:
1123 if callable(scheduler):
1124 return scheduler
1125 elif "Client" in type(scheduler).__name__ and hasattr(scheduler, "get"):
1126 return _ensure_not_async(scheduler)
1127 elif isinstance(scheduler, str):
1128 scheduler = scheduler.lower()
1130 client_available = False
1131 if _distributed_available():
1132 assert _DistributedClient is not None
1133 with suppress(ValueError):
1134 _DistributedClient.current(allow_global=True)
1135 client_available = True
1136 if scheduler in named_schedulers:
1137 return named_schedulers[scheduler]
1138 elif scheduler in ("dask.distributed", "distributed"):
1139 if not client_available:
1140 raise RuntimeError(
1141 f"Requested {scheduler} scheduler but no Client active."
1142 )
1143 assert _get_distributed_client is not None
1144 client = _get_distributed_client()
1145 return _ensure_not_async(client)
1146 else:
1147 raise ValueError(
1148 "Expected one of [distributed, {}]".format(
1149 ", ".join(sorted(named_schedulers))
1150 )
1151 )
1152 elif isinstance(scheduler, Executor):
1153 # Get `num_workers` from `Executor`'s `_max_workers` attribute.
1154 # If undefined, fallback to `config` or worst case CPU_COUNT.
1155 num_workers = getattr(scheduler, "_max_workers", None)
1156 if num_workers is None:
1157 num_workers = config.get("num_workers", CPU_COUNT)
1158 assert isinstance(num_workers, Integral) and num_workers > 0
1159 return partial(local.get_async, scheduler.submit, num_workers)
1160 else:
1161 raise ValueError(f"Unexpected scheduler: {scheduler!r}")
1162 # else: # try to connect to remote scheduler with this name
1163 # return get_client(scheduler).get
1165 if config.get("scheduler", None):
1166 return get_scheduler(scheduler=config.get("scheduler", None))
1168 if config.get("get", None):
1169 raise ValueError(get_err_msg)
1171 try:
1172 from distributed import get_client
1174 return _ensure_not_async(get_client())
1175 except (ImportError, ValueError):
1176 pass
1178 if cls is not None:
1179 return cls.__dask_scheduler__
1181 if collections:
1182 collections = [c for c in collections if c is not None]
1183 if collections:
1184 get = collections[0].__dask_scheduler__
1185 if not all(c.__dask_scheduler__ == get for c in collections):
1186 raise ValueError(
1187 "Compute called on multiple collections with "
1188 "differing default schedulers. Please specify a "
1189 "scheduler=` parameter explicitly in compute or "
1190 "globally with `dask.config.set`."
1191 )
1192 return get
1194 return None
1197def wait(x, timeout=None, return_when="ALL_COMPLETED"):
1198 """Wait until computation has finished
1200 This is a compatibility alias for ``dask.distributed.wait``.
1201 If it is applied onto Dask collections without Dask Futures or if Dask
1202 distributed is not installed then it is a no-op
1203 """
1204 try:
1205 from distributed import wait
1207 return wait(x, timeout=timeout, return_when=return_when)
1208 except (ImportError, ValueError):
1209 return x
1212def get_collection_names(collection) -> set[str]:
1213 """Infer the collection names from the dask keys, under the assumption that all keys
1214 are either tuples with matching first element, and that element is a string, or
1215 there is exactly one key and it is a string.
1217 Examples
1218 --------
1219 >>> a.__dask_keys__() # doctest: +SKIP
1220 ["foo", "bar"]
1221 >>> get_collection_names(a) # doctest: +SKIP
1222 {"foo", "bar"}
1223 >>> b.__dask_keys__() # doctest: +SKIP
1224 [[("foo-123", 0, 0), ("foo-123", 0, 1)], [("foo-123", 1, 0), ("foo-123", 1, 1)]]
1225 >>> get_collection_names(b) # doctest: +SKIP
1226 {"foo-123"}
1227 """
1228 if not is_dask_collection(collection):
1229 raise TypeError(f"Expected Dask collection; got {type(collection)}")
1230 return {get_name_from_key(k) for k in flatten(collection.__dask_keys__())}
1233def get_name_from_key(key: Key) -> str:
1234 """Given a dask collection's key, extract the collection name.
1236 Parameters
1237 ----------
1238 key: string or tuple
1239 Dask collection's key, which must be either a single string or a tuple whose
1240 first element is a string (commonly referred to as a collection's 'name'),
1242 Examples
1243 --------
1244 >>> get_name_from_key("foo")
1245 'foo'
1246 >>> get_name_from_key(("foo-123", 1, 2))
1247 'foo-123'
1248 """
1249 if isinstance(key, tuple) and key and isinstance(key[0], str):
1250 return key[0]
1251 if isinstance(key, str):
1252 return key
1253 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}")
1256KeyOrStrT = TypeVar("KeyOrStrT", Key, str)
1259def replace_name_in_key(key: KeyOrStrT, rename: Mapping[str, str]) -> KeyOrStrT:
1260 """Given a dask collection's key, replace the collection name with a new one.
1262 Parameters
1263 ----------
1264 key: string or tuple
1265 Dask collection's key, which must be either a single string or a tuple whose
1266 first element is a string (commonly referred to as a collection's 'name'),
1267 rename:
1268 Mapping of zero or more names from : to. Extraneous names will be ignored.
1269 Names not found in this mapping won't be replaced.
1271 Examples
1272 --------
1273 >>> replace_name_in_key("foo", {})
1274 'foo'
1275 >>> replace_name_in_key("foo", {"foo": "bar"})
1276 'bar'
1277 >>> replace_name_in_key(("foo-123", 1, 2), {"foo-123": "bar-456"})
1278 ('bar-456', 1, 2)
1279 """
1280 if isinstance(key, tuple) and key and isinstance(key[0], str):
1281 return (rename.get(key[0], key[0]),) + key[1:]
1282 if isinstance(key, str):
1283 return rename.get(key, key)
1284 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}")
1287def clone_key(key: KeyOrStrT, seed: Hashable) -> KeyOrStrT:
1288 """Clone a key from a Dask collection, producing a new key with the same prefix and
1289 indices and a token which is a deterministic function of the previous key and seed.
1291 Examples
1292 --------
1293 >>> clone_key("x", 123) # doctest: +SKIP
1294 'x-c4fb64ccca807af85082413d7ef01721'
1295 >>> clone_key("inc-cbb1eca3bafafbb3e8b2419c4eebb387", 123) # doctest: +SKIP
1296 'inc-bc629c23014a4472e18b575fdaf29ee7'
1297 >>> clone_key(("sum-cbb1eca3bafafbb3e8b2419c4eebb387", 4, 3), 123) # doctest: +SKIP
1298 ('sum-c053f3774e09bd0f7de6044dbc40e71d', 4, 3)
1299 """
1300 if isinstance(key, tuple) and key and isinstance(key[0], str):
1301 return (clone_key(key[0], seed),) + key[1:]
1302 if isinstance(key, str):
1303 prefix = key_split(key)
1304 return prefix + "-" + tokenize(key, seed)
1305 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}")