Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/core.py: 32%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3from collections import defaultdict
4from collections.abc import Collection, Iterable, Mapping, MutableMapping
5from typing import Any, Literal, TypeVar, cast, overload
7import toolz
9from dask._task_spec import (
10 DependenciesMapping,
11 TaskRef,
12 convert_legacy_graph,
13 execute_graph,
14)
15from dask.typing import Graph, Key, NoDefault, no_default
18def ishashable(x):
19 """Is x hashable?
21 Examples
22 --------
24 >>> ishashable(1)
25 True
26 >>> ishashable([1])
27 False
29 See Also
30 --------
31 iskey
32 """
33 try:
34 hash(x)
35 return True
36 except TypeError:
37 return False
40def istask(x):
41 """Is x a runnable task?
43 A task is a tuple with a callable first argument
45 Examples
46 --------
48 >>> inc = lambda x: x + 1
49 >>> istask((inc, 1))
50 True
51 >>> istask(1)
52 False
53 """
54 from dask._task_spec import DataNode, GraphNode
56 if isinstance(x, GraphNode):
57 if isinstance(x, DataNode):
58 return False
59 return True
60 return type(x) is tuple and x and callable(x[0])
63def preorder_traversal(task):
64 """A generator to preorder-traverse a task."""
66 for item in task:
67 if istask(item):
68 yield from preorder_traversal(item)
69 elif isinstance(item, list):
70 yield list
71 yield from preorder_traversal(item)
72 else:
73 yield item
76def lists_to_tuples(res, keys):
77 if isinstance(keys, list):
78 return tuple(lists_to_tuples(r, k) for r, k in zip(res, keys))
79 return res
82def _pack_result(result: Mapping, keys: list | Key) -> Any:
83 if isinstance(keys, list):
84 return tuple(_pack_result(result, k) for k in keys)
85 return result[keys]
88def get(dsk: Mapping, out: list | Key, cache: MutableMapping | None = None) -> Any:
89 """Get value from Dask
91 Examples
92 --------
94 >>> inc = lambda x: x + 1
95 >>> d = {'x': 1, 'y': (inc, 'x')}
97 >>> get(d, 'x')
98 1
99 >>> get(d, 'y')
100 2
101 """
102 for k in flatten(out):
103 if k not in dsk:
104 raise KeyError(f"{k} is not a key in the graph")
105 if cache is None:
106 cache = {}
108 dsk2 = convert_legacy_graph(dsk, all_keys=set(dsk) | set(cache))
109 result = execute_graph(dsk2, cache, keys=set(flatten([out])))
110 return _pack_result(result, out)
113def keys_in_tasks(keys: Collection[Key], tasks: Iterable[Any], as_list: bool = False):
114 """Returns the keys in `keys` that are also in `tasks`
116 Examples
117 --------
118 >>> inc = lambda x: x + 1
119 >>> add = lambda x, y: x + y
120 >>> dsk = {'x': 1,
121 ... 'y': (inc, 'x'),
122 ... 'z': (add, 'x', 'y'),
123 ... 'w': (inc, 'z'),
124 ... 'a': (add, (inc, 'x'), 1)}
126 >>> keys_in_tasks(dsk, ['x', 'y', 'j']) # doctest: +SKIP
127 {'x', 'y'}
128 """
129 from dask._task_spec import GraphNode
131 ret: list[Key] = []
132 while tasks:
133 work = []
134 for w in tasks:
135 typ = type(w)
136 if typ is tuple and w and callable(w[0]): # istask(w)
137 work.extend(w[1:])
138 elif typ is list:
139 work.extend(w)
140 elif typ is dict:
141 work.extend(w.values())
142 elif isinstance(w, GraphNode):
143 work.extend(w.dependencies)
144 elif isinstance(w, TaskRef):
145 work.append(w.key)
146 else:
147 try:
148 if w in keys:
149 ret.append(w)
150 except TypeError: # not hashable
151 pass
152 tasks = work
153 return ret if as_list else set(ret)
156def iskey(key: object) -> bool:
157 """Return True if the given object is a potential dask key; False otherwise.
159 The definition of a key in a Dask graph is any str, int, float, or tuple
160 thereof.
162 See Also
163 --------
164 ishashable
165 validate_key
166 dask.typing.Key
167 """
168 typ = type(key)
169 if typ is tuple:
170 return all(iskey(i) for i in cast(tuple, key))
171 return typ in {int, float, str}
174def validate_key(key: object) -> None:
175 """Validate the format of a dask key.
177 See Also
178 --------
179 iskey
180 """
181 if iskey(key):
182 return
183 typ = type(key)
185 if typ is tuple:
186 index = None
187 try:
188 for index, part in enumerate(cast(tuple, key)): # noqa: B007
189 validate_key(part)
190 except TypeError as e:
191 raise TypeError(
192 f"Composite key contains unexpected key type at {index=} ({key=!r})"
193 ) from e
194 raise TypeError(f"Unexpected key type {typ} ({key=!r})")
197@overload
198def get_dependencies(
199 dsk: Graph,
200 key: Key | None = ...,
201 task: Key | NoDefault = ...,
202 as_list: Literal[False] = ...,
203) -> set[Key]: ...
206@overload
207def get_dependencies(
208 dsk: Graph,
209 key: Key | None,
210 task: Key | NoDefault,
211 as_list: Literal[True],
212) -> list[Key]: ...
215def get_dependencies(
216 dsk: Graph,
217 key: Key | None = None,
218 task: Key | NoDefault = no_default,
219 as_list: bool = False,
220) -> set[Key] | list[Key]:
221 """Get the immediate tasks on which this task depends
223 Examples
224 --------
225 >>> inc = lambda x: x + 1
226 >>> add = lambda x, y: x + y
227 >>> dsk = {'x': 1,
228 ... 'y': (inc, 'x'),
229 ... 'z': (add, 'x', 'y'),
230 ... 'w': (inc, 'z'),
231 ... 'a': (add, (inc, 'x'), 1)}
233 >>> get_dependencies(dsk, 'x')
234 set()
236 >>> get_dependencies(dsk, 'y')
237 {'x'}
239 >>> get_dependencies(dsk, 'z') # doctest: +SKIP
240 {'x', 'y'}
242 >>> get_dependencies(dsk, 'w') # Only direct dependencies
243 {'z'}
245 >>> get_dependencies(dsk, 'a') # Ignore non-keys
246 {'x'}
248 >>> get_dependencies(dsk, task=(inc, 'x')) # provide tasks directly
249 {'x'}
250 """
251 if key is not None:
252 arg = dsk[key]
253 elif task is not no_default:
254 arg = task
255 else:
256 raise ValueError("Provide either key or task")
258 return keys_in_tasks(dsk, [arg], as_list=as_list)
261def get_deps(dsk: Graph) -> tuple[dict[Key, set[Key]], dict[Key, set[Key]]]:
262 """Get dependencies and dependents from dask dask graph
264 >>> inc = lambda x: x + 1
265 >>> dsk = {'a': 1, 'b': (inc, 'a'), 'c': (inc, 'b')}
266 >>> dependencies, dependents = get_deps(dsk)
267 >>> dependencies
268 {'a': set(), 'b': {'a'}, 'c': {'b'}}
269 >>> dependents # doctest: +SKIP
270 {'a': {'b'}, 'b': {'c'}, 'c': set()}
271 """
272 dependencies = {k: get_dependencies(dsk, task=v) for k, v in dsk.items()}
273 dependents = reverse_dict(dependencies)
274 return dependencies, dependents
277def flatten(seq, container=list):
278 """
280 >>> list(flatten([1]))
281 [1]
283 >>> list(flatten([[1, 2], [1, 2]]))
284 [1, 2, 1, 2]
286 >>> list(flatten([[[1], [2]], [[1], [2]]]))
287 [1, 2, 1, 2]
289 >>> list(flatten(((1, 2), (1, 2)))) # Don't flatten tuples
290 [(1, 2), (1, 2)]
292 >>> list(flatten((1, 2, [3, 4]))) # support heterogeneous
293 [1, 2, 3, 4]
294 """
295 if isinstance(seq, str):
296 yield seq
297 else:
298 for item in seq:
299 if isinstance(item, container):
300 yield from flatten(item, container=container)
301 else:
302 yield item
305T_ = TypeVar("T_")
308def reverse_dict(d: Mapping[T_, Iterable[T_]]) -> dict[T_, set[T_]]:
309 """
311 >>> a, b, c = 'abc'
312 >>> d = {a: [b, c], b: [c]}
313 >>> reverse_dict(d) # doctest: +SKIP
314 {'a': set([]), 'b': set(['a']}, 'c': set(['a', 'b'])}
315 """
316 result: defaultdict[T_, set[T_]] = defaultdict(set)
317 _add = set.add
318 for k, vals in d.items():
319 result[k]
320 for val in vals:
321 _add(result[val], k)
322 return dict(result)
325def subs(task, key, val):
326 """Perform a substitution on a task
328 Examples
329 --------
330 >>> def inc(x):
331 ... return x + 1
333 >>> subs((inc, 'x'), 'x', 1) # doctest: +ELLIPSIS
334 (<function inc at ...>, 1)
335 """
336 type_task = type(task)
337 if not (type_task is tuple and task and callable(task[0])): # istask(task):
338 try:
339 if type_task is type(key) and task == key:
340 return val
341 except Exception:
342 pass
343 if type_task is list:
344 return [subs(x, key, val) for x in task]
345 return task
346 newargs = []
347 hash_key = {key}
348 for arg in task[1:]:
349 type_arg = type(arg)
350 if type_arg is tuple and arg and callable(arg[0]): # istask(task):
351 arg = subs(arg, key, val)
352 elif type_arg is list:
353 arg = [subs(x, key, val) for x in arg]
354 else:
355 try:
356 if arg in hash_key: # Hash and equality match
357 arg = val
358 except TypeError: # not hashable
359 pass
360 newargs.append(arg)
361 return task[:1] + tuple(newargs)
364def _toposort(dsk, keys=None, returncycle=False, dependencies=None):
366 # Stack-based depth-first search traversal. This is based on Tarjan's
367 # method for topological sorting (see wikipedia for pseudocode)
368 if keys is None:
369 keys = dsk
370 elif not isinstance(keys, list):
371 keys = [keys]
372 if not returncycle:
373 ordered = []
375 # Nodes whose descendents have been completely explored.
376 # These nodes are guaranteed to not be part of a cycle.
377 completed = set()
379 # All nodes that have been visited in the current traversal. Because
380 # we are doing depth-first search, going "deeper" should never result
381 # in visiting a node that has already been seen. The `seen` and
382 # `completed` sets are mutually exclusive; it is okay to visit a node
383 # that has already been added to `completed`.
384 seen = set()
386 if dependencies is None:
388 dependencies = DependenciesMapping(dsk)
390 for key in keys:
391 if key in completed:
392 continue
393 nodes = [key]
394 while nodes:
395 # Keep current node on the stack until all descendants are visited
396 cur = nodes[-1]
397 if cur in completed:
398 # Already fully traversed descendants of cur
399 nodes.pop()
400 continue
401 seen.add(cur)
403 # Add direct descendants of cur to nodes stack
404 next_nodes = []
405 for nxt in dependencies[cur]:
406 if nxt not in completed:
407 if nxt in seen:
408 # Cycle detected!
409 # Let's report only the nodes that directly participate in the cycle.
410 # We use `priorities` below to greedily construct a short cycle.
411 # Shorter cycles may exist.
412 priorities = {}
413 prev = nodes[-1]
414 # Give priority to nodes that were seen earlier.
415 while nodes[-1] != nxt:
416 priorities[nodes.pop()] = -len(priorities)
417 priorities[nxt] = -len(priorities)
418 # We're going to get the cycle by walking backwards along dependents,
419 # so calculate dependents only for the nodes in play.
420 inplay = set(priorities)
421 dependents = reverse_dict(
422 {k: inplay.intersection(dependencies[k]) for k in inplay}
423 )
424 # Begin with the node that was seen twice and the node `prev` from
425 # which we detected the cycle.
426 cycle = [nodes.pop()]
427 cycle.append(prev)
428 while prev != cycle[0]:
429 # Greedily take a step that takes us closest to completing the cycle.
430 # This may not give us the shortest cycle, but we get *a* short cycle.
431 deps = dependents[cycle[-1]]
432 prev = min(deps, key=priorities.__getitem__)
433 cycle.append(prev)
434 cycle.reverse()
436 if returncycle:
437 return cycle
438 else:
439 cycle = "->".join(str(x) for x in cycle)
440 raise RuntimeError("Cycle detected in Dask: %s" % cycle)
441 next_nodes.append(nxt)
443 if next_nodes:
444 nodes.extend(next_nodes)
445 else:
446 # cur has no more descendants to explore, so we're done with it
447 if not returncycle:
448 ordered.append(cur)
449 completed.add(cur)
450 seen.remove(cur)
451 nodes.pop()
452 if returncycle:
453 return []
454 return ordered
457def toposort(dsk, dependencies=None):
458 """Return a list of keys of dask sorted in topological order."""
459 return _toposort(dsk, dependencies=dependencies)
462def getcycle(d, keys):
463 """Return a list of nodes that form a cycle if Dask is not a DAG.
465 Returns an empty list if no cycle is found.
467 ``keys`` may be a single key or list of keys.
469 Examples
470 --------
472 >>> inc = lambda x: x + 1
473 >>> d = {'x': (inc, 'z'), 'y': (inc, 'x'), 'z': (inc, 'y')}
474 >>> getcycle(d, 'x')
475 ['x', 'z', 'y', 'x']
477 See Also
478 --------
479 isdag
480 """
481 return _toposort(d, keys=keys, returncycle=True)
484def isdag(d, keys):
485 """Does Dask form a directed acyclic graph when calculating keys?
487 ``keys`` may be a single key or list of keys.
489 Examples
490 --------
492 >>> inc = lambda x: x + 1
493 >>> inc = lambda x: x + 1
494 >>> isdag({'x': 0, 'y': (inc, 'x')}, 'y')
495 True
496 >>> isdag({'x': (inc, 'y'), 'y': (inc, 'x')}, 'y')
497 False
499 See Also
500 --------
501 getcycle
502 """
503 return not getcycle(d, keys)
506class literal:
507 """A small serializable object to wrap literal values without copying"""
509 __slots__ = ("data",)
511 def __init__(self, data):
512 self.data = data
514 def __repr__(self):
515 return "literal<type=%s>" % type(self.data).__name__
517 def __reduce__(self):
518 return (literal, (self.data,))
520 def __call__(self):
521 return self.data
524def quote(x):
525 """Ensure that this value remains this value in a dask graph
527 Some values in dask graph take on special meaning. Sometimes we want to
528 ensure that our data is not interpreted but remains literal.
530 >>> add = lambda x, y: x + y
531 >>> quote((add, 1, 2))
532 (literal<type=tuple>,)
533 """
534 if istask(x) or type(x) is list or type(x) is dict:
535 return (literal(x),)
536 return x
539def reshapelist(shape, seq):
540 """Reshape iterator to nested shape
542 >>> reshapelist((2, 3), range(6))
543 [[0, 1, 2], [3, 4, 5]]
544 """
545 if len(shape) == 1:
546 return list(seq)
547 else:
548 n = int(len(seq) / shape[0])
549 return [reshapelist(shape[1:], part) for part in toolz.partition(n, seq)]