Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/_task_spec.py: 25%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3""" Task specification for dask
5This module contains the task specification for dask. It is used to represent
6runnable (task) and non-runnable (data) nodes in a dask graph.
8Simple examples of how to express tasks in dask
9-----------------------------------------------
11.. code-block:: python
13 func("a", "b") ~ Task("key", func, "a", "b")
15 [func("a"), func("b")] ~ [Task("key-1", func, "a"), Task("key-2", func, "b")]
17 {"a": func("b")} ~ {"a": Task("a", func, "b")}
19 "literal-string" ~ DataNode("key", "literal-string")
22Keys, Aliases and TaskRefs
23-------------------------
25Keys are used to identify tasks in a dask graph. Every `GraphNode` instance has a
26key attribute that _should_ reference the key in the dask graph.
28.. code-block:: python
30 {"key": Task("key", func, "a")}
32Referencing other tasks is possible by using either one of `Alias` or a
33`TaskRef`.
35.. code-block:: python
37 # TaskRef can be used to provide the name of the reference explicitly
38 t = Task("key", func, TaskRef("key-1"))
40 # If a task is still in scope, the method `ref` can be used for convenience
41 t2 = Task("key2", func2, t.ref())
44Executing a task
45----------------
47A task can be executed by calling it with a dictionary of values. The values
48should contain the dependencies of the task.
50.. code-block:: python
52 t = Task("key", add, TaskRef("a"), TaskRef("b"))
53 assert t.dependencies == {"a", "b"}
54 t({"a": 1, "b": 2}) == 3
56"""
57import functools
58import itertools
59import sys
60from collections import defaultdict
61from collections.abc import Callable, Container, Iterable, Mapping, MutableMapping
62from functools import lru_cache, partial
63from typing import Any, TypeVar, cast
65from dask.sizeof import sizeof
66from dask.typing import Key as KeyType
67from dask.utils import funcname, is_namedtuple_instance
69_T = TypeVar("_T")
72# Ported from more-itertools
73# https://github.com/more-itertools/more-itertools/blob/c8153e2801ade2527f3a6c8b623afae93f5a1ce1/more_itertools/recipes.py#L944-L973
74def _batched(iterable, n, *, strict=False):
75 """Batch data into tuples of length *n*. If the number of items in
76 *iterable* is not divisible by *n*:
77 * The last batch will be shorter if *strict* is ``False``.
78 * :exc:`ValueError` will be raised if *strict* is ``True``.
80 >>> list(batched('ABCDEFG', 3))
81 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
83 On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
84 """
85 if n < 1:
86 raise ValueError("n must be at least one")
87 it = iter(iterable)
88 while batch := tuple(itertools.islice(it, n)):
89 if strict and len(batch) != n:
90 raise ValueError("batched(): incomplete batch")
91 yield batch
94if sys.hexversion >= 0x30D00A2:
96 def batched(iterable, n, *, strict=False):
97 return itertools.batched(iterable, n, strict=strict)
99else:
100 batched = _batched
102 batched.__doc__ = _batched.__doc__
103# End port
106def identity(*args):
107 return args
110def _identity_cast(*args, typ):
111 return typ(args)
114_anom_count = itertools.count()
117def parse_input(obj: Any) -> object:
118 """Tokenize user input into GraphNode objects
120 Note: This is similar to `convert_legacy_task` but does not
121 - compare any values to a global set of known keys to infer references/futures
122 - parse tuples and interprets them as runnable tasks
123 - Deal with SubgraphCallables
125 Parameters
126 ----------
127 obj : _type_
128 _description_
130 Returns
131 -------
132 _type_
133 _description_
134 """
135 if isinstance(obj, GraphNode):
136 return obj
138 if isinstance(obj, TaskRef):
139 return Alias(obj.key)
141 if isinstance(obj, dict):
142 parsed_dict = {k: parse_input(v) for k, v in obj.items()}
143 if any(isinstance(v, GraphNode) for v in parsed_dict.values()):
144 return Dict(parsed_dict)
146 if isinstance(obj, (list, set, tuple)):
147 parsed_collection = tuple(parse_input(o) for o in obj)
148 if any(isinstance(o, GraphNode) for o in parsed_collection):
149 if isinstance(obj, list):
150 return List(*parsed_collection)
151 if isinstance(obj, set):
152 return Set(*parsed_collection)
153 if isinstance(obj, tuple):
154 if is_namedtuple_instance(obj):
155 return _wrap_namedtuple_task(None, obj, parse_input)
156 return Tuple(*parsed_collection)
158 return obj
161def _wrap_namedtuple_task(k, obj, parser):
162 if hasattr(obj, "__getnewargs_ex__"):
163 new_args, kwargs = obj.__getnewargs_ex__()
164 kwargs = {k: parser(v) for k, v in kwargs.items()}
165 elif hasattr(obj, "__getnewargs__"):
166 new_args = obj.__getnewargs__()
167 kwargs = {}
169 args_converted = parse_input(type(new_args)(map(parser, new_args)))
171 return Task(
172 k, partial(_instantiate_named_tuple, type(obj)), args_converted, Dict(kwargs)
173 )
176def _instantiate_named_tuple(typ, args, kwargs):
177 return typ(*args, **kwargs)
180class _MultiContainer(Container):
181 container: tuple
182 __slots__ = ("container",)
184 def __init__(self, *container):
185 self.container = container
187 def __contains__(self, o: object) -> bool:
188 return any(o in c for c in self.container)
191SubgraphType = None
194def _execute_subgraph(inner_dsk, outkey, inkeys, *dependencies):
195 final = {}
196 final.update(inner_dsk)
197 for k, v in zip(inkeys, dependencies):
198 final[k] = DataNode(None, v)
199 res = execute_graph(final, keys=[outkey])
200 return res[outkey]
203def convert_legacy_task(
204 key: KeyType | None,
205 task: _T,
206 all_keys: Container,
207) -> GraphNode | _T:
208 if isinstance(task, GraphNode):
209 return task
211 if type(task) is tuple and task and callable(task[0]):
212 func, args = task[0], task[1:]
213 new_args = []
214 new: object
215 for a in args:
216 if isinstance(a, dict):
217 new = Dict(a)
218 else:
219 new = convert_legacy_task(None, a, all_keys)
220 new_args.append(new)
221 return Task(key, func, *new_args)
222 try:
223 if isinstance(task, (int, float, str, tuple)):
224 if task in all_keys:
225 if key is None:
226 return Alias(task)
227 else:
228 return Alias(key, target=task)
229 except TypeError:
230 # Unhashable
231 pass
233 if isinstance(task, (list, tuple, set, frozenset)):
234 if is_namedtuple_instance(task):
235 return _wrap_namedtuple_task(
236 key,
237 task,
238 partial(
239 convert_legacy_task,
240 None,
241 all_keys=all_keys,
242 ),
243 )
244 else:
245 parsed_args = tuple(convert_legacy_task(None, t, all_keys) for t in task)
246 if any(isinstance(a, GraphNode) for a in parsed_args):
247 return Task(key, _identity_cast, *parsed_args, typ=type(task))
248 else:
249 return cast(_T, type(task)(parsed_args))
250 elif isinstance(task, TaskRef):
251 if key is None:
252 return Alias(task.key)
253 else:
254 return Alias(key, target=task.key)
255 else:
256 return task
259def convert_legacy_graph(
260 dsk: Mapping,
261 all_keys: Container | None = None,
262):
263 if all_keys is None:
264 all_keys = set(dsk)
265 new_dsk = {}
266 for k, arg in dsk.items():
267 t = convert_legacy_task(k, arg, all_keys)
268 if isinstance(t, Alias) and t.target == k:
269 continue
270 elif not isinstance(t, GraphNode):
271 t = DataNode(k, t)
272 new_dsk[k] = t
273 return new_dsk
276def resolve_aliases(dsk: dict, keys: set, dependents: dict) -> dict:
277 """Remove trivial sequential alias chains
279 Example:
281 dsk = {'x': 1, 'y': Alias('x'), 'z': Alias('y')}
283 resolve_aliases(dsk, {'z'}, {'x': {'y'}, 'y': {'z'}}) == {'z': 1}
285 """
286 if not keys:
287 raise ValueError("No keys provided")
288 dsk = dict(dsk)
289 work = list(keys)
290 seen = set()
291 while work:
292 k = work.pop()
293 if k in seen or k not in dsk:
294 continue
295 seen.add(k)
296 t = dsk[k]
297 if isinstance(t, Alias):
298 target_key = t.target
299 # Rules for when we allow to collapse an alias
300 # 1. The target key is not in the keys set. The keys set is what the
301 # user is requesting and by collapsing we'd no longer be able to
302 # return that result.
303 # 2. The target key is in fact part of dsk. If it isn't this could
304 # point to a persisted dependency and we cannot collapse it.
305 # 3. The target key has only one dependent which is the key we're
306 # currently looking at. This means that there is a one to one
307 # relation between this and the target key in which case we can
308 # collapse them.
309 # Note: If target was an alias as well, we could continue with
310 # more advanced optimizations but this isn't implemented, yet
311 if (
312 target_key not in keys
313 and target_key in dsk
314 # Note: whenever we're performing a collapse, we're not updating
315 # the dependents. The length == 1 should still be sufficient for
316 # chains of these aliases
317 and len(dependents[target_key]) == 1
318 ):
319 tnew = dsk.pop(target_key).copy()
321 dsk[k] = tnew
322 tnew.key = k
323 if isinstance(tnew, Alias):
324 work.append(k)
325 seen.discard(k)
326 else:
327 work.extend(tnew.dependencies)
329 work.extend(t.dependencies)
330 return dsk
333class TaskRef:
334 val: KeyType
335 __slots__ = ("key",)
337 def __init__(self, key: KeyType):
338 self.key = key
340 def __str__(self):
341 return str(self.key)
343 def __repr__(self):
344 return f"{type(self).__name__}({self.key!r})"
346 def __hash__(self) -> int:
347 return hash(self.key)
349 def __eq__(self, value: object) -> bool:
350 if not isinstance(value, TaskRef):
351 return False
352 return self.key == value.key
354 def __reduce__(self):
355 return TaskRef, (self.key,)
357 def substitute(self, subs: dict, key: KeyType | None = None) -> TaskRef | GraphNode:
358 if self.key in subs:
359 val = subs[self.key]
360 if isinstance(val, GraphNode):
361 return val.substitute({}, key=self.key)
362 elif isinstance(val, TaskRef):
363 return val
364 else:
365 return TaskRef(val)
366 return self
369class GraphNode:
370 key: KeyType
371 _dependencies: frozenset
373 __slots__ = tuple(__annotations__)
375 def ref(self):
376 return Alias(self.key)
378 def copy(self):
379 raise NotImplementedError
381 @property
382 def data_producer(self) -> bool:
383 return False
385 @property
386 def dependencies(self) -> frozenset:
387 return self._dependencies
389 @property
390 def block_fusion(self) -> bool:
391 return False
393 def _verify_values(self, values: tuple | dict) -> None:
394 if not self.dependencies:
395 return
396 if missing := set(self.dependencies) - set(values):
397 raise RuntimeError(f"Not enough arguments provided: missing keys {missing}")
399 def __call__(self, values) -> Any:
400 raise NotImplementedError("Not implemented")
402 def __eq__(self, value: object) -> bool:
403 if type(value) is not type(self):
404 return False
406 from dask.tokenize import tokenize
408 return tokenize(self) == tokenize(value)
410 @property
411 def is_coro(self) -> bool:
412 return False
414 def __sizeof__(self) -> int:
415 all_slots = self.get_all_slots()
416 return sum(sizeof(getattr(self, sl)) for sl in all_slots) + sys.getsizeof(
417 type(self)
418 )
420 def substitute(
421 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
422 ) -> GraphNode:
423 """Substitute a dependency with a new value. The new value either has to
424 be a new valid key or a GraphNode to replace the dependency entirely.
426 The GraphNode will not be mutated but instead a shallow copy will be
427 returned. The substitution will be performed eagerly.
429 Parameters
430 ----------
431 subs : dict[KeyType, KeyType | GraphNode]
432 The mapping describing the substitutions to be made.
433 key : KeyType | None, optional
434 The key of the new GraphNode object. If None provided, the key of
435 the old one will be reused.
436 """
437 raise NotImplementedError
439 @staticmethod
440 def fuse(*tasks: GraphNode, key: KeyType | None = None) -> GraphNode:
441 """Fuse a set of tasks into a single task.
443 The tasks are fused into a single task that will execute the tasks in a
444 subgraph. The internal tasks are no longer accessible from the outside.
446 All provided tasks must form a valid subgraph that will reduce to a
447 single key. If multiple outputs are possible with the provided tasks, an
448 exception will be raised.
450 The tasks will not be rewritten but instead a new Task will be created
451 that will merely reference the old task objects. This way, Task objects
452 may be reused in multiple fused tasks.
454 Parameters
455 ----------
456 key : KeyType | None, optional
457 The key of the new Task object. If None provided, the key of the
458 final task will be used.
460 See also
461 --------
462 GraphNode.substitute : Easier substitution of dependencies
463 """
464 if any(t.key is None for t in tasks):
465 raise ValueError("Cannot fuse tasks with missing keys")
466 if len(tasks) == 1:
467 return tasks[0].substitute({}, key=key)
468 all_keys = set()
469 all_deps: set[KeyType] = set()
470 for t in tasks:
471 all_deps.update(t.dependencies)
472 all_keys.add(t.key)
473 external_deps = tuple(sorted(all_deps - all_keys, key=hash))
474 leafs = all_keys - all_deps
475 if len(leafs) > 1:
476 raise ValueError(f"Cannot fuse tasks with multiple outputs {leafs}")
478 outkey = leafs.pop()
479 return Task(
480 key or outkey,
481 _execute_subgraph,
482 {t.key: t for t in tasks},
483 outkey,
484 external_deps,
485 *(TaskRef(k) for k in external_deps),
486 _data_producer=any(t.data_producer for t in tasks),
487 )
489 @classmethod
490 @lru_cache
491 def get_all_slots(cls):
492 slots = list()
493 for c in cls.mro():
494 slots.extend(getattr(c, "__slots__", ()))
495 # Interestingly, sorting this causes the nested containers to pickle
496 # more efficiently
497 return sorted(set(slots))
500_no_deps: frozenset = frozenset()
503class Alias(GraphNode):
504 target: KeyType
505 __slots__ = tuple(__annotations__)
507 def __init__(
508 self, key: KeyType | TaskRef, target: Alias | TaskRef | KeyType | None = None
509 ):
510 if isinstance(key, TaskRef):
511 key = key.key
512 self.key = key
513 if target is None:
514 target = key
515 if isinstance(target, Alias):
516 target = target.target
517 if isinstance(target, TaskRef):
518 target = target.key
519 self.target = target
520 self._dependencies = frozenset((self.target,))
522 def __reduce__(self):
523 return Alias, (self.key, self.target)
525 def copy(self):
526 return Alias(self.key, self.target)
528 def substitute(
529 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
530 ) -> GraphNode:
531 if self.key in subs or self.target in subs:
532 sub_key = subs.get(self.key, self.key)
533 val = subs.get(self.target, self.target)
534 if sub_key == self.key and val == self.target:
535 return self
536 if isinstance(val, (GraphNode, TaskRef)):
537 return val.substitute({}, key=key)
538 if key is None and isinstance(sub_key, GraphNode):
539 raise RuntimeError(
540 f"Invalid substitution encountered {self.key!r} -> {sub_key}"
541 )
542 return Alias(key or sub_key, val) # type: ignore
543 return self
545 def __dask_tokenize__(self):
546 return (type(self).__name__, self.key, self.target)
548 def __call__(self, values=()):
549 self._verify_values(values)
550 return values[self.target]
552 def __repr__(self):
553 if self.key != self.target:
554 return f"Alias({self.key!r}->{self.target!r})"
555 else:
556 return f"Alias({self.key!r})"
558 def __eq__(self, value: object) -> bool:
559 if not isinstance(value, Alias):
560 return False
561 if self.key != value.key:
562 return False
563 return self.target == value.target
566class DataNode(GraphNode):
567 value: Any
568 typ: type
569 __slots__ = tuple(__annotations__)
571 def __init__(self, key: Any, value: Any):
572 if key is None:
573 key = (type(value).__name__, next(_anom_count))
574 self.key = key
575 self.value = value
576 self.typ = type(value)
577 self._dependencies = _no_deps
579 @property
580 def data_producer(self) -> bool:
581 return True
583 def copy(self):
584 return DataNode(self.key, self.value)
586 def __call__(self, values=()):
587 return self.value
589 def __repr__(self):
590 return f"DataNode({self.value!r})"
592 def __reduce__(self):
593 return (DataNode, (self.key, self.value))
595 def __dask_tokenize__(self):
596 from dask.base import tokenize
598 return (type(self).__name__, tokenize(self.value))
600 def substitute(
601 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
602 ) -> DataNode:
603 if key is not None and key != self.key:
604 return DataNode(key, self.value)
605 return self
607 def __iter__(self):
608 return iter(self.value)
611def _get_dependencies(obj: object) -> set | frozenset:
612 if isinstance(obj, TaskRef):
613 return {obj.key}
614 elif isinstance(obj, GraphNode):
615 return obj.dependencies
616 elif isinstance(obj, dict):
617 if not obj:
618 return _no_deps
619 return set().union(*map(_get_dependencies, obj.values()))
620 elif isinstance(obj, (list, tuple, frozenset, set)):
621 if not obj:
622 return _no_deps
623 return set().union(*map(_get_dependencies, obj))
624 return _no_deps
627class Task(GraphNode):
628 func: Callable
629 args: tuple
630 kwargs: dict
631 _data_producer: bool
632 _token: str | None
633 _is_coro: bool | None
634 _repr: str | None
636 __slots__ = tuple(__annotations__)
638 def __init__(
639 self,
640 key: Any,
641 func: Callable,
642 /,
643 *args: Any,
644 _data_producer: bool = False,
645 **kwargs: Any,
646 ):
647 self.key = key
648 self.func = func
649 if isinstance(func, Task):
650 raise TypeError("Cannot nest tasks")
652 self.args = args
653 self.kwargs = kwargs
654 _dependencies: set[KeyType] | None = None
655 for a in itertools.chain(args, kwargs.values()):
656 if isinstance(a, TaskRef):
657 if _dependencies is None:
658 _dependencies = {a.key}
659 else:
660 _dependencies.add(a.key)
661 elif isinstance(a, GraphNode) and a.dependencies:
662 if _dependencies is None:
663 _dependencies = set(a.dependencies)
664 else:
665 _dependencies.update(a.dependencies)
666 if _dependencies:
667 self._dependencies = frozenset(_dependencies)
668 else:
669 self._dependencies = _no_deps
670 self._is_coro = None
671 self._token = None
672 self._repr = None
673 self._data_producer = _data_producer
675 @property
676 def data_producer(self) -> bool:
677 return self._data_producer
679 def has_subgraph(self) -> bool:
680 return self.func == _execute_subgraph
682 def copy(self):
683 return type(self)(
684 self.key,
685 self.func,
686 *self.args,
687 **self.kwargs,
688 )
690 def __hash__(self):
691 return hash(self._get_token())
693 def _get_token(self) -> str:
694 if self._token:
695 return self._token
696 from dask.base import tokenize
698 self._token = tokenize(
699 (
700 type(self).__name__,
701 self.func,
702 self.args,
703 self.kwargs,
704 )
705 )
706 return self._token
708 def __dask_tokenize__(self):
709 return self._get_token()
711 def __repr__(self) -> str:
712 # When `Task` is deserialized the constructor will not run and
713 # `self._repr` is thus undefined.
714 if not hasattr(self, "_repr") or not self._repr:
715 head = funcname(self.func)
716 tail = ")"
717 label_size = 40
718 args = self.args
719 kwargs = self.kwargs
720 if args or kwargs:
721 label_size2 = int(
722 (label_size - len(head) - len(tail) - len(str(self.key)))
723 // (len(args) + len(kwargs))
724 )
725 if args:
726 if label_size2 > 5:
727 args_repr = ", ".join(repr(t) for t in args)
728 else:
729 args_repr = "..."
730 else:
731 args_repr = ""
732 if kwargs:
733 if label_size2 > 5:
734 kwargs_repr = ", " + ", ".join(
735 f"{k}={repr(v)}" for k, v in sorted(kwargs.items())
736 )
737 else:
738 kwargs_repr = ", ..."
739 else:
740 kwargs_repr = ""
741 self._repr = f"<Task {self.key!r} {head}({args_repr}{kwargs_repr}{tail}>"
742 return self._repr
744 def __call__(self, values=()):
745 self._verify_values(values)
747 def _eval(a):
748 if isinstance(a, GraphNode):
749 return a({k: values[k] for k in a.dependencies})
750 elif isinstance(a, TaskRef):
751 return values[a.key]
752 else:
753 return a
755 new_argspec = tuple(map(_eval, self.args))
756 if self.kwargs:
757 kwargs = {k: _eval(kw) for k, kw in self.kwargs.items()}
758 return self.func(*new_argspec, **kwargs)
759 return self.func(*new_argspec)
761 def __setstate__(self, state):
762 slots = self.__class__.get_all_slots()
763 for sl, val in zip(slots, state):
764 setattr(self, sl, val)
766 def __getstate__(self):
767 slots = self.__class__.get_all_slots()
768 return tuple(getattr(self, sl) for sl in slots)
770 @property
771 def is_coro(self):
772 if self._is_coro is None:
773 # Note: Can't use cached_property on objects without __dict__
774 try:
775 from distributed.utils import iscoroutinefunction
777 self._is_coro = iscoroutinefunction(self.func)
778 except Exception:
779 self._is_coro = False
780 return self._is_coro
782 def substitute(
783 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
784 ) -> Task:
785 subs_filtered = {
786 k: v for k, v in subs.items() if k in self.dependencies and k != v
787 }
788 extras = _extra_args(type(self)) # type: ignore
789 extra_kwargs = {
790 name: getattr(self, name) for name in extras if name not in {"key", "func"}
791 }
792 if subs_filtered:
793 new_args = tuple(
794 (
795 a.substitute(subs_filtered)
796 if isinstance(a, (GraphNode, TaskRef))
797 else a
798 )
799 for a in self.args
800 )
801 new_kwargs = {
802 k: (
803 v.substitute(subs_filtered)
804 if isinstance(v, (GraphNode, TaskRef))
805 else v
806 )
807 for k, v in self.kwargs.items()
808 }
809 return type(self)(
810 key or self.key,
811 self.func,
812 *new_args,
813 **new_kwargs, # type: ignore[arg-type]
814 **extra_kwargs,
815 )
816 elif key is None or key == self.key:
817 return self
818 else:
819 # Rename
820 return type(self)(
821 key,
822 self.func,
823 *self.args,
824 **self.kwargs,
825 **extra_kwargs,
826 )
829class NestedContainer(Task, Iterable):
830 constructor: Callable
831 klass: type
832 __slots__ = tuple(__annotations__)
834 def __init__(
835 self,
836 /,
837 *args: Any,
838 **kwargs: Any,
839 ):
840 if len(args) == 1 and isinstance(args[0], self.klass):
841 args = args[0] # type: ignore
842 super().__init__(
843 None,
844 self.to_container,
845 *args,
846 constructor=self.constructor,
847 **kwargs,
848 )
850 def __getstate__(self):
851 state = super().__getstate__()
852 state = list(state)
853 slots = self.__class__.get_all_slots()
854 ix = slots.index("kwargs")
855 # The constructor as a kwarg is redundant since this is encoded in the
856 # class itself. Serializing the builtin types is not trivial
857 # This saves about 15% of overhead
858 state[ix] = state[ix].copy()
859 state[ix].pop("constructor", None)
860 return state
862 def __setstate__(self, state):
863 super().__setstate__(state)
864 self.kwargs["constructor"] = self.__class__.constructor
865 return self
867 def __repr__(self):
868 return f"{type(self).__name__}({self.args})"
870 def substitute(
871 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
872 ) -> NestedContainer:
873 subs_filtered = {
874 k: v for k, v in subs.items() if k in self.dependencies and k != v
875 }
876 if not subs_filtered:
877 return self
878 return type(self)(
879 *(
880 (
881 a.substitute(subs_filtered)
882 if isinstance(a, (GraphNode, TaskRef))
883 else a
884 )
885 for a in self.args
886 )
887 )
889 def __dask_tokenize__(self):
890 from dask.tokenize import tokenize
892 return (
893 type(self).__name__,
894 self.klass,
895 sorted(tokenize(a) for a in self.args),
896 )
898 return super().__dask_tokenize__()
900 @staticmethod
901 def to_container(*args, constructor):
902 return constructor(args)
904 def __iter__(self):
905 yield from self.args
908class List(NestedContainer):
909 constructor = klass = list
912class Tuple(NestedContainer):
913 constructor = klass = tuple
916class Set(NestedContainer):
917 constructor = klass = set
920class Dict(NestedContainer, Mapping):
921 klass = dict
923 def __init__(self, /, *args: Any, **kwargs: Any):
924 if args:
925 assert not kwargs
926 if len(args) == 1:
927 args = args[0]
928 if isinstance(args, dict): # type: ignore
929 args = tuple(itertools.chain(*args.items())) # type: ignore
930 elif isinstance(args, (list, tuple)):
931 if all(
932 len(el) == 2 if isinstance(el, (list, tuple)) else False
933 for el in args
934 ):
935 args = tuple(itertools.chain(*args))
936 else:
937 raise ValueError("Invalid argument provided")
939 if len(args) % 2 != 0:
940 raise ValueError("Invalid number of arguments provided")
942 elif kwargs:
943 assert not args
944 args = tuple(itertools.chain(*kwargs.items()))
946 super().__init__(*args)
948 def __repr__(self):
949 values = ", ".join(f"{k}: {v}" for k, v in batched(self.args, 2, strict=True))
950 return f"Dict({values})"
952 def substitute(
953 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
954 ) -> Dict:
955 subs_filtered = {
956 k: v for k, v in subs.items() if k in self.dependencies and k != v
957 }
958 if not subs_filtered:
959 return self
961 new_args = []
962 for arg in self.args:
963 new_arg = (
964 arg.substitute(subs_filtered)
965 if isinstance(arg, (GraphNode, TaskRef))
966 else arg
967 )
968 new_args.append(new_arg)
969 return type(self)(new_args)
971 def __iter__(self):
972 yield from self.args[::2]
974 def __len__(self):
975 return len(self.args) // 2
977 def __getitem__(self, key):
978 for k, v in batched(self.args, 2, strict=True):
979 if k == key:
980 return v
981 raise KeyError(key)
983 @staticmethod
984 def constructor(args):
985 return dict(batched(args, 2, strict=True))
988class DependenciesMapping(MutableMapping):
989 def __init__(self, dsk):
990 self.dsk = dsk
991 self._removed = set()
992 # Set a copy of dsk to avoid dct resizing
993 self._cache = dsk.copy()
994 self._cache.clear()
996 def __getitem__(self, key):
997 if (val := self._cache.get(key)) is not None:
998 return val
999 else:
1000 v = self.dsk[key]
1001 try:
1002 deps = v.dependencies
1003 except AttributeError:
1004 from dask.core import get_dependencies
1006 deps = get_dependencies(self.dsk, task=v)
1008 if self._removed:
1009 # deps is a frozenset but for good measure, let's not use -= since
1010 # that _may_ perform an inplace mutation
1011 deps = deps - self._removed
1012 self._cache[key] = deps
1013 return deps
1015 def __iter__(self):
1016 return iter(self.dsk)
1018 def __delitem__(self, key: Any) -> None:
1019 self._cache.clear()
1020 self._removed.add(key)
1022 def __setitem__(self, key: Any, value: Any) -> None:
1023 raise NotImplementedError
1025 def __len__(self) -> int:
1026 return len(self.dsk)
1029class _DevNullMapping(MutableMapping):
1030 def __getitem__(self, key):
1031 raise KeyError(key)
1033 def __setitem__(self, key, value):
1034 pass
1036 def __delitem__(self, key):
1037 pass
1039 def __len__(self):
1040 return 0
1042 def __iter__(self):
1043 return iter(())
1046def execute_graph(
1047 dsk: Iterable[GraphNode] | Mapping[KeyType, GraphNode],
1048 cache: MutableMapping[KeyType, object] | None = None,
1049 keys: Container[KeyType] | None = None,
1050) -> MutableMapping[KeyType, object]:
1051 """Execute a given graph.
1053 The graph is executed in topological order as defined by dask.order until
1054 all leaf nodes, i.e. nodes without any dependents, are reached. The returned
1055 dictionary contains the results of the leaf nodes.
1057 If keys are required that are not part of the graph, they can be provided in the `cache` argument.
1059 If `keys` is provided, the result will contain only values that are part of the `keys` set.
1061 """
1062 if isinstance(dsk, (list, tuple, set, frozenset)):
1063 dsk = {t.key: t for t in dsk}
1064 else:
1065 assert isinstance(dsk, dict)
1067 refcount: defaultdict[KeyType, int] = defaultdict(int)
1068 for vals in DependenciesMapping(dsk).values():
1069 for val in vals:
1070 refcount[val] += 1
1072 cache = cache or {}
1073 from dask.order import order
1075 priorities = order(dsk)
1077 for key, node in sorted(dsk.items(), key=lambda it: priorities[it[0]]):
1078 cache[key] = node(cache)
1079 for dep in node.dependencies:
1080 refcount[dep] -= 1
1081 if refcount[dep] == 0 and keys and dep not in keys:
1082 del cache[dep]
1084 return cache
1087def fuse_linear_task_spec(dsk, keys):
1088 """
1089 keys are the keys from the graph that are requested by a computation. We
1090 can't fuse those together.
1091 """
1092 from dask.core import reverse_dict
1093 from dask.optimization import default_fused_keys_renamer
1095 keys = set(keys)
1096 dependencies = DependenciesMapping(dsk)
1097 dependents = reverse_dict(dependencies)
1099 seen = set()
1100 result = {}
1102 for key in dsk:
1103 if key in seen:
1104 continue
1106 seen.add(key)
1108 deps = dependencies[key]
1109 dependents_key = dependents[key]
1111 if len(deps) != 1 and len(dependents_key) != 1 or dsk[key].block_fusion:
1112 result[key] = dsk[key]
1113 continue
1115 linear_chain = [dsk[key]]
1116 top_key = key
1118 # Walk towards the leafs as long as the nodes have a single dependency
1119 # and a single dependent, we can't fuse two nodes of an intermediate node
1120 # is the source for 2 dependents
1121 while len(deps) == 1:
1122 (new_key,) = deps
1123 if new_key in seen:
1124 break
1125 seen.add(new_key)
1126 if new_key not in dsk:
1127 # This can happen if a future is in the graph, the dependency mapping
1128 # adds the key that is referenced by the future as a dependency
1129 # see test_futures_to_delayed_array
1130 break
1131 if (
1132 len(dependents[new_key]) != 1
1133 or dsk[new_key].block_fusion
1134 or new_key in keys
1135 ):
1136 result[new_key] = dsk[new_key]
1137 break
1138 # backwards comp for new names, temporary until is_rootish is removed
1139 linear_chain.insert(0, dsk[new_key])
1140 deps = dependencies[new_key]
1142 # Walk the tree towards the root as long as the nodes have a single dependent
1143 # and a single dependency, we can't fuse two nodes if node has multiple
1144 # dependencies
1145 while len(dependents_key) == 1 and top_key not in keys:
1146 new_key = dependents_key.pop()
1147 if new_key in seen:
1148 break
1149 seen.add(new_key)
1150 if len(dependencies[new_key]) != 1 or dsk[new_key].block_fusion:
1151 # Exit if the dependent has multiple dependencies, triangle
1152 result[new_key] = dsk[new_key]
1153 break
1154 linear_chain.append(dsk[new_key])
1155 top_key = new_key
1156 dependents_key = dependents[new_key]
1158 if len(linear_chain) == 1:
1159 result[top_key] = linear_chain[0]
1160 else:
1161 # Renaming the keys is necessary to preserve the rootish detection for now
1162 renamed_key = default_fused_keys_renamer([tsk.key for tsk in linear_chain])
1163 result[renamed_key] = Task.fuse(*linear_chain, key=renamed_key)
1164 if renamed_key != top_key:
1165 # Having the same prefixes can result in the same key, i.e. getitem-hash -> getitem-hash
1166 result[top_key] = Alias(top_key, target=renamed_key)
1167 return result
1170def cull(
1171 dsk: dict[KeyType, GraphNode], keys: Iterable[KeyType]
1172) -> dict[KeyType, GraphNode]:
1173 if not isinstance(keys, (list, set, tuple)):
1174 raise TypeError(
1175 f"Expected list, set or tuple for keys, got {type(keys).__name__}"
1176 )
1177 if len(keys) == len(dsk):
1178 return dsk
1179 work = set(keys)
1180 seen: set[KeyType] = set()
1181 dsk2 = {}
1182 wpop = work.pop
1183 wupdate = work.update
1184 sadd = seen.add
1185 while work:
1186 k = wpop()
1187 if k in seen or k not in dsk:
1188 continue
1189 sadd(k)
1190 dsk2[k] = v = dsk[k]
1191 wupdate(v.dependencies)
1192 return dsk2
1195@functools.cache
1196def _extra_args(typ: type) -> set[str]:
1197 import inspect
1199 sig = inspect.signature(typ)
1200 extras = set()
1201 for name, param in sig.parameters.items():
1202 if param.kind in (
1203 inspect.Parameter.VAR_POSITIONAL,
1204 inspect.Parameter.VAR_KEYWORD,
1205 ):
1206 continue
1207 if name in typ.get_all_slots(): # type: ignore
1208 extras.add(name)
1209 return extras