Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/delayed.py: 38%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import operator
4import types
5import uuid
6import warnings
7from collections.abc import Sequence
8from dataclasses import fields, is_dataclass, replace
9from functools import partial
11import toolz
12from tlz import concat, curry, merge
14from dask import base, config, utils
15from dask._expr import FinalizeCompute, ProhibitReuse, _ExprSequence
16from dask._task_spec import (
17 DataNode,
18 Dict,
19 GraphNode,
20 List,
21 Task,
22 TaskRef,
23 convert_legacy_graph,
24 fuse_linear_task_spec,
25)
26from dask.base import (
27 DaskMethodsMixin,
28 collections_to_expr,
29 is_dask_collection,
30 named_schedulers,
31 replace_name_in_key,
32)
33from dask.base import tokenize as _tokenize
34from dask.context import globalmethod
35from dask.core import flatten, quote
36from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
37from dask.typing import Graph, NestedKeys
38from dask.utils import (
39 OperatorMethodMixin,
40 apply,
41 ensure_dict,
42 funcname,
43 is_namedtuple_instance,
44 methodcaller,
45 unzip,
46)
48__all__ = ["Delayed", "delayed"]
51DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])
54def finalize(collection):
55 assert is_dask_collection(collection)
57 name = "finalize-" + tokenize(collection)
58 expr = collections_to_expr(collection).finalize_compute()
59 return Delayed(name, expr)
62def _convert_dask_keys(keys: NestedKeys) -> List:
63 assert isinstance(keys, list)
64 new_keys: list[List | TaskRef] = []
65 for key in keys:
66 if isinstance(key, list):
67 new_keys.append(_convert_dask_keys(key))
68 else:
69 new_keys.append(TaskRef(key))
70 return List(*new_keys)
73def _get_partial(key, dct, default):
74 return dct.get(key, default)
77def _finalize_args_collections(args, collections):
78 old_keys = [c.__dask_keys__()[0] for c in collections]
79 from dask._task_spec import cull
81 collections = _ExprSequence(*collections).optimize()
82 new_keys = collections.__dask_keys__()
83 dsk = convert_legacy_graph(collections.__dask_graph__())
84 annots = collections.__dask_annotations__()
85 outcollections = []
86 for k in new_keys:
87 # Annotations are defined per HLG Layer but after this transformation
88 # these no longer properly exist which is why __dask_annotations__
89 # returns a fully materialized dictionary {annot: {key: value}}
90 # Introducing a tombstone with a callable is the only way I found how we
91 # could revert this transformation (not necessarily efficient but
92 # well...)
93 layer_annotations = {
94 annot: partial(
95 _get_partial, dct=key_val, default=collections._annotations_tombstone()
96 )
97 for annot, key_val in annots.items()
98 }
99 hlg = HighLevelGraph(
100 {
101 k[0]: MaterializedLayer(
102 cull(dsk, [k[0]]),
103 annotations=layer_annotations,
104 )
105 },
106 dependencies={k[0]: set()},
107 )
108 outcollections.append(Delayed(k[0], hlg))
109 collections = tuple(outcollections)
110 subs = {old: new[0] for old, new in zip(old_keys, new_keys) if old != new}
111 args = args.substitute(subs)
112 return args, collections
115def unpack_collections(expr, _return_collections=True):
116 """Normalize a python object and merge all sub-graphs.
118 - Replace ``Delayed`` with their keys
119 - Convert literals to things the schedulers can handle
120 - Extract dask graphs from all enclosed values.
122 Note, that the returned _task_ is not necessarily runnable and the caller is
123 responsible to deal with the output types accordingly.
125 The task is one of
126 - `TaskRef` as a pointer to the collection returned in collections. This is
127 not callable and should not be a top-level member of a dask task graph.
128 - A runnable task (i.e. subclass `GraphNode`) which can be embedded
129 directly into a task graph. This indicates that a dask collection was
130 encountered on a deeper nesting level and this runnable task restores the
131 input nesting with the computed dask collection replaced.
132 - The unaltered object as provided if no dask collections are found.
134 Parameters
135 ----------
136 expr : object
137 The object to be normalized. This function knows how to handle
138 dask collections, as well as most builtin python types.
140 _optimize_collections: bool, optional
141 Internal use only!
144 Returns
145 -------
146 task : object
147 collections : a tuple of collections
149 Examples
150 --------
151 >>> import dask
152 >>> a = delayed(1, 'a')
153 >>> b = delayed(2, 'b')
154 >>> task, collections = unpack_collections([a, b, 3])
155 >>> task
156 List((TaskRef('a'), TaskRef('b'), 3))
157 >>> collections
158 (Delayed('a'), Delayed('b'))
160 >>> task, collections = unpack_collections({a: 1, b: 2})
161 >>> task
162 Dict(a: 1, b: 2)
163 >>> collections
164 (Delayed('a'), Delayed('b'))
165 """
166 if isinstance(expr, Delayed):
167 if _return_collections:
168 return TaskRef(expr._key), (expr,)
169 else:
170 expr = collections_to_expr(expr).finalize_compute()
171 (name,) = expr.__dask_keys__()
172 return TaskRef(name), (expr,)
174 # FIXME: Make this not trigger materialization
175 # Currently this is checking with hasattr for __dask_graph__ which triggers
176 # a materialization
177 if base.is_dask_collection(expr):
178 if _return_collections:
179 expr2 = ProhibitReuse(collections_to_expr(expr).finalize_compute())
180 finalized = expr2.optimize()
181 # FIXME: Make this also go away
182 dsk = finalized.__dask_graph__()
183 keys = list(flatten(finalized.__dask_keys__()))
184 if len(keys) > 1:
185 # `finalize_compute` _should_ guarantee that we only have one key
186 raise RuntimeError(
187 "Cannot unpack dask collections which don't finalize to a "
188 f"single key. Got {type(expr)} with {keys=}",
189 )
191 return unpack_collections(Delayed(keys[0], dsk))
192 else:
193 expr = collections_to_expr(expr).finalize_compute()
194 (name,) = expr.__dask_keys__()
195 return TaskRef(name), (expr,)
197 if type(expr) is type(iter(list())):
198 expr = list(expr)
199 elif type(expr) is type(iter(tuple())):
200 expr = tuple(expr)
201 elif type(expr) is type(iter(set())):
202 expr = set(expr)
204 typ = type(expr)
206 if typ in (list, tuple, set):
207 args, collections = utils.unzip(
208 (unpack_collections(e, _return_collections=False) for e in expr), 2
209 )
211 collections = tuple(toolz.unique(toolz.concat(collections), key=id))
212 # The List constructor also checks for futures
213 args = List(*args)
214 if not collections and not args.dependencies:
215 return expr, ()
216 if _return_collections:
217 args, collections = _finalize_args_collections(args, collections)
218 # Ensure output type matches input type
219 if typ is not list:
220 args = Task(None, typ, args)
221 return args, collections
223 if typ is dict:
224 keyargs, kcollections = unpack_collections(
225 [k for k in expr.keys()], _return_collections=False
226 )
227 valargs, valcollections = unpack_collections(
228 [v for v in expr.values()], _return_collections=False
229 )
230 collections = kcollections + valcollections
231 args = Dict([[k, v] for k, v in zip(keyargs, valargs)])
232 if not collections and not args.dependencies:
233 return expr, ()
234 if _return_collections:
235 args, collections = _finalize_args_collections(args, collections)
236 return args, collections
238 if typ is slice:
239 args, collections = unpack_collections(
240 [expr.start, expr.stop, expr.step], _return_collections=False
241 )
242 if not collections and not isinstance(args, GraphNode):
243 return expr, ()
245 if _return_collections:
246 args, collections = _finalize_args_collections(args, collections)
247 return Task(None, apply, slice, args), collections
249 if is_dataclass(expr):
250 args, collections = unpack_collections(
251 [
252 [f.name, getattr(expr, f.name)]
253 for f in fields(expr)
254 if hasattr(expr, f.name) # if init=False, field might not exist
255 ],
256 _return_collections=False,
257 )
258 if not collections and not isinstance(args, GraphNode):
259 return expr, ()
261 if _return_collections:
262 args, collections = _finalize_args_collections(args, collections)
263 try:
264 _fields = {
265 f.name: getattr(expr, f.name)
266 for f in fields(expr)
267 if hasattr(expr, f.name)
268 }
269 replace(expr, **_fields)
270 except (TypeError, ValueError) as e:
271 if isinstance(e, ValueError) or "is declared with init=False" in str(e):
272 raise ValueError(
273 f"Failed to unpack {typ} instance. "
274 "Note that using fields with `init=False` are not supported."
275 ) from e
276 else:
277 raise TypeError(
278 f"Failed to unpack {typ} instance. "
279 "Note that using a custom __init__ is not supported."
280 ) from e
281 return Task(None, apply, typ, (), Task(None, dict, args)), collections
283 if utils.is_namedtuple_instance(expr):
284 args, collections = unpack_collections(
285 tuple(v for v in expr), _return_collections=False
286 )
287 if not collections:
288 return expr, ()
289 if _return_collections:
290 args, collections = _finalize_args_collections(args, collections)
291 return Task(None, _reconstruct_namedtuple, typ, args), collections
293 return expr, ()
296def _reconstruct_namedtuple(typ, fields):
297 return typ(*fields)
300def to_task_dask(expr):
301 """Normalize a python object and merge all sub-graphs.
303 - Replace ``Delayed`` with their keys
304 - Convert literals to things the schedulers can handle
305 - Extract dask graphs from all enclosed values
307 Parameters
308 ----------
309 expr : object
310 The object to be normalized. This function knows how to handle
311 ``Delayed``s, as well as most builtin python types.
313 Returns
314 -------
315 task : normalized task to be run
316 dask : a merged dask graph that forms the dag for this task
318 Examples
319 --------
320 >>> import dask
321 >>> a = delayed(1, 'a')
322 >>> b = delayed(2, 'b')
323 >>> task, dask = to_task_dask([a, b, 3]) # doctest: +SKIP
324 >>> task # doctest: +SKIP
325 ['a', 'b', 3]
326 >>> dict(dask) # doctest: +SKIP
327 {'a': 1, 'b': 2}
329 >>> task, dasks = to_task_dask({a: 1, b: 2}) # doctest: +SKIP
330 >>> task # doctest: +SKIP
331 (dict, [['a', 1], ['b', 2]])
332 >>> dict(dask) # doctest: +SKIP
333 {'a': 1, 'b': 2}
334 """
335 warnings.warn(
336 "The dask.delayed.to_dask_dask function has been "
337 "Deprecated in favor of unpack_collections",
338 stacklevel=2,
339 )
341 if isinstance(expr, Delayed):
342 return expr.key, expr.dask
344 if is_dask_collection(expr):
345 expr = collections_to_expr(expr)
346 expr = FinalizeCompute(expr)
347 expr = expr.optimize()
348 (name,) = expr.__dask_keys__()
349 return name, expr.__dask_graph__()
351 if type(expr) is type(iter(list())):
352 expr = list(expr)
353 elif type(expr) is type(iter(tuple())):
354 expr = tuple(expr)
355 elif type(expr) is type(iter(set())):
356 expr = set(expr)
357 typ = type(expr)
359 if typ in (list, tuple, set):
360 args, dasks = unzip((to_task_dask(e) for e in expr), 2)
361 args = list(args)
362 dsk = merge(dasks)
363 # Ensure output type matches input type
364 return (args, dsk) if typ is list else ((typ, args), dsk)
366 if typ is dict:
367 args, dsk = to_task_dask([[k, v] for k, v in expr.items()])
368 return (dict, args), dsk
370 if is_dataclass(expr):
371 args, dsk = to_task_dask(
372 [
373 [f.name, getattr(expr, f.name)]
374 for f in fields(expr)
375 if hasattr(expr, f.name) # if init=False, field might not exist
376 ]
377 )
379 return (apply, typ, (), (dict, args)), dsk
381 if is_namedtuple_instance(expr):
382 args, dsk = to_task_dask([v for v in expr])
383 return (typ, *args), dsk
385 if typ is slice:
386 args, dsk = to_task_dask([expr.start, expr.stop, expr.step])
387 return (slice,) + tuple(args), dsk
389 return expr, {}
392def tokenize(*args, pure=None, **kwargs):
393 """Mapping function from task -> consistent name.
395 Parameters
396 ----------
397 args : object
398 Python objects that summarize the task.
399 pure : boolean, optional
400 If True, a consistent hash function is tried on the input. If this
401 fails, then a unique identifier is used. If False (default), then a
402 unique identifier is always used.
403 """
404 if pure is None:
405 pure = config.get("delayed_pure", False)
407 if pure:
408 return _tokenize(*args, **kwargs)
409 else:
410 return str(uuid.uuid4())
413@curry
414def delayed(obj, name=None, pure=None, nout=None, traverse=True):
415 """Wraps a function or object to produce a ``Delayed``.
417 ``Delayed`` objects act as proxies for the object they wrap, but all
418 operations on them are done lazily by building up a dask graph internally.
420 Parameters
421 ----------
422 obj : object
423 The function or object to wrap
424 name : Dask key, optional
425 The key to use in the underlying graph for the wrapped object. Defaults
426 to hashing content. Note that this only affects the name of the object
427 wrapped by this call to delayed, and *not* the output of delayed
428 function calls - for that use ``dask_key_name=`` as described below.
430 .. note::
432 Because this ``name`` is used as the key in task graphs, you should
433 ensure that it uniquely identifies ``obj``. If you'd like to provide
434 a descriptive name that is still unique, combine the descriptive name
435 with :func:`dask.base.tokenize` of the ``array_like``. See
436 :ref:`graphs` for more.
438 pure : bool, optional
439 Indicates whether calling the resulting ``Delayed`` object is a pure
440 operation. If True, arguments to the call are hashed to produce
441 deterministic keys. If not provided, the default is to check the global
442 ``delayed_pure`` setting, and fallback to ``False`` if unset.
443 nout : int, optional
444 The number of outputs returned from calling the resulting ``Delayed``
445 object. If provided, the ``Delayed`` output of the call can be iterated
446 into ``nout`` objects, allowing for unpacking of results. By default
447 iteration over ``Delayed`` objects will error. Note, that ``nout=1``
448 expects ``obj`` to return a tuple of length 1, and consequently for
449 ``nout=0``, ``obj`` should return an empty tuple.
450 traverse : bool, optional
451 By default dask traverses builtin python collections looking for dask
452 objects passed to ``delayed``. For large collections this can be
453 expensive. If ``obj`` doesn't contain any dask objects, set
454 ``traverse=False`` to avoid doing this traversal.
456 Examples
457 --------
458 Apply to functions to delay execution:
460 >>> from dask import delayed
461 >>> def inc(x):
462 ... return x + 1
464 >>> inc(10)
465 11
467 >>> x = delayed(inc, pure=True)(10)
468 >>> type(x) == Delayed
469 True
470 >>> x.compute()
471 11
473 Can be used as a decorator:
475 >>> @delayed(pure=True)
476 ... def add(a, b):
477 ... return a + b
478 >>> add(1, 2).compute()
479 3
481 ``delayed`` also accepts an optional keyword ``pure``. If False, then
482 subsequent calls will always produce a different ``Delayed``. This is
483 useful for non-pure functions (such as ``time`` or ``random``).
485 >>> from random import random
486 >>> out1 = delayed(random, pure=False)()
487 >>> out2 = delayed(random, pure=False)()
488 >>> out1.key == out2.key
489 False
491 If you know a function is pure (output only depends on the input, with no
492 global state), then you can set ``pure=True``. This will attempt to apply a
493 consistent name to the output, but will fallback on the same behavior of
494 ``pure=False`` if this fails.
496 >>> @delayed(pure=True)
497 ... def add(a, b):
498 ... return a + b
499 >>> out1 = add(1, 2)
500 >>> out2 = add(1, 2)
501 >>> out1.key == out2.key
502 True
504 Instead of setting ``pure`` as a property of the callable, you can also set
505 it contextually using the ``delayed_pure`` setting. Note that this
506 influences the *call* and not the *creation* of the callable:
508 >>> @delayed
509 ... def mul(a, b):
510 ... return a * b
511 >>> import dask
512 >>> with dask.config.set(delayed_pure=True):
513 ... print(mul(1, 2).key == mul(1, 2).key)
514 True
515 >>> with dask.config.set(delayed_pure=False):
516 ... print(mul(1, 2).key == mul(1, 2).key)
517 False
519 The key name of the result of calling a delayed object is determined by
520 hashing the arguments by default. To explicitly set the name, you can use
521 the ``dask_key_name`` keyword when calling the function:
523 >>> add(1, 2) # doctest: +SKIP
524 Delayed('add-3dce7c56edd1ac2614add714086e950f')
525 >>> add(1, 2, dask_key_name='three')
526 Delayed('three')
528 Note that objects with the same key name are assumed to have the same
529 result. If you set the names explicitly you should make sure your key names
530 are different for different results.
532 >>> add(1, 2, dask_key_name='three')
533 Delayed('three')
534 >>> add(2, 1, dask_key_name='three')
535 Delayed('three')
536 >>> add(2, 2, dask_key_name='four')
537 Delayed('four')
539 ``delayed`` can also be applied to objects to make operations on them lazy:
541 >>> a = delayed([1, 2, 3])
542 >>> isinstance(a, Delayed)
543 True
544 >>> a.compute()
545 [1, 2, 3]
547 The key name of a delayed object is hashed by default if ``pure=True`` or
548 is generated randomly if ``pure=False`` (default). To explicitly set the
549 name, you can use the ``name`` keyword. To ensure that the key is unique
550 you should include the tokenized value as well, or otherwise ensure that
551 it's unique:
553 >>> from dask.base import tokenize
554 >>> data = [1, 2, 3]
555 >>> a = delayed(data, name='mylist-' + tokenize(data))
556 >>> a # doctest: +SKIP
557 Delayed('mylist-55af65871cb378a4fa6de1660c3e8fb7')
559 Delayed results act as a proxy to the underlying object. Many operators
560 are supported:
562 >>> (a + [1, 2]).compute()
563 [1, 2, 3, 1, 2]
564 >>> a[1].compute()
565 2
567 Method and attribute access also works:
569 >>> a.count(2).compute()
570 1
572 Note that if a method doesn't exist, no error will be thrown until runtime:
574 >>> res = a.not_a_real_method() # doctest: +SKIP
575 >>> res.compute() # doctest: +SKIP
576 AttributeError("'list' object has no attribute 'not_a_real_method'")
578 "Magic" methods (e.g. operators and attribute access) are assumed to be
579 pure, meaning that subsequent calls must return the same results. This
580 behavior is not overridable through the ``delayed`` call, but can be
581 modified using other ways as described below.
583 To invoke an impure attribute or operator, you'd need to use it in a
584 delayed function with ``pure=False``:
586 >>> class Incrementer:
587 ... def __init__(self):
588 ... self._n = 0
589 ... @property
590 ... def n(self):
591 ... self._n += 1
592 ... return self._n
593 ...
594 >>> x = delayed(Incrementer())
595 >>> x.n.key == x.n.key
596 True
597 >>> get_n = delayed(lambda x: x.n, pure=False)
598 >>> get_n(x).key == get_n(x).key
599 False
601 In contrast, methods are assumed to be impure by default, meaning that
602 subsequent calls may return different results. To assume purity, set
603 ``pure=True``. This allows sharing of any intermediate values.
605 >>> a.count(2, pure=True).key == a.count(2, pure=True).key
606 True
608 As with function calls, method calls also respect the global
609 ``delayed_pure`` setting and support the ``dask_key_name`` keyword:
611 >>> a.count(2, dask_key_name="count_2")
612 Delayed('count_2')
613 >>> import dask
614 >>> with dask.config.set(delayed_pure=True):
615 ... print(a.count(2).key == a.count(2).key)
616 True
617 """
618 if isinstance(obj, Delayed):
619 return obj
621 if is_dask_collection(obj) or traverse:
622 task, collections = unpack_collections(obj)
623 else:
624 task = quote(obj)
625 collections = set()
627 if not (nout is None or (type(nout) is int and nout >= 0)):
628 raise ValueError("nout must be None or a non-negative integer, got %s" % nout)
629 if task is obj:
630 if isinstance(obj, TaskRef):
631 name = obj.key
632 elif not name:
633 try:
634 prefix = obj.__name__
635 except AttributeError:
636 prefix = type(obj).__name__
637 token = tokenize(obj, nout, pure=pure)
638 name = f"{prefix}-{token}"
639 return DelayedLeaf(obj, name, pure=pure, nout=nout)
640 else:
641 if not name:
642 name = f"{type(obj).__name__}-{tokenize(task, pure=pure)}"
643 layer = {name: task}
644 if isinstance(task, GraphNode):
645 task.key = name
646 graph = HighLevelGraph.from_collections(name, layer, dependencies=collections)
647 return Delayed(name, graph, nout)
650def _swap(method, self, other):
651 return method(other, self)
654def right(method):
655 """Wrapper to create 'right' version of operator given left version"""
656 return partial(_swap, method)
659def optimize(dsk, keys, **kwargs):
660 if not isinstance(keys, (list, set)):
661 keys = [keys]
663 if config.get("optimization.fuse.delayed"):
664 dsk = ensure_dict(dsk)
665 dsk = fuse_linear_task_spec(dsk, keys, **kwargs)
667 if not isinstance(dsk, HighLevelGraph):
668 dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
669 dsk = dsk.cull(set(flatten(keys)))
670 return dsk
673class Delayed(DaskMethodsMixin, OperatorMethodMixin):
674 """Represents a value to be computed by dask.
676 Equivalent to the output from a single key in a dask graph.
677 """
679 __slots__ = ("_key", "_dask", "_length", "_layer")
681 def __init__(self, key, dsk, length=None, layer=None):
682 self._key = key
683 self._dask = dsk
684 self._length = length
686 # NOTE: Layer is used by `to_delayed` in other collections, but not in normal Delayed use
687 self._layer = layer or key
688 if isinstance(dsk, HighLevelGraph) and self._layer not in dsk.layers:
689 raise ValueError(
690 f"Layer {self._layer} not in the HighLevelGraph's layers: {list(dsk.layers)}"
691 )
693 @property
694 def key(self):
695 return self._key
697 @property
698 def dask(self):
699 return self._dask
701 def __dask_graph__(self) -> Graph:
702 return self.dask
704 def __dask_keys__(self) -> NestedKeys:
705 return [self.key]
707 def __dask_layers__(self) -> Sequence[str]:
708 return (self._layer,)
710 def __dask_tokenize__(self):
711 return self.key
713 __dask_scheduler__ = staticmethod(DEFAULT_GET)
714 __dask_optimize__ = globalmethod(optimize, key="delayed_optimize")
716 def __dask_postcompute__(self):
717 return single_key, ()
719 def __dask_postpersist__(self):
720 return self._rebuild, ()
722 def _rebuild(self, dsk, *, rename=None):
723 key = replace_name_in_key(self.key, rename) if rename else self.key
724 if isinstance(dsk, HighLevelGraph) and len(dsk.layers) == 1:
725 # FIXME Delayed is currently the only collection type that supports both high- and low-level graphs.
726 # The HLG output of `optimize` will have a layer name that doesn't match `key`.
727 # Remove this when Delayed is HLG-only (because `optimize` will only be passed HLGs, so it won't have
728 # to generate random layer names).
729 layer = next(iter(dsk.layers))
730 else:
731 layer = None
732 return Delayed(key, dsk, self._length, layer=layer)
734 def __repr__(self):
735 return f"Delayed({repr(self.key)})"
737 def __hash__(self):
738 return hash(self.key)
740 def __dir__(self):
741 return dir(type(self))
743 def __getattr__(self, attr):
744 if attr.startswith("_"):
745 raise AttributeError(f"Attribute {attr} not found")
747 if attr == "visualise":
748 # added to warn users in case of spelling error
749 # for more details: https://github.com/dask/dask/issues/5721
750 warnings.warn(
751 "dask.delayed objects have no `visualise` method. "
752 "Perhaps you meant `visualize`?"
753 )
755 return DelayedAttr(self, attr)
757 def __setattr__(self, attr, val):
758 try:
759 object.__setattr__(self, attr, val)
760 except AttributeError:
761 # attr is neither in type(self).__slots__ nor in the __slots__ of any of its
762 # parent classes, and all the parent classes define __slots__ too.
763 # This last bit needs to be unit tested: if any of the parent classes omit
764 # the __slots__ declaration, self will gain a __dict__ and this branch will
765 # become unreachable.
766 raise TypeError("Delayed objects are immutable")
768 def __setitem__(self, index, val):
769 raise TypeError("Delayed objects are immutable")
771 def __iter__(self):
772 if self._length is None:
773 raise TypeError("Delayed objects of unspecified length are not iterable")
774 for i in range(self._length):
775 yield self[i]
777 def __len__(self):
778 if self._length is None:
779 raise TypeError("Delayed objects of unspecified length have no len()")
780 return self._length
782 def __call__(self, *args, pure=None, dask_key_name=None, **kwargs):
783 func = delayed(apply, pure=pure)
784 if dask_key_name is not None:
785 return func(self, args, kwargs, dask_key_name=dask_key_name)
786 return func(self, args, kwargs)
788 def __bool__(self):
789 raise TypeError("Truth of Delayed objects is not supported")
791 __nonzero__ = __bool__
793 def __get__(self, instance, cls):
794 if instance is None:
795 return self
796 return types.MethodType(self, instance)
798 @classmethod
799 def _get_binary_operator(cls, op, inv=False):
800 method = delayed(right(op) if inv else op, pure=True)
801 return lambda *args, **kwargs: method(*args, **kwargs)
803 _get_unary_operator = _get_binary_operator
806def call_function(func, func_token, args, kwargs, pure=None, nout=None):
807 dask_key_name = kwargs.pop("dask_key_name", None)
808 pure = kwargs.pop("pure", pure)
810 if dask_key_name is None:
811 name = "{}-{}".format(
812 funcname(func),
813 tokenize(func_token, *args, pure=pure, **kwargs),
814 )
815 else:
816 name = dask_key_name
818 args2, collections = unzip(map(unpack_collections, args), 2)
819 collections = list(concat(collections))
821 dask_kwargs, collections2 = unpack_collections(kwargs)
822 collections.extend(collections2)
823 task = Task(name, func, *args2, **dask_kwargs)
825 graph = HighLevelGraph.from_collections(
826 name, {name: task}, dependencies=collections
827 )
828 nout = nout if nout is not None else None
829 return Delayed(name, graph, length=nout)
832class DelayedLeaf(Delayed):
833 __slots__ = ("_obj", "_pure", "_nout")
835 def __init__(self, obj, key, pure=None, nout=None):
836 super().__init__(key, None, length=nout)
837 self._obj = obj
838 self._pure = pure
839 self._nout = nout
841 @property
842 def dask(self):
843 if isinstance(self._obj, (TaskRef, GraphNode)):
844 dsk = {self._key: self._obj}
845 else:
846 dsk = {self._key: DataNode(self._key, self._obj)}
847 return HighLevelGraph.from_collections(self._key, dsk, dependencies=())
849 def __call__(self, *args, **kwargs):
850 return call_function(
851 self._obj, self._key, args, kwargs, pure=self._pure, nout=self._nout
852 )
854 @property
855 def __name__(self):
856 return self._obj.__name__
858 @property
859 def __doc__(self):
860 return self._obj.__doc__
862 @property
863 def __wrapped__(self):
864 return self._obj
867class DelayedAttr(Delayed):
868 __slots__ = ("_obj", "_attr")
870 def __init__(self, obj, attr):
871 key = "getattr-%s" % tokenize(obj, attr, pure=True)
872 super().__init__(key, None)
873 self._obj = obj
874 self._attr = attr
876 def __getattr__(self, attr):
877 # Calling np.dtype(dask.delayed(...)) used to result in a segfault, as
878 # numpy recursively tries to get `dtype` from the object. This is
879 # likely a bug in numpy. For now, we can do a dumb for if
880 # `x.dtype().dtype()` is called (which shouldn't ever show up in real
881 # code). See https://github.com/dask/dask/pull/4374#issuecomment-454381465
882 if attr == "dtype" and self._attr == "dtype":
883 raise AttributeError("Attribute dtype not found")
884 return super().__getattr__(attr)
886 @property
887 def dask(self):
888 layer = {self._key: (getattr, self._obj._key, self._attr)}
889 return HighLevelGraph.from_collections(
890 self._key, layer, dependencies=[self._obj]
891 )
893 def __call__(self, *args, **kwargs):
894 return call_function(
895 methodcaller(self._attr), self._attr, (self._obj,) + args, kwargs
896 )
899for op in [
900 operator.abs,
901 operator.neg,
902 operator.pos,
903 operator.invert,
904 operator.add,
905 operator.sub,
906 operator.mul,
907 operator.floordiv,
908 operator.truediv,
909 operator.mod,
910 operator.pow,
911 operator.and_,
912 operator.or_,
913 operator.xor,
914 operator.lshift,
915 operator.rshift,
916 operator.eq,
917 operator.ge,
918 operator.gt,
919 operator.ne,
920 operator.le,
921 operator.lt,
922 operator.getitem,
923]:
924 Delayed._bind_operator(op)
927try:
928 Delayed._bind_operator(operator.matmul)
929except AttributeError:
930 pass
933def single_key(seq):
934 """Pick out the only element of this list, a list of keys"""
935 return seq[0]