Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/highlevelgraph.py: 19%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

408 statements  

1from __future__ import annotations 

2 

3import abc 

4import copy 

5import functools 

6import html 

7from collections.abc import ( 

8 Collection, 

9 Hashable, 

10 ItemsView, 

11 Iterable, 

12 Iterator, 

13 KeysView, 

14 Mapping, 

15 Sequence, 

16 Set, 

17 ValuesView, 

18) 

19from typing import Any 

20 

21import tlz as toolz 

22 

23import dask 

24from dask import config 

25from dask._task_spec import GraphNode 

26from dask.base import clone_key, flatten, is_dask_collection 

27from dask.core import keys_in_tasks, reverse_dict 

28from dask.tokenize import normalize_token, tokenize 

29from dask.typing import DaskCollection, Graph, Key 

30from dask.utils import ensure_dict, import_required, key_split 

31from dask.widgets import get_template 

32 

33 

34def compute_layer_dependencies(layers): 

35 """Returns the dependencies between layers""" 

36 

37 def _find_layer_containing_key(key): 

38 for k, v in layers.items(): 

39 if key in v: 

40 return k 

41 raise RuntimeError(f"{repr(key)} not found") 

42 

43 all_keys = {key for layer in layers.values() for key in layer} 

44 ret = {k: set() for k in layers} 

45 for k, v in layers.items(): 

46 for key in keys_in_tasks(all_keys - v.keys(), v.values()): 

47 ret[k].add(_find_layer_containing_key(key)) 

48 return ret 

49 

50 

51class Layer(Graph): 

52 """High level graph layer 

53 

54 This abstract class establish a protocol for high level graph layers. 

55 

56 The main motivation of a layer is to represent a collection of tasks 

57 symbolically in order to speedup a series of operations significantly. 

58 Ideally, a layer should stay in this symbolic state until execution 

59 but in practice some operations will force the layer to generate all 

60 its internal tasks. We say that the layer has been materialized. 

61 

62 Most of the default implementations in this class will materialize the 

63 layer. It is up to derived classes to implement non-materializing 

64 implementations. 

65 """ 

66 

67 annotations: Mapping[str, Any] | None 

68 collection_annotations: Mapping[str, Any] | None 

69 

70 def __init__( 

71 self, 

72 annotations: Mapping[str, Any] | None = None, 

73 collection_annotations: Mapping[str, Any] | None = None, 

74 ): 

75 """Initialize Layer object. 

76 

77 Parameters 

78 ---------- 

79 annotations : Mapping[str, Any], optional 

80 By default, None. 

81 Annotations are metadata or soft constraints associated with tasks 

82 that dask schedulers may choose to respect: 

83 They signal intent without enforcing hard constraints. 

84 As such, they are primarily designed for use with the distributed 

85 scheduler. See the dask.annotate function for more information. 

86 collection_annotations : Mapping[str, Any], optional. By default, None. 

87 Experimental, intended to assist with visualizing the performance 

88 characteristics of Dask computations. 

89 These annotations are *not* passed to the distributed scheduler. 

90 """ 

91 self.annotations = annotations or dask.get_annotations().copy() or None 

92 self.collection_annotations = collection_annotations or copy.copy( 

93 config.get("collection_annotations", None) 

94 ) 

95 

96 @functools.cached_property 

97 def has_legacy_tasks(self): 

98 """Check if the layer has legacy tasks 

99 

100 Legacy tasks are those that are not in the form of a tuple 

101 (function, *args, **kwargs). 

102 """ 

103 return any(not isinstance(v, GraphNode) for v in self.values()) 

104 

105 @abc.abstractmethod 

106 def is_materialized(self) -> bool: 

107 """Return whether the layer is materialized or not""" 

108 return True 

109 

110 @abc.abstractmethod 

111 def get_output_keys(self) -> Set[Key]: 

112 """Return a set of all output keys 

113 

114 Output keys are all keys in the layer that might be referenced by 

115 other layers. 

116 

117 Classes overriding this implementation should not cause the layer 

118 to be materialized. 

119 

120 Returns 

121 ------- 

122 keys: Set 

123 All output keys 

124 """ 

125 return self.keys() # this implementation will materialize the graph 

126 

127 def cull( 

128 self, keys: set[Key], all_hlg_keys: Collection[Key] 

129 ) -> tuple[Layer, Mapping[Key, set[Key]]]: 

130 """Remove unnecessary tasks from the layer 

131 

132 In other words, return a new Layer with only the tasks required to 

133 calculate `keys` and a map of external key dependencies. 

134 

135 Examples 

136 -------- 

137 >>> inc = lambda x: x + 1 

138 >>> add = lambda x, y: x + y 

139 >>> d = MaterializedLayer({'x': 1, 'y': (inc, 'x'), 'out': (add, 'x', 10)}) 

140 >>> _, deps = d.cull({'out'}, d.keys()) 

141 >>> deps 

142 {'out': {'x'}, 'x': set()} 

143 

144 Returns 

145 ------- 

146 layer: Layer 

147 Culled layer 

148 deps: Map 

149 Map of external key dependencies 

150 """ 

151 

152 if self.has_legacy_tasks: 

153 if len(keys) == len(self): 

154 # Nothing to cull if preserving all existing keys 

155 return ( 

156 self, 

157 {k: self.get_dependencies(k, all_hlg_keys) for k in self.keys()}, 

158 ) 

159 ret_deps = {} 

160 seen = set() 

161 out = {} 

162 work = keys.copy() 

163 while work: 

164 k = work.pop() 

165 if k not in self: 

166 continue 

167 out[k] = self[k] 

168 ret_deps[k] = self.get_dependencies(k, all_hlg_keys) 

169 for d in ret_deps[k]: 

170 if d not in seen: 

171 if d in self: 

172 seen.add(d) 

173 work.add(d) 

174 

175 return MaterializedLayer(out, annotations=self.annotations), ret_deps 

176 else: 

177 from dask._task_spec import cull 

178 

179 out = cull(dict(self), keys) 

180 return MaterializedLayer(out, annotations=self.annotations), { 

181 k: set(v.dependencies) for k, v in out.items() 

182 } 

183 

184 def get_dependencies(self, key: Key, all_hlg_keys: Collection[Key]) -> set: 

185 """Get dependencies of `key` in the layer 

186 

187 Parameters 

188 ---------- 

189 key: 

190 The key to find dependencies of 

191 all_hlg_keys: 

192 All keys in the high level graph. 

193 

194 Returns 

195 ------- 

196 deps: set 

197 A set of dependencies 

198 """ 

199 return keys_in_tasks(all_hlg_keys, [self[key]]) 

200 

201 def clone( 

202 self, 

203 keys: set, 

204 seed: Hashable, 

205 bind_to: Key | None = None, 

206 ) -> tuple[Layer, bool]: 

207 """Clone selected keys in the layer, as well as references to keys in other 

208 layers 

209 

210 Parameters 

211 ---------- 

212 keys 

213 Keys to be replaced. This never includes keys not listed by 

214 :meth:`get_output_keys`. It must also include any keys that are outside 

215 of this layer that may be referenced by it. 

216 seed 

217 Common hashable used to alter the keys; see :func:`dask.base.clone_key` 

218 bind_to 

219 Optional key to bind the leaf nodes to. A leaf node here is one that does 

220 not reference any replaced keys; in other words it's a node where the 

221 replacement graph traversal stops; it may still have dependencies on 

222 non-replaced nodes. 

223 A bound node will not be computed until after ``bind_to`` has been computed. 

224 

225 Returns 

226 ------- 

227 - New layer 

228 - True if the ``bind_to`` key was injected anywhere; False otherwise 

229 

230 Notes 

231 ----- 

232 This method should be overridden by subclasses to avoid materializing the layer. 

233 """ 

234 from dask.graph_manipulation import chunks 

235 

236 is_leaf: bool 

237 

238 def clone_value(o): 

239 """Variant of distributed.utils_comm.subs_multiple, which allows injecting 

240 bind_to 

241 """ 

242 nonlocal is_leaf 

243 

244 typ = type(o) 

245 if typ is tuple and o and callable(o[0]): 

246 return (o[0],) + tuple(clone_value(i) for i in o[1:]) 

247 elif typ is list: 

248 return [clone_value(i) for i in o] 

249 elif typ is dict: 

250 return {k: clone_value(v) for k, v in o.items()} 

251 else: 

252 try: 

253 if o not in keys: 

254 return o 

255 except TypeError: 

256 return o 

257 is_leaf = False 

258 return clone_key(o, seed) 

259 

260 dsk_new = {} 

261 bound = False 

262 

263 for key, value in self.items(): 

264 if key in keys: 

265 key = clone_key(key, seed) 

266 is_leaf = True 

267 value = clone_value(value) 

268 if bind_to is not None and is_leaf: 

269 value = (chunks.bind, value, bind_to) 

270 bound = True 

271 

272 dsk_new[key] = value 

273 

274 return MaterializedLayer(dsk_new), bound 

275 

276 def __copy__(self): 

277 """Default shallow copy implementation""" 

278 obj = type(self).__new__(self.__class__) 

279 obj.__dict__.update(self.__dict__) 

280 return obj 

281 

282 def _repr_html_(self, layer_index="", highlevelgraph_key="", dependencies=()): 

283 if highlevelgraph_key != "": 

284 shortname = key_split(highlevelgraph_key) 

285 elif hasattr(self, "name"): 

286 shortname = key_split(self.name) 

287 else: 

288 shortname = self.__class__.__name__ 

289 

290 svg_repr = "" 

291 if ( 

292 self.collection_annotations 

293 and self.collection_annotations.get("type") == "dask.array.core.Array" 

294 ): 

295 chunks = self.collection_annotations.get("chunks") 

296 if chunks: 

297 from dask.array.svg import svg 

298 

299 svg_repr = svg(chunks) 

300 

301 return get_template("highlevelgraph_layer.html.j2").render( 

302 materialized=self.is_materialized(), 

303 shortname=shortname, 

304 layer_index=layer_index, 

305 highlevelgraph_key=highlevelgraph_key, 

306 info=self.layer_info_dict(), 

307 dependencies=dependencies, 

308 svg_repr=svg_repr, 

309 ) 

310 

311 def layer_info_dict(self): 

312 info = { 

313 "layer_type": type(self).__name__, 

314 "is_materialized": self.is_materialized(), 

315 "number of outputs": f"{len(self.get_output_keys())}", 

316 } 

317 if self.annotations is not None: 

318 for key, val in self.annotations.items(): 

319 info[key] = html.escape(str(val)) 

320 if self.collection_annotations is not None: 

321 for key, val in self.collection_annotations.items(): 

322 # Hide verbose chunk details from the HTML table 

323 if key != "chunks": 

324 info[key] = html.escape(str(val)) 

325 return info 

326 

327 

328class MaterializedLayer(Layer): 

329 """Fully materialized layer of `Layer` 

330 

331 Parameters 

332 ---------- 

333 mapping: Mapping 

334 The mapping between keys and tasks, typically a dask graph. 

335 """ 

336 

337 def __init__(self, mapping: Mapping, annotations=None, collection_annotations=None): 

338 super().__init__( 

339 annotations=annotations, collection_annotations=collection_annotations 

340 ) 

341 if not isinstance(mapping, Mapping): 

342 raise TypeError(f"mapping must be a Mapping. Instead got {type(mapping)}") 

343 self.mapping = mapping 

344 

345 def __contains__(self, k): 

346 return k in self.mapping 

347 

348 def __getitem__(self, k): 

349 return self.mapping[k] 

350 

351 def __iter__(self): 

352 return iter(self.mapping) 

353 

354 def __len__(self): 

355 return len(self.mapping) 

356 

357 def is_materialized(self): 

358 return True 

359 

360 def get_output_keys(self): 

361 return self.keys() 

362 

363 

364class HighLevelGraph(Graph): 

365 """Task graph composed of layers of dependent subgraphs 

366 

367 This object encodes a Dask task graph that is composed of layers of 

368 dependent subgraphs, such as commonly occurs when building task graphs 

369 using high level collections like Dask array, bag, or dataframe. 

370 

371 Typically each high level array, bag, or dataframe operation takes the task 

372 graphs of the input collections, merges them, and then adds one or more new 

373 layers of tasks for the new operation. These layers typically have at 

374 least as many tasks as there are partitions or chunks in the collection. 

375 The HighLevelGraph object stores the subgraphs for each operation 

376 separately in sub-graphs, and also stores the dependency structure between 

377 them. 

378 

379 Parameters 

380 ---------- 

381 layers : Mapping[str, Mapping] 

382 The subgraph layers, keyed by a unique name 

383 dependencies : Mapping[str, set[str]] 

384 The set of layers on which each layer depends 

385 key_dependencies : dict[Key, set], optional 

386 Mapping (some) keys in the high level graph to their dependencies. If 

387 a key is missing, its dependencies will be calculated on-the-fly. 

388 

389 Examples 

390 -------- 

391 Here is an idealized example that shows the internal state of a 

392 HighLevelGraph 

393 

394 >>> import dask.dataframe as dd 

395 

396 >>> df = dd.read_csv('myfile.*.csv') # doctest: +SKIP 

397 >>> df = df + 100 # doctest: +SKIP 

398 >>> df = df[df.name == 'Alice'] # doctest: +SKIP 

399 

400 >>> graph = df.__dask_graph__() # doctest: +SKIP 

401 >>> graph.layers # doctest: +SKIP 

402 { 

403 'read-csv': {('read-csv', 0): (pandas.read_csv, 'myfile.0.csv'), 

404 ('read-csv', 1): (pandas.read_csv, 'myfile.1.csv'), 

405 ('read-csv', 2): (pandas.read_csv, 'myfile.2.csv'), 

406 ('read-csv', 3): (pandas.read_csv, 'myfile.3.csv')}, 

407 'add': {('add', 0): (operator.add, ('read-csv', 0), 100), 

408 ('add', 1): (operator.add, ('read-csv', 1), 100), 

409 ('add', 2): (operator.add, ('read-csv', 2), 100), 

410 ('add', 3): (operator.add, ('read-csv', 3), 100)} 

411 'filter': {('filter', 0): (lambda part: part[part.name == 'Alice'], ('add', 0)), 

412 ('filter', 1): (lambda part: part[part.name == 'Alice'], ('add', 1)), 

413 ('filter', 2): (lambda part: part[part.name == 'Alice'], ('add', 2)), 

414 ('filter', 3): (lambda part: part[part.name == 'Alice'], ('add', 3))} 

415 } 

416 

417 >>> graph.dependencies # doctest: +SKIP 

418 { 

419 'read-csv': set(), 

420 'add': {'read-csv'}, 

421 'filter': {'add'} 

422 } 

423 

424 See Also 

425 -------- 

426 HighLevelGraph.from_collections : 

427 typically used by developers to make new HighLevelGraphs 

428 """ 

429 

430 layers: Mapping[str, Layer] 

431 dependencies: Mapping[str, set[str]] 

432 key_dependencies: dict[Key, set[Key]] 

433 _to_dict: dict 

434 _all_external_keys: set 

435 

436 def __init__( 

437 self, 

438 layers: Mapping[str, Graph], 

439 dependencies: Mapping[str, set[str]], 

440 key_dependencies: dict[Key, set[Key]] | None = None, 

441 ): 

442 self.dependencies = dependencies 

443 self.key_dependencies = key_dependencies or {} 

444 # Makes sure that all layers are `Layer` 

445 self.layers = { 

446 k: v if isinstance(v, Layer) else MaterializedLayer(v) 

447 for k, v in layers.items() 

448 } 

449 

450 @classmethod 

451 def _from_collection(cls, name, layer, collection): 

452 """`from_collections` optimized for a single collection""" 

453 if not is_dask_collection(collection): 

454 raise TypeError(type(collection)) 

455 

456 graph = collection.__dask_graph__() 

457 if isinstance(graph, HighLevelGraph): 

458 layers = ensure_dict(graph.layers, copy=True) 

459 layers[name] = layer 

460 deps = ensure_dict(graph.dependencies, copy=True) 

461 deps[name] = set(collection.__dask_layers__()) 

462 else: 

463 key = _get_some_layer_name(collection) 

464 layers = {name: layer, key: graph} 

465 deps = {name: {key}, key: set()} 

466 

467 return cls(layers, deps) 

468 

469 @classmethod 

470 def from_collections( 

471 cls, 

472 name: str, 

473 layer: Graph, 

474 dependencies: Sequence[DaskCollection] = (), 

475 ) -> HighLevelGraph: 

476 """Construct a HighLevelGraph from a new layer and a set of collections 

477 

478 This constructs a HighLevelGraph in the common case where we have a single 

479 new layer and a set of old collections on which we want to depend. 

480 

481 This pulls out the ``__dask_layers__()`` method of the collections if 

482 they exist, and adds them to the dependencies for this new layer. It 

483 also merges all of the layers from all of the dependent collections 

484 together into the new layers for this graph. 

485 

486 Parameters 

487 ---------- 

488 name : str 

489 The name of the new layer 

490 layer : Mapping 

491 The graph layer itself 

492 dependencies : List of Dask collections 

493 A list of other dask collections (like arrays or dataframes) that 

494 have graphs themselves 

495 

496 Examples 

497 -------- 

498 

499 In typical usage we make a new task layer, and then pass that layer 

500 along with all dependent collections to this method. 

501 

502 >>> def add(self, other): 

503 ... name = 'add-' + tokenize(self, other) 

504 ... layer = {(name, i): (add, input_key, other) 

505 ... for i, input_key in enumerate(self.__dask_keys__())} 

506 ... graph = HighLevelGraph.from_collections(name, layer, dependencies=[self]) 

507 ... return new_collection(name, graph) 

508 """ 

509 if len(dependencies) == 1: 

510 return cls._from_collection(name, layer, dependencies[0]) 

511 layers = {name: layer} 

512 name_dep: set[str] = set() 

513 deps: dict[str, set[str]] = {name: name_dep} 

514 for collection in toolz.unique(dependencies, key=id): 

515 if is_dask_collection(collection): 

516 graph = collection.__dask_graph__() 

517 if isinstance(graph, HighLevelGraph): 

518 layers.update(graph.layers) 

519 deps.update(graph.dependencies) 

520 name_dep |= set(collection.__dask_layers__()) 

521 else: 

522 key = _get_some_layer_name(collection) 

523 layers[key] = graph 

524 name_dep.add(key) 

525 deps[key] = set() 

526 else: 

527 raise TypeError(type(collection)) 

528 

529 return cls(layers, deps) 

530 

531 def __getitem__(self, key: Key) -> Any: 

532 # Attempt O(1) direct access first, under the assumption that layer names match 

533 # either the keys (Scalar, Item, Delayed) or the first element of the key tuples 

534 # (Array, Bag, DataFrame, Series). This assumption is not always true. 

535 try: 

536 return self.layers[key][key] # type: ignore 

537 except KeyError: 

538 pass 

539 try: 

540 return self.layers[key[0]][key] # type: ignore 

541 except (KeyError, IndexError, TypeError): 

542 pass 

543 

544 # Fall back to O(n) access 

545 for d in self.layers.values(): 

546 try: 

547 return d[key] 

548 except KeyError: 

549 pass 

550 

551 raise KeyError(key) 

552 

553 def __len__(self) -> int: 

554 # NOTE: this will double-count keys that are duplicated between layers, so it's 

555 # possible that `len(hlg) > len(hlg.to_dict())`. However, duplicate keys should 

556 # not occur through normal use, and their existence would usually be a bug. 

557 # So we ignore this case in favor of better performance. 

558 # https://github.com/dask/dask/issues/7271 

559 return sum(len(layer) for layer in self.layers.values()) 

560 

561 def __iter__(self) -> Iterator[Key]: 

562 return iter(self.to_dict()) 

563 

564 def to_dict(self) -> dict[Key, Any]: 

565 """Efficiently convert to plain dict. This method is faster than dict(self).""" 

566 try: 

567 return self._to_dict 

568 except AttributeError: 

569 out = self._to_dict = ensure_dict(self) 

570 return out 

571 

572 def keys(self) -> KeysView: 

573 """Get all keys of all the layers. 

574 

575 This will in many cases materialize layers, which makes it a relatively 

576 expensive operation. See :meth:`get_all_external_keys` for a faster alternative. 

577 """ 

578 return self.to_dict().keys() 

579 

580 def get_all_external_keys(self) -> set[Key]: 

581 """Get all output keys of all layers 

582 

583 This will in most cases _not_ materialize any layers, which makes 

584 it a relative cheap operation. 

585 

586 Returns 

587 ------- 

588 keys: set 

589 A set of all external keys 

590 """ 

591 try: 

592 return self._all_external_keys 

593 except AttributeError: 

594 keys: set = set() 

595 for layer in self.layers.values(): 

596 # Note: don't use `keys |= ...`, because the RHS is a 

597 # collections.abc.Set rather than a real set, and this will 

598 # cause a whole new set to be constructed. 

599 keys.update(layer.get_output_keys()) 

600 self._all_external_keys = keys 

601 return keys 

602 

603 def items(self) -> ItemsView[Key, Any]: 

604 return self.to_dict().items() 

605 

606 def values(self) -> ValuesView[Any]: 

607 return self.to_dict().values() 

608 

609 def get_all_dependencies(self) -> dict[Key, set[Key]]: 

610 """Get dependencies of all keys 

611 

612 This will in most cases materialize all layers, which makes 

613 it an expensive operation. 

614 

615 Returns 

616 ------- 

617 map: Mapping 

618 A map that maps each key to its dependencies 

619 """ 

620 all_keys = self.keys() 

621 missing_keys = all_keys - self.key_dependencies.keys() 

622 if missing_keys: 

623 for layer in self.layers.values(): 

624 for k in missing_keys & layer.keys(): 

625 self.key_dependencies[k] = layer.get_dependencies(k, all_keys) 

626 return self.key_dependencies 

627 

628 @property 

629 def dependents(self) -> dict[str, set[str]]: 

630 return reverse_dict(self.dependencies) 

631 

632 def copy(self) -> HighLevelGraph: 

633 return HighLevelGraph( 

634 ensure_dict(self.layers, copy=True), 

635 ensure_dict(self.dependencies, copy=True), 

636 self.key_dependencies.copy(), 

637 ) 

638 

639 @classmethod 

640 def merge(cls, *graphs: Graph) -> HighLevelGraph: 

641 layers: dict[str, Graph] = {} 

642 dependencies: dict[str, set[str]] = {} 

643 for g in graphs: 

644 if isinstance(g, HighLevelGraph): 

645 layers.update(g.layers) 

646 dependencies.update(g.dependencies) 

647 elif isinstance(g, Mapping): 

648 layers[str(id(g))] = g 

649 dependencies[str(id(g))] = set() 

650 else: 

651 raise TypeError(g) 

652 return cls(layers, dependencies) 

653 

654 def visualize(self, filename="dask-hlg.svg", format=None, **kwargs): 

655 """ 

656 Visualize this dask high level graph. 

657 

658 Requires ``graphviz`` to be installed. 

659 

660 Parameters 

661 ---------- 

662 filename : str or None, optional 

663 The name of the file to write to disk. If the provided `filename` 

664 doesn't include an extension, '.png' will be used by default. 

665 If `filename` is None, no file will be written, and the graph is 

666 rendered in the Jupyter notebook only. 

667 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional 

668 Format in which to write output file. Default is 'svg'. 

669 color : {None, 'layer_type'}, optional (default: None) 

670 Options to color nodes. 

671 - None, no colors. 

672 - layer_type, color nodes based on the layer type. 

673 **kwargs 

674 Additional keyword arguments to forward to ``to_graphviz``. 

675 

676 Examples 

677 -------- 

678 >>> x.dask.visualize(filename='dask.svg') # doctest: +SKIP 

679 >>> x.dask.visualize(filename='dask.svg', color='layer_type') # doctest: +SKIP 

680 

681 Returns 

682 ------- 

683 result : IPython.display.Image, IPython.display.SVG, or None 

684 See dask.dot.dot_graph for more information. 

685 

686 See Also 

687 -------- 

688 dask.dot.dot_graph 

689 dask.base.visualize # low level variant 

690 """ 

691 

692 from dask.dot import graphviz_to_file 

693 

694 g = to_graphviz(self, **kwargs) 

695 graphviz_to_file(g, filename, format) 

696 return g 

697 

698 def _toposort_layers(self) -> list[str]: 

699 """Sort the layers in a high level graph topologically 

700 

701 Parameters 

702 ---------- 

703 hlg : HighLevelGraph 

704 The high level graph's layers to sort 

705 

706 Returns 

707 ------- 

708 sorted: list 

709 List of layer names sorted topologically 

710 """ 

711 degree = {k: len(v) for k, v in self.dependencies.items()} 

712 reverse_deps: dict[str, list[str]] = {k: [] for k in self.dependencies} 

713 ready = [] 

714 for k, v in self.dependencies.items(): 

715 for dep in v: 

716 reverse_deps[dep].append(k) 

717 if not v: 

718 ready.append(k) 

719 ret = [] 

720 while len(ready) > 0: 

721 layer = ready.pop() 

722 ret.append(layer) 

723 for rdep in reverse_deps[layer]: 

724 degree[rdep] -= 1 

725 if degree[rdep] == 0: 

726 ready.append(rdep) 

727 return ret 

728 

729 def cull(self, keys: Iterable[Key]) -> HighLevelGraph: 

730 """Return new HighLevelGraph with only the tasks required to calculate keys. 

731 

732 In other words, remove unnecessary tasks from dask. 

733 

734 Parameters 

735 ---------- 

736 keys 

737 iterable of keys or nested list of keys such as the output of 

738 ``__dask_keys__()`` 

739 

740 Returns 

741 ------- 

742 hlg: HighLevelGraph 

743 Culled high level graph 

744 """ 

745 keys_set = set(flatten(keys)) 

746 

747 # Note: All Layer classes still in existence are of 

748 # one of these types (or subclasses) 

749 # 

750 # - MaterializedLayer 

751 # - Blockwise 

752 # - ArrayOverlapLayer (which is basically as good as MaterializedLayer) 

753 if not any(layer.has_legacy_tasks for layer in self.layers.values()): 

754 all_ext_keys = set() 

755 else: 

756 # FIXME: Technically, we don't need to compute **all** keys but only 

757 # those of the current layer and all of its dependencies, i.e. if 

758 # there are legacy layers for IO followed by many blockwise layers, 

759 # we should still get by without this 

760 all_ext_keys = self.get_all_external_keys() 

761 

762 ret_layers: dict = {} 

763 layer_dependencies = {} 

764 tok = tokenize(keys_set) 

765 for layer_name in reversed(self._toposort_layers()): 

766 new_layer_name = f"{layer_name}-{tok}" 

767 layer = self.layers[layer_name] 

768 if keys_set: 

769 culled_layer, culled_deps = layer.cull(keys_set, all_ext_keys) 

770 if not culled_deps: 

771 continue 

772 

773 # Update `keys` with all layer's external key dependencies, 

774 # which are all the layer's dependencies (`culled_deps`) 

775 # excluding the layer's output keys. 

776 for k, d in culled_deps.items(): 

777 keys_set |= d 

778 keys_set.discard(k) 

779 layer = culled_layer 

780 # Save the culled layer and its key dependencies 

781 ret_layers[new_layer_name] = layer 

782 layer_dependencies[new_layer_name] = self.dependencies[layer_name] 

783 

784 # Converting dict_keys to a real set lets Python optimise the set 

785 # intersection to iterate over the smaller of the two sets. 

786 ret_layers_keys = set(ret_layers.keys()) 

787 ret_dependencies = { 

788 layer_name: layer_dependencies[layer_name] & ret_layers_keys 

789 for layer_name in ret_layers 

790 } 

791 

792 return HighLevelGraph(ret_layers, ret_dependencies) 

793 

794 def cull_layers(self, layers: Iterable[str]) -> HighLevelGraph: 

795 """Return a new HighLevelGraph with only the given layers and their 

796 dependencies. Internally, layers are not modified. 

797 

798 This is a variant of :meth:`HighLevelGraph.cull` which is much faster and does 

799 not risk creating a collision between two layers with the same name and 

800 different content when two culled graphs are merged later on. 

801 

802 Returns 

803 ------- 

804 hlg: HighLevelGraph 

805 Culled high level graph 

806 """ 

807 to_visit = set(layers) 

808 ret_layers = {} 

809 ret_dependencies = {} 

810 while to_visit: 

811 k = to_visit.pop() 

812 ret_layers[k] = self.layers[k] 

813 ret_dependencies[k] = self.dependencies[k] 

814 to_visit |= ret_dependencies[k] - ret_dependencies.keys() 

815 

816 return HighLevelGraph(ret_layers, ret_dependencies) 

817 

818 def validate(self) -> None: 

819 # Check dependencies 

820 for layer_name, deps in self.dependencies.items(): 

821 if layer_name not in self.layers: 

822 raise ValueError( 

823 f"dependencies[{repr(layer_name)}] not found in layers" 

824 ) 

825 for dep in deps: 

826 if dep not in self.dependencies: 

827 raise ValueError(f"{repr(dep)} not found in dependencies") 

828 

829 for layer in self.layers.values(): 

830 assert hasattr(layer, "annotations") 

831 

832 # Re-calculate all layer dependencies 

833 dependencies = compute_layer_dependencies(self.layers) 

834 

835 # Check keys 

836 dep_key1 = self.dependencies.keys() 

837 dep_key2 = dependencies.keys() 

838 if dep_key1 != dep_key2: 

839 raise ValueError( 

840 f"incorrect dependencies keys {set(dep_key1)!r} " 

841 f"expected {set(dep_key2)!r}" 

842 ) 

843 

844 # Check values 

845 for k in dep_key1: 

846 if self.dependencies[k] != dependencies[k]: 

847 raise ValueError( 

848 f"incorrect HLG dependencies[{repr(k)}]: {repr(self.dependencies[k])} " 

849 f"expected {repr(dependencies[k])} from task dependencies" 

850 ) 

851 

852 def __repr__(self) -> str: 

853 representation = f"{type(self).__name__} with {len(self.layers)} layers.\n" 

854 representation += f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}>\n" 

855 for i, layerkey in enumerate(self._toposort_layers()): 

856 representation += f" {i}. {layerkey}\n" 

857 return representation 

858 

859 def _repr_html_(self) -> str: 

860 return get_template("highlevelgraph.html.j2").render( 

861 type=type(self).__name__, 

862 layers=self.layers, 

863 toposort=self._toposort_layers(), 

864 layer_dependencies=self.dependencies, 

865 n_outputs=len(self.get_all_external_keys()), 

866 ) 

867 

868 

869def to_graphviz( 

870 hg, 

871 data_attributes=None, 

872 function_attributes=None, 

873 rankdir="BT", 

874 graph_attr=None, 

875 node_attr=None, 

876 edge_attr=None, 

877 **kwargs, 

878): 

879 from dask.dot import label, name 

880 

881 graphviz = import_required( 

882 "graphviz", 

883 "Drawing dask graphs with the graphviz visualization engine requires the `graphviz` " 

884 "python library and the `graphviz` system library.\n\n" 

885 "Please either conda or pip install as follows:\n\n" 

886 " conda install python-graphviz # either conda install\n" 

887 " python -m pip install graphviz # or pip install and follow installation instructions", 

888 ) 

889 

890 data_attributes = data_attributes or {} 

891 function_attributes = function_attributes or {} 

892 graph_attr = graph_attr or {} 

893 node_attr = node_attr or {} 

894 edge_attr = edge_attr or {} 

895 

896 graph_attr["rankdir"] = rankdir 

897 node_attr["shape"] = "box" 

898 node_attr["fontname"] = "helvetica" 

899 

900 graph_attr.update(kwargs) 

901 g = graphviz.Digraph( 

902 graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr 

903 ) 

904 

905 n_tasks = {} 

906 for layer in hg.dependencies: 

907 n_tasks[layer] = len(hg.layers[layer]) 

908 

909 min_tasks = min(n_tasks.values()) 

910 max_tasks = max(n_tasks.values()) 

911 

912 cache = {} 

913 

914 color = kwargs.get("color") 

915 if color == "layer_type": 

916 layer_colors = { 

917 "DataFrameIOLayer": ["#CCC7F9", False], # purple 

918 "ArrayOverlayLayer": ["#FFD9F2", False], # pink 

919 "BroadcastJoinLayer": ["#D9F2FF", False], # blue 

920 "Blockwise": ["#D9FFE6", False], # green 

921 "BlockwiseLayer": ["#D9FFE6", False], # green 

922 "MaterializedLayer": ["#DBDEE5", False], # gray 

923 } 

924 

925 for layer in hg.dependencies: 

926 layer_name = name(layer) 

927 attrs = data_attributes.get(layer, {}) 

928 

929 node_label = label(layer, cache=cache) 

930 node_size = ( 

931 20 

932 if max_tasks == min_tasks 

933 else int(20 + ((n_tasks[layer] - min_tasks) / (max_tasks - min_tasks)) * 20) 

934 ) 

935 

936 layer_type = str(type(hg.layers[layer]).__name__) 

937 node_tooltips = ( 

938 f"A {layer_type.replace('Layer', '')} Layer with {n_tasks[layer]} Tasks.\n" 

939 ) 

940 

941 layer_ca = hg.layers[layer].collection_annotations 

942 if layer_ca: 

943 if layer_ca.get("type") == "dask.array.core.Array": 

944 node_tooltips += ( 

945 f"Array Shape: {layer_ca.get('shape')}\n" 

946 f"Data Type: {layer_ca.get('dtype')}\n" 

947 f"Chunk Size: {layer_ca.get('chunksize')}\n" 

948 f"Chunk Type: {layer_ca.get('chunk_type')}\n" 

949 ) 

950 

951 if layer_ca.get("type") == "dask.dataframe.core.DataFrame": 

952 dftype = {"pandas.core.frame.DataFrame": "pandas"} 

953 cols = layer_ca.get("columns") 

954 

955 node_tooltips += ( 

956 f"Number of Partitions: {layer_ca.get('npartitions')}\n" 

957 f"DataFrame Type: {dftype.get(layer_ca.get('dataframe_type'))}\n" 

958 f"{len(cols)} DataFrame Columns: {str(cols) if len(str(cols)) <= 40 else '[...]'}\n" 

959 ) 

960 

961 attrs.setdefault("label", str(node_label)) 

962 attrs.setdefault("fontsize", str(node_size)) 

963 attrs.setdefault("tooltip", str(node_tooltips)) 

964 

965 if color == "layer_type": 

966 node_color = layer_colors.get(layer_type)[0] 

967 layer_colors.get(layer_type)[1] = True 

968 

969 attrs.setdefault("fillcolor", str(node_color)) 

970 attrs.setdefault("style", "filled") 

971 

972 g.node(layer_name, **attrs) 

973 

974 for layer, deps in hg.dependencies.items(): 

975 layer_name = name(layer) 

976 for dep in deps: 

977 dep_name = name(dep) 

978 g.edge(dep_name, layer_name) 

979 

980 if color == "layer_type": 

981 legend_title = "Key" 

982 

983 legend_label = ( 

984 '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="5">' 

985 "<TR><TD><B>Legend: Layer types</B></TD></TR>" 

986 ) 

987 

988 for layer_type, color in layer_colors.items(): 

989 if color[1]: 

990 legend_label += f'<TR><TD BGCOLOR="{color[0]}">{layer_type}</TD></TR>' 

991 

992 legend_label += "</TABLE>>" 

993 

994 attrs = data_attributes.get(legend_title, {}) 

995 attrs.setdefault("label", str(legend_label)) 

996 attrs.setdefault("fontsize", "20") 

997 attrs.setdefault("margin", "0") 

998 

999 g.node(legend_title, **attrs) 

1000 

1001 return g 

1002 

1003 

1004def _get_some_layer_name(collection) -> str: 

1005 """Somehow get a unique name for a Layer from a non-HighLevelGraph dask mapping""" 

1006 try: 

1007 (name,) = collection.__dask_layers__() 

1008 return name 

1009 except (AttributeError, ValueError): 

1010 # collection does not define the optional __dask_layers__ method 

1011 # or it spuriously returns more than one layer 

1012 return str(id(collection)) 

1013 

1014 

1015@normalize_token.register(HighLevelGraph) 

1016def register_highlevelgraph(hlg): 

1017 # Note: Layer keys are not necessarily identifying HLGs uniquely 

1018 # see https://github.com/dask/dask/issues/9888 

1019 return normalize_token(list(hlg.layers.keys()))