1from __future__ import annotations
2
3import abc
4import copy
5import functools
6import html
7from collections.abc import (
8 Collection,
9 Hashable,
10 ItemsView,
11 Iterable,
12 Iterator,
13 KeysView,
14 Mapping,
15 Sequence,
16 Set,
17 ValuesView,
18)
19from typing import Any
20
21import tlz as toolz
22
23import dask
24from dask import config
25from dask._task_spec import GraphNode
26from dask.base import clone_key, flatten, is_dask_collection
27from dask.core import keys_in_tasks, reverse_dict
28from dask.tokenize import normalize_token, tokenize
29from dask.typing import DaskCollection, Graph, Key
30from dask.utils import ensure_dict, import_required, key_split
31from dask.widgets import get_template
32
33
34def compute_layer_dependencies(layers):
35 """Returns the dependencies between layers"""
36
37 def _find_layer_containing_key(key):
38 for k, v in layers.items():
39 if key in v:
40 return k
41 raise RuntimeError(f"{repr(key)} not found")
42
43 all_keys = {key for layer in layers.values() for key in layer}
44 ret = {k: set() for k in layers}
45 for k, v in layers.items():
46 for key in keys_in_tasks(all_keys - v.keys(), v.values()):
47 ret[k].add(_find_layer_containing_key(key))
48 return ret
49
50
51class Layer(Graph):
52 """High level graph layer
53
54 This abstract class establish a protocol for high level graph layers.
55
56 The main motivation of a layer is to represent a collection of tasks
57 symbolically in order to speedup a series of operations significantly.
58 Ideally, a layer should stay in this symbolic state until execution
59 but in practice some operations will force the layer to generate all
60 its internal tasks. We say that the layer has been materialized.
61
62 Most of the default implementations in this class will materialize the
63 layer. It is up to derived classes to implement non-materializing
64 implementations.
65 """
66
67 annotations: Mapping[str, Any] | None
68 collection_annotations: Mapping[str, Any] | None
69
70 def __init__(
71 self,
72 annotations: Mapping[str, Any] | None = None,
73 collection_annotations: Mapping[str, Any] | None = None,
74 ):
75 """Initialize Layer object.
76
77 Parameters
78 ----------
79 annotations : Mapping[str, Any], optional
80 By default, None.
81 Annotations are metadata or soft constraints associated with tasks
82 that dask schedulers may choose to respect:
83 They signal intent without enforcing hard constraints.
84 As such, they are primarily designed for use with the distributed
85 scheduler. See the dask.annotate function for more information.
86 collection_annotations : Mapping[str, Any], optional. By default, None.
87 Experimental, intended to assist with visualizing the performance
88 characteristics of Dask computations.
89 These annotations are *not* passed to the distributed scheduler.
90 """
91 self.annotations = annotations or dask.get_annotations().copy() or None
92 self.collection_annotations = collection_annotations or copy.copy(
93 config.get("collection_annotations", None)
94 )
95
96 @functools.cached_property
97 def has_legacy_tasks(self):
98 """Check if the layer has legacy tasks
99
100 Legacy tasks are those that are not in the form of a tuple
101 (function, *args, **kwargs).
102 """
103 return any(not isinstance(v, GraphNode) for v in self.values())
104
105 @abc.abstractmethod
106 def is_materialized(self) -> bool:
107 """Return whether the layer is materialized or not"""
108 return True
109
110 @abc.abstractmethod
111 def get_output_keys(self) -> Set[Key]:
112 """Return a set of all output keys
113
114 Output keys are all keys in the layer that might be referenced by
115 other layers.
116
117 Classes overriding this implementation should not cause the layer
118 to be materialized.
119
120 Returns
121 -------
122 keys: Set
123 All output keys
124 """
125 return self.keys() # this implementation will materialize the graph
126
127 def cull(
128 self, keys: set[Key], all_hlg_keys: Collection[Key]
129 ) -> tuple[Layer, Mapping[Key, set[Key]]]:
130 """Remove unnecessary tasks from the layer
131
132 In other words, return a new Layer with only the tasks required to
133 calculate `keys` and a map of external key dependencies.
134
135 Examples
136 --------
137 >>> inc = lambda x: x + 1
138 >>> add = lambda x, y: x + y
139 >>> d = MaterializedLayer({'x': 1, 'y': (inc, 'x'), 'out': (add, 'x', 10)})
140 >>> _, deps = d.cull({'out'}, d.keys())
141 >>> deps
142 {'out': {'x'}, 'x': set()}
143
144 Returns
145 -------
146 layer: Layer
147 Culled layer
148 deps: Map
149 Map of external key dependencies
150 """
151
152 if self.has_legacy_tasks:
153 if len(keys) == len(self):
154 # Nothing to cull if preserving all existing keys
155 return (
156 self,
157 {k: self.get_dependencies(k, all_hlg_keys) for k in self.keys()},
158 )
159 ret_deps = {}
160 seen = set()
161 out = {}
162 work = keys.copy()
163 while work:
164 k = work.pop()
165 if k not in self:
166 continue
167 out[k] = self[k]
168 ret_deps[k] = self.get_dependencies(k, all_hlg_keys)
169 for d in ret_deps[k]:
170 if d not in seen:
171 if d in self:
172 seen.add(d)
173 work.add(d)
174
175 return MaterializedLayer(out, annotations=self.annotations), ret_deps
176 else:
177 from dask._task_spec import cull
178
179 out = cull(dict(self), keys)
180 return MaterializedLayer(out, annotations=self.annotations), {
181 k: set(v.dependencies) for k, v in out.items()
182 }
183
184 def get_dependencies(self, key: Key, all_hlg_keys: Collection[Key]) -> set:
185 """Get dependencies of `key` in the layer
186
187 Parameters
188 ----------
189 key:
190 The key to find dependencies of
191 all_hlg_keys:
192 All keys in the high level graph.
193
194 Returns
195 -------
196 deps: set
197 A set of dependencies
198 """
199 return keys_in_tasks(all_hlg_keys, [self[key]])
200
201 def clone(
202 self,
203 keys: set,
204 seed: Hashable,
205 bind_to: Key | None = None,
206 ) -> tuple[Layer, bool]:
207 """Clone selected keys in the layer, as well as references to keys in other
208 layers
209
210 Parameters
211 ----------
212 keys
213 Keys to be replaced. This never includes keys not listed by
214 :meth:`get_output_keys`. It must also include any keys that are outside
215 of this layer that may be referenced by it.
216 seed
217 Common hashable used to alter the keys; see :func:`dask.base.clone_key`
218 bind_to
219 Optional key to bind the leaf nodes to. A leaf node here is one that does
220 not reference any replaced keys; in other words it's a node where the
221 replacement graph traversal stops; it may still have dependencies on
222 non-replaced nodes.
223 A bound node will not be computed until after ``bind_to`` has been computed.
224
225 Returns
226 -------
227 - New layer
228 - True if the ``bind_to`` key was injected anywhere; False otherwise
229
230 Notes
231 -----
232 This method should be overridden by subclasses to avoid materializing the layer.
233 """
234 from dask.graph_manipulation import chunks
235
236 is_leaf: bool
237
238 def clone_value(o):
239 """Variant of distributed.utils_comm.subs_multiple, which allows injecting
240 bind_to
241 """
242 nonlocal is_leaf
243
244 typ = type(o)
245 if typ is tuple and o and callable(o[0]):
246 return (o[0],) + tuple(clone_value(i) for i in o[1:])
247 elif typ is list:
248 return [clone_value(i) for i in o]
249 elif typ is dict:
250 return {k: clone_value(v) for k, v in o.items()}
251 else:
252 try:
253 if o not in keys:
254 return o
255 except TypeError:
256 return o
257 is_leaf = False
258 return clone_key(o, seed)
259
260 dsk_new = {}
261 bound = False
262
263 for key, value in self.items():
264 if key in keys:
265 key = clone_key(key, seed)
266 is_leaf = True
267 value = clone_value(value)
268 if bind_to is not None and is_leaf:
269 value = (chunks.bind, value, bind_to)
270 bound = True
271
272 dsk_new[key] = value
273
274 return MaterializedLayer(dsk_new), bound
275
276 def __copy__(self):
277 """Default shallow copy implementation"""
278 obj = type(self).__new__(self.__class__)
279 obj.__dict__.update(self.__dict__)
280 return obj
281
282 def _repr_html_(self, layer_index="", highlevelgraph_key="", dependencies=()):
283 if highlevelgraph_key != "":
284 shortname = key_split(highlevelgraph_key)
285 elif hasattr(self, "name"):
286 shortname = key_split(self.name)
287 else:
288 shortname = self.__class__.__name__
289
290 svg_repr = ""
291 if (
292 self.collection_annotations
293 and self.collection_annotations.get("type") == "dask.array.core.Array"
294 ):
295 chunks = self.collection_annotations.get("chunks")
296 if chunks:
297 from dask.array.svg import svg
298
299 svg_repr = svg(chunks)
300
301 return get_template("highlevelgraph_layer.html.j2").render(
302 materialized=self.is_materialized(),
303 shortname=shortname,
304 layer_index=layer_index,
305 highlevelgraph_key=highlevelgraph_key,
306 info=self.layer_info_dict(),
307 dependencies=dependencies,
308 svg_repr=svg_repr,
309 )
310
311 def layer_info_dict(self):
312 info = {
313 "layer_type": type(self).__name__,
314 "is_materialized": self.is_materialized(),
315 "number of outputs": f"{len(self.get_output_keys())}",
316 }
317 if self.annotations is not None:
318 for key, val in self.annotations.items():
319 info[key] = html.escape(str(val))
320 if self.collection_annotations is not None:
321 for key, val in self.collection_annotations.items():
322 # Hide verbose chunk details from the HTML table
323 if key != "chunks":
324 info[key] = html.escape(str(val))
325 return info
326
327
328class MaterializedLayer(Layer):
329 """Fully materialized layer of `Layer`
330
331 Parameters
332 ----------
333 mapping: Mapping
334 The mapping between keys and tasks, typically a dask graph.
335 """
336
337 def __init__(self, mapping: Mapping, annotations=None, collection_annotations=None):
338 super().__init__(
339 annotations=annotations, collection_annotations=collection_annotations
340 )
341 if not isinstance(mapping, Mapping):
342 raise TypeError(f"mapping must be a Mapping. Instead got {type(mapping)}")
343 self.mapping = mapping
344
345 def __contains__(self, k):
346 return k in self.mapping
347
348 def __getitem__(self, k):
349 return self.mapping[k]
350
351 def __iter__(self):
352 return iter(self.mapping)
353
354 def __len__(self):
355 return len(self.mapping)
356
357 def is_materialized(self):
358 return True
359
360 def get_output_keys(self):
361 return self.keys()
362
363
364class HighLevelGraph(Graph):
365 """Task graph composed of layers of dependent subgraphs
366
367 This object encodes a Dask task graph that is composed of layers of
368 dependent subgraphs, such as commonly occurs when building task graphs
369 using high level collections like Dask array, bag, or dataframe.
370
371 Typically each high level array, bag, or dataframe operation takes the task
372 graphs of the input collections, merges them, and then adds one or more new
373 layers of tasks for the new operation. These layers typically have at
374 least as many tasks as there are partitions or chunks in the collection.
375 The HighLevelGraph object stores the subgraphs for each operation
376 separately in sub-graphs, and also stores the dependency structure between
377 them.
378
379 Parameters
380 ----------
381 layers : Mapping[str, Mapping]
382 The subgraph layers, keyed by a unique name
383 dependencies : Mapping[str, set[str]]
384 The set of layers on which each layer depends
385 key_dependencies : dict[Key, set], optional
386 Mapping (some) keys in the high level graph to their dependencies. If
387 a key is missing, its dependencies will be calculated on-the-fly.
388
389 Examples
390 --------
391 Here is an idealized example that shows the internal state of a
392 HighLevelGraph
393
394 >>> import dask.dataframe as dd
395
396 >>> df = dd.read_csv('myfile.*.csv') # doctest: +SKIP
397 >>> df = df + 100 # doctest: +SKIP
398 >>> df = df[df.name == 'Alice'] # doctest: +SKIP
399
400 >>> graph = df.__dask_graph__() # doctest: +SKIP
401 >>> graph.layers # doctest: +SKIP
402 {
403 'read-csv': {('read-csv', 0): (pandas.read_csv, 'myfile.0.csv'),
404 ('read-csv', 1): (pandas.read_csv, 'myfile.1.csv'),
405 ('read-csv', 2): (pandas.read_csv, 'myfile.2.csv'),
406 ('read-csv', 3): (pandas.read_csv, 'myfile.3.csv')},
407 'add': {('add', 0): (operator.add, ('read-csv', 0), 100),
408 ('add', 1): (operator.add, ('read-csv', 1), 100),
409 ('add', 2): (operator.add, ('read-csv', 2), 100),
410 ('add', 3): (operator.add, ('read-csv', 3), 100)}
411 'filter': {('filter', 0): (lambda part: part[part.name == 'Alice'], ('add', 0)),
412 ('filter', 1): (lambda part: part[part.name == 'Alice'], ('add', 1)),
413 ('filter', 2): (lambda part: part[part.name == 'Alice'], ('add', 2)),
414 ('filter', 3): (lambda part: part[part.name == 'Alice'], ('add', 3))}
415 }
416
417 >>> graph.dependencies # doctest: +SKIP
418 {
419 'read-csv': set(),
420 'add': {'read-csv'},
421 'filter': {'add'}
422 }
423
424 See Also
425 --------
426 HighLevelGraph.from_collections :
427 typically used by developers to make new HighLevelGraphs
428 """
429
430 layers: Mapping[str, Layer]
431 dependencies: Mapping[str, set[str]]
432 key_dependencies: dict[Key, set[Key]]
433 _to_dict: dict
434 _all_external_keys: set
435
436 def __init__(
437 self,
438 layers: Mapping[str, Graph],
439 dependencies: Mapping[str, set[str]],
440 key_dependencies: dict[Key, set[Key]] | None = None,
441 ):
442 self.dependencies = dependencies
443 self.key_dependencies = key_dependencies or {}
444 # Makes sure that all layers are `Layer`
445 self.layers = {
446 k: v if isinstance(v, Layer) else MaterializedLayer(v)
447 for k, v in layers.items()
448 }
449
450 @classmethod
451 def _from_collection(cls, name, layer, collection):
452 """`from_collections` optimized for a single collection"""
453 if not is_dask_collection(collection):
454 raise TypeError(type(collection))
455
456 graph = collection.__dask_graph__()
457 if isinstance(graph, HighLevelGraph):
458 layers = ensure_dict(graph.layers, copy=True)
459 layers[name] = layer
460 deps = ensure_dict(graph.dependencies, copy=True)
461 deps[name] = set(collection.__dask_layers__())
462 else:
463 key = _get_some_layer_name(collection)
464 layers = {name: layer, key: graph}
465 deps = {name: {key}, key: set()}
466
467 return cls(layers, deps)
468
469 @classmethod
470 def from_collections(
471 cls,
472 name: str,
473 layer: Graph,
474 dependencies: Sequence[DaskCollection] = (),
475 ) -> HighLevelGraph:
476 """Construct a HighLevelGraph from a new layer and a set of collections
477
478 This constructs a HighLevelGraph in the common case where we have a single
479 new layer and a set of old collections on which we want to depend.
480
481 This pulls out the ``__dask_layers__()`` method of the collections if
482 they exist, and adds them to the dependencies for this new layer. It
483 also merges all of the layers from all of the dependent collections
484 together into the new layers for this graph.
485
486 Parameters
487 ----------
488 name : str
489 The name of the new layer
490 layer : Mapping
491 The graph layer itself
492 dependencies : List of Dask collections
493 A list of other dask collections (like arrays or dataframes) that
494 have graphs themselves
495
496 Examples
497 --------
498
499 In typical usage we make a new task layer, and then pass that layer
500 along with all dependent collections to this method.
501
502 >>> def add(self, other):
503 ... name = 'add-' + tokenize(self, other)
504 ... layer = {(name, i): (add, input_key, other)
505 ... for i, input_key in enumerate(self.__dask_keys__())}
506 ... graph = HighLevelGraph.from_collections(name, layer, dependencies=[self])
507 ... return new_collection(name, graph)
508 """
509 if len(dependencies) == 1:
510 return cls._from_collection(name, layer, dependencies[0])
511 layers = {name: layer}
512 name_dep: set[str] = set()
513 deps: dict[str, set[str]] = {name: name_dep}
514 for collection in toolz.unique(dependencies, key=id):
515 if is_dask_collection(collection):
516 graph = collection.__dask_graph__()
517 if isinstance(graph, HighLevelGraph):
518 layers.update(graph.layers)
519 deps.update(graph.dependencies)
520 name_dep |= set(collection.__dask_layers__())
521 else:
522 key = _get_some_layer_name(collection)
523 layers[key] = graph
524 name_dep.add(key)
525 deps[key] = set()
526 else:
527 raise TypeError(type(collection))
528
529 return cls(layers, deps)
530
531 def __getitem__(self, key: Key) -> Any:
532 # Attempt O(1) direct access first, under the assumption that layer names match
533 # either the keys (Scalar, Item, Delayed) or the first element of the key tuples
534 # (Array, Bag, DataFrame, Series). This assumption is not always true.
535 try:
536 return self.layers[key][key] # type: ignore
537 except KeyError:
538 pass
539 try:
540 return self.layers[key[0]][key] # type: ignore
541 except (KeyError, IndexError, TypeError):
542 pass
543
544 # Fall back to O(n) access
545 for d in self.layers.values():
546 try:
547 return d[key]
548 except KeyError:
549 pass
550
551 raise KeyError(key)
552
553 def __len__(self) -> int:
554 # NOTE: this will double-count keys that are duplicated between layers, so it's
555 # possible that `len(hlg) > len(hlg.to_dict())`. However, duplicate keys should
556 # not occur through normal use, and their existence would usually be a bug.
557 # So we ignore this case in favor of better performance.
558 # https://github.com/dask/dask/issues/7271
559 return sum(len(layer) for layer in self.layers.values())
560
561 def __iter__(self) -> Iterator[Key]:
562 return iter(self.to_dict())
563
564 def to_dict(self) -> dict[Key, Any]:
565 """Efficiently convert to plain dict. This method is faster than dict(self)."""
566 try:
567 return self._to_dict
568 except AttributeError:
569 out = self._to_dict = ensure_dict(self)
570 return out
571
572 def keys(self) -> KeysView:
573 """Get all keys of all the layers.
574
575 This will in many cases materialize layers, which makes it a relatively
576 expensive operation. See :meth:`get_all_external_keys` for a faster alternative.
577 """
578 return self.to_dict().keys()
579
580 def get_all_external_keys(self) -> set[Key]:
581 """Get all output keys of all layers
582
583 This will in most cases _not_ materialize any layers, which makes
584 it a relative cheap operation.
585
586 Returns
587 -------
588 keys: set
589 A set of all external keys
590 """
591 try:
592 return self._all_external_keys
593 except AttributeError:
594 keys: set = set()
595 for layer in self.layers.values():
596 # Note: don't use `keys |= ...`, because the RHS is a
597 # collections.abc.Set rather than a real set, and this will
598 # cause a whole new set to be constructed.
599 keys.update(layer.get_output_keys())
600 self._all_external_keys = keys
601 return keys
602
603 def items(self) -> ItemsView[Key, Any]:
604 return self.to_dict().items()
605
606 def values(self) -> ValuesView[Any]:
607 return self.to_dict().values()
608
609 def get_all_dependencies(self) -> dict[Key, set[Key]]:
610 """Get dependencies of all keys
611
612 This will in most cases materialize all layers, which makes
613 it an expensive operation.
614
615 Returns
616 -------
617 map: Mapping
618 A map that maps each key to its dependencies
619 """
620 all_keys = self.keys()
621 missing_keys = all_keys - self.key_dependencies.keys()
622 if missing_keys:
623 for layer in self.layers.values():
624 for k in missing_keys & layer.keys():
625 self.key_dependencies[k] = layer.get_dependencies(k, all_keys)
626 return self.key_dependencies
627
628 @property
629 def dependents(self) -> dict[str, set[str]]:
630 return reverse_dict(self.dependencies)
631
632 def copy(self) -> HighLevelGraph:
633 return HighLevelGraph(
634 ensure_dict(self.layers, copy=True),
635 ensure_dict(self.dependencies, copy=True),
636 self.key_dependencies.copy(),
637 )
638
639 @classmethod
640 def merge(cls, *graphs: Graph) -> HighLevelGraph:
641 layers: dict[str, Graph] = {}
642 dependencies: dict[str, set[str]] = {}
643 for g in graphs:
644 if isinstance(g, HighLevelGraph):
645 layers.update(g.layers)
646 dependencies.update(g.dependencies)
647 elif isinstance(g, Mapping):
648 layers[str(id(g))] = g
649 dependencies[str(id(g))] = set()
650 else:
651 raise TypeError(g)
652 return cls(layers, dependencies)
653
654 def visualize(self, filename="dask-hlg.svg", format=None, **kwargs):
655 """
656 Visualize this dask high level graph.
657
658 Requires ``graphviz`` to be installed.
659
660 Parameters
661 ----------
662 filename : str or None, optional
663 The name of the file to write to disk. If the provided `filename`
664 doesn't include an extension, '.png' will be used by default.
665 If `filename` is None, no file will be written, and the graph is
666 rendered in the Jupyter notebook only.
667 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional
668 Format in which to write output file. Default is 'svg'.
669 color : {None, 'layer_type'}, optional (default: None)
670 Options to color nodes.
671 - None, no colors.
672 - layer_type, color nodes based on the layer type.
673 **kwargs
674 Additional keyword arguments to forward to ``to_graphviz``.
675
676 Examples
677 --------
678 >>> x.dask.visualize(filename='dask.svg') # doctest: +SKIP
679 >>> x.dask.visualize(filename='dask.svg', color='layer_type') # doctest: +SKIP
680
681 Returns
682 -------
683 result : IPython.display.Image, IPython.display.SVG, or None
684 See dask.dot.dot_graph for more information.
685
686 See Also
687 --------
688 dask.dot.dot_graph
689 dask.base.visualize # low level variant
690 """
691
692 from dask.dot import graphviz_to_file
693
694 g = to_graphviz(self, **kwargs)
695 graphviz_to_file(g, filename, format)
696 return g
697
698 def _toposort_layers(self) -> list[str]:
699 """Sort the layers in a high level graph topologically
700
701 Parameters
702 ----------
703 hlg : HighLevelGraph
704 The high level graph's layers to sort
705
706 Returns
707 -------
708 sorted: list
709 List of layer names sorted topologically
710 """
711 degree = {k: len(v) for k, v in self.dependencies.items()}
712 reverse_deps: dict[str, list[str]] = {k: [] for k in self.dependencies}
713 ready = []
714 for k, v in self.dependencies.items():
715 for dep in v:
716 reverse_deps[dep].append(k)
717 if not v:
718 ready.append(k)
719 ret = []
720 while len(ready) > 0:
721 layer = ready.pop()
722 ret.append(layer)
723 for rdep in reverse_deps[layer]:
724 degree[rdep] -= 1
725 if degree[rdep] == 0:
726 ready.append(rdep)
727 return ret
728
729 def cull(self, keys: Iterable[Key]) -> HighLevelGraph:
730 """Return new HighLevelGraph with only the tasks required to calculate keys.
731
732 In other words, remove unnecessary tasks from dask.
733
734 Parameters
735 ----------
736 keys
737 iterable of keys or nested list of keys such as the output of
738 ``__dask_keys__()``
739
740 Returns
741 -------
742 hlg: HighLevelGraph
743 Culled high level graph
744 """
745 keys_set = set(flatten(keys))
746
747 # Note: All Layer classes still in existence are of
748 # one of these types (or subclasses)
749 #
750 # - MaterializedLayer
751 # - Blockwise
752 # - ArrayOverlapLayer (which is basically as good as MaterializedLayer)
753 if not any(layer.has_legacy_tasks for layer in self.layers.values()):
754 all_ext_keys = set()
755 else:
756 # FIXME: Technically, we don't need to compute **all** keys but only
757 # those of the current layer and all of its dependencies, i.e. if
758 # there are legacy layers for IO followed by many blockwise layers,
759 # we should still get by without this
760 all_ext_keys = self.get_all_external_keys()
761
762 ret_layers: dict = {}
763 layer_dependencies = {}
764 tok = tokenize(keys_set)
765 for layer_name in reversed(self._toposort_layers()):
766 new_layer_name = f"{layer_name}-{tok}"
767 layer = self.layers[layer_name]
768 if keys_set:
769 culled_layer, culled_deps = layer.cull(keys_set, all_ext_keys)
770 if not culled_deps:
771 continue
772
773 # Update `keys` with all layer's external key dependencies,
774 # which are all the layer's dependencies (`culled_deps`)
775 # excluding the layer's output keys.
776 for k, d in culled_deps.items():
777 keys_set |= d
778 keys_set.discard(k)
779 layer = culled_layer
780 # Save the culled layer and its key dependencies
781 ret_layers[new_layer_name] = layer
782 layer_dependencies[new_layer_name] = self.dependencies[layer_name]
783
784 # Converting dict_keys to a real set lets Python optimise the set
785 # intersection to iterate over the smaller of the two sets.
786 ret_layers_keys = set(ret_layers.keys())
787 ret_dependencies = {
788 layer_name: layer_dependencies[layer_name] & ret_layers_keys
789 for layer_name in ret_layers
790 }
791
792 return HighLevelGraph(ret_layers, ret_dependencies)
793
794 def cull_layers(self, layers: Iterable[str]) -> HighLevelGraph:
795 """Return a new HighLevelGraph with only the given layers and their
796 dependencies. Internally, layers are not modified.
797
798 This is a variant of :meth:`HighLevelGraph.cull` which is much faster and does
799 not risk creating a collision between two layers with the same name and
800 different content when two culled graphs are merged later on.
801
802 Returns
803 -------
804 hlg: HighLevelGraph
805 Culled high level graph
806 """
807 to_visit = set(layers)
808 ret_layers = {}
809 ret_dependencies = {}
810 while to_visit:
811 k = to_visit.pop()
812 ret_layers[k] = self.layers[k]
813 ret_dependencies[k] = self.dependencies[k]
814 to_visit |= ret_dependencies[k] - ret_dependencies.keys()
815
816 return HighLevelGraph(ret_layers, ret_dependencies)
817
818 def validate(self) -> None:
819 # Check dependencies
820 for layer_name, deps in self.dependencies.items():
821 if layer_name not in self.layers:
822 raise ValueError(
823 f"dependencies[{repr(layer_name)}] not found in layers"
824 )
825 for dep in deps:
826 if dep not in self.dependencies:
827 raise ValueError(f"{repr(dep)} not found in dependencies")
828
829 for layer in self.layers.values():
830 assert hasattr(layer, "annotations")
831
832 # Re-calculate all layer dependencies
833 dependencies = compute_layer_dependencies(self.layers)
834
835 # Check keys
836 dep_key1 = self.dependencies.keys()
837 dep_key2 = dependencies.keys()
838 if dep_key1 != dep_key2:
839 raise ValueError(
840 f"incorrect dependencies keys {set(dep_key1)!r} "
841 f"expected {set(dep_key2)!r}"
842 )
843
844 # Check values
845 for k in dep_key1:
846 if self.dependencies[k] != dependencies[k]:
847 raise ValueError(
848 f"incorrect HLG dependencies[{repr(k)}]: {repr(self.dependencies[k])} "
849 f"expected {repr(dependencies[k])} from task dependencies"
850 )
851
852 def __repr__(self) -> str:
853 representation = f"{type(self).__name__} with {len(self.layers)} layers.\n"
854 representation += f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}>\n"
855 for i, layerkey in enumerate(self._toposort_layers()):
856 representation += f" {i}. {layerkey}\n"
857 return representation
858
859 def _repr_html_(self) -> str:
860 return get_template("highlevelgraph.html.j2").render(
861 type=type(self).__name__,
862 layers=self.layers,
863 toposort=self._toposort_layers(),
864 layer_dependencies=self.dependencies,
865 n_outputs=len(self.get_all_external_keys()),
866 )
867
868
869def to_graphviz(
870 hg,
871 data_attributes=None,
872 function_attributes=None,
873 rankdir="BT",
874 graph_attr=None,
875 node_attr=None,
876 edge_attr=None,
877 **kwargs,
878):
879 from dask.dot import label, name
880
881 graphviz = import_required(
882 "graphviz",
883 "Drawing dask graphs with the graphviz visualization engine requires the `graphviz` "
884 "python library and the `graphviz` system library.\n\n"
885 "Please either conda or pip install as follows:\n\n"
886 " conda install python-graphviz # either conda install\n"
887 " python -m pip install graphviz # or pip install and follow installation instructions",
888 )
889
890 data_attributes = data_attributes or {}
891 function_attributes = function_attributes or {}
892 graph_attr = graph_attr or {}
893 node_attr = node_attr or {}
894 edge_attr = edge_attr or {}
895
896 graph_attr["rankdir"] = rankdir
897 node_attr["shape"] = "box"
898 node_attr["fontname"] = "helvetica"
899
900 graph_attr.update(kwargs)
901 g = graphviz.Digraph(
902 graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr
903 )
904
905 n_tasks = {}
906 for layer in hg.dependencies:
907 n_tasks[layer] = len(hg.layers[layer])
908
909 min_tasks = min(n_tasks.values())
910 max_tasks = max(n_tasks.values())
911
912 cache = {}
913
914 color = kwargs.get("color")
915 if color == "layer_type":
916 layer_colors = {
917 "DataFrameIOLayer": ["#CCC7F9", False], # purple
918 "ArrayOverlayLayer": ["#FFD9F2", False], # pink
919 "BroadcastJoinLayer": ["#D9F2FF", False], # blue
920 "Blockwise": ["#D9FFE6", False], # green
921 "BlockwiseLayer": ["#D9FFE6", False], # green
922 "MaterializedLayer": ["#DBDEE5", False], # gray
923 }
924
925 for layer in hg.dependencies:
926 layer_name = name(layer)
927 attrs = data_attributes.get(layer, {})
928
929 node_label = label(layer, cache=cache)
930 node_size = (
931 20
932 if max_tasks == min_tasks
933 else int(20 + ((n_tasks[layer] - min_tasks) / (max_tasks - min_tasks)) * 20)
934 )
935
936 layer_type = str(type(hg.layers[layer]).__name__)
937 node_tooltips = (
938 f"A {layer_type.replace('Layer', '')} Layer with {n_tasks[layer]} Tasks.\n"
939 )
940
941 layer_ca = hg.layers[layer].collection_annotations
942 if layer_ca:
943 if layer_ca.get("type") == "dask.array.core.Array":
944 node_tooltips += (
945 f"Array Shape: {layer_ca.get('shape')}\n"
946 f"Data Type: {layer_ca.get('dtype')}\n"
947 f"Chunk Size: {layer_ca.get('chunksize')}\n"
948 f"Chunk Type: {layer_ca.get('chunk_type')}\n"
949 )
950
951 if layer_ca.get("type") == "dask.dataframe.core.DataFrame":
952 dftype = {"pandas.core.frame.DataFrame": "pandas"}
953 cols = layer_ca.get("columns")
954
955 node_tooltips += (
956 f"Number of Partitions: {layer_ca.get('npartitions')}\n"
957 f"DataFrame Type: {dftype.get(layer_ca.get('dataframe_type'))}\n"
958 f"{len(cols)} DataFrame Columns: {str(cols) if len(str(cols)) <= 40 else '[...]'}\n"
959 )
960
961 attrs.setdefault("label", str(node_label))
962 attrs.setdefault("fontsize", str(node_size))
963 attrs.setdefault("tooltip", str(node_tooltips))
964
965 if color == "layer_type":
966 node_color = layer_colors.get(layer_type)[0]
967 layer_colors.get(layer_type)[1] = True
968
969 attrs.setdefault("fillcolor", str(node_color))
970 attrs.setdefault("style", "filled")
971
972 g.node(layer_name, **attrs)
973
974 for layer, deps in hg.dependencies.items():
975 layer_name = name(layer)
976 for dep in deps:
977 dep_name = name(dep)
978 g.edge(dep_name, layer_name)
979
980 if color == "layer_type":
981 legend_title = "Key"
982
983 legend_label = (
984 '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="5">'
985 "<TR><TD><B>Legend: Layer types</B></TD></TR>"
986 )
987
988 for layer_type, color in layer_colors.items():
989 if color[1]:
990 legend_label += f'<TR><TD BGCOLOR="{color[0]}">{layer_type}</TD></TR>'
991
992 legend_label += "</TABLE>>"
993
994 attrs = data_attributes.get(legend_title, {})
995 attrs.setdefault("label", str(legend_label))
996 attrs.setdefault("fontsize", "20")
997 attrs.setdefault("margin", "0")
998
999 g.node(legend_title, **attrs)
1000
1001 return g
1002
1003
1004def _get_some_layer_name(collection) -> str:
1005 """Somehow get a unique name for a Layer from a non-HighLevelGraph dask mapping"""
1006 try:
1007 (name,) = collection.__dask_layers__()
1008 return name
1009 except (AttributeError, ValueError):
1010 # collection does not define the optional __dask_layers__ method
1011 # or it spuriously returns more than one layer
1012 return str(id(collection))
1013
1014
1015@normalize_token.register(HighLevelGraph)
1016def register_highlevelgraph(hlg):
1017 # Note: Layer keys are not necessarily identifying HLGs uniquely
1018 # see https://github.com/dask/dask/issues/9888
1019 return normalize_token(list(hlg.layers.keys()))