Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/_task_spec.py: 25%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3""" Task specification for dask
5This module contains the task specification for dask. It is used to represent
6runnable (task) and non-runnable (data) nodes in a dask graph.
8Simple examples of how to express tasks in dask
9-----------------------------------------------
11.. code-block:: python
13 func("a", "b") ~ Task("key", func, "a", "b")
15 [func("a"), func("b")] ~ [Task("key-1", func, "a"), Task("key-2", func, "b")]
17 {"a": func("b")} ~ {"a": Task("a", func, "b")}
19 "literal-string" ~ DataNode("key", "literal-string")
22Keys, Aliases and TaskRefs
23-------------------------
25Keys are used to identify tasks in a dask graph. Every `GraphNode` instance has a
26key attribute that _should_ reference the key in the dask graph.
28.. code-block:: python
30 {"key": Task("key", func, "a")}
32Referencing other tasks is possible by using either one of `Alias` or a
33`TaskRef`.
35.. code-block:: python
37 # TaskRef can be used to provide the name of the reference explicitly
38 t = Task("key", func, TaskRef("key-1"))
40 # If a task is still in scope, the method `ref` can be used for convenience
41 t2 = Task("key2", func2, t.ref())
44Executing a task
45----------------
47A task can be executed by calling it with a dictionary of values. The values
48should contain the dependencies of the task.
50.. code-block:: python
52 t = Task("key", add, TaskRef("a"), TaskRef("b"))
53 assert t.dependencies == {"a", "b"}
54 t({"a": 1, "b": 2}) == 3
56"""
57import functools
58import itertools
59import sys
60from collections import defaultdict
61from collections.abc import Callable, Container, Iterable, Mapping, MutableMapping
62from functools import lru_cache, partial
63from typing import Any, TypeVar, cast
65from dask.sizeof import sizeof
66from dask.typing import Key as KeyType
67from dask.utils import funcname, is_namedtuple_instance
69_T = TypeVar("_T")
72# Ported from more-itertools
73# https://github.com/more-itertools/more-itertools/blob/c8153e2801ade2527f3a6c8b623afae93f5a1ce1/more_itertools/recipes.py#L944-L973
74def _batched(iterable, n, *, strict=False):
75 """Batch data into tuples of length *n*. If the number of items in
76 *iterable* is not divisible by *n*:
77 * The last batch will be shorter if *strict* is ``False``.
78 * :exc:`ValueError` will be raised if *strict* is ``True``.
80 >>> list(batched('ABCDEFG', 3))
81 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
83 On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
84 """
85 if n < 1:
86 raise ValueError("n must be at least one")
87 it = iter(iterable)
88 while batch := tuple(itertools.islice(it, n)):
89 if strict and len(batch) != n:
90 raise ValueError("batched(): incomplete batch")
91 yield batch
94if sys.hexversion >= 0x30D00A2:
96 def batched(iterable, n, *, strict=False):
97 return itertools.batched(iterable, n, strict=strict)
99else:
100 batched = _batched
102 batched.__doc__ = _batched.__doc__
103# End port
106def identity(*args):
107 return args
110def _identity_cast(*args, typ):
111 return typ(args)
114_anom_count = itertools.count()
117def parse_input(obj: Any) -> object:
118 """Tokenize user input into GraphNode objects
120 Note: This is similar to `convert_legacy_task` but does not
121 - compare any values to a global set of known keys to infer references/futures
122 - parse tuples and interprets them as runnable tasks
123 - Deal with SubgraphCallables
125 Parameters
126 ----------
127 obj : _type_
128 _description_
130 Returns
131 -------
132 _type_
133 _description_
134 """
135 if isinstance(obj, GraphNode):
136 return obj
138 if _is_dask_future(obj):
139 return Alias(obj.key)
141 if isinstance(obj, dict):
142 parsed_dict = {k: parse_input(v) for k, v in obj.items()}
143 if any(isinstance(v, GraphNode) for v in parsed_dict.values()):
144 return Dict(parsed_dict)
146 if isinstance(obj, (list, set, tuple)):
147 parsed_collection = tuple(parse_input(o) for o in obj)
148 if any(isinstance(o, GraphNode) for o in parsed_collection):
149 if isinstance(obj, list):
150 return List(*parsed_collection)
151 if isinstance(obj, set):
152 return Set(*parsed_collection)
153 if isinstance(obj, tuple):
154 if is_namedtuple_instance(obj):
155 return _wrap_namedtuple_task(None, obj, parse_input)
156 return Tuple(*parsed_collection)
158 return obj
161def _wrap_namedtuple_task(k, obj, parser):
162 if hasattr(obj, "__getnewargs_ex__"):
163 new_args, kwargs = obj.__getnewargs_ex__()
164 kwargs = {k: parser(v) for k, v in kwargs.items()}
165 elif hasattr(obj, "__getnewargs__"):
166 new_args = obj.__getnewargs__()
167 kwargs = {}
169 args_converted = parse_input(type(new_args)(map(parser, new_args)))
171 return Task(
172 k, partial(_instantiate_named_tuple, type(obj)), args_converted, Dict(kwargs)
173 )
176def _instantiate_named_tuple(typ, args, kwargs):
177 return typ(*args, **kwargs)
180class _MultiContainer(Container):
181 container: tuple
182 __slots__ = ("container",)
184 def __init__(self, *container):
185 self.container = container
187 def __contains__(self, o: object) -> bool:
188 return any(o in c for c in self.container)
191SubgraphType = None
194def _execute_subgraph(inner_dsk, outkey, inkeys, *dependencies):
195 final = {}
196 final.update(inner_dsk)
197 for k, v in zip(inkeys, dependencies):
198 final[k] = DataNode(None, v)
199 res = execute_graph(final, keys=[outkey])
200 return res[outkey]
203def convert_legacy_task(
204 key: KeyType | None,
205 task: _T,
206 all_keys: Container,
207) -> GraphNode | _T:
208 if isinstance(task, GraphNode):
209 return task
211 if type(task) is tuple and task and callable(task[0]):
212 func, args = task[0], task[1:]
213 new_args = []
214 new: object
215 for a in args:
216 if isinstance(a, dict):
217 new = Dict(a)
218 else:
219 new = convert_legacy_task(None, a, all_keys)
220 new_args.append(new)
221 return Task(key, func, *new_args)
222 try:
223 if isinstance(task, (int, float, str, tuple)):
224 if task in all_keys:
225 if key is None:
226 return Alias(task)
227 else:
228 return Alias(key, target=task)
229 except TypeError:
230 # Unhashable
231 pass
233 if isinstance(task, (list, tuple, set, frozenset)):
234 if is_namedtuple_instance(task):
235 return _wrap_namedtuple_task(
236 key,
237 task,
238 partial(
239 convert_legacy_task,
240 None,
241 all_keys=all_keys,
242 ),
243 )
244 else:
245 parsed_args = tuple(convert_legacy_task(None, t, all_keys) for t in task)
246 if any(isinstance(a, GraphNode) for a in parsed_args):
247 return Task(key, _identity_cast, *parsed_args, typ=type(task))
248 else:
249 return cast(_T, type(task)(parsed_args))
250 elif _is_dask_future(task):
251 if key is None:
252 return Alias(task.key) # type: ignore[attr-defined]
253 else:
254 return Alias(key, target=task.key) # type: ignore[attr-defined]
255 else:
256 return task
259def convert_legacy_graph(
260 dsk: Mapping,
261 all_keys: Container | None = None,
262):
263 if all_keys is None:
264 all_keys = set(dsk)
265 new_dsk = {}
266 for k, arg in dsk.items():
267 t = convert_legacy_task(k, arg, all_keys)
268 if isinstance(t, Alias) and t.target == k:
269 continue
270 elif not isinstance(t, GraphNode):
271 t = DataNode(k, t)
272 new_dsk[k] = t
273 return new_dsk
276def resolve_aliases(dsk: dict, keys: set, dependents: dict) -> dict:
277 """Remove trivial sequential alias chains
279 Example:
281 dsk = {'x': 1, 'y': Alias('x'), 'z': Alias('y')}
283 resolve_aliases(dsk, {'z'}, {'x': {'y'}, 'y': {'z'}}) == {'z': 1}
285 """
286 if not keys:
287 raise ValueError("No keys provided")
288 dsk = dict(dsk)
289 work = list(keys)
290 seen = set()
291 while work:
292 k = work.pop()
293 if k in seen or k not in dsk:
294 continue
295 seen.add(k)
296 t = dsk[k]
297 if isinstance(t, Alias):
298 target_key = t.target
299 # Rules for when we allow to collapse an alias
300 # 1. The target key is not in the keys set. The keys set is what the
301 # user is requesting and by collapsing we'd no longer be able to
302 # return that result.
303 # 2. The target key is in fact part of dsk. If it isn't this could
304 # point to a persisted dependency and we cannot collapse it.
305 # 3. The target key has only one dependent which is the key we're
306 # currently looking at. This means that there is a one to one
307 # relation between this and the target key in which case we can
308 # collapse them.
309 # Note: If target was an alias as well, we could continue with
310 # more advanced optimizations but this isn't implemented, yet
311 if (
312 target_key not in keys
313 and target_key in dsk
314 # Note: whenever we're performing a collapse, we're not updating
315 # the dependents. The length == 1 should still be sufficient for
316 # chains of these aliases
317 and len(dependents[target_key]) == 1
318 ):
319 tnew = dsk.pop(target_key).copy()
321 dsk[k] = tnew
322 tnew.key = k
323 if isinstance(tnew, Alias):
324 work.append(k)
325 seen.discard(k)
326 else:
327 work.extend(tnew.dependencies)
329 work.extend(t.dependencies)
330 return dsk
333class TaskRef:
334 val: KeyType
335 __slots__ = ("key",)
337 def __init__(self, key: KeyType):
338 self.key = key
340 def __str__(self):
341 return str(self.key)
343 def __repr__(self):
344 return f"{type(self).__name__}({self.key!r})"
346 def __hash__(self) -> int:
347 return hash(self.key)
349 def __eq__(self, value: object) -> bool:
350 if not isinstance(value, TaskRef):
351 return False
352 return self.key == value.key
354 def __reduce__(self):
355 return TaskRef, (self.key,)
357 def substitute(self, subs: dict, key: KeyType | None = None) -> TaskRef | GraphNode:
358 if self.key in subs:
359 val = subs[self.key]
360 if isinstance(val, GraphNode):
361 return val.substitute({}, key=self.key)
362 elif isinstance(val, TaskRef):
363 return val
364 else:
365 return TaskRef(val)
366 return self
369def _is_dask_future(obj: object) -> bool:
370 """Check if obj is a dask Future (TaskRef or duck-typed with __dask_future__).
372 This supports both distributed.Future (which inherits from TaskRef) and
373 third-party scheduler futures that set __dask_future__ = True.
374 """
375 return isinstance(obj, TaskRef) or getattr(obj, "__dask_future__", False)
378class GraphNode:
379 key: KeyType
380 _dependencies: frozenset
382 __slots__ = tuple(__annotations__)
384 def ref(self):
385 return Alias(self.key)
387 def copy(self):
388 raise NotImplementedError
390 @property
391 def data_producer(self) -> bool:
392 return False
394 @property
395 def dependencies(self) -> frozenset:
396 return self._dependencies
398 @property
399 def block_fusion(self) -> bool:
400 return False
402 def _verify_values(self, values: tuple | dict) -> None:
403 if not self.dependencies:
404 return
405 if missing := set(self.dependencies) - set(values):
406 raise RuntimeError(f"Not enough arguments provided: missing keys {missing}")
408 def __call__(self, values) -> Any:
409 raise NotImplementedError("Not implemented")
411 def __eq__(self, value: object) -> bool:
412 if type(value) is not type(self):
413 return False
415 from dask.tokenize import tokenize
417 return tokenize(self) == tokenize(value)
419 @property
420 def is_coro(self) -> bool:
421 return False
423 def __sizeof__(self) -> int:
424 all_slots = self.get_all_slots()
425 return sum(sizeof(getattr(self, sl)) for sl in all_slots) + sys.getsizeof(
426 type(self)
427 )
429 def substitute(
430 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
431 ) -> GraphNode:
432 """Substitute a dependency with a new value. The new value either has to
433 be a new valid key or a GraphNode to replace the dependency entirely.
435 The GraphNode will not be mutated but instead a shallow copy will be
436 returned. The substitution will be performed eagerly.
438 Parameters
439 ----------
440 subs : dict[KeyType, KeyType | GraphNode]
441 The mapping describing the substitutions to be made.
442 key : KeyType | None, optional
443 The key of the new GraphNode object. If None provided, the key of
444 the old one will be reused.
445 """
446 raise NotImplementedError
448 @staticmethod
449 def fuse(*tasks: GraphNode, key: KeyType | None = None) -> GraphNode:
450 """Fuse a set of tasks into a single task.
452 The tasks are fused into a single task that will execute the tasks in a
453 subgraph. The internal tasks are no longer accessible from the outside.
455 All provided tasks must form a valid subgraph that will reduce to a
456 single key. If multiple outputs are possible with the provided tasks, an
457 exception will be raised.
459 The tasks will not be rewritten but instead a new Task will be created
460 that will merely reference the old task objects. This way, Task objects
461 may be reused in multiple fused tasks.
463 Parameters
464 ----------
465 key : KeyType | None, optional
466 The key of the new Task object. If None provided, the key of the
467 final task will be used.
469 See also
470 --------
471 GraphNode.substitute : Easier substitution of dependencies
472 """
473 if any(t.key is None for t in tasks):
474 raise ValueError("Cannot fuse tasks with missing keys")
475 if len(tasks) == 1:
476 return tasks[0].substitute({}, key=key)
477 all_keys = set()
478 all_deps: set[KeyType] = set()
479 for t in tasks:
480 all_deps.update(t.dependencies)
481 all_keys.add(t.key)
482 external_deps = tuple(sorted(all_deps - all_keys, key=hash))
483 leafs = all_keys - all_deps
484 if len(leafs) > 1:
485 raise ValueError(f"Cannot fuse tasks with multiple outputs {leafs}")
487 outkey = leafs.pop()
488 return Task(
489 key or outkey,
490 _execute_subgraph,
491 {t.key: t for t in tasks},
492 outkey,
493 external_deps,
494 *(TaskRef(k) for k in external_deps),
495 _data_producer=any(t.data_producer for t in tasks),
496 )
498 @classmethod
499 @lru_cache
500 def get_all_slots(cls):
501 slots = list()
502 for c in cls.mro():
503 slots.extend(getattr(c, "__slots__", ()))
504 # Interestingly, sorting this causes the nested containers to pickle
505 # more efficiently
506 return sorted(set(slots))
509_no_deps: frozenset = frozenset()
512class Alias(GraphNode):
513 target: KeyType
514 __slots__ = tuple(__annotations__)
516 def __init__(
517 self, key: KeyType | TaskRef, target: Alias | TaskRef | KeyType | None = None
518 ):
519 if isinstance(key, TaskRef):
520 key = key.key
521 self.key = key
522 if target is None:
523 target = key
524 if isinstance(target, Alias):
525 target = target.target
526 if isinstance(target, TaskRef):
527 target = target.key
528 self.target = target
529 self._dependencies = frozenset((self.target,))
531 def __reduce__(self):
532 return Alias, (self.key, self.target)
534 def copy(self):
535 return Alias(self.key, self.target)
537 def substitute(
538 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
539 ) -> GraphNode:
540 if self.key in subs or self.target in subs:
541 sub_key = subs.get(self.key, self.key)
542 val = subs.get(self.target, self.target)
543 if sub_key == self.key and val == self.target:
544 return self
545 if isinstance(val, (GraphNode, TaskRef)):
546 return val.substitute({}, key=key)
547 if key is None and isinstance(sub_key, GraphNode):
548 raise RuntimeError(
549 f"Invalid substitution encountered {self.key!r} -> {sub_key}"
550 )
551 return Alias(key or sub_key, val) # type: ignore [arg-type]
552 return self
554 def __dask_tokenize__(self):
555 return (type(self).__name__, self.key, self.target)
557 def __call__(self, values=()):
558 self._verify_values(values)
559 return values[self.target]
561 def __repr__(self):
562 if self.key != self.target:
563 return f"Alias({self.key!r}->{self.target!r})"
564 else:
565 return f"Alias({self.key!r})"
567 def __eq__(self, value: object) -> bool:
568 if not isinstance(value, Alias):
569 return False
570 if self.key != value.key:
571 return False
572 return self.target == value.target
575class DataNode(GraphNode):
576 value: Any
577 typ: type
578 __slots__ = tuple(__annotations__)
580 def __init__(self, key: Any, value: Any):
581 if key is None:
582 key = (type(value).__name__, next(_anom_count))
583 self.key = key
584 self.value = value
585 self.typ = type(value)
586 self._dependencies = _no_deps
588 @property
589 def data_producer(self) -> bool:
590 return True
592 def copy(self):
593 return DataNode(self.key, self.value)
595 def __call__(self, values=()):
596 return self.value
598 def __repr__(self):
599 return f"DataNode({self.value!r})"
601 def __reduce__(self):
602 return (DataNode, (self.key, self.value))
604 def __dask_tokenize__(self):
605 from dask.base import tokenize
607 return (type(self).__name__, tokenize(self.value))
609 def substitute(
610 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
611 ) -> DataNode:
612 if key is not None and key != self.key:
613 return DataNode(key, self.value)
614 return self
616 def __iter__(self):
617 return iter(self.value)
620def _get_dependencies(obj: object) -> set | frozenset:
621 if _is_dask_future(obj):
622 return obj.key # type: ignore[attr-defined]
623 elif isinstance(obj, GraphNode):
624 return obj.dependencies
625 elif isinstance(obj, dict):
626 if not obj:
627 return _no_deps
628 return set().union(*map(_get_dependencies, obj.values()))
629 elif isinstance(obj, (list, tuple, frozenset, set)):
630 if not obj:
631 return _no_deps
632 return set().union(*map(_get_dependencies, obj))
633 return _no_deps
636class Task(GraphNode):
637 func: Callable
638 args: tuple
639 kwargs: dict
640 _data_producer: bool
641 _token: str | None
642 _is_coro: bool | None
643 _repr: str | None
645 __slots__ = tuple(__annotations__)
647 def __init__(
648 self,
649 key: Any,
650 func: Callable,
651 /,
652 *args: Any,
653 _data_producer: bool = False,
654 **kwargs: Any,
655 ):
656 self.key = key
657 self.func = func
658 if isinstance(func, Task):
659 raise TypeError("Cannot nest tasks")
661 self.args = args
662 self.kwargs = kwargs
663 _dependencies: set[KeyType] | None = None
664 for a in itertools.chain(args, kwargs.values()):
665 if isinstance(a, TaskRef):
666 if _dependencies is None:
667 _dependencies = {a.key}
668 else:
669 _dependencies.add(a.key)
670 elif isinstance(a, GraphNode) and a.dependencies:
671 if _dependencies is None:
672 _dependencies = set(a.dependencies)
673 else:
674 _dependencies.update(a.dependencies)
675 if _dependencies:
676 self._dependencies = frozenset(_dependencies)
677 else:
678 self._dependencies = _no_deps
679 self._is_coro = None
680 self._token = None
681 self._repr = None
682 self._data_producer = _data_producer
684 @property
685 def data_producer(self) -> bool:
686 return self._data_producer
688 def has_subgraph(self) -> bool:
689 return self.func == _execute_subgraph
691 def copy(self):
692 return type(self)(
693 self.key,
694 self.func,
695 *self.args,
696 **self.kwargs,
697 )
699 def __hash__(self):
700 return hash(self._get_token())
702 def _get_token(self) -> str:
703 if self._token:
704 return self._token
705 from dask.base import tokenize
707 self._token = tokenize(
708 (
709 type(self).__name__,
710 self.func,
711 self.args,
712 self.kwargs,
713 )
714 )
715 return self._token
717 def __dask_tokenize__(self):
718 return self._get_token()
720 def __repr__(self) -> str:
721 # When `Task` is deserialized the constructor will not run and
722 # `self._repr` is thus undefined.
723 if not hasattr(self, "_repr") or not self._repr:
724 head = funcname(self.func)
725 tail = ")"
726 label_size = 40
727 args = self.args
728 kwargs = self.kwargs
729 if args or kwargs:
730 label_size2 = int(
731 (label_size - len(head) - len(tail) - len(str(self.key)))
732 // (len(args) + len(kwargs))
733 )
734 if args:
735 if label_size2 > 5:
736 args_repr = ", ".join(repr(t) for t in args)
737 else:
738 args_repr = "..."
739 else:
740 args_repr = ""
741 if kwargs:
742 if label_size2 > 5:
743 kwargs_repr = ", " + ", ".join(
744 f"{k}={v!r}" for k, v in sorted(kwargs.items())
745 )
746 else:
747 kwargs_repr = ", ..."
748 else:
749 kwargs_repr = ""
750 self._repr = f"<Task {self.key!r} {head}({args_repr}{kwargs_repr}{tail}>"
751 return self._repr
753 def __call__(self, values=()):
754 self._verify_values(values)
756 def _eval(a):
757 if isinstance(a, GraphNode):
758 return a({k: values[k] for k in a.dependencies})
759 elif isinstance(a, TaskRef):
760 return values[a.key]
761 else:
762 return a
764 new_argspec = tuple(map(_eval, self.args))
765 if self.kwargs:
766 kwargs = {k: _eval(kw) for k, kw in self.kwargs.items()}
767 return self.func(*new_argspec, **kwargs)
768 return self.func(*new_argspec)
770 def __setstate__(self, state):
771 slots = self.__class__.get_all_slots()
772 for sl, val in zip(slots, state):
773 setattr(self, sl, val)
775 def __getstate__(self):
776 slots = self.__class__.get_all_slots()
777 return tuple(getattr(self, sl) for sl in slots)
779 @property
780 def is_coro(self):
781 if self._is_coro is None:
782 # Note: Can't use cached_property on objects without __dict__
783 try:
784 from distributed.utils import iscoroutinefunction
786 self._is_coro = iscoroutinefunction(self.func)
787 except Exception:
788 self._is_coro = False
789 return self._is_coro
791 def substitute(
792 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
793 ) -> Task:
794 subs_filtered = {
795 k: v for k, v in subs.items() if k in self.dependencies and k != v
796 }
797 extras = _extra_args(type(self)) # type: ignore[arg-type]
798 extra_kwargs = {
799 name: getattr(self, name) for name in extras if name not in {"key", "func"}
800 }
801 if subs_filtered:
802 new_args = tuple(
803 (
804 a.substitute(subs_filtered)
805 if isinstance(a, (GraphNode, TaskRef))
806 else a
807 )
808 for a in self.args
809 )
810 new_kwargs = {
811 k: (
812 v.substitute(subs_filtered)
813 if isinstance(v, (GraphNode, TaskRef))
814 else v
815 )
816 for k, v in self.kwargs.items()
817 }
818 return type(self)(
819 key or self.key,
820 self.func,
821 *new_args,
822 **new_kwargs, # type: ignore[arg-type]
823 **extra_kwargs,
824 )
825 elif key is None or key == self.key:
826 return self
827 else:
828 # Rename
829 return type(self)(
830 key,
831 self.func,
832 *self.args,
833 **self.kwargs,
834 **extra_kwargs,
835 )
838class NestedContainer(Task, Iterable):
839 constructor: Callable
840 klass: type
841 __slots__ = tuple(__annotations__)
843 def __init__(
844 self,
845 /,
846 *args: Any,
847 **kwargs: Any,
848 ):
849 if len(args) == 1 and isinstance(args[0], self.klass):
850 args = args[0] # type: ignore[assignment]
851 super().__init__(
852 None,
853 self.to_container,
854 *args,
855 constructor=self.constructor,
856 **kwargs,
857 )
859 def __getstate__(self):
860 state = super().__getstate__()
861 state = list(state)
862 slots = self.__class__.get_all_slots()
863 ix = slots.index("kwargs")
864 # The constructor as a kwarg is redundant since this is encoded in the
865 # class itself. Serializing the builtin types is not trivial
866 # This saves about 15% of overhead
867 state[ix] = state[ix].copy()
868 state[ix].pop("constructor", None)
869 return state
871 def __setstate__(self, state):
872 super().__setstate__(state)
873 self.kwargs["constructor"] = self.__class__.constructor
874 return self
876 def __repr__(self):
877 return f"{type(self).__name__}({self.args})"
879 def substitute(
880 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
881 ) -> NestedContainer:
882 subs_filtered = {
883 k: v for k, v in subs.items() if k in self.dependencies and k != v
884 }
885 if not subs_filtered:
886 return self
887 return type(self)(
888 *(
889 (
890 a.substitute(subs_filtered)
891 if isinstance(a, (GraphNode, TaskRef))
892 else a
893 )
894 for a in self.args
895 )
896 )
898 def __dask_tokenize__(self):
899 from dask.tokenize import tokenize
901 return (
902 type(self).__name__,
903 self.klass,
904 sorted(tokenize(a) for a in self.args),
905 )
907 return super().__dask_tokenize__()
909 @staticmethod
910 def to_container(*args, constructor):
911 return constructor(args)
913 def __iter__(self):
914 yield from self.args
917class List(NestedContainer):
918 constructor = klass = list
921class Tuple(NestedContainer):
922 constructor = klass = tuple
925class Set(NestedContainer):
926 constructor = klass = set
929class Dict(NestedContainer, Mapping):
930 klass = dict
932 def __init__(self, /, *args: Any, **kwargs: Any):
933 if args:
934 assert not kwargs
935 if len(args) == 1:
936 args = args[0]
937 if isinstance(args, dict): # type: ignore[unreachable]
938 args = tuple(itertools.chain(*args.items())) # type: ignore[unreachable]
939 elif isinstance(args, (list, tuple)):
940 if all(
941 len(el) == 2 if isinstance(el, (list, tuple)) else False
942 for el in args
943 ):
944 args = tuple(itertools.chain(*args))
945 else:
946 raise ValueError("Invalid argument provided")
948 if len(args) % 2 != 0:
949 raise ValueError("Invalid number of arguments provided")
951 elif kwargs:
952 assert not args
953 args = tuple(itertools.chain(*kwargs.items()))
955 super().__init__(*args)
957 def __repr__(self):
958 values = ", ".join(f"{k}: {v}" for k, v in batched(self.args, 2, strict=True))
959 return f"Dict({values})"
961 def substitute(
962 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
963 ) -> Dict:
964 subs_filtered = {
965 k: v for k, v in subs.items() if k in self.dependencies and k != v
966 }
967 if not subs_filtered:
968 return self
970 new_args = []
971 for arg in self.args:
972 new_arg = (
973 arg.substitute(subs_filtered)
974 if isinstance(arg, (GraphNode, TaskRef))
975 else arg
976 )
977 new_args.append(new_arg)
978 return type(self)(new_args)
980 def __iter__(self):
981 yield from self.args[::2]
983 def __len__(self):
984 return len(self.args) // 2
986 def __getitem__(self, key):
987 for k, v in batched(self.args, 2, strict=True):
988 if k == key:
989 return v
990 raise KeyError(key)
992 @staticmethod
993 def constructor(args):
994 return dict(batched(args, 2, strict=True))
997class DependenciesMapping(MutableMapping):
998 def __init__(self, dsk):
999 self.dsk = dsk
1000 self._removed = set()
1001 # Set a copy of dsk to avoid dct resizing
1002 self._cache = dsk.copy()
1003 self._cache.clear()
1005 def __getitem__(self, key):
1006 if (val := self._cache.get(key)) is not None:
1007 return val
1008 else:
1009 v = self.dsk[key]
1010 try:
1011 deps = v.dependencies
1012 except AttributeError:
1013 from dask.core import get_dependencies
1015 deps = get_dependencies(self.dsk, task=v)
1017 if self._removed:
1018 # deps is a frozenset but for good measure, let's not use -= since
1019 # that _may_ perform an inplace mutation
1020 deps = deps - self._removed
1021 self._cache[key] = deps
1022 return deps
1024 def __iter__(self):
1025 return iter(self.dsk)
1027 def __delitem__(self, key: Any) -> None:
1028 self._cache.clear()
1029 self._removed.add(key)
1031 def __setitem__(self, key: Any, value: Any) -> None:
1032 raise NotImplementedError
1034 def __len__(self) -> int:
1035 return len(self.dsk)
1038class _DevNullMapping(MutableMapping):
1039 def __getitem__(self, key):
1040 raise KeyError(key)
1042 def __setitem__(self, key, value):
1043 pass
1045 def __delitem__(self, key):
1046 pass
1048 def __len__(self):
1049 return 0
1051 def __iter__(self):
1052 return iter(())
1055def execute_graph(
1056 dsk: Iterable[GraphNode] | Mapping[KeyType, GraphNode],
1057 cache: MutableMapping[KeyType, object] | None = None,
1058 keys: Container[KeyType] | None = None,
1059) -> MutableMapping[KeyType, object]:
1060 """Execute a given graph.
1062 The graph is executed in topological order as defined by dask.order until
1063 all leaf nodes, i.e. nodes without any dependents, are reached. The returned
1064 dictionary contains the results of the leaf nodes.
1066 If keys are required that are not part of the graph, they can be provided in the `cache` argument.
1068 If `keys` is provided, the result will contain only values that are part of the `keys` set.
1070 """
1071 if isinstance(dsk, (list, tuple, set, frozenset)):
1072 dsk = {t.key: t for t in dsk}
1073 else:
1074 assert isinstance(dsk, dict)
1076 refcount: defaultdict[KeyType, int] = defaultdict(int)
1077 for vals in DependenciesMapping(dsk).values():
1078 for val in vals:
1079 refcount[val] += 1
1081 cache = cache or {}
1082 from dask.order import order
1084 priorities = order(dsk)
1086 for key, node in sorted(dsk.items(), key=lambda it: priorities[it[0]]):
1087 cache[key] = node(cache)
1088 for dep in node.dependencies:
1089 refcount[dep] -= 1
1090 if refcount[dep] == 0 and keys and dep not in keys:
1091 del cache[dep]
1093 return cache
1096def fuse_linear_task_spec(dsk, keys):
1097 """
1098 keys are the keys from the graph that are requested by a computation. We
1099 can't fuse those together.
1100 """
1101 from dask.core import reverse_dict
1102 from dask.optimization import default_fused_keys_renamer
1104 keys = set(keys)
1105 dependencies = DependenciesMapping(dsk)
1106 dependents = reverse_dict(dependencies)
1108 seen = set()
1109 result = {}
1111 for key in dsk:
1112 if key in seen:
1113 continue
1115 seen.add(key)
1117 deps = dependencies[key]
1118 dependents_key = dependents[key]
1120 if len(deps) != 1 and len(dependents_key) != 1 or dsk[key].block_fusion:
1121 result[key] = dsk[key]
1122 continue
1124 linear_chain = [dsk[key]]
1125 top_key = key
1127 # Walk towards the leafs as long as the nodes have a single dependency
1128 # and a single dependent, we can't fuse two nodes of an intermediate node
1129 # is the source for 2 dependents
1130 while len(deps) == 1:
1131 (new_key,) = deps
1132 if new_key in seen:
1133 break
1134 seen.add(new_key)
1135 if new_key not in dsk:
1136 # This can happen if a future is in the graph, the dependency mapping
1137 # adds the key that is referenced by the future as a dependency
1138 # see test_futures_to_delayed_array
1139 break
1140 if (
1141 len(dependents[new_key]) != 1
1142 or dsk[new_key].block_fusion
1143 or new_key in keys
1144 ):
1145 result[new_key] = dsk[new_key]
1146 break
1147 # backwards comp for new names, temporary until is_rootish is removed
1148 linear_chain.insert(0, dsk[new_key])
1149 deps = dependencies[new_key]
1151 # Walk the tree towards the root as long as the nodes have a single dependent
1152 # and a single dependency, we can't fuse two nodes if node has multiple
1153 # dependencies
1154 while len(dependents_key) == 1 and top_key not in keys:
1155 new_key = dependents_key.pop()
1156 if new_key in seen:
1157 break
1158 seen.add(new_key)
1159 if len(dependencies[new_key]) != 1 or dsk[new_key].block_fusion:
1160 # Exit if the dependent has multiple dependencies, triangle
1161 result[new_key] = dsk[new_key]
1162 break
1163 linear_chain.append(dsk[new_key])
1164 top_key = new_key
1165 dependents_key = dependents[new_key]
1167 if len(linear_chain) == 1:
1168 result[top_key] = linear_chain[0]
1169 else:
1170 # Renaming the keys is necessary to preserve the rootish detection for now
1171 renamed_key = default_fused_keys_renamer([tsk.key for tsk in linear_chain])
1172 result[renamed_key] = Task.fuse(*linear_chain, key=renamed_key)
1173 if renamed_key != top_key:
1174 # Having the same prefixes can result in the same key, i.e. getitem-hash -> getitem-hash
1175 result[top_key] = Alias(top_key, target=renamed_key)
1176 return result
1179def cull(
1180 dsk: dict[KeyType, GraphNode], keys: Iterable[KeyType]
1181) -> dict[KeyType, GraphNode]:
1182 if not isinstance(keys, (list, set, tuple)):
1183 raise TypeError(
1184 f"Expected list, set or tuple for keys, got {type(keys).__name__}"
1185 )
1186 if len(keys) == len(dsk):
1187 return dsk
1188 work = set(keys)
1189 seen: set[KeyType] = set()
1190 dsk2 = {}
1191 wpop = work.pop
1192 wupdate = work.update
1193 sadd = seen.add
1194 while work:
1195 k = wpop()
1196 if k in seen or k not in dsk:
1197 continue
1198 sadd(k)
1199 dsk2[k] = v = dsk[k]
1200 wupdate(v.dependencies)
1201 return dsk2
1204@functools.cache
1205def _extra_args(typ: type) -> set[str]:
1206 import inspect
1208 sig = inspect.signature(typ)
1209 extras = set()
1210 for name, param in sig.parameters.items():
1211 if param.kind in (
1212 inspect.Parameter.VAR_POSITIONAL,
1213 inspect.Parameter.VAR_KEYWORD,
1214 ):
1215 continue
1216 if name in typ.get_all_slots(): # type: ignore[attr-defined]
1217 extras.add(name)
1218 return extras