Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/_task_spec.py: 25%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

658 statements  

1from __future__ import annotations 

2 

3""" Task specification for dask 

4 

5This module contains the task specification for dask. It is used to represent 

6runnable (task) and non-runnable (data) nodes in a dask graph. 

7 

8Simple examples of how to express tasks in dask 

9----------------------------------------------- 

10 

11.. code-block:: python 

12 

13 func("a", "b") ~ Task("key", func, "a", "b") 

14 

15 [func("a"), func("b")] ~ [Task("key-1", func, "a"), Task("key-2", func, "b")] 

16 

17 {"a": func("b")} ~ {"a": Task("a", func, "b")} 

18 

19 "literal-string" ~ DataNode("key", "literal-string") 

20 

21 

22Keys, Aliases and TaskRefs 

23------------------------- 

24 

25Keys are used to identify tasks in a dask graph. Every `GraphNode` instance has a 

26key attribute that _should_ reference the key in the dask graph. 

27 

28.. code-block:: python 

29 

30 {"key": Task("key", func, "a")} 

31 

32Referencing other tasks is possible by using either one of `Alias` or a 

33`TaskRef`. 

34 

35.. code-block:: python 

36 

37 # TaskRef can be used to provide the name of the reference explicitly 

38 t = Task("key", func, TaskRef("key-1")) 

39 

40 # If a task is still in scope, the method `ref` can be used for convenience 

41 t2 = Task("key2", func2, t.ref()) 

42 

43 

44Executing a task 

45---------------- 

46 

47A task can be executed by calling it with a dictionary of values. The values 

48should contain the dependencies of the task. 

49 

50.. code-block:: python 

51 

52 t = Task("key", add, TaskRef("a"), TaskRef("b")) 

53 assert t.dependencies == {"a", "b"} 

54 t({"a": 1, "b": 2}) == 3 

55 

56""" 

57import functools 

58import itertools 

59import sys 

60from collections import defaultdict 

61from collections.abc import Callable, Container, Iterable, Mapping, MutableMapping 

62from functools import lru_cache, partial 

63from typing import Any, TypeVar, cast 

64 

65from dask.sizeof import sizeof 

66from dask.typing import Key as KeyType 

67from dask.utils import funcname, is_namedtuple_instance 

68 

69_T = TypeVar("_T") 

70 

71 

72# Ported from more-itertools 

73# https://github.com/more-itertools/more-itertools/blob/c8153e2801ade2527f3a6c8b623afae93f5a1ce1/more_itertools/recipes.py#L944-L973 

74def _batched(iterable, n, *, strict=False): 

75 """Batch data into tuples of length *n*. If the number of items in 

76 *iterable* is not divisible by *n*: 

77 * The last batch will be shorter if *strict* is ``False``. 

78 * :exc:`ValueError` will be raised if *strict* is ``True``. 

79 

80 >>> list(batched('ABCDEFG', 3)) 

81 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)] 

82 

83 On Python 3.13 and above, this is an alias for :func:`itertools.batched`. 

84 """ 

85 if n < 1: 

86 raise ValueError("n must be at least one") 

87 it = iter(iterable) 

88 while batch := tuple(itertools.islice(it, n)): 

89 if strict and len(batch) != n: 

90 raise ValueError("batched(): incomplete batch") 

91 yield batch 

92 

93 

94if sys.hexversion >= 0x30D00A2: 

95 

96 def batched(iterable, n, *, strict=False): 

97 return itertools.batched(iterable, n, strict=strict) 

98 

99else: 

100 batched = _batched 

101 

102 batched.__doc__ = _batched.__doc__ 

103# End port 

104 

105 

106def identity(*args): 

107 return args 

108 

109 

110def _identity_cast(*args, typ): 

111 return typ(args) 

112 

113 

114_anom_count = itertools.count() 

115 

116 

117def parse_input(obj: Any) -> object: 

118 """Tokenize user input into GraphNode objects 

119 

120 Note: This is similar to `convert_legacy_task` but does not 

121 - compare any values to a global set of known keys to infer references/futures 

122 - parse tuples and interprets them as runnable tasks 

123 - Deal with SubgraphCallables 

124 

125 Parameters 

126 ---------- 

127 obj : _type_ 

128 _description_ 

129 

130 Returns 

131 ------- 

132 _type_ 

133 _description_ 

134 """ 

135 if isinstance(obj, GraphNode): 

136 return obj 

137 

138 if _is_dask_future(obj): 

139 return Alias(obj.key) 

140 

141 if isinstance(obj, dict): 

142 parsed_dict = {k: parse_input(v) for k, v in obj.items()} 

143 if any(isinstance(v, GraphNode) for v in parsed_dict.values()): 

144 return Dict(parsed_dict) 

145 

146 if isinstance(obj, (list, set, tuple)): 

147 parsed_collection = tuple(parse_input(o) for o in obj) 

148 if any(isinstance(o, GraphNode) for o in parsed_collection): 

149 if isinstance(obj, list): 

150 return List(*parsed_collection) 

151 if isinstance(obj, set): 

152 return Set(*parsed_collection) 

153 if isinstance(obj, tuple): 

154 if is_namedtuple_instance(obj): 

155 return _wrap_namedtuple_task(None, obj, parse_input) 

156 return Tuple(*parsed_collection) 

157 

158 return obj 

159 

160 

161def _wrap_namedtuple_task(k, obj, parser): 

162 if hasattr(obj, "__getnewargs_ex__"): 

163 new_args, kwargs = obj.__getnewargs_ex__() 

164 kwargs = {k: parser(v) for k, v in kwargs.items()} 

165 elif hasattr(obj, "__getnewargs__"): 

166 new_args = obj.__getnewargs__() 

167 kwargs = {} 

168 

169 args_converted = parse_input(type(new_args)(map(parser, new_args))) 

170 

171 return Task( 

172 k, partial(_instantiate_named_tuple, type(obj)), args_converted, Dict(kwargs) 

173 ) 

174 

175 

176def _instantiate_named_tuple(typ, args, kwargs): 

177 return typ(*args, **kwargs) 

178 

179 

180class _MultiContainer(Container): 

181 container: tuple 

182 __slots__ = ("container",) 

183 

184 def __init__(self, *container): 

185 self.container = container 

186 

187 def __contains__(self, o: object) -> bool: 

188 return any(o in c for c in self.container) 

189 

190 

191SubgraphType = None 

192 

193 

194def _execute_subgraph(inner_dsk, outkey, inkeys, *dependencies): 

195 final = {} 

196 final.update(inner_dsk) 

197 for k, v in zip(inkeys, dependencies): 

198 final[k] = DataNode(None, v) 

199 res = execute_graph(final, keys=[outkey]) 

200 return res[outkey] 

201 

202 

203def convert_legacy_task( 

204 key: KeyType | None, 

205 task: _T, 

206 all_keys: Container, 

207) -> GraphNode | _T: 

208 if isinstance(task, GraphNode): 

209 return task 

210 

211 if type(task) is tuple and task and callable(task[0]): 

212 func, args = task[0], task[1:] 

213 new_args = [] 

214 new: object 

215 for a in args: 

216 if isinstance(a, dict): 

217 new = Dict(a) 

218 else: 

219 new = convert_legacy_task(None, a, all_keys) 

220 new_args.append(new) 

221 return Task(key, func, *new_args) 

222 try: 

223 if isinstance(task, (int, float, str, tuple)): 

224 if task in all_keys: 

225 if key is None: 

226 return Alias(task) 

227 else: 

228 return Alias(key, target=task) 

229 except TypeError: 

230 # Unhashable 

231 pass 

232 

233 if isinstance(task, (list, tuple, set, frozenset)): 

234 if is_namedtuple_instance(task): 

235 return _wrap_namedtuple_task( 

236 key, 

237 task, 

238 partial( 

239 convert_legacy_task, 

240 None, 

241 all_keys=all_keys, 

242 ), 

243 ) 

244 else: 

245 parsed_args = tuple(convert_legacy_task(None, t, all_keys) for t in task) 

246 if any(isinstance(a, GraphNode) for a in parsed_args): 

247 return Task(key, _identity_cast, *parsed_args, typ=type(task)) 

248 else: 

249 return cast(_T, type(task)(parsed_args)) 

250 elif _is_dask_future(task): 

251 if key is None: 

252 return Alias(task.key) # type: ignore[attr-defined] 

253 else: 

254 return Alias(key, target=task.key) # type: ignore[attr-defined] 

255 else: 

256 return task 

257 

258 

259def convert_legacy_graph( 

260 dsk: Mapping, 

261 all_keys: Container | None = None, 

262): 

263 if all_keys is None: 

264 all_keys = set(dsk) 

265 new_dsk = {} 

266 for k, arg in dsk.items(): 

267 t = convert_legacy_task(k, arg, all_keys) 

268 if isinstance(t, Alias) and t.target == k: 

269 continue 

270 elif not isinstance(t, GraphNode): 

271 t = DataNode(k, t) 

272 new_dsk[k] = t 

273 return new_dsk 

274 

275 

276def resolve_aliases(dsk: dict, keys: set, dependents: dict) -> dict: 

277 """Remove trivial sequential alias chains 

278 

279 Example: 

280 

281 dsk = {'x': 1, 'y': Alias('x'), 'z': Alias('y')} 

282 

283 resolve_aliases(dsk, {'z'}, {'x': {'y'}, 'y': {'z'}}) == {'z': 1} 

284 

285 """ 

286 if not keys: 

287 raise ValueError("No keys provided") 

288 dsk = dict(dsk) 

289 work = list(keys) 

290 seen = set() 

291 while work: 

292 k = work.pop() 

293 if k in seen or k not in dsk: 

294 continue 

295 seen.add(k) 

296 t = dsk[k] 

297 if isinstance(t, Alias): 

298 target_key = t.target 

299 # Rules for when we allow to collapse an alias 

300 # 1. The target key is not in the keys set. The keys set is what the 

301 # user is requesting and by collapsing we'd no longer be able to 

302 # return that result. 

303 # 2. The target key is in fact part of dsk. If it isn't this could 

304 # point to a persisted dependency and we cannot collapse it. 

305 # 3. The target key has only one dependent which is the key we're 

306 # currently looking at. This means that there is a one to one 

307 # relation between this and the target key in which case we can 

308 # collapse them. 

309 # Note: If target was an alias as well, we could continue with 

310 # more advanced optimizations but this isn't implemented, yet 

311 if ( 

312 target_key not in keys 

313 and target_key in dsk 

314 # Note: whenever we're performing a collapse, we're not updating 

315 # the dependents. The length == 1 should still be sufficient for 

316 # chains of these aliases 

317 and len(dependents[target_key]) == 1 

318 ): 

319 tnew = dsk.pop(target_key).copy() 

320 

321 dsk[k] = tnew 

322 tnew.key = k 

323 if isinstance(tnew, Alias): 

324 work.append(k) 

325 seen.discard(k) 

326 else: 

327 work.extend(tnew.dependencies) 

328 

329 work.extend(t.dependencies) 

330 return dsk 

331 

332 

333class TaskRef: 

334 val: KeyType 

335 __slots__ = ("key",) 

336 

337 def __init__(self, key: KeyType): 

338 self.key = key 

339 

340 def __str__(self): 

341 return str(self.key) 

342 

343 def __repr__(self): 

344 return f"{type(self).__name__}({self.key!r})" 

345 

346 def __hash__(self) -> int: 

347 return hash(self.key) 

348 

349 def __eq__(self, value: object) -> bool: 

350 if not isinstance(value, TaskRef): 

351 return False 

352 return self.key == value.key 

353 

354 def __reduce__(self): 

355 return TaskRef, (self.key,) 

356 

357 def substitute(self, subs: dict, key: KeyType | None = None) -> TaskRef | GraphNode: 

358 if self.key in subs: 

359 val = subs[self.key] 

360 if isinstance(val, GraphNode): 

361 return val.substitute({}, key=self.key) 

362 elif isinstance(val, TaskRef): 

363 return val 

364 else: 

365 return TaskRef(val) 

366 return self 

367 

368 

369def _is_dask_future(obj: object) -> bool: 

370 """Check if obj is a dask Future (TaskRef or duck-typed with __dask_future__). 

371 

372 This supports both distributed.Future (which inherits from TaskRef) and 

373 third-party scheduler futures that set __dask_future__ = True. 

374 """ 

375 return isinstance(obj, TaskRef) or getattr(obj, "__dask_future__", False) 

376 

377 

378class GraphNode: 

379 key: KeyType 

380 _dependencies: frozenset 

381 

382 __slots__ = tuple(__annotations__) 

383 

384 def ref(self): 

385 return Alias(self.key) 

386 

387 def copy(self): 

388 raise NotImplementedError 

389 

390 @property 

391 def data_producer(self) -> bool: 

392 return False 

393 

394 @property 

395 def dependencies(self) -> frozenset: 

396 return self._dependencies 

397 

398 @property 

399 def block_fusion(self) -> bool: 

400 return False 

401 

402 def _verify_values(self, values: tuple | dict) -> None: 

403 if not self.dependencies: 

404 return 

405 if missing := set(self.dependencies) - set(values): 

406 raise RuntimeError(f"Not enough arguments provided: missing keys {missing}") 

407 

408 def __call__(self, values) -> Any: 

409 raise NotImplementedError("Not implemented") 

410 

411 def __eq__(self, value: object) -> bool: 

412 if type(value) is not type(self): 

413 return False 

414 

415 from dask.tokenize import tokenize 

416 

417 return tokenize(self) == tokenize(value) 

418 

419 @property 

420 def is_coro(self) -> bool: 

421 return False 

422 

423 def __sizeof__(self) -> int: 

424 all_slots = self.get_all_slots() 

425 return sum(sizeof(getattr(self, sl)) for sl in all_slots) + sys.getsizeof( 

426 type(self) 

427 ) 

428 

429 def substitute( 

430 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

431 ) -> GraphNode: 

432 """Substitute a dependency with a new value. The new value either has to 

433 be a new valid key or a GraphNode to replace the dependency entirely. 

434 

435 The GraphNode will not be mutated but instead a shallow copy will be 

436 returned. The substitution will be performed eagerly. 

437 

438 Parameters 

439 ---------- 

440 subs : dict[KeyType, KeyType | GraphNode] 

441 The mapping describing the substitutions to be made. 

442 key : KeyType | None, optional 

443 The key of the new GraphNode object. If None provided, the key of 

444 the old one will be reused. 

445 """ 

446 raise NotImplementedError 

447 

448 @staticmethod 

449 def fuse(*tasks: GraphNode, key: KeyType | None = None) -> GraphNode: 

450 """Fuse a set of tasks into a single task. 

451 

452 The tasks are fused into a single task that will execute the tasks in a 

453 subgraph. The internal tasks are no longer accessible from the outside. 

454 

455 All provided tasks must form a valid subgraph that will reduce to a 

456 single key. If multiple outputs are possible with the provided tasks, an 

457 exception will be raised. 

458 

459 The tasks will not be rewritten but instead a new Task will be created 

460 that will merely reference the old task objects. This way, Task objects 

461 may be reused in multiple fused tasks. 

462 

463 Parameters 

464 ---------- 

465 key : KeyType | None, optional 

466 The key of the new Task object. If None provided, the key of the 

467 final task will be used. 

468 

469 See also 

470 -------- 

471 GraphNode.substitute : Easier substitution of dependencies 

472 """ 

473 if any(t.key is None for t in tasks): 

474 raise ValueError("Cannot fuse tasks with missing keys") 

475 if len(tasks) == 1: 

476 return tasks[0].substitute({}, key=key) 

477 all_keys = set() 

478 all_deps: set[KeyType] = set() 

479 for t in tasks: 

480 all_deps.update(t.dependencies) 

481 all_keys.add(t.key) 

482 external_deps = tuple(sorted(all_deps - all_keys, key=hash)) 

483 leafs = all_keys - all_deps 

484 if len(leafs) > 1: 

485 raise ValueError(f"Cannot fuse tasks with multiple outputs {leafs}") 

486 

487 outkey = leafs.pop() 

488 return Task( 

489 key or outkey, 

490 _execute_subgraph, 

491 {t.key: t for t in tasks}, 

492 outkey, 

493 external_deps, 

494 *(TaskRef(k) for k in external_deps), 

495 _data_producer=any(t.data_producer for t in tasks), 

496 ) 

497 

498 @classmethod 

499 @lru_cache 

500 def get_all_slots(cls): 

501 slots = list() 

502 for c in cls.mro(): 

503 slots.extend(getattr(c, "__slots__", ())) 

504 # Interestingly, sorting this causes the nested containers to pickle 

505 # more efficiently 

506 return sorted(set(slots)) 

507 

508 

509_no_deps: frozenset = frozenset() 

510 

511 

512class Alias(GraphNode): 

513 target: KeyType 

514 __slots__ = tuple(__annotations__) 

515 

516 def __init__( 

517 self, key: KeyType | TaskRef, target: Alias | TaskRef | KeyType | None = None 

518 ): 

519 if isinstance(key, TaskRef): 

520 key = key.key 

521 self.key = key 

522 if target is None: 

523 target = key 

524 if isinstance(target, Alias): 

525 target = target.target 

526 if isinstance(target, TaskRef): 

527 target = target.key 

528 self.target = target 

529 self._dependencies = frozenset((self.target,)) 

530 

531 def __reduce__(self): 

532 return Alias, (self.key, self.target) 

533 

534 def copy(self): 

535 return Alias(self.key, self.target) 

536 

537 def substitute( 

538 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

539 ) -> GraphNode: 

540 if self.key in subs or self.target in subs: 

541 sub_key = subs.get(self.key, self.key) 

542 val = subs.get(self.target, self.target) 

543 if sub_key == self.key and val == self.target: 

544 return self 

545 if isinstance(val, (GraphNode, TaskRef)): 

546 return val.substitute({}, key=key) 

547 if key is None and isinstance(sub_key, GraphNode): 

548 raise RuntimeError( 

549 f"Invalid substitution encountered {self.key!r} -> {sub_key}" 

550 ) 

551 return Alias(key or sub_key, val) # type: ignore [arg-type] 

552 return self 

553 

554 def __dask_tokenize__(self): 

555 return (type(self).__name__, self.key, self.target) 

556 

557 def __call__(self, values=()): 

558 self._verify_values(values) 

559 return values[self.target] 

560 

561 def __repr__(self): 

562 if self.key != self.target: 

563 return f"Alias({self.key!r}->{self.target!r})" 

564 else: 

565 return f"Alias({self.key!r})" 

566 

567 def __eq__(self, value: object) -> bool: 

568 if not isinstance(value, Alias): 

569 return False 

570 if self.key != value.key: 

571 return False 

572 return self.target == value.target 

573 

574 

575class DataNode(GraphNode): 

576 value: Any 

577 typ: type 

578 __slots__ = tuple(__annotations__) 

579 

580 def __init__(self, key: Any, value: Any): 

581 if key is None: 

582 key = (type(value).__name__, next(_anom_count)) 

583 self.key = key 

584 self.value = value 

585 self.typ = type(value) 

586 self._dependencies = _no_deps 

587 

588 @property 

589 def data_producer(self) -> bool: 

590 return True 

591 

592 def copy(self): 

593 return DataNode(self.key, self.value) 

594 

595 def __call__(self, values=()): 

596 return self.value 

597 

598 def __repr__(self): 

599 return f"DataNode({self.value!r})" 

600 

601 def __reduce__(self): 

602 return (DataNode, (self.key, self.value)) 

603 

604 def __dask_tokenize__(self): 

605 from dask.base import tokenize 

606 

607 return (type(self).__name__, tokenize(self.value)) 

608 

609 def substitute( 

610 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

611 ) -> DataNode: 

612 if key is not None and key != self.key: 

613 return DataNode(key, self.value) 

614 return self 

615 

616 def __iter__(self): 

617 return iter(self.value) 

618 

619 

620def _get_dependencies(obj: object) -> set | frozenset: 

621 if _is_dask_future(obj): 

622 return obj.key # type: ignore[attr-defined] 

623 elif isinstance(obj, GraphNode): 

624 return obj.dependencies 

625 elif isinstance(obj, dict): 

626 if not obj: 

627 return _no_deps 

628 return set().union(*map(_get_dependencies, obj.values())) 

629 elif isinstance(obj, (list, tuple, frozenset, set)): 

630 if not obj: 

631 return _no_deps 

632 return set().union(*map(_get_dependencies, obj)) 

633 return _no_deps 

634 

635 

636class Task(GraphNode): 

637 func: Callable 

638 args: tuple 

639 kwargs: dict 

640 _data_producer: bool 

641 _token: str | None 

642 _is_coro: bool | None 

643 _repr: str | None 

644 

645 __slots__ = tuple(__annotations__) 

646 

647 def __init__( 

648 self, 

649 key: Any, 

650 func: Callable, 

651 /, 

652 *args: Any, 

653 _data_producer: bool = False, 

654 **kwargs: Any, 

655 ): 

656 self.key = key 

657 self.func = func 

658 if isinstance(func, Task): 

659 raise TypeError("Cannot nest tasks") 

660 

661 self.args = args 

662 self.kwargs = kwargs 

663 _dependencies: set[KeyType] | None = None 

664 for a in itertools.chain(args, kwargs.values()): 

665 if isinstance(a, TaskRef): 

666 if _dependencies is None: 

667 _dependencies = {a.key} 

668 else: 

669 _dependencies.add(a.key) 

670 elif isinstance(a, GraphNode) and a.dependencies: 

671 if _dependencies is None: 

672 _dependencies = set(a.dependencies) 

673 else: 

674 _dependencies.update(a.dependencies) 

675 if _dependencies: 

676 self._dependencies = frozenset(_dependencies) 

677 else: 

678 self._dependencies = _no_deps 

679 self._is_coro = None 

680 self._token = None 

681 self._repr = None 

682 self._data_producer = _data_producer 

683 

684 @property 

685 def data_producer(self) -> bool: 

686 return self._data_producer 

687 

688 def has_subgraph(self) -> bool: 

689 return self.func == _execute_subgraph 

690 

691 def copy(self): 

692 return type(self)( 

693 self.key, 

694 self.func, 

695 *self.args, 

696 **self.kwargs, 

697 ) 

698 

699 def __hash__(self): 

700 return hash(self._get_token()) 

701 

702 def _get_token(self) -> str: 

703 if self._token: 

704 return self._token 

705 from dask.base import tokenize 

706 

707 self._token = tokenize( 

708 ( 

709 type(self).__name__, 

710 self.func, 

711 self.args, 

712 self.kwargs, 

713 ) 

714 ) 

715 return self._token 

716 

717 def __dask_tokenize__(self): 

718 return self._get_token() 

719 

720 def __repr__(self) -> str: 

721 # When `Task` is deserialized the constructor will not run and 

722 # `self._repr` is thus undefined. 

723 if not hasattr(self, "_repr") or not self._repr: 

724 head = funcname(self.func) 

725 tail = ")" 

726 label_size = 40 

727 args = self.args 

728 kwargs = self.kwargs 

729 if args or kwargs: 

730 label_size2 = int( 

731 (label_size - len(head) - len(tail) - len(str(self.key))) 

732 // (len(args) + len(kwargs)) 

733 ) 

734 if args: 

735 if label_size2 > 5: 

736 args_repr = ", ".join(repr(t) for t in args) 

737 else: 

738 args_repr = "..." 

739 else: 

740 args_repr = "" 

741 if kwargs: 

742 if label_size2 > 5: 

743 kwargs_repr = ", " + ", ".join( 

744 f"{k}={v!r}" for k, v in sorted(kwargs.items()) 

745 ) 

746 else: 

747 kwargs_repr = ", ..." 

748 else: 

749 kwargs_repr = "" 

750 self._repr = f"<Task {self.key!r} {head}({args_repr}{kwargs_repr}{tail}>" 

751 return self._repr 

752 

753 def __call__(self, values=()): 

754 self._verify_values(values) 

755 

756 def _eval(a): 

757 if isinstance(a, GraphNode): 

758 return a({k: values[k] for k in a.dependencies}) 

759 elif isinstance(a, TaskRef): 

760 return values[a.key] 

761 else: 

762 return a 

763 

764 new_argspec = tuple(map(_eval, self.args)) 

765 if self.kwargs: 

766 kwargs = {k: _eval(kw) for k, kw in self.kwargs.items()} 

767 return self.func(*new_argspec, **kwargs) 

768 return self.func(*new_argspec) 

769 

770 def __setstate__(self, state): 

771 slots = self.__class__.get_all_slots() 

772 for sl, val in zip(slots, state): 

773 setattr(self, sl, val) 

774 

775 def __getstate__(self): 

776 slots = self.__class__.get_all_slots() 

777 return tuple(getattr(self, sl) for sl in slots) 

778 

779 @property 

780 def is_coro(self): 

781 if self._is_coro is None: 

782 # Note: Can't use cached_property on objects without __dict__ 

783 try: 

784 from distributed.utils import iscoroutinefunction 

785 

786 self._is_coro = iscoroutinefunction(self.func) 

787 except Exception: 

788 self._is_coro = False 

789 return self._is_coro 

790 

791 def substitute( 

792 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

793 ) -> Task: 

794 subs_filtered = { 

795 k: v for k, v in subs.items() if k in self.dependencies and k != v 

796 } 

797 extras = _extra_args(type(self)) # type: ignore[arg-type] 

798 extra_kwargs = { 

799 name: getattr(self, name) for name in extras if name not in {"key", "func"} 

800 } 

801 if subs_filtered: 

802 new_args = tuple( 

803 ( 

804 a.substitute(subs_filtered) 

805 if isinstance(a, (GraphNode, TaskRef)) 

806 else a 

807 ) 

808 for a in self.args 

809 ) 

810 new_kwargs = { 

811 k: ( 

812 v.substitute(subs_filtered) 

813 if isinstance(v, (GraphNode, TaskRef)) 

814 else v 

815 ) 

816 for k, v in self.kwargs.items() 

817 } 

818 return type(self)( 

819 key or self.key, 

820 self.func, 

821 *new_args, 

822 **new_kwargs, # type: ignore[arg-type] 

823 **extra_kwargs, 

824 ) 

825 elif key is None or key == self.key: 

826 return self 

827 else: 

828 # Rename 

829 return type(self)( 

830 key, 

831 self.func, 

832 *self.args, 

833 **self.kwargs, 

834 **extra_kwargs, 

835 ) 

836 

837 

838class NestedContainer(Task, Iterable): 

839 constructor: Callable 

840 klass: type 

841 __slots__ = tuple(__annotations__) 

842 

843 def __init__( 

844 self, 

845 /, 

846 *args: Any, 

847 **kwargs: Any, 

848 ): 

849 if len(args) == 1 and isinstance(args[0], self.klass): 

850 args = args[0] # type: ignore[assignment] 

851 super().__init__( 

852 None, 

853 self.to_container, 

854 *args, 

855 constructor=self.constructor, 

856 **kwargs, 

857 ) 

858 

859 def __getstate__(self): 

860 state = super().__getstate__() 

861 state = list(state) 

862 slots = self.__class__.get_all_slots() 

863 ix = slots.index("kwargs") 

864 # The constructor as a kwarg is redundant since this is encoded in the 

865 # class itself. Serializing the builtin types is not trivial 

866 # This saves about 15% of overhead 

867 state[ix] = state[ix].copy() 

868 state[ix].pop("constructor", None) 

869 return state 

870 

871 def __setstate__(self, state): 

872 super().__setstate__(state) 

873 self.kwargs["constructor"] = self.__class__.constructor 

874 return self 

875 

876 def __repr__(self): 

877 return f"{type(self).__name__}({self.args})" 

878 

879 def substitute( 

880 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

881 ) -> NestedContainer: 

882 subs_filtered = { 

883 k: v for k, v in subs.items() if k in self.dependencies and k != v 

884 } 

885 if not subs_filtered: 

886 return self 

887 return type(self)( 

888 *( 

889 ( 

890 a.substitute(subs_filtered) 

891 if isinstance(a, (GraphNode, TaskRef)) 

892 else a 

893 ) 

894 for a in self.args 

895 ) 

896 ) 

897 

898 def __dask_tokenize__(self): 

899 from dask.tokenize import tokenize 

900 

901 return ( 

902 type(self).__name__, 

903 self.klass, 

904 sorted(tokenize(a) for a in self.args), 

905 ) 

906 

907 return super().__dask_tokenize__() 

908 

909 @staticmethod 

910 def to_container(*args, constructor): 

911 return constructor(args) 

912 

913 def __iter__(self): 

914 yield from self.args 

915 

916 

917class List(NestedContainer): 

918 constructor = klass = list 

919 

920 

921class Tuple(NestedContainer): 

922 constructor = klass = tuple 

923 

924 

925class Set(NestedContainer): 

926 constructor = klass = set 

927 

928 

929class Dict(NestedContainer, Mapping): 

930 klass = dict 

931 

932 def __init__(self, /, *args: Any, **kwargs: Any): 

933 if args: 

934 assert not kwargs 

935 if len(args) == 1: 

936 args = args[0] 

937 if isinstance(args, dict): # type: ignore[unreachable] 

938 args = tuple(itertools.chain(*args.items())) # type: ignore[unreachable] 

939 elif isinstance(args, (list, tuple)): 

940 if all( 

941 len(el) == 2 if isinstance(el, (list, tuple)) else False 

942 for el in args 

943 ): 

944 args = tuple(itertools.chain(*args)) 

945 else: 

946 raise ValueError("Invalid argument provided") 

947 

948 if len(args) % 2 != 0: 

949 raise ValueError("Invalid number of arguments provided") 

950 

951 elif kwargs: 

952 assert not args 

953 args = tuple(itertools.chain(*kwargs.items())) 

954 

955 super().__init__(*args) 

956 

957 def __repr__(self): 

958 values = ", ".join(f"{k}: {v}" for k, v in batched(self.args, 2, strict=True)) 

959 return f"Dict({values})" 

960 

961 def substitute( 

962 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

963 ) -> Dict: 

964 subs_filtered = { 

965 k: v for k, v in subs.items() if k in self.dependencies and k != v 

966 } 

967 if not subs_filtered: 

968 return self 

969 

970 new_args = [] 

971 for arg in self.args: 

972 new_arg = ( 

973 arg.substitute(subs_filtered) 

974 if isinstance(arg, (GraphNode, TaskRef)) 

975 else arg 

976 ) 

977 new_args.append(new_arg) 

978 return type(self)(new_args) 

979 

980 def __iter__(self): 

981 yield from self.args[::2] 

982 

983 def __len__(self): 

984 return len(self.args) // 2 

985 

986 def __getitem__(self, key): 

987 for k, v in batched(self.args, 2, strict=True): 

988 if k == key: 

989 return v 

990 raise KeyError(key) 

991 

992 @staticmethod 

993 def constructor(args): 

994 return dict(batched(args, 2, strict=True)) 

995 

996 

997class DependenciesMapping(MutableMapping): 

998 def __init__(self, dsk): 

999 self.dsk = dsk 

1000 self._removed = set() 

1001 # Set a copy of dsk to avoid dct resizing 

1002 self._cache = dsk.copy() 

1003 self._cache.clear() 

1004 

1005 def __getitem__(self, key): 

1006 if (val := self._cache.get(key)) is not None: 

1007 return val 

1008 else: 

1009 v = self.dsk[key] 

1010 try: 

1011 deps = v.dependencies 

1012 except AttributeError: 

1013 from dask.core import get_dependencies 

1014 

1015 deps = get_dependencies(self.dsk, task=v) 

1016 

1017 if self._removed: 

1018 # deps is a frozenset but for good measure, let's not use -= since 

1019 # that _may_ perform an inplace mutation 

1020 deps = deps - self._removed 

1021 self._cache[key] = deps 

1022 return deps 

1023 

1024 def __iter__(self): 

1025 return iter(self.dsk) 

1026 

1027 def __delitem__(self, key: Any) -> None: 

1028 self._cache.clear() 

1029 self._removed.add(key) 

1030 

1031 def __setitem__(self, key: Any, value: Any) -> None: 

1032 raise NotImplementedError 

1033 

1034 def __len__(self) -> int: 

1035 return len(self.dsk) 

1036 

1037 

1038class _DevNullMapping(MutableMapping): 

1039 def __getitem__(self, key): 

1040 raise KeyError(key) 

1041 

1042 def __setitem__(self, key, value): 

1043 pass 

1044 

1045 def __delitem__(self, key): 

1046 pass 

1047 

1048 def __len__(self): 

1049 return 0 

1050 

1051 def __iter__(self): 

1052 return iter(()) 

1053 

1054 

1055def execute_graph( 

1056 dsk: Iterable[GraphNode] | Mapping[KeyType, GraphNode], 

1057 cache: MutableMapping[KeyType, object] | None = None, 

1058 keys: Container[KeyType] | None = None, 

1059) -> MutableMapping[KeyType, object]: 

1060 """Execute a given graph. 

1061 

1062 The graph is executed in topological order as defined by dask.order until 

1063 all leaf nodes, i.e. nodes without any dependents, are reached. The returned 

1064 dictionary contains the results of the leaf nodes. 

1065 

1066 If keys are required that are not part of the graph, they can be provided in the `cache` argument. 

1067 

1068 If `keys` is provided, the result will contain only values that are part of the `keys` set. 

1069 

1070 """ 

1071 if isinstance(dsk, (list, tuple, set, frozenset)): 

1072 dsk = {t.key: t for t in dsk} 

1073 else: 

1074 assert isinstance(dsk, dict) 

1075 

1076 refcount: defaultdict[KeyType, int] = defaultdict(int) 

1077 for vals in DependenciesMapping(dsk).values(): 

1078 for val in vals: 

1079 refcount[val] += 1 

1080 

1081 cache = cache or {} 

1082 from dask.order import order 

1083 

1084 priorities = order(dsk) 

1085 

1086 for key, node in sorted(dsk.items(), key=lambda it: priorities[it[0]]): 

1087 cache[key] = node(cache) 

1088 for dep in node.dependencies: 

1089 refcount[dep] -= 1 

1090 if refcount[dep] == 0 and keys and dep not in keys: 

1091 del cache[dep] 

1092 

1093 return cache 

1094 

1095 

1096def fuse_linear_task_spec(dsk, keys): 

1097 """ 

1098 keys are the keys from the graph that are requested by a computation. We 

1099 can't fuse those together. 

1100 """ 

1101 from dask.core import reverse_dict 

1102 from dask.optimization import default_fused_keys_renamer 

1103 

1104 keys = set(keys) 

1105 dependencies = DependenciesMapping(dsk) 

1106 dependents = reverse_dict(dependencies) 

1107 

1108 seen = set() 

1109 result = {} 

1110 

1111 for key in dsk: 

1112 if key in seen: 

1113 continue 

1114 

1115 seen.add(key) 

1116 

1117 deps = dependencies[key] 

1118 dependents_key = dependents[key] 

1119 

1120 if len(deps) != 1 and len(dependents_key) != 1 or dsk[key].block_fusion: 

1121 result[key] = dsk[key] 

1122 continue 

1123 

1124 linear_chain = [dsk[key]] 

1125 top_key = key 

1126 

1127 # Walk towards the leafs as long as the nodes have a single dependency 

1128 # and a single dependent, we can't fuse two nodes of an intermediate node 

1129 # is the source for 2 dependents 

1130 while len(deps) == 1: 

1131 (new_key,) = deps 

1132 if new_key in seen: 

1133 break 

1134 seen.add(new_key) 

1135 if new_key not in dsk: 

1136 # This can happen if a future is in the graph, the dependency mapping 

1137 # adds the key that is referenced by the future as a dependency 

1138 # see test_futures_to_delayed_array 

1139 break 

1140 if ( 

1141 len(dependents[new_key]) != 1 

1142 or dsk[new_key].block_fusion 

1143 or new_key in keys 

1144 ): 

1145 result[new_key] = dsk[new_key] 

1146 break 

1147 # backwards comp for new names, temporary until is_rootish is removed 

1148 linear_chain.insert(0, dsk[new_key]) 

1149 deps = dependencies[new_key] 

1150 

1151 # Walk the tree towards the root as long as the nodes have a single dependent 

1152 # and a single dependency, we can't fuse two nodes if node has multiple 

1153 # dependencies 

1154 while len(dependents_key) == 1 and top_key not in keys: 

1155 new_key = dependents_key.pop() 

1156 if new_key in seen: 

1157 break 

1158 seen.add(new_key) 

1159 if len(dependencies[new_key]) != 1 or dsk[new_key].block_fusion: 

1160 # Exit if the dependent has multiple dependencies, triangle 

1161 result[new_key] = dsk[new_key] 

1162 break 

1163 linear_chain.append(dsk[new_key]) 

1164 top_key = new_key 

1165 dependents_key = dependents[new_key] 

1166 

1167 if len(linear_chain) == 1: 

1168 result[top_key] = linear_chain[0] 

1169 else: 

1170 # Renaming the keys is necessary to preserve the rootish detection for now 

1171 renamed_key = default_fused_keys_renamer([tsk.key for tsk in linear_chain]) 

1172 result[renamed_key] = Task.fuse(*linear_chain, key=renamed_key) 

1173 if renamed_key != top_key: 

1174 # Having the same prefixes can result in the same key, i.e. getitem-hash -> getitem-hash 

1175 result[top_key] = Alias(top_key, target=renamed_key) 

1176 return result 

1177 

1178 

1179def cull( 

1180 dsk: dict[KeyType, GraphNode], keys: Iterable[KeyType] 

1181) -> dict[KeyType, GraphNode]: 

1182 if not isinstance(keys, (list, set, tuple)): 

1183 raise TypeError( 

1184 f"Expected list, set or tuple for keys, got {type(keys).__name__}" 

1185 ) 

1186 if len(keys) == len(dsk): 

1187 return dsk 

1188 work = set(keys) 

1189 seen: set[KeyType] = set() 

1190 dsk2 = {} 

1191 wpop = work.pop 

1192 wupdate = work.update 

1193 sadd = seen.add 

1194 while work: 

1195 k = wpop() 

1196 if k in seen or k not in dsk: 

1197 continue 

1198 sadd(k) 

1199 dsk2[k] = v = dsk[k] 

1200 wupdate(v.dependencies) 

1201 return dsk2 

1202 

1203 

1204@functools.cache 

1205def _extra_args(typ: type) -> set[str]: 

1206 import inspect 

1207 

1208 sig = inspect.signature(typ) 

1209 extras = set() 

1210 for name, param in sig.parameters.items(): 

1211 if param.kind in ( 

1212 inspect.Parameter.VAR_POSITIONAL, 

1213 inspect.Parameter.VAR_KEYWORD, 

1214 ): 

1215 continue 

1216 if name in typ.get_all_slots(): # type: ignore[attr-defined] 

1217 extras.add(name) 

1218 return extras