Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/_task_spec.py: 25%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

658 statements  

1from __future__ import annotations 

2 

3""" Task specification for dask 

4 

5This module contains the task specification for dask. It is used to represent 

6runnable (task) and non-runnable (data) nodes in a dask graph. 

7 

8Simple examples of how to express tasks in dask 

9----------------------------------------------- 

10 

11.. code-block:: python 

12 

13 func("a", "b") ~ Task("key", func, "a", "b") 

14 

15 [func("a"), func("b")] ~ [Task("key-1", func, "a"), Task("key-2", func, "b")] 

16 

17 {"a": func("b")} ~ {"a": Task("a", func, "b")} 

18 

19 "literal-string" ~ DataNode("key", "literal-string") 

20 

21 

22Keys, Aliases and TaskRefs 

23------------------------- 

24 

25Keys are used to identify tasks in a dask graph. Every `GraphNode` instance has a 

26key attribute that _should_ reference the key in the dask graph. 

27 

28.. code-block:: python 

29 

30 {"key": Task("key", func, "a")} 

31 

32Referencing other tasks is possible by using either one of `Alias` or a 

33`TaskRef`. 

34 

35.. code-block:: python 

36 

37 # TaskRef can be used to provide the name of the reference explicitly 

38 t = Task("key", func, TaskRef("key-1")) 

39 

40 # If a task is still in scope, the method `ref` can be used for convenience 

41 t2 = Task("key2", func2, t.ref()) 

42 

43 

44Executing a task 

45---------------- 

46 

47A task can be executed by calling it with a dictionary of values. The values 

48should contain the dependencies of the task. 

49 

50.. code-block:: python 

51 

52 t = Task("key", add, TaskRef("a"), TaskRef("b")) 

53 assert t.dependencies == {"a", "b"} 

54 t({"a": 1, "b": 2}) == 3 

55 

56""" 

57import functools 

58import itertools 

59import sys 

60from collections import defaultdict 

61from collections.abc import Callable, Container, Iterable, Mapping, MutableMapping 

62from functools import lru_cache, partial 

63from typing import Any, TypeVar, cast 

64 

65from dask.sizeof import sizeof 

66from dask.typing import Key as KeyType 

67from dask.utils import funcname, is_namedtuple_instance 

68 

69_T = TypeVar("_T") 

70 

71 

72# Ported from more-itertools 

73# https://github.com/more-itertools/more-itertools/blob/c8153e2801ade2527f3a6c8b623afae93f5a1ce1/more_itertools/recipes.py#L944-L973 

74def _batched(iterable, n, *, strict=False): 

75 """Batch data into tuples of length *n*. If the number of items in 

76 *iterable* is not divisible by *n*: 

77 * The last batch will be shorter if *strict* is ``False``. 

78 * :exc:`ValueError` will be raised if *strict* is ``True``. 

79 

80 >>> list(batched('ABCDEFG', 3)) 

81 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)] 

82 

83 On Python 3.13 and above, this is an alias for :func:`itertools.batched`. 

84 """ 

85 if n < 1: 

86 raise ValueError("n must be at least one") 

87 it = iter(iterable) 

88 while batch := tuple(itertools.islice(it, n)): 

89 if strict and len(batch) != n: 

90 raise ValueError("batched(): incomplete batch") 

91 yield batch 

92 

93 

94if sys.hexversion >= 0x30D00A2: 

95 

96 def batched(iterable, n, *, strict=False): 

97 return itertools.batched(iterable, n, strict=strict) 

98 

99else: 

100 batched = _batched 

101 

102 batched.__doc__ = _batched.__doc__ 

103# End port 

104 

105 

106def identity(*args): 

107 return args 

108 

109 

110def _identity_cast(*args, typ): 

111 return typ(args) 

112 

113 

114_anom_count = itertools.count() 

115 

116 

117def parse_input(obj: Any) -> object: 

118 """Tokenize user input into GraphNode objects 

119 

120 Note: This is similar to `convert_legacy_task` but does not 

121 - compare any values to a global set of known keys to infer references/futures 

122 - parse tuples and interprets them as runnable tasks 

123 - Deal with SubgraphCallables 

124 

125 Parameters 

126 ---------- 

127 obj : _type_ 

128 _description_ 

129 

130 Returns 

131 ------- 

132 _type_ 

133 _description_ 

134 """ 

135 if isinstance(obj, GraphNode): 

136 return obj 

137 

138 if _is_dask_future(obj): 

139 return Alias(obj.key) 

140 

141 if isinstance(obj, dict): 

142 parsed_dict = {k: parse_input(v) for k, v in obj.items()} 

143 if any(isinstance(v, GraphNode) for v in parsed_dict.values()): 

144 return Dict(parsed_dict) 

145 

146 if isinstance(obj, (list, set, tuple)): 

147 parsed_collection = tuple(parse_input(o) for o in obj) 

148 if any(isinstance(o, GraphNode) for o in parsed_collection): 

149 if isinstance(obj, list): 

150 return List(*parsed_collection) 

151 if isinstance(obj, set): 

152 return Set(*parsed_collection) 

153 if isinstance(obj, tuple): 

154 if is_namedtuple_instance(obj): 

155 return _wrap_namedtuple_task(None, obj, parse_input) 

156 return Tuple(*parsed_collection) 

157 

158 return obj 

159 

160 

161def _wrap_namedtuple_task(k, obj, parser): 

162 if hasattr(obj, "__getnewargs_ex__"): 

163 new_args, kwargs = obj.__getnewargs_ex__() 

164 kwargs = {k: parser(v) for k, v in kwargs.items()} 

165 elif hasattr(obj, "__getnewargs__"): 

166 new_args = obj.__getnewargs__() 

167 kwargs = {} 

168 

169 args_converted = parse_input(type(new_args)(map(parser, new_args))) 

170 

171 return Task( 

172 k, partial(_instantiate_named_tuple, type(obj)), args_converted, Dict(kwargs) 

173 ) 

174 

175 

176def _instantiate_named_tuple(typ, args, kwargs): 

177 return typ(*args, **kwargs) 

178 

179 

180class _MultiContainer(Container): 

181 container: tuple 

182 __slots__ = ("container",) 

183 

184 def __init__(self, *container): 

185 self.container = container 

186 

187 def __contains__(self, o: object) -> bool: 

188 return any(o in c for c in self.container) 

189 

190 

191SubgraphType = None 

192 

193 

194def _execute_subgraph(inner_dsk, outkey, inkeys, *dependencies): 

195 final = {} 

196 final.update(inner_dsk) 

197 for k, v in zip(inkeys, dependencies): 

198 final[k] = DataNode(None, v) 

199 res = execute_graph(final, keys=[outkey]) 

200 return res[outkey] 

201 

202 

203def convert_legacy_task( 

204 key: KeyType | None, 

205 task: _T, 

206 all_keys: Container, 

207) -> GraphNode | _T: 

208 if isinstance(task, GraphNode): 

209 return task 

210 

211 if type(task) is tuple and task and callable(task[0]): 

212 func, args = task[0], task[1:] 

213 new_args = [] 

214 new: object 

215 for a in args: 

216 if isinstance(a, dict): 

217 new = Dict(a) 

218 else: 

219 new = convert_legacy_task(None, a, all_keys) 

220 new_args.append(new) 

221 return Task(key, func, *new_args) 

222 try: 

223 if isinstance(task, (int, float, str, tuple)): 

224 if task in all_keys: 

225 if key is None: 

226 return Alias(task) 

227 else: 

228 return Alias(key, target=task) 

229 except TypeError: 

230 # Unhashable 

231 pass 

232 

233 if isinstance(task, (list, tuple, set, frozenset)): 

234 if is_namedtuple_instance(task): 

235 return _wrap_namedtuple_task( 

236 key, 

237 task, 

238 partial( 

239 convert_legacy_task, 

240 None, 

241 all_keys=all_keys, 

242 ), 

243 ) 

244 else: 

245 parsed_args = tuple(convert_legacy_task(None, t, all_keys) for t in task) 

246 if any(isinstance(a, GraphNode) for a in parsed_args): 

247 return Task(key, _identity_cast, *parsed_args, typ=type(task)) 

248 else: 

249 return cast(_T, type(task)(parsed_args)) 

250 elif _is_dask_future(task): 

251 if key is None: 

252 return Alias(task.key) # type: ignore[attr-defined] 

253 else: 

254 return Alias(key, target=task.key) # type: ignore[attr-defined] 

255 else: 

256 return task 

257 

258 

259def convert_legacy_graph( 

260 dsk: Mapping, 

261 all_keys: Container | None = None, 

262): 

263 if all_keys is None: 

264 all_keys = set(dsk) 

265 new_dsk = {} 

266 for k, arg in dsk.items(): 

267 t = convert_legacy_task(k, arg, all_keys) 

268 if isinstance(t, Alias) and t.target == k: 

269 continue 

270 elif not isinstance(t, GraphNode): 

271 t = DataNode(k, t) 

272 new_dsk[k] = t 

273 return new_dsk 

274 

275 

276def resolve_aliases(dsk: dict, keys: set, dependents: dict) -> dict: 

277 """Remove trivial sequential alias chains 

278 

279 Example: 

280 

281 dsk = {'x': 1, 'y': Alias('x'), 'z': Alias('y')} 

282 

283 resolve_aliases(dsk, {'z'}, {'x': {'y'}, 'y': {'z'}}) == {'z': 1} 

284 

285 """ 

286 if not keys: 

287 raise ValueError("No keys provided") 

288 dsk = dict(dsk) 

289 work = list(keys) 

290 seen = set() 

291 while work: 

292 k = work.pop() 

293 if k in seen or k not in dsk: 

294 continue 

295 seen.add(k) 

296 t = dsk[k] 

297 if isinstance(t, Alias): 

298 target_key = t.target 

299 # Rules for when we allow to collapse an alias 

300 # 1. The target key is not in the keys set. The keys set is what the 

301 # user is requesting and by collapsing we'd no longer be able to 

302 # return that result. 

303 # 2. The target key is in fact part of dsk. If it isn't this could 

304 # point to a persisted dependency and we cannot collapse it. 

305 # 3. The target key has only one dependent which is the key we're 

306 # currently looking at. This means that there is a one to one 

307 # relation between this and the target key in which case we can 

308 # collapse them. 

309 # Note: If target was an alias as well, we could continue with 

310 # more advanced optimizations but this isn't implemented, yet 

311 if ( 

312 target_key not in keys 

313 and target_key in dsk 

314 # Note: whenever we're performing a collapse, we're not updating 

315 # the dependents. The length == 1 should still be sufficient for 

316 # chains of these aliases 

317 and len(dependents[target_key]) == 1 

318 ): 

319 tnew = dsk.pop(target_key).copy() 

320 

321 dsk[k] = tnew 

322 tnew.key = k 

323 if isinstance(tnew, Alias): 

324 work.append(k) 

325 seen.discard(k) 

326 else: 

327 work.extend(tnew.dependencies) 

328 

329 work.extend(t.dependencies) 

330 return dsk 

331 

332 

333class TaskRef: 

334 val: KeyType 

335 __slots__ = ("key",) 

336 

337 def __init__(self, key: KeyType): 

338 self.key = key 

339 

340 def __str__(self): 

341 return str(self.key) 

342 

343 def __repr__(self): 

344 return f"{type(self).__name__}({self.key!r})" 

345 

346 def __hash__(self) -> int: 

347 return hash(self.key) 

348 

349 def __eq__(self, value: object) -> bool: 

350 if not isinstance(value, TaskRef): 

351 return False 

352 return self.key == value.key 

353 

354 def __reduce__(self): 

355 return TaskRef, (self.key,) 

356 

357 def substitute(self, subs: dict, key: KeyType | None = None) -> TaskRef | GraphNode: 

358 if self.key in subs: 

359 val = subs[self.key] 

360 if isinstance(val, GraphNode): 

361 return val.substitute({}, key=self.key) 

362 elif isinstance(val, TaskRef): 

363 return val 

364 else: 

365 return TaskRef(val) 

366 return self 

367 

368 

369def _is_dask_future(obj: object) -> bool: 

370 """Check if obj is a dask Future (TaskRef or duck-typed with __dask_future__). 

371 

372 This supports both distributed.Future (which inherits from TaskRef) and 

373 third-party scheduler futures that set __dask_future__ = True. 

374 """ 

375 return isinstance(obj, TaskRef) or getattr(obj, "__dask_future__", False) 

376 

377 

378class GraphNode: 

379 key: KeyType 

380 _dependencies: frozenset 

381 

382 __slots__ = tuple(__annotations__) 

383 

384 def ref(self): 

385 return Alias(self.key) 

386 

387 def copy(self): 

388 raise NotImplementedError 

389 

390 @property 

391 def data_producer(self) -> bool: 

392 return False 

393 

394 @property 

395 def dependencies(self) -> frozenset: 

396 return self._dependencies 

397 

398 @property 

399 def block_fusion(self) -> bool: 

400 return False 

401 

402 def _verify_values(self, values: tuple | dict) -> None: 

403 if not self.dependencies: 

404 return 

405 if missing := set(self.dependencies) - set(values): 

406 raise RuntimeError(f"Not enough arguments provided: missing keys {missing}") 

407 

408 def __call__(self, values) -> Any: 

409 raise NotImplementedError("Not implemented") 

410 

411 def __eq__(self, value: object) -> bool: 

412 if type(value) is not type(self): 

413 return False 

414 

415 from dask.tokenize import tokenize 

416 

417 return tokenize(self) == tokenize(value) 

418 

419 @property 

420 def is_coro(self) -> bool: 

421 return False 

422 

423 def __sizeof__(self) -> int: 

424 all_slots = self.get_all_slots() 

425 return sum(sizeof(getattr(self, sl)) for sl in all_slots) + sys.getsizeof( 

426 type(self) 

427 ) 

428 

429 def substitute( 

430 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

431 ) -> GraphNode: 

432 """Substitute a dependency with a new value. The new value either has to 

433 be a new valid key or a GraphNode to replace the dependency entirely. 

434 

435 The GraphNode will not be mutated but instead a shallow copy will be 

436 returned. The substitution will be performed eagerly. 

437 

438 Parameters 

439 ---------- 

440 subs : dict[KeyType, KeyType | GraphNode] 

441 The mapping describing the substitutions to be made. 

442 key : KeyType | None, optional 

443 The key of the new GraphNode object. If None provided, the key of 

444 the old one will be reused. 

445 """ 

446 raise NotImplementedError 

447 

448 @staticmethod 

449 def fuse(*tasks: GraphNode, key: KeyType | None = None) -> GraphNode: 

450 """Fuse a set of tasks into a single task. 

451 

452 The tasks are fused into a single task that will execute the tasks in a 

453 subgraph. The internal tasks are no longer accessible from the outside. 

454 

455 All provided tasks must form a valid subgraph that will reduce to a 

456 single key. If multiple outputs are possible with the provided tasks, an 

457 exception will be raised. 

458 

459 The tasks will not be rewritten but instead a new Task will be created 

460 that will merely reference the old task objects. This way, Task objects 

461 may be reused in multiple fused tasks. 

462 

463 Parameters 

464 ---------- 

465 key : KeyType | None, optional 

466 The key of the new Task object. If None provided, the key of the 

467 final task will be used. 

468 

469 See also 

470 -------- 

471 GraphNode.substitute : Easier substitution of dependencies 

472 """ 

473 if any(t.key is None for t in tasks): 

474 raise ValueError("Cannot fuse tasks with missing keys") 

475 if len(tasks) == 1: 

476 return tasks[0].substitute({}, key=key) 

477 all_keys = set() 

478 all_deps: set[KeyType] = set() 

479 for t in tasks: 

480 all_deps.update(t.dependencies) 

481 all_keys.add(t.key) 

482 external_deps = tuple(sorted(all_deps - all_keys, key=hash)) 

483 leafs = all_keys - all_deps 

484 if len(leafs) > 1: 

485 raise ValueError(f"Cannot fuse tasks with multiple outputs {leafs}") 

486 

487 outkey = leafs.pop() 

488 return Task( 

489 key or outkey, 

490 _execute_subgraph, 

491 {t.key: t for t in tasks}, 

492 outkey, 

493 external_deps, 

494 *(TaskRef(k) for k in external_deps), 

495 _data_producer=any(t.data_producer for t in tasks), 

496 ) 

497 

498 @classmethod 

499 @lru_cache 

500 def get_all_slots(cls): 

501 slots = list() 

502 for c in cls.mro(): 

503 slots.extend(getattr(c, "__slots__", ())) 

504 # Interestingly, sorting this causes the nested containers to pickle 

505 # more efficiently 

506 return sorted(set(slots)) 

507 

508 

509_no_deps: frozenset = frozenset() 

510 

511 

512class Alias(GraphNode): 

513 target: KeyType 

514 __slots__ = tuple(__annotations__) 

515 

516 def __init__( 

517 self, key: KeyType | TaskRef, target: Alias | TaskRef | KeyType | None = None 

518 ): 

519 if isinstance(key, TaskRef): 

520 key = key.key 

521 self.key = key 

522 if target is None: 

523 target = key 

524 if isinstance(target, Alias): 

525 target = target.target 

526 if isinstance(target, TaskRef): 

527 target = target.key 

528 self.target = target 

529 self._dependencies = frozenset((self.target,)) 

530 

531 def __reduce__(self): 

532 return Alias, (self.key, self.target) 

533 

534 def copy(self): 

535 return Alias(self.key, self.target) 

536 

537 def substitute( 

538 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

539 ) -> GraphNode: 

540 if self.key in subs or self.target in subs: 

541 sub_key = subs.get(self.key, self.key) 

542 val = subs.get(self.target, self.target) 

543 if sub_key == self.key and val == self.target: 

544 return self 

545 if isinstance(val, (GraphNode, TaskRef)): 

546 return val.substitute({}, key=key) 

547 if key is None and isinstance(sub_key, GraphNode): 

548 raise RuntimeError( 

549 f"Invalid substitution encountered {self.key!r} -> {sub_key}" 

550 ) 

551 return Alias(key or sub_key, val) # type: ignore [arg-type] 

552 return self 

553 

554 def __dask_tokenize__(self): 

555 return (type(self).__name__, self.key, self.target) 

556 

557 def __call__(self, values=()): 

558 self._verify_values(values) 

559 return values[self.target] 

560 

561 def __repr__(self): 

562 if self.key != self.target: 

563 return f"Alias({self.key!r}->{self.target!r})" 

564 else: 

565 return f"Alias({self.key!r})" 

566 

567 def __eq__(self, value: object) -> bool: 

568 if not isinstance(value, Alias): 

569 return False 

570 if self.key != value.key: 

571 return False 

572 return self.target == value.target 

573 

574 

575class DataNode(GraphNode): 

576 value: Any 

577 typ: type 

578 __slots__ = tuple(__annotations__) 

579 

580 def __init__(self, key: Any, value: Any): 

581 if key is None: 

582 key = (type(value).__name__, next(_anom_count)) 

583 self.key = key 

584 self.value = value 

585 self.typ = type(value) 

586 self._dependencies = _no_deps 

587 

588 @property 

589 def data_producer(self) -> bool: 

590 return True 

591 

592 def copy(self): 

593 return DataNode(self.key, self.value) 

594 

595 def __call__(self, values=()): 

596 return self.value 

597 

598 def __repr__(self): 

599 return f"DataNode({self.value!r})" 

600 

601 def __reduce__(self): 

602 return (DataNode, (self.key, self.value)) 

603 

604 def __dask_tokenize__(self): 

605 from dask.base import tokenize 

606 

607 return (type(self).__name__, tokenize(self.value)) 

608 

609 def substitute( 

610 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

611 ) -> DataNode: 

612 if key is not None and key != self.key: 

613 return DataNode(key, self.value) 

614 return self 

615 

616 def __iter__(self): 

617 return iter(self.value) 

618 

619 

620def _get_dependencies(obj: object) -> set | frozenset: 

621 if _is_dask_future(obj): 

622 return obj.key # type: ignore[attr-defined] 

623 elif isinstance(obj, GraphNode): 

624 return obj.dependencies 

625 elif isinstance(obj, dict): 

626 if not obj: 

627 return _no_deps 

628 return set().union(*map(_get_dependencies, obj.values())) 

629 elif isinstance(obj, (list, tuple, frozenset, set)): 

630 if not obj: 

631 return _no_deps 

632 return set().union(*map(_get_dependencies, obj)) 

633 return _no_deps 

634 

635 

636class Task(GraphNode): 

637 func: Callable 

638 args: tuple 

639 kwargs: dict 

640 _data_producer: bool 

641 _token: str | None 

642 _is_coro: bool | None 

643 _repr: str | None 

644 

645 __slots__ = tuple(__annotations__) 

646 

647 def __init__( 

648 self, 

649 key: Any, 

650 func: Callable, 

651 /, 

652 *args: Any, 

653 _data_producer: bool = False, 

654 **kwargs: Any, 

655 ): 

656 self.key = key 

657 self.func = func 

658 if isinstance(func, Task): 

659 raise TypeError("Cannot nest tasks") 

660 

661 self.args = args 

662 self.kwargs = kwargs 

663 _dependencies: set[KeyType] | None = None 

664 for a in itertools.chain(args, kwargs.values()): 

665 if isinstance(a, TaskRef): 

666 if _dependencies is None: 

667 _dependencies = {a.key} 

668 else: 

669 _dependencies.add(a.key) 

670 elif isinstance(a, GraphNode) and a.dependencies: 

671 if _dependencies is None: 

672 _dependencies = set(a.dependencies) 

673 else: 

674 _dependencies.update(a.dependencies) 

675 if _dependencies: 

676 self._dependencies = frozenset(_dependencies) 

677 else: 

678 self._dependencies = _no_deps 

679 self._is_coro = None 

680 self._token = None 

681 self._repr = None 

682 self._data_producer = _data_producer 

683 

684 @property 

685 def data_producer(self) -> bool: 

686 return self._data_producer 

687 

688 def has_subgraph(self) -> bool: 

689 return self.func == _execute_subgraph 

690 

691 def copy(self): 

692 return type(self)( 

693 self.key, 

694 self.func, 

695 *self.args, 

696 **self.kwargs, 

697 ) 

698 

699 def __hash__(self): 

700 return hash(self._get_token()) 

701 

702 def _get_token(self) -> str: 

703 if self._token: 

704 return self._token 

705 from dask.base import tokenize 

706 

707 self._token = tokenize( 

708 ( 

709 type(self).__name__, 

710 self.func, 

711 self.args, 

712 self.kwargs, 

713 ) 

714 ) 

715 return self._token 

716 

717 def __dask_tokenize__(self): 

718 return self._get_token() 

719 

720 def __repr__(self) -> str: 

721 # When `Task` is deserialized the constructor will not run and 

722 # `self._repr` is thus undefined. 

723 if not hasattr(self, "_repr") or not self._repr: 

724 head = funcname(self.func) 

725 tail = ")" 

726 label_size = 40 

727 args = self.args 

728 kwargs = self.kwargs 

729 if args or kwargs: 

730 label_size2 = int( 

731 (label_size - len(head) - len(tail) - len(str(self.key))) 

732 // (len(args) + len(kwargs)) 

733 ) 

734 if args: 

735 if label_size2 > 5: 

736 args_repr = ", ".join(repr(t) for t in args) 

737 else: 

738 args_repr = "..." 

739 else: 

740 args_repr = "" 

741 if kwargs: 

742 if label_size2 > 5: 

743 kwargs_repr = ", " + ", ".join( 

744 f"{k}={v!r}" for k, v in sorted(kwargs.items()) 

745 ) 

746 else: 

747 kwargs_repr = ", ..." 

748 else: 

749 kwargs_repr = "" 

750 self._repr = f"<Task {self.key!r} {head}({args_repr}{kwargs_repr}{tail}>" 

751 return self._repr 

752 

753 def __call__(self, values=()): 

754 self._verify_values(values) 

755 

756 def _eval(a): 

757 if isinstance(a, GraphNode): 

758 return a({k: values[k] for k in a.dependencies}) 

759 elif isinstance(a, TaskRef): 

760 return values[a.key] 

761 else: 

762 return a 

763 

764 new_argspec = tuple(map(_eval, self.args)) 

765 if self.kwargs: 

766 kwargs = {k: _eval(kw) for k, kw in self.kwargs.items()} 

767 return self.func(*new_argspec, **kwargs) 

768 return self.func(*new_argspec) 

769 

770 def __setstate__(self, state): 

771 slots = self.__class__.get_all_slots() 

772 for sl, val in zip(slots, state): 

773 setattr(self, sl, val) 

774 

775 def __getstate__(self): 

776 slots = self.__class__.get_all_slots() 

777 return tuple(getattr(self, sl) for sl in slots) 

778 

779 @property 

780 def is_coro(self): 

781 if self._is_coro is None: 

782 # Note: Can't use cached_property on objects without __dict__ 

783 try: 

784 from distributed.utils import iscoroutinefunction 

785 

786 self._is_coro = iscoroutinefunction(self.func) 

787 except Exception: 

788 self._is_coro = False 

789 return self._is_coro 

790 

791 def substitute( 

792 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

793 ) -> Task: 

794 subs_filtered = { 

795 # subs can be as large as the whole graph; do not iterate over it! 

796 k: v 

797 for k in (self.dependencies & subs.keys()) 

798 if (v := subs[k]) != k 

799 } 

800 extras = _extra_args(type(self)) # type: ignore[arg-type] 

801 extra_kwargs = { 

802 name: getattr(self, name) for name in extras if name not in {"key", "func"} 

803 } 

804 if subs_filtered: 

805 new_args = tuple( 

806 ( 

807 a.substitute(subs_filtered) 

808 if isinstance(a, (GraphNode, TaskRef)) 

809 else a 

810 ) 

811 for a in self.args 

812 ) 

813 new_kwargs = { 

814 k: ( 

815 v.substitute(subs_filtered) 

816 if isinstance(v, (GraphNode, TaskRef)) 

817 else v 

818 ) 

819 for k, v in self.kwargs.items() 

820 } 

821 return type(self)( 

822 key or self.key, 

823 self.func, 

824 *new_args, 

825 **new_kwargs, # type: ignore[arg-type] 

826 **extra_kwargs, 

827 ) 

828 elif key is None or key == self.key: 

829 return self 

830 else: 

831 # Rename 

832 return type(self)( 

833 key, 

834 self.func, 

835 *self.args, 

836 **self.kwargs, 

837 **extra_kwargs, 

838 ) 

839 

840 

841class NestedContainer(Task, Iterable): 

842 constructor: Callable 

843 klass: type 

844 __slots__ = tuple(__annotations__) 

845 

846 def __init__( 

847 self, 

848 /, 

849 *args: Any, 

850 **kwargs: Any, 

851 ): 

852 if len(args) == 1 and isinstance(args[0], self.klass): 

853 args = args[0] # type: ignore[assignment] 

854 super().__init__( 

855 None, 

856 self.to_container, 

857 *args, 

858 constructor=self.constructor, 

859 **kwargs, 

860 ) 

861 

862 def __getstate__(self): 

863 state = super().__getstate__() 

864 state = list(state) 

865 slots = self.__class__.get_all_slots() 

866 ix = slots.index("kwargs") 

867 # The constructor as a kwarg is redundant since this is encoded in the 

868 # class itself. Serializing the builtin types is not trivial 

869 # This saves about 15% of overhead 

870 state[ix] = state[ix].copy() 

871 state[ix].pop("constructor", None) 

872 return state 

873 

874 def __setstate__(self, state): 

875 super().__setstate__(state) 

876 self.kwargs["constructor"] = self.__class__.constructor 

877 return self 

878 

879 def __repr__(self): 

880 return f"{type(self).__name__}({self.args})" 

881 

882 def substitute( 

883 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

884 ) -> NestedContainer: 

885 subs_filtered = { 

886 # subs can be as large as the whole graph; do not iterate over it! 

887 k: v 

888 for k in (self.dependencies & subs.keys()) 

889 if (v := subs[k]) != k 

890 } 

891 if not subs_filtered: 

892 return self 

893 return type(self)( 

894 *( 

895 ( 

896 a.substitute(subs_filtered) 

897 if isinstance(a, (GraphNode, TaskRef)) 

898 else a 

899 ) 

900 for a in self.args 

901 ) 

902 ) 

903 

904 def __dask_tokenize__(self): 

905 from dask.tokenize import tokenize 

906 

907 return ( 

908 type(self).__name__, 

909 self.klass, 

910 sorted(tokenize(a) for a in self.args), 

911 ) 

912 

913 return super().__dask_tokenize__() 

914 

915 @staticmethod 

916 def to_container(*args, constructor): 

917 return constructor(args) 

918 

919 def __iter__(self): 

920 yield from self.args 

921 

922 

923class List(NestedContainer): 

924 constructor = klass = list 

925 

926 

927class Tuple(NestedContainer): 

928 constructor = klass = tuple 

929 

930 

931class Set(NestedContainer): 

932 constructor = klass = set 

933 

934 

935class Dict(NestedContainer, Mapping): 

936 klass = dict 

937 

938 def __init__(self, /, *args: Any, **kwargs: Any): 

939 if args: 

940 assert not kwargs 

941 if len(args) == 1: 

942 args = args[0] 

943 if isinstance(args, dict): # type: ignore[unreachable] 

944 args = tuple(itertools.chain(*args.items())) # type: ignore[unreachable] 

945 elif isinstance(args, (list, tuple)): 

946 if all( 

947 len(el) == 2 if isinstance(el, (list, tuple)) else False 

948 for el in args 

949 ): 

950 args = tuple(itertools.chain(*args)) 

951 else: 

952 raise ValueError("Invalid argument provided") 

953 

954 if len(args) % 2 != 0: 

955 raise ValueError("Invalid number of arguments provided") 

956 

957 elif kwargs: 

958 assert not args 

959 args = tuple(itertools.chain(*kwargs.items())) 

960 

961 super().__init__(*args) 

962 

963 def __repr__(self): 

964 values = ", ".join(f"{k}: {v}" for k, v in batched(self.args, 2, strict=True)) 

965 return f"Dict({values})" 

966 

967 def substitute( 

968 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

969 ) -> Dict: 

970 subs_filtered = { 

971 # subs can be as large as the whole graph; do not iterate over it! 

972 k: v 

973 for k in (self.dependencies & subs.keys()) 

974 if (v := subs[k]) != k 

975 } 

976 if not subs_filtered: 

977 return self 

978 

979 new_args = [] 

980 for arg in self.args: 

981 new_arg = ( 

982 arg.substitute(subs_filtered) 

983 if isinstance(arg, (GraphNode, TaskRef)) 

984 else arg 

985 ) 

986 new_args.append(new_arg) 

987 return type(self)(new_args) 

988 

989 def __iter__(self): 

990 yield from self.args[::2] 

991 

992 def __len__(self): 

993 return len(self.args) // 2 

994 

995 def __getitem__(self, key): 

996 for k, v in batched(self.args, 2, strict=True): 

997 if k == key: 

998 return v 

999 raise KeyError(key) 

1000 

1001 @staticmethod 

1002 def constructor(args): 

1003 return dict(batched(args, 2, strict=True)) 

1004 

1005 

1006class DependenciesMapping(MutableMapping): 

1007 def __init__(self, dsk): 

1008 self.dsk = dsk 

1009 self._removed = set() 

1010 # Set a copy of dsk to avoid dct resizing 

1011 self._cache = dsk.copy() 

1012 self._cache.clear() 

1013 

1014 def __getitem__(self, key): 

1015 if (val := self._cache.get(key)) is not None: 

1016 return val 

1017 else: 

1018 v = self.dsk[key] 

1019 try: 

1020 deps = v.dependencies 

1021 except AttributeError: 

1022 from dask.core import get_dependencies 

1023 

1024 deps = get_dependencies(self.dsk, task=v) 

1025 

1026 if self._removed: 

1027 # deps is a frozenset but for good measure, let's not use -= since 

1028 # that _may_ perform an inplace mutation 

1029 deps = deps - self._removed 

1030 self._cache[key] = deps 

1031 return deps 

1032 

1033 def __iter__(self): 

1034 return iter(self.dsk) 

1035 

1036 def __delitem__(self, key: Any) -> None: 

1037 self._cache.clear() 

1038 self._removed.add(key) 

1039 

1040 def __setitem__(self, key: Any, value: Any) -> None: 

1041 raise NotImplementedError 

1042 

1043 def __len__(self) -> int: 

1044 return len(self.dsk) 

1045 

1046 

1047class _DevNullMapping(MutableMapping): 

1048 def __getitem__(self, key): 

1049 raise KeyError(key) 

1050 

1051 def __setitem__(self, key, value): 

1052 pass 

1053 

1054 def __delitem__(self, key): 

1055 pass 

1056 

1057 def __len__(self): 

1058 return 0 

1059 

1060 def __iter__(self): 

1061 return iter(()) 

1062 

1063 

1064def execute_graph( 

1065 dsk: Iterable[GraphNode] | Mapping[KeyType, GraphNode], 

1066 cache: MutableMapping[KeyType, object] | None = None, 

1067 keys: Container[KeyType] | None = None, 

1068) -> MutableMapping[KeyType, object]: 

1069 """Execute a given graph. 

1070 

1071 The graph is executed in topological order as defined by dask.order until 

1072 all leaf nodes, i.e. nodes without any dependents, are reached. The returned 

1073 dictionary contains the results of the leaf nodes. 

1074 

1075 If keys are required that are not part of the graph, they can be provided in the `cache` argument. 

1076 

1077 If `keys` is provided, the result will contain only values that are part of the `keys` set. 

1078 

1079 """ 

1080 if isinstance(dsk, (list, tuple, set, frozenset)): 

1081 dsk = {t.key: t for t in dsk} 

1082 else: 

1083 assert isinstance(dsk, dict) 

1084 

1085 refcount: defaultdict[KeyType, int] = defaultdict(int) 

1086 for vals in DependenciesMapping(dsk).values(): 

1087 for val in vals: 

1088 refcount[val] += 1 

1089 

1090 cache = cache or {} 

1091 from dask.order import order 

1092 

1093 priorities = order(dsk) 

1094 

1095 for key, node in sorted(dsk.items(), key=lambda it: priorities[it[0]]): 

1096 cache[key] = node(cache) 

1097 for dep in node.dependencies: 

1098 refcount[dep] -= 1 

1099 if refcount[dep] == 0 and keys and dep not in keys: 

1100 del cache[dep] 

1101 

1102 return cache 

1103 

1104 

1105def fuse_linear_task_spec(dsk, keys): 

1106 """ 

1107 keys are the keys from the graph that are requested by a computation. We 

1108 can't fuse those together. 

1109 """ 

1110 from dask.core import reverse_dict 

1111 from dask.optimization import default_fused_keys_renamer 

1112 

1113 keys = set(keys) 

1114 dependencies = DependenciesMapping(dsk) 

1115 dependents = reverse_dict(dependencies) 

1116 

1117 seen = set() 

1118 result = {} 

1119 

1120 for key in dsk: 

1121 if key in seen: 

1122 continue 

1123 

1124 seen.add(key) 

1125 

1126 deps = dependencies[key] 

1127 dependents_key = dependents[key] 

1128 

1129 if len(deps) != 1 and len(dependents_key) != 1 or dsk[key].block_fusion: 

1130 result[key] = dsk[key] 

1131 continue 

1132 

1133 linear_chain = [dsk[key]] 

1134 top_key = key 

1135 

1136 # Walk towards the leafs as long as the nodes have a single dependency 

1137 # and a single dependent, we can't fuse two nodes of an intermediate node 

1138 # is the source for 2 dependents 

1139 while len(deps) == 1: 

1140 (new_key,) = deps 

1141 if new_key in seen: 

1142 break 

1143 seen.add(new_key) 

1144 if new_key not in dsk: 

1145 # This can happen if a future is in the graph, the dependency mapping 

1146 # adds the key that is referenced by the future as a dependency 

1147 # see test_futures_to_delayed_array 

1148 break 

1149 if ( 

1150 len(dependents[new_key]) != 1 

1151 or dsk[new_key].block_fusion 

1152 or new_key in keys 

1153 ): 

1154 result[new_key] = dsk[new_key] 

1155 break 

1156 # backwards comp for new names, temporary until is_rootish is removed 

1157 linear_chain.insert(0, dsk[new_key]) 

1158 deps = dependencies[new_key] 

1159 

1160 # Walk the tree towards the root as long as the nodes have a single dependent 

1161 # and a single dependency, we can't fuse two nodes if node has multiple 

1162 # dependencies 

1163 while len(dependents_key) == 1 and top_key not in keys: 

1164 new_key = dependents_key.pop() 

1165 if new_key in seen: 

1166 break 

1167 seen.add(new_key) 

1168 if len(dependencies[new_key]) != 1 or dsk[new_key].block_fusion: 

1169 # Exit if the dependent has multiple dependencies, triangle 

1170 result[new_key] = dsk[new_key] 

1171 break 

1172 linear_chain.append(dsk[new_key]) 

1173 top_key = new_key 

1174 dependents_key = dependents[new_key] 

1175 

1176 if len(linear_chain) == 1: 

1177 result[top_key] = linear_chain[0] 

1178 else: 

1179 # Renaming the keys is necessary to preserve the rootish detection for now 

1180 renamed_key = default_fused_keys_renamer([tsk.key for tsk in linear_chain]) 

1181 result[renamed_key] = Task.fuse(*linear_chain, key=renamed_key) 

1182 if renamed_key != top_key: 

1183 # Having the same prefixes can result in the same key, i.e. getitem-hash -> getitem-hash 

1184 result[top_key] = Alias(top_key, target=renamed_key) 

1185 return result 

1186 

1187 

1188def cull( 

1189 dsk: dict[KeyType, GraphNode], keys: Iterable[KeyType] 

1190) -> dict[KeyType, GraphNode]: 

1191 if not isinstance(keys, (list, set, tuple)): 

1192 raise TypeError( 

1193 f"Expected list, set or tuple for keys, got {type(keys).__name__}" 

1194 ) 

1195 if len(keys) == len(dsk): 

1196 return dsk 

1197 work = set(keys) 

1198 seen: set[KeyType] = set() 

1199 dsk2 = {} 

1200 wpop = work.pop 

1201 wupdate = work.update 

1202 sadd = seen.add 

1203 while work: 

1204 k = wpop() 

1205 if k in seen or k not in dsk: 

1206 continue 

1207 sadd(k) 

1208 dsk2[k] = v = dsk[k] 

1209 wupdate(v.dependencies) 

1210 return dsk2 

1211 

1212 

1213@functools.cache 

1214def _extra_args(typ: type) -> set[str]: 

1215 import inspect 

1216 

1217 sig = inspect.signature(typ) 

1218 extras = set() 

1219 for name, param in sig.parameters.items(): 

1220 if param.kind in ( 

1221 inspect.Parameter.VAR_POSITIONAL, 

1222 inspect.Parameter.VAR_KEYWORD, 

1223 ): 

1224 continue 

1225 if name in typ.get_all_slots(): # type: ignore[attr-defined] 

1226 extras.add(name) 

1227 return extras