Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/_expr.py: 21%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import functools
4import os
5import uuid
6import warnings
7import weakref
8from collections import defaultdict
9from collections.abc import Generator
10from typing import TYPE_CHECKING, Any, Literal, TypeAlias
12import toolz
14import dask
15from dask._task_spec import Task, convert_legacy_graph
16from dask.tokenize import _tokenize_deterministic
17from dask.typing import Key
18from dask.utils import ensure_dict, funcname, import_required
20if TYPE_CHECKING:
21 from dask.highlevelgraph import HighLevelGraph
23OptimizerStage: TypeAlias = Literal[
24 "logical",
25 "simplified-logical",
26 "tuned-logical",
27 "physical",
28 "simplified-physical",
29 "fused",
30]
33def _unpack_collections(o):
34 from dask.delayed import Delayed
36 if isinstance(o, Expr):
37 return o
39 if hasattr(o, "expr") and not isinstance(o, Delayed):
40 return o.expr
41 else:
42 return o
45class Expr:
46 _parameters: list[str] = []
47 _defaults: dict[str, Any] = {}
49 _pickle_functools_cache: bool = True
51 operands: list
53 _determ_token: str | None
55 def __new__(cls, *args, _determ_token=None, **kwargs):
56 operands = list(args)
57 for parameter in cls._parameters[len(operands) :]:
58 try:
59 operands.append(kwargs.pop(parameter))
60 except KeyError:
61 operands.append(cls._defaults[parameter])
62 assert not kwargs, kwargs
63 inst = object.__new__(cls)
65 inst._determ_token = _determ_token
66 inst.operands = [_unpack_collections(o) for o in operands]
67 # This is typically cached. Make sure the cache is populated by calling
68 # it once
69 inst._name
70 return inst
72 def _tune_down(self):
73 return None
75 def _tune_up(self, parent):
76 return None
78 def finalize_compute(self):
79 return self
81 def _operands_for_repr(self):
82 return [f"{param}={op!r}" for param, op in zip(self._parameters, self.operands)]
84 def __str__(self):
85 s = ", ".join(self._operands_for_repr())
86 return f"{type(self).__name__}({s})"
88 def __repr__(self):
89 return str(self)
91 def _tree_repr_argument_construction(self, i, op, header):
92 try:
93 param = self._parameters[i]
94 default = self._defaults[param]
95 except (IndexError, KeyError):
96 param = self._parameters[i] if i < len(self._parameters) else ""
97 default = "--no-default--"
99 if repr(op) != repr(default):
100 if param:
101 header += f" {param}={op!r}"
102 else:
103 header += repr(op)
104 return header
106 def _tree_repr_lines(self, indent=0, recursive=True):
107 return " " * indent + repr(self)
109 def tree_repr(self):
110 return os.linesep.join(self._tree_repr_lines())
112 def analyze(self, filename: str | None = None, format: str | None = None) -> None:
113 from dask.dataframe.dask_expr._expr import Expr as DFExpr
114 from dask.dataframe.dask_expr.diagnostics import analyze
116 if not isinstance(self, DFExpr):
117 raise TypeError(
118 "analyze is only supported for dask.dataframe.Expr objects."
119 )
120 return analyze(self, filename=filename, format=format)
122 def explain(
123 self, stage: OptimizerStage = "fused", format: str | None = None
124 ) -> None:
125 from dask.dataframe.dask_expr.diagnostics import explain
127 return explain(self, stage, format)
129 def pprint(self):
130 for line in self._tree_repr_lines():
131 print(line)
133 def __hash__(self):
134 return hash(self._name)
136 def __dask_tokenize__(self):
137 if not self._determ_token:
138 # If the subclass does not implement a __dask_tokenize__ we'll want
139 # to tokenize all operands.
140 # Note how this differs to the implementation of
141 # Expr.deterministic_token
142 self._determ_token = _tokenize_deterministic(type(self), *self.operands)
143 return self._determ_token
145 def __dask_keys__(self):
146 """The keys for this expression
148 This is used to determine the keys of the output collection
149 when this expression is computed.
151 Returns
152 -------
153 keys: list
154 The keys for this expression
155 """
156 return [(self._name, i) for i in range(self.npartitions)]
158 @staticmethod
159 def _reconstruct(*args):
160 typ, *operands, token, cache = args
161 inst = typ(*operands, _determ_token=token)
162 for k, v in cache.items():
163 inst.__dict__[k] = v
164 return inst
166 def __reduce__(self):
167 if dask.config.get("dask-expr-no-serialize", False):
168 raise RuntimeError(f"Serializing a {type(self)} object")
169 cache = {}
170 if type(self)._pickle_functools_cache:
171 for k, v in type(self).__dict__.items():
172 if isinstance(v, functools.cached_property) and k in self.__dict__:
173 cache[k] = getattr(self, k)
175 return Expr._reconstruct, (
176 type(self),
177 *self.operands,
178 self.deterministic_token,
179 cache,
180 )
182 def _depth(self, cache=None):
183 """Depth of the expression tree
185 Returns
186 -------
187 depth: int
188 """
189 if cache is None:
190 cache = {}
191 if not self.dependencies():
192 return 1
193 else:
194 result = []
195 for expr in self.dependencies():
196 if expr._name in cache:
197 result.append(cache[expr._name])
198 else:
199 result.append(expr._depth(cache) + 1)
200 cache[expr._name] = result[-1]
201 return max(result)
203 def __setattr__(self, name: str, value: Any) -> None:
204 if name in ["operands", "_determ_token"]:
205 object.__setattr__(self, name, value)
206 return
207 try:
208 params = type(self)._parameters
209 operands = object.__getattribute__(self, "operands")
210 operands[params.index(name)] = value
211 except ValueError:
212 raise AttributeError(
213 f"{type(self).__name__} object has no attribute {name}"
214 )
216 def operand(self, key):
217 # Access an operand unambiguously
218 # (e.g. if the key is reserved by a method/property)
219 return self.operands[type(self)._parameters.index(key)]
221 def dependencies(self):
222 # Dependencies are `Expr` operands only
223 return [operand for operand in self.operands if isinstance(operand, Expr)]
225 def _task(self, key: Key, index: int) -> Task:
226 """The task for the i'th partition
228 Parameters
229 ----------
230 index:
231 The index of the partition of this dataframe
233 Examples
234 --------
235 >>> class Add(Expr):
236 ... def _task(self, i):
237 ... return Task(
238 ... self.__dask_keys__()[i],
239 ... operator.add,
240 ... TaskRef((self.left._name, i)),
241 ... TaskRef((self.right._name, i))
242 ... )
244 Returns
245 -------
246 task:
247 The Dask task to compute this partition
249 See Also
250 --------
251 Expr._layer
252 """
253 raise NotImplementedError(
254 "Expressions should define either _layer (full dictionary) or _task"
255 f" (single task). This expression {type(self)} defines neither"
256 )
258 def _layer(self) -> dict:
259 """The graph layer added by this expression.
261 Simple expressions that apply one task per partition can choose to only
262 implement `Expr._task` instead.
264 Examples
265 --------
266 >>> class Add(Expr):
267 ... def _layer(self):
268 ... return {
269 ... name: Task(
270 ... name,
271 ... operator.add,
272 ... TaskRef((self.left._name, i)),
273 ... TaskRef((self.right._name, i))
274 ... )
275 ... for i, name in enumerate(self.__dask_keys__())
276 ... }
278 Returns
279 -------
280 layer: dict
281 The Dask task graph added by this expression
283 See Also
284 --------
285 Expr._task
286 Expr.__dask_graph__
287 """
289 return {
290 (self._name, i): self._task((self._name, i), i)
291 for i in range(self.npartitions)
292 }
294 def rewrite(self, kind: str, rewritten):
295 """Rewrite an expression
297 This leverages the ``._{kind}_down`` and ``._{kind}_up``
298 methods defined on each class
300 Returns
301 -------
302 expr:
303 output expression
304 changed:
305 whether or not any change occurred
306 """
307 if self._name in rewritten:
308 return rewritten[self._name]
310 expr = self
311 down_name = f"_{kind}_down"
312 up_name = f"_{kind}_up"
313 while True:
314 _continue = False
316 # Rewrite this node
317 out = getattr(expr, down_name)()
318 if out is None:
319 out = expr
320 if not isinstance(out, Expr):
321 return out
322 if out._name != expr._name:
323 expr = out
324 continue
326 # Allow children to rewrite their parents
327 for child in expr.dependencies():
328 out = getattr(child, up_name)(expr)
329 if out is None:
330 out = expr
331 if not isinstance(out, Expr):
332 return out
333 if out is not expr and out._name != expr._name:
334 expr = out
335 _continue = True
336 break
338 if _continue:
339 continue
341 # Rewrite all of the children
342 new_operands = []
343 changed = False
344 for operand in expr.operands:
345 if isinstance(operand, Expr):
346 new = operand.rewrite(kind=kind, rewritten=rewritten)
347 rewritten[operand._name] = new
348 if new._name != operand._name:
349 changed = True
350 else:
351 new = operand
352 new_operands.append(new)
354 if changed:
355 expr = type(expr)(*new_operands)
356 continue
357 else:
358 break
360 return expr
362 def simplify_once(self, dependents: defaultdict, simplified: dict):
363 """Simplify an expression
365 This leverages the ``._simplify_down`` and ``._simplify_up``
366 methods defined on each class
368 Parameters
369 ----------
371 dependents: defaultdict[list]
372 The dependents for every node.
373 simplified: dict
374 Cache of simplified expressions for these dependents.
376 Returns
377 -------
378 expr:
379 output expression
380 """
381 # Check if we've already simplified for these dependents
382 if self._name in simplified:
383 return simplified[self._name]
385 expr = self
387 while True:
388 out = expr._simplify_down()
389 if out is None:
390 out = expr
391 if not isinstance(out, Expr):
392 return out
393 if out._name != expr._name:
394 expr = out
396 # Allow children to simplify their parents
397 for child in expr.dependencies():
398 out = child._simplify_up(expr, dependents)
399 if out is None:
400 out = expr
402 if not isinstance(out, Expr):
403 return out
404 if out is not expr and out._name != expr._name:
405 expr = out
406 break
408 # Rewrite all of the children
409 new_operands = []
410 changed = False
411 for operand in expr.operands:
412 if isinstance(operand, Expr):
413 # Bandaid for now, waiting for Singleton
414 dependents[operand._name].append(weakref.ref(expr))
415 new = operand.simplify_once(
416 dependents=dependents, simplified=simplified
417 )
418 simplified[operand._name] = new
419 if new._name != operand._name:
420 changed = True
421 else:
422 new = operand
423 new_operands.append(new)
425 if changed:
426 expr = type(expr)(*new_operands)
428 break
430 return expr
432 def optimize(self, fuse: bool = False) -> Expr:
433 stage: OptimizerStage = "fused" if fuse else "simplified-physical"
435 return optimize_until(self, stage)
437 def fuse(self) -> Expr:
438 return self
440 def simplify(self) -> Expr:
441 expr = self
442 seen = set()
443 while True:
444 dependents = collect_dependents(expr)
445 new = expr.simplify_once(dependents=dependents, simplified={})
446 if new._name == expr._name:
447 break
448 if new._name in seen:
449 raise RuntimeError(
450 f"Optimizer does not converge. {expr!r} simplified to {new!r} which was already seen. "
451 "Please report this issue on the dask issue tracker with a minimal reproducer."
452 )
453 seen.add(new._name)
454 expr = new
455 return expr
457 def _simplify_down(self):
458 return
460 def _simplify_up(self, parent, dependents):
461 return
463 def lower_once(self, lowered: dict):
464 # Check for a cached result
465 try:
466 return lowered[self._name]
467 except KeyError:
468 pass
470 expr = self
472 # Lower this node
473 out = expr._lower()
474 if out is None:
475 out = expr
476 if not isinstance(out, Expr):
477 return out
479 # Lower all children
480 new_operands = []
481 changed = False
482 for operand in out.operands:
483 if isinstance(operand, Expr):
484 new = operand.lower_once(lowered)
485 if new._name != operand._name:
486 changed = True
487 else:
488 new = operand
489 new_operands.append(new)
491 if changed:
492 out = type(out)(*new_operands)
494 # Cache the result and return
495 return lowered.setdefault(self._name, out)
497 def lower_completely(self) -> Expr:
498 """Lower an expression completely
500 This calls the ``lower_once`` method in a loop
501 until nothing changes. This function does not
502 apply any other optimizations (like ``simplify``).
504 Returns
505 -------
506 expr:
507 output expression
509 See Also
510 --------
511 Expr.lower_once
512 Expr._lower
513 """
514 # Lower until nothing changes
515 expr = self
516 lowered: dict = {}
517 while True:
518 new = expr.lower_once(lowered)
519 if new._name == expr._name:
520 break
521 expr = new
522 return expr
524 def _lower(self):
525 return
527 @functools.cached_property
528 def _funcname(self) -> str:
529 return funcname(type(self)).lower()
531 @property
532 def deterministic_token(self):
533 if not self._determ_token:
534 # Just tokenize self to fall back on __dask_tokenize__
535 # Note how this differs to the implementation of __dask_tokenize__
536 self._determ_token = self.__dask_tokenize__()
537 return self._determ_token
539 @functools.cached_property
540 def _name(self) -> str:
541 return f"{self._funcname}-{self.deterministic_token}"
543 @property
544 def _meta(self):
545 raise NotImplementedError()
547 @classmethod
548 def _annotations_tombstone(cls) -> _AnnotationsTombstone:
549 return _AnnotationsTombstone()
551 def __dask_annotations__(self):
552 return {}
554 def __dask_graph__(self):
555 """Traverse expression tree, collect layers
557 Subclasses generally do not want to override this method unless custom
558 logic is required to treat (e.g. ignore) specific operands during graph
559 generation.
561 See also
562 --------
563 Expr._layer
564 Expr._task
565 """
566 stack = [self]
567 seen = set()
568 layers = []
569 while stack:
570 expr = stack.pop()
572 if expr._name in seen:
573 continue
574 seen.add(expr._name)
576 layers.append(expr._layer())
577 for operand in expr.dependencies():
578 stack.append(operand)
580 return toolz.merge(layers)
582 @property
583 def dask(self):
584 return self.__dask_graph__()
586 def substitute(self, old, new) -> Expr:
587 """Substitute a specific term within the expression
589 Note that replacing non-`Expr` terms may produce
590 unexpected results, and is not recommended.
591 Substituting boolean values is not allowed.
593 Parameters
594 ----------
595 old:
596 Old term to find and replace.
597 new:
598 New term to replace instances of `old` with.
600 Examples
601 --------
602 >>> (df + 10).substitute(10, 20) # doctest: +SKIP
603 df + 20
604 """
605 return self._substitute(old, new, _seen=set())
607 def _substitute(self, old, new, _seen):
608 if self._name in _seen:
609 return self
610 # Check if we are replacing a literal
611 if isinstance(old, Expr):
612 substitute_literal = False
613 if self._name == old._name:
614 return new
615 else:
616 substitute_literal = True
617 if isinstance(old, bool):
618 raise TypeError("Arguments to `substitute` cannot be bool.")
620 new_exprs = []
621 update = False
622 for operand in self.operands:
623 if isinstance(operand, Expr):
624 val = operand._substitute(old, new, _seen)
625 if operand._name != val._name:
626 update = True
627 new_exprs.append(val)
628 elif (
629 "Fused" in type(self).__name__
630 and isinstance(operand, list)
631 and all(isinstance(op, Expr) for op in operand)
632 ):
633 # Special handling for `Fused`.
634 # We make no promise to dive through a
635 # list operand in general, but NEED to
636 # do so for the `Fused.exprs` operand.
637 val = []
638 for op in operand:
639 val.append(op._substitute(old, new, _seen))
640 if val[-1]._name != op._name:
641 update = True
642 new_exprs.append(val)
643 elif (
644 substitute_literal
645 and not isinstance(operand, bool)
646 and isinstance(operand, type(old))
647 and operand == old
648 ):
649 new_exprs.append(new)
650 update = True
651 else:
652 new_exprs.append(operand)
654 if update: # Only recreate if something changed
655 return type(self)(*new_exprs)
656 else:
657 _seen.add(self._name)
658 return self
660 def substitute_parameters(self, substitutions: dict) -> Expr:
661 """Substitute specific `Expr` parameters
663 Parameters
664 ----------
665 substitutions:
666 Mapping of parameter keys to new values. Keys that
667 are not found in ``self._parameters`` will be ignored.
668 """
669 if not substitutions:
670 return self
672 changed = False
673 new_operands = []
674 for i, operand in enumerate(self.operands):
675 if i < len(self._parameters) and self._parameters[i] in substitutions:
676 new_operands.append(substitutions[self._parameters[i]])
677 changed = True
678 else:
679 new_operands.append(operand)
680 if changed:
681 return type(self)(*new_operands)
682 return self
684 def _node_label_args(self):
685 """Operands to include in the node label by `visualize`"""
686 return self.dependencies()
688 def _to_graphviz(
689 self,
690 rankdir="BT",
691 graph_attr=None,
692 node_attr=None,
693 edge_attr=None,
694 **kwargs,
695 ):
696 from dask.dot import label, name
698 graphviz = import_required(
699 "graphviz",
700 "Drawing dask graphs with the graphviz visualization engine requires the `graphviz` "
701 "python library and the `graphviz` system library.\n\n"
702 "Please either conda or pip install as follows:\n\n"
703 " conda install python-graphviz # either conda install\n"
704 " python -m pip install graphviz # or pip install and follow installation instructions",
705 )
707 graph_attr = graph_attr or {}
708 node_attr = node_attr or {}
709 edge_attr = edge_attr or {}
711 graph_attr["rankdir"] = rankdir
712 node_attr["shape"] = "box"
713 node_attr["fontname"] = "helvetica"
715 graph_attr.update(kwargs)
716 g = graphviz.Digraph(
717 graph_attr=graph_attr,
718 node_attr=node_attr,
719 edge_attr=edge_attr,
720 )
722 stack = [self]
723 seen = set()
724 dependencies = {}
725 while stack:
726 expr = stack.pop()
728 if expr._name in seen:
729 continue
730 seen.add(expr._name)
732 dependencies[expr] = set(expr.dependencies())
733 for dep in expr.dependencies():
734 stack.append(dep)
736 cache = {}
737 for expr in dependencies:
738 expr_name = name(expr)
739 attrs = {}
741 # Make node label
742 deps = [
743 funcname(type(dep)) if isinstance(dep, Expr) else str(dep)
744 for dep in expr._node_label_args()
745 ]
746 _label = funcname(type(expr))
747 if deps:
748 _label = f"{_label}({', '.join(deps)})" if deps else _label
749 node_label = label(_label, cache=cache)
751 attrs.setdefault("label", str(node_label))
752 attrs.setdefault("fontsize", "20")
753 g.node(expr_name, **attrs)
755 for expr, deps in dependencies.items():
756 expr_name = name(expr)
757 for dep in deps:
758 dep_name = name(dep)
759 g.edge(dep_name, expr_name)
761 return g
763 def visualize(self, filename="dask-expr.svg", format=None, **kwargs):
764 """
765 Visualize the expression graph.
766 Requires ``graphviz`` to be installed.
768 Parameters
769 ----------
770 filename : str or None, optional
771 The name of the file to write to disk. If the provided `filename`
772 doesn't include an extension, '.png' will be used by default.
773 If `filename` is None, no file will be written, and the graph is
774 rendered in the Jupyter notebook only.
775 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional
776 Format in which to write output file. Default is 'svg'.
777 **kwargs
778 Additional keyword arguments to forward to ``to_graphviz``.
779 """
780 from dask.dot import graphviz_to_file
782 g = self._to_graphviz(**kwargs)
783 graphviz_to_file(g, filename, format)
784 return g
786 def walk(self) -> Generator[Expr]:
787 """Iterate through all expressions in the tree
789 Returns
790 -------
791 nodes
792 Generator of Expr instances in the graph.
793 Ordering is a depth-first search of the expression tree
794 """
795 stack = [self]
796 seen = set()
797 while stack:
798 node = stack.pop()
799 if node._name in seen:
800 continue
801 seen.add(node._name)
803 for dep in node.dependencies():
804 stack.append(dep)
806 yield node
808 def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]:
809 """Search the expression graph for a specific operation type
811 Parameters
812 ----------
813 operation
814 The operation type to search for.
816 Returns
817 -------
818 nodes
819 Generator of `operation` instances. Ordering corresponds
820 to a depth-first search of the expression graph.
821 """
822 assert (
823 isinstance(operation, tuple)
824 and all(issubclass(e, Expr) for e in operation)
825 or issubclass(operation, Expr) # type: ignore[arg-type]
826 ), "`operation` must be`Expr` subclass)"
827 return (expr for expr in self.walk() if isinstance(expr, operation))
829 def __getattr__(self, key):
830 try:
831 return object.__getattribute__(self, key)
832 except AttributeError as err:
833 if key.startswith("_meta"):
834 # Avoid a recursive loop if/when `self._meta*`
835 # produces an `AttributeError`
836 raise RuntimeError(
837 f"Failed to generate metadata for {self}. "
838 "This operation may not be supported by the current backend."
839 )
841 # Allow operands to be accessed as attributes
842 # as long as the keys are not already reserved
843 # by existing methods/properties
844 _parameters = type(self)._parameters
845 if key in _parameters:
846 idx = _parameters.index(key)
847 return self.operands[idx]
849 raise AttributeError(
850 f"{err}\n\n"
851 "This often means that you are attempting to use an unsupported "
852 f"API function.."
853 )
856class SingletonExpr(Expr):
857 """A singleton Expr class
859 This is used to treat the subclassed expression as a singleton. Singletons
860 are deduplicated by expr._name which is typically based on the dask.tokenize
861 output.
863 This is a crucial performance optimization for expressions that walk through
864 an optimizer and are recreated repeatedly but isn't safe for objects that
865 cannot be reliably or quickly tokenized.
866 """
868 _instances: weakref.WeakValueDictionary[str, SingletonExpr]
870 def __new__(cls, *args, _determ_token=None, **kwargs):
871 if not hasattr(cls, "_instances"):
872 cls._instances = weakref.WeakValueDictionary()
873 inst = super().__new__(cls, *args, _determ_token=_determ_token, **kwargs)
874 _name = inst._name
875 if _name in cls._instances and cls.__init__ == object.__init__:
876 return cls._instances[_name]
878 cls._instances[_name] = inst
879 return inst
882def collect_dependents(expr) -> defaultdict:
883 dependents = defaultdict(list)
884 stack = [expr]
885 seen = set()
886 while stack:
887 node = stack.pop()
888 if node._name in seen:
889 continue
890 seen.add(node._name)
892 for dep in node.dependencies():
893 stack.append(dep)
894 dependents[dep._name].append(weakref.ref(node))
895 return dependents
898def optimize(expr: Expr, fuse: bool = True) -> Expr:
899 """High level query optimization
901 This leverages three optimization passes:
903 1. Class based simplification using the ``_simplify`` function and methods
904 2. Blockwise fusion
906 Parameters
907 ----------
908 expr:
909 Input expression to optimize
910 fuse:
911 whether or not to turn on blockwise fusion
913 See Also
914 --------
915 simplify
916 optimize_blockwise_fusion
917 """
918 stage: OptimizerStage = "fused" if fuse else "simplified-physical"
920 return optimize_until(expr, stage)
923def optimize_until(expr: Expr, stage: OptimizerStage) -> Expr:
924 result = expr
925 if stage == "logical":
926 return result
928 # Simplify
929 expr = result.simplify()
930 if stage == "simplified-logical":
931 return expr
933 # Manipulate Expression to make it more efficient
934 if dask.config.get("optimization.tune.active", True):
935 expr = expr.rewrite(kind="tune", rewritten={})
936 if stage == "tuned-logical":
937 return expr
939 # Lower
940 expr = expr.lower_completely()
941 if stage == "physical":
942 return expr
944 # Simplify again
945 expr = expr.simplify()
946 if stage == "simplified-physical":
947 return expr
949 # Final graph-specific optimizations
950 expr = expr.fuse()
951 if stage == "fused":
952 return expr
954 raise ValueError(f"Stage {stage!r} not supported.")
957class LLGExpr(Expr):
958 """Low Level Graph Expression"""
960 _parameters = ["dsk"]
962 def __dask_keys__(self):
963 return list(self.operand("dsk"))
965 def _layer(self) -> dict:
966 return ensure_dict(self.operand("dsk"))
969class HLGExpr(Expr):
970 _parameters = [
971 "dsk",
972 "low_level_optimizer",
973 "output_keys",
974 "postcompute",
975 "_cached_optimized",
976 ]
977 _defaults = {
978 "low_level_optimizer": None,
979 "output_keys": None,
980 "postcompute": None,
981 "_cached_optimized": None,
982 }
984 @property
985 def hlg(self):
986 return self.operand("dsk")
988 @staticmethod
989 def from_collection(collection, optimize_graph=True):
990 from dask.highlevelgraph import HighLevelGraph
992 if hasattr(collection, "dask"):
993 dsk = collection.dask.copy()
994 else:
995 dsk = collection.__dask_graph__()
997 # Delayed objects still ship with low level graphs as `dask` when going
998 # through optimize / persist
999 if not isinstance(dsk, HighLevelGraph):
1001 dsk = HighLevelGraph.from_collections(
1002 str(id(collection)), dsk, dependencies=()
1003 )
1004 if optimize_graph and not hasattr(collection, "__dask_optimize__"):
1005 warnings.warn(
1006 f"Collection {type(collection)} does not define a "
1007 "`__dask_optimize__` method. In the future this will raise. "
1008 "If no optimization is desired, please set this to `None`.",
1009 PendingDeprecationWarning,
1010 )
1011 low_level_optimizer = None
1012 else:
1013 low_level_optimizer = (
1014 collection.__dask_optimize__ if optimize_graph else None
1015 )
1016 return HLGExpr(
1017 dsk=dsk,
1018 low_level_optimizer=low_level_optimizer,
1019 output_keys=collection.__dask_keys__(),
1020 postcompute=collection.__dask_postcompute__(),
1021 )
1023 def finalize_compute(self):
1024 return HLGFinalizeCompute(
1025 self,
1026 low_level_optimizer=self.low_level_optimizer,
1027 output_keys=self.output_keys,
1028 postcompute=self.postcompute,
1029 )
1031 def __dask_annotations__(self) -> dict[str, dict[Key, object]]:
1032 # optimization has to be called (and cached) since blockwise fusion can
1033 # alter annotations
1034 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)`
1035 dsk = self._optimized_dsk
1036 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict)
1037 for layer in dsk.layers.values():
1038 if layer.annotations:
1039 annot = layer.annotations
1040 for annot_type, value in annot.items():
1041 annotations_by_type[annot_type].update(
1042 {k: (value(k) if callable(value) else value) for k in layer}
1043 )
1044 return dict(annotations_by_type)
1046 def __dask_keys__(self):
1047 if (keys := self.operand("output_keys")) is not None:
1048 return keys
1049 dsk = self.hlg
1050 # Note: This will materialize
1051 dependencies = dsk.get_all_dependencies()
1052 leafs = set(dependencies)
1053 for val in dependencies.values():
1054 leafs -= val
1055 self.output_keys = list(leafs)
1056 return self.output_keys
1058 @functools.cached_property
1059 def _optimized_dsk(self) -> HighLevelGraph:
1060 from dask.highlevelgraph import HighLevelGraph
1062 optimizer = self.low_level_optimizer
1063 keys = self.__dask_keys__()
1064 dsk = self.hlg
1065 if (optimizer := self.low_level_optimizer) is not None:
1066 dsk = optimizer(dsk, keys)
1067 return HighLevelGraph.merge(dsk)
1069 @property
1070 def deterministic_token(self):
1071 if not self._determ_token:
1072 self._determ_token = uuid.uuid4().hex
1073 return self._determ_token
1075 def _layer(self) -> dict:
1076 dsk = self._optimized_dsk
1077 return ensure_dict(dsk)
1080class _HLGExprGroup(HLGExpr):
1081 # Identical to HLGExpr
1082 # Used internally to determine how output keys are supposed to be returned
1083 pass
1086class _HLGExprSequence(Expr):
1088 def __getitem__(self, other):
1089 return self.operands[other]
1091 def _operands_for_repr(self):
1092 return [
1093 f"name={self.operand('name')!r}",
1094 f"dsk={self.operand('dsk')!r}",
1095 ]
1097 def _tree_repr_lines(self, indent=0, recursive=True):
1098 return self._operands_for_repr()
1100 def finalize_compute(self):
1101 return _HLGExprSequence(*[op.finalize_compute() for op in self.operands])
1103 def _tune_down(self):
1104 if len(self.operands) == 1:
1105 return None
1106 from dask.highlevelgraph import HighLevelGraph
1108 groups = toolz.groupby(
1109 lambda x: x.low_level_optimizer if isinstance(x, HLGExpr) else None,
1110 self.operands,
1111 )
1112 exprs = []
1113 changed = False
1114 for optimizer, group in groups.items():
1115 if len(group) > 1:
1116 graphs = [expr.hlg for expr in group]
1118 changed = True
1119 dsk = HighLevelGraph.merge(*graphs)
1120 hlg_group = _HLGExprGroup(
1121 dsk=dsk,
1122 low_level_optimizer=optimizer,
1123 output_keys=[v.__dask_keys__() for v in group],
1124 postcompute=[g.postcompute for g in group],
1125 )
1126 exprs.append(hlg_group)
1127 else:
1128 exprs.append(group[0])
1129 if not changed:
1130 return None
1131 return _HLGExprSequence(*exprs)
1133 @functools.cached_property
1134 def _optimized_dsk(self) -> HighLevelGraph:
1135 from dask.highlevelgraph import HighLevelGraph
1137 hlgexpr: HLGExpr
1138 graphs = []
1139 # simplify_down ensure there are only one HLGExpr per optimizer/finalizer
1140 for hlgexpr in self.operands:
1141 keys = hlgexpr.__dask_keys__()
1142 dsk = hlgexpr.hlg
1143 if (optimizer := hlgexpr.low_level_optimizer) is not None:
1144 dsk = optimizer(dsk, keys)
1145 graphs.append(dsk)
1147 return HighLevelGraph.merge(*graphs)
1149 def __dask_graph__(self):
1150 # This class has to override this and not just _layer to ensure the HLGs
1151 # are not optimized individually
1152 return ensure_dict(self._optimized_dsk)
1154 _layer = __dask_graph__
1156 def __dask_annotations__(self) -> dict[str, dict[Key, object]]:
1157 # optimization has to be called (and cached) since blockwise fusion can
1158 # alter annotations
1159 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)`
1160 dsk = self._optimized_dsk
1161 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict)
1162 for layer in dsk.layers.values():
1163 if layer.annotations:
1164 annot = layer.annotations
1165 for annot_type, value in annot.items():
1166 annots = list(
1167 (k, (value(k) if callable(value) else value)) for k in layer
1168 )
1169 annotations_by_type[annot_type].update(
1170 {
1171 k: v
1172 for k, v in annots
1173 if not isinstance(v, _AnnotationsTombstone)
1174 }
1175 )
1176 if not annotations_by_type[annot_type]:
1177 del annotations_by_type[annot_type]
1178 return dict(annotations_by_type)
1180 def __dask_keys__(self) -> list:
1181 all_keys = []
1182 for op in self.operands:
1183 if isinstance(op, _HLGExprGroup):
1184 all_keys.extend(op.__dask_keys__())
1185 else:
1186 all_keys.append(op.__dask_keys__())
1187 return all_keys
1190class _ExprSequence(Expr):
1191 """A sequence of expressions
1193 This is used to be able to optimize multiple collections combined, e.g. when
1194 being computed simultaneously with ``dask.compute((Expr1, Expr2))``.
1195 """
1197 def __getitem__(self, other):
1198 return self.operands[other]
1200 def _layer(self) -> dict:
1201 return toolz.merge(op._layer() for op in self.operands)
1203 def __dask_keys__(self) -> list:
1204 all_keys = []
1205 for op in self.operands:
1206 all_keys.append(list(op.__dask_keys__()))
1207 return all_keys
1209 def __repr__(self):
1210 return "ExprSequence(" + ", ".join(map(repr, self.operands)) + ")"
1212 __str__ = __repr__
1214 def finalize_compute(self):
1215 return _ExprSequence(
1216 *(op.finalize_compute() for op in self.operands),
1217 )
1219 def __dask_annotations__(self):
1220 annotations_by_type = {}
1221 for op in self.operands:
1222 for k, v in op.__dask_annotations__().items():
1223 annotations_by_type.setdefault(k, {}).update(v)
1224 return annotations_by_type
1226 def __len__(self):
1227 return len(self.operands)
1229 def __iter__(self):
1230 return iter(self.operands)
1232 def _simplify_down(self):
1233 from dask.highlevelgraph import HighLevelGraph
1235 issue_warning = False
1236 hlgs = []
1237 if any(
1238 isinstance(op, (HLGExpr, HLGFinalizeCompute, dict)) for op in self.operands
1239 ):
1240 for op in self.operands:
1241 if isinstance(op, (HLGExpr, HLGFinalizeCompute)):
1242 hlgs.append(op)
1243 elif isinstance(op, dict):
1244 hlgs.append(
1245 HLGExpr(
1246 dsk=HighLevelGraph.from_collections(
1247 str(id(op)), op, dependencies=()
1248 )
1249 )
1250 )
1251 else:
1252 issue_warning = True
1253 opt = op.optimize()
1254 hlgs.append(
1255 HLGExpr(
1256 dsk=HighLevelGraph.from_collections(
1257 opt._name, opt.__dask_graph__(), dependencies=()
1258 )
1259 )
1260 )
1261 if issue_warning:
1262 warnings.warn(
1263 "Computing mixed collections that are backed by "
1264 "HighlevelGraphs/dicts and Expressions. "
1265 "This forces Expressions to be materialized. "
1266 "It is recommended to use only one type and separate the dask."
1267 "compute calls if necessary.",
1268 UserWarning,
1269 )
1270 if not hlgs:
1271 return None
1272 return _HLGExprSequence(*hlgs)
1275class _AnnotationsTombstone: ...
1278class FinalizeCompute(Expr):
1279 _parameters = ["expr"]
1281 def _simplify_down(self):
1282 return self.expr.finalize_compute()
1285def _convert_dask_keys(keys):
1286 from dask._task_spec import List, TaskRef
1288 assert isinstance(keys, list)
1289 new_keys = []
1290 for key in keys:
1291 if isinstance(key, list):
1292 new_keys.append(_convert_dask_keys(key))
1293 else:
1294 new_keys.append(TaskRef(key))
1295 return List(*new_keys)
1298class HLGFinalizeCompute(HLGExpr):
1300 def _simplify_down(self):
1301 if not self.postcompute:
1302 return self.dsk
1304 from dask.delayed import Delayed
1306 # Skip finalization for Delayed
1307 if self.dsk.postcompute == Delayed.__dask_postcompute__(self.dsk):
1308 return self.dsk
1309 return self
1311 @property
1312 def _name(self):
1313 return f"finalize-{super()._name}"
1315 def __dask_graph__(self):
1316 # The baseclass __dask_graph__ will not just materialize this layer but
1317 # also that of its dependencies, i.e. it will render the finalized and
1318 # the non-finalized graph and combine them. We only want the finalized
1319 # so we're overriding this.
1320 # This is an artifact generated since the wrapped expression is
1321 # identified automatically as a dependency but HLG expressions are not
1322 # working in this layered way.
1323 return self._layer()
1325 @property
1326 def hlg(self):
1327 expr = self.operand("dsk")
1328 layers = expr.dsk.layers.copy()
1329 deps = expr.dsk.dependencies.copy()
1330 keys = expr.__dask_keys__()
1331 if isinstance(expr.postcompute, list):
1332 postcomputes = expr.postcompute
1333 else:
1334 postcomputes = [expr.postcompute]
1335 tasks = [
1336 Task(self._name, func, _convert_dask_keys(keys), *extra_args)
1337 for func, extra_args in postcomputes
1338 ]
1339 from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
1341 leafs = set(deps)
1342 for val in deps.values():
1343 leafs -= val
1344 for t in tasks:
1345 layers[t.key] = MaterializedLayer({t.key: t})
1346 deps[t.key] = leafs
1347 return HighLevelGraph(layers, dependencies=deps)
1349 def __dask_keys__(self):
1350 return [self._name]
1353class ProhibitReuse(Expr):
1354 """
1355 An expression that guarantees that all keys are suffixes with a unique id.
1356 This can be used to break a common subexpression apart.
1357 """
1359 _parameters = ["expr"]
1360 _ALLOWED_TYPES = [HLGExpr, LLGExpr, HLGFinalizeCompute, _HLGExprSequence]
1362 def __dask_keys__(self):
1363 return self._modify_keys(self.expr.__dask_keys__())
1365 @staticmethod
1366 def _identity(obj):
1367 return obj
1369 @functools.cached_property
1370 def _suffix(self):
1371 return uuid.uuid4().hex
1373 def _modify_keys(self, k):
1374 if isinstance(k, list):
1375 return [self._modify_keys(kk) for kk in k]
1376 elif isinstance(k, tuple):
1377 return (self._modify_keys(k[0]),) + k[1:]
1378 elif isinstance(k, (int, float)):
1379 k = str(k)
1380 return f"{k}-{self._suffix}"
1382 def _simplify_down(self):
1383 # FIXME: Shuffling cannot be rewritten since the barrier key is
1384 # hardcoded. Skipping this here should do the trick most of the time
1385 if not isinstance(
1386 self.expr,
1387 tuple(self._ALLOWED_TYPES),
1388 ):
1389 return self.expr
1391 def __dask_graph__(self):
1392 try:
1393 from distributed.shuffle._core import P2PBarrierTask
1394 except ModuleNotFoundError:
1395 P2PBarrierTask = type(None)
1396 dsk = convert_legacy_graph(self.expr.__dask_graph__())
1398 subs = {old_key: self._modify_keys(old_key) for old_key in dsk}
1399 dsk2 = {}
1400 for old_key, new_key in subs.items():
1401 t = dsk[old_key]
1402 if isinstance(t, P2PBarrierTask):
1403 warnings.warn(
1404 "Cannot block reusing for graphs including a "
1405 "P2PBarrierTask. This may cause unexpected results. "
1406 "This typically happens when converting a dask "
1407 "DataFrame to delayed objects.",
1408 UserWarning,
1409 )
1410 return dsk
1411 dsk2[new_key] = Task(
1412 new_key,
1413 ProhibitReuse._identity,
1414 t.substitute(subs),
1415 )
1417 dsk2.update(dsk)
1418 return dsk2
1420 _layer = __dask_graph__