Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/base.py: 18%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

378 statements  

1from __future__ import annotations 

2 

3import dataclasses 

4import uuid 

5import warnings 

6from collections import OrderedDict 

7from collections.abc import Hashable, Iterable, Iterator, Mapping 

8from concurrent.futures import Executor 

9from contextlib import contextmanager, suppress 

10from contextvars import ContextVar 

11from functools import partial 

12from numbers import Integral, Number 

13from operator import getitem 

14from typing import TYPE_CHECKING, Any, Literal, TypeVar 

15 

16from tlz import merge 

17 

18from dask import config, local 

19from dask._compatibility import EMSCRIPTEN 

20from dask._task_spec import DataNode, Dict, List, Task, TaskRef 

21from dask.core import flatten 

22from dask.core import get as simple_get 

23from dask.system import CPU_COUNT 

24from dask.typing import Key, SchedulerGetCallable 

25from dask.utils import is_namedtuple_instance, key_split, shorten_traceback 

26 

27if TYPE_CHECKING: 

28 from dask._expr import Expr 

29 

30_DistributedClient = None 

31_get_distributed_client = None 

32_DISTRIBUTED_AVAILABLE = None 

33 

34 

35def _distributed_available() -> bool: 

36 # Lazy import in get_scheduler can be expensive 

37 global _DistributedClient, _get_distributed_client, _DISTRIBUTED_AVAILABLE 

38 if _DISTRIBUTED_AVAILABLE is not None: 

39 return _DISTRIBUTED_AVAILABLE # type: ignore[unreachable] 

40 try: 

41 from distributed import Client as _DistributedClient 

42 from distributed.worker import get_client as _get_distributed_client 

43 

44 _DISTRIBUTED_AVAILABLE = True 

45 except ImportError: 

46 _DISTRIBUTED_AVAILABLE = False 

47 return _DISTRIBUTED_AVAILABLE 

48 

49 

50__all__ = ( 

51 "DaskMethodsMixin", 

52 "annotate", 

53 "get_annotations", 

54 "is_dask_collection", 

55 "compute", 

56 "persist", 

57 "optimize", 

58 "visualize", 

59 "tokenize", 

60 "normalize_token", 

61 "get_collection_names", 

62 "get_name_from_key", 

63 "replace_name_in_key", 

64 "clone_key", 

65) 

66 

67# Backwards compat 

68from dask.tokenize import TokenizationError, normalize_token, tokenize # noqa: F401 

69 

70_annotations: ContextVar[dict[str, Any] | None] = ContextVar( 

71 "annotations", default=None 

72) 

73 

74 

75def get_annotations() -> dict[str, Any]: 

76 """Get current annotations. 

77 

78 Returns 

79 ------- 

80 Dict of all current annotations 

81 

82 See Also 

83 -------- 

84 annotate 

85 """ 

86 return _annotations.get() or {} 

87 

88 

89@contextmanager 

90def annotate(**annotations: Any) -> Iterator[None]: 

91 """Context Manager for setting HighLevelGraph Layer annotations. 

92 

93 Annotations are metadata or soft constraints associated with 

94 tasks that dask schedulers may choose to respect: They signal intent 

95 without enforcing hard constraints. As such, they are 

96 primarily designed for use with the distributed scheduler. 

97 

98 Almost any object can serve as an annotation, but small Python objects 

99 are preferred, while large objects such as NumPy arrays are discouraged. 

100 

101 Callables supplied as an annotation should take a single *key* argument and 

102 produce the appropriate annotation. Individual task keys in the annotated collection 

103 are supplied to the callable. 

104 

105 Parameters 

106 ---------- 

107 **annotations : key-value pairs 

108 

109 Examples 

110 -------- 

111 

112 All tasks within array A should have priority 100 and be retried 3 times 

113 on failure. 

114 

115 >>> import dask 

116 >>> import dask.array as da 

117 >>> with dask.annotate(priority=100, retries=3): 

118 ... A = da.ones((10000, 10000)) 

119 

120 Prioritise tasks within Array A on flattened block ID. 

121 

122 >>> nblocks = (10, 10) 

123 >>> with dask.annotate(priority=lambda k: k[1]*nblocks[1] + k[2]): 

124 ... A = da.ones((1000, 1000), chunks=(100, 100)) 

125 

126 Annotations may be nested. 

127 

128 >>> with dask.annotate(priority=1): 

129 ... with dask.annotate(retries=3): 

130 ... A = da.ones((1000, 1000)) 

131 ... B = A + 1 

132 

133 See Also 

134 -------- 

135 get_annotations 

136 """ 

137 

138 # Sanity check annotations used in place of 

139 # legacy distributed Client.{submit, persist, compute} keywords 

140 if "workers" in annotations: 

141 if isinstance(annotations["workers"], (list, set, tuple)): 

142 annotations["workers"] = list(annotations["workers"]) 

143 elif isinstance(annotations["workers"], str): 

144 annotations["workers"] = [annotations["workers"]] 

145 elif callable(annotations["workers"]): 

146 pass 

147 else: 

148 raise TypeError( 

149 "'workers' annotation must be a sequence of str, a str or a callable, but got {}.".format( 

150 annotations["workers"] 

151 ) 

152 ) 

153 

154 if ( 

155 "priority" in annotations 

156 and not isinstance(annotations["priority"], Number) 

157 and not callable(annotations["priority"]) 

158 ): 

159 raise TypeError( 

160 "'priority' annotation must be a Number or a callable, but got {}".format( 

161 annotations["priority"] 

162 ) 

163 ) 

164 

165 if ( 

166 "retries" in annotations 

167 and not isinstance(annotations["retries"], Number) 

168 and not callable(annotations["retries"]) 

169 ): 

170 raise TypeError( 

171 "'retries' annotation must be a Number or a callable, but got {}".format( 

172 annotations["retries"] 

173 ) 

174 ) 

175 

176 if ( 

177 "resources" in annotations 

178 and not isinstance(annotations["resources"], dict) 

179 and not callable(annotations["resources"]) 

180 ): 

181 raise TypeError( 

182 "'resources' annotation must be a dict, but got {}".format( 

183 annotations["resources"] 

184 ) 

185 ) 

186 

187 if ( 

188 "allow_other_workers" in annotations 

189 and not isinstance(annotations["allow_other_workers"], bool) 

190 and not callable(annotations["allow_other_workers"]) 

191 ): 

192 raise TypeError( 

193 "'allow_other_workers' annotations must be a bool or a callable, but got {}".format( 

194 annotations["allow_other_workers"] 

195 ) 

196 ) 

197 ctx_annot = _annotations.get() 

198 if ctx_annot is None: 

199 ctx_annot = {} 

200 token = _annotations.set(merge(ctx_annot, annotations)) 

201 try: 

202 yield 

203 finally: 

204 _annotations.reset(token) 

205 

206 

207def is_dask_collection(x) -> bool: 

208 """Returns ``True`` if ``x`` is a dask collection. 

209 

210 Parameters 

211 ---------- 

212 x : Any 

213 Object to test. 

214 

215 Returns 

216 ------- 

217 result : bool 

218 ``True`` if `x` is a Dask collection. 

219 

220 Notes 

221 ----- 

222 The DaskCollection typing.Protocol implementation defines a Dask 

223 collection as a class that returns a Mapping from the 

224 ``__dask_graph__`` method. This helper function existed before the 

225 implementation of the protocol. 

226 

227 """ 

228 if ( 

229 isinstance(x, type) 

230 or not hasattr(x, "__dask_graph__") 

231 or not callable(x.__dask_graph__) 

232 ): 

233 return False 

234 

235 pkg_name = getattr(type(x), "__module__", "") 

236 if pkg_name.split(".")[0] in ("dask_cudf",): 

237 # Temporary hack to avoid graph materialization. Note that this won't work with 

238 # dask_expr.array objects wrapped by xarray or pint. By the time dask_expr.array 

239 # is published, we hope to be able to rewrite this method completely. 

240 # Read: https://github.com/dask/dask/pull/10676 

241 return True 

242 elif pkg_name.startswith("dask.dataframe.dask_expr"): 

243 return True 

244 elif pkg_name.startswith("dask.array._array_expr"): 

245 return True 

246 

247 # xarray, pint, and possibly other wrappers always define a __dask_graph__ method, 

248 # but it may return None if they wrap around a non-dask object. 

249 # In all known dask collections other than dask-expr, 

250 # calling __dask_graph__ is cheap. 

251 return x.__dask_graph__() is not None 

252 

253 

254class DaskMethodsMixin: 

255 """A mixin adding standard dask collection methods""" 

256 

257 __slots__ = ("__weakref__",) 

258 

259 def visualize(self, filename="mydask", format=None, optimize_graph=False, **kwargs): 

260 """Render the computation of this object's task graph using graphviz. 

261 

262 Requires ``graphviz`` to be installed. 

263 

264 Parameters 

265 ---------- 

266 filename : str or None, optional 

267 The name of the file to write to disk. If the provided `filename` 

268 doesn't include an extension, '.png' will be used by default. 

269 If `filename` is None, no file will be written, and we communicate 

270 with dot using only pipes. 

271 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional 

272 Format in which to write output file. Default is 'png'. 

273 optimize_graph : bool, optional 

274 If True, the graph is optimized before rendering. Otherwise, 

275 the graph is displayed as is. Default is False. 

276 color: {None, 'order'}, optional 

277 Options to color nodes. Provide ``cmap=`` keyword for additional 

278 colormap 

279 **kwargs 

280 Additional keyword arguments to forward to ``to_graphviz``. 

281 

282 Examples 

283 -------- 

284 >>> x.visualize(filename='dask.pdf') # doctest: +SKIP 

285 >>> x.visualize(filename='dask.pdf', color='order') # doctest: +SKIP 

286 

287 Returns 

288 ------- 

289 result : IPython.display.Image, IPython.display.SVG, or None 

290 See dask.dot.dot_graph for more information. 

291 

292 See Also 

293 -------- 

294 dask.visualize 

295 dask.dot.dot_graph 

296 

297 Notes 

298 ----- 

299 For more information on optimization see here: 

300 

301 https://docs.dask.org/en/latest/optimize.html 

302 """ 

303 return visualize( 

304 self, 

305 filename=filename, 

306 format=format, 

307 optimize_graph=optimize_graph, 

308 **kwargs, 

309 ) 

310 

311 def persist(self, **kwargs): 

312 """Persist this dask collection into memory 

313 

314 This turns a lazy Dask collection into a Dask collection with the same 

315 metadata, but now with the results fully computed or actively computing 

316 in the background. 

317 

318 The action of function differs significantly depending on the active 

319 task scheduler. If the task scheduler supports asynchronous computing, 

320 such as is the case of the dask.distributed scheduler, then persist 

321 will return *immediately* and the return value's task graph will 

322 contain Dask Future objects. However if the task scheduler only 

323 supports blocking computation then the call to persist will *block* 

324 and the return value's task graph will contain concrete Python results. 

325 

326 This function is particularly useful when using distributed systems, 

327 because the results will be kept in distributed memory, rather than 

328 returned to the local process as with compute. 

329 

330 Parameters 

331 ---------- 

332 scheduler : string, optional 

333 Which scheduler to use like "threads", "synchronous" or "processes". 

334 If not provided, the default is to check the global settings first, 

335 and then fall back to the collection defaults. 

336 optimize_graph : bool, optional 

337 If True [default], the graph is optimized before computation. 

338 Otherwise the graph is run as is. This can be useful for debugging. 

339 **kwargs 

340 Extra keywords to forward to the scheduler function. 

341 

342 Returns 

343 ------- 

344 New dask collections backed by in-memory data 

345 

346 See Also 

347 -------- 

348 dask.persist 

349 """ 

350 (result,) = persist(self, traverse=False, **kwargs) 

351 return result 

352 

353 def compute(self, **kwargs): 

354 """Compute this dask collection 

355 

356 This turns a lazy Dask collection into its in-memory equivalent. 

357 For example a Dask array turns into a NumPy array and a Dask dataframe 

358 turns into a Pandas dataframe. The entire dataset must fit into memory 

359 before calling this operation. 

360 

361 Parameters 

362 ---------- 

363 scheduler : string, optional 

364 Which scheduler to use like "threads", "synchronous" or "processes". 

365 If not provided, the default is to check the global settings first, 

366 and then fall back to the collection defaults. 

367 optimize_graph : bool, optional 

368 If True [default], the graph is optimized before computation. 

369 Otherwise the graph is run as is. This can be useful for debugging. 

370 kwargs 

371 Extra keywords to forward to the scheduler function. 

372 

373 See Also 

374 -------- 

375 dask.compute 

376 """ 

377 (result,) = compute(self, traverse=False, **kwargs) 

378 return result 

379 

380 def __await__(self): 

381 try: 

382 from distributed import futures_of, wait 

383 except ImportError as e: 

384 raise ImportError( 

385 "Using async/await with dask requires the `distributed` package" 

386 ) from e 

387 

388 async def f(): 

389 if futures_of(self): 

390 await wait(self) 

391 return self 

392 

393 return f().__await__() 

394 

395 

396def compute_as_if_collection(cls, dsk, keys, scheduler=None, get=None, **kwargs): 

397 """Compute a graph as if it were of type cls. 

398 

399 Allows for applying the same optimizations and default scheduler.""" 

400 schedule = get_scheduler(scheduler=scheduler, cls=cls, get=get) 

401 dsk2 = optimization_function(cls)(dsk, keys, **kwargs) 

402 return schedule(dsk2, keys, **kwargs) 

403 

404 

405def dont_optimize(dsk, keys, **kwargs): 

406 return dsk 

407 

408 

409def optimization_function(x): 

410 return getattr(x, "__dask_optimize__", dont_optimize) 

411 

412 

413def collections_to_expr( 

414 collections: Iterable, 

415 optimize_graph: bool = True, 

416) -> Expr: 

417 """ 

418 Convert many collections into a single dask expression. 

419 

420 Typically, users should not be required to interact with this function. 

421 

422 Parameters 

423 ---------- 

424 collections : Iterable 

425 An iterable of dask collections to be combined. 

426 optimize_graph : bool, optional 

427 If this is True and collections are encountered which are backed by 

428 legacy HighLevelGraph objects, the returned Expression will run a low 

429 level task optimization during materialization. 

430 """ 

431 is_iterable = False 

432 if isinstance(collections, (tuple, list, set)): 

433 is_iterable = True 

434 else: 

435 collections = [collections] 

436 if not collections: 

437 raise ValueError("No collections provided") 

438 from dask._expr import HLGExpr, _ExprSequence 

439 

440 graphs = [] 

441 for coll in collections: 

442 from dask.delayed import Delayed 

443 

444 if isinstance(coll, Delayed) or not hasattr(coll, "expr"): 

445 graphs.append(HLGExpr.from_collection(coll, optimize_graph=optimize_graph)) 

446 else: 

447 graphs.append(coll.expr) 

448 

449 if len(graphs) > 1 or is_iterable: 

450 return _ExprSequence(*graphs) 

451 else: 

452 return graphs[0] 

453 

454 

455def unpack_collections(*args, traverse=True): 

456 """Extract collections in preparation for compute/persist/etc... 

457 

458 Intended use is to find all collections in a set of (possibly nested) 

459 python objects, do something to them (compute, etc...), then repackage them 

460 in equivalent python objects. 

461 

462 Parameters 

463 ---------- 

464 *args 

465 Any number of objects. If it is a dask collection, it's extracted and 

466 added to the list of collections returned. By default, python builtin 

467 collections are also traversed to look for dask collections (for more 

468 information see the ``traverse`` keyword). 

469 traverse : bool, optional 

470 If True (default), builtin python collections are traversed looking for 

471 any dask collections they might contain. 

472 

473 Returns 

474 ------- 

475 collections : list 

476 A list of all dask collections contained in ``args`` 

477 repack : callable 

478 A function to call on the transformed collections to repackage them as 

479 they were in the original ``args``. 

480 """ 

481 

482 collections = [] 

483 repack_dsk = {} 

484 

485 collections_token = uuid.uuid4().hex 

486 

487 def _unpack(expr): 

488 if is_dask_collection(expr): 

489 tok = tokenize(expr) 

490 if tok not in repack_dsk: 

491 repack_dsk[tok] = Task( 

492 tok, getitem, TaskRef(collections_token), len(collections) 

493 ) 

494 collections.append(expr) 

495 return TaskRef(tok) 

496 

497 tok = uuid.uuid4().hex 

498 tsk: DataNode | Task # type: ignore[annotation-unchecked] 

499 if not traverse: 

500 tsk = DataNode(None, expr) 

501 else: 

502 # Treat iterators like lists 

503 typ = list if isinstance(expr, Iterator) else type(expr) 

504 if typ in (list, tuple, set): 

505 tsk = Task(tok, typ, List(*[_unpack(i) for i in expr])) 

506 elif typ in (dict, OrderedDict): 

507 tsk = Task( 

508 tok, typ, Dict({_unpack(k): _unpack(v) for k, v in expr.items()}) 

509 ) 

510 elif dataclasses.is_dataclass(expr) and not isinstance(expr, type): 

511 tsk = Task( 

512 tok, 

513 typ, 

514 *[_unpack(getattr(expr, f.name)) for f in dataclasses.fields(expr)], 

515 ) 

516 elif is_namedtuple_instance(expr): 

517 tsk = Task(tok, typ, *[_unpack(i) for i in expr]) 

518 else: 

519 return expr 

520 

521 repack_dsk[tok] = tsk 

522 return TaskRef(tok) 

523 

524 out = uuid.uuid4().hex 

525 repack_dsk[out] = Task(out, tuple, List(*[_unpack(i) for i in args])) 

526 

527 def repack(results): 

528 dsk = repack_dsk.copy() 

529 dsk[collections_token] = DataNode(collections_token, results) 

530 return simple_get(dsk, out) 

531 

532 # The original `collections` is kept alive by the closure 

533 # This causes the collection to be only freed by the garbage collector 

534 collections2 = list(collections) 

535 collections.clear() 

536 return collections2, repack 

537 

538 

539def optimize(*args, traverse=True, **kwargs): 

540 """Optimize several dask collections at once. 

541 

542 Returns equivalent dask collections that all share the same merged and 

543 optimized underlying graph. This can be useful if converting multiple 

544 collections to delayed objects, or to manually apply the optimizations at 

545 strategic points. 

546 

547 Note that in most cases you shouldn't need to call this function directly. 

548 

549 Warning:: 

550 

551 This function triggers a materialization of the collections and looses 

552 any annotations attached to HLG layers. 

553 

554 Parameters 

555 ---------- 

556 *args : objects 

557 Any number of objects. If a dask object, its graph is optimized and 

558 merged with all those of all other dask objects before returning an 

559 equivalent dask collection. Non-dask arguments are passed through 

560 unchanged. 

561 traverse : bool, optional 

562 By default dask traverses builtin python collections looking for dask 

563 objects passed to ``optimize``. For large collections this can be 

564 expensive. If none of the arguments contain any dask objects, set 

565 ``traverse=False`` to avoid doing this traversal. 

566 optimizations : list of callables, optional 

567 Additional optimization passes to perform. 

568 **kwargs 

569 Extra keyword arguments to forward to the optimization passes. 

570 

571 Examples 

572 -------- 

573 >>> import dask 

574 >>> import dask.array as da 

575 >>> a = da.arange(10, chunks=2).sum() 

576 >>> b = da.arange(10, chunks=2).mean() 

577 >>> a2, b2 = dask.optimize(a, b) 

578 

579 >>> a2.compute() == a.compute() 

580 True 

581 >>> b2.compute() == b.compute() 

582 True 

583 """ 

584 # TODO: This API is problematic. The approach to using postpersist forces us 

585 # to materialize the graph. Most low level optimizations will materialize as 

586 # well 

587 collections, repack = unpack_collections(*args, traverse=traverse) 

588 if not collections: 

589 return args 

590 

591 dsk = collections_to_expr(collections) 

592 

593 postpersists = [] 

594 for a in collections: 

595 r, s = a.__dask_postpersist__() 

596 postpersists.append(r(dsk.__dask_graph__(), *s)) 

597 

598 return repack(postpersists) 

599 

600 

601def compute( 

602 *args, 

603 traverse=True, 

604 optimize_graph=True, 

605 scheduler=None, 

606 get=None, 

607 **kwargs, 

608): 

609 """Compute several dask collections at once. 

610 

611 Parameters 

612 ---------- 

613 args : object 

614 Any number of objects. If it is a dask object, it's computed and the 

615 result is returned. By default, python builtin collections are also 

616 traversed to look for dask objects (for more information see the 

617 ``traverse`` keyword). Non-dask arguments are passed through unchanged. 

618 traverse : bool, optional 

619 By default dask traverses builtin python collections looking for dask 

620 objects passed to ``compute``. For large collections this can be 

621 expensive. If none of the arguments contain any dask objects, set 

622 ``traverse=False`` to avoid doing this traversal. 

623 scheduler : string, optional 

624 Which scheduler to use like "threads", "synchronous" or "processes". 

625 If not provided, the default is to check the global settings first, 

626 and then fall back to the collection defaults. 

627 optimize_graph : bool, optional 

628 If True [default], the optimizations for each collection are applied 

629 before computation. Otherwise the graph is run as is. This can be 

630 useful for debugging. 

631 get : ``None`` 

632 Should be left to ``None`` The get= keyword has been removed. 

633 kwargs 

634 Extra keywords to forward to the scheduler function. 

635 

636 Examples 

637 -------- 

638 >>> import dask 

639 >>> import dask.array as da 

640 >>> a = da.arange(10, chunks=2).sum() 

641 >>> b = da.arange(10, chunks=2).mean() 

642 >>> dask.compute(a, b) 

643 (45, 4.5) 

644 

645 By default, dask objects inside python collections will also be computed: 

646 

647 >>> dask.compute({'a': a, 'b': b, 'c': 1}) 

648 ({'a': 45, 'b': 4.5, 'c': 1},) 

649 """ 

650 

651 collections, repack = unpack_collections(*args, traverse=traverse) 

652 if not collections: 

653 return args 

654 

655 schedule = get_scheduler( 

656 scheduler=scheduler, 

657 collections=collections, 

658 get=get, 

659 ) 

660 from dask._expr import FinalizeCompute 

661 

662 expr = collections_to_expr(collections, optimize_graph) 

663 expr = FinalizeCompute(expr) 

664 

665 with shorten_traceback(): 

666 # The high level optimize will have to be called client side (for now) 

667 # The optimize can internally trigger already a computation 

668 # (e.g. parquet is reading some statistics). To move this to the 

669 # scheduler we'd need some sort of scheduler-client to trigger a 

670 # computation from inside the scheduler and continue with optimization 

671 # once the results are in. An alternative could be to introduce a 

672 # pre-optimize step for the Expressions that handles steps like these as 

673 # a dedicated computation 

674 

675 # Another caveat is that optimize will only lock in the expression names 

676 # after optimization. Names are determined using tokenize and tokenize 

677 # is not cross-interpreter (let alone cross-host) stable such that we 

678 # have to lock this in before sending stuff (otherwise we'd need to 

679 # change the graph submission to a handshake which introduces all sorts 

680 # of concurrency control issues) 

681 

682 expr = expr.optimize() 

683 keys = list(flatten(expr.__dask_keys__())) 

684 

685 results = schedule(expr, keys, **kwargs) 

686 

687 return repack(results) 

688 

689 

690def visualize( 

691 *args, 

692 filename="mydask", 

693 traverse=True, 

694 optimize_graph=False, 

695 maxval=None, 

696 engine: Literal["cytoscape", "ipycytoscape", "graphviz"] | None = None, 

697 **kwargs, 

698): 

699 """ 

700 Visualize several dask graphs simultaneously. 

701 

702 Requires ``graphviz`` to be installed. All options that are not the dask 

703 graph(s) should be passed as keyword arguments. 

704 

705 Parameters 

706 ---------- 

707 args : object 

708 Any number of objects. If it is a dask collection (for example, a 

709 dask DataFrame, Array, Bag, or Delayed), its associated graph 

710 will be included in the output of visualize. By default, python builtin 

711 collections are also traversed to look for dask objects (for more 

712 information see the ``traverse`` keyword). Arguments lacking an 

713 associated graph will be ignored. 

714 filename : str or None, optional 

715 The name of the file to write to disk. If the provided `filename` 

716 doesn't include an extension, '.png' will be used by default. 

717 If `filename` is None, no file will be written, and we communicate 

718 with dot using only pipes. 

719 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional 

720 Format in which to write output file. Default is 'png'. 

721 traverse : bool, optional 

722 By default, dask traverses builtin python collections looking for dask 

723 objects passed to ``visualize``. For large collections this can be 

724 expensive. If none of the arguments contain any dask objects, set 

725 ``traverse=False`` to avoid doing this traversal. 

726 optimize_graph : bool, optional 

727 If True, the graph is optimized before rendering. Otherwise, 

728 the graph is displayed as is. Default is False. 

729 color : {None, 'order', 'ages', 'freed', 'memoryincreases', 'memorydecreases', 'memorypressure'}, optional 

730 Options to color nodes. colormap: 

731 

732 - None, the default, no colors. 

733 - 'order', colors the nodes' border based on the order they appear in the graph. 

734 - 'ages', how long the data of a node is held. 

735 - 'freed', the number of dependencies released after running a node. 

736 - 'memoryincreases', how many more outputs are held after the lifetime of a node. 

737 Large values may indicate nodes that should have run later. 

738 - 'memorydecreases', how many fewer outputs are held after the lifetime of a node. 

739 Large values may indicate nodes that should have run sooner. 

740 - 'memorypressure', the number of data held when the node is run (circle), or 

741 the data is released (rectangle). 

742 maxval : {int, float}, optional 

743 Maximum value for colormap to normalize form 0 to 1.0. Default is ``None`` 

744 will make it the max number of values 

745 collapse_outputs : bool, optional 

746 Whether to collapse output boxes, which often have empty labels. 

747 Default is False. 

748 verbose : bool, optional 

749 Whether to label output and input boxes even if the data aren't chunked. 

750 Beware: these labels can get very long. Default is False. 

751 engine : {"graphviz", "ipycytoscape", "cytoscape"}, optional. 

752 The visualization engine to use. If not provided, this checks the dask config 

753 value "visualization.engine". If that is not set, it tries to import ``graphviz`` 

754 and ``ipycytoscape``, using the first one to succeed. 

755 **kwargs 

756 Additional keyword arguments to forward to the visualization engine. 

757 

758 Examples 

759 -------- 

760 >>> x.visualize(filename='dask.pdf') # doctest: +SKIP 

761 >>> x.visualize(filename='dask.pdf', color='order') # doctest: +SKIP 

762 

763 Returns 

764 ------- 

765 result : IPython.display.Image, IPython.display.SVG, or None 

766 See dask.dot.dot_graph for more information. 

767 

768 See Also 

769 -------- 

770 dask.dot.dot_graph 

771 

772 Notes 

773 ----- 

774 For more information on optimization see here: 

775 

776 https://docs.dask.org/en/latest/optimize.html 

777 """ 

778 args, _ = unpack_collections(*args, traverse=traverse) 

779 

780 dsk = collections_to_expr(args, optimize_graph=optimize_graph).__dask_graph__() 

781 

782 return visualize_dsk( 

783 dsk=dsk, 

784 filename=filename, 

785 traverse=traverse, 

786 optimize_graph=optimize_graph, 

787 maxval=maxval, 

788 engine=engine, 

789 **kwargs, 

790 ) 

791 

792 

793def visualize_dsk( 

794 dsk, 

795 filename="mydask", 

796 traverse=True, 

797 optimize_graph=False, 

798 maxval=None, 

799 o=None, 

800 engine: Literal["cytoscape", "ipycytoscape", "graphviz"] | None = None, 

801 limit=None, 

802 **kwargs, 

803): 

804 color = kwargs.get("color") 

805 from dask.order import diagnostics, order 

806 

807 if color in { 

808 "order", 

809 "order-age", 

810 "order-freed", 

811 "order-memoryincreases", 

812 "order-memorydecreases", 

813 "order-memorypressure", 

814 "age", 

815 "freed", 

816 "memoryincreases", 

817 "memorydecreases", 

818 "memorypressure", 

819 "critical", 

820 "cpath", 

821 }: 

822 import matplotlib.pyplot as plt 

823 

824 if o is None: 

825 o_stats = order(dsk, return_stats=True) 

826 o = {k: v.priority for k, v in o_stats.items()} 

827 elif isinstance(next(iter(o.values())), int): 

828 o_stats = order(dsk, return_stats=True) 

829 else: 

830 o_stats = o 

831 o = {k: v.priority for k, v in o.items()} 

832 

833 try: 

834 cmap = kwargs.pop("cmap") 

835 except KeyError: 

836 cmap = plt.cm.plasma 

837 if isinstance(cmap, str): 

838 import matplotlib.pyplot as plt 

839 

840 cmap = getattr(plt.cm, cmap) 

841 

842 def label(x): 

843 return str(values[x]) 

844 

845 data_values = None 

846 if color != "order": 

847 info = diagnostics(dsk, o)[0] 

848 if color.endswith("age"): 

849 values = {key: val.age for key, val in info.items()} 

850 elif color.endswith("freed"): 

851 values = {key: val.num_dependencies_freed for key, val in info.items()} 

852 elif color.endswith("memorypressure"): 

853 values = {key: val.num_data_when_run for key, val in info.items()} 

854 data_values = { 

855 key: val.num_data_when_released for key, val in info.items() 

856 } 

857 elif color.endswith("memoryincreases"): 

858 values = { 

859 key: max(0, val.num_data_when_released - val.num_data_when_run) 

860 for key, val in info.items() 

861 } 

862 elif color.endswith("memorydecreases"): 

863 values = { 

864 key: max(0, val.num_data_when_run - val.num_data_when_released) 

865 for key, val in info.items() 

866 } 

867 elif color.split("-")[-1] in {"critical", "cpath"}: 

868 values = {key: val.critical_path for key, val in o_stats.items()} 

869 else: 

870 raise NotImplementedError(color) 

871 

872 if color.startswith("order-"): 

873 

874 def label(x): 

875 return f"{o[x]}-{values[x]}" 

876 

877 else: 

878 values = o 

879 if maxval is None: 

880 maxval = max(1, *values.values()) 

881 colors = { 

882 k: _colorize(tuple(map(int, cmap(v / maxval, bytes=True)))) 

883 for k, v in values.items() 

884 } 

885 if data_values is None: 

886 data_colors = colors 

887 else: 

888 data_colors = { 

889 k: _colorize(tuple(map(int, cmap(v / maxval, bytes=True)))) 

890 for k, v in values.items() 

891 } 

892 

893 kwargs["function_attributes"] = { 

894 k: {"color": v, "label": label(k)} for k, v in colors.items() 

895 } 

896 kwargs["data_attributes"] = {k: {"color": v} for k, v in data_colors.items()} 

897 elif color: 

898 raise NotImplementedError(f"Unknown value color={color}") 

899 

900 # Determine which engine to dispatch to, first checking the kwarg, then config, 

901 # then whichever of graphviz or ipycytoscape are installed, in that order. 

902 engine = engine or config.get("visualization.engine", None) 

903 

904 if not engine: 

905 try: 

906 import graphviz # noqa: F401 

907 

908 engine = "graphviz" 

909 except ImportError: 

910 try: 

911 import ipycytoscape # noqa: F401 

912 

913 engine = "cytoscape" 

914 except ImportError: 

915 pass 

916 if engine == "graphviz": 

917 from dask.dot import dot_graph 

918 

919 return dot_graph(dsk, filename=filename, **kwargs) 

920 elif engine in ("cytoscape", "ipycytoscape"): 

921 from dask.dot import cytoscape_graph 

922 

923 return cytoscape_graph(dsk, filename=filename, **kwargs) 

924 elif engine is None: 

925 raise RuntimeError( 

926 "No visualization engine detected, please install graphviz or ipycytoscape" 

927 ) 

928 else: 

929 raise ValueError(f"Visualization engine {engine} not recognized") 

930 

931 

932def persist(*args, traverse=True, optimize_graph=True, scheduler=None, **kwargs): 

933 """Persist multiple Dask collections into memory 

934 

935 This turns lazy Dask collections into Dask collections with the same 

936 metadata, but now with their results fully computed or actively computing 

937 in the background. 

938 

939 For example a lazy dask.array built up from many lazy calls will now be a 

940 dask.array of the same shape, dtype, chunks, etc., but now with all of 

941 those previously lazy tasks either computed in memory as many small :class:`numpy.array` 

942 (in the single-machine case) or asynchronously running in the 

943 background on a cluster (in the distributed case). 

944 

945 This function operates differently if a ``dask.distributed.Client`` exists 

946 and is connected to a distributed scheduler. In this case this function 

947 will return as soon as the task graph has been submitted to the cluster, 

948 but before the computations have completed. Computations will continue 

949 asynchronously in the background. When using this function with the single 

950 machine scheduler it blocks until the computations have finished. 

951 

952 When using Dask on a single machine you should ensure that the dataset fits 

953 entirely within memory. 

954 

955 Examples 

956 -------- 

957 >>> df = dd.read_csv('/path/to/*.csv') # doctest: +SKIP 

958 >>> df = df[df.name == 'Alice'] # doctest: +SKIP 

959 >>> df['in-debt'] = df.balance < 0 # doctest: +SKIP 

960 >>> df = df.persist() # triggers computation # doctest: +SKIP 

961 

962 >>> df.value().min() # future computations are now fast # doctest: +SKIP 

963 -10 

964 >>> df.value().max() # doctest: +SKIP 

965 100 

966 

967 >>> from dask import persist # use persist function on multiple collections 

968 >>> a, b = persist(a, b) # doctest: +SKIP 

969 

970 Parameters 

971 ---------- 

972 *args: Dask collections 

973 scheduler : string, optional 

974 Which scheduler to use like "threads", "synchronous" or "processes". 

975 If not provided, the default is to check the global settings first, 

976 and then fall back to the collection defaults. 

977 traverse : bool, optional 

978 By default dask traverses builtin python collections looking for dask 

979 objects passed to ``persist``. For large collections this can be 

980 expensive. If none of the arguments contain any dask objects, set 

981 ``traverse=False`` to avoid doing this traversal. 

982 optimize_graph : bool, optional 

983 If True [default], the graph is optimized before computation. 

984 Otherwise the graph is run as is. This can be useful for debugging. 

985 **kwargs 

986 Extra keywords to forward to the scheduler function. 

987 

988 Returns 

989 ------- 

990 New dask collections backed by in-memory data 

991 """ 

992 collections, repack = unpack_collections(*args, traverse=traverse) 

993 if not collections: 

994 return args 

995 

996 schedule = get_scheduler(scheduler=scheduler, collections=collections) 

997 

998 # Protocol: scheduler can provide its own persist method for async behavior. 

999 # For Client-like objects, get_scheduler returns client.get (a bound method), 

1000 # so we check __self__ for the actual client instance. 

1001 client = getattr(schedule, "__self__", schedule) 

1002 if hasattr(client, "persist") and callable(client.persist): 

1003 results = client.persist(collections, optimize_graph=optimize_graph, **kwargs) 

1004 return repack(results) 

1005 

1006 expr = collections_to_expr(collections, optimize_graph) 

1007 expr = expr.optimize() 

1008 keys, postpersists = [], [] 

1009 for a, akeys in zip(collections, expr.__dask_keys__(), strict=True): 

1010 a_keys = list(flatten(akeys)) 

1011 rebuild, state = a.__dask_postpersist__() 

1012 keys.extend(a_keys) 

1013 postpersists.append((rebuild, a_keys, state)) 

1014 

1015 with shorten_traceback(): 

1016 results = schedule(expr, keys, **kwargs) 

1017 

1018 d = dict(zip(keys, results)) 

1019 results2 = [r({k: d[k] for k in ks}, *s) for r, ks, s in postpersists] 

1020 return repack(results2) 

1021 

1022 

1023def _colorize(t): 

1024 """Convert (r, g, b) triple to "#RRGGBB" string 

1025 

1026 For use with ``visualize(color=...)`` 

1027 

1028 Examples 

1029 -------- 

1030 >>> _colorize((255, 255, 255)) 

1031 '#FFFFFF' 

1032 >>> _colorize((0, 32, 128)) 

1033 '#002080' 

1034 """ 

1035 t = t[:3] 

1036 i = sum(v << 8 * i for i, v in enumerate(reversed(t))) 

1037 return f"#{i:>06X}" 

1038 

1039 

1040named_schedulers: dict[str, SchedulerGetCallable] = { 

1041 "sync": local.get_sync, 

1042 "synchronous": local.get_sync, 

1043 "single-threaded": local.get_sync, 

1044} 

1045 

1046if not EMSCRIPTEN: 

1047 from dask import threaded 

1048 

1049 named_schedulers.update( 

1050 { 

1051 "threads": threaded.get, 

1052 "threading": threaded.get, 

1053 } 

1054 ) 

1055 

1056 from dask import multiprocessing as dask_multiprocessing 

1057 

1058 named_schedulers.update( 

1059 { 

1060 "processes": dask_multiprocessing.get, 

1061 "multiprocessing": dask_multiprocessing.get, 

1062 } 

1063 ) 

1064 

1065 

1066get_err_msg = """ 

1067The get= keyword has been removed. 

1068 

1069Please use the scheduler= keyword instead with the name of 

1070the desired scheduler like 'threads' or 'processes' 

1071 

1072 x.compute(scheduler='single-threaded') 

1073 x.compute(scheduler='threads') 

1074 x.compute(scheduler='processes') 

1075 

1076or with a function that takes the graph and keys 

1077 

1078 x.compute(scheduler=my_scheduler_function) 

1079 

1080or with a Dask client 

1081 

1082 x.compute(scheduler=client) 

1083""".strip() 

1084 

1085 

1086def _ensure_not_async(client): 

1087 if client.asynchronous: 

1088 if fallback := config.get("admin.async-client-fallback", None): 

1089 warnings.warn( 

1090 "Distributed Client detected but Client instance is " 

1091 f"asynchronous. Falling back to `{fallback}` scheduler. " 

1092 "To use an asynchronous Client, please use " 

1093 "``Client.compute`` and ``Client.gather`` " 

1094 "instead of the top level ``dask.compute``", 

1095 UserWarning, 

1096 ) 

1097 return get_scheduler(scheduler=fallback) 

1098 else: 

1099 raise RuntimeError( 

1100 "Attempting to use an asynchronous " 

1101 "Client in a synchronous context of `dask.compute`" 

1102 ) 

1103 return client.get 

1104 

1105 

1106def get_scheduler(get=None, scheduler=None, collections=None, cls=None): 

1107 """Get scheduler function 

1108 

1109 There are various ways to specify the scheduler to use: 

1110 

1111 1. Passing in scheduler= parameters 

1112 2. Passing these into global configuration 

1113 3. Using a dask.distributed default Client 

1114 4. Using defaults of a dask collection 

1115 

1116 This function centralizes the logic to determine the right scheduler to use 

1117 from those many options 

1118 """ 

1119 if get: 

1120 raise TypeError(get_err_msg) 

1121 

1122 if scheduler is not None: 

1123 if callable(scheduler): 

1124 return scheduler 

1125 elif "Client" in type(scheduler).__name__ and hasattr(scheduler, "get"): 

1126 return _ensure_not_async(scheduler) 

1127 elif isinstance(scheduler, str): 

1128 scheduler = scheduler.lower() 

1129 

1130 client_available = False 

1131 if _distributed_available(): 

1132 assert _DistributedClient is not None 

1133 with suppress(ValueError): 

1134 _DistributedClient.current(allow_global=True) 

1135 client_available = True 

1136 if scheduler in named_schedulers: 

1137 return named_schedulers[scheduler] 

1138 elif scheduler in ("dask.distributed", "distributed"): 

1139 if not client_available: 

1140 raise RuntimeError( 

1141 f"Requested {scheduler} scheduler but no Client active." 

1142 ) 

1143 assert _get_distributed_client is not None 

1144 client = _get_distributed_client() 

1145 return _ensure_not_async(client) 

1146 else: 

1147 raise ValueError( 

1148 "Expected one of [distributed, {}]".format( 

1149 ", ".join(sorted(named_schedulers)) 

1150 ) 

1151 ) 

1152 elif isinstance(scheduler, Executor): 

1153 # Get `num_workers` from `Executor`'s `_max_workers` attribute. 

1154 # If undefined, fallback to `config` or worst case CPU_COUNT. 

1155 num_workers = getattr(scheduler, "_max_workers", None) 

1156 if num_workers is None: 

1157 num_workers = config.get("num_workers", CPU_COUNT) 

1158 assert isinstance(num_workers, Integral) and num_workers > 0 

1159 return partial(local.get_async, scheduler.submit, num_workers) 

1160 else: 

1161 raise ValueError(f"Unexpected scheduler: {scheduler!r}") 

1162 # else: # try to connect to remote scheduler with this name 

1163 # return get_client(scheduler).get 

1164 

1165 if config.get("scheduler", None): 

1166 return get_scheduler(scheduler=config.get("scheduler", None)) 

1167 

1168 if config.get("get", None): 

1169 raise ValueError(get_err_msg) 

1170 

1171 try: 

1172 from distributed import get_client 

1173 

1174 return _ensure_not_async(get_client()) 

1175 except (ImportError, ValueError): 

1176 pass 

1177 

1178 if cls is not None: 

1179 return cls.__dask_scheduler__ 

1180 

1181 if collections: 

1182 collections = [c for c in collections if c is not None] 

1183 if collections: 

1184 get = collections[0].__dask_scheduler__ 

1185 if not all(c.__dask_scheduler__ == get for c in collections): 

1186 raise ValueError( 

1187 "Compute called on multiple collections with " 

1188 "differing default schedulers. Please specify a " 

1189 "scheduler=` parameter explicitly in compute or " 

1190 "globally with `dask.config.set`." 

1191 ) 

1192 return get 

1193 

1194 return None 

1195 

1196 

1197def wait(x, timeout=None, return_when="ALL_COMPLETED"): 

1198 """Wait until computation has finished 

1199 

1200 This is a compatibility alias for ``dask.distributed.wait``. 

1201 If it is applied onto Dask collections without Dask Futures or if Dask 

1202 distributed is not installed then it is a no-op 

1203 """ 

1204 try: 

1205 from distributed import wait 

1206 

1207 return wait(x, timeout=timeout, return_when=return_when) 

1208 except (ImportError, ValueError): 

1209 return x 

1210 

1211 

1212def get_collection_names(collection) -> set[str]: 

1213 """Infer the collection names from the dask keys, under the assumption that all keys 

1214 are either tuples with matching first element, and that element is a string, or 

1215 there is exactly one key and it is a string. 

1216 

1217 Examples 

1218 -------- 

1219 >>> a.__dask_keys__() # doctest: +SKIP 

1220 ["foo", "bar"] 

1221 >>> get_collection_names(a) # doctest: +SKIP 

1222 {"foo", "bar"} 

1223 >>> b.__dask_keys__() # doctest: +SKIP 

1224 [[("foo-123", 0, 0), ("foo-123", 0, 1)], [("foo-123", 1, 0), ("foo-123", 1, 1)]] 

1225 >>> get_collection_names(b) # doctest: +SKIP 

1226 {"foo-123"} 

1227 """ 

1228 if not is_dask_collection(collection): 

1229 raise TypeError(f"Expected Dask collection; got {type(collection)}") 

1230 return {get_name_from_key(k) for k in flatten(collection.__dask_keys__())} 

1231 

1232 

1233def get_name_from_key(key: Key) -> str: 

1234 """Given a dask collection's key, extract the collection name. 

1235 

1236 Parameters 

1237 ---------- 

1238 key: string or tuple 

1239 Dask collection's key, which must be either a single string or a tuple whose 

1240 first element is a string (commonly referred to as a collection's 'name'), 

1241 

1242 Examples 

1243 -------- 

1244 >>> get_name_from_key("foo") 

1245 'foo' 

1246 >>> get_name_from_key(("foo-123", 1, 2)) 

1247 'foo-123' 

1248 """ 

1249 if isinstance(key, tuple) and key and isinstance(key[0], str): 

1250 return key[0] 

1251 if isinstance(key, str): 

1252 return key 

1253 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}") 

1254 

1255 

1256KeyOrStrT = TypeVar("KeyOrStrT", Key, str) 

1257 

1258 

1259def replace_name_in_key(key: KeyOrStrT, rename: Mapping[str, str]) -> KeyOrStrT: 

1260 """Given a dask collection's key, replace the collection name with a new one. 

1261 

1262 Parameters 

1263 ---------- 

1264 key: string or tuple 

1265 Dask collection's key, which must be either a single string or a tuple whose 

1266 first element is a string (commonly referred to as a collection's 'name'), 

1267 rename: 

1268 Mapping of zero or more names from : to. Extraneous names will be ignored. 

1269 Names not found in this mapping won't be replaced. 

1270 

1271 Examples 

1272 -------- 

1273 >>> replace_name_in_key("foo", {}) 

1274 'foo' 

1275 >>> replace_name_in_key("foo", {"foo": "bar"}) 

1276 'bar' 

1277 >>> replace_name_in_key(("foo-123", 1, 2), {"foo-123": "bar-456"}) 

1278 ('bar-456', 1, 2) 

1279 """ 

1280 if isinstance(key, tuple) and key and isinstance(key[0], str): 

1281 return (rename.get(key[0], key[0]),) + key[1:] 

1282 if isinstance(key, str): 

1283 return rename.get(key, key) 

1284 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}") 

1285 

1286 

1287def clone_key(key: KeyOrStrT, seed: Hashable) -> KeyOrStrT: 

1288 """Clone a key from a Dask collection, producing a new key with the same prefix and 

1289 indices and a token which is a deterministic function of the previous key and seed. 

1290 

1291 Examples 

1292 -------- 

1293 >>> clone_key("x", 123) # doctest: +SKIP 

1294 'x-c4fb64ccca807af85082413d7ef01721' 

1295 >>> clone_key("inc-cbb1eca3bafafbb3e8b2419c4eebb387", 123) # doctest: +SKIP 

1296 'inc-bc629c23014a4472e18b575fdaf29ee7' 

1297 >>> clone_key(("sum-cbb1eca3bafafbb3e8b2419c4eebb387", 4, 3), 123) # doctest: +SKIP 

1298 ('sum-c053f3774e09bd0f7de6044dbc40e71d', 4, 3) 

1299 """ 

1300 if isinstance(key, tuple) and key and isinstance(key[0], str): 

1301 return (clone_key(key[0], seed),) + key[1:] 

1302 if isinstance(key, str): 

1303 prefix = key_split(key) 

1304 return prefix + "-" + tokenize(key, seed) 

1305 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}")