Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/_expr.py: 21%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import functools
4import os
5import uuid
6import warnings
7import weakref
8from collections import defaultdict
9from collections.abc import Generator
10from typing import TYPE_CHECKING, Literal
12import toolz
14import dask
15from dask._task_spec import Task, convert_legacy_graph
16from dask.tokenize import _tokenize_deterministic
17from dask.typing import Key
18from dask.utils import ensure_dict, funcname, import_required
20if TYPE_CHECKING:
21 # TODO import from typing (requires Python >=3.10)
22 from typing import Any, TypeAlias
24 from dask.highlevelgraph import HighLevelGraph
26OptimizerStage: TypeAlias = Literal[
27 "logical",
28 "simplified-logical",
29 "tuned-logical",
30 "physical",
31 "simplified-physical",
32 "fused",
33]
36def _unpack_collections(o):
37 from dask.delayed import Delayed
39 if isinstance(o, Expr):
40 return o
42 if hasattr(o, "expr") and not isinstance(o, Delayed):
43 return o.expr
44 else:
45 return o
48class Expr:
49 _parameters: list[str] = []
50 _defaults: dict[str, Any] = {}
52 _pickle_functools_cache: bool = True
54 operands: list
56 _determ_token: str | None
58 def __new__(cls, *args, _determ_token=None, **kwargs):
59 operands = list(args)
60 for parameter in cls._parameters[len(operands) :]:
61 try:
62 operands.append(kwargs.pop(parameter))
63 except KeyError:
64 operands.append(cls._defaults[parameter])
65 assert not kwargs, kwargs
66 inst = object.__new__(cls)
68 inst._determ_token = _determ_token
69 inst.operands = [_unpack_collections(o) for o in operands]
70 # This is typically cached. Make sure the cache is populated by calling
71 # it once
72 inst._name
73 return inst
75 def _tune_down(self):
76 return None
78 def _tune_up(self, parent):
79 return None
81 def finalize_compute(self):
82 return self
84 def _operands_for_repr(self):
85 return [
86 f"{param}={repr(op)}" for param, op in zip(self._parameters, self.operands)
87 ]
89 def __str__(self):
90 s = ", ".join(self._operands_for_repr())
91 return f"{type(self).__name__}({s})"
93 def __repr__(self):
94 return str(self)
96 def _tree_repr_argument_construction(self, i, op, header):
97 try:
98 param = self._parameters[i]
99 default = self._defaults[param]
100 except (IndexError, KeyError):
101 param = self._parameters[i] if i < len(self._parameters) else ""
102 default = "--no-default--"
104 if repr(op) != repr(default):
105 if param:
106 header += f" {param}={repr(op)}"
107 else:
108 header += repr(op)
109 return header
111 def _tree_repr_lines(self, indent=0, recursive=True):
112 return " " * indent + repr(self)
114 def tree_repr(self):
115 return os.linesep.join(self._tree_repr_lines())
117 def analyze(self, filename: str | None = None, format: str | None = None) -> None:
118 from dask.dataframe.dask_expr._expr import Expr as DFExpr
119 from dask.dataframe.dask_expr.diagnostics import analyze
121 if not isinstance(self, DFExpr):
122 raise TypeError(
123 "analyze is only supported for dask.dataframe.Expr objects."
124 )
125 return analyze(self, filename=filename, format=format)
127 def explain(
128 self, stage: OptimizerStage = "fused", format: str | None = None
129 ) -> None:
130 from dask.dataframe.dask_expr.diagnostics import explain
132 return explain(self, stage, format)
134 def pprint(self):
135 for line in self._tree_repr_lines():
136 print(line)
138 def __hash__(self):
139 return hash(self._name)
141 def __dask_tokenize__(self):
142 if not self._determ_token:
143 # If the subclass does not implement a __dask_tokenize__ we'll want
144 # to tokenize all operands.
145 # Note how this differs to the implementation of
146 # Expr.deterministic_token
147 self._determ_token = _tokenize_deterministic(type(self), *self.operands)
148 return self._determ_token
150 def __dask_keys__(self):
151 """The keys for this expression
153 This is used to determine the keys of the output collection
154 when this expression is computed.
156 Returns
157 -------
158 keys: list
159 The keys for this expression
160 """
161 return [(self._name, i) for i in range(self.npartitions)]
163 @staticmethod
164 def _reconstruct(*args):
165 typ, *operands, token, cache = args
166 inst = typ(*operands, _determ_token=token)
167 for k, v in cache.items():
168 inst.__dict__[k] = v
169 return inst
171 def __reduce__(self):
172 if dask.config.get("dask-expr-no-serialize", False):
173 raise RuntimeError(f"Serializing a {type(self)} object")
174 cache = {}
175 if type(self)._pickle_functools_cache:
176 for k, v in type(self).__dict__.items():
177 if isinstance(v, functools.cached_property) and k in self.__dict__:
178 cache[k] = getattr(self, k)
180 return Expr._reconstruct, (
181 type(self),
182 *self.operands,
183 self.deterministic_token,
184 cache,
185 )
187 def _depth(self, cache=None):
188 """Depth of the expression tree
190 Returns
191 -------
192 depth: int
193 """
194 if cache is None:
195 cache = {}
196 if not self.dependencies():
197 return 1
198 else:
199 result = []
200 for expr in self.dependencies():
201 if expr._name in cache:
202 result.append(cache[expr._name])
203 else:
204 result.append(expr._depth(cache) + 1)
205 cache[expr._name] = result[-1]
206 return max(result)
208 def __setattr__(self, name: str, value: Any) -> None:
209 if name in ["operands", "_determ_token"]:
210 object.__setattr__(self, name, value)
211 return
212 try:
213 params = type(self)._parameters
214 operands = object.__getattribute__(self, "operands")
215 operands[params.index(name)] = value
216 except ValueError:
217 raise AttributeError(
218 f"{type(self).__name__} object has no attribute {name}"
219 )
221 def operand(self, key):
222 # Access an operand unambiguously
223 # (e.g. if the key is reserved by a method/property)
224 return self.operands[type(self)._parameters.index(key)]
226 def dependencies(self):
227 # Dependencies are `Expr` operands only
228 return [operand for operand in self.operands if isinstance(operand, Expr)]
230 def _task(self, key: Key, index: int) -> Task:
231 """The task for the i'th partition
233 Parameters
234 ----------
235 index:
236 The index of the partition of this dataframe
238 Examples
239 --------
240 >>> class Add(Expr):
241 ... def _task(self, i):
242 ... return Task(
243 ... self.__dask_keys__()[i],
244 ... operator.add,
245 ... TaskRef((self.left._name, i)),
246 ... TaskRef((self.right._name, i))
247 ... )
249 Returns
250 -------
251 task:
252 The Dask task to compute this partition
254 See Also
255 --------
256 Expr._layer
257 """
258 raise NotImplementedError(
259 "Expressions should define either _layer (full dictionary) or _task"
260 f" (single task). This expression {type(self)} defines neither"
261 )
263 def _layer(self) -> dict:
264 """The graph layer added by this expression.
266 Simple expressions that apply one task per partition can choose to only
267 implement `Expr._task` instead.
269 Examples
270 --------
271 >>> class Add(Expr):
272 ... def _layer(self):
273 ... return {
274 ... name: Task(
275 ... name,
276 ... operator.add,
277 ... TaskRef((self.left._name, i)),
278 ... TaskRef((self.right._name, i))
279 ... )
280 ... for i, name in enumerate(self.__dask_keys__())
281 ... }
283 Returns
284 -------
285 layer: dict
286 The Dask task graph added by this expression
288 See Also
289 --------
290 Expr._task
291 Expr.__dask_graph__
292 """
294 return {
295 (self._name, i): self._task((self._name, i), i)
296 for i in range(self.npartitions)
297 }
299 def rewrite(self, kind: str, rewritten):
300 """Rewrite an expression
302 This leverages the ``._{kind}_down`` and ``._{kind}_up``
303 methods defined on each class
305 Returns
306 -------
307 expr:
308 output expression
309 changed:
310 whether or not any change occurred
311 """
312 if self._name in rewritten:
313 return rewritten[self._name]
315 expr = self
316 down_name = f"_{kind}_down"
317 up_name = f"_{kind}_up"
318 while True:
319 _continue = False
321 # Rewrite this node
322 out = getattr(expr, down_name)()
323 if out is None:
324 out = expr
325 if not isinstance(out, Expr):
326 return out
327 if out._name != expr._name:
328 expr = out
329 continue
331 # Allow children to rewrite their parents
332 for child in expr.dependencies():
333 out = getattr(child, up_name)(expr)
334 if out is None:
335 out = expr
336 if not isinstance(out, Expr):
337 return out
338 if out is not expr and out._name != expr._name:
339 expr = out
340 _continue = True
341 break
343 if _continue:
344 continue
346 # Rewrite all of the children
347 new_operands = []
348 changed = False
349 for operand in expr.operands:
350 if isinstance(operand, Expr):
351 new = operand.rewrite(kind=kind, rewritten=rewritten)
352 rewritten[operand._name] = new
353 if new._name != operand._name:
354 changed = True
355 else:
356 new = operand
357 new_operands.append(new)
359 if changed:
360 expr = type(expr)(*new_operands)
361 continue
362 else:
363 break
365 return expr
367 def simplify_once(self, dependents: defaultdict, simplified: dict):
368 """Simplify an expression
370 This leverages the ``._simplify_down`` and ``._simplify_up``
371 methods defined on each class
373 Parameters
374 ----------
376 dependents: defaultdict[list]
377 The dependents for every node.
378 simplified: dict
379 Cache of simplified expressions for these dependents.
381 Returns
382 -------
383 expr:
384 output expression
385 """
386 # Check if we've already simplified for these dependents
387 if self._name in simplified:
388 return simplified[self._name]
390 expr = self
392 while True:
393 out = expr._simplify_down()
394 if out is None:
395 out = expr
396 if not isinstance(out, Expr):
397 return out
398 if out._name != expr._name:
399 expr = out
401 # Allow children to simplify their parents
402 for child in expr.dependencies():
403 out = child._simplify_up(expr, dependents)
404 if out is None:
405 out = expr
407 if not isinstance(out, Expr):
408 return out
409 if out is not expr and out._name != expr._name:
410 expr = out
411 break
413 # Rewrite all of the children
414 new_operands = []
415 changed = False
416 for operand in expr.operands:
417 if isinstance(operand, Expr):
418 # Bandaid for now, waiting for Singleton
419 dependents[operand._name].append(weakref.ref(expr))
420 new = operand.simplify_once(
421 dependents=dependents, simplified=simplified
422 )
423 simplified[operand._name] = new
424 if new._name != operand._name:
425 changed = True
426 else:
427 new = operand
428 new_operands.append(new)
430 if changed:
431 expr = type(expr)(*new_operands)
433 break
435 return expr
437 def optimize(self, fuse: bool = False) -> Expr:
438 stage: OptimizerStage = "fused" if fuse else "simplified-physical"
440 return optimize_until(self, stage)
442 def fuse(self) -> Expr:
443 return self
445 def simplify(self) -> Expr:
446 expr = self
447 seen = set()
448 while True:
449 dependents = collect_dependents(expr)
450 new = expr.simplify_once(dependents=dependents, simplified={})
451 if new._name == expr._name:
452 break
453 if new._name in seen:
454 raise RuntimeError(
455 f"Optimizer does not converge. {expr!r} simplified to {new!r} which was already seen. "
456 "Please report this issue on the dask issue tracker with a minimal reproducer."
457 )
458 seen.add(new._name)
459 expr = new
460 return expr
462 def _simplify_down(self):
463 return
465 def _simplify_up(self, parent, dependents):
466 return
468 def lower_once(self, lowered: dict):
469 # Check for a cached result
470 try:
471 return lowered[self._name]
472 except KeyError:
473 pass
475 expr = self
477 # Lower this node
478 out = expr._lower()
479 if out is None:
480 out = expr
481 if not isinstance(out, Expr):
482 return out
484 # Lower all children
485 new_operands = []
486 changed = False
487 for operand in out.operands:
488 if isinstance(operand, Expr):
489 new = operand.lower_once(lowered)
490 if new._name != operand._name:
491 changed = True
492 else:
493 new = operand
494 new_operands.append(new)
496 if changed:
497 out = type(out)(*new_operands)
499 # Cache the result and return
500 return lowered.setdefault(self._name, out)
502 def lower_completely(self) -> Expr:
503 """Lower an expression completely
505 This calls the ``lower_once`` method in a loop
506 until nothing changes. This function does not
507 apply any other optimizations (like ``simplify``).
509 Returns
510 -------
511 expr:
512 output expression
514 See Also
515 --------
516 Expr.lower_once
517 Expr._lower
518 """
519 # Lower until nothing changes
520 expr = self
521 lowered: dict = {}
522 while True:
523 new = expr.lower_once(lowered)
524 if new._name == expr._name:
525 break
526 expr = new
527 return expr
529 def _lower(self):
530 return
532 @functools.cached_property
533 def _funcname(self) -> str:
534 return funcname(type(self)).lower()
536 @property
537 def deterministic_token(self):
538 if not self._determ_token:
539 # Just tokenize self to fall back on __dask_tokenize__
540 # Note how this differs to the implementation of __dask_tokenize__
541 self._determ_token = self.__dask_tokenize__()
542 return self._determ_token
544 @functools.cached_property
545 def _name(self) -> str:
546 return self._funcname + "-" + self.deterministic_token
548 @property
549 def _meta(self):
550 raise NotImplementedError()
552 @classmethod
553 def _annotations_tombstone(cls) -> _AnnotationsTombstone:
554 return _AnnotationsTombstone()
556 def __dask_annotations__(self):
557 return {}
559 def __dask_graph__(self):
560 """Traverse expression tree, collect layers
562 Subclasses generally do not want to override this method unless custom
563 logic is required to treat (e.g. ignore) specific operands during graph
564 generation.
566 See also
567 --------
568 Expr._layer
569 Expr._task
570 """
571 stack = [self]
572 seen = set()
573 layers = []
574 while stack:
575 expr = stack.pop()
577 if expr._name in seen:
578 continue
579 seen.add(expr._name)
581 layers.append(expr._layer())
582 for operand in expr.dependencies():
583 stack.append(operand)
585 return toolz.merge(layers)
587 @property
588 def dask(self):
589 return self.__dask_graph__()
591 def substitute(self, old, new) -> Expr:
592 """Substitute a specific term within the expression
594 Note that replacing non-`Expr` terms may produce
595 unexpected results, and is not recommended.
596 Substituting boolean values is not allowed.
598 Parameters
599 ----------
600 old:
601 Old term to find and replace.
602 new:
603 New term to replace instances of `old` with.
605 Examples
606 --------
607 >>> (df + 10).substitute(10, 20) # doctest: +SKIP
608 df + 20
609 """
610 return self._substitute(old, new, _seen=set())
612 def _substitute(self, old, new, _seen):
613 if self._name in _seen:
614 return self
615 # Check if we are replacing a literal
616 if isinstance(old, Expr):
617 substitute_literal = False
618 if self._name == old._name:
619 return new
620 else:
621 substitute_literal = True
622 if isinstance(old, bool):
623 raise TypeError("Arguments to `substitute` cannot be bool.")
625 new_exprs = []
626 update = False
627 for operand in self.operands:
628 if isinstance(operand, Expr):
629 val = operand._substitute(old, new, _seen)
630 if operand._name != val._name:
631 update = True
632 new_exprs.append(val)
633 elif (
634 "Fused" in type(self).__name__
635 and isinstance(operand, list)
636 and all(isinstance(op, Expr) for op in operand)
637 ):
638 # Special handling for `Fused`.
639 # We make no promise to dive through a
640 # list operand in general, but NEED to
641 # do so for the `Fused.exprs` operand.
642 val = []
643 for op in operand:
644 val.append(op._substitute(old, new, _seen))
645 if val[-1]._name != op._name:
646 update = True
647 new_exprs.append(val)
648 elif (
649 substitute_literal
650 and not isinstance(operand, bool)
651 and isinstance(operand, type(old))
652 and operand == old
653 ):
654 new_exprs.append(new)
655 update = True
656 else:
657 new_exprs.append(operand)
659 if update: # Only recreate if something changed
660 return type(self)(*new_exprs)
661 else:
662 _seen.add(self._name)
663 return self
665 def substitute_parameters(self, substitutions: dict) -> Expr:
666 """Substitute specific `Expr` parameters
668 Parameters
669 ----------
670 substitutions:
671 Mapping of parameter keys to new values. Keys that
672 are not found in ``self._parameters`` will be ignored.
673 """
674 if not substitutions:
675 return self
677 changed = False
678 new_operands = []
679 for i, operand in enumerate(self.operands):
680 if i < len(self._parameters) and self._parameters[i] in substitutions:
681 new_operands.append(substitutions[self._parameters[i]])
682 changed = True
683 else:
684 new_operands.append(operand)
685 if changed:
686 return type(self)(*new_operands)
687 return self
689 def _node_label_args(self):
690 """Operands to include in the node label by `visualize`"""
691 return self.dependencies()
693 def _to_graphviz(
694 self,
695 rankdir="BT",
696 graph_attr=None,
697 node_attr=None,
698 edge_attr=None,
699 **kwargs,
700 ):
701 from dask.dot import label, name
703 graphviz = import_required(
704 "graphviz",
705 "Drawing dask graphs with the graphviz visualization engine requires the `graphviz` "
706 "python library and the `graphviz` system library.\n\n"
707 "Please either conda or pip install as follows:\n\n"
708 " conda install python-graphviz # either conda install\n"
709 " python -m pip install graphviz # or pip install and follow installation instructions",
710 )
712 graph_attr = graph_attr or {}
713 node_attr = node_attr or {}
714 edge_attr = edge_attr or {}
716 graph_attr["rankdir"] = rankdir
717 node_attr["shape"] = "box"
718 node_attr["fontname"] = "helvetica"
720 graph_attr.update(kwargs)
721 g = graphviz.Digraph(
722 graph_attr=graph_attr,
723 node_attr=node_attr,
724 edge_attr=edge_attr,
725 )
727 stack = [self]
728 seen = set()
729 dependencies = {}
730 while stack:
731 expr = stack.pop()
733 if expr._name in seen:
734 continue
735 seen.add(expr._name)
737 dependencies[expr] = set(expr.dependencies())
738 for dep in expr.dependencies():
739 stack.append(dep)
741 cache = {}
742 for expr in dependencies:
743 expr_name = name(expr)
744 attrs = {}
746 # Make node label
747 deps = [
748 funcname(type(dep)) if isinstance(dep, Expr) else str(dep)
749 for dep in expr._node_label_args()
750 ]
751 _label = funcname(type(expr))
752 if deps:
753 _label = f"{_label}({', '.join(deps)})" if deps else _label
754 node_label = label(_label, cache=cache)
756 attrs.setdefault("label", str(node_label))
757 attrs.setdefault("fontsize", "20")
758 g.node(expr_name, **attrs)
760 for expr, deps in dependencies.items():
761 expr_name = name(expr)
762 for dep in deps:
763 dep_name = name(dep)
764 g.edge(dep_name, expr_name)
766 return g
768 def visualize(self, filename="dask-expr.svg", format=None, **kwargs):
769 """
770 Visualize the expression graph.
771 Requires ``graphviz`` to be installed.
773 Parameters
774 ----------
775 filename : str or None, optional
776 The name of the file to write to disk. If the provided `filename`
777 doesn't include an extension, '.png' will be used by default.
778 If `filename` is None, no file will be written, and the graph is
779 rendered in the Jupyter notebook only.
780 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional
781 Format in which to write output file. Default is 'svg'.
782 **kwargs
783 Additional keyword arguments to forward to ``to_graphviz``.
784 """
785 from dask.dot import graphviz_to_file
787 g = self._to_graphviz(**kwargs)
788 graphviz_to_file(g, filename, format)
789 return g
791 def walk(self) -> Generator[Expr]:
792 """Iterate through all expressions in the tree
794 Returns
795 -------
796 nodes
797 Generator of Expr instances in the graph.
798 Ordering is a depth-first search of the expression tree
799 """
800 stack = [self]
801 seen = set()
802 while stack:
803 node = stack.pop()
804 if node._name in seen:
805 continue
806 seen.add(node._name)
808 for dep in node.dependencies():
809 stack.append(dep)
811 yield node
813 def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]:
814 """Search the expression graph for a specific operation type
816 Parameters
817 ----------
818 operation
819 The operation type to search for.
821 Returns
822 -------
823 nodes
824 Generator of `operation` instances. Ordering corresponds
825 to a depth-first search of the expression graph.
826 """
827 assert (
828 isinstance(operation, tuple)
829 and all(issubclass(e, Expr) for e in operation)
830 or issubclass(operation, Expr) # type: ignore
831 ), "`operation` must be`Expr` subclass)"
832 return (expr for expr in self.walk() if isinstance(expr, operation))
834 def __getattr__(self, key):
835 try:
836 return object.__getattribute__(self, key)
837 except AttributeError as err:
838 if key.startswith("_meta"):
839 # Avoid a recursive loop if/when `self._meta*`
840 # produces an `AttributeError`
841 raise RuntimeError(
842 f"Failed to generate metadata for {self}. "
843 "This operation may not be supported by the current backend."
844 )
846 # Allow operands to be accessed as attributes
847 # as long as the keys are not already reserved
848 # by existing methods/properties
849 _parameters = type(self)._parameters
850 if key in _parameters:
851 idx = _parameters.index(key)
852 return self.operands[idx]
854 raise AttributeError(
855 f"{err}\n\n"
856 "This often means that you are attempting to use an unsupported "
857 f"API function.."
858 )
861class SingletonExpr(Expr):
862 """A singleton Expr class
864 This is used to treat the subclassed expression as a singleton. Singletons
865 are deduplicated by expr._name which is typically based on the dask.tokenize
866 output.
868 This is a crucial performance optimization for expressions that walk through
869 an optimizer and are recreated repeatedly but isn't safe for objects that
870 cannot be reliably or quickly tokenized.
871 """
873 _instances: weakref.WeakValueDictionary[str, SingletonExpr]
875 def __new__(cls, *args, _determ_token=None, **kwargs):
876 if not hasattr(cls, "_instances"):
877 cls._instances = weakref.WeakValueDictionary()
878 inst = super().__new__(cls, *args, _determ_token=_determ_token, **kwargs)
879 _name = inst._name
880 if _name in cls._instances and cls.__init__ == object.__init__:
881 return cls._instances[_name]
883 cls._instances[_name] = inst
884 return inst
887def collect_dependents(expr) -> defaultdict:
888 dependents = defaultdict(list)
889 stack = [expr]
890 seen = set()
891 while stack:
892 node = stack.pop()
893 if node._name in seen:
894 continue
895 seen.add(node._name)
897 for dep in node.dependencies():
898 stack.append(dep)
899 dependents[dep._name].append(weakref.ref(node))
900 return dependents
903def optimize(expr: Expr, fuse: bool = True) -> Expr:
904 """High level query optimization
906 This leverages three optimization passes:
908 1. Class based simplification using the ``_simplify`` function and methods
909 2. Blockwise fusion
911 Parameters
912 ----------
913 expr:
914 Input expression to optimize
915 fuse:
916 whether or not to turn on blockwise fusion
918 See Also
919 --------
920 simplify
921 optimize_blockwise_fusion
922 """
923 stage: OptimizerStage = "fused" if fuse else "simplified-physical"
925 return optimize_until(expr, stage)
928def optimize_until(expr: Expr, stage: OptimizerStage) -> Expr:
929 result = expr
930 if stage == "logical":
931 return result
933 # Simplify
934 expr = result.simplify()
935 if stage == "simplified-logical":
936 return expr
938 # Manipulate Expression to make it more efficient
939 expr = expr.rewrite(kind="tune", rewritten={})
940 if stage == "tuned-logical":
941 return expr
943 # Lower
944 expr = expr.lower_completely()
945 if stage == "physical":
946 return expr
948 # Simplify again
949 expr = expr.simplify()
950 if stage == "simplified-physical":
951 return expr
953 # Final graph-specific optimizations
954 expr = expr.fuse()
955 if stage == "fused":
956 return expr
958 raise ValueError(f"Stage {stage!r} not supported.")
961class LLGExpr(Expr):
962 """Low Level Graph Expression"""
964 _parameters = ["dsk"]
966 def __dask_keys__(self):
967 return list(self.operand("dsk"))
969 def _layer(self) -> dict:
970 return ensure_dict(self.operand("dsk"))
973class HLGExpr(Expr):
974 _parameters = [
975 "dsk",
976 "low_level_optimizer",
977 "output_keys",
978 "postcompute",
979 "_cached_optimized",
980 ]
981 _defaults = {
982 "low_level_optimizer": None,
983 "output_keys": None,
984 "postcompute": None,
985 "_cached_optimized": None,
986 }
988 @property
989 def hlg(self):
990 return self.operand("dsk")
992 @staticmethod
993 def from_collection(collection, optimize_graph=True):
994 from dask.highlevelgraph import HighLevelGraph
996 if hasattr(collection, "dask"):
997 dsk = collection.dask.copy()
998 else:
999 dsk = collection.__dask_graph__()
1001 # Delayed objects still ship with low level graphs as `dask` when going
1002 # through optimize / persist
1003 if not isinstance(dsk, HighLevelGraph):
1005 dsk = HighLevelGraph.from_collections(
1006 str(id(collection)), dsk, dependencies=()
1007 )
1008 if optimize_graph and not hasattr(collection, "__dask_optimize__"):
1009 warnings.warn(
1010 f"Collection {type(collection)} does not define a "
1011 "`__dask_optimize__` method. In the future this will raise. "
1012 "If no optimization is desired, please set this to `None`.",
1013 PendingDeprecationWarning,
1014 )
1015 low_level_optimizer = None
1016 else:
1017 low_level_optimizer = (
1018 collection.__dask_optimize__ if optimize_graph else None
1019 )
1020 return HLGExpr(
1021 dsk=dsk,
1022 low_level_optimizer=low_level_optimizer,
1023 output_keys=collection.__dask_keys__(),
1024 postcompute=collection.__dask_postcompute__(),
1025 )
1027 def finalize_compute(self):
1028 return HLGFinalizeCompute(
1029 self,
1030 low_level_optimizer=self.low_level_optimizer,
1031 output_keys=self.output_keys,
1032 postcompute=self.postcompute,
1033 )
1035 def __dask_annotations__(self) -> dict[str, dict[Key, object]]:
1036 # optimization has to be called (and cached) since blockwise fusion can
1037 # alter annotations
1038 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)`
1039 dsk = self._optimized_dsk
1040 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict)
1041 for layer in dsk.layers.values():
1042 if layer.annotations:
1043 annot = layer.annotations
1044 for annot_type, value in annot.items():
1045 annotations_by_type[annot_type].update(
1046 {k: (value(k) if callable(value) else value) for k in layer}
1047 )
1048 return dict(annotations_by_type)
1050 def __dask_keys__(self):
1051 if (keys := self.operand("output_keys")) is not None:
1052 return keys
1053 dsk = self.hlg
1054 # Note: This will materialize
1055 dependencies = dsk.get_all_dependencies()
1056 leafs = set(dependencies)
1057 for val in dependencies.values():
1058 leafs -= val
1059 self.output_keys = list(leafs)
1060 return self.output_keys
1062 @functools.cached_property
1063 def _optimized_dsk(self) -> HighLevelGraph:
1064 from dask.highlevelgraph import HighLevelGraph
1066 optimizer = self.low_level_optimizer
1067 keys = self.__dask_keys__()
1068 dsk = self.hlg
1069 if (optimizer := self.low_level_optimizer) is not None:
1070 dsk = optimizer(dsk, keys)
1071 return HighLevelGraph.merge(dsk)
1073 @property
1074 def deterministic_token(self):
1075 if not self._determ_token:
1076 self._determ_token = uuid.uuid4().hex
1077 return self._determ_token
1079 def _layer(self) -> dict:
1080 dsk = self._optimized_dsk
1081 return ensure_dict(dsk)
1084class _HLGExprGroup(HLGExpr):
1085 # Identical to HLGExpr
1086 # Used internally to determine how output keys are supposed to be returned
1087 pass
1090class _HLGExprSequence(Expr):
1092 def __getitem__(self, other):
1093 return self.operands[other]
1095 def _operands_for_repr(self):
1096 return [
1097 f"name={self.operand('name')!r}",
1098 f"dsk={self.operand('dsk')!r}",
1099 ]
1101 def _tree_repr_lines(self, indent=0, recursive=True):
1102 return self._operands_for_repr()
1104 def finalize_compute(self):
1105 return _HLGExprSequence(*[op.finalize_compute() for op in self.operands])
1107 def _tune_down(self):
1108 if len(self.operands) == 1:
1109 return None
1110 from dask.highlevelgraph import HighLevelGraph
1112 groups = toolz.groupby(
1113 lambda x: x.low_level_optimizer if isinstance(x, HLGExpr) else None,
1114 self.operands,
1115 )
1116 exprs = []
1117 changed = False
1118 for optimizer, group in groups.items():
1119 if len(group) > 1:
1120 graphs = [expr.hlg for expr in group]
1122 changed = True
1123 dsk = HighLevelGraph.merge(*graphs)
1124 hlg_group = _HLGExprGroup(
1125 dsk=dsk,
1126 low_level_optimizer=optimizer,
1127 output_keys=[v.__dask_keys__() for v in group],
1128 postcompute=[g.postcompute for g in group],
1129 )
1130 exprs.append(hlg_group)
1131 else:
1132 exprs.append(group[0])
1133 if not changed:
1134 return None
1135 return _HLGExprSequence(*exprs)
1137 @functools.cached_property
1138 def _optimized_dsk(self) -> HighLevelGraph:
1139 from dask.highlevelgraph import HighLevelGraph
1141 hlgexpr: HLGExpr
1142 graphs = []
1143 # simplify_down ensure there are only one HLGExpr per optimizer/finalizer
1144 for hlgexpr in self.operands:
1145 keys = hlgexpr.__dask_keys__()
1146 dsk = hlgexpr.hlg
1147 if (optimizer := hlgexpr.low_level_optimizer) is not None:
1148 dsk = optimizer(dsk, keys)
1149 graphs.append(dsk)
1151 return HighLevelGraph.merge(*graphs)
1153 def __dask_graph__(self):
1154 # This class has to override this and not just _layer to ensure the HLGs
1155 # are not optimized individually
1156 return ensure_dict(self._optimized_dsk)
1158 _layer = __dask_graph__
1160 def __dask_annotations__(self) -> dict[str, dict[Key, object]]:
1161 # optimization has to be called (and cached) since blockwise fusion can
1162 # alter annotations
1163 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)`
1164 dsk = self._optimized_dsk
1165 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict)
1166 for layer in dsk.layers.values():
1167 if layer.annotations:
1168 annot = layer.annotations
1169 for annot_type, value in annot.items():
1170 annots = list(
1171 (k, (value(k) if callable(value) else value)) for k in layer
1172 )
1173 annotations_by_type[annot_type].update(
1174 {
1175 k: v
1176 for k, v in annots
1177 if not isinstance(v, _AnnotationsTombstone)
1178 }
1179 )
1180 if not annotations_by_type[annot_type]:
1181 del annotations_by_type[annot_type]
1182 return dict(annotations_by_type)
1184 def __dask_keys__(self) -> list:
1185 all_keys = []
1186 for op in self.operands:
1187 if isinstance(op, _HLGExprGroup):
1188 all_keys.extend(op.__dask_keys__())
1189 else:
1190 all_keys.append(op.__dask_keys__())
1191 return all_keys
1194class _ExprSequence(Expr):
1195 """A sequence of expressions
1197 This is used to be able to optimize multiple collections combined, e.g. when
1198 being computed simultaneously with ``dask.compute((Expr1, Expr2))``.
1199 """
1201 def __getitem__(self, other):
1202 return self.operands[other]
1204 def _layer(self) -> dict:
1205 return toolz.merge(op._layer() for op in self.operands)
1207 def __dask_keys__(self) -> list:
1208 all_keys = []
1209 for op in self.operands:
1210 all_keys.append(list(op.__dask_keys__()))
1211 return all_keys
1213 def __repr__(self):
1214 return "ExprSequence(" + ", ".join(map(repr, self.operands)) + ")"
1216 __str__ = __repr__
1218 def finalize_compute(self):
1219 return _ExprSequence(
1220 *(op.finalize_compute() for op in self.operands),
1221 )
1223 def __dask_annotations__(self):
1224 annotations_by_type = {}
1225 for op in self.operands:
1226 for k, v in op.__dask_annotations__().items():
1227 annotations_by_type.setdefault(k, {}).update(v)
1228 return annotations_by_type
1230 def __len__(self):
1231 return len(self.operands)
1233 def __iter__(self):
1234 return iter(self.operands)
1236 def _simplify_down(self):
1237 from dask.highlevelgraph import HighLevelGraph
1239 issue_warning = False
1240 hlgs = []
1241 for op in self.operands:
1242 if isinstance(op, (HLGExpr, HLGFinalizeCompute)):
1243 hlgs.append(op)
1244 elif isinstance(op, dict):
1245 hlgs.append(
1246 HLGExpr(
1247 dsk=HighLevelGraph.from_collections(
1248 str(id(op)), op, dependencies=()
1249 )
1250 )
1251 )
1252 elif hlgs:
1253 issue_warning = True
1254 opt = op.optimize()
1255 hlgs.append(
1256 HLGExpr(
1257 dsk=HighLevelGraph.from_collections(
1258 opt._name, opt.__dask_graph__(), dependencies=()
1259 )
1260 )
1261 )
1262 if issue_warning:
1263 warnings.warn(
1264 "Computing mixed collections that are backed by "
1265 "HighlevelGraphs/dicts and Expressions. "
1266 "This forces Expressions to be materialized. "
1267 "It is recommended to use only one type and separate the dask."
1268 "compute calls if necessary.",
1269 UserWarning,
1270 )
1271 if not hlgs:
1272 return None
1273 return _HLGExprSequence(*hlgs)
1276class _AnnotationsTombstone: ...
1279class FinalizeCompute(Expr):
1280 _parameters = ["expr"]
1282 def _simplify_down(self):
1283 return self.expr.finalize_compute()
1286def _convert_dask_keys(keys):
1287 from dask._task_spec import List, TaskRef
1289 assert isinstance(keys, list)
1290 new_keys = []
1291 for key in keys:
1292 if isinstance(key, list):
1293 new_keys.append(_convert_dask_keys(key))
1294 else:
1295 new_keys.append(TaskRef(key))
1296 return List(*new_keys)
1299class HLGFinalizeCompute(HLGExpr):
1301 def _simplify_down(self):
1302 if not self.postcompute:
1303 return self.dsk
1305 from dask.delayed import Delayed
1307 # Skip finalization for Delayed
1308 if self.dsk.postcompute == Delayed.__dask_postcompute__(self.dsk):
1309 return self.dsk
1310 return self
1312 @property
1313 def _name(self):
1314 return f"finalize-{super()._name}"
1316 def __dask_graph__(self):
1317 # The baseclass __dask_graph__ will not just materialize this layer but
1318 # also that of its dependencies, i.e. it will render the finalized and
1319 # the non-finalized graph and combine them. We only want the finalized
1320 # so we're overriding this.
1321 # This is an artifact generated since the wrapped expression is
1322 # identified automatically as a dependency but HLG expressions are not
1323 # working in this layered way.
1324 return self._layer()
1326 @property
1327 def hlg(self):
1328 expr = self.operand("dsk")
1329 layers = expr.dsk.layers.copy()
1330 deps = expr.dsk.dependencies.copy()
1331 keys = expr.__dask_keys__()
1332 if isinstance(expr.postcompute, list):
1333 postcomputes = expr.postcompute
1334 else:
1335 postcomputes = [expr.postcompute]
1336 tasks = [
1337 Task(self._name, func, _convert_dask_keys(keys), *extra_args)
1338 for func, extra_args in postcomputes
1339 ]
1340 from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
1342 leafs = set(deps)
1343 for val in deps.values():
1344 leafs -= val
1345 for t in tasks:
1346 layers[t.key] = MaterializedLayer({t.key: t})
1347 deps[t.key] = leafs
1348 return HighLevelGraph(layers, dependencies=deps)
1350 def __dask_keys__(self):
1351 return [self._name]
1354class ProhibitReuse(Expr):
1355 """
1356 An expression that guarantees that all keys are suffixes with a unique id.
1357 This can be used to break a common subexpression apart.
1358 """
1360 _parameters = ["expr"]
1361 _ALLOWED_TYPES = [HLGExpr, LLGExpr, HLGFinalizeCompute, _HLGExprSequence]
1363 def __dask_keys__(self):
1364 return self._modify_keys(self.expr.__dask_keys__())
1366 @staticmethod
1367 def _identity(obj):
1368 return obj
1370 @functools.cached_property
1371 def _suffix(self):
1372 return uuid.uuid4().hex
1374 def _modify_keys(self, k):
1375 if isinstance(k, list):
1376 return [self._modify_keys(kk) for kk in k]
1377 elif isinstance(k, tuple):
1378 return (self._modify_keys(k[0]),) + k[1:]
1379 elif isinstance(k, (int, float)):
1380 k = str(k)
1381 return f"{k}-{self._suffix}"
1383 def _simplify_down(self):
1384 # FIXME: Shuffling cannot be rewritten since the barrier key is
1385 # hardcoded. Skipping this here should do the trick most of the time
1386 if not isinstance(
1387 self.expr,
1388 tuple(self._ALLOWED_TYPES),
1389 ):
1390 return self.expr
1392 def __dask_graph__(self):
1393 try:
1394 from distributed.shuffle._core import P2PBarrierTask
1395 except ModuleNotFoundError:
1396 P2PBarrierTask = type(None)
1397 dsk = convert_legacy_graph(self.expr.__dask_graph__())
1399 subs = {old_key: self._modify_keys(old_key) for old_key in dsk}
1400 dsk2 = {}
1401 for old_key, new_key in subs.items():
1402 t = dsk[old_key]
1403 if isinstance(t, P2PBarrierTask):
1404 warnings.warn(
1405 "Cannot block reusing for graphs including a "
1406 "P2PBarrierTask. This may cause unexpected results. "
1407 "This typically happens when converting a dask "
1408 "DataFrame to delayed objects.",
1409 UserWarning,
1410 )
1411 return dsk
1412 dsk2[new_key] = Task(
1413 new_key,
1414 ProhibitReuse._identity,
1415 t.substitute(subs),
1416 )
1418 dsk2.update(dsk)
1419 return dsk2
1421 _layer = __dask_graph__