Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/_expr.py: 21%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import functools
4import os
5import uuid
6import warnings
7import weakref
8from collections import defaultdict
9from collections.abc import Generator
10from typing import TYPE_CHECKING, Literal
12import toolz
14import dask
15from dask._task_spec import Task, convert_legacy_graph
16from dask.tokenize import _tokenize_deterministic
17from dask.typing import Key
18from dask.utils import ensure_dict, funcname, import_required
20if TYPE_CHECKING:
21 # TODO import from typing (requires Python >=3.10)
22 from typing import Any, TypeAlias
24 from dask.highlevelgraph import HighLevelGraph
26OptimizerStage: TypeAlias = Literal[
27 "logical",
28 "simplified-logical",
29 "tuned-logical",
30 "physical",
31 "simplified-physical",
32 "fused",
33]
36def _unpack_collections(o):
37 from dask.delayed import Delayed
39 if isinstance(o, Expr):
40 return o
42 if hasattr(o, "expr") and not isinstance(o, Delayed):
43 return o.expr
44 else:
45 return o
48class Expr:
49 _parameters: list[str] = []
50 _defaults: dict[str, Any] = {}
52 _pickle_functools_cache: bool = True
54 operands: list
56 _determ_token: str | None
58 def __new__(cls, *args, _determ_token=None, **kwargs):
59 operands = list(args)
60 for parameter in cls._parameters[len(operands) :]:
61 try:
62 operands.append(kwargs.pop(parameter))
63 except KeyError:
64 operands.append(cls._defaults[parameter])
65 assert not kwargs, kwargs
66 inst = object.__new__(cls)
68 inst._determ_token = _determ_token
69 inst.operands = [_unpack_collections(o) for o in operands]
70 # This is typically cached. Make sure the cache is populated by calling
71 # it once
72 inst._name
73 return inst
75 def _tune_down(self):
76 return None
78 def _tune_up(self, parent):
79 return None
81 def finalize_compute(self):
82 return self
84 def _operands_for_repr(self):
85 return [f"{param}={op!r}" for param, op in zip(self._parameters, self.operands)]
87 def __str__(self):
88 s = ", ".join(self._operands_for_repr())
89 return f"{type(self).__name__}({s})"
91 def __repr__(self):
92 return str(self)
94 def _tree_repr_argument_construction(self, i, op, header):
95 try:
96 param = self._parameters[i]
97 default = self._defaults[param]
98 except (IndexError, KeyError):
99 param = self._parameters[i] if i < len(self._parameters) else ""
100 default = "--no-default--"
102 if repr(op) != repr(default):
103 if param:
104 header += f" {param}={op!r}"
105 else:
106 header += repr(op)
107 return header
109 def _tree_repr_lines(self, indent=0, recursive=True):
110 return " " * indent + repr(self)
112 def tree_repr(self):
113 return os.linesep.join(self._tree_repr_lines())
115 def analyze(self, filename: str | None = None, format: str | None = None) -> None:
116 from dask.dataframe.dask_expr._expr import Expr as DFExpr
117 from dask.dataframe.dask_expr.diagnostics import analyze
119 if not isinstance(self, DFExpr):
120 raise TypeError(
121 "analyze is only supported for dask.dataframe.Expr objects."
122 )
123 return analyze(self, filename=filename, format=format)
125 def explain(
126 self, stage: OptimizerStage = "fused", format: str | None = None
127 ) -> None:
128 from dask.dataframe.dask_expr.diagnostics import explain
130 return explain(self, stage, format)
132 def pprint(self):
133 for line in self._tree_repr_lines():
134 print(line)
136 def __hash__(self):
137 return hash(self._name)
139 def __dask_tokenize__(self):
140 if not self._determ_token:
141 # If the subclass does not implement a __dask_tokenize__ we'll want
142 # to tokenize all operands.
143 # Note how this differs to the implementation of
144 # Expr.deterministic_token
145 self._determ_token = _tokenize_deterministic(type(self), *self.operands)
146 return self._determ_token
148 def __dask_keys__(self):
149 """The keys for this expression
151 This is used to determine the keys of the output collection
152 when this expression is computed.
154 Returns
155 -------
156 keys: list
157 The keys for this expression
158 """
159 return [(self._name, i) for i in range(self.npartitions)]
161 @staticmethod
162 def _reconstruct(*args):
163 typ, *operands, token, cache = args
164 inst = typ(*operands, _determ_token=token)
165 for k, v in cache.items():
166 inst.__dict__[k] = v
167 return inst
169 def __reduce__(self):
170 if dask.config.get("dask-expr-no-serialize", False):
171 raise RuntimeError(f"Serializing a {type(self)} object")
172 cache = {}
173 if type(self)._pickle_functools_cache:
174 for k, v in type(self).__dict__.items():
175 if isinstance(v, functools.cached_property) and k in self.__dict__:
176 cache[k] = getattr(self, k)
178 return Expr._reconstruct, (
179 type(self),
180 *self.operands,
181 self.deterministic_token,
182 cache,
183 )
185 def _depth(self, cache=None):
186 """Depth of the expression tree
188 Returns
189 -------
190 depth: int
191 """
192 if cache is None:
193 cache = {}
194 if not self.dependencies():
195 return 1
196 else:
197 result = []
198 for expr in self.dependencies():
199 if expr._name in cache:
200 result.append(cache[expr._name])
201 else:
202 result.append(expr._depth(cache) + 1)
203 cache[expr._name] = result[-1]
204 return max(result)
206 def __setattr__(self, name: str, value: Any) -> None:
207 if name in ["operands", "_determ_token"]:
208 object.__setattr__(self, name, value)
209 return
210 try:
211 params = type(self)._parameters
212 operands = object.__getattribute__(self, "operands")
213 operands[params.index(name)] = value
214 except ValueError:
215 raise AttributeError(
216 f"{type(self).__name__} object has no attribute {name}"
217 )
219 def operand(self, key):
220 # Access an operand unambiguously
221 # (e.g. if the key is reserved by a method/property)
222 return self.operands[type(self)._parameters.index(key)]
224 def dependencies(self):
225 # Dependencies are `Expr` operands only
226 return [operand for operand in self.operands if isinstance(operand, Expr)]
228 def _task(self, key: Key, index: int) -> Task:
229 """The task for the i'th partition
231 Parameters
232 ----------
233 index:
234 The index of the partition of this dataframe
236 Examples
237 --------
238 >>> class Add(Expr):
239 ... def _task(self, i):
240 ... return Task(
241 ... self.__dask_keys__()[i],
242 ... operator.add,
243 ... TaskRef((self.left._name, i)),
244 ... TaskRef((self.right._name, i))
245 ... )
247 Returns
248 -------
249 task:
250 The Dask task to compute this partition
252 See Also
253 --------
254 Expr._layer
255 """
256 raise NotImplementedError(
257 "Expressions should define either _layer (full dictionary) or _task"
258 f" (single task). This expression {type(self)} defines neither"
259 )
261 def _layer(self) -> dict:
262 """The graph layer added by this expression.
264 Simple expressions that apply one task per partition can choose to only
265 implement `Expr._task` instead.
267 Examples
268 --------
269 >>> class Add(Expr):
270 ... def _layer(self):
271 ... return {
272 ... name: Task(
273 ... name,
274 ... operator.add,
275 ... TaskRef((self.left._name, i)),
276 ... TaskRef((self.right._name, i))
277 ... )
278 ... for i, name in enumerate(self.__dask_keys__())
279 ... }
281 Returns
282 -------
283 layer: dict
284 The Dask task graph added by this expression
286 See Also
287 --------
288 Expr._task
289 Expr.__dask_graph__
290 """
292 return {
293 (self._name, i): self._task((self._name, i), i)
294 for i in range(self.npartitions)
295 }
297 def rewrite(self, kind: str, rewritten):
298 """Rewrite an expression
300 This leverages the ``._{kind}_down`` and ``._{kind}_up``
301 methods defined on each class
303 Returns
304 -------
305 expr:
306 output expression
307 changed:
308 whether or not any change occurred
309 """
310 if self._name in rewritten:
311 return rewritten[self._name]
313 expr = self
314 down_name = f"_{kind}_down"
315 up_name = f"_{kind}_up"
316 while True:
317 _continue = False
319 # Rewrite this node
320 out = getattr(expr, down_name)()
321 if out is None:
322 out = expr
323 if not isinstance(out, Expr):
324 return out
325 if out._name != expr._name:
326 expr = out
327 continue
329 # Allow children to rewrite their parents
330 for child in expr.dependencies():
331 out = getattr(child, up_name)(expr)
332 if out is None:
333 out = expr
334 if not isinstance(out, Expr):
335 return out
336 if out is not expr and out._name != expr._name:
337 expr = out
338 _continue = True
339 break
341 if _continue:
342 continue
344 # Rewrite all of the children
345 new_operands = []
346 changed = False
347 for operand in expr.operands:
348 if isinstance(operand, Expr):
349 new = operand.rewrite(kind=kind, rewritten=rewritten)
350 rewritten[operand._name] = new
351 if new._name != operand._name:
352 changed = True
353 else:
354 new = operand
355 new_operands.append(new)
357 if changed:
358 expr = type(expr)(*new_operands)
359 continue
360 else:
361 break
363 return expr
365 def simplify_once(self, dependents: defaultdict, simplified: dict):
366 """Simplify an expression
368 This leverages the ``._simplify_down`` and ``._simplify_up``
369 methods defined on each class
371 Parameters
372 ----------
374 dependents: defaultdict[list]
375 The dependents for every node.
376 simplified: dict
377 Cache of simplified expressions for these dependents.
379 Returns
380 -------
381 expr:
382 output expression
383 """
384 # Check if we've already simplified for these dependents
385 if self._name in simplified:
386 return simplified[self._name]
388 expr = self
390 while True:
391 out = expr._simplify_down()
392 if out is None:
393 out = expr
394 if not isinstance(out, Expr):
395 return out
396 if out._name != expr._name:
397 expr = out
399 # Allow children to simplify their parents
400 for child in expr.dependencies():
401 out = child._simplify_up(expr, dependents)
402 if out is None:
403 out = expr
405 if not isinstance(out, Expr):
406 return out
407 if out is not expr and out._name != expr._name:
408 expr = out
409 break
411 # Rewrite all of the children
412 new_operands = []
413 changed = False
414 for operand in expr.operands:
415 if isinstance(operand, Expr):
416 # Bandaid for now, waiting for Singleton
417 dependents[operand._name].append(weakref.ref(expr))
418 new = operand.simplify_once(
419 dependents=dependents, simplified=simplified
420 )
421 simplified[operand._name] = new
422 if new._name != operand._name:
423 changed = True
424 else:
425 new = operand
426 new_operands.append(new)
428 if changed:
429 expr = type(expr)(*new_operands)
431 break
433 return expr
435 def optimize(self, fuse: bool = False) -> Expr:
436 stage: OptimizerStage = "fused" if fuse else "simplified-physical"
438 return optimize_until(self, stage)
440 def fuse(self) -> Expr:
441 return self
443 def simplify(self) -> Expr:
444 expr = self
445 seen = set()
446 while True:
447 dependents = collect_dependents(expr)
448 new = expr.simplify_once(dependents=dependents, simplified={})
449 if new._name == expr._name:
450 break
451 if new._name in seen:
452 raise RuntimeError(
453 f"Optimizer does not converge. {expr!r} simplified to {new!r} which was already seen. "
454 "Please report this issue on the dask issue tracker with a minimal reproducer."
455 )
456 seen.add(new._name)
457 expr = new
458 return expr
460 def _simplify_down(self):
461 return
463 def _simplify_up(self, parent, dependents):
464 return
466 def lower_once(self, lowered: dict):
467 # Check for a cached result
468 try:
469 return lowered[self._name]
470 except KeyError:
471 pass
473 expr = self
475 # Lower this node
476 out = expr._lower()
477 if out is None:
478 out = expr
479 if not isinstance(out, Expr):
480 return out
482 # Lower all children
483 new_operands = []
484 changed = False
485 for operand in out.operands:
486 if isinstance(operand, Expr):
487 new = operand.lower_once(lowered)
488 if new._name != operand._name:
489 changed = True
490 else:
491 new = operand
492 new_operands.append(new)
494 if changed:
495 out = type(out)(*new_operands)
497 # Cache the result and return
498 return lowered.setdefault(self._name, out)
500 def lower_completely(self) -> Expr:
501 """Lower an expression completely
503 This calls the ``lower_once`` method in a loop
504 until nothing changes. This function does not
505 apply any other optimizations (like ``simplify``).
507 Returns
508 -------
509 expr:
510 output expression
512 See Also
513 --------
514 Expr.lower_once
515 Expr._lower
516 """
517 # Lower until nothing changes
518 expr = self
519 lowered: dict = {}
520 while True:
521 new = expr.lower_once(lowered)
522 if new._name == expr._name:
523 break
524 expr = new
525 return expr
527 def _lower(self):
528 return
530 @functools.cached_property
531 def _funcname(self) -> str:
532 return funcname(type(self)).lower()
534 @property
535 def deterministic_token(self):
536 if not self._determ_token:
537 # Just tokenize self to fall back on __dask_tokenize__
538 # Note how this differs to the implementation of __dask_tokenize__
539 self._determ_token = self.__dask_tokenize__()
540 return self._determ_token
542 @functools.cached_property
543 def _name(self) -> str:
544 return self._funcname + "-" + self.deterministic_token
546 @property
547 def _meta(self):
548 raise NotImplementedError()
550 @classmethod
551 def _annotations_tombstone(cls) -> _AnnotationsTombstone:
552 return _AnnotationsTombstone()
554 def __dask_annotations__(self):
555 return {}
557 def __dask_graph__(self):
558 """Traverse expression tree, collect layers
560 Subclasses generally do not want to override this method unless custom
561 logic is required to treat (e.g. ignore) specific operands during graph
562 generation.
564 See also
565 --------
566 Expr._layer
567 Expr._task
568 """
569 stack = [self]
570 seen = set()
571 layers = []
572 while stack:
573 expr = stack.pop()
575 if expr._name in seen:
576 continue
577 seen.add(expr._name)
579 layers.append(expr._layer())
580 for operand in expr.dependencies():
581 stack.append(operand)
583 return toolz.merge(layers)
585 @property
586 def dask(self):
587 return self.__dask_graph__()
589 def substitute(self, old, new) -> Expr:
590 """Substitute a specific term within the expression
592 Note that replacing non-`Expr` terms may produce
593 unexpected results, and is not recommended.
594 Substituting boolean values is not allowed.
596 Parameters
597 ----------
598 old:
599 Old term to find and replace.
600 new:
601 New term to replace instances of `old` with.
603 Examples
604 --------
605 >>> (df + 10).substitute(10, 20) # doctest: +SKIP
606 df + 20
607 """
608 return self._substitute(old, new, _seen=set())
610 def _substitute(self, old, new, _seen):
611 if self._name in _seen:
612 return self
613 # Check if we are replacing a literal
614 if isinstance(old, Expr):
615 substitute_literal = False
616 if self._name == old._name:
617 return new
618 else:
619 substitute_literal = True
620 if isinstance(old, bool):
621 raise TypeError("Arguments to `substitute` cannot be bool.")
623 new_exprs = []
624 update = False
625 for operand in self.operands:
626 if isinstance(operand, Expr):
627 val = operand._substitute(old, new, _seen)
628 if operand._name != val._name:
629 update = True
630 new_exprs.append(val)
631 elif (
632 "Fused" in type(self).__name__
633 and isinstance(operand, list)
634 and all(isinstance(op, Expr) for op in operand)
635 ):
636 # Special handling for `Fused`.
637 # We make no promise to dive through a
638 # list operand in general, but NEED to
639 # do so for the `Fused.exprs` operand.
640 val = []
641 for op in operand:
642 val.append(op._substitute(old, new, _seen))
643 if val[-1]._name != op._name:
644 update = True
645 new_exprs.append(val)
646 elif (
647 substitute_literal
648 and not isinstance(operand, bool)
649 and isinstance(operand, type(old))
650 and operand == old
651 ):
652 new_exprs.append(new)
653 update = True
654 else:
655 new_exprs.append(operand)
657 if update: # Only recreate if something changed
658 return type(self)(*new_exprs)
659 else:
660 _seen.add(self._name)
661 return self
663 def substitute_parameters(self, substitutions: dict) -> Expr:
664 """Substitute specific `Expr` parameters
666 Parameters
667 ----------
668 substitutions:
669 Mapping of parameter keys to new values. Keys that
670 are not found in ``self._parameters`` will be ignored.
671 """
672 if not substitutions:
673 return self
675 changed = False
676 new_operands = []
677 for i, operand in enumerate(self.operands):
678 if i < len(self._parameters) and self._parameters[i] in substitutions:
679 new_operands.append(substitutions[self._parameters[i]])
680 changed = True
681 else:
682 new_operands.append(operand)
683 if changed:
684 return type(self)(*new_operands)
685 return self
687 def _node_label_args(self):
688 """Operands to include in the node label by `visualize`"""
689 return self.dependencies()
691 def _to_graphviz(
692 self,
693 rankdir="BT",
694 graph_attr=None,
695 node_attr=None,
696 edge_attr=None,
697 **kwargs,
698 ):
699 from dask.dot import label, name
701 graphviz = import_required(
702 "graphviz",
703 "Drawing dask graphs with the graphviz visualization engine requires the `graphviz` "
704 "python library and the `graphviz` system library.\n\n"
705 "Please either conda or pip install as follows:\n\n"
706 " conda install python-graphviz # either conda install\n"
707 " python -m pip install graphviz # or pip install and follow installation instructions",
708 )
710 graph_attr = graph_attr or {}
711 node_attr = node_attr or {}
712 edge_attr = edge_attr or {}
714 graph_attr["rankdir"] = rankdir
715 node_attr["shape"] = "box"
716 node_attr["fontname"] = "helvetica"
718 graph_attr.update(kwargs)
719 g = graphviz.Digraph(
720 graph_attr=graph_attr,
721 node_attr=node_attr,
722 edge_attr=edge_attr,
723 )
725 stack = [self]
726 seen = set()
727 dependencies = {}
728 while stack:
729 expr = stack.pop()
731 if expr._name in seen:
732 continue
733 seen.add(expr._name)
735 dependencies[expr] = set(expr.dependencies())
736 for dep in expr.dependencies():
737 stack.append(dep)
739 cache = {}
740 for expr in dependencies:
741 expr_name = name(expr)
742 attrs = {}
744 # Make node label
745 deps = [
746 funcname(type(dep)) if isinstance(dep, Expr) else str(dep)
747 for dep in expr._node_label_args()
748 ]
749 _label = funcname(type(expr))
750 if deps:
751 _label = f"{_label}({', '.join(deps)})" if deps else _label
752 node_label = label(_label, cache=cache)
754 attrs.setdefault("label", str(node_label))
755 attrs.setdefault("fontsize", "20")
756 g.node(expr_name, **attrs)
758 for expr, deps in dependencies.items():
759 expr_name = name(expr)
760 for dep in deps:
761 dep_name = name(dep)
762 g.edge(dep_name, expr_name)
764 return g
766 def visualize(self, filename="dask-expr.svg", format=None, **kwargs):
767 """
768 Visualize the expression graph.
769 Requires ``graphviz`` to be installed.
771 Parameters
772 ----------
773 filename : str or None, optional
774 The name of the file to write to disk. If the provided `filename`
775 doesn't include an extension, '.png' will be used by default.
776 If `filename` is None, no file will be written, and the graph is
777 rendered in the Jupyter notebook only.
778 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional
779 Format in which to write output file. Default is 'svg'.
780 **kwargs
781 Additional keyword arguments to forward to ``to_graphviz``.
782 """
783 from dask.dot import graphviz_to_file
785 g = self._to_graphviz(**kwargs)
786 graphviz_to_file(g, filename, format)
787 return g
789 def walk(self) -> Generator[Expr]:
790 """Iterate through all expressions in the tree
792 Returns
793 -------
794 nodes
795 Generator of Expr instances in the graph.
796 Ordering is a depth-first search of the expression tree
797 """
798 stack = [self]
799 seen = set()
800 while stack:
801 node = stack.pop()
802 if node._name in seen:
803 continue
804 seen.add(node._name)
806 for dep in node.dependencies():
807 stack.append(dep)
809 yield node
811 def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]:
812 """Search the expression graph for a specific operation type
814 Parameters
815 ----------
816 operation
817 The operation type to search for.
819 Returns
820 -------
821 nodes
822 Generator of `operation` instances. Ordering corresponds
823 to a depth-first search of the expression graph.
824 """
825 assert (
826 isinstance(operation, tuple)
827 and all(issubclass(e, Expr) for e in operation)
828 or issubclass(operation, Expr) # type: ignore
829 ), "`operation` must be`Expr` subclass)"
830 return (expr for expr in self.walk() if isinstance(expr, operation))
832 def __getattr__(self, key):
833 try:
834 return object.__getattribute__(self, key)
835 except AttributeError as err:
836 if key.startswith("_meta"):
837 # Avoid a recursive loop if/when `self._meta*`
838 # produces an `AttributeError`
839 raise RuntimeError(
840 f"Failed to generate metadata for {self}. "
841 "This operation may not be supported by the current backend."
842 )
844 # Allow operands to be accessed as attributes
845 # as long as the keys are not already reserved
846 # by existing methods/properties
847 _parameters = type(self)._parameters
848 if key in _parameters:
849 idx = _parameters.index(key)
850 return self.operands[idx]
852 raise AttributeError(
853 f"{err}\n\n"
854 "This often means that you are attempting to use an unsupported "
855 f"API function.."
856 )
859class SingletonExpr(Expr):
860 """A singleton Expr class
862 This is used to treat the subclassed expression as a singleton. Singletons
863 are deduplicated by expr._name which is typically based on the dask.tokenize
864 output.
866 This is a crucial performance optimization for expressions that walk through
867 an optimizer and are recreated repeatedly but isn't safe for objects that
868 cannot be reliably or quickly tokenized.
869 """
871 _instances: weakref.WeakValueDictionary[str, SingletonExpr]
873 def __new__(cls, *args, _determ_token=None, **kwargs):
874 if not hasattr(cls, "_instances"):
875 cls._instances = weakref.WeakValueDictionary()
876 inst = super().__new__(cls, *args, _determ_token=_determ_token, **kwargs)
877 _name = inst._name
878 if _name in cls._instances and cls.__init__ == object.__init__:
879 return cls._instances[_name]
881 cls._instances[_name] = inst
882 return inst
885def collect_dependents(expr) -> defaultdict:
886 dependents = defaultdict(list)
887 stack = [expr]
888 seen = set()
889 while stack:
890 node = stack.pop()
891 if node._name in seen:
892 continue
893 seen.add(node._name)
895 for dep in node.dependencies():
896 stack.append(dep)
897 dependents[dep._name].append(weakref.ref(node))
898 return dependents
901def optimize(expr: Expr, fuse: bool = True) -> Expr:
902 """High level query optimization
904 This leverages three optimization passes:
906 1. Class based simplification using the ``_simplify`` function and methods
907 2. Blockwise fusion
909 Parameters
910 ----------
911 expr:
912 Input expression to optimize
913 fuse:
914 whether or not to turn on blockwise fusion
916 See Also
917 --------
918 simplify
919 optimize_blockwise_fusion
920 """
921 stage: OptimizerStage = "fused" if fuse else "simplified-physical"
923 return optimize_until(expr, stage)
926def optimize_until(expr: Expr, stage: OptimizerStage) -> Expr:
927 result = expr
928 if stage == "logical":
929 return result
931 # Simplify
932 expr = result.simplify()
933 if stage == "simplified-logical":
934 return expr
936 # Manipulate Expression to make it more efficient
937 expr = expr.rewrite(kind="tune", rewritten={})
938 if stage == "tuned-logical":
939 return expr
941 # Lower
942 expr = expr.lower_completely()
943 if stage == "physical":
944 return expr
946 # Simplify again
947 expr = expr.simplify()
948 if stage == "simplified-physical":
949 return expr
951 # Final graph-specific optimizations
952 expr = expr.fuse()
953 if stage == "fused":
954 return expr
956 raise ValueError(f"Stage {stage!r} not supported.")
959class LLGExpr(Expr):
960 """Low Level Graph Expression"""
962 _parameters = ["dsk"]
964 def __dask_keys__(self):
965 return list(self.operand("dsk"))
967 def _layer(self) -> dict:
968 return ensure_dict(self.operand("dsk"))
971class HLGExpr(Expr):
972 _parameters = [
973 "dsk",
974 "low_level_optimizer",
975 "output_keys",
976 "postcompute",
977 "_cached_optimized",
978 ]
979 _defaults = {
980 "low_level_optimizer": None,
981 "output_keys": None,
982 "postcompute": None,
983 "_cached_optimized": None,
984 }
986 @property
987 def hlg(self):
988 return self.operand("dsk")
990 @staticmethod
991 def from_collection(collection, optimize_graph=True):
992 from dask.highlevelgraph import HighLevelGraph
994 if hasattr(collection, "dask"):
995 dsk = collection.dask.copy()
996 else:
997 dsk = collection.__dask_graph__()
999 # Delayed objects still ship with low level graphs as `dask` when going
1000 # through optimize / persist
1001 if not isinstance(dsk, HighLevelGraph):
1003 dsk = HighLevelGraph.from_collections(
1004 str(id(collection)), dsk, dependencies=()
1005 )
1006 if optimize_graph and not hasattr(collection, "__dask_optimize__"):
1007 warnings.warn(
1008 f"Collection {type(collection)} does not define a "
1009 "`__dask_optimize__` method. In the future this will raise. "
1010 "If no optimization is desired, please set this to `None`.",
1011 PendingDeprecationWarning,
1012 )
1013 low_level_optimizer = None
1014 else:
1015 low_level_optimizer = (
1016 collection.__dask_optimize__ if optimize_graph else None
1017 )
1018 return HLGExpr(
1019 dsk=dsk,
1020 low_level_optimizer=low_level_optimizer,
1021 output_keys=collection.__dask_keys__(),
1022 postcompute=collection.__dask_postcompute__(),
1023 )
1025 def finalize_compute(self):
1026 return HLGFinalizeCompute(
1027 self,
1028 low_level_optimizer=self.low_level_optimizer,
1029 output_keys=self.output_keys,
1030 postcompute=self.postcompute,
1031 )
1033 def __dask_annotations__(self) -> dict[str, dict[Key, object]]:
1034 # optimization has to be called (and cached) since blockwise fusion can
1035 # alter annotations
1036 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)`
1037 dsk = self._optimized_dsk
1038 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict)
1039 for layer in dsk.layers.values():
1040 if layer.annotations:
1041 annot = layer.annotations
1042 for annot_type, value in annot.items():
1043 annotations_by_type[annot_type].update(
1044 {k: (value(k) if callable(value) else value) for k in layer}
1045 )
1046 return dict(annotations_by_type)
1048 def __dask_keys__(self):
1049 if (keys := self.operand("output_keys")) is not None:
1050 return keys
1051 dsk = self.hlg
1052 # Note: This will materialize
1053 dependencies = dsk.get_all_dependencies()
1054 leafs = set(dependencies)
1055 for val in dependencies.values():
1056 leafs -= val
1057 self.output_keys = list(leafs)
1058 return self.output_keys
1060 @functools.cached_property
1061 def _optimized_dsk(self) -> HighLevelGraph:
1062 from dask.highlevelgraph import HighLevelGraph
1064 optimizer = self.low_level_optimizer
1065 keys = self.__dask_keys__()
1066 dsk = self.hlg
1067 if (optimizer := self.low_level_optimizer) is not None:
1068 dsk = optimizer(dsk, keys)
1069 return HighLevelGraph.merge(dsk)
1071 @property
1072 def deterministic_token(self):
1073 if not self._determ_token:
1074 self._determ_token = uuid.uuid4().hex
1075 return self._determ_token
1077 def _layer(self) -> dict:
1078 dsk = self._optimized_dsk
1079 return ensure_dict(dsk)
1082class _HLGExprGroup(HLGExpr):
1083 # Identical to HLGExpr
1084 # Used internally to determine how output keys are supposed to be returned
1085 pass
1088class _HLGExprSequence(Expr):
1090 def __getitem__(self, other):
1091 return self.operands[other]
1093 def _operands_for_repr(self):
1094 return [
1095 f"name={self.operand('name')!r}",
1096 f"dsk={self.operand('dsk')!r}",
1097 ]
1099 def _tree_repr_lines(self, indent=0, recursive=True):
1100 return self._operands_for_repr()
1102 def finalize_compute(self):
1103 return _HLGExprSequence(*[op.finalize_compute() for op in self.operands])
1105 def _tune_down(self):
1106 if len(self.operands) == 1:
1107 return None
1108 from dask.highlevelgraph import HighLevelGraph
1110 groups = toolz.groupby(
1111 lambda x: x.low_level_optimizer if isinstance(x, HLGExpr) else None,
1112 self.operands,
1113 )
1114 exprs = []
1115 changed = False
1116 for optimizer, group in groups.items():
1117 if len(group) > 1:
1118 graphs = [expr.hlg for expr in group]
1120 changed = True
1121 dsk = HighLevelGraph.merge(*graphs)
1122 hlg_group = _HLGExprGroup(
1123 dsk=dsk,
1124 low_level_optimizer=optimizer,
1125 output_keys=[v.__dask_keys__() for v in group],
1126 postcompute=[g.postcompute for g in group],
1127 )
1128 exprs.append(hlg_group)
1129 else:
1130 exprs.append(group[0])
1131 if not changed:
1132 return None
1133 return _HLGExprSequence(*exprs)
1135 @functools.cached_property
1136 def _optimized_dsk(self) -> HighLevelGraph:
1137 from dask.highlevelgraph import HighLevelGraph
1139 hlgexpr: HLGExpr
1140 graphs = []
1141 # simplify_down ensure there are only one HLGExpr per optimizer/finalizer
1142 for hlgexpr in self.operands:
1143 keys = hlgexpr.__dask_keys__()
1144 dsk = hlgexpr.hlg
1145 if (optimizer := hlgexpr.low_level_optimizer) is not None:
1146 dsk = optimizer(dsk, keys)
1147 graphs.append(dsk)
1149 return HighLevelGraph.merge(*graphs)
1151 def __dask_graph__(self):
1152 # This class has to override this and not just _layer to ensure the HLGs
1153 # are not optimized individually
1154 return ensure_dict(self._optimized_dsk)
1156 _layer = __dask_graph__
1158 def __dask_annotations__(self) -> dict[str, dict[Key, object]]:
1159 # optimization has to be called (and cached) since blockwise fusion can
1160 # alter annotations
1161 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)`
1162 dsk = self._optimized_dsk
1163 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict)
1164 for layer in dsk.layers.values():
1165 if layer.annotations:
1166 annot = layer.annotations
1167 for annot_type, value in annot.items():
1168 annots = list(
1169 (k, (value(k) if callable(value) else value)) for k in layer
1170 )
1171 annotations_by_type[annot_type].update(
1172 {
1173 k: v
1174 for k, v in annots
1175 if not isinstance(v, _AnnotationsTombstone)
1176 }
1177 )
1178 if not annotations_by_type[annot_type]:
1179 del annotations_by_type[annot_type]
1180 return dict(annotations_by_type)
1182 def __dask_keys__(self) -> list:
1183 all_keys = []
1184 for op in self.operands:
1185 if isinstance(op, _HLGExprGroup):
1186 all_keys.extend(op.__dask_keys__())
1187 else:
1188 all_keys.append(op.__dask_keys__())
1189 return all_keys
1192class _ExprSequence(Expr):
1193 """A sequence of expressions
1195 This is used to be able to optimize multiple collections combined, e.g. when
1196 being computed simultaneously with ``dask.compute((Expr1, Expr2))``.
1197 """
1199 def __getitem__(self, other):
1200 return self.operands[other]
1202 def _layer(self) -> dict:
1203 return toolz.merge(op._layer() for op in self.operands)
1205 def __dask_keys__(self) -> list:
1206 all_keys = []
1207 for op in self.operands:
1208 all_keys.append(list(op.__dask_keys__()))
1209 return all_keys
1211 def __repr__(self):
1212 return "ExprSequence(" + ", ".join(map(repr, self.operands)) + ")"
1214 __str__ = __repr__
1216 def finalize_compute(self):
1217 return _ExprSequence(
1218 *(op.finalize_compute() for op in self.operands),
1219 )
1221 def __dask_annotations__(self):
1222 annotations_by_type = {}
1223 for op in self.operands:
1224 for k, v in op.__dask_annotations__().items():
1225 annotations_by_type.setdefault(k, {}).update(v)
1226 return annotations_by_type
1228 def __len__(self):
1229 return len(self.operands)
1231 def __iter__(self):
1232 return iter(self.operands)
1234 def _simplify_down(self):
1235 from dask.highlevelgraph import HighLevelGraph
1237 issue_warning = False
1238 hlgs = []
1239 if any(
1240 isinstance(op, (HLGExpr, HLGFinalizeCompute, dict)) for op in self.operands
1241 ):
1242 for op in self.operands:
1243 if isinstance(op, (HLGExpr, HLGFinalizeCompute)):
1244 hlgs.append(op)
1245 elif isinstance(op, dict):
1246 hlgs.append(
1247 HLGExpr(
1248 dsk=HighLevelGraph.from_collections(
1249 str(id(op)), op, dependencies=()
1250 )
1251 )
1252 )
1253 else:
1254 issue_warning = True
1255 opt = op.optimize()
1256 hlgs.append(
1257 HLGExpr(
1258 dsk=HighLevelGraph.from_collections(
1259 opt._name, opt.__dask_graph__(), dependencies=()
1260 )
1261 )
1262 )
1263 if issue_warning:
1264 warnings.warn(
1265 "Computing mixed collections that are backed by "
1266 "HighlevelGraphs/dicts and Expressions. "
1267 "This forces Expressions to be materialized. "
1268 "It is recommended to use only one type and separate the dask."
1269 "compute calls if necessary.",
1270 UserWarning,
1271 )
1272 if not hlgs:
1273 return None
1274 return _HLGExprSequence(*hlgs)
1277class _AnnotationsTombstone: ...
1280class FinalizeCompute(Expr):
1281 _parameters = ["expr"]
1283 def _simplify_down(self):
1284 return self.expr.finalize_compute()
1287def _convert_dask_keys(keys):
1288 from dask._task_spec import List, TaskRef
1290 assert isinstance(keys, list)
1291 new_keys = []
1292 for key in keys:
1293 if isinstance(key, list):
1294 new_keys.append(_convert_dask_keys(key))
1295 else:
1296 new_keys.append(TaskRef(key))
1297 return List(*new_keys)
1300class HLGFinalizeCompute(HLGExpr):
1302 def _simplify_down(self):
1303 if not self.postcompute:
1304 return self.dsk
1306 from dask.delayed import Delayed
1308 # Skip finalization for Delayed
1309 if self.dsk.postcompute == Delayed.__dask_postcompute__(self.dsk):
1310 return self.dsk
1311 return self
1313 @property
1314 def _name(self):
1315 return f"finalize-{super()._name}"
1317 def __dask_graph__(self):
1318 # The baseclass __dask_graph__ will not just materialize this layer but
1319 # also that of its dependencies, i.e. it will render the finalized and
1320 # the non-finalized graph and combine them. We only want the finalized
1321 # so we're overriding this.
1322 # This is an artifact generated since the wrapped expression is
1323 # identified automatically as a dependency but HLG expressions are not
1324 # working in this layered way.
1325 return self._layer()
1327 @property
1328 def hlg(self):
1329 expr = self.operand("dsk")
1330 layers = expr.dsk.layers.copy()
1331 deps = expr.dsk.dependencies.copy()
1332 keys = expr.__dask_keys__()
1333 if isinstance(expr.postcompute, list):
1334 postcomputes = expr.postcompute
1335 else:
1336 postcomputes = [expr.postcompute]
1337 tasks = [
1338 Task(self._name, func, _convert_dask_keys(keys), *extra_args)
1339 for func, extra_args in postcomputes
1340 ]
1341 from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
1343 leafs = set(deps)
1344 for val in deps.values():
1345 leafs -= val
1346 for t in tasks:
1347 layers[t.key] = MaterializedLayer({t.key: t})
1348 deps[t.key] = leafs
1349 return HighLevelGraph(layers, dependencies=deps)
1351 def __dask_keys__(self):
1352 return [self._name]
1355class ProhibitReuse(Expr):
1356 """
1357 An expression that guarantees that all keys are suffixes with a unique id.
1358 This can be used to break a common subexpression apart.
1359 """
1361 _parameters = ["expr"]
1362 _ALLOWED_TYPES = [HLGExpr, LLGExpr, HLGFinalizeCompute, _HLGExprSequence]
1364 def __dask_keys__(self):
1365 return self._modify_keys(self.expr.__dask_keys__())
1367 @staticmethod
1368 def _identity(obj):
1369 return obj
1371 @functools.cached_property
1372 def _suffix(self):
1373 return uuid.uuid4().hex
1375 def _modify_keys(self, k):
1376 if isinstance(k, list):
1377 return [self._modify_keys(kk) for kk in k]
1378 elif isinstance(k, tuple):
1379 return (self._modify_keys(k[0]),) + k[1:]
1380 elif isinstance(k, (int, float)):
1381 k = str(k)
1382 return f"{k}-{self._suffix}"
1384 def _simplify_down(self):
1385 # FIXME: Shuffling cannot be rewritten since the barrier key is
1386 # hardcoded. Skipping this here should do the trick most of the time
1387 if not isinstance(
1388 self.expr,
1389 tuple(self._ALLOWED_TYPES),
1390 ):
1391 return self.expr
1393 def __dask_graph__(self):
1394 try:
1395 from distributed.shuffle._core import P2PBarrierTask
1396 except ModuleNotFoundError:
1397 P2PBarrierTask = type(None)
1398 dsk = convert_legacy_graph(self.expr.__dask_graph__())
1400 subs = {old_key: self._modify_keys(old_key) for old_key in dsk}
1401 dsk2 = {}
1402 for old_key, new_key in subs.items():
1403 t = dsk[old_key]
1404 if isinstance(t, P2PBarrierTask):
1405 warnings.warn(
1406 "Cannot block reusing for graphs including a "
1407 "P2PBarrierTask. This may cause unexpected results. "
1408 "This typically happens when converting a dask "
1409 "DataFrame to delayed objects.",
1410 UserWarning,
1411 )
1412 return dsk
1413 dsk2[new_key] = Task(
1414 new_key,
1415 ProhibitReuse._identity,
1416 t.substitute(subs),
1417 )
1419 dsk2.update(dsk)
1420 return dsk2
1422 _layer = __dask_graph__