Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/base.py: 18%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

389 statements  

1from __future__ import annotations 

2 

3import dataclasses 

4import inspect 

5import uuid 

6import warnings 

7from collections import OrderedDict 

8from collections.abc import Hashable, Iterable, Iterator, Mapping 

9from concurrent.futures import Executor 

10from contextlib import contextmanager, suppress 

11from contextvars import ContextVar 

12from functools import partial 

13from numbers import Integral, Number 

14from operator import getitem 

15from typing import TYPE_CHECKING, Any, Literal, TypeVar 

16 

17from tlz import merge 

18 

19from dask import config, local 

20from dask._compatibility import EMSCRIPTEN 

21from dask._task_spec import DataNode, Dict, List, Task, TaskRef 

22from dask.core import flatten 

23from dask.core import get as simple_get 

24from dask.system import CPU_COUNT 

25from dask.typing import Key, SchedulerGetCallable 

26from dask.utils import is_namedtuple_instance, key_split, shorten_traceback 

27 

28if TYPE_CHECKING: 

29 from dask._expr import Expr 

30 

31_DistributedClient = None 

32_get_distributed_client = None 

33_DISTRIBUTED_AVAILABLE = None 

34 

35 

36def _distributed_available() -> bool: 

37 # Lazy import in get_scheduler can be expensive 

38 global _DistributedClient, _get_distributed_client, _DISTRIBUTED_AVAILABLE 

39 if _DISTRIBUTED_AVAILABLE is not None: 

40 return _DISTRIBUTED_AVAILABLE # type: ignore[unreachable] 

41 try: 

42 from distributed import Client as _DistributedClient 

43 from distributed.worker import get_client as _get_distributed_client 

44 

45 _DISTRIBUTED_AVAILABLE = True 

46 except ImportError: 

47 _DISTRIBUTED_AVAILABLE = False 

48 return _DISTRIBUTED_AVAILABLE 

49 

50 

51__all__ = ( 

52 "DaskMethodsMixin", 

53 "annotate", 

54 "get_annotations", 

55 "is_dask_collection", 

56 "compute", 

57 "persist", 

58 "optimize", 

59 "visualize", 

60 "tokenize", 

61 "normalize_token", 

62 "get_collection_names", 

63 "get_name_from_key", 

64 "replace_name_in_key", 

65 "clone_key", 

66) 

67 

68# Backwards compat 

69from dask.tokenize import TokenizationError, normalize_token, tokenize # noqa: F401 

70 

71_annotations: ContextVar[dict[str, Any] | None] = ContextVar( 

72 "annotations", default=None 

73) 

74 

75 

76def get_annotations() -> dict[str, Any]: 

77 """Get current annotations. 

78 

79 Returns 

80 ------- 

81 Dict of all current annotations 

82 

83 See Also 

84 -------- 

85 annotate 

86 """ 

87 return _annotations.get() or {} 

88 

89 

90@contextmanager 

91def annotate(**annotations: Any) -> Iterator[None]: 

92 """Context Manager for setting HighLevelGraph Layer annotations. 

93 

94 Annotations are metadata or soft constraints associated with 

95 tasks that dask schedulers may choose to respect: They signal intent 

96 without enforcing hard constraints. As such, they are 

97 primarily designed for use with the distributed scheduler. 

98 

99 Almost any object can serve as an annotation, but small Python objects 

100 are preferred, while large objects such as NumPy arrays are discouraged. 

101 

102 Callables supplied as an annotation should take a single *key* argument and 

103 produce the appropriate annotation. Individual task keys in the annotated collection 

104 are supplied to the callable. 

105 

106 Parameters 

107 ---------- 

108 **annotations : key-value pairs 

109 

110 Examples 

111 -------- 

112 

113 All tasks within array A should have priority 100 and be retried 3 times 

114 on failure. 

115 

116 >>> import dask 

117 >>> import dask.array as da 

118 >>> with dask.annotate(priority=100, retries=3): 

119 ... A = da.ones((10000, 10000)) 

120 

121 Prioritise tasks within Array A on flattened block ID. 

122 

123 >>> nblocks = (10, 10) 

124 >>> with dask.annotate(priority=lambda k: k[1]*nblocks[1] + k[2]): 

125 ... A = da.ones((1000, 1000), chunks=(100, 100)) 

126 

127 Annotations may be nested. 

128 

129 >>> with dask.annotate(priority=1): 

130 ... with dask.annotate(retries=3): 

131 ... A = da.ones((1000, 1000)) 

132 ... B = A + 1 

133 

134 See Also 

135 -------- 

136 get_annotations 

137 """ 

138 

139 # Sanity check annotations used in place of 

140 # legacy distributed Client.{submit, persist, compute} keywords 

141 if "workers" in annotations: 

142 if isinstance(annotations["workers"], (list, set, tuple)): 

143 annotations["workers"] = list(annotations["workers"]) 

144 elif isinstance(annotations["workers"], str): 

145 annotations["workers"] = [annotations["workers"]] 

146 elif callable(annotations["workers"]): 

147 pass 

148 else: 

149 raise TypeError( 

150 "'workers' annotation must be a sequence of str, a str or a callable, but got {}.".format( 

151 annotations["workers"] 

152 ) 

153 ) 

154 

155 if ( 

156 "priority" in annotations 

157 and not isinstance(annotations["priority"], Number) 

158 and not callable(annotations["priority"]) 

159 ): 

160 raise TypeError( 

161 "'priority' annotation must be a Number or a callable, but got {}".format( 

162 annotations["priority"] 

163 ) 

164 ) 

165 

166 if ( 

167 "retries" in annotations 

168 and not isinstance(annotations["retries"], Number) 

169 and not callable(annotations["retries"]) 

170 ): 

171 raise TypeError( 

172 "'retries' annotation must be a Number or a callable, but got {}".format( 

173 annotations["retries"] 

174 ) 

175 ) 

176 

177 if ( 

178 "resources" in annotations 

179 and not isinstance(annotations["resources"], dict) 

180 and not callable(annotations["resources"]) 

181 ): 

182 raise TypeError( 

183 "'resources' annotation must be a dict, but got {}".format( 

184 annotations["resources"] 

185 ) 

186 ) 

187 

188 if ( 

189 "allow_other_workers" in annotations 

190 and not isinstance(annotations["allow_other_workers"], bool) 

191 and not callable(annotations["allow_other_workers"]) 

192 ): 

193 raise TypeError( 

194 "'allow_other_workers' annotations must be a bool or a callable, but got {}".format( 

195 annotations["allow_other_workers"] 

196 ) 

197 ) 

198 ctx_annot = _annotations.get() 

199 if ctx_annot is None: 

200 ctx_annot = {} 

201 token = _annotations.set(merge(ctx_annot, annotations)) 

202 try: 

203 yield 

204 finally: 

205 _annotations.reset(token) 

206 

207 

208def is_dask_collection(x) -> bool: 

209 """Returns ``True`` if ``x`` is a dask collection. 

210 

211 Parameters 

212 ---------- 

213 x : Any 

214 Object to test. 

215 

216 Returns 

217 ------- 

218 result : bool 

219 ``True`` if `x` is a Dask collection. 

220 

221 Notes 

222 ----- 

223 The DaskCollection typing.Protocol implementation defines a Dask 

224 collection as a class that returns a Mapping from the 

225 ``__dask_graph__`` method. This helper function existed before the 

226 implementation of the protocol. 

227 

228 """ 

229 if ( 

230 isinstance(x, type) 

231 or not hasattr(x, "__dask_graph__") 

232 or not callable(x.__dask_graph__) 

233 ): 

234 return False 

235 

236 pkg_name = getattr(type(x), "__module__", "") 

237 if pkg_name.split(".")[0] in ("dask_cudf",): 

238 # Temporary hack to avoid graph materialization. Note that this won't work with 

239 # dask_expr.array objects wrapped by xarray or pint. By the time dask_expr.array 

240 # is published, we hope to be able to rewrite this method completely. 

241 # Read: https://github.com/dask/dask/pull/10676 

242 return True 

243 elif pkg_name.startswith("dask.dataframe.dask_expr"): 

244 return True 

245 elif pkg_name.startswith("dask.array._array_expr"): 

246 return True 

247 

248 # xarray, pint, and possibly other wrappers always define a __dask_graph__ method, 

249 # but it may return None if they wrap around a non-dask object. 

250 # In all known dask collections other than dask-expr, 

251 # calling __dask_graph__ is cheap. 

252 return x.__dask_graph__() is not None 

253 

254 

255class DaskMethodsMixin: 

256 """A mixin adding standard dask collection methods""" 

257 

258 __slots__ = ("__weakref__",) 

259 

260 def visualize(self, filename="mydask", format=None, optimize_graph=False, **kwargs): 

261 """Render the computation of this object's task graph using graphviz. 

262 

263 Requires ``graphviz`` to be installed. 

264 

265 Parameters 

266 ---------- 

267 filename : str or None, optional 

268 The name of the file to write to disk. If the provided `filename` 

269 doesn't include an extension, '.png' will be used by default. 

270 If `filename` is None, no file will be written, and we communicate 

271 with dot using only pipes. 

272 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional 

273 Format in which to write output file. Default is 'png'. 

274 optimize_graph : bool, optional 

275 If True, the graph is optimized before rendering. Otherwise, 

276 the graph is displayed as is. Default is False. 

277 color: {None, 'order'}, optional 

278 Options to color nodes. Provide ``cmap=`` keyword for additional 

279 colormap 

280 **kwargs 

281 Additional keyword arguments to forward to ``to_graphviz``. 

282 

283 Examples 

284 -------- 

285 >>> x.visualize(filename='dask.pdf') # doctest: +SKIP 

286 >>> x.visualize(filename='dask.pdf', color='order') # doctest: +SKIP 

287 

288 Returns 

289 ------- 

290 result : IPython.display.Image, IPython.display.SVG, or None 

291 See dask.dot.dot_graph for more information. 

292 

293 See Also 

294 -------- 

295 dask.visualize 

296 dask.dot.dot_graph 

297 

298 Notes 

299 ----- 

300 For more information on optimization see here: 

301 

302 https://docs.dask.org/en/latest/optimize.html 

303 """ 

304 return visualize( 

305 self, 

306 filename=filename, 

307 format=format, 

308 optimize_graph=optimize_graph, 

309 **kwargs, 

310 ) 

311 

312 def persist(self, **kwargs): 

313 """Persist this dask collection into memory 

314 

315 This turns a lazy Dask collection into a Dask collection with the same 

316 metadata, but now with the results fully computed or actively computing 

317 in the background. 

318 

319 The action of function differs significantly depending on the active 

320 task scheduler. If the task scheduler supports asynchronous computing, 

321 such as is the case of the dask.distributed scheduler, then persist 

322 will return *immediately* and the return value's task graph will 

323 contain Dask Future objects. However if the task scheduler only 

324 supports blocking computation then the call to persist will *block* 

325 and the return value's task graph will contain concrete Python results. 

326 

327 This function is particularly useful when using distributed systems, 

328 because the results will be kept in distributed memory, rather than 

329 returned to the local process as with compute. 

330 

331 Parameters 

332 ---------- 

333 scheduler : string, optional 

334 Which scheduler to use like "threads", "synchronous" or "processes". 

335 If not provided, the default is to check the global settings first, 

336 and then fall back to the collection defaults. 

337 optimize_graph : bool, optional 

338 If True [default], the graph is optimized before computation. 

339 Otherwise the graph is run as is. This can be useful for debugging. 

340 **kwargs 

341 Extra keywords to forward to the scheduler function. 

342 

343 Returns 

344 ------- 

345 New dask collections backed by in-memory data 

346 

347 See Also 

348 -------- 

349 dask.persist 

350 """ 

351 (result,) = persist(self, traverse=False, **kwargs) 

352 return result 

353 

354 def compute(self, **kwargs): 

355 """Compute this dask collection 

356 

357 This turns a lazy Dask collection into its in-memory equivalent. 

358 For example a Dask array turns into a NumPy array and a Dask dataframe 

359 turns into a Pandas dataframe. The entire dataset must fit into memory 

360 before calling this operation. 

361 

362 Parameters 

363 ---------- 

364 scheduler : string, optional 

365 Which scheduler to use like "threads", "synchronous" or "processes". 

366 If not provided, the default is to check the global settings first, 

367 and then fall back to the collection defaults. 

368 optimize_graph : bool, optional 

369 If True [default], the graph is optimized before computation. 

370 Otherwise the graph is run as is. This can be useful for debugging. 

371 kwargs 

372 Extra keywords to forward to the scheduler function. 

373 

374 See Also 

375 -------- 

376 dask.compute 

377 """ 

378 (result,) = compute(self, traverse=False, **kwargs) 

379 return result 

380 

381 def __await__(self): 

382 try: 

383 from distributed import futures_of, wait 

384 except ImportError as e: 

385 raise ImportError( 

386 "Using async/await with dask requires the `distributed` package" 

387 ) from e 

388 

389 async def f(): 

390 if futures_of(self): 

391 await wait(self) 

392 return self 

393 

394 return f().__await__() 

395 

396 

397def compute_as_if_collection(cls, dsk, keys, scheduler=None, get=None, **kwargs): 

398 """Compute a graph as if it were of type cls. 

399 

400 Allows for applying the same optimizations and default scheduler.""" 

401 schedule = get_scheduler(scheduler=scheduler, cls=cls, get=get) 

402 dsk2 = optimization_function(cls)(dsk, keys, **kwargs) 

403 return schedule(dsk2, keys, **kwargs) 

404 

405 

406def dont_optimize(dsk, keys, **kwargs): 

407 return dsk 

408 

409 

410def optimization_function(x): 

411 return getattr(x, "__dask_optimize__", dont_optimize) 

412 

413 

414def collections_to_expr( 

415 collections: Iterable, 

416 optimize_graph: bool = True, 

417) -> Expr: 

418 """ 

419 Convert many collections into a single dask expression. 

420 

421 Typically, users should not be required to interact with this function. 

422 

423 Parameters 

424 ---------- 

425 collections : Iterable 

426 An iterable of dask collections to be combined. 

427 optimize_graph : bool, optional 

428 If this is True and collections are encountered which are backed by 

429 legacy HighLevelGraph objects, the returned Expression will run a low 

430 level task optimization during materialization. 

431 """ 

432 is_iterable = False 

433 if isinstance(collections, (tuple, list, set)): 

434 is_iterable = True 

435 else: 

436 collections = [collections] 

437 if not collections: 

438 raise ValueError("No collections provided") 

439 from dask._expr import HLGExpr, _ExprSequence 

440 

441 graphs = [] 

442 for coll in collections: 

443 from dask.delayed import Delayed 

444 

445 if isinstance(coll, Delayed) or not hasattr(coll, "expr"): 

446 graphs.append(HLGExpr.from_collection(coll, optimize_graph=optimize_graph)) 

447 else: 

448 graphs.append(coll.expr) 

449 

450 if len(graphs) > 1 or is_iterable: 

451 return _ExprSequence(*graphs) 

452 else: 

453 return graphs[0] 

454 

455 

456def unpack_collections(*args, traverse=True): 

457 """Extract collections in preparation for compute/persist/etc... 

458 

459 Intended use is to find all collections in a set of (possibly nested) 

460 python objects, do something to them (compute, etc...), then repackage them 

461 in equivalent python objects. 

462 

463 Parameters 

464 ---------- 

465 *args 

466 Any number of objects. If it is a dask collection, it's extracted and 

467 added to the list of collections returned. By default, python builtin 

468 collections are also traversed to look for dask collections (for more 

469 information see the ``traverse`` keyword). 

470 traverse : bool, optional 

471 If True (default), builtin python collections are traversed looking for 

472 any dask collections they might contain. 

473 

474 Returns 

475 ------- 

476 collections : list 

477 A list of all dask collections contained in ``args`` 

478 repack : callable 

479 A function to call on the transformed collections to repackage them as 

480 they were in the original ``args``. 

481 """ 

482 

483 collections = [] 

484 repack_dsk = {} 

485 

486 collections_token = uuid.uuid4().hex 

487 

488 def _unpack(expr): 

489 if is_dask_collection(expr): 

490 tok = tokenize(expr) 

491 if tok not in repack_dsk: 

492 repack_dsk[tok] = Task( 

493 tok, getitem, TaskRef(collections_token), len(collections) 

494 ) 

495 collections.append(expr) 

496 return TaskRef(tok) 

497 

498 tok = uuid.uuid4().hex 

499 tsk: DataNode | Task # type: ignore 

500 if not traverse: 

501 tsk = DataNode(None, expr) 

502 else: 

503 # Treat iterators like lists 

504 typ = list if isinstance(expr, Iterator) else type(expr) 

505 if typ in (list, tuple, set): 

506 tsk = Task(tok, typ, List(*[_unpack(i) for i in expr])) 

507 elif typ in (dict, OrderedDict): 

508 tsk = Task( 

509 tok, typ, Dict({_unpack(k): _unpack(v) for k, v in expr.items()}) 

510 ) 

511 elif dataclasses.is_dataclass(expr) and not isinstance(expr, type): 

512 tsk = Task( 

513 tok, 

514 typ, 

515 *[_unpack(getattr(expr, f.name)) for f in dataclasses.fields(expr)], 

516 ) 

517 elif is_namedtuple_instance(expr): 

518 tsk = Task(tok, typ, *[_unpack(i) for i in expr]) 

519 else: 

520 return expr 

521 

522 repack_dsk[tok] = tsk 

523 return TaskRef(tok) 

524 

525 out = uuid.uuid4().hex 

526 repack_dsk[out] = Task(out, tuple, List(*[_unpack(i) for i in args])) 

527 

528 def repack(results): 

529 dsk = repack_dsk.copy() 

530 dsk[collections_token] = DataNode(collections_token, results) 

531 return simple_get(dsk, out) 

532 

533 # The original `collections` is kept alive by the closure 

534 # This causes the collection to be only freed by the garbage collector 

535 collections2 = list(collections) 

536 collections.clear() 

537 return collections2, repack 

538 

539 

540def optimize(*args, traverse=True, **kwargs): 

541 """Optimize several dask collections at once. 

542 

543 Returns equivalent dask collections that all share the same merged and 

544 optimized underlying graph. This can be useful if converting multiple 

545 collections to delayed objects, or to manually apply the optimizations at 

546 strategic points. 

547 

548 Note that in most cases you shouldn't need to call this function directly. 

549 

550 Warning:: 

551 

552 This function triggers a materialization of the collections and looses 

553 any annotations attached to HLG layers. 

554 

555 Parameters 

556 ---------- 

557 *args : objects 

558 Any number of objects. If a dask object, its graph is optimized and 

559 merged with all those of all other dask objects before returning an 

560 equivalent dask collection. Non-dask arguments are passed through 

561 unchanged. 

562 traverse : bool, optional 

563 By default dask traverses builtin python collections looking for dask 

564 objects passed to ``optimize``. For large collections this can be 

565 expensive. If none of the arguments contain any dask objects, set 

566 ``traverse=False`` to avoid doing this traversal. 

567 optimizations : list of callables, optional 

568 Additional optimization passes to perform. 

569 **kwargs 

570 Extra keyword arguments to forward to the optimization passes. 

571 

572 Examples 

573 -------- 

574 >>> import dask 

575 >>> import dask.array as da 

576 >>> a = da.arange(10, chunks=2).sum() 

577 >>> b = da.arange(10, chunks=2).mean() 

578 >>> a2, b2 = dask.optimize(a, b) 

579 

580 >>> a2.compute() == a.compute() 

581 np.True_ 

582 >>> b2.compute() == b.compute() 

583 np.True_ 

584 """ 

585 # TODO: This API is problematic. The approach to using postpersist forces us 

586 # to materialize the graph. Most low level optimizations will materialize as 

587 # well 

588 collections, repack = unpack_collections(*args, traverse=traverse) 

589 if not collections: 

590 return args 

591 

592 dsk = collections_to_expr(collections) 

593 

594 postpersists = [] 

595 for a in collections: 

596 r, s = a.__dask_postpersist__() 

597 postpersists.append(r(dsk.__dask_graph__(), *s)) 

598 

599 return repack(postpersists) 

600 

601 

602def compute( 

603 *args, 

604 traverse=True, 

605 optimize_graph=True, 

606 scheduler=None, 

607 get=None, 

608 **kwargs, 

609): 

610 """Compute several dask collections at once. 

611 

612 Parameters 

613 ---------- 

614 args : object 

615 Any number of objects. If it is a dask object, it's computed and the 

616 result is returned. By default, python builtin collections are also 

617 traversed to look for dask objects (for more information see the 

618 ``traverse`` keyword). Non-dask arguments are passed through unchanged. 

619 traverse : bool, optional 

620 By default dask traverses builtin python collections looking for dask 

621 objects passed to ``compute``. For large collections this can be 

622 expensive. If none of the arguments contain any dask objects, set 

623 ``traverse=False`` to avoid doing this traversal. 

624 scheduler : string, optional 

625 Which scheduler to use like "threads", "synchronous" or "processes". 

626 If not provided, the default is to check the global settings first, 

627 and then fall back to the collection defaults. 

628 optimize_graph : bool, optional 

629 If True [default], the optimizations for each collection are applied 

630 before computation. Otherwise the graph is run as is. This can be 

631 useful for debugging. 

632 get : ``None`` 

633 Should be left to ``None`` The get= keyword has been removed. 

634 kwargs 

635 Extra keywords to forward to the scheduler function. 

636 

637 Examples 

638 -------- 

639 >>> import dask 

640 >>> import dask.array as da 

641 >>> a = da.arange(10, chunks=2).sum() 

642 >>> b = da.arange(10, chunks=2).mean() 

643 >>> dask.compute(a, b) 

644 (np.int64(45), np.float64(4.5)) 

645 

646 By default, dask objects inside python collections will also be computed: 

647 

648 >>> dask.compute({'a': a, 'b': b, 'c': 1}) 

649 ({'a': np.int64(45), 'b': np.float64(4.5), 'c': 1},) 

650 """ 

651 

652 collections, repack = unpack_collections(*args, traverse=traverse) 

653 if not collections: 

654 return args 

655 

656 schedule = get_scheduler( 

657 scheduler=scheduler, 

658 collections=collections, 

659 get=get, 

660 ) 

661 from dask._expr import FinalizeCompute 

662 

663 expr = collections_to_expr(collections, optimize_graph) 

664 expr = FinalizeCompute(expr) 

665 

666 with shorten_traceback(): 

667 # The high level optimize will have to be called client side (for now) 

668 # The optimize can internally trigger already a computation 

669 # (e.g. parquet is reading some statistics). To move this to the 

670 # scheduler we'd need some sort of scheduler-client to trigger a 

671 # computation from inside the scheduler and continue with optimization 

672 # once the results are in. An alternative could be to introduce a 

673 # pre-optimize step for the Expressions that handles steps like these as 

674 # a dedicated computation 

675 

676 # Another caveat is that optimize will only lock in the expression names 

677 # after optimization. Names are determined using tokenize and tokenize 

678 # is not cross-interpreter (let alone cross-host) stable such that we 

679 # have to lock this in before sending stuff (otherwise we'd need to 

680 # change the graph submission to a handshake which introduces all sorts 

681 # of concurrency control issues) 

682 

683 expr = expr.optimize() 

684 keys = list(flatten(expr.__dask_keys__())) 

685 

686 results = schedule(expr, keys, **kwargs) 

687 

688 return repack(results) 

689 

690 

691def visualize( 

692 *args, 

693 filename="mydask", 

694 traverse=True, 

695 optimize_graph=False, 

696 maxval=None, 

697 engine: Literal["cytoscape", "ipycytoscape", "graphviz"] | None = None, 

698 **kwargs, 

699): 

700 """ 

701 Visualize several dask graphs simultaneously. 

702 

703 Requires ``graphviz`` to be installed. All options that are not the dask 

704 graph(s) should be passed as keyword arguments. 

705 

706 Parameters 

707 ---------- 

708 args : object 

709 Any number of objects. If it is a dask collection (for example, a 

710 dask DataFrame, Array, Bag, or Delayed), its associated graph 

711 will be included in the output of visualize. By default, python builtin 

712 collections are also traversed to look for dask objects (for more 

713 information see the ``traverse`` keyword). Arguments lacking an 

714 associated graph will be ignored. 

715 filename : str or None, optional 

716 The name of the file to write to disk. If the provided `filename` 

717 doesn't include an extension, '.png' will be used by default. 

718 If `filename` is None, no file will be written, and we communicate 

719 with dot using only pipes. 

720 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional 

721 Format in which to write output file. Default is 'png'. 

722 traverse : bool, optional 

723 By default, dask traverses builtin python collections looking for dask 

724 objects passed to ``visualize``. For large collections this can be 

725 expensive. If none of the arguments contain any dask objects, set 

726 ``traverse=False`` to avoid doing this traversal. 

727 optimize_graph : bool, optional 

728 If True, the graph is optimized before rendering. Otherwise, 

729 the graph is displayed as is. Default is False. 

730 color : {None, 'order', 'ages', 'freed', 'memoryincreases', 'memorydecreases', 'memorypressure'}, optional 

731 Options to color nodes. colormap: 

732 

733 - None, the default, no colors. 

734 - 'order', colors the nodes' border based on the order they appear in the graph. 

735 - 'ages', how long the data of a node is held. 

736 - 'freed', the number of dependencies released after running a node. 

737 - 'memoryincreases', how many more outputs are held after the lifetime of a node. 

738 Large values may indicate nodes that should have run later. 

739 - 'memorydecreases', how many fewer outputs are held after the lifetime of a node. 

740 Large values may indicate nodes that should have run sooner. 

741 - 'memorypressure', the number of data held when the node is run (circle), or 

742 the data is released (rectangle). 

743 maxval : {int, float}, optional 

744 Maximum value for colormap to normalize form 0 to 1.0. Default is ``None`` 

745 will make it the max number of values 

746 collapse_outputs : bool, optional 

747 Whether to collapse output boxes, which often have empty labels. 

748 Default is False. 

749 verbose : bool, optional 

750 Whether to label output and input boxes even if the data aren't chunked. 

751 Beware: these labels can get very long. Default is False. 

752 engine : {"graphviz", "ipycytoscape", "cytoscape"}, optional. 

753 The visualization engine to use. If not provided, this checks the dask config 

754 value "visualization.engine". If that is not set, it tries to import ``graphviz`` 

755 and ``ipycytoscape``, using the first one to succeed. 

756 **kwargs 

757 Additional keyword arguments to forward to the visualization engine. 

758 

759 Examples 

760 -------- 

761 >>> x.visualize(filename='dask.pdf') # doctest: +SKIP 

762 >>> x.visualize(filename='dask.pdf', color='order') # doctest: +SKIP 

763 

764 Returns 

765 ------- 

766 result : IPython.display.Image, IPython.display.SVG, or None 

767 See dask.dot.dot_graph for more information. 

768 

769 See Also 

770 -------- 

771 dask.dot.dot_graph 

772 

773 Notes 

774 ----- 

775 For more information on optimization see here: 

776 

777 https://docs.dask.org/en/latest/optimize.html 

778 """ 

779 args, _ = unpack_collections(*args, traverse=traverse) 

780 

781 dsk = collections_to_expr(args, optimize_graph=optimize_graph).__dask_graph__() 

782 

783 return visualize_dsk( 

784 dsk=dsk, 

785 filename=filename, 

786 traverse=traverse, 

787 optimize_graph=optimize_graph, 

788 maxval=maxval, 

789 engine=engine, 

790 **kwargs, 

791 ) 

792 

793 

794def visualize_dsk( 

795 dsk, 

796 filename="mydask", 

797 traverse=True, 

798 optimize_graph=False, 

799 maxval=None, 

800 o=None, 

801 engine: Literal["cytoscape", "ipycytoscape", "graphviz"] | None = None, 

802 limit=None, 

803 **kwargs, 

804): 

805 color = kwargs.get("color") 

806 from dask.order import diagnostics, order 

807 

808 if color in { 

809 "order", 

810 "order-age", 

811 "order-freed", 

812 "order-memoryincreases", 

813 "order-memorydecreases", 

814 "order-memorypressure", 

815 "age", 

816 "freed", 

817 "memoryincreases", 

818 "memorydecreases", 

819 "memorypressure", 

820 "critical", 

821 "cpath", 

822 }: 

823 import matplotlib.pyplot as plt 

824 

825 if o is None: 

826 o_stats = order(dsk, return_stats=True) 

827 o = {k: v.priority for k, v in o_stats.items()} 

828 elif isinstance(next(iter(o.values())), int): 

829 o_stats = order(dsk, return_stats=True) 

830 else: 

831 o_stats = o 

832 o = {k: v.priority for k, v in o.items()} 

833 

834 try: 

835 cmap = kwargs.pop("cmap") 

836 except KeyError: 

837 cmap = plt.cm.plasma 

838 if isinstance(cmap, str): 

839 import matplotlib.pyplot as plt 

840 

841 cmap = getattr(plt.cm, cmap) 

842 

843 def label(x): 

844 return str(values[x]) 

845 

846 data_values = None 

847 if color != "order": 

848 info = diagnostics(dsk, o)[0] 

849 if color.endswith("age"): 

850 values = {key: val.age for key, val in info.items()} 

851 elif color.endswith("freed"): 

852 values = {key: val.num_dependencies_freed for key, val in info.items()} 

853 elif color.endswith("memorypressure"): 

854 values = {key: val.num_data_when_run for key, val in info.items()} 

855 data_values = { 

856 key: val.num_data_when_released for key, val in info.items() 

857 } 

858 elif color.endswith("memoryincreases"): 

859 values = { 

860 key: max(0, val.num_data_when_released - val.num_data_when_run) 

861 for key, val in info.items() 

862 } 

863 elif color.endswith("memorydecreases"): 

864 values = { 

865 key: max(0, val.num_data_when_run - val.num_data_when_released) 

866 for key, val in info.items() 

867 } 

868 elif color.split("-")[-1] in {"critical", "cpath"}: 

869 values = {key: val.critical_path for key, val in o_stats.items()} 

870 else: 

871 raise NotImplementedError(color) 

872 

873 if color.startswith("order-"): 

874 

875 def label(x): 

876 return str(o[x]) + "-" + str(values[x]) 

877 

878 else: 

879 values = o 

880 if maxval is None: 

881 maxval = max(1, *values.values()) 

882 colors = { 

883 k: _colorize(tuple(map(int, cmap(v / maxval, bytes=True)))) 

884 for k, v in values.items() 

885 } 

886 if data_values is None: 

887 data_colors = colors 

888 else: 

889 data_colors = { 

890 k: _colorize(tuple(map(int, cmap(v / maxval, bytes=True)))) 

891 for k, v in values.items() 

892 } 

893 

894 kwargs["function_attributes"] = { 

895 k: {"color": v, "label": label(k)} for k, v in colors.items() 

896 } 

897 kwargs["data_attributes"] = {k: {"color": v} for k, v in data_colors.items()} 

898 elif color: 

899 raise NotImplementedError(f"Unknown value color={color}") 

900 

901 # Determine which engine to dispatch to, first checking the kwarg, then config, 

902 # then whichever of graphviz or ipycytoscape are installed, in that order. 

903 engine = engine or config.get("visualization.engine", None) 

904 

905 if not engine: 

906 try: 

907 import graphviz # noqa: F401 

908 

909 engine = "graphviz" 

910 except ImportError: 

911 try: 

912 import ipycytoscape # noqa: F401 

913 

914 engine = "cytoscape" 

915 except ImportError: 

916 pass 

917 if engine == "graphviz": 

918 from dask.dot import dot_graph 

919 

920 return dot_graph(dsk, filename=filename, **kwargs) 

921 elif engine in ("cytoscape", "ipycytoscape"): 

922 from dask.dot import cytoscape_graph 

923 

924 return cytoscape_graph(dsk, filename=filename, **kwargs) 

925 elif engine is None: 

926 raise RuntimeError( 

927 "No visualization engine detected, please install graphviz or ipycytoscape" 

928 ) 

929 else: 

930 raise ValueError(f"Visualization engine {engine} not recognized") 

931 

932 

933def persist(*args, traverse=True, optimize_graph=True, scheduler=None, **kwargs): 

934 """Persist multiple Dask collections into memory 

935 

936 This turns lazy Dask collections into Dask collections with the same 

937 metadata, but now with their results fully computed or actively computing 

938 in the background. 

939 

940 For example a lazy dask.array built up from many lazy calls will now be a 

941 dask.array of the same shape, dtype, chunks, etc., but now with all of 

942 those previously lazy tasks either computed in memory as many small :class:`numpy.array` 

943 (in the single-machine case) or asynchronously running in the 

944 background on a cluster (in the distributed case). 

945 

946 This function operates differently if a ``dask.distributed.Client`` exists 

947 and is connected to a distributed scheduler. In this case this function 

948 will return as soon as the task graph has been submitted to the cluster, 

949 but before the computations have completed. Computations will continue 

950 asynchronously in the background. When using this function with the single 

951 machine scheduler it blocks until the computations have finished. 

952 

953 When using Dask on a single machine you should ensure that the dataset fits 

954 entirely within memory. 

955 

956 Examples 

957 -------- 

958 >>> df = dd.read_csv('/path/to/*.csv') # doctest: +SKIP 

959 >>> df = df[df.name == 'Alice'] # doctest: +SKIP 

960 >>> df['in-debt'] = df.balance < 0 # doctest: +SKIP 

961 >>> df = df.persist() # triggers computation # doctest: +SKIP 

962 

963 >>> df.value().min() # future computations are now fast # doctest: +SKIP 

964 -10 

965 >>> df.value().max() # doctest: +SKIP 

966 100 

967 

968 >>> from dask import persist # use persist function on multiple collections 

969 >>> a, b = persist(a, b) # doctest: +SKIP 

970 

971 Parameters 

972 ---------- 

973 *args: Dask collections 

974 scheduler : string, optional 

975 Which scheduler to use like "threads", "synchronous" or "processes". 

976 If not provided, the default is to check the global settings first, 

977 and then fall back to the collection defaults. 

978 traverse : bool, optional 

979 By default dask traverses builtin python collections looking for dask 

980 objects passed to ``persist``. For large collections this can be 

981 expensive. If none of the arguments contain any dask objects, set 

982 ``traverse=False`` to avoid doing this traversal. 

983 optimize_graph : bool, optional 

984 If True [default], the graph is optimized before computation. 

985 Otherwise the graph is run as is. This can be useful for debugging. 

986 **kwargs 

987 Extra keywords to forward to the scheduler function. 

988 

989 Returns 

990 ------- 

991 New dask collections backed by in-memory data 

992 """ 

993 collections, repack = unpack_collections(*args, traverse=traverse) 

994 if not collections: 

995 return args 

996 

997 schedule = get_scheduler(scheduler=scheduler, collections=collections) 

998 

999 if inspect.ismethod(schedule): 

1000 try: 

1001 from distributed.client import default_client 

1002 except ImportError: 

1003 pass 

1004 else: 

1005 try: 

1006 client = default_client() 

1007 except ValueError: 

1008 pass 

1009 else: 

1010 if client.get == schedule: 

1011 results = client.persist( 

1012 collections, optimize_graph=optimize_graph, **kwargs 

1013 ) 

1014 return repack(results) 

1015 

1016 expr = collections_to_expr(collections, optimize_graph) 

1017 expr = expr.optimize() 

1018 keys, postpersists = [], [] 

1019 for a, akeys in zip(collections, expr.__dask_keys__(), strict=True): 

1020 a_keys = list(flatten(akeys)) 

1021 rebuild, state = a.__dask_postpersist__() 

1022 keys.extend(a_keys) 

1023 postpersists.append((rebuild, a_keys, state)) 

1024 

1025 with shorten_traceback(): 

1026 results = schedule(expr, keys, **kwargs) 

1027 

1028 d = dict(zip(keys, results)) 

1029 results2 = [r({k: d[k] for k in ks}, *s) for r, ks, s in postpersists] 

1030 return repack(results2) 

1031 

1032 

1033def _colorize(t): 

1034 """Convert (r, g, b) triple to "#RRGGBB" string 

1035 

1036 For use with ``visualize(color=...)`` 

1037 

1038 Examples 

1039 -------- 

1040 >>> _colorize((255, 255, 255)) 

1041 '#FFFFFF' 

1042 >>> _colorize((0, 32, 128)) 

1043 '#002080' 

1044 """ 

1045 t = t[:3] 

1046 i = sum(v * 256 ** (len(t) - i - 1) for i, v in enumerate(t)) 

1047 h = hex(int(i))[2:].upper() 

1048 h = "0" * (6 - len(h)) + h 

1049 return "#" + h 

1050 

1051 

1052named_schedulers: dict[str, SchedulerGetCallable] = { 

1053 "sync": local.get_sync, 

1054 "synchronous": local.get_sync, 

1055 "single-threaded": local.get_sync, 

1056} 

1057 

1058if not EMSCRIPTEN: 

1059 from dask import threaded 

1060 

1061 named_schedulers.update( 

1062 { 

1063 "threads": threaded.get, 

1064 "threading": threaded.get, 

1065 } 

1066 ) 

1067 

1068 from dask import multiprocessing as dask_multiprocessing 

1069 

1070 named_schedulers.update( 

1071 { 

1072 "processes": dask_multiprocessing.get, 

1073 "multiprocessing": dask_multiprocessing.get, 

1074 } 

1075 ) 

1076 

1077 

1078get_err_msg = """ 

1079The get= keyword has been removed. 

1080 

1081Please use the scheduler= keyword instead with the name of 

1082the desired scheduler like 'threads' or 'processes' 

1083 

1084 x.compute(scheduler='single-threaded') 

1085 x.compute(scheduler='threads') 

1086 x.compute(scheduler='processes') 

1087 

1088or with a function that takes the graph and keys 

1089 

1090 x.compute(scheduler=my_scheduler_function) 

1091 

1092or with a Dask client 

1093 

1094 x.compute(scheduler=client) 

1095""".strip() 

1096 

1097 

1098def _ensure_not_async(client): 

1099 if client.asynchronous: 

1100 if fallback := config.get("admin.async-client-fallback", None): 

1101 warnings.warn( 

1102 "Distributed Client detected but Client instance is " 

1103 f"asynchronous. Falling back to `{fallback}` scheduler. " 

1104 "To use an asynchronous Client, please use " 

1105 "``Client.compute`` and ``Client.gather`` " 

1106 "instead of the top level ``dask.compute``", 

1107 UserWarning, 

1108 ) 

1109 return get_scheduler(scheduler=fallback) 

1110 else: 

1111 raise RuntimeError( 

1112 "Attempting to use an asynchronous " 

1113 "Client in a synchronous context of `dask.compute`" 

1114 ) 

1115 return client.get 

1116 

1117 

1118def get_scheduler(get=None, scheduler=None, collections=None, cls=None): 

1119 """Get scheduler function 

1120 

1121 There are various ways to specify the scheduler to use: 

1122 

1123 1. Passing in scheduler= parameters 

1124 2. Passing these into global configuration 

1125 3. Using a dask.distributed default Client 

1126 4. Using defaults of a dask collection 

1127 

1128 This function centralizes the logic to determine the right scheduler to use 

1129 from those many options 

1130 """ 

1131 if get: 

1132 raise TypeError(get_err_msg) 

1133 

1134 if scheduler is not None: 

1135 if callable(scheduler): 

1136 return scheduler 

1137 elif "Client" in type(scheduler).__name__ and hasattr(scheduler, "get"): 

1138 return _ensure_not_async(scheduler) 

1139 elif isinstance(scheduler, str): 

1140 scheduler = scheduler.lower() 

1141 

1142 client_available = False 

1143 if _distributed_available(): 

1144 assert _DistributedClient is not None 

1145 with suppress(ValueError): 

1146 _DistributedClient.current(allow_global=True) 

1147 client_available = True 

1148 if scheduler in named_schedulers: 

1149 return named_schedulers[scheduler] 

1150 elif scheduler in ("dask.distributed", "distributed"): 

1151 if not client_available: 

1152 raise RuntimeError( 

1153 f"Requested {scheduler} scheduler but no Client active." 

1154 ) 

1155 assert _get_distributed_client is not None 

1156 client = _get_distributed_client() 

1157 return _ensure_not_async(client) 

1158 else: 

1159 raise ValueError( 

1160 "Expected one of [distributed, {}]".format( 

1161 ", ".join(sorted(named_schedulers)) 

1162 ) 

1163 ) 

1164 elif isinstance(scheduler, Executor): 

1165 # Get `num_workers` from `Executor`'s `_max_workers` attribute. 

1166 # If undefined, fallback to `config` or worst case CPU_COUNT. 

1167 num_workers = getattr(scheduler, "_max_workers", None) 

1168 if num_workers is None: 

1169 num_workers = config.get("num_workers", CPU_COUNT) 

1170 assert isinstance(num_workers, Integral) and num_workers > 0 

1171 return partial(local.get_async, scheduler.submit, num_workers) 

1172 else: 

1173 raise ValueError(f"Unexpected scheduler: {scheduler!r}") 

1174 # else: # try to connect to remote scheduler with this name 

1175 # return get_client(scheduler).get 

1176 

1177 if config.get("scheduler", None): 

1178 return get_scheduler(scheduler=config.get("scheduler", None)) 

1179 

1180 if config.get("get", None): 

1181 raise ValueError(get_err_msg) 

1182 

1183 try: 

1184 from distributed import get_client 

1185 

1186 return _ensure_not_async(get_client()) 

1187 except (ImportError, ValueError): 

1188 pass 

1189 

1190 if cls is not None: 

1191 return cls.__dask_scheduler__ 

1192 

1193 if collections: 

1194 collections = [c for c in collections if c is not None] 

1195 if collections: 

1196 get = collections[0].__dask_scheduler__ 

1197 if not all(c.__dask_scheduler__ == get for c in collections): 

1198 raise ValueError( 

1199 "Compute called on multiple collections with " 

1200 "differing default schedulers. Please specify a " 

1201 "scheduler=` parameter explicitly in compute or " 

1202 "globally with `dask.config.set`." 

1203 ) 

1204 return get 

1205 

1206 return None 

1207 

1208 

1209def wait(x, timeout=None, return_when="ALL_COMPLETED"): 

1210 """Wait until computation has finished 

1211 

1212 This is a compatibility alias for ``dask.distributed.wait``. 

1213 If it is applied onto Dask collections without Dask Futures or if Dask 

1214 distributed is not installed then it is a no-op 

1215 """ 

1216 try: 

1217 from distributed import wait 

1218 

1219 return wait(x, timeout=timeout, return_when=return_when) 

1220 except (ImportError, ValueError): 

1221 return x 

1222 

1223 

1224def get_collection_names(collection) -> set[str]: 

1225 """Infer the collection names from the dask keys, under the assumption that all keys 

1226 are either tuples with matching first element, and that element is a string, or 

1227 there is exactly one key and it is a string. 

1228 

1229 Examples 

1230 -------- 

1231 >>> a.__dask_keys__() # doctest: +SKIP 

1232 ["foo", "bar"] 

1233 >>> get_collection_names(a) # doctest: +SKIP 

1234 {"foo", "bar"} 

1235 >>> b.__dask_keys__() # doctest: +SKIP 

1236 [[("foo-123", 0, 0), ("foo-123", 0, 1)], [("foo-123", 1, 0), ("foo-123", 1, 1)]] 

1237 >>> get_collection_names(b) # doctest: +SKIP 

1238 {"foo-123"} 

1239 """ 

1240 if not is_dask_collection(collection): 

1241 raise TypeError(f"Expected Dask collection; got {type(collection)}") 

1242 return {get_name_from_key(k) for k in flatten(collection.__dask_keys__())} 

1243 

1244 

1245def get_name_from_key(key: Key) -> str: 

1246 """Given a dask collection's key, extract the collection name. 

1247 

1248 Parameters 

1249 ---------- 

1250 key: string or tuple 

1251 Dask collection's key, which must be either a single string or a tuple whose 

1252 first element is a string (commonly referred to as a collection's 'name'), 

1253 

1254 Examples 

1255 -------- 

1256 >>> get_name_from_key("foo") 

1257 'foo' 

1258 >>> get_name_from_key(("foo-123", 1, 2)) 

1259 'foo-123' 

1260 """ 

1261 if isinstance(key, tuple) and key and isinstance(key[0], str): 

1262 return key[0] 

1263 if isinstance(key, str): 

1264 return key 

1265 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}") 

1266 

1267 

1268KeyOrStrT = TypeVar("KeyOrStrT", Key, str) 

1269 

1270 

1271def replace_name_in_key(key: KeyOrStrT, rename: Mapping[str, str]) -> KeyOrStrT: 

1272 """Given a dask collection's key, replace the collection name with a new one. 

1273 

1274 Parameters 

1275 ---------- 

1276 key: string or tuple 

1277 Dask collection's key, which must be either a single string or a tuple whose 

1278 first element is a string (commonly referred to as a collection's 'name'), 

1279 rename: 

1280 Mapping of zero or more names from : to. Extraneous names will be ignored. 

1281 Names not found in this mapping won't be replaced. 

1282 

1283 Examples 

1284 -------- 

1285 >>> replace_name_in_key("foo", {}) 

1286 'foo' 

1287 >>> replace_name_in_key("foo", {"foo": "bar"}) 

1288 'bar' 

1289 >>> replace_name_in_key(("foo-123", 1, 2), {"foo-123": "bar-456"}) 

1290 ('bar-456', 1, 2) 

1291 """ 

1292 if isinstance(key, tuple) and key and isinstance(key[0], str): 

1293 return (rename.get(key[0], key[0]),) + key[1:] 

1294 if isinstance(key, str): 

1295 return rename.get(key, key) 

1296 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}") 

1297 

1298 

1299def clone_key(key: KeyOrStrT, seed: Hashable) -> KeyOrStrT: 

1300 """Clone a key from a Dask collection, producing a new key with the same prefix and 

1301 indices and a token which is a deterministic function of the previous key and seed. 

1302 

1303 Examples 

1304 -------- 

1305 >>> clone_key("x", 123) # doctest: +SKIP 

1306 'x-c4fb64ccca807af85082413d7ef01721' 

1307 >>> clone_key("inc-cbb1eca3bafafbb3e8b2419c4eebb387", 123) # doctest: +SKIP 

1308 'inc-bc629c23014a4472e18b575fdaf29ee7' 

1309 >>> clone_key(("sum-cbb1eca3bafafbb3e8b2419c4eebb387", 4, 3), 123) # doctest: +SKIP 

1310 ('sum-c053f3774e09bd0f7de6044dbc40e71d', 4, 3) 

1311 """ 

1312 if isinstance(key, tuple) and key and isinstance(key[0], str): 

1313 return (clone_key(key[0], seed),) + key[1:] 

1314 if isinstance(key, str): 

1315 prefix = key_split(key) 

1316 return prefix + "-" + tokenize(key, seed) 

1317 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}")