Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/_expr.py: 21%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import functools
4import os
5import uuid
6import warnings
7import weakref
8from collections import defaultdict
9from collections.abc import Generator
10from typing import TYPE_CHECKING, Literal
12import toolz
14import dask
15from dask._task_spec import Task, convert_legacy_graph
16from dask.tokenize import _tokenize_deterministic
17from dask.typing import Key
18from dask.utils import ensure_dict, funcname, import_required
20if TYPE_CHECKING:
21 # TODO import from typing (requires Python >=3.10)
22 from typing import Any, TypeAlias
24 from dask.highlevelgraph import HighLevelGraph
26OptimizerStage: TypeAlias = Literal[
27 "logical",
28 "simplified-logical",
29 "tuned-logical",
30 "physical",
31 "simplified-physical",
32 "fused",
33]
36def _unpack_collections(o):
37 from dask.delayed import Delayed
39 if isinstance(o, Expr):
40 return o
42 if hasattr(o, "expr") and not isinstance(o, Delayed):
43 return o.expr
44 else:
45 return o
48class Expr:
49 _parameters: list[str] = []
50 _defaults: dict[str, Any] = {}
52 _pickle_functools_cache: bool = True
54 operands: list
56 _determ_token: str | None
58 def __new__(cls, *args, _determ_token=None, **kwargs):
59 operands = list(args)
60 for parameter in cls._parameters[len(operands) :]:
61 try:
62 operands.append(kwargs.pop(parameter))
63 except KeyError:
64 operands.append(cls._defaults[parameter])
65 assert not kwargs, kwargs
66 inst = object.__new__(cls)
68 inst._determ_token = _determ_token
69 inst.operands = [_unpack_collections(o) for o in operands]
70 # This is typically cached. Make sure the cache is populated by calling
71 # it once
72 inst._name
73 return inst
75 def _tune_down(self):
76 return None
78 def _tune_up(self, parent):
79 return None
81 def finalize_compute(self):
82 return self
84 def _operands_for_repr(self):
85 return [
86 f"{param}={repr(op)}" for param, op in zip(self._parameters, self.operands)
87 ]
89 def __str__(self):
90 s = ", ".join(self._operands_for_repr())
91 return f"{type(self).__name__}({s})"
93 def __repr__(self):
94 return str(self)
96 def _tree_repr_argument_construction(self, i, op, header):
97 try:
98 param = self._parameters[i]
99 default = self._defaults[param]
100 except (IndexError, KeyError):
101 param = self._parameters[i] if i < len(self._parameters) else ""
102 default = "--no-default--"
104 if repr(op) != repr(default):
105 if param:
106 header += f" {param}={repr(op)}"
107 else:
108 header += repr(op)
109 return header
111 def _tree_repr_lines(self, indent=0, recursive=True):
112 return " " * indent + repr(self)
114 def tree_repr(self):
115 return os.linesep.join(self._tree_repr_lines())
117 def analyze(self, filename: str | None = None, format: str | None = None) -> None:
118 from dask.dataframe.dask_expr._expr import Expr as DFExpr
119 from dask.dataframe.dask_expr.diagnostics import analyze
121 if not isinstance(self, DFExpr):
122 raise TypeError(
123 "analyze is only supported for dask.dataframe.Expr objects."
124 )
125 return analyze(self, filename=filename, format=format)
127 def explain(
128 self, stage: OptimizerStage = "fused", format: str | None = None
129 ) -> None:
130 from dask.dataframe.dask_expr.diagnostics import explain
132 return explain(self, stage, format)
134 def pprint(self):
135 for line in self._tree_repr_lines():
136 print(line)
138 def __hash__(self):
139 return hash(self._name)
141 def __dask_tokenize__(self):
142 if not self._determ_token:
143 # If the subclass does not implement a __dask_tokenize__ we'll want
144 # to tokenize all operands.
145 # Note how this differs to the implementation of
146 # Expr.deterministic_token
147 self._determ_token = _tokenize_deterministic(type(self), *self.operands)
148 return self._determ_token
150 def __dask_keys__(self):
151 """The keys for this expression
153 This is used to determine the keys of the output collection
154 when this expression is computed.
156 Returns
157 -------
158 keys: list
159 The keys for this expression
160 """
161 return [(self._name, i) for i in range(self.npartitions)]
163 @staticmethod
164 def _reconstruct(*args):
165 typ, *operands, token, cache = args
166 inst = typ(*operands, _determ_token=token)
167 for k, v in cache.items():
168 inst.__dict__[k] = v
169 return inst
171 def __reduce__(self):
172 if dask.config.get("dask-expr-no-serialize", False):
173 raise RuntimeError(f"Serializing a {type(self)} object")
174 cache = {}
175 if type(self)._pickle_functools_cache:
176 for k, v in type(self).__dict__.items():
177 if isinstance(v, functools.cached_property) and k in self.__dict__:
178 cache[k] = getattr(self, k)
180 return Expr._reconstruct, tuple(
181 [type(self), *self.operands, self.deterministic_token, cache]
182 )
184 def _depth(self, cache=None):
185 """Depth of the expression tree
187 Returns
188 -------
189 depth: int
190 """
191 if cache is None:
192 cache = {}
193 if not self.dependencies():
194 return 1
195 else:
196 result = []
197 for expr in self.dependencies():
198 if expr._name in cache:
199 result.append(cache[expr._name])
200 else:
201 result.append(expr._depth(cache) + 1)
202 cache[expr._name] = result[-1]
203 return max(result)
205 def __setattr__(self, name: str, value: Any) -> None:
206 if name in ["operands", "_determ_token"]:
207 object.__setattr__(self, name, value)
208 return
209 try:
210 params = type(self)._parameters
211 operands = object.__getattribute__(self, "operands")
212 operands[params.index(name)] = value
213 except ValueError:
214 raise AttributeError(
215 f"{type(self).__name__} object has no attribute {name}"
216 )
218 def operand(self, key):
219 # Access an operand unambiguously
220 # (e.g. if the key is reserved by a method/property)
221 return self.operands[type(self)._parameters.index(key)]
223 def dependencies(self):
224 # Dependencies are `Expr` operands only
225 return [operand for operand in self.operands if isinstance(operand, Expr)]
227 def _task(self, key: Key, index: int) -> Task:
228 """The task for the i'th partition
230 Parameters
231 ----------
232 index:
233 The index of the partition of this dataframe
235 Examples
236 --------
237 >>> class Add(Expr):
238 ... def _task(self, i):
239 ... return Task(
240 ... self.__dask_keys__()[i],
241 ... operator.add,
242 ... TaskRef((self.left._name, i)),
243 ... TaskRef((self.right._name, i))
244 ... )
246 Returns
247 -------
248 task:
249 The Dask task to compute this partition
251 See Also
252 --------
253 Expr._layer
254 """
255 raise NotImplementedError(
256 "Expressions should define either _layer (full dictionary) or _task"
257 f" (single task). This expression {type(self)} defines neither"
258 )
260 def _layer(self) -> dict:
261 """The graph layer added by this expression.
263 Simple expressions that apply one task per partition can choose to only
264 implement `Expr._task` instead.
266 Examples
267 --------
268 >>> class Add(Expr):
269 ... def _layer(self):
270 ... return {
271 ... name: Task(
272 ... name,
273 ... operator.add,
274 ... TaskRef((self.left._name, i)),
275 ... TaskRef((self.right._name, i))
276 ... )
277 ... for i, name in enumerate(self.__dask_keys__())
278 ... }
280 Returns
281 -------
282 layer: dict
283 The Dask task graph added by this expression
285 See Also
286 --------
287 Expr._task
288 Expr.__dask_graph__
289 """
291 return {
292 (self._name, i): self._task((self._name, i), i)
293 for i in range(self.npartitions)
294 }
296 def rewrite(self, kind: str, rewritten):
297 """Rewrite an expression
299 This leverages the ``._{kind}_down`` and ``._{kind}_up``
300 methods defined on each class
302 Returns
303 -------
304 expr:
305 output expression
306 changed:
307 whether or not any change occured
308 """
309 if self._name in rewritten:
310 return rewritten[self._name]
312 expr = self
313 down_name = f"_{kind}_down"
314 up_name = f"_{kind}_up"
315 while True:
316 _continue = False
318 # Rewrite this node
319 out = getattr(expr, down_name)()
320 if out is None:
321 out = expr
322 if not isinstance(out, Expr):
323 return out
324 if out._name != expr._name:
325 expr = out
326 continue
328 # Allow children to rewrite their parents
329 for child in expr.dependencies():
330 out = getattr(child, up_name)(expr)
331 if out is None:
332 out = expr
333 if not isinstance(out, Expr):
334 return out
335 if out is not expr and out._name != expr._name:
336 expr = out
337 _continue = True
338 break
340 if _continue:
341 continue
343 # Rewrite all of the children
344 new_operands = []
345 changed = False
346 for operand in expr.operands:
347 if isinstance(operand, Expr):
348 new = operand.rewrite(kind=kind, rewritten=rewritten)
349 rewritten[operand._name] = new
350 if new._name != operand._name:
351 changed = True
352 else:
353 new = operand
354 new_operands.append(new)
356 if changed:
357 expr = type(expr)(*new_operands)
358 continue
359 else:
360 break
362 return expr
364 def simplify_once(self, dependents: defaultdict, simplified: dict):
365 """Simplify an expression
367 This leverages the ``._simplify_down`` and ``._simplify_up``
368 methods defined on each class
370 Parameters
371 ----------
373 dependents: defaultdict[list]
374 The dependents for every node.
375 simplified: dict
376 Cache of simplified expressions for these dependents.
378 Returns
379 -------
380 expr:
381 output expression
382 """
383 # Check if we've already simplified for these dependents
384 if self._name in simplified:
385 return simplified[self._name]
387 expr = self
389 while True:
390 out = expr._simplify_down()
391 if out is None:
392 out = expr
393 if not isinstance(out, Expr):
394 return out
395 if out._name != expr._name:
396 expr = out
398 # Allow children to simplify their parents
399 for child in expr.dependencies():
400 out = child._simplify_up(expr, dependents)
401 if out is None:
402 out = expr
404 if not isinstance(out, Expr):
405 return out
406 if out is not expr and out._name != expr._name:
407 expr = out
408 break
410 # Rewrite all of the children
411 new_operands = []
412 changed = False
413 for operand in expr.operands:
414 if isinstance(operand, Expr):
415 # Bandaid for now, waiting for Singleton
416 dependents[operand._name].append(weakref.ref(expr))
417 new = operand.simplify_once(
418 dependents=dependents, simplified=simplified
419 )
420 simplified[operand._name] = new
421 if new._name != operand._name:
422 changed = True
423 else:
424 new = operand
425 new_operands.append(new)
427 if changed:
428 expr = type(expr)(*new_operands)
430 break
432 return expr
434 def optimize(self, fuse: bool = False) -> Expr:
435 stage: OptimizerStage = "fused" if fuse else "simplified-physical"
437 return optimize_until(self, stage)
439 def fuse(self) -> Expr:
440 return self
442 def simplify(self) -> Expr:
443 expr = self
444 seen = set()
445 while True:
446 dependents = collect_dependents(expr)
447 new = expr.simplify_once(dependents=dependents, simplified={})
448 if new._name == expr._name:
449 break
450 if new._name in seen:
451 raise RuntimeError(
452 f"Optimizer does not converge. {expr!r} simplified to {new!r} which was already seen. "
453 "Please report this issue on the dask issue tracker with a minimal reproducer."
454 )
455 seen.add(new._name)
456 expr = new
457 return expr
459 def _simplify_down(self):
460 return
462 def _simplify_up(self, parent, dependents):
463 return
465 def lower_once(self, lowered: dict):
466 # Check for a chached result
467 try:
468 return lowered[self._name]
469 except KeyError:
470 pass
472 expr = self
474 # Lower this node
475 out = expr._lower()
476 if out is None:
477 out = expr
478 if not isinstance(out, Expr):
479 return out
481 # Lower all children
482 new_operands = []
483 changed = False
484 for operand in out.operands:
485 if isinstance(operand, Expr):
486 new = operand.lower_once(lowered)
487 if new._name != operand._name:
488 changed = True
489 else:
490 new = operand
491 new_operands.append(new)
493 if changed:
494 out = type(out)(*new_operands)
496 # Cache the result and return
497 return lowered.setdefault(self._name, out)
499 def lower_completely(self) -> Expr:
500 """Lower an expression completely
502 This calls the ``lower_once`` method in a loop
503 until nothing changes. This function does not
504 apply any other optimizations (like ``simplify``).
506 Returns
507 -------
508 expr:
509 output expression
511 See Also
512 --------
513 Expr.lower_once
514 Expr._lower
515 """
516 # Lower until nothing changes
517 expr = self
518 lowered: dict = {}
519 while True:
520 new = expr.lower_once(lowered)
521 if new._name == expr._name:
522 break
523 expr = new
524 return expr
526 def _lower(self):
527 return
529 @functools.cached_property
530 def _funcname(self) -> str:
531 return funcname(type(self)).lower()
533 @property
534 def deterministic_token(self):
535 if not self._determ_token:
536 # Just tokenize self to fall back on __dask_tokenize__
537 # Note how this differs to the implementation of __dask_tokenize__
538 self._determ_token = self.__dask_tokenize__()
539 return self._determ_token
541 @functools.cached_property
542 def _name(self) -> str:
543 return self._funcname + "-" + self.deterministic_token
545 @property
546 def _meta(self):
547 raise NotImplementedError()
549 @classmethod
550 def _annotations_tombstone(cls) -> _AnnotationsTombstone:
551 return _AnnotationsTombstone()
553 def __dask_annotations__(self):
554 return {}
556 def __dask_graph__(self):
557 """Traverse expression tree, collect layers
559 Subclasses generally do not want to override this method unless custom
560 logic is required to treat (e.g. ignore) specific operands during graph
561 generation.
563 See also
564 --------
565 Expr._layer
566 Expr._task
567 """
568 stack = [self]
569 seen = set()
570 layers = []
571 while stack:
572 expr = stack.pop()
574 if expr._name in seen:
575 continue
576 seen.add(expr._name)
578 layers.append(expr._layer())
579 for operand in expr.dependencies():
580 stack.append(operand)
582 return toolz.merge(layers)
584 @property
585 def dask(self):
586 return self.__dask_graph__()
588 def substitute(self, old, new) -> Expr:
589 """Substitute a specific term within the expression
591 Note that replacing non-`Expr` terms may produce
592 unexpected results, and is not recommended.
593 Substituting boolean values is not allowed.
595 Parameters
596 ----------
597 old:
598 Old term to find and replace.
599 new:
600 New term to replace instances of `old` with.
602 Examples
603 --------
604 >>> (df + 10).substitute(10, 20) # doctest: +SKIP
605 df + 20
606 """
607 return self._substitute(old, new, _seen=set())
609 def _substitute(self, old, new, _seen):
610 if self._name in _seen:
611 return self
612 # Check if we are replacing a literal
613 if isinstance(old, Expr):
614 substitute_literal = False
615 if self._name == old._name:
616 return new
617 else:
618 substitute_literal = True
619 if isinstance(old, bool):
620 raise TypeError("Arguments to `substitute` cannot be bool.")
622 new_exprs = []
623 update = False
624 for operand in self.operands:
625 if isinstance(operand, Expr):
626 val = operand._substitute(old, new, _seen)
627 if operand._name != val._name:
628 update = True
629 new_exprs.append(val)
630 elif (
631 "Fused" in type(self).__name__
632 and isinstance(operand, list)
633 and all(isinstance(op, Expr) for op in operand)
634 ):
635 # Special handling for `Fused`.
636 # We make no promise to dive through a
637 # list operand in general, but NEED to
638 # do so for the `Fused.exprs` operand.
639 val = []
640 for op in operand:
641 val.append(op._substitute(old, new, _seen))
642 if val[-1]._name != op._name:
643 update = True
644 new_exprs.append(val)
645 elif (
646 substitute_literal
647 and not isinstance(operand, bool)
648 and isinstance(operand, type(old))
649 and operand == old
650 ):
651 new_exprs.append(new)
652 update = True
653 else:
654 new_exprs.append(operand)
656 if update: # Only recreate if something changed
657 return type(self)(*new_exprs)
658 else:
659 _seen.add(self._name)
660 return self
662 def substitute_parameters(self, substitutions: dict) -> Expr:
663 """Substitute specific `Expr` parameters
665 Parameters
666 ----------
667 substitutions:
668 Mapping of parameter keys to new values. Keys that
669 are not found in ``self._parameters`` will be ignored.
670 """
671 if not substitutions:
672 return self
674 changed = False
675 new_operands = []
676 for i, operand in enumerate(self.operands):
677 if i < len(self._parameters) and self._parameters[i] in substitutions:
678 new_operands.append(substitutions[self._parameters[i]])
679 changed = True
680 else:
681 new_operands.append(operand)
682 if changed:
683 return type(self)(*new_operands)
684 return self
686 def _node_label_args(self):
687 """Operands to include in the node label by `visualize`"""
688 return self.dependencies()
690 def _to_graphviz(
691 self,
692 rankdir="BT",
693 graph_attr=None,
694 node_attr=None,
695 edge_attr=None,
696 **kwargs,
697 ):
698 from dask.dot import label, name
700 graphviz = import_required(
701 "graphviz",
702 "Drawing dask graphs with the graphviz visualization engine requires the `graphviz` "
703 "python library and the `graphviz` system library.\n\n"
704 "Please either conda or pip install as follows:\n\n"
705 " conda install python-graphviz # either conda install\n"
706 " python -m pip install graphviz # or pip install and follow installation instructions",
707 )
709 graph_attr = graph_attr or {}
710 node_attr = node_attr or {}
711 edge_attr = edge_attr or {}
713 graph_attr["rankdir"] = rankdir
714 node_attr["shape"] = "box"
715 node_attr["fontname"] = "helvetica"
717 graph_attr.update(kwargs)
718 g = graphviz.Digraph(
719 graph_attr=graph_attr,
720 node_attr=node_attr,
721 edge_attr=edge_attr,
722 )
724 stack = [self]
725 seen = set()
726 dependencies = {}
727 while stack:
728 expr = stack.pop()
730 if expr._name in seen:
731 continue
732 seen.add(expr._name)
734 dependencies[expr] = set(expr.dependencies())
735 for dep in expr.dependencies():
736 stack.append(dep)
738 cache = {}
739 for expr in dependencies:
740 expr_name = name(expr)
741 attrs = {}
743 # Make node label
744 deps = [
745 funcname(type(dep)) if isinstance(dep, Expr) else str(dep)
746 for dep in expr._node_label_args()
747 ]
748 _label = funcname(type(expr))
749 if deps:
750 _label = f"{_label}({', '.join(deps)})" if deps else _label
751 node_label = label(_label, cache=cache)
753 attrs.setdefault("label", str(node_label))
754 attrs.setdefault("fontsize", "20")
755 g.node(expr_name, **attrs)
757 for expr, deps in dependencies.items():
758 expr_name = name(expr)
759 for dep in deps:
760 dep_name = name(dep)
761 g.edge(dep_name, expr_name)
763 return g
765 def visualize(self, filename="dask-expr.svg", format=None, **kwargs):
766 """
767 Visualize the expression graph.
768 Requires ``graphviz`` to be installed.
770 Parameters
771 ----------
772 filename : str or None, optional
773 The name of the file to write to disk. If the provided `filename`
774 doesn't include an extension, '.png' will be used by default.
775 If `filename` is None, no file will be written, and the graph is
776 rendered in the Jupyter notebook only.
777 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional
778 Format in which to write output file. Default is 'svg'.
779 **kwargs
780 Additional keyword arguments to forward to ``to_graphviz``.
781 """
782 from dask.dot import graphviz_to_file
784 g = self._to_graphviz(**kwargs)
785 graphviz_to_file(g, filename, format)
786 return g
788 def walk(self) -> Generator[Expr]:
789 """Iterate through all expressions in the tree
791 Returns
792 -------
793 nodes
794 Generator of Expr instances in the graph.
795 Ordering is a depth-first search of the expression tree
796 """
797 stack = [self]
798 seen = set()
799 while stack:
800 node = stack.pop()
801 if node._name in seen:
802 continue
803 seen.add(node._name)
805 for dep in node.dependencies():
806 stack.append(dep)
808 yield node
810 def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]:
811 """Search the expression graph for a specific operation type
813 Parameters
814 ----------
815 operation
816 The operation type to search for.
818 Returns
819 -------
820 nodes
821 Generator of `operation` instances. Ordering corresponds
822 to a depth-first search of the expression graph.
823 """
824 assert (
825 isinstance(operation, tuple)
826 and all(issubclass(e, Expr) for e in operation)
827 or issubclass(operation, Expr) # type: ignore
828 ), "`operation` must be`Expr` subclass)"
829 return (expr for expr in self.walk() if isinstance(expr, operation))
831 def __getattr__(self, key):
832 try:
833 return object.__getattribute__(self, key)
834 except AttributeError as err:
835 if key.startswith("_meta"):
836 # Avoid a recursive loop if/when `self._meta*`
837 # produces an `AttributeError`
838 raise RuntimeError(
839 f"Failed to generate metadata for {self}. "
840 "This operation may not be supported by the current backend."
841 )
843 # Allow operands to be accessed as attributes
844 # as long as the keys are not already reserved
845 # by existing methods/properties
846 _parameters = type(self)._parameters
847 if key in _parameters:
848 idx = _parameters.index(key)
849 return self.operands[idx]
851 raise AttributeError(
852 f"{err}\n\n"
853 "This often means that you are attempting to use an unsupported "
854 f"API function.."
855 )
858class SingletonExpr(Expr):
859 """A singleton Expr class
861 This is used to treat the subclassed expression as a singleton. Singletons
862 are deduplicated by expr._name which is typically based on the dask.tokenize
863 output.
865 This is a crucial performance optimization for expressions that walk through
866 an optimizer and are recreated repeatedly but isn't safe for objects that
867 cannot be reliably or quickly tokenized.
868 """
870 _instances: weakref.WeakValueDictionary[str, SingletonExpr]
872 def __new__(cls, *args, _determ_token=None, **kwargs):
873 if not hasattr(cls, "_instances"):
874 cls._instances = weakref.WeakValueDictionary()
875 inst = super().__new__(cls, *args, _determ_token=_determ_token, **kwargs)
876 _name = inst._name
877 if _name in cls._instances and cls.__init__ == object.__init__:
878 return cls._instances[_name]
880 cls._instances[_name] = inst
881 return inst
884def collect_dependents(expr) -> defaultdict:
885 dependents = defaultdict(list)
886 stack = [expr]
887 seen = set()
888 while stack:
889 node = stack.pop()
890 if node._name in seen:
891 continue
892 seen.add(node._name)
894 for dep in node.dependencies():
895 stack.append(dep)
896 dependents[dep._name].append(weakref.ref(node))
897 return dependents
900def optimize(expr: Expr, fuse: bool = True) -> Expr:
901 """High level query optimization
903 This leverages three optimization passes:
905 1. Class based simplification using the ``_simplify`` function and methods
906 2. Blockwise fusion
908 Parameters
909 ----------
910 expr:
911 Input expression to optimize
912 fuse:
913 whether or not to turn on blockwise fusion
915 See Also
916 --------
917 simplify
918 optimize_blockwise_fusion
919 """
920 stage: OptimizerStage = "fused" if fuse else "simplified-physical"
922 return optimize_until(expr, stage)
925def optimize_until(expr: Expr, stage: OptimizerStage) -> Expr:
926 result = expr
927 if stage == "logical":
928 return result
930 # Simplify
931 expr = result.simplify()
932 if stage == "simplified-logical":
933 return expr
935 # Manipulate Expression to make it more efficient
936 expr = expr.rewrite(kind="tune", rewritten={})
937 if stage == "tuned-logical":
938 return expr
940 # Lower
941 expr = expr.lower_completely()
942 if stage == "physical":
943 return expr
945 # Simplify again
946 expr = expr.simplify()
947 if stage == "simplified-physical":
948 return expr
950 # Final graph-specific optimizations
951 expr = expr.fuse()
952 if stage == "fused":
953 return expr
955 raise ValueError(f"Stage {stage!r} not supported.")
958class LLGExpr(Expr):
959 """Low Level Graph Expression"""
961 _parameters = ["dsk"]
963 def __dask_keys__(self):
964 return list(self.operand("dsk"))
966 def _layer(self) -> dict:
967 return ensure_dict(self.operand("dsk"))
970class HLGExpr(Expr):
971 _parameters = [
972 "dsk",
973 "low_level_optimizer",
974 "output_keys",
975 "postcompute",
976 "_cached_optimized",
977 ]
978 _defaults = {
979 "low_level_optimizer": None,
980 "output_keys": None,
981 "postcompute": None,
982 "_cached_optimized": None,
983 }
985 @property
986 def hlg(self):
987 return self.operand("dsk")
989 @staticmethod
990 def from_collection(collection, optimize_graph=True):
991 from dask.highlevelgraph import HighLevelGraph
993 if hasattr(collection, "dask"):
994 dsk = collection.dask.copy()
995 else:
996 dsk = collection.__dask_graph__()
998 # Delayed objects still ship with low level graphs as `dask` when going
999 # through optimize / persist
1000 if not isinstance(dsk, HighLevelGraph):
1002 dsk = HighLevelGraph.from_collections(
1003 str(id(collection)), dsk, dependencies=()
1004 )
1005 if optimize_graph and not hasattr(collection, "__dask_optimize__"):
1006 warnings.warn(
1007 f"Collection {type(collection)} does not define a "
1008 "`__dask_optimize__` method. In the future this will raise. "
1009 "If no optimization is desired, please set this to `None`.",
1010 PendingDeprecationWarning,
1011 )
1012 low_level_optimizer = None
1013 else:
1014 low_level_optimizer = (
1015 collection.__dask_optimize__ if optimize_graph else None
1016 )
1017 return HLGExpr(
1018 dsk=dsk,
1019 low_level_optimizer=low_level_optimizer,
1020 output_keys=collection.__dask_keys__(),
1021 postcompute=collection.__dask_postcompute__(),
1022 )
1024 def finalize_compute(self):
1025 return HLGFinalizeCompute(
1026 self,
1027 low_level_optimizer=self.low_level_optimizer,
1028 output_keys=self.output_keys,
1029 postcompute=self.postcompute,
1030 )
1032 def __dask_annotations__(self) -> dict[str, dict[Key, object]]:
1033 # optimization has to be called (and cached) since blockwise fusion can
1034 # alter annotations
1035 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)`
1036 dsk = self._optimized_dsk
1037 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict)
1038 for layer in dsk.layers.values():
1039 if layer.annotations:
1040 annot = layer.annotations
1041 for annot_type, value in annot.items():
1042 annotations_by_type[annot_type].update(
1043 {k: (value(k) if callable(value) else value) for k in layer}
1044 )
1045 return dict(annotations_by_type)
1047 def __dask_keys__(self):
1048 if (keys := self.operand("output_keys")) is not None:
1049 return keys
1050 dsk = self.hlg
1051 # Note: This will materialize
1052 dependencies = dsk.get_all_dependencies()
1053 leafs = set(dependencies)
1054 for val in dependencies.values():
1055 leafs -= val
1056 self.output_keys = list(leafs)
1057 return self.output_keys
1059 @functools.cached_property
1060 def _optimized_dsk(self) -> HighLevelGraph:
1061 from dask.highlevelgraph import HighLevelGraph
1063 optimizer = self.low_level_optimizer
1064 keys = self.__dask_keys__()
1065 dsk = self.hlg
1066 if (optimizer := self.low_level_optimizer) is not None:
1067 dsk = optimizer(dsk, keys)
1068 return HighLevelGraph.merge(dsk)
1070 @property
1071 def deterministic_token(self):
1072 if not self._determ_token:
1073 self._determ_token = uuid.uuid4().hex
1074 return self._determ_token
1076 def _layer(self) -> dict:
1077 dsk = self._optimized_dsk
1078 return ensure_dict(dsk)
1081class _HLGExprGroup(HLGExpr):
1082 # Identical to HLGExpr
1083 # Used internally to determine how output keys are supposed to be returned
1084 pass
1087class _HLGExprSequence(Expr):
1089 def __getitem__(self, other):
1090 return self.operands[other]
1092 def _operands_for_repr(self):
1093 return [
1094 f"name={self.operand('name')!r}",
1095 f"dsk={self.operand('dsk')!r}",
1096 ]
1098 def _tree_repr_lines(self, indent=0, recursive=True):
1099 return self._operands_for_repr()
1101 def finalize_compute(self):
1102 return _HLGExprSequence(*[op.finalize_compute() for op in self.operands])
1104 def _tune_down(self):
1105 if len(self.operands) == 1:
1106 return None
1107 from dask.highlevelgraph import HighLevelGraph
1109 groups = toolz.groupby(
1110 lambda x: x.low_level_optimizer if isinstance(x, HLGExpr) else None,
1111 self.operands,
1112 )
1113 exprs = []
1114 changed = False
1115 for optimizer, group in groups.items():
1116 if len(group) > 1:
1117 graphs = [expr.hlg for expr in group]
1119 changed = True
1120 dsk = HighLevelGraph.merge(*graphs)
1121 hlg_group = _HLGExprGroup(
1122 dsk=dsk,
1123 low_level_optimizer=optimizer,
1124 output_keys=[v.__dask_keys__() for v in group],
1125 postcompute=[g.postcompute for g in group],
1126 )
1127 exprs.append(hlg_group)
1128 else:
1129 exprs.append(group[0])
1130 if not changed:
1131 return None
1132 return _HLGExprSequence(*exprs)
1134 @functools.cached_property
1135 def _optimized_dsk(self) -> HighLevelGraph:
1136 from dask.highlevelgraph import HighLevelGraph
1138 hlgexpr: HLGExpr
1139 graphs = []
1140 # simplify_down ensure there are only one HLGExpr per optimizer/finalizer
1141 for hlgexpr in self.operands:
1142 keys = hlgexpr.__dask_keys__()
1143 dsk = hlgexpr.hlg
1144 if (optimizer := hlgexpr.low_level_optimizer) is not None:
1145 dsk = optimizer(dsk, keys)
1146 graphs.append(dsk)
1148 return HighLevelGraph.merge(*graphs)
1150 def __dask_graph__(self):
1151 # This class has to override this and not just _layer to ensure the HLGs
1152 # are not optimized individually
1153 return ensure_dict(self._optimized_dsk)
1155 _layer = __dask_graph__
1157 def __dask_annotations__(self) -> dict[str, dict[Key, object]]:
1158 # optimization has to be called (and cached) since blockwise fusion can
1159 # alter annotations
1160 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)`
1161 dsk = self._optimized_dsk
1162 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict)
1163 for layer in dsk.layers.values():
1164 if layer.annotations:
1165 annot = layer.annotations
1166 for annot_type, value in annot.items():
1167 annots = list(
1168 (k, (value(k) if callable(value) else value)) for k in layer
1169 )
1170 annotations_by_type[annot_type].update(
1171 {
1172 k: v
1173 for k, v in annots
1174 if not isinstance(v, _AnnotationsTombstone)
1175 }
1176 )
1177 if not annotations_by_type[annot_type]:
1178 del annotations_by_type[annot_type]
1179 return dict(annotations_by_type)
1181 def __dask_keys__(self) -> list:
1182 all_keys = []
1183 for op in self.operands:
1184 if isinstance(op, _HLGExprGroup):
1185 all_keys.extend(op.__dask_keys__())
1186 else:
1187 all_keys.append(op.__dask_keys__())
1188 return all_keys
1191class _ExprSequence(Expr):
1192 """A sequence of expressions
1194 This is used to be able to optimize multiple collections combined, e.g. when
1195 being computed simultaneously with ``dask.compute((Expr1, Expr2))``.
1196 """
1198 def __getitem__(self, other):
1199 return self.operands[other]
1201 def _layer(self) -> dict:
1202 return toolz.merge(op._layer() for op in self.operands)
1204 def __dask_keys__(self) -> list:
1205 all_keys = []
1206 for op in self.operands:
1207 all_keys.append(list(op.__dask_keys__()))
1208 return all_keys
1210 def __repr__(self):
1211 return "ExprSequence(" + ", ".join(map(repr, self.operands)) + ")"
1213 __str__ = __repr__
1215 def finalize_compute(self):
1216 return _ExprSequence(
1217 *(op.finalize_compute() for op in self.operands),
1218 )
1220 def __dask_annotations__(self):
1221 annotations_by_type = {}
1222 for op in self.operands:
1223 for k, v in op.__dask_annotations__().items():
1224 annotations_by_type.setdefault(k, {}).update(v)
1225 return annotations_by_type
1227 def __len__(self):
1228 return len(self.operands)
1230 def __iter__(self):
1231 return iter(self.operands)
1233 def _simplify_down(self):
1234 from dask.highlevelgraph import HighLevelGraph
1236 issue_warning = False
1237 hlgs = []
1238 for op in self.operands:
1239 if isinstance(op, (HLGExpr, HLGFinalizeCompute)):
1240 hlgs.append(op)
1241 elif isinstance(op, dict):
1242 hlgs.append(
1243 HLGExpr(
1244 dsk=HighLevelGraph.from_collections(
1245 str(id(op)), op, dependencies=()
1246 )
1247 )
1248 )
1249 elif hlgs:
1250 issue_warning = True
1251 opt = op.optimize()
1252 hlgs.append(
1253 HLGExpr(
1254 dsk=HighLevelGraph.from_collections(
1255 opt._name, opt.__dask_graph__(), dependencies=()
1256 )
1257 )
1258 )
1259 if issue_warning:
1260 warnings.warn(
1261 "Computing mixed collections that are backed by "
1262 "HighlevelGraphs/dicts and Expressions. "
1263 "This forces Expressions to be materialized. "
1264 "It is recommended to use only one type and separate the dask."
1265 "compute calls if necessary.",
1266 UserWarning,
1267 )
1268 if not hlgs:
1269 return None
1270 return _HLGExprSequence(*hlgs)
1273class _AnnotationsTombstone: ...
1276class FinalizeCompute(Expr):
1277 _parameters = ["expr"]
1279 def _simplify_down(self):
1280 return self.expr.finalize_compute()
1283def _convert_dask_keys(keys):
1284 from dask._task_spec import List, TaskRef
1286 assert isinstance(keys, list)
1287 new_keys = []
1288 for key in keys:
1289 if isinstance(key, list):
1290 new_keys.append(_convert_dask_keys(key))
1291 else:
1292 new_keys.append(TaskRef(key))
1293 return List(*new_keys)
1296class HLGFinalizeCompute(HLGExpr):
1298 def _simplify_down(self):
1299 if not self.postcompute:
1300 return self.dsk
1302 from dask.delayed import Delayed
1304 # Skip finalization for Delayed
1305 if self.dsk.postcompute == Delayed.__dask_postcompute__(self.dsk):
1306 return self.dsk
1307 return self
1309 @property
1310 def _name(self):
1311 return f"finalize-{super()._name}"
1313 def __dask_graph__(self):
1314 # The baseclass __dask_graph__ will not just materialize this layer but
1315 # also that of its dependencies, i.e. it will render the finalized and
1316 # the non-finalized graph and combine them. We only want the finalized
1317 # so we're overriding this.
1318 # This is an artifact generated since the wrapped expression is
1319 # identified automatically as a dependency but HLG expressions are not
1320 # working in this layered way.
1321 return self._layer()
1323 @property
1324 def hlg(self):
1325 expr = self.operand("dsk")
1326 layers = expr.dsk.layers.copy()
1327 deps = expr.dsk.dependencies.copy()
1328 keys = expr.__dask_keys__()
1329 if isinstance(expr.postcompute, list):
1330 postcomputes = expr.postcompute
1331 else:
1332 postcomputes = [expr.postcompute]
1333 tasks = [
1334 Task(self._name, func, _convert_dask_keys(keys), *extra_args)
1335 for func, extra_args in postcomputes
1336 ]
1337 from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
1339 leafs = set(deps)
1340 for val in deps.values():
1341 leafs -= val
1342 for t in tasks:
1343 layers[t.key] = MaterializedLayer({t.key: t})
1344 deps[t.key] = leafs
1345 return HighLevelGraph(layers, dependencies=deps)
1347 def __dask_keys__(self):
1348 return [self._name]
1351class ProhibitReuse(Expr):
1352 """
1353 An expression that guarantees that all keys are suffixes with a unique id.
1354 This can be used to break a common subexpression apart.
1355 """
1357 _parameters = ["expr"]
1358 _ALLOWED_TYPES = [HLGExpr, LLGExpr, HLGFinalizeCompute, _HLGExprSequence]
1360 def __dask_keys__(self):
1361 return self._modify_keys(self.expr.__dask_keys__())
1363 @staticmethod
1364 def _identity(obj):
1365 return obj
1367 @functools.cached_property
1368 def _suffix(self):
1369 return uuid.uuid4().hex
1371 def _modify_keys(self, k):
1372 if isinstance(k, list):
1373 return [self._modify_keys(kk) for kk in k]
1374 elif isinstance(k, tuple):
1375 return (self._modify_keys(k[0]),) + k[1:]
1376 elif isinstance(k, (int, float)):
1377 k = str(k)
1378 return f"{k}-{self._suffix}"
1380 def _simplify_down(self):
1381 # FIXME: Shuffling cannot be rewritten since the barrier key is
1382 # hardcoded. Skipping this here should do the trick most of the time
1383 if not isinstance(
1384 self.expr,
1385 tuple(self._ALLOWED_TYPES),
1386 ):
1387 return self.expr
1389 def __dask_graph__(self):
1390 try:
1391 from distributed.shuffle._core import P2PBarrierTask
1392 except ModuleNotFoundError:
1393 P2PBarrierTask = type(None)
1394 dsk = convert_legacy_graph(self.expr.__dask_graph__())
1396 subs = {old_key: self._modify_keys(old_key) for old_key in dsk}
1397 dsk2 = {}
1398 for old_key, new_key in subs.items():
1399 t = dsk[old_key]
1400 if isinstance(t, P2PBarrierTask):
1401 warnings.warn(
1402 "Cannot block reusing for graphs including a "
1403 "P2PBarrierTask. This may cause unexpected results. "
1404 "This typically happens when converting a dask "
1405 "DataFrame to delayed objects.",
1406 UserWarning,
1407 )
1408 return dsk
1409 dsk2[new_key] = Task(
1410 new_key,
1411 ProhibitReuse._identity,
1412 t.substitute(subs),
1413 )
1415 dsk2.update(dsk)
1416 return dsk2
1418 _layer = __dask_graph__