Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/_task_spec.py: 25%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3""" Task specification for dask
5This module contains the task specification for dask. It is used to represent
6runnable (task) and non-runnable (data) nodes in a dask graph.
8Simple examples of how to express tasks in dask
9-----------------------------------------------
11.. code-block:: python
13 func("a", "b") ~ Task("key", func, "a", "b")
15 [func("a"), func("b")] ~ [Task("key-1", func, "a"), Task("key-2", func, "b")]
17 {"a": func("b")} ~ {"a": Task("a", func, "b")}
19 "literal-string" ~ DataNode("key", "literal-string")
22Keys, Aliases and TaskRefs
23-------------------------
25Keys are used to identify tasks in a dask graph. Every `GraphNode` instance has a
26key attribute that _should_ reference the key in the dask graph.
28.. code-block:: python
30 {"key": Task("key", func, "a")}
32Referencing other tasks is possible by using either one of `Alias` or a
33`TaskRef`.
35.. code-block:: python
37 # TaskRef can be used to provide the name of the reference explicitly
38 t = Task("key", func, TaskRef("key-1"))
40 # If a task is still in scope, the method `ref` can be used for convenience
41 t2 = Task("key2", func2, t.ref())
44Executing a task
45----------------
47A task can be executed by calling it with a dictionary of values. The values
48should contain the dependencies of the task.
50.. code-block:: python
52 t = Task("key", add, TaskRef("a"), TaskRef("b"))
53 assert t.dependencies == {"a", "b"}
54 t({"a": 1, "b": 2}) == 3
56"""
57import functools
58import itertools
59import sys
60from collections import defaultdict
61from collections.abc import Callable, Container, Iterable, Mapping, MutableMapping
62from functools import lru_cache, partial
63from typing import Any, TypeVar, cast
65from dask.sizeof import sizeof
66from dask.typing import Key as KeyType
67from dask.utils import funcname, is_namedtuple_instance
69_T = TypeVar("_T")
72# Ported from more-itertools
73# https://github.com/more-itertools/more-itertools/blob/c8153e2801ade2527f3a6c8b623afae93f5a1ce1/more_itertools/recipes.py#L944-L973
74def _batched(iterable, n, *, strict=False):
75 """Batch data into tuples of length *n*. If the number of items in
76 *iterable* is not divisible by *n*:
77 * The last batch will be shorter if *strict* is ``False``.
78 * :exc:`ValueError` will be raised if *strict* is ``True``.
80 >>> list(batched('ABCDEFG', 3))
81 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
83 On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
84 """
85 if n < 1:
86 raise ValueError("n must be at least one")
87 it = iter(iterable)
88 while batch := tuple(itertools.islice(it, n)):
89 if strict and len(batch) != n:
90 raise ValueError("batched(): incomplete batch")
91 yield batch
94if sys.hexversion >= 0x30D00A2:
96 def batched(iterable, n, *, strict=False):
97 return itertools.batched(iterable, n, strict=strict)
99else:
100 batched = _batched
102 batched.__doc__ = _batched.__doc__
103# End port
106def identity(*args):
107 return args
110def _identity_cast(*args, typ):
111 return typ(args)
114_anom_count = itertools.count()
117def parse_input(obj: Any) -> object:
118 """Tokenize user input into GraphNode objects
120 Note: This is similar to `convert_legacy_task` but does not
121 - compare any values to a global set of known keys to infer references/futures
122 - parse tuples and interprets them as runnable tasks
123 - Deal with SubgraphCallables
125 Parameters
126 ----------
127 obj : _type_
128 _description_
130 Returns
131 -------
132 _type_
133 _description_
134 """
135 if isinstance(obj, GraphNode):
136 return obj
138 if isinstance(obj, TaskRef):
139 return Alias(obj.key)
141 if isinstance(obj, dict):
142 parsed_dict = {k: parse_input(v) for k, v in obj.items()}
143 if any(isinstance(v, GraphNode) for v in parsed_dict.values()):
144 return Dict(parsed_dict)
146 if isinstance(obj, (list, set, tuple)):
147 parsed_collection = tuple(parse_input(o) for o in obj)
148 if any(isinstance(o, GraphNode) for o in parsed_collection):
149 if isinstance(obj, list):
150 return List(*parsed_collection)
151 if isinstance(obj, set):
152 return Set(*parsed_collection)
153 if isinstance(obj, tuple):
154 if is_namedtuple_instance(obj):
155 return _wrap_namedtuple_task(None, obj, parse_input)
156 return Tuple(*parsed_collection)
158 return obj
161def _wrap_namedtuple_task(k, obj, parser):
162 if hasattr(obj, "__getnewargs_ex__"):
163 new_args, kwargs = obj.__getnewargs_ex__()
164 kwargs = {k: parser(v) for k, v in kwargs.items()}
165 elif hasattr(obj, "__getnewargs__"):
166 new_args = obj.__getnewargs__()
167 kwargs = {}
169 args_converted = parse_input(type(new_args)(map(parser, new_args)))
171 return Task(
172 k, partial(_instantiate_named_tuple, type(obj)), args_converted, Dict(kwargs)
173 )
176def _instantiate_named_tuple(typ, args, kwargs):
177 return typ(*args, **kwargs)
180class _MultiContainer(Container):
181 container: tuple
182 __slots__ = ("container",)
184 def __init__(self, *container):
185 self.container = container
187 def __contains__(self, o: object) -> bool:
188 for c in self.container:
189 if o in c:
190 return True
191 return False
194SubgraphType = None
197def _execute_subgraph(inner_dsk, outkey, inkeys, *dependencies):
198 final = {}
199 final.update(inner_dsk)
200 for k, v in zip(inkeys, dependencies):
201 final[k] = DataNode(None, v)
202 res = execute_graph(final, keys=[outkey])
203 return res[outkey]
206def convert_legacy_task(
207 key: KeyType | None,
208 task: _T,
209 all_keys: Container,
210) -> GraphNode | _T:
211 if isinstance(task, GraphNode):
212 return task
214 if type(task) is tuple and task and callable(task[0]):
215 func, args = task[0], task[1:]
216 new_args = []
217 new: object
218 for a in args:
219 if isinstance(a, dict):
220 new = Dict(a)
221 else:
222 new = convert_legacy_task(None, a, all_keys)
223 new_args.append(new)
224 return Task(key, func, *new_args)
225 try:
226 if isinstance(task, (int, float, str, tuple)):
227 if task in all_keys:
228 if key is None:
229 return Alias(task)
230 else:
231 return Alias(key, target=task)
232 except TypeError:
233 # Unhashable
234 pass
236 if isinstance(task, (list, tuple, set, frozenset)):
237 if is_namedtuple_instance(task):
238 return _wrap_namedtuple_task(
239 key,
240 task,
241 partial(
242 convert_legacy_task,
243 None,
244 all_keys=all_keys,
245 ),
246 )
247 else:
248 parsed_args = tuple(convert_legacy_task(None, t, all_keys) for t in task)
249 if any(isinstance(a, GraphNode) for a in parsed_args):
250 return Task(key, _identity_cast, *parsed_args, typ=type(task))
251 else:
252 return cast(_T, type(task)(parsed_args))
253 elif isinstance(task, TaskRef):
254 if key is None:
255 return Alias(task.key)
256 else:
257 return Alias(key, target=task.key)
258 else:
259 return task
262def convert_legacy_graph(
263 dsk: Mapping,
264 all_keys: Container | None = None,
265):
266 if all_keys is None:
267 all_keys = set(dsk)
268 new_dsk = {}
269 for k, arg in dsk.items():
270 t = convert_legacy_task(k, arg, all_keys)
271 if isinstance(t, Alias) and t.target == k:
272 continue
273 elif not isinstance(t, GraphNode):
274 t = DataNode(k, t)
275 new_dsk[k] = t
276 return new_dsk
279def resolve_aliases(dsk: dict, keys: set, dependents: dict) -> dict:
280 """Remove trivial sequential alias chains
282 Example:
284 dsk = {'x': 1, 'y': Alias('x'), 'z': Alias('y')}
286 resolve_aliases(dsk, {'z'}, {'x': {'y'}, 'y': {'z'}}) == {'z': 1}
288 """
289 if not keys:
290 raise ValueError("No keys provided")
291 dsk = dict(dsk)
292 work = list(keys)
293 seen = set()
294 while work:
295 k = work.pop()
296 if k in seen or k not in dsk:
297 continue
298 seen.add(k)
299 t = dsk[k]
300 if isinstance(t, Alias):
301 target_key = t.target
302 # Rules for when we allow to collapse an alias
303 # 1. The target key is not in the keys set. The keys set is what the
304 # user is requesting and by collapsing we'd no longer be able to
305 # return that result.
306 # 2. The target key is in fact part of dsk. If it isnt' this could
307 # point to a persisted dependency and we cannot collapse it.
308 # 3. The target key has only one dependent which is the key we're
309 # currently looking at. This means that there is a one to one
310 # relation between this and the target key in which case we can
311 # collapse them.
312 # Note: If target was an alias as well, we could continue with
313 # more advanced optimizations but this isn't implemented, yet
314 if (
315 target_key not in keys
316 and target_key in dsk
317 # Note: whenever we're performing a collapse, we're not updating
318 # the dependents. The length == 1 should still be sufficient for
319 # chains of these aliases
320 and len(dependents[target_key]) == 1
321 ):
322 tnew = dsk.pop(target_key).copy()
324 dsk[k] = tnew
325 tnew.key = k
326 if isinstance(tnew, Alias):
327 work.append(k)
328 seen.discard(k)
329 else:
330 work.extend(tnew.dependencies)
332 work.extend(t.dependencies)
333 return dsk
336class TaskRef:
337 val: KeyType
338 __slots__ = ("key",)
340 def __init__(self, key: KeyType):
341 self.key = key
343 def __str__(self):
344 return str(self.key)
346 def __repr__(self):
347 return f"{type(self).__name__}({self.key!r})"
349 def __hash__(self) -> int:
350 return hash(self.key)
352 def __eq__(self, value: object) -> bool:
353 if not isinstance(value, TaskRef):
354 return False
355 return self.key == value.key
357 def __reduce__(self):
358 return TaskRef, (self.key,)
360 def substitute(self, subs: dict, key: KeyType | None = None) -> TaskRef | GraphNode:
361 if self.key in subs:
362 val = subs[self.key]
363 if isinstance(val, GraphNode):
364 return val.substitute({}, key=self.key)
365 elif isinstance(val, TaskRef):
366 return val
367 else:
368 return TaskRef(val)
369 return self
372class GraphNode:
373 key: KeyType
374 _dependencies: frozenset
376 __slots__ = tuple(__annotations__)
378 def ref(self):
379 return Alias(self.key)
381 def copy(self):
382 raise NotImplementedError
384 @property
385 def data_producer(self) -> bool:
386 return False
388 @property
389 def dependencies(self) -> frozenset:
390 return self._dependencies
392 @property
393 def block_fusion(self) -> bool:
394 return False
396 def _verify_values(self, values: tuple | dict) -> None:
397 if not self.dependencies:
398 return
399 if missing := set(self.dependencies) - set(values):
400 raise RuntimeError(f"Not enough arguments provided: missing keys {missing}")
402 def __call__(self, values) -> Any:
403 raise NotImplementedError("Not implemented")
405 def __eq__(self, value: object) -> bool:
406 if type(value) is not type(self):
407 return False
409 from dask.tokenize import tokenize
411 return tokenize(self) == tokenize(value)
413 @property
414 def is_coro(self) -> bool:
415 return False
417 def __sizeof__(self) -> int:
418 all_slots = self.get_all_slots()
419 return sum(sizeof(getattr(self, sl)) for sl in all_slots) + sys.getsizeof(
420 type(self)
421 )
423 def substitute(
424 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
425 ) -> GraphNode:
426 """Substitute a dependency with a new value. The new value either has to
427 be a new valid key or a GraphNode to replace the dependency entirely.
429 The GraphNode will not be mutated but instead a shallow copy will be
430 returned. The substitution will be performed eagerly.
432 Parameters
433 ----------
434 subs : dict[KeyType, KeyType | GraphNode]
435 The mapping describing the substitutions to be made.
436 key : KeyType | None, optional
437 The key of the new GraphNode object. If None provided, the key of
438 the old one will be reused.
439 """
440 raise NotImplementedError
442 @staticmethod
443 def fuse(*tasks: GraphNode, key: KeyType | None = None) -> GraphNode:
444 """Fuse a set of tasks into a single task.
446 The tasks are fused into a single task that will execute the tasks in a
447 subgraph. The internal tasks are no longer accessible from the outside.
449 All provided tasks must form a valid subgraph that will reduce to a
450 single key. If multiple outputs are possible with the provided tasks, an
451 exception will be raised.
453 The tasks will not be rewritten but instead a new Task will be created
454 that will merely reference the old task objects. This way, Task objects
455 may be reused in multiple fused tasks.
457 Parameters
458 ----------
459 key : KeyType | None, optional
460 The key of the new Task object. If None provided, the key of the
461 final task will be used.
463 See also
464 --------
465 GraphNode.substitute : Easer substitution of dependencies
466 """
467 if any(t.key is None for t in tasks):
468 raise ValueError("Cannot fuse tasks with missing keys")
469 if len(tasks) == 1:
470 return tasks[0].substitute({}, key=key)
471 all_keys = set()
472 all_deps: set[KeyType] = set()
473 for t in tasks:
474 all_deps.update(t.dependencies)
475 all_keys.add(t.key)
476 external_deps = tuple(sorted(all_deps - all_keys, key=hash))
477 leafs = all_keys - all_deps
478 if len(leafs) > 1:
479 raise ValueError(f"Cannot fuse tasks with multiple outputs {leafs}")
481 outkey = leafs.pop()
482 return Task(
483 key or outkey,
484 _execute_subgraph,
485 {t.key: t for t in tasks},
486 outkey,
487 external_deps,
488 *(TaskRef(k) for k in external_deps),
489 _data_producer=any(t.data_producer for t in tasks),
490 )
492 @classmethod
493 @lru_cache
494 def get_all_slots(cls):
495 slots = list()
496 for c in cls.mro():
497 slots.extend(getattr(c, "__slots__", ()))
498 # Interestingly, sorting this causes the nested containers to pickle
499 # more efficiently
500 return sorted(set(slots))
503_no_deps: frozenset = frozenset()
506class Alias(GraphNode):
507 target: KeyType
508 __slots__ = tuple(__annotations__)
510 def __init__(
511 self, key: KeyType | TaskRef, target: Alias | TaskRef | KeyType | None = None
512 ):
513 if isinstance(key, TaskRef):
514 key = key.key
515 self.key = key
516 if target is None:
517 target = key
518 if isinstance(target, Alias):
519 target = target.target
520 if isinstance(target, TaskRef):
521 target = target.key
522 self.target = target
523 self._dependencies = frozenset((self.target,))
525 def __reduce__(self):
526 return Alias, (self.key, self.target)
528 def copy(self):
529 return Alias(self.key, self.target)
531 def substitute(
532 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
533 ) -> GraphNode:
534 if self.key in subs or self.target in subs:
535 sub_key = subs.get(self.key, self.key)
536 val = subs.get(self.target, self.target)
537 if sub_key == self.key and val == self.target:
538 return self
539 if isinstance(val, (GraphNode, TaskRef)):
540 return val.substitute({}, key=key)
541 if key is None and isinstance(sub_key, GraphNode):
542 raise RuntimeError(
543 f"Invalid substitution encountered {self.key!r} -> {sub_key}"
544 )
545 return Alias(key or sub_key, val) # type: ignore
546 return self
548 def __dask_tokenize__(self):
549 return (type(self).__name__, self.key, self.target)
551 def __call__(self, values=()):
552 self._verify_values(values)
553 return values[self.target]
555 def __repr__(self):
556 if self.key != self.target:
557 return f"Alias({self.key!r}->{self.target!r})"
558 else:
559 return f"Alias({self.key!r})"
561 def __eq__(self, value: object) -> bool:
562 if not isinstance(value, Alias):
563 return False
564 if self.key != value.key:
565 return False
566 if self.target != value.target:
567 return False
568 return True
571class DataNode(GraphNode):
572 value: Any
573 typ: type
574 __slots__ = tuple(__annotations__)
576 def __init__(self, key: Any, value: Any):
577 if key is None:
578 key = (type(value).__name__, next(_anom_count))
579 self.key = key
580 self.value = value
581 self.typ = type(value)
582 self._dependencies = _no_deps
584 @property
585 def data_producer(self) -> bool:
586 return True
588 def copy(self):
589 return DataNode(self.key, self.value)
591 def __call__(self, values=()):
592 return self.value
594 def __repr__(self):
595 return f"DataNode({self.value!r})"
597 def __reduce__(self):
598 return (DataNode, (self.key, self.value))
600 def __dask_tokenize__(self):
601 from dask.base import tokenize
603 return (type(self).__name__, tokenize(self.value))
605 def substitute(
606 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
607 ) -> DataNode:
608 if key is not None and key != self.key:
609 return DataNode(key, self.value)
610 return self
612 def __iter__(self):
613 return iter(self.value)
616def _get_dependencies(obj: object) -> set | frozenset:
617 if isinstance(obj, TaskRef):
618 return {obj.key}
619 elif isinstance(obj, GraphNode):
620 return obj.dependencies
621 elif isinstance(obj, dict):
622 if not obj:
623 return _no_deps
624 return set().union(*map(_get_dependencies, obj.values()))
625 elif isinstance(obj, (list, tuple, frozenset, set)):
626 if not obj:
627 return _no_deps
628 return set().union(*map(_get_dependencies, obj))
629 return _no_deps
632class Task(GraphNode):
633 func: Callable
634 args: tuple
635 kwargs: dict
636 _data_producer: bool
637 _token: str | None
638 _is_coro: bool | None
639 _repr: str | None
641 __slots__ = tuple(__annotations__)
643 def __init__(
644 self,
645 key: Any,
646 func: Callable,
647 /,
648 *args: Any,
649 _data_producer: bool = False,
650 **kwargs: Any,
651 ):
652 self.key = key
653 self.func = func
654 if isinstance(func, Task):
655 raise TypeError("Cannot nest tasks")
657 self.args = args
658 self.kwargs = kwargs
659 _dependencies: set[KeyType] | None = None
660 for a in itertools.chain(args, kwargs.values()):
661 if isinstance(a, TaskRef):
662 if _dependencies is None:
663 _dependencies = {a.key}
664 else:
665 _dependencies.add(a.key)
666 elif isinstance(a, GraphNode) and a.dependencies:
667 if _dependencies is None:
668 _dependencies = set(a.dependencies)
669 else:
670 _dependencies.update(a.dependencies)
671 if _dependencies:
672 self._dependencies = frozenset(_dependencies)
673 else:
674 self._dependencies = _no_deps
675 self._is_coro = None
676 self._token = None
677 self._repr = None
678 self._data_producer = _data_producer
680 @property
681 def data_producer(self) -> bool:
682 return self._data_producer
684 def has_subgraph(self) -> bool:
685 return self.func == _execute_subgraph
687 def copy(self):
688 return type(self)(
689 self.key,
690 self.func,
691 *self.args,
692 **self.kwargs,
693 )
695 def __hash__(self):
696 return hash(self._get_token())
698 def _get_token(self) -> str:
699 if self._token:
700 return self._token
701 from dask.base import tokenize
703 self._token = tokenize(
704 (
705 type(self).__name__,
706 self.func,
707 self.args,
708 self.kwargs,
709 )
710 )
711 return self._token
713 def __dask_tokenize__(self):
714 return self._get_token()
716 def __repr__(self) -> str:
717 # When `Task` is deserialized the constructor will not run and
718 # `self._repr` is thus undefined.
719 if not hasattr(self, "_repr") or not self._repr:
720 head = funcname(self.func)
721 tail = ")"
722 label_size = 40
723 args = self.args
724 kwargs = self.kwargs
725 if args or kwargs:
726 label_size2 = int(
727 (label_size - len(head) - len(tail) - len(str(self.key)))
728 // (len(args) + len(kwargs))
729 )
730 if args:
731 if label_size2 > 5:
732 args_repr = ", ".join(repr(t) for t in args)
733 else:
734 args_repr = "..."
735 else:
736 args_repr = ""
737 if kwargs:
738 if label_size2 > 5:
739 kwargs_repr = ", " + ", ".join(
740 f"{k}={repr(v)}" for k, v in sorted(kwargs.items())
741 )
742 else:
743 kwargs_repr = ", ..."
744 else:
745 kwargs_repr = ""
746 self._repr = f"<Task {self.key!r} {head}({args_repr}{kwargs_repr}{tail}>"
747 return self._repr
749 def __call__(self, values=()):
750 self._verify_values(values)
752 def _eval(a):
753 if isinstance(a, GraphNode):
754 return a({k: values[k] for k in a.dependencies})
755 elif isinstance(a, TaskRef):
756 return values[a.key]
757 else:
758 return a
760 new_argspec = tuple(map(_eval, self.args))
761 if self.kwargs:
762 kwargs = {k: _eval(kw) for k, kw in self.kwargs.items()}
763 return self.func(*new_argspec, **kwargs)
764 return self.func(*new_argspec)
766 def __setstate__(self, state):
767 slots = self.__class__.get_all_slots()
768 for sl, val in zip(slots, state):
769 setattr(self, sl, val)
771 def __getstate__(self):
772 slots = self.__class__.get_all_slots()
773 return tuple(getattr(self, sl) for sl in slots)
775 @property
776 def is_coro(self):
777 if self._is_coro is None:
778 # Note: Can't use cached_property on objects without __dict__
779 try:
780 from distributed.utils import iscoroutinefunction
782 self._is_coro = iscoroutinefunction(self.func)
783 except Exception:
784 self._is_coro = False
785 return self._is_coro
787 def substitute(
788 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
789 ) -> Task:
790 subs_filtered = {
791 k: v for k, v in subs.items() if k in self.dependencies and k != v
792 }
793 extras = _extra_args(type(self)) # type: ignore
794 extra_kwargs = {
795 name: getattr(self, name) for name in extras if name not in {"key", "func"}
796 }
797 if subs_filtered:
798 new_args = tuple(
799 (
800 a.substitute(subs_filtered)
801 if isinstance(a, (GraphNode, TaskRef))
802 else a
803 )
804 for a in self.args
805 )
806 new_kwargs = {
807 k: (
808 v.substitute(subs_filtered)
809 if isinstance(v, (GraphNode, TaskRef))
810 else v
811 )
812 for k, v in self.kwargs.items()
813 }
814 return type(self)(
815 key or self.key,
816 self.func,
817 *new_args,
818 **new_kwargs, # type: ignore[arg-type]
819 **extra_kwargs,
820 )
821 elif key is None or key == self.key:
822 return self
823 else:
824 # Rename
825 return type(self)(
826 key,
827 self.func,
828 *self.args,
829 **self.kwargs,
830 **extra_kwargs,
831 )
834class NestedContainer(Task, Iterable):
835 constructor: Callable
836 klass: type
837 __slots__ = tuple(__annotations__)
839 def __init__(
840 self,
841 /,
842 *args: Any,
843 **kwargs: Any,
844 ):
845 if len(args) == 1 and isinstance(args[0], self.klass):
846 args = args[0] # type: ignore
847 super().__init__(
848 None,
849 self.to_container,
850 *args,
851 constructor=self.constructor,
852 **kwargs,
853 )
855 def __getstate__(self):
856 state = super().__getstate__()
857 state = list(state)
858 slots = self.__class__.get_all_slots()
859 ix = slots.index("kwargs")
860 # The constructor as a kwarg is redundant since this is encoded in the
861 # class itself. Serializing the builtin types is not trivial
862 # This saves about 15% of overhead
863 state[ix] = state[ix].copy()
864 state[ix].pop("constructor", None)
865 return state
867 def __setstate__(self, state):
868 super().__setstate__(state)
869 self.kwargs["constructor"] = self.__class__.constructor
870 return self
872 def __repr__(self):
873 return f"{type(self).__name__}({self.args})"
875 def substitute(
876 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
877 ) -> NestedContainer:
878 subs_filtered = {
879 k: v for k, v in subs.items() if k in self.dependencies and k != v
880 }
881 if not subs_filtered:
882 return self
883 return type(self)(
884 *(
885 (
886 a.substitute(subs_filtered)
887 if isinstance(a, (GraphNode, TaskRef))
888 else a
889 )
890 for a in self.args
891 )
892 )
894 def __dask_tokenize__(self):
895 from dask.tokenize import tokenize
897 return (
898 type(self).__name__,
899 self.klass,
900 sorted(tokenize(a) for a in self.args),
901 )
903 return super().__dask_tokenize__()
905 @staticmethod
906 def to_container(*args, constructor):
907 return constructor(args)
909 def __iter__(self):
910 yield from self.args
913class List(NestedContainer):
914 constructor = klass = list
917class Tuple(NestedContainer):
918 constructor = klass = tuple
921class Set(NestedContainer):
922 constructor = klass = set
925class Dict(NestedContainer, Mapping):
926 klass = dict
928 def __init__(self, /, *args: Any, **kwargs: Any):
929 if args:
930 assert not kwargs
931 if len(args) == 1:
932 args = args[0]
933 if isinstance(args, dict): # type: ignore
934 args = tuple(itertools.chain(*args.items())) # type: ignore
935 elif isinstance(args, (list, tuple)):
936 if all(
937 len(el) == 2 if isinstance(el, (list, tuple)) else False
938 for el in args
939 ):
940 args = tuple(itertools.chain(*args))
941 else:
942 raise ValueError("Invalid argument provided")
944 if len(args) % 2 != 0:
945 raise ValueError("Invalid number of arguments provided")
947 elif kwargs:
948 assert not args
949 args = tuple(itertools.chain(*kwargs.items()))
951 super().__init__(*args)
953 def __repr__(self):
954 values = ", ".join(f"{k}: {v}" for k, v in batched(self.args, 2, strict=True))
955 return f"Dict({values})"
957 def substitute(
958 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None
959 ) -> Dict:
960 subs_filtered = {
961 k: v for k, v in subs.items() if k in self.dependencies and k != v
962 }
963 if not subs_filtered:
964 return self
966 new_args = []
967 for arg in self.args:
968 new_arg = (
969 arg.substitute(subs_filtered)
970 if isinstance(arg, (GraphNode, TaskRef))
971 else arg
972 )
973 new_args.append(new_arg)
974 return type(self)(new_args)
976 def __iter__(self):
977 yield from self.args[::2]
979 def __len__(self):
980 return len(self.args) // 2
982 def __getitem__(self, key):
983 for k, v in batched(self.args, 2, strict=True):
984 if k == key:
985 return v
986 raise KeyError(key)
988 @staticmethod
989 def constructor(args):
990 return dict(batched(args, 2, strict=True))
993class DependenciesMapping(MutableMapping):
994 def __init__(self, dsk):
995 self.dsk = dsk
996 self._removed = set()
997 # Set a copy of dsk to avoid dct resizing
998 self._cache = dsk.copy()
999 self._cache.clear()
1001 def __getitem__(self, key):
1002 if (val := self._cache.get(key)) is not None:
1003 return val
1004 else:
1005 v = self.dsk[key]
1006 try:
1007 deps = v.dependencies
1008 except AttributeError:
1009 from dask.core import get_dependencies
1011 deps = get_dependencies(self.dsk, task=v)
1013 if self._removed:
1014 # deps is a frozenset but for good measure, let's not use -= since
1015 # that _may_ perform an inplace mutation
1016 deps = deps - self._removed
1017 self._cache[key] = deps
1018 return deps
1020 def __iter__(self):
1021 return iter(self.dsk)
1023 def __delitem__(self, key: Any) -> None:
1024 self._cache.clear()
1025 self._removed.add(key)
1027 def __setitem__(self, key: Any, value: Any) -> None:
1028 raise NotImplementedError
1030 def __len__(self) -> int:
1031 return len(self.dsk)
1034class _DevNullMapping(MutableMapping):
1035 def __getitem__(self, key):
1036 raise KeyError(key)
1038 def __setitem__(self, key, value):
1039 pass
1041 def __delitem__(self, key):
1042 pass
1044 def __len__(self):
1045 return 0
1047 def __iter__(self):
1048 return iter(())
1051def execute_graph(
1052 dsk: Iterable[GraphNode] | Mapping[KeyType, GraphNode],
1053 cache: MutableMapping[KeyType, object] | None = None,
1054 keys: Container[KeyType] | None = None,
1055) -> MutableMapping[KeyType, object]:
1056 """Execute a given graph.
1058 The graph is exceuted in topological order as defined by dask.order until
1059 all leaf nodes, i.e. nodes without any dependents, are reached. The returned
1060 dictionary contains the results of the leaf nodes.
1062 If keys are required that are not part of the graph, they can be provided in the `cache` argument.
1064 If `keys` is provided, the result will contain only values that are part of the `keys` set.
1066 """
1067 if isinstance(dsk, (list, tuple, set, frozenset)):
1068 dsk = {t.key: t for t in dsk}
1069 else:
1070 assert isinstance(dsk, dict)
1072 refcount: defaultdict[KeyType, int] = defaultdict(int)
1073 for vals in DependenciesMapping(dsk).values():
1074 for val in vals:
1075 refcount[val] += 1
1077 cache = cache or {}
1078 from dask.order import order
1080 priorities = order(dsk)
1082 for key, node in sorted(dsk.items(), key=lambda it: priorities[it[0]]):
1083 cache[key] = node(cache)
1084 for dep in node.dependencies:
1085 refcount[dep] -= 1
1086 if refcount[dep] == 0 and keys and dep not in keys:
1087 del cache[dep]
1089 return cache
1092def fuse_linear_task_spec(dsk, keys):
1093 """
1094 keys are the keys from the graph that are requested by a computation. We
1095 can't fuse those together.
1096 """
1097 from dask.core import reverse_dict
1098 from dask.optimization import default_fused_keys_renamer
1100 keys = set(keys)
1101 dependencies = DependenciesMapping(dsk)
1102 dependents = reverse_dict(dependencies)
1104 seen = set()
1105 result = {}
1107 for key in dsk:
1108 if key in seen:
1109 continue
1111 seen.add(key)
1113 deps = dependencies[key]
1114 dependents_key = dependents[key]
1116 if len(deps) != 1 and len(dependents_key) != 1 or dsk[key].block_fusion:
1117 result[key] = dsk[key]
1118 continue
1120 linear_chain = [dsk[key]]
1121 top_key = key
1123 # Walk towards the leafs as long as the nodes have a single dependency
1124 # and a single dependent, we can't fuse two nodes of an intermediate node
1125 # is the source for 2 dependents
1126 while len(deps) == 1:
1127 (new_key,) = deps
1128 if new_key in seen:
1129 break
1130 seen.add(new_key)
1131 if new_key not in dsk:
1132 # This can happen if a future is in the graph, the dependency mapping
1133 # adds the key that is referenced by the future as a dependency
1134 # see test_futures_to_delayed_array
1135 break
1136 if (
1137 len(dependents[new_key]) != 1
1138 or dsk[new_key].block_fusion
1139 or new_key in keys
1140 ):
1141 result[new_key] = dsk[new_key]
1142 break
1143 # backwards comp for new names, temporary until is_rootish is removed
1144 linear_chain.insert(0, dsk[new_key])
1145 deps = dependencies[new_key]
1147 # Walk the tree towards the root as long as the nodes have a single dependent
1148 # and a single dependency, we can't fuse two nodes if node has multiple
1149 # dependencies
1150 while len(dependents_key) == 1 and top_key not in keys:
1151 new_key = dependents_key.pop()
1152 if new_key in seen:
1153 break
1154 seen.add(new_key)
1155 if len(dependencies[new_key]) != 1 or dsk[new_key].block_fusion:
1156 # Exit if the dependent has multiple dependencies, triangle
1157 result[new_key] = dsk[new_key]
1158 break
1159 linear_chain.append(dsk[new_key])
1160 top_key = new_key
1161 dependents_key = dependents[new_key]
1163 if len(linear_chain) == 1:
1164 result[top_key] = linear_chain[0]
1165 else:
1166 # Renaming the keys is necessary to preserve the rootish detection for now
1167 renamed_key = default_fused_keys_renamer([tsk.key for tsk in linear_chain])
1168 result[renamed_key] = Task.fuse(*linear_chain, key=renamed_key)
1169 if renamed_key != top_key:
1170 # Having the same prefixes can result in the same key, i.e. getitem-hash -> getitem-hash
1171 result[top_key] = Alias(top_key, target=renamed_key)
1172 return result
1175def cull(
1176 dsk: dict[KeyType, GraphNode], keys: Iterable[KeyType]
1177) -> dict[KeyType, GraphNode]:
1178 if not isinstance(keys, (list, set, tuple)):
1179 raise TypeError(
1180 f"Expected list, set or tuple for keys, got {type(keys).__name__}"
1181 )
1182 if len(keys) == len(dsk):
1183 return dsk
1184 work = set(keys)
1185 seen: set[KeyType] = set()
1186 dsk2 = {}
1187 wpop = work.pop
1188 wupdate = work.update
1189 sadd = seen.add
1190 while work:
1191 k = wpop()
1192 if k in seen or k not in dsk:
1193 continue
1194 sadd(k)
1195 dsk2[k] = v = dsk[k]
1196 wupdate(v.dependencies)
1197 return dsk2
1200@functools.cache
1201def _extra_args(typ: type) -> set[str]:
1202 import inspect
1204 sig = inspect.signature(typ)
1205 extras = set()
1206 for name, param in sig.parameters.items():
1207 if param.kind in (
1208 inspect.Parameter.VAR_POSITIONAL,
1209 inspect.Parameter.VAR_KEYWORD,
1210 ):
1211 continue
1212 if name in typ.get_all_slots(): # type: ignore
1213 extras.add(name)
1214 return extras