Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/base.py: 18%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

389 statements  

1from __future__ import annotations 

2 

3import dataclasses 

4import inspect 

5import uuid 

6import warnings 

7from collections import OrderedDict 

8from collections.abc import Hashable, Iterable, Iterator, Mapping 

9from concurrent.futures import Executor 

10from contextlib import contextmanager, suppress 

11from contextvars import ContextVar 

12from functools import partial 

13from numbers import Integral, Number 

14from operator import getitem 

15from typing import TYPE_CHECKING, Any, Literal, TypeVar 

16 

17from tlz import merge 

18 

19from dask import config, local 

20from dask._compatibility import EMSCRIPTEN 

21from dask._task_spec import DataNode, Dict, List, Task, TaskRef 

22from dask.core import flatten 

23from dask.core import get as simple_get 

24from dask.system import CPU_COUNT 

25from dask.typing import Key, SchedulerGetCallable 

26from dask.utils import is_namedtuple_instance, key_split, shorten_traceback 

27 

28if TYPE_CHECKING: 

29 from dask._expr import Expr 

30 

31_DistributedClient = None 

32_get_distributed_client = None 

33_DISTRIBUTED_AVAILABLE = None 

34 

35 

36def _distributed_available() -> bool: 

37 # Lazy import in get_scheduler can be expensive 

38 global _DistributedClient, _get_distributed_client, _DISTRIBUTED_AVAILABLE 

39 if _DISTRIBUTED_AVAILABLE is not None: 

40 return _DISTRIBUTED_AVAILABLE # type: ignore[unreachable] 

41 try: 

42 from distributed import Client as _DistributedClient 

43 from distributed.worker import get_client as _get_distributed_client 

44 

45 _DISTRIBUTED_AVAILABLE = True 

46 except ImportError: 

47 _DISTRIBUTED_AVAILABLE = False 

48 return _DISTRIBUTED_AVAILABLE 

49 

50 

51__all__ = ( 

52 "DaskMethodsMixin", 

53 "annotate", 

54 "get_annotations", 

55 "is_dask_collection", 

56 "compute", 

57 "persist", 

58 "optimize", 

59 "visualize", 

60 "tokenize", 

61 "normalize_token", 

62 "get_collection_names", 

63 "get_name_from_key", 

64 "replace_name_in_key", 

65 "clone_key", 

66) 

67 

68# Backwards compat 

69from dask.tokenize import TokenizationError, normalize_token, tokenize # noqa: F401 

70 

71_annotations: ContextVar[dict[str, Any] | None] = ContextVar( 

72 "annotations", default=None 

73) 

74 

75 

76def get_annotations() -> dict[str, Any]: 

77 """Get current annotations. 

78 

79 Returns 

80 ------- 

81 Dict of all current annotations 

82 

83 See Also 

84 -------- 

85 annotate 

86 """ 

87 return _annotations.get() or {} 

88 

89 

90@contextmanager 

91def annotate(**annotations: Any) -> Iterator[None]: 

92 """Context Manager for setting HighLevelGraph Layer annotations. 

93 

94 Annotations are metadata or soft constraints associated with 

95 tasks that dask schedulers may choose to respect: They signal intent 

96 without enforcing hard constraints. As such, they are 

97 primarily designed for use with the distributed scheduler. 

98 

99 Almost any object can serve as an annotation, but small Python objects 

100 are preferred, while large objects such as NumPy arrays are discouraged. 

101 

102 Callables supplied as an annotation should take a single *key* argument and 

103 produce the appropriate annotation. Individual task keys in the annotated collection 

104 are supplied to the callable. 

105 

106 Parameters 

107 ---------- 

108 **annotations : key-value pairs 

109 

110 Examples 

111 -------- 

112 

113 All tasks within array A should have priority 100 and be retried 3 times 

114 on failure. 

115 

116 >>> import dask 

117 >>> import dask.array as da 

118 >>> with dask.annotate(priority=100, retries=3): 

119 ... A = da.ones((10000, 10000)) 

120 

121 Prioritise tasks within Array A on flattened block ID. 

122 

123 >>> nblocks = (10, 10) 

124 >>> with dask.annotate(priority=lambda k: k[1]*nblocks[1] + k[2]): 

125 ... A = da.ones((1000, 1000), chunks=(100, 100)) 

126 

127 Annotations may be nested. 

128 

129 >>> with dask.annotate(priority=1): 

130 ... with dask.annotate(retries=3): 

131 ... A = da.ones((1000, 1000)) 

132 ... B = A + 1 

133 

134 See Also 

135 -------- 

136 get_annotations 

137 """ 

138 

139 # Sanity check annotations used in place of 

140 # legacy distributed Client.{submit, persist, compute} keywords 

141 if "workers" in annotations: 

142 if isinstance(annotations["workers"], (list, set, tuple)): 

143 annotations["workers"] = list(annotations["workers"]) 

144 elif isinstance(annotations["workers"], str): 

145 annotations["workers"] = [annotations["workers"]] 

146 elif callable(annotations["workers"]): 

147 pass 

148 else: 

149 raise TypeError( 

150 "'workers' annotation must be a sequence of str, a str or a callable, but got %s." 

151 % annotations["workers"] 

152 ) 

153 

154 if ( 

155 "priority" in annotations 

156 and not isinstance(annotations["priority"], Number) 

157 and not callable(annotations["priority"]) 

158 ): 

159 raise TypeError( 

160 "'priority' annotation must be a Number or a callable, but got %s" 

161 % annotations["priority"] 

162 ) 

163 

164 if ( 

165 "retries" in annotations 

166 and not isinstance(annotations["retries"], Number) 

167 and not callable(annotations["retries"]) 

168 ): 

169 raise TypeError( 

170 "'retries' annotation must be a Number or a callable, but got %s" 

171 % annotations["retries"] 

172 ) 

173 

174 if ( 

175 "resources" in annotations 

176 and not isinstance(annotations["resources"], dict) 

177 and not callable(annotations["resources"]) 

178 ): 

179 raise TypeError( 

180 "'resources' annotation must be a dict, but got %s" 

181 % annotations["resources"] 

182 ) 

183 

184 if ( 

185 "allow_other_workers" in annotations 

186 and not isinstance(annotations["allow_other_workers"], bool) 

187 and not callable(annotations["allow_other_workers"]) 

188 ): 

189 raise TypeError( 

190 "'allow_other_workers' annotations must be a bool or a callable, but got %s" 

191 % annotations["allow_other_workers"] 

192 ) 

193 ctx_annot = _annotations.get() 

194 if ctx_annot is None: 

195 ctx_annot = {} 

196 token = _annotations.set(merge(ctx_annot, annotations)) 

197 try: 

198 yield 

199 finally: 

200 _annotations.reset(token) 

201 

202 

203def is_dask_collection(x) -> bool: 

204 """Returns ``True`` if ``x`` is a dask collection. 

205 

206 Parameters 

207 ---------- 

208 x : Any 

209 Object to test. 

210 

211 Returns 

212 ------- 

213 result : bool 

214 ``True`` if `x` is a Dask collection. 

215 

216 Notes 

217 ----- 

218 The DaskCollection typing.Protocol implementation defines a Dask 

219 collection as a class that returns a Mapping from the 

220 ``__dask_graph__`` method. This helper function existed before the 

221 implementation of the protocol. 

222 

223 """ 

224 if ( 

225 isinstance(x, type) 

226 or not hasattr(x, "__dask_graph__") 

227 or not callable(x.__dask_graph__) 

228 ): 

229 return False 

230 

231 pkg_name = getattr(type(x), "__module__", "") 

232 if pkg_name.split(".")[0] in ("dask_cudf",): 

233 # Temporary hack to avoid graph materialization. Note that this won't work with 

234 # dask_expr.array objects wrapped by xarray or pint. By the time dask_expr.array 

235 # is published, we hope to be able to rewrite this method completely. 

236 # Read: https://github.com/dask/dask/pull/10676 

237 return True 

238 elif pkg_name.startswith("dask.dataframe.dask_expr"): 

239 return True 

240 elif pkg_name.startswith("dask.array._array_expr"): 

241 return True 

242 

243 # xarray, pint, and possibly other wrappers always define a __dask_graph__ method, 

244 # but it may return None if they wrap around a non-dask object. 

245 # In all known dask collections other than dask-expr, 

246 # calling __dask_graph__ is cheap. 

247 return x.__dask_graph__() is not None 

248 

249 

250class DaskMethodsMixin: 

251 """A mixin adding standard dask collection methods""" 

252 

253 __slots__ = ("__weakref__",) 

254 

255 def visualize(self, filename="mydask", format=None, optimize_graph=False, **kwargs): 

256 """Render the computation of this object's task graph using graphviz. 

257 

258 Requires ``graphviz`` to be installed. 

259 

260 Parameters 

261 ---------- 

262 filename : str or None, optional 

263 The name of the file to write to disk. If the provided `filename` 

264 doesn't include an extension, '.png' will be used by default. 

265 If `filename` is None, no file will be written, and we communicate 

266 with dot using only pipes. 

267 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional 

268 Format in which to write output file. Default is 'png'. 

269 optimize_graph : bool, optional 

270 If True, the graph is optimized before rendering. Otherwise, 

271 the graph is displayed as is. Default is False. 

272 color: {None, 'order'}, optional 

273 Options to color nodes. Provide ``cmap=`` keyword for additional 

274 colormap 

275 **kwargs 

276 Additional keyword arguments to forward to ``to_graphviz``. 

277 

278 Examples 

279 -------- 

280 >>> x.visualize(filename='dask.pdf') # doctest: +SKIP 

281 >>> x.visualize(filename='dask.pdf', color='order') # doctest: +SKIP 

282 

283 Returns 

284 ------- 

285 result : IPython.display.Image, IPython.display.SVG, or None 

286 See dask.dot.dot_graph for more information. 

287 

288 See Also 

289 -------- 

290 dask.visualize 

291 dask.dot.dot_graph 

292 

293 Notes 

294 ----- 

295 For more information on optimization see here: 

296 

297 https://docs.dask.org/en/latest/optimize.html 

298 """ 

299 return visualize( 

300 self, 

301 filename=filename, 

302 format=format, 

303 optimize_graph=optimize_graph, 

304 **kwargs, 

305 ) 

306 

307 def persist(self, **kwargs): 

308 """Persist this dask collection into memory 

309 

310 This turns a lazy Dask collection into a Dask collection with the same 

311 metadata, but now with the results fully computed or actively computing 

312 in the background. 

313 

314 The action of function differs significantly depending on the active 

315 task scheduler. If the task scheduler supports asynchronous computing, 

316 such as is the case of the dask.distributed scheduler, then persist 

317 will return *immediately* and the return value's task graph will 

318 contain Dask Future objects. However if the task scheduler only 

319 supports blocking computation then the call to persist will *block* 

320 and the return value's task graph will contain concrete Python results. 

321 

322 This function is particularly useful when using distributed systems, 

323 because the results will be kept in distributed memory, rather than 

324 returned to the local process as with compute. 

325 

326 Parameters 

327 ---------- 

328 scheduler : string, optional 

329 Which scheduler to use like "threads", "synchronous" or "processes". 

330 If not provided, the default is to check the global settings first, 

331 and then fall back to the collection defaults. 

332 optimize_graph : bool, optional 

333 If True [default], the graph is optimized before computation. 

334 Otherwise the graph is run as is. This can be useful for debugging. 

335 **kwargs 

336 Extra keywords to forward to the scheduler function. 

337 

338 Returns 

339 ------- 

340 New dask collections backed by in-memory data 

341 

342 See Also 

343 -------- 

344 dask.persist 

345 """ 

346 (result,) = persist(self, traverse=False, **kwargs) 

347 return result 

348 

349 def compute(self, **kwargs): 

350 """Compute this dask collection 

351 

352 This turns a lazy Dask collection into its in-memory equivalent. 

353 For example a Dask array turns into a NumPy array and a Dask dataframe 

354 turns into a Pandas dataframe. The entire dataset must fit into memory 

355 before calling this operation. 

356 

357 Parameters 

358 ---------- 

359 scheduler : string, optional 

360 Which scheduler to use like "threads", "synchronous" or "processes". 

361 If not provided, the default is to check the global settings first, 

362 and then fall back to the collection defaults. 

363 optimize_graph : bool, optional 

364 If True [default], the graph is optimized before computation. 

365 Otherwise the graph is run as is. This can be useful for debugging. 

366 kwargs 

367 Extra keywords to forward to the scheduler function. 

368 

369 See Also 

370 -------- 

371 dask.compute 

372 """ 

373 (result,) = compute(self, traverse=False, **kwargs) 

374 return result 

375 

376 def __await__(self): 

377 try: 

378 from distributed import futures_of, wait 

379 except ImportError as e: 

380 raise ImportError( 

381 "Using async/await with dask requires the `distributed` package" 

382 ) from e 

383 

384 async def f(): 

385 if futures_of(self): 

386 await wait(self) 

387 return self 

388 

389 return f().__await__() 

390 

391 

392def compute_as_if_collection(cls, dsk, keys, scheduler=None, get=None, **kwargs): 

393 """Compute a graph as if it were of type cls. 

394 

395 Allows for applying the same optimizations and default scheduler.""" 

396 schedule = get_scheduler(scheduler=scheduler, cls=cls, get=get) 

397 dsk2 = optimization_function(cls)(dsk, keys, **kwargs) 

398 return schedule(dsk2, keys, **kwargs) 

399 

400 

401def dont_optimize(dsk, keys, **kwargs): 

402 return dsk 

403 

404 

405def optimization_function(x): 

406 return getattr(x, "__dask_optimize__", dont_optimize) 

407 

408 

409def collections_to_expr( 

410 collections: Iterable, 

411 optimize_graph: bool = True, 

412) -> Expr: 

413 """ 

414 Convert many collections into a single dask expression. 

415 

416 Typically, users should not be required to interact with this function. 

417 

418 Parameters 

419 ---------- 

420 collections : Iterable 

421 An iterable of dask collections to be combined. 

422 optimize_graph : bool, optional 

423 If this is True and collections are encountered which are backed by 

424 legacy HighLevelGraph objects, the returned Expression will run a low 

425 level task optimization during materialization. 

426 """ 

427 is_iterable = False 

428 if isinstance(collections, (tuple, list, set)): 

429 is_iterable = True 

430 else: 

431 collections = [collections] 

432 if not collections: 

433 raise ValueError("No collections provided") 

434 from dask._expr import HLGExpr, _ExprSequence 

435 

436 graphs = [] 

437 for coll in collections: 

438 from dask.delayed import Delayed 

439 

440 if isinstance(coll, Delayed) or not hasattr(coll, "expr"): 

441 graphs.append(HLGExpr.from_collection(coll, optimize_graph=optimize_graph)) 

442 else: 

443 graphs.append(coll.expr) 

444 

445 if len(graphs) > 1 or is_iterable: 

446 return _ExprSequence(*graphs) 

447 else: 

448 return graphs[0] 

449 

450 

451def unpack_collections(*args, traverse=True): 

452 """Extract collections in preparation for compute/persist/etc... 

453 

454 Intended use is to find all collections in a set of (possibly nested) 

455 python objects, do something to them (compute, etc...), then repackage them 

456 in equivalent python objects. 

457 

458 Parameters 

459 ---------- 

460 *args 

461 Any number of objects. If it is a dask collection, it's extracted and 

462 added to the list of collections returned. By default, python builtin 

463 collections are also traversed to look for dask collections (for more 

464 information see the ``traverse`` keyword). 

465 traverse : bool, optional 

466 If True (default), builtin python collections are traversed looking for 

467 any dask collections they might contain. 

468 

469 Returns 

470 ------- 

471 collections : list 

472 A list of all dask collections contained in ``args`` 

473 repack : callable 

474 A function to call on the transformed collections to repackage them as 

475 they were in the original ``args``. 

476 """ 

477 

478 collections = [] 

479 repack_dsk = {} 

480 

481 collections_token = uuid.uuid4().hex 

482 

483 def _unpack(expr): 

484 if is_dask_collection(expr): 

485 tok = tokenize(expr) 

486 if tok not in repack_dsk: 

487 repack_dsk[tok] = Task( 

488 tok, getitem, TaskRef(collections_token), len(collections) 

489 ) 

490 collections.append(expr) 

491 return TaskRef(tok) 

492 

493 tok = uuid.uuid4().hex 

494 tsk: DataNode | Task # type: ignore 

495 if not traverse: 

496 tsk = DataNode(None, expr) 

497 else: 

498 # Treat iterators like lists 

499 typ = list if isinstance(expr, Iterator) else type(expr) 

500 if typ in (list, tuple, set): 

501 tsk = Task(tok, typ, List(*[_unpack(i) for i in expr])) 

502 elif typ in (dict, OrderedDict): 

503 tsk = Task( 

504 tok, typ, Dict({_unpack(k): _unpack(v) for k, v in expr.items()}) 

505 ) 

506 elif dataclasses.is_dataclass(expr) and not isinstance(expr, type): 

507 tsk = Task( 

508 tok, 

509 typ, 

510 *[_unpack(getattr(expr, f.name)) for f in dataclasses.fields(expr)], 

511 ) 

512 elif is_namedtuple_instance(expr): 

513 tsk = Task(tok, typ, *[_unpack(i) for i in expr]) 

514 else: 

515 return expr 

516 

517 repack_dsk[tok] = tsk 

518 return TaskRef(tok) 

519 

520 out = uuid.uuid4().hex 

521 repack_dsk[out] = Task(out, tuple, List(*[_unpack(i) for i in args])) 

522 

523 def repack(results): 

524 dsk = repack_dsk.copy() 

525 dsk[collections_token] = DataNode(collections_token, results) 

526 return simple_get(dsk, out) 

527 

528 # The original `collections` is kept alive by the closure 

529 # This causes the collection to be only freed by the garbage collector 

530 collections2 = list(collections) 

531 collections.clear() 

532 return collections2, repack 

533 

534 

535def optimize(*args, traverse=True, **kwargs): 

536 """Optimize several dask collections at once. 

537 

538 Returns equivalent dask collections that all share the same merged and 

539 optimized underlying graph. This can be useful if converting multiple 

540 collections to delayed objects, or to manually apply the optimizations at 

541 strategic points. 

542 

543 Note that in most cases you shouldn't need to call this function directly. 

544 

545 Warning:: 

546 

547 This function triggers a materialization of the collections and looses 

548 any annotations attached to HLG layers. 

549 

550 Parameters 

551 ---------- 

552 *args : objects 

553 Any number of objects. If a dask object, its graph is optimized and 

554 merged with all those of all other dask objects before returning an 

555 equivalent dask collection. Non-dask arguments are passed through 

556 unchanged. 

557 traverse : bool, optional 

558 By default dask traverses builtin python collections looking for dask 

559 objects passed to ``optimize``. For large collections this can be 

560 expensive. If none of the arguments contain any dask objects, set 

561 ``traverse=False`` to avoid doing this traversal. 

562 optimizations : list of callables, optional 

563 Additional optimization passes to perform. 

564 **kwargs 

565 Extra keyword arguments to forward to the optimization passes. 

566 

567 Examples 

568 -------- 

569 >>> import dask 

570 >>> import dask.array as da 

571 >>> a = da.arange(10, chunks=2).sum() 

572 >>> b = da.arange(10, chunks=2).mean() 

573 >>> a2, b2 = dask.optimize(a, b) 

574 

575 >>> a2.compute() == a.compute() 

576 np.True_ 

577 >>> b2.compute() == b.compute() 

578 np.True_ 

579 """ 

580 # TODO: This API is problematic. The approach to using postpersist forces us 

581 # to materialize the graph. Most low level optimizations will materialize as 

582 # well 

583 collections, repack = unpack_collections(*args, traverse=traverse) 

584 if not collections: 

585 return args 

586 

587 dsk = collections_to_expr(collections) 

588 

589 postpersists = [] 

590 for a in collections: 

591 r, s = a.__dask_postpersist__() 

592 postpersists.append(r(dsk.__dask_graph__(), *s)) 

593 

594 return repack(postpersists) 

595 

596 

597def compute( 

598 *args, 

599 traverse=True, 

600 optimize_graph=True, 

601 scheduler=None, 

602 get=None, 

603 **kwargs, 

604): 

605 """Compute several dask collections at once. 

606 

607 Parameters 

608 ---------- 

609 args : object 

610 Any number of objects. If it is a dask object, it's computed and the 

611 result is returned. By default, python builtin collections are also 

612 traversed to look for dask objects (for more information see the 

613 ``traverse`` keyword). Non-dask arguments are passed through unchanged. 

614 traverse : bool, optional 

615 By default dask traverses builtin python collections looking for dask 

616 objects passed to ``compute``. For large collections this can be 

617 expensive. If none of the arguments contain any dask objects, set 

618 ``traverse=False`` to avoid doing this traversal. 

619 scheduler : string, optional 

620 Which scheduler to use like "threads", "synchronous" or "processes". 

621 If not provided, the default is to check the global settings first, 

622 and then fall back to the collection defaults. 

623 optimize_graph : bool, optional 

624 If True [default], the optimizations for each collection are applied 

625 before computation. Otherwise the graph is run as is. This can be 

626 useful for debugging. 

627 get : ``None`` 

628 Should be left to ``None`` The get= keyword has been removed. 

629 kwargs 

630 Extra keywords to forward to the scheduler function. 

631 

632 Examples 

633 -------- 

634 >>> import dask 

635 >>> import dask.array as da 

636 >>> a = da.arange(10, chunks=2).sum() 

637 >>> b = da.arange(10, chunks=2).mean() 

638 >>> dask.compute(a, b) 

639 (np.int64(45), np.float64(4.5)) 

640 

641 By default, dask objects inside python collections will also be computed: 

642 

643 >>> dask.compute({'a': a, 'b': b, 'c': 1}) 

644 ({'a': np.int64(45), 'b': np.float64(4.5), 'c': 1},) 

645 """ 

646 

647 collections, repack = unpack_collections(*args, traverse=traverse) 

648 if not collections: 

649 return args 

650 

651 schedule = get_scheduler( 

652 scheduler=scheduler, 

653 collections=collections, 

654 get=get, 

655 ) 

656 from dask._expr import FinalizeCompute 

657 

658 expr = collections_to_expr(collections, optimize_graph) 

659 expr = FinalizeCompute(expr) 

660 

661 with shorten_traceback(): 

662 # The high level optimize will have to be called client side (for now) 

663 # The optimize can internally trigger already a computation 

664 # (e.g. parquet is reading some statistics). To move this to the 

665 # scheduler we'd need some sort of scheduler-client to trigger a 

666 # computation from inside the scheduler and continue with optimization 

667 # once the results are in. An alternative could be to introduce a 

668 # pre-optimize step for the Expressions that handles steps like these as 

669 # a dedicated computation 

670 

671 # Another caveat is that optimize will only lock in the expression names 

672 # after optimization. Names are determined using tokenize and tokenize 

673 # is not cross-interpreter (let alone cross-host) stable such that we 

674 # have to lock this in before sending stuff (otherwise we'd need to 

675 # change the graph submission to a handshake which introduces all sorts 

676 # of concurrency control issues) 

677 

678 expr = expr.optimize() 

679 keys = list(flatten(expr.__dask_keys__())) 

680 

681 results = schedule(expr, keys, **kwargs) 

682 

683 return repack(results) 

684 

685 

686def visualize( 

687 *args, 

688 filename="mydask", 

689 traverse=True, 

690 optimize_graph=False, 

691 maxval=None, 

692 engine: Literal["cytoscape", "ipycytoscape", "graphviz"] | None = None, 

693 **kwargs, 

694): 

695 """ 

696 Visualize several dask graphs simultaneously. 

697 

698 Requires ``graphviz`` to be installed. All options that are not the dask 

699 graph(s) should be passed as keyword arguments. 

700 

701 Parameters 

702 ---------- 

703 args : object 

704 Any number of objects. If it is a dask collection (for example, a 

705 dask DataFrame, Array, Bag, or Delayed), its associated graph 

706 will be included in the output of visualize. By default, python builtin 

707 collections are also traversed to look for dask objects (for more 

708 information see the ``traverse`` keyword). Arguments lacking an 

709 associated graph will be ignored. 

710 filename : str or None, optional 

711 The name of the file to write to disk. If the provided `filename` 

712 doesn't include an extension, '.png' will be used by default. 

713 If `filename` is None, no file will be written, and we communicate 

714 with dot using only pipes. 

715 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional 

716 Format in which to write output file. Default is 'png'. 

717 traverse : bool, optional 

718 By default, dask traverses builtin python collections looking for dask 

719 objects passed to ``visualize``. For large collections this can be 

720 expensive. If none of the arguments contain any dask objects, set 

721 ``traverse=False`` to avoid doing this traversal. 

722 optimize_graph : bool, optional 

723 If True, the graph is optimized before rendering. Otherwise, 

724 the graph is displayed as is. Default is False. 

725 color : {None, 'order', 'ages', 'freed', 'memoryincreases', 'memorydecreases', 'memorypressure'}, optional 

726 Options to color nodes. colormap: 

727 

728 - None, the default, no colors. 

729 - 'order', colors the nodes' border based on the order they appear in the graph. 

730 - 'ages', how long the data of a node is held. 

731 - 'freed', the number of dependencies released after running a node. 

732 - 'memoryincreases', how many more outputs are held after the lifetime of a node. 

733 Large values may indicate nodes that should have run later. 

734 - 'memorydecreases', how many fewer outputs are held after the lifetime of a node. 

735 Large values may indicate nodes that should have run sooner. 

736 - 'memorypressure', the number of data held when the node is run (circle), or 

737 the data is released (rectangle). 

738 maxval : {int, float}, optional 

739 Maximum value for colormap to normalize form 0 to 1.0. Default is ``None`` 

740 will make it the max number of values 

741 collapse_outputs : bool, optional 

742 Whether to collapse output boxes, which often have empty labels. 

743 Default is False. 

744 verbose : bool, optional 

745 Whether to label output and input boxes even if the data aren't chunked. 

746 Beware: these labels can get very long. Default is False. 

747 engine : {"graphviz", "ipycytoscape", "cytoscape"}, optional. 

748 The visualization engine to use. If not provided, this checks the dask config 

749 value "visualization.engine". If that is not set, it tries to import ``graphviz`` 

750 and ``ipycytoscape``, using the first one to succeed. 

751 **kwargs 

752 Additional keyword arguments to forward to the visualization engine. 

753 

754 Examples 

755 -------- 

756 >>> x.visualize(filename='dask.pdf') # doctest: +SKIP 

757 >>> x.visualize(filename='dask.pdf', color='order') # doctest: +SKIP 

758 

759 Returns 

760 ------- 

761 result : IPython.display.Image, IPython.display.SVG, or None 

762 See dask.dot.dot_graph for more information. 

763 

764 See Also 

765 -------- 

766 dask.dot.dot_graph 

767 

768 Notes 

769 ----- 

770 For more information on optimization see here: 

771 

772 https://docs.dask.org/en/latest/optimize.html 

773 """ 

774 args, _ = unpack_collections(*args, traverse=traverse) 

775 

776 dsk = collections_to_expr(args, optimize_graph=optimize_graph).__dask_graph__() 

777 

778 return visualize_dsk( 

779 dsk=dsk, 

780 filename=filename, 

781 traverse=traverse, 

782 optimize_graph=optimize_graph, 

783 maxval=maxval, 

784 engine=engine, 

785 **kwargs, 

786 ) 

787 

788 

789def visualize_dsk( 

790 dsk, 

791 filename="mydask", 

792 traverse=True, 

793 optimize_graph=False, 

794 maxval=None, 

795 o=None, 

796 engine: Literal["cytoscape", "ipycytoscape", "graphviz"] | None = None, 

797 limit=None, 

798 **kwargs, 

799): 

800 color = kwargs.get("color") 

801 from dask.order import diagnostics, order 

802 

803 if color in { 

804 "order", 

805 "order-age", 

806 "order-freed", 

807 "order-memoryincreases", 

808 "order-memorydecreases", 

809 "order-memorypressure", 

810 "age", 

811 "freed", 

812 "memoryincreases", 

813 "memorydecreases", 

814 "memorypressure", 

815 "critical", 

816 "cpath", 

817 }: 

818 import matplotlib.pyplot as plt 

819 

820 if o is None: 

821 o_stats = order(dsk, return_stats=True) 

822 o = {k: v.priority for k, v in o_stats.items()} 

823 elif isinstance(next(iter(o.values())), int): 

824 o_stats = order(dsk, return_stats=True) 

825 else: 

826 o_stats = o 

827 o = {k: v.priority for k, v in o.items()} 

828 

829 try: 

830 cmap = kwargs.pop("cmap") 

831 except KeyError: 

832 cmap = plt.cm.plasma 

833 if isinstance(cmap, str): 

834 import matplotlib.pyplot as plt 

835 

836 cmap = getattr(plt.cm, cmap) 

837 

838 def label(x): 

839 return str(values[x]) 

840 

841 data_values = None 

842 if color != "order": 

843 info = diagnostics(dsk, o)[0] 

844 if color.endswith("age"): 

845 values = {key: val.age for key, val in info.items()} 

846 elif color.endswith("freed"): 

847 values = {key: val.num_dependencies_freed for key, val in info.items()} 

848 elif color.endswith("memorypressure"): 

849 values = {key: val.num_data_when_run for key, val in info.items()} 

850 data_values = { 

851 key: val.num_data_when_released for key, val in info.items() 

852 } 

853 elif color.endswith("memoryincreases"): 

854 values = { 

855 key: max(0, val.num_data_when_released - val.num_data_when_run) 

856 for key, val in info.items() 

857 } 

858 elif color.endswith("memorydecreases"): 

859 values = { 

860 key: max(0, val.num_data_when_run - val.num_data_when_released) 

861 for key, val in info.items() 

862 } 

863 elif color.split("-")[-1] in {"critical", "cpath"}: 

864 values = {key: val.critical_path for key, val in o_stats.items()} 

865 else: 

866 raise NotImplementedError(color) 

867 

868 if color.startswith("order-"): 

869 

870 def label(x): 

871 return str(o[x]) + "-" + str(values[x]) 

872 

873 else: 

874 values = o 

875 if maxval is None: 

876 maxval = max(1, max(values.values())) 

877 colors = { 

878 k: _colorize(tuple(map(int, cmap(v / maxval, bytes=True)))) 

879 for k, v in values.items() 

880 } 

881 if data_values is None: 

882 data_colors = colors 

883 else: 

884 data_colors = { 

885 k: _colorize(tuple(map(int, cmap(v / maxval, bytes=True)))) 

886 for k, v in values.items() 

887 } 

888 

889 kwargs["function_attributes"] = { 

890 k: {"color": v, "label": label(k)} for k, v in colors.items() 

891 } 

892 kwargs["data_attributes"] = {k: {"color": v} for k, v in data_colors.items()} 

893 elif color: 

894 raise NotImplementedError("Unknown value color=%s" % color) 

895 

896 # Determine which engine to dispatch to, first checking the kwarg, then config, 

897 # then whichever of graphviz or ipycytoscape are installed, in that order. 

898 engine = engine or config.get("visualization.engine", None) 

899 

900 if not engine: 

901 try: 

902 import graphviz # noqa: F401 

903 

904 engine = "graphviz" 

905 except ImportError: 

906 try: 

907 import ipycytoscape # noqa: F401 

908 

909 engine = "cytoscape" 

910 except ImportError: 

911 pass 

912 if engine == "graphviz": 

913 from dask.dot import dot_graph 

914 

915 return dot_graph(dsk, filename=filename, **kwargs) 

916 elif engine in ("cytoscape", "ipycytoscape"): 

917 from dask.dot import cytoscape_graph 

918 

919 return cytoscape_graph(dsk, filename=filename, **kwargs) 

920 elif engine is None: 

921 raise RuntimeError( 

922 "No visualization engine detected, please install graphviz or ipycytoscape" 

923 ) 

924 else: 

925 raise ValueError(f"Visualization engine {engine} not recognized") 

926 

927 

928def persist(*args, traverse=True, optimize_graph=True, scheduler=None, **kwargs): 

929 """Persist multiple Dask collections into memory 

930 

931 This turns lazy Dask collections into Dask collections with the same 

932 metadata, but now with their results fully computed or actively computing 

933 in the background. 

934 

935 For example a lazy dask.array built up from many lazy calls will now be a 

936 dask.array of the same shape, dtype, chunks, etc., but now with all of 

937 those previously lazy tasks either computed in memory as many small :class:`numpy.array` 

938 (in the single-machine case) or asynchronously running in the 

939 background on a cluster (in the distributed case). 

940 

941 This function operates differently if a ``dask.distributed.Client`` exists 

942 and is connected to a distributed scheduler. In this case this function 

943 will return as soon as the task graph has been submitted to the cluster, 

944 but before the computations have completed. Computations will continue 

945 asynchronously in the background. When using this function with the single 

946 machine scheduler it blocks until the computations have finished. 

947 

948 When using Dask on a single machine you should ensure that the dataset fits 

949 entirely within memory. 

950 

951 Examples 

952 -------- 

953 >>> df = dd.read_csv('/path/to/*.csv') # doctest: +SKIP 

954 >>> df = df[df.name == 'Alice'] # doctest: +SKIP 

955 >>> df['in-debt'] = df.balance < 0 # doctest: +SKIP 

956 >>> df = df.persist() # triggers computation # doctest: +SKIP 

957 

958 >>> df.value().min() # future computations are now fast # doctest: +SKIP 

959 -10 

960 >>> df.value().max() # doctest: +SKIP 

961 100 

962 

963 >>> from dask import persist # use persist function on multiple collections 

964 >>> a, b = persist(a, b) # doctest: +SKIP 

965 

966 Parameters 

967 ---------- 

968 *args: Dask collections 

969 scheduler : string, optional 

970 Which scheduler to use like "threads", "synchronous" or "processes". 

971 If not provided, the default is to check the global settings first, 

972 and then fall back to the collection defaults. 

973 traverse : bool, optional 

974 By default dask traverses builtin python collections looking for dask 

975 objects passed to ``persist``. For large collections this can be 

976 expensive. If none of the arguments contain any dask objects, set 

977 ``traverse=False`` to avoid doing this traversal. 

978 optimize_graph : bool, optional 

979 If True [default], the graph is optimized before computation. 

980 Otherwise the graph is run as is. This can be useful for debugging. 

981 **kwargs 

982 Extra keywords to forward to the scheduler function. 

983 

984 Returns 

985 ------- 

986 New dask collections backed by in-memory data 

987 """ 

988 collections, repack = unpack_collections(*args, traverse=traverse) 

989 if not collections: 

990 return args 

991 

992 schedule = get_scheduler(scheduler=scheduler, collections=collections) 

993 

994 if inspect.ismethod(schedule): 

995 try: 

996 from distributed.client import default_client 

997 except ImportError: 

998 pass 

999 else: 

1000 try: 

1001 client = default_client() 

1002 except ValueError: 

1003 pass 

1004 else: 

1005 if client.get == schedule: 

1006 results = client.persist( 

1007 collections, optimize_graph=optimize_graph, **kwargs 

1008 ) 

1009 return repack(results) 

1010 

1011 expr = collections_to_expr(collections, optimize_graph) 

1012 expr = expr.optimize() 

1013 keys, postpersists = [], [] 

1014 for a, akeys in zip(collections, expr.__dask_keys__(), strict=True): 

1015 a_keys = list(flatten(akeys)) 

1016 rebuild, state = a.__dask_postpersist__() 

1017 keys.extend(a_keys) 

1018 postpersists.append((rebuild, a_keys, state)) 

1019 

1020 with shorten_traceback(): 

1021 results = schedule(expr, keys, **kwargs) 

1022 

1023 d = dict(zip(keys, results)) 

1024 results2 = [r({k: d[k] for k in ks}, *s) for r, ks, s in postpersists] 

1025 return repack(results2) 

1026 

1027 

1028def _colorize(t): 

1029 """Convert (r, g, b) triple to "#RRGGBB" string 

1030 

1031 For use with ``visualize(color=...)`` 

1032 

1033 Examples 

1034 -------- 

1035 >>> _colorize((255, 255, 255)) 

1036 '#FFFFFF' 

1037 >>> _colorize((0, 32, 128)) 

1038 '#002080' 

1039 """ 

1040 t = t[:3] 

1041 i = sum(v * 256 ** (len(t) - i - 1) for i, v in enumerate(t)) 

1042 h = hex(int(i))[2:].upper() 

1043 h = "0" * (6 - len(h)) + h 

1044 return "#" + h 

1045 

1046 

1047named_schedulers: dict[str, SchedulerGetCallable] = { 

1048 "sync": local.get_sync, 

1049 "synchronous": local.get_sync, 

1050 "single-threaded": local.get_sync, 

1051} 

1052 

1053if not EMSCRIPTEN: 

1054 from dask import threaded 

1055 

1056 named_schedulers.update( 

1057 { 

1058 "threads": threaded.get, 

1059 "threading": threaded.get, 

1060 } 

1061 ) 

1062 

1063 from dask import multiprocessing as dask_multiprocessing 

1064 

1065 named_schedulers.update( 

1066 { 

1067 "processes": dask_multiprocessing.get, 

1068 "multiprocessing": dask_multiprocessing.get, 

1069 } 

1070 ) 

1071 

1072 

1073get_err_msg = """ 

1074The get= keyword has been removed. 

1075 

1076Please use the scheduler= keyword instead with the name of 

1077the desired scheduler like 'threads' or 'processes' 

1078 

1079 x.compute(scheduler='single-threaded') 

1080 x.compute(scheduler='threads') 

1081 x.compute(scheduler='processes') 

1082 

1083or with a function that takes the graph and keys 

1084 

1085 x.compute(scheduler=my_scheduler_function) 

1086 

1087or with a Dask client 

1088 

1089 x.compute(scheduler=client) 

1090""".strip() 

1091 

1092 

1093def _ensure_not_async(client): 

1094 if client.asynchronous: 

1095 if fallback := config.get("admin.async-client-fallback", None): 

1096 warnings.warn( 

1097 "Distributed Client detected but Client instance is " 

1098 f"asynchronous. Falling back to `{fallback}` scheduler. " 

1099 "To use an asynchronous Client, please use " 

1100 "``Client.compute`` and ``Client.gather`` " 

1101 "instead of the top level ``dask.compute``", 

1102 UserWarning, 

1103 ) 

1104 return get_scheduler(scheduler=fallback) 

1105 else: 

1106 raise RuntimeError( 

1107 "Attempting to use an asynchronous " 

1108 "Client in a synchronous context of `dask.compute`" 

1109 ) 

1110 return client.get 

1111 

1112 

1113def get_scheduler(get=None, scheduler=None, collections=None, cls=None): 

1114 """Get scheduler function 

1115 

1116 There are various ways to specify the scheduler to use: 

1117 

1118 1. Passing in scheduler= parameters 

1119 2. Passing these into global configuration 

1120 3. Using a dask.distributed default Client 

1121 4. Using defaults of a dask collection 

1122 

1123 This function centralizes the logic to determine the right scheduler to use 

1124 from those many options 

1125 """ 

1126 if get: 

1127 raise TypeError(get_err_msg) 

1128 

1129 if scheduler is not None: 

1130 if callable(scheduler): 

1131 return scheduler 

1132 elif "Client" in type(scheduler).__name__ and hasattr(scheduler, "get"): 

1133 return _ensure_not_async(scheduler) 

1134 elif isinstance(scheduler, str): 

1135 scheduler = scheduler.lower() 

1136 

1137 client_available = False 

1138 if _distributed_available(): 

1139 assert _DistributedClient is not None 

1140 with suppress(ValueError): 

1141 _DistributedClient.current(allow_global=True) 

1142 client_available = True 

1143 if scheduler in named_schedulers: 

1144 return named_schedulers[scheduler] 

1145 elif scheduler in ("dask.distributed", "distributed"): 

1146 if not client_available: 

1147 raise RuntimeError( 

1148 f"Requested {scheduler} scheduler but no Client active." 

1149 ) 

1150 assert _get_distributed_client is not None 

1151 client = _get_distributed_client() 

1152 return _ensure_not_async(client) 

1153 else: 

1154 raise ValueError( 

1155 "Expected one of [distributed, %s]" 

1156 % ", ".join(sorted(named_schedulers)) 

1157 ) 

1158 elif isinstance(scheduler, Executor): 

1159 # Get `num_workers` from `Executor`'s `_max_workers` attribute. 

1160 # If undefined, fallback to `config` or worst case CPU_COUNT. 

1161 num_workers = getattr(scheduler, "_max_workers", None) 

1162 if num_workers is None: 

1163 num_workers = config.get("num_workers", CPU_COUNT) 

1164 assert isinstance(num_workers, Integral) and num_workers > 0 

1165 return partial(local.get_async, scheduler.submit, num_workers) 

1166 else: 

1167 raise ValueError("Unexpected scheduler: %s" % repr(scheduler)) 

1168 # else: # try to connect to remote scheduler with this name 

1169 # return get_client(scheduler).get 

1170 

1171 if config.get("scheduler", None): 

1172 return get_scheduler(scheduler=config.get("scheduler", None)) 

1173 

1174 if config.get("get", None): 

1175 raise ValueError(get_err_msg) 

1176 

1177 try: 

1178 from distributed import get_client 

1179 

1180 return _ensure_not_async(get_client()) 

1181 except (ImportError, ValueError): 

1182 pass 

1183 

1184 if cls is not None: 

1185 return cls.__dask_scheduler__ 

1186 

1187 if collections: 

1188 collections = [c for c in collections if c is not None] 

1189 if collections: 

1190 get = collections[0].__dask_scheduler__ 

1191 if not all(c.__dask_scheduler__ == get for c in collections): 

1192 raise ValueError( 

1193 "Compute called on multiple collections with " 

1194 "differing default schedulers. Please specify a " 

1195 "scheduler=` parameter explicitly in compute or " 

1196 "globally with `dask.config.set`." 

1197 ) 

1198 return get 

1199 

1200 return None 

1201 

1202 

1203def wait(x, timeout=None, return_when="ALL_COMPLETED"): 

1204 """Wait until computation has finished 

1205 

1206 This is a compatibility alias for ``dask.distributed.wait``. 

1207 If it is applied onto Dask collections without Dask Futures or if Dask 

1208 distributed is not installed then it is a no-op 

1209 """ 

1210 try: 

1211 from distributed import wait 

1212 

1213 return wait(x, timeout=timeout, return_when=return_when) 

1214 except (ImportError, ValueError): 

1215 return x 

1216 

1217 

1218def get_collection_names(collection) -> set[str]: 

1219 """Infer the collection names from the dask keys, under the assumption that all keys 

1220 are either tuples with matching first element, and that element is a string, or 

1221 there is exactly one key and it is a string. 

1222 

1223 Examples 

1224 -------- 

1225 >>> a.__dask_keys__() # doctest: +SKIP 

1226 ["foo", "bar"] 

1227 >>> get_collection_names(a) # doctest: +SKIP 

1228 {"foo", "bar"} 

1229 >>> b.__dask_keys__() # doctest: +SKIP 

1230 [[("foo-123", 0, 0), ("foo-123", 0, 1)], [("foo-123", 1, 0), ("foo-123", 1, 1)]] 

1231 >>> get_collection_names(b) # doctest: +SKIP 

1232 {"foo-123"} 

1233 """ 

1234 if not is_dask_collection(collection): 

1235 raise TypeError(f"Expected Dask collection; got {type(collection)}") 

1236 return {get_name_from_key(k) for k in flatten(collection.__dask_keys__())} 

1237 

1238 

1239def get_name_from_key(key: Key) -> str: 

1240 """Given a dask collection's key, extract the collection name. 

1241 

1242 Parameters 

1243 ---------- 

1244 key: string or tuple 

1245 Dask collection's key, which must be either a single string or a tuple whose 

1246 first element is a string (commonly referred to as a collection's 'name'), 

1247 

1248 Examples 

1249 -------- 

1250 >>> get_name_from_key("foo") 

1251 'foo' 

1252 >>> get_name_from_key(("foo-123", 1, 2)) 

1253 'foo-123' 

1254 """ 

1255 if isinstance(key, tuple) and key and isinstance(key[0], str): 

1256 return key[0] 

1257 if isinstance(key, str): 

1258 return key 

1259 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}") 

1260 

1261 

1262KeyOrStrT = TypeVar("KeyOrStrT", Key, str) 

1263 

1264 

1265def replace_name_in_key(key: KeyOrStrT, rename: Mapping[str, str]) -> KeyOrStrT: 

1266 """Given a dask collection's key, replace the collection name with a new one. 

1267 

1268 Parameters 

1269 ---------- 

1270 key: string or tuple 

1271 Dask collection's key, which must be either a single string or a tuple whose 

1272 first element is a string (commonly referred to as a collection's 'name'), 

1273 rename: 

1274 Mapping of zero or more names from : to. Extraneous names will be ignored. 

1275 Names not found in this mapping won't be replaced. 

1276 

1277 Examples 

1278 -------- 

1279 >>> replace_name_in_key("foo", {}) 

1280 'foo' 

1281 >>> replace_name_in_key("foo", {"foo": "bar"}) 

1282 'bar' 

1283 >>> replace_name_in_key(("foo-123", 1, 2), {"foo-123": "bar-456"}) 

1284 ('bar-456', 1, 2) 

1285 """ 

1286 if isinstance(key, tuple) and key and isinstance(key[0], str): 

1287 return (rename.get(key[0], key[0]),) + key[1:] 

1288 if isinstance(key, str): 

1289 return rename.get(key, key) 

1290 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}") 

1291 

1292 

1293def clone_key(key: KeyOrStrT, seed: Hashable) -> KeyOrStrT: 

1294 """Clone a key from a Dask collection, producing a new key with the same prefix and 

1295 indices and a token which is a deterministic function of the previous key and seed. 

1296 

1297 Examples 

1298 -------- 

1299 >>> clone_key("x", 123) # doctest: +SKIP 

1300 'x-c4fb64ccca807af85082413d7ef01721' 

1301 >>> clone_key("inc-cbb1eca3bafafbb3e8b2419c4eebb387", 123) # doctest: +SKIP 

1302 'inc-bc629c23014a4472e18b575fdaf29ee7' 

1303 >>> clone_key(("sum-cbb1eca3bafafbb3e8b2419c4eebb387", 4, 3), 123) # doctest: +SKIP 

1304 ('sum-c053f3774e09bd0f7de6044dbc40e71d', 4, 3) 

1305 """ 

1306 if isinstance(key, tuple) and key and isinstance(key[0], str): 

1307 return (clone_key(key[0], seed),) + key[1:] 

1308 if isinstance(key, str): 

1309 prefix = key_split(key) 

1310 return prefix + "-" + tokenize(key, seed) 

1311 raise TypeError(f"Expected str or a tuple starting with str; got {key!r}")