Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/_expr.py: 21%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import functools
4import os
5import uuid
6import warnings
7import weakref
8from collections import defaultdict
9from collections.abc import Generator
10from typing import TYPE_CHECKING, Literal
12import toolz
14import dask
15from dask._task_spec import Task, convert_legacy_graph
16from dask.tokenize import _tokenize_deterministic
17from dask.typing import Key
18from dask.utils import ensure_dict, funcname, import_required
20if TYPE_CHECKING:
21 # TODO import from typing (requires Python >=3.10)
22 from typing import Any, TypeAlias
24 from dask.highlevelgraph import HighLevelGraph
26OptimizerStage: TypeAlias = Literal[
27 "logical",
28 "simplified-logical",
29 "tuned-logical",
30 "physical",
31 "simplified-physical",
32 "fused",
33]
36def _unpack_collections(o):
37 from dask.delayed import Delayed
39 if isinstance(o, Expr):
40 return o
42 if hasattr(o, "expr") and not isinstance(o, Delayed):
43 return o.expr
44 else:
45 return o
48class Expr:
49 _parameters: list[str] = []
50 _defaults: dict[str, Any] = {}
52 _pickle_functools_cache: bool = True
54 operands: list
56 _determ_token: str | None
58 def __new__(cls, *args, _determ_token=None, **kwargs):
59 operands = list(args)
60 for parameter in cls._parameters[len(operands) :]:
61 try:
62 operands.append(kwargs.pop(parameter))
63 except KeyError:
64 operands.append(cls._defaults[parameter])
65 assert not kwargs, kwargs
66 inst = object.__new__(cls)
68 inst._determ_token = _determ_token
69 inst.operands = [_unpack_collections(o) for o in operands]
70 # This is typically cached. Make sure the cache is populated by calling
71 # it once
72 inst._name
73 return inst
75 def _tune_down(self):
76 return None
78 def _tune_up(self, parent):
79 return None
81 def finalize_compute(self):
82 return self
84 def _operands_for_repr(self):
85 return [f"{param}={op!r}" for param, op in zip(self._parameters, self.operands)]
87 def __str__(self):
88 s = ", ".join(self._operands_for_repr())
89 return f"{type(self).__name__}({s})"
91 def __repr__(self):
92 return str(self)
94 def _tree_repr_argument_construction(self, i, op, header):
95 try:
96 param = self._parameters[i]
97 default = self._defaults[param]
98 except (IndexError, KeyError):
99 param = self._parameters[i] if i < len(self._parameters) else ""
100 default = "--no-default--"
102 if repr(op) != repr(default):
103 if param:
104 header += f" {param}={op!r}"
105 else:
106 header += repr(op)
107 return header
109 def _tree_repr_lines(self, indent=0, recursive=True):
110 return " " * indent + repr(self)
112 def tree_repr(self):
113 return os.linesep.join(self._tree_repr_lines())
115 def analyze(self, filename: str | None = None, format: str | None = None) -> None:
116 from dask.dataframe.dask_expr._expr import Expr as DFExpr
117 from dask.dataframe.dask_expr.diagnostics import analyze
119 if not isinstance(self, DFExpr):
120 raise TypeError(
121 "analyze is only supported for dask.dataframe.Expr objects."
122 )
123 return analyze(self, filename=filename, format=format)
125 def explain(
126 self, stage: OptimizerStage = "fused", format: str | None = None
127 ) -> None:
128 from dask.dataframe.dask_expr.diagnostics import explain
130 return explain(self, stage, format)
132 def pprint(self):
133 for line in self._tree_repr_lines():
134 print(line)
136 def __hash__(self):
137 return hash(self._name)
139 def __dask_tokenize__(self):
140 if not self._determ_token:
141 # If the subclass does not implement a __dask_tokenize__ we'll want
142 # to tokenize all operands.
143 # Note how this differs to the implementation of
144 # Expr.deterministic_token
145 self._determ_token = _tokenize_deterministic(type(self), *self.operands)
146 return self._determ_token
148 def __dask_keys__(self):
149 """The keys for this expression
151 This is used to determine the keys of the output collection
152 when this expression is computed.
154 Returns
155 -------
156 keys: list
157 The keys for this expression
158 """
159 return [(self._name, i) for i in range(self.npartitions)]
161 @staticmethod
162 def _reconstruct(*args):
163 typ, *operands, token, cache = args
164 inst = typ(*operands, _determ_token=token)
165 for k, v in cache.items():
166 inst.__dict__[k] = v
167 return inst
169 def __reduce__(self):
170 if dask.config.get("dask-expr-no-serialize", False):
171 raise RuntimeError(f"Serializing a {type(self)} object")
172 cache = {}
173 if type(self)._pickle_functools_cache:
174 for k, v in type(self).__dict__.items():
175 if isinstance(v, functools.cached_property) and k in self.__dict__:
176 cache[k] = getattr(self, k)
178 return Expr._reconstruct, (
179 type(self),
180 *self.operands,
181 self.deterministic_token,
182 cache,
183 )
185 def _depth(self, cache=None):
186 """Depth of the expression tree
188 Returns
189 -------
190 depth: int
191 """
192 if cache is None:
193 cache = {}
194 if not self.dependencies():
195 return 1
196 else:
197 result = []
198 for expr in self.dependencies():
199 if expr._name in cache:
200 result.append(cache[expr._name])
201 else:
202 result.append(expr._depth(cache) + 1)
203 cache[expr._name] = result[-1]
204 return max(result)
206 def __setattr__(self, name: str, value: Any) -> None:
207 if name in ["operands", "_determ_token"]:
208 object.__setattr__(self, name, value)
209 return
210 try:
211 params = type(self)._parameters
212 operands = object.__getattribute__(self, "operands")
213 operands[params.index(name)] = value
214 except ValueError:
215 raise AttributeError(
216 f"{type(self).__name__} object has no attribute {name}"
217 )
219 def operand(self, key):
220 # Access an operand unambiguously
221 # (e.g. if the key is reserved by a method/property)
222 return self.operands[type(self)._parameters.index(key)]
224 def dependencies(self):
225 # Dependencies are `Expr` operands only
226 return [operand for operand in self.operands if isinstance(operand, Expr)]
228 def _task(self, key: Key, index: int) -> Task:
229 """The task for the i'th partition
231 Parameters
232 ----------
233 index:
234 The index of the partition of this dataframe
236 Examples
237 --------
238 >>> class Add(Expr):
239 ... def _task(self, i):
240 ... return Task(
241 ... self.__dask_keys__()[i],
242 ... operator.add,
243 ... TaskRef((self.left._name, i)),
244 ... TaskRef((self.right._name, i))
245 ... )
247 Returns
248 -------
249 task:
250 The Dask task to compute this partition
252 See Also
253 --------
254 Expr._layer
255 """
256 raise NotImplementedError(
257 "Expressions should define either _layer (full dictionary) or _task"
258 f" (single task). This expression {type(self)} defines neither"
259 )
261 def _layer(self) -> dict:
262 """The graph layer added by this expression.
264 Simple expressions that apply one task per partition can choose to only
265 implement `Expr._task` instead.
267 Examples
268 --------
269 >>> class Add(Expr):
270 ... def _layer(self):
271 ... return {
272 ... name: Task(
273 ... name,
274 ... operator.add,
275 ... TaskRef((self.left._name, i)),
276 ... TaskRef((self.right._name, i))
277 ... )
278 ... for i, name in enumerate(self.__dask_keys__())
279 ... }
281 Returns
282 -------
283 layer: dict
284 The Dask task graph added by this expression
286 See Also
287 --------
288 Expr._task
289 Expr.__dask_graph__
290 """
292 return {
293 (self._name, i): self._task((self._name, i), i)
294 for i in range(self.npartitions)
295 }
297 def rewrite(self, kind: str, rewritten):
298 """Rewrite an expression
300 This leverages the ``._{kind}_down`` and ``._{kind}_up``
301 methods defined on each class
303 Returns
304 -------
305 expr:
306 output expression
307 changed:
308 whether or not any change occurred
309 """
310 if self._name in rewritten:
311 return rewritten[self._name]
313 expr = self
314 down_name = f"_{kind}_down"
315 up_name = f"_{kind}_up"
316 while True:
317 _continue = False
319 # Rewrite this node
320 out = getattr(expr, down_name)()
321 if out is None:
322 out = expr
323 if not isinstance(out, Expr):
324 return out
325 if out._name != expr._name:
326 expr = out
327 continue
329 # Allow children to rewrite their parents
330 for child in expr.dependencies():
331 out = getattr(child, up_name)(expr)
332 if out is None:
333 out = expr
334 if not isinstance(out, Expr):
335 return out
336 if out is not expr and out._name != expr._name:
337 expr = out
338 _continue = True
339 break
341 if _continue:
342 continue
344 # Rewrite all of the children
345 new_operands = []
346 changed = False
347 for operand in expr.operands:
348 if isinstance(operand, Expr):
349 new = operand.rewrite(kind=kind, rewritten=rewritten)
350 rewritten[operand._name] = new
351 if new._name != operand._name:
352 changed = True
353 else:
354 new = operand
355 new_operands.append(new)
357 if changed:
358 expr = type(expr)(*new_operands)
359 continue
360 else:
361 break
363 return expr
365 def simplify_once(self, dependents: defaultdict, simplified: dict):
366 """Simplify an expression
368 This leverages the ``._simplify_down`` and ``._simplify_up``
369 methods defined on each class
371 Parameters
372 ----------
374 dependents: defaultdict[list]
375 The dependents for every node.
376 simplified: dict
377 Cache of simplified expressions for these dependents.
379 Returns
380 -------
381 expr:
382 output expression
383 """
384 # Check if we've already simplified for these dependents
385 if self._name in simplified:
386 return simplified[self._name]
388 expr = self
390 while True:
391 out = expr._simplify_down()
392 if out is None:
393 out = expr
394 if not isinstance(out, Expr):
395 return out
396 if out._name != expr._name:
397 expr = out
399 # Allow children to simplify their parents
400 for child in expr.dependencies():
401 out = child._simplify_up(expr, dependents)
402 if out is None:
403 out = expr
405 if not isinstance(out, Expr):
406 return out
407 if out is not expr and out._name != expr._name:
408 expr = out
409 break
411 # Rewrite all of the children
412 new_operands = []
413 changed = False
414 for operand in expr.operands:
415 if isinstance(operand, Expr):
416 # Bandaid for now, waiting for Singleton
417 dependents[operand._name].append(weakref.ref(expr))
418 new = operand.simplify_once(
419 dependents=dependents, simplified=simplified
420 )
421 simplified[operand._name] = new
422 if new._name != operand._name:
423 changed = True
424 else:
425 new = operand
426 new_operands.append(new)
428 if changed:
429 expr = type(expr)(*new_operands)
431 break
433 return expr
435 def optimize(self, fuse: bool = False) -> Expr:
436 stage: OptimizerStage = "fused" if fuse else "simplified-physical"
438 return optimize_until(self, stage)
440 def fuse(self) -> Expr:
441 return self
443 def simplify(self) -> Expr:
444 expr = self
445 seen = set()
446 while True:
447 dependents = collect_dependents(expr)
448 new = expr.simplify_once(dependents=dependents, simplified={})
449 if new._name == expr._name:
450 break
451 if new._name in seen:
452 raise RuntimeError(
453 f"Optimizer does not converge. {expr!r} simplified to {new!r} which was already seen. "
454 "Please report this issue on the dask issue tracker with a minimal reproducer."
455 )
456 seen.add(new._name)
457 expr = new
458 return expr
460 def _simplify_down(self):
461 return
463 def _simplify_up(self, parent, dependents):
464 return
466 def lower_once(self, lowered: dict):
467 # Check for a cached result
468 try:
469 return lowered[self._name]
470 except KeyError:
471 pass
473 expr = self
475 # Lower this node
476 out = expr._lower()
477 if out is None:
478 out = expr
479 if not isinstance(out, Expr):
480 return out
482 # Lower all children
483 new_operands = []
484 changed = False
485 for operand in out.operands:
486 if isinstance(operand, Expr):
487 new = operand.lower_once(lowered)
488 if new._name != operand._name:
489 changed = True
490 else:
491 new = operand
492 new_operands.append(new)
494 if changed:
495 out = type(out)(*new_operands)
497 # Cache the result and return
498 return lowered.setdefault(self._name, out)
500 def lower_completely(self) -> Expr:
501 """Lower an expression completely
503 This calls the ``lower_once`` method in a loop
504 until nothing changes. This function does not
505 apply any other optimizations (like ``simplify``).
507 Returns
508 -------
509 expr:
510 output expression
512 See Also
513 --------
514 Expr.lower_once
515 Expr._lower
516 """
517 # Lower until nothing changes
518 expr = self
519 lowered: dict = {}
520 while True:
521 new = expr.lower_once(lowered)
522 if new._name == expr._name:
523 break
524 expr = new
525 return expr
527 def _lower(self):
528 return
530 @functools.cached_property
531 def _funcname(self) -> str:
532 return funcname(type(self)).lower()
534 @property
535 def deterministic_token(self):
536 if not self._determ_token:
537 # Just tokenize self to fall back on __dask_tokenize__
538 # Note how this differs to the implementation of __dask_tokenize__
539 self._determ_token = self.__dask_tokenize__()
540 return self._determ_token
542 @functools.cached_property
543 def _name(self) -> str:
544 return f"{self._funcname}-{self.deterministic_token}"
546 @property
547 def _meta(self):
548 raise NotImplementedError()
550 @classmethod
551 def _annotations_tombstone(cls) -> _AnnotationsTombstone:
552 return _AnnotationsTombstone()
554 def __dask_annotations__(self):
555 return {}
557 def __dask_graph__(self):
558 """Traverse expression tree, collect layers
560 Subclasses generally do not want to override this method unless custom
561 logic is required to treat (e.g. ignore) specific operands during graph
562 generation.
564 See also
565 --------
566 Expr._layer
567 Expr._task
568 """
569 stack = [self]
570 seen = set()
571 layers = []
572 while stack:
573 expr = stack.pop()
575 if expr._name in seen:
576 continue
577 seen.add(expr._name)
579 layers.append(expr._layer())
580 for operand in expr.dependencies():
581 stack.append(operand)
583 return toolz.merge(layers)
585 @property
586 def dask(self):
587 return self.__dask_graph__()
589 def substitute(self, old, new) -> Expr:
590 """Substitute a specific term within the expression
592 Note that replacing non-`Expr` terms may produce
593 unexpected results, and is not recommended.
594 Substituting boolean values is not allowed.
596 Parameters
597 ----------
598 old:
599 Old term to find and replace.
600 new:
601 New term to replace instances of `old` with.
603 Examples
604 --------
605 >>> (df + 10).substitute(10, 20) # doctest: +SKIP
606 df + 20
607 """
608 return self._substitute(old, new, _seen=set())
610 def _substitute(self, old, new, _seen):
611 if self._name in _seen:
612 return self
613 # Check if we are replacing a literal
614 if isinstance(old, Expr):
615 substitute_literal = False
616 if self._name == old._name:
617 return new
618 else:
619 substitute_literal = True
620 if isinstance(old, bool):
621 raise TypeError("Arguments to `substitute` cannot be bool.")
623 new_exprs = []
624 update = False
625 for operand in self.operands:
626 if isinstance(operand, Expr):
627 val = operand._substitute(old, new, _seen)
628 if operand._name != val._name:
629 update = True
630 new_exprs.append(val)
631 elif (
632 "Fused" in type(self).__name__
633 and isinstance(operand, list)
634 and all(isinstance(op, Expr) for op in operand)
635 ):
636 # Special handling for `Fused`.
637 # We make no promise to dive through a
638 # list operand in general, but NEED to
639 # do so for the `Fused.exprs` operand.
640 val = []
641 for op in operand:
642 val.append(op._substitute(old, new, _seen))
643 if val[-1]._name != op._name:
644 update = True
645 new_exprs.append(val)
646 elif (
647 substitute_literal
648 and not isinstance(operand, bool)
649 and isinstance(operand, type(old))
650 and operand == old
651 ):
652 new_exprs.append(new)
653 update = True
654 else:
655 new_exprs.append(operand)
657 if update: # Only recreate if something changed
658 return type(self)(*new_exprs)
659 else:
660 _seen.add(self._name)
661 return self
663 def substitute_parameters(self, substitutions: dict) -> Expr:
664 """Substitute specific `Expr` parameters
666 Parameters
667 ----------
668 substitutions:
669 Mapping of parameter keys to new values. Keys that
670 are not found in ``self._parameters`` will be ignored.
671 """
672 if not substitutions:
673 return self
675 changed = False
676 new_operands = []
677 for i, operand in enumerate(self.operands):
678 if i < len(self._parameters) and self._parameters[i] in substitutions:
679 new_operands.append(substitutions[self._parameters[i]])
680 changed = True
681 else:
682 new_operands.append(operand)
683 if changed:
684 return type(self)(*new_operands)
685 return self
687 def _node_label_args(self):
688 """Operands to include in the node label by `visualize`"""
689 return self.dependencies()
691 def _to_graphviz(
692 self,
693 rankdir="BT",
694 graph_attr=None,
695 node_attr=None,
696 edge_attr=None,
697 **kwargs,
698 ):
699 from dask.dot import label, name
701 graphviz = import_required(
702 "graphviz",
703 "Drawing dask graphs with the graphviz visualization engine requires the `graphviz` "
704 "python library and the `graphviz` system library.\n\n"
705 "Please either conda or pip install as follows:\n\n"
706 " conda install python-graphviz # either conda install\n"
707 " python -m pip install graphviz # or pip install and follow installation instructions",
708 )
710 graph_attr = graph_attr or {}
711 node_attr = node_attr or {}
712 edge_attr = edge_attr or {}
714 graph_attr["rankdir"] = rankdir
715 node_attr["shape"] = "box"
716 node_attr["fontname"] = "helvetica"
718 graph_attr.update(kwargs)
719 g = graphviz.Digraph(
720 graph_attr=graph_attr,
721 node_attr=node_attr,
722 edge_attr=edge_attr,
723 )
725 stack = [self]
726 seen = set()
727 dependencies = {}
728 while stack:
729 expr = stack.pop()
731 if expr._name in seen:
732 continue
733 seen.add(expr._name)
735 dependencies[expr] = set(expr.dependencies())
736 for dep in expr.dependencies():
737 stack.append(dep)
739 cache = {}
740 for expr in dependencies:
741 expr_name = name(expr)
742 attrs = {}
744 # Make node label
745 deps = [
746 funcname(type(dep)) if isinstance(dep, Expr) else str(dep)
747 for dep in expr._node_label_args()
748 ]
749 _label = funcname(type(expr))
750 if deps:
751 _label = f"{_label}({', '.join(deps)})" if deps else _label
752 node_label = label(_label, cache=cache)
754 attrs.setdefault("label", str(node_label))
755 attrs.setdefault("fontsize", "20")
756 g.node(expr_name, **attrs)
758 for expr, deps in dependencies.items():
759 expr_name = name(expr)
760 for dep in deps:
761 dep_name = name(dep)
762 g.edge(dep_name, expr_name)
764 return g
766 def visualize(self, filename="dask-expr.svg", format=None, **kwargs):
767 """
768 Visualize the expression graph.
769 Requires ``graphviz`` to be installed.
771 Parameters
772 ----------
773 filename : str or None, optional
774 The name of the file to write to disk. If the provided `filename`
775 doesn't include an extension, '.png' will be used by default.
776 If `filename` is None, no file will be written, and the graph is
777 rendered in the Jupyter notebook only.
778 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional
779 Format in which to write output file. Default is 'svg'.
780 **kwargs
781 Additional keyword arguments to forward to ``to_graphviz``.
782 """
783 from dask.dot import graphviz_to_file
785 g = self._to_graphviz(**kwargs)
786 graphviz_to_file(g, filename, format)
787 return g
789 def walk(self) -> Generator[Expr]:
790 """Iterate through all expressions in the tree
792 Returns
793 -------
794 nodes
795 Generator of Expr instances in the graph.
796 Ordering is a depth-first search of the expression tree
797 """
798 stack = [self]
799 seen = set()
800 while stack:
801 node = stack.pop()
802 if node._name in seen:
803 continue
804 seen.add(node._name)
806 for dep in node.dependencies():
807 stack.append(dep)
809 yield node
811 def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]:
812 """Search the expression graph for a specific operation type
814 Parameters
815 ----------
816 operation
817 The operation type to search for.
819 Returns
820 -------
821 nodes
822 Generator of `operation` instances. Ordering corresponds
823 to a depth-first search of the expression graph.
824 """
825 assert (
826 isinstance(operation, tuple)
827 and all(issubclass(e, Expr) for e in operation)
828 or issubclass(operation, Expr) # type: ignore[arg-type]
829 ), "`operation` must be`Expr` subclass)"
830 return (expr for expr in self.walk() if isinstance(expr, operation))
832 def __getattr__(self, key):
833 try:
834 return object.__getattribute__(self, key)
835 except AttributeError as err:
836 if key.startswith("_meta"):
837 # Avoid a recursive loop if/when `self._meta*`
838 # produces an `AttributeError`
839 raise RuntimeError(
840 f"Failed to generate metadata for {self}. "
841 "This operation may not be supported by the current backend."
842 )
844 # Allow operands to be accessed as attributes
845 # as long as the keys are not already reserved
846 # by existing methods/properties
847 _parameters = type(self)._parameters
848 if key in _parameters:
849 idx = _parameters.index(key)
850 return self.operands[idx]
852 raise AttributeError(
853 f"{err}\n\n"
854 "This often means that you are attempting to use an unsupported "
855 f"API function.."
856 )
859class SingletonExpr(Expr):
860 """A singleton Expr class
862 This is used to treat the subclassed expression as a singleton. Singletons
863 are deduplicated by expr._name which is typically based on the dask.tokenize
864 output.
866 This is a crucial performance optimization for expressions that walk through
867 an optimizer and are recreated repeatedly but isn't safe for objects that
868 cannot be reliably or quickly tokenized.
869 """
871 _instances: weakref.WeakValueDictionary[str, SingletonExpr]
873 def __new__(cls, *args, _determ_token=None, **kwargs):
874 if not hasattr(cls, "_instances"):
875 cls._instances = weakref.WeakValueDictionary()
876 inst = super().__new__(cls, *args, _determ_token=_determ_token, **kwargs)
877 _name = inst._name
878 if _name in cls._instances and cls.__init__ == object.__init__:
879 return cls._instances[_name]
881 cls._instances[_name] = inst
882 return inst
885def collect_dependents(expr) -> defaultdict:
886 dependents = defaultdict(list)
887 stack = [expr]
888 seen = set()
889 while stack:
890 node = stack.pop()
891 if node._name in seen:
892 continue
893 seen.add(node._name)
895 for dep in node.dependencies():
896 stack.append(dep)
897 dependents[dep._name].append(weakref.ref(node))
898 return dependents
901def optimize(expr: Expr, fuse: bool = True) -> Expr:
902 """High level query optimization
904 This leverages three optimization passes:
906 1. Class based simplification using the ``_simplify`` function and methods
907 2. Blockwise fusion
909 Parameters
910 ----------
911 expr:
912 Input expression to optimize
913 fuse:
914 whether or not to turn on blockwise fusion
916 See Also
917 --------
918 simplify
919 optimize_blockwise_fusion
920 """
921 stage: OptimizerStage = "fused" if fuse else "simplified-physical"
923 return optimize_until(expr, stage)
926def optimize_until(expr: Expr, stage: OptimizerStage) -> Expr:
927 result = expr
928 if stage == "logical":
929 return result
931 # Simplify
932 expr = result.simplify()
933 if stage == "simplified-logical":
934 return expr
936 # Manipulate Expression to make it more efficient
937 if dask.config.get("optimization.tune.active", True):
938 expr = expr.rewrite(kind="tune", rewritten={})
939 if stage == "tuned-logical":
940 return expr
942 # Lower
943 expr = expr.lower_completely()
944 if stage == "physical":
945 return expr
947 # Simplify again
948 expr = expr.simplify()
949 if stage == "simplified-physical":
950 return expr
952 # Final graph-specific optimizations
953 expr = expr.fuse()
954 if stage == "fused":
955 return expr
957 raise ValueError(f"Stage {stage!r} not supported.")
960class LLGExpr(Expr):
961 """Low Level Graph Expression"""
963 _parameters = ["dsk"]
965 def __dask_keys__(self):
966 return list(self.operand("dsk"))
968 def _layer(self) -> dict:
969 return ensure_dict(self.operand("dsk"))
972class HLGExpr(Expr):
973 _parameters = [
974 "dsk",
975 "low_level_optimizer",
976 "output_keys",
977 "postcompute",
978 "_cached_optimized",
979 ]
980 _defaults = {
981 "low_level_optimizer": None,
982 "output_keys": None,
983 "postcompute": None,
984 "_cached_optimized": None,
985 }
987 @property
988 def hlg(self):
989 return self.operand("dsk")
991 @staticmethod
992 def from_collection(collection, optimize_graph=True):
993 from dask.highlevelgraph import HighLevelGraph
995 if hasattr(collection, "dask"):
996 dsk = collection.dask.copy()
997 else:
998 dsk = collection.__dask_graph__()
1000 # Delayed objects still ship with low level graphs as `dask` when going
1001 # through optimize / persist
1002 if not isinstance(dsk, HighLevelGraph):
1004 dsk = HighLevelGraph.from_collections(
1005 str(id(collection)), dsk, dependencies=()
1006 )
1007 if optimize_graph and not hasattr(collection, "__dask_optimize__"):
1008 warnings.warn(
1009 f"Collection {type(collection)} does not define a "
1010 "`__dask_optimize__` method. In the future this will raise. "
1011 "If no optimization is desired, please set this to `None`.",
1012 PendingDeprecationWarning,
1013 )
1014 low_level_optimizer = None
1015 else:
1016 low_level_optimizer = (
1017 collection.__dask_optimize__ if optimize_graph else None
1018 )
1019 return HLGExpr(
1020 dsk=dsk,
1021 low_level_optimizer=low_level_optimizer,
1022 output_keys=collection.__dask_keys__(),
1023 postcompute=collection.__dask_postcompute__(),
1024 )
1026 def finalize_compute(self):
1027 return HLGFinalizeCompute(
1028 self,
1029 low_level_optimizer=self.low_level_optimizer,
1030 output_keys=self.output_keys,
1031 postcompute=self.postcompute,
1032 )
1034 def __dask_annotations__(self) -> dict[str, dict[Key, object]]:
1035 # optimization has to be called (and cached) since blockwise fusion can
1036 # alter annotations
1037 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)`
1038 dsk = self._optimized_dsk
1039 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict)
1040 for layer in dsk.layers.values():
1041 if layer.annotations:
1042 annot = layer.annotations
1043 for annot_type, value in annot.items():
1044 annotations_by_type[annot_type].update(
1045 {k: (value(k) if callable(value) else value) for k in layer}
1046 )
1047 return dict(annotations_by_type)
1049 def __dask_keys__(self):
1050 if (keys := self.operand("output_keys")) is not None:
1051 return keys
1052 dsk = self.hlg
1053 # Note: This will materialize
1054 dependencies = dsk.get_all_dependencies()
1055 leafs = set(dependencies)
1056 for val in dependencies.values():
1057 leafs -= val
1058 self.output_keys = list(leafs)
1059 return self.output_keys
1061 @functools.cached_property
1062 def _optimized_dsk(self) -> HighLevelGraph:
1063 from dask.highlevelgraph import HighLevelGraph
1065 optimizer = self.low_level_optimizer
1066 keys = self.__dask_keys__()
1067 dsk = self.hlg
1068 if (optimizer := self.low_level_optimizer) is not None:
1069 dsk = optimizer(dsk, keys)
1070 return HighLevelGraph.merge(dsk)
1072 @property
1073 def deterministic_token(self):
1074 if not self._determ_token:
1075 self._determ_token = uuid.uuid4().hex
1076 return self._determ_token
1078 def _layer(self) -> dict:
1079 dsk = self._optimized_dsk
1080 return ensure_dict(dsk)
1083class _HLGExprGroup(HLGExpr):
1084 # Identical to HLGExpr
1085 # Used internally to determine how output keys are supposed to be returned
1086 pass
1089class _HLGExprSequence(Expr):
1091 def __getitem__(self, other):
1092 return self.operands[other]
1094 def _operands_for_repr(self):
1095 return [
1096 f"name={self.operand('name')!r}",
1097 f"dsk={self.operand('dsk')!r}",
1098 ]
1100 def _tree_repr_lines(self, indent=0, recursive=True):
1101 return self._operands_for_repr()
1103 def finalize_compute(self):
1104 return _HLGExprSequence(*[op.finalize_compute() for op in self.operands])
1106 def _tune_down(self):
1107 if len(self.operands) == 1:
1108 return None
1109 from dask.highlevelgraph import HighLevelGraph
1111 groups = toolz.groupby(
1112 lambda x: x.low_level_optimizer if isinstance(x, HLGExpr) else None,
1113 self.operands,
1114 )
1115 exprs = []
1116 changed = False
1117 for optimizer, group in groups.items():
1118 if len(group) > 1:
1119 graphs = [expr.hlg for expr in group]
1121 changed = True
1122 dsk = HighLevelGraph.merge(*graphs)
1123 hlg_group = _HLGExprGroup(
1124 dsk=dsk,
1125 low_level_optimizer=optimizer,
1126 output_keys=[v.__dask_keys__() for v in group],
1127 postcompute=[g.postcompute for g in group],
1128 )
1129 exprs.append(hlg_group)
1130 else:
1131 exprs.append(group[0])
1132 if not changed:
1133 return None
1134 return _HLGExprSequence(*exprs)
1136 @functools.cached_property
1137 def _optimized_dsk(self) -> HighLevelGraph:
1138 from dask.highlevelgraph import HighLevelGraph
1140 hlgexpr: HLGExpr
1141 graphs = []
1142 # simplify_down ensure there are only one HLGExpr per optimizer/finalizer
1143 for hlgexpr in self.operands:
1144 keys = hlgexpr.__dask_keys__()
1145 dsk = hlgexpr.hlg
1146 if (optimizer := hlgexpr.low_level_optimizer) is not None:
1147 dsk = optimizer(dsk, keys)
1148 graphs.append(dsk)
1150 return HighLevelGraph.merge(*graphs)
1152 def __dask_graph__(self):
1153 # This class has to override this and not just _layer to ensure the HLGs
1154 # are not optimized individually
1155 return ensure_dict(self._optimized_dsk)
1157 _layer = __dask_graph__
1159 def __dask_annotations__(self) -> dict[str, dict[Key, object]]:
1160 # optimization has to be called (and cached) since blockwise fusion can
1161 # alter annotations
1162 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)`
1163 dsk = self._optimized_dsk
1164 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict)
1165 for layer in dsk.layers.values():
1166 if layer.annotations:
1167 annot = layer.annotations
1168 for annot_type, value in annot.items():
1169 annots = list(
1170 (k, (value(k) if callable(value) else value)) for k in layer
1171 )
1172 annotations_by_type[annot_type].update(
1173 {
1174 k: v
1175 for k, v in annots
1176 if not isinstance(v, _AnnotationsTombstone)
1177 }
1178 )
1179 if not annotations_by_type[annot_type]:
1180 del annotations_by_type[annot_type]
1181 return dict(annotations_by_type)
1183 def __dask_keys__(self) -> list:
1184 all_keys = []
1185 for op in self.operands:
1186 if isinstance(op, _HLGExprGroup):
1187 all_keys.extend(op.__dask_keys__())
1188 else:
1189 all_keys.append(op.__dask_keys__())
1190 return all_keys
1193class _ExprSequence(Expr):
1194 """A sequence of expressions
1196 This is used to be able to optimize multiple collections combined, e.g. when
1197 being computed simultaneously with ``dask.compute((Expr1, Expr2))``.
1198 """
1200 def __getitem__(self, other):
1201 return self.operands[other]
1203 def _layer(self) -> dict:
1204 return toolz.merge(op._layer() for op in self.operands)
1206 def __dask_keys__(self) -> list:
1207 all_keys = []
1208 for op in self.operands:
1209 all_keys.append(list(op.__dask_keys__()))
1210 return all_keys
1212 def __repr__(self):
1213 return "ExprSequence(" + ", ".join(map(repr, self.operands)) + ")"
1215 __str__ = __repr__
1217 def finalize_compute(self):
1218 return _ExprSequence(
1219 *(op.finalize_compute() for op in self.operands),
1220 )
1222 def __dask_annotations__(self):
1223 annotations_by_type = {}
1224 for op in self.operands:
1225 for k, v in op.__dask_annotations__().items():
1226 annotations_by_type.setdefault(k, {}).update(v)
1227 return annotations_by_type
1229 def __len__(self):
1230 return len(self.operands)
1232 def __iter__(self):
1233 return iter(self.operands)
1235 def _simplify_down(self):
1236 from dask.highlevelgraph import HighLevelGraph
1238 issue_warning = False
1239 hlgs = []
1240 if any(
1241 isinstance(op, (HLGExpr, HLGFinalizeCompute, dict)) for op in self.operands
1242 ):
1243 for op in self.operands:
1244 if isinstance(op, (HLGExpr, HLGFinalizeCompute)):
1245 hlgs.append(op)
1246 elif isinstance(op, dict):
1247 hlgs.append(
1248 HLGExpr(
1249 dsk=HighLevelGraph.from_collections(
1250 str(id(op)), op, dependencies=()
1251 )
1252 )
1253 )
1254 else:
1255 issue_warning = True
1256 opt = op.optimize()
1257 hlgs.append(
1258 HLGExpr(
1259 dsk=HighLevelGraph.from_collections(
1260 opt._name, opt.__dask_graph__(), dependencies=()
1261 )
1262 )
1263 )
1264 if issue_warning:
1265 warnings.warn(
1266 "Computing mixed collections that are backed by "
1267 "HighlevelGraphs/dicts and Expressions. "
1268 "This forces Expressions to be materialized. "
1269 "It is recommended to use only one type and separate the dask."
1270 "compute calls if necessary.",
1271 UserWarning,
1272 )
1273 if not hlgs:
1274 return None
1275 return _HLGExprSequence(*hlgs)
1278class _AnnotationsTombstone: ...
1281class FinalizeCompute(Expr):
1282 _parameters = ["expr"]
1284 def _simplify_down(self):
1285 return self.expr.finalize_compute()
1288def _convert_dask_keys(keys):
1289 from dask._task_spec import List, TaskRef
1291 assert isinstance(keys, list)
1292 new_keys = []
1293 for key in keys:
1294 if isinstance(key, list):
1295 new_keys.append(_convert_dask_keys(key))
1296 else:
1297 new_keys.append(TaskRef(key))
1298 return List(*new_keys)
1301class HLGFinalizeCompute(HLGExpr):
1303 def _simplify_down(self):
1304 if not self.postcompute:
1305 return self.dsk
1307 from dask.delayed import Delayed
1309 # Skip finalization for Delayed
1310 if self.dsk.postcompute == Delayed.__dask_postcompute__(self.dsk):
1311 return self.dsk
1312 return self
1314 @property
1315 def _name(self):
1316 return f"finalize-{super()._name}"
1318 def __dask_graph__(self):
1319 # The baseclass __dask_graph__ will not just materialize this layer but
1320 # also that of its dependencies, i.e. it will render the finalized and
1321 # the non-finalized graph and combine them. We only want the finalized
1322 # so we're overriding this.
1323 # This is an artifact generated since the wrapped expression is
1324 # identified automatically as a dependency but HLG expressions are not
1325 # working in this layered way.
1326 return self._layer()
1328 @property
1329 def hlg(self):
1330 expr = self.operand("dsk")
1331 layers = expr.dsk.layers.copy()
1332 deps = expr.dsk.dependencies.copy()
1333 keys = expr.__dask_keys__()
1334 if isinstance(expr.postcompute, list):
1335 postcomputes = expr.postcompute
1336 else:
1337 postcomputes = [expr.postcompute]
1338 tasks = [
1339 Task(self._name, func, _convert_dask_keys(keys), *extra_args)
1340 for func, extra_args in postcomputes
1341 ]
1342 from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
1344 leafs = set(deps)
1345 for val in deps.values():
1346 leafs -= val
1347 for t in tasks:
1348 layers[t.key] = MaterializedLayer({t.key: t})
1349 deps[t.key] = leafs
1350 return HighLevelGraph(layers, dependencies=deps)
1352 def __dask_keys__(self):
1353 return [self._name]
1356class ProhibitReuse(Expr):
1357 """
1358 An expression that guarantees that all keys are suffixes with a unique id.
1359 This can be used to break a common subexpression apart.
1360 """
1362 _parameters = ["expr"]
1363 _ALLOWED_TYPES = [HLGExpr, LLGExpr, HLGFinalizeCompute, _HLGExprSequence]
1365 def __dask_keys__(self):
1366 return self._modify_keys(self.expr.__dask_keys__())
1368 @staticmethod
1369 def _identity(obj):
1370 return obj
1372 @functools.cached_property
1373 def _suffix(self):
1374 return uuid.uuid4().hex
1376 def _modify_keys(self, k):
1377 if isinstance(k, list):
1378 return [self._modify_keys(kk) for kk in k]
1379 elif isinstance(k, tuple):
1380 return (self._modify_keys(k[0]),) + k[1:]
1381 elif isinstance(k, (int, float)):
1382 k = str(k)
1383 return f"{k}-{self._suffix}"
1385 def _simplify_down(self):
1386 # FIXME: Shuffling cannot be rewritten since the barrier key is
1387 # hardcoded. Skipping this here should do the trick most of the time
1388 if not isinstance(
1389 self.expr,
1390 tuple(self._ALLOWED_TYPES),
1391 ):
1392 return self.expr
1394 def __dask_graph__(self):
1395 try:
1396 from distributed.shuffle._core import P2PBarrierTask
1397 except ModuleNotFoundError:
1398 P2PBarrierTask = type(None)
1399 dsk = convert_legacy_graph(self.expr.__dask_graph__())
1401 subs = {old_key: self._modify_keys(old_key) for old_key in dsk}
1402 dsk2 = {}
1403 for old_key, new_key in subs.items():
1404 t = dsk[old_key]
1405 if isinstance(t, P2PBarrierTask):
1406 warnings.warn(
1407 "Cannot block reusing for graphs including a "
1408 "P2PBarrierTask. This may cause unexpected results. "
1409 "This typically happens when converting a dask "
1410 "DataFrame to delayed objects.",
1411 UserWarning,
1412 )
1413 return dsk
1414 dsk2[new_key] = Task(
1415 new_key,
1416 ProhibitReuse._identity,
1417 t.substitute(subs),
1418 )
1420 dsk2.update(dsk)
1421 return dsk2
1423 _layer = __dask_graph__