Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/_task_spec.py: 25%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

661 statements  

1from __future__ import annotations 

2 

3""" Task specification for dask 

4 

5This module contains the task specification for dask. It is used to represent 

6runnable (task) and non-runnable (data) nodes in a dask graph. 

7 

8Simple examples of how to express tasks in dask 

9----------------------------------------------- 

10 

11.. code-block:: python 

12 

13 func("a", "b") ~ Task("key", func, "a", "b") 

14 

15 [func("a"), func("b")] ~ [Task("key-1", func, "a"), Task("key-2", func, "b")] 

16 

17 {"a": func("b")} ~ {"a": Task("a", func, "b")} 

18 

19 "literal-string" ~ DataNode("key", "literal-string") 

20 

21 

22Keys, Aliases and TaskRefs 

23------------------------- 

24 

25Keys are used to identify tasks in a dask graph. Every `GraphNode` instance has a 

26key attribute that _should_ reference the key in the dask graph. 

27 

28.. code-block:: python 

29 

30 {"key": Task("key", func, "a")} 

31 

32Referencing other tasks is possible by using either one of `Alias` or a 

33`TaskRef`. 

34 

35.. code-block:: python 

36 

37 # TaskRef can be used to provide the name of the reference explicitly 

38 t = Task("key", func, TaskRef("key-1")) 

39 

40 # If a task is still in scope, the method `ref` can be used for convenience 

41 t2 = Task("key2", func2, t.ref()) 

42 

43 

44Executing a task 

45---------------- 

46 

47A task can be executed by calling it with a dictionary of values. The values 

48should contain the dependencies of the task. 

49 

50.. code-block:: python 

51 

52 t = Task("key", add, TaskRef("a"), TaskRef("b")) 

53 assert t.dependencies == {"a", "b"} 

54 t({"a": 1, "b": 2}) == 3 

55 

56""" 

57import functools 

58import itertools 

59import sys 

60from collections import defaultdict 

61from collections.abc import Callable, Container, Iterable, Mapping, MutableMapping 

62from functools import lru_cache, partial 

63from typing import Any, TypeVar, cast 

64 

65from dask.sizeof import sizeof 

66from dask.typing import Key as KeyType 

67from dask.utils import funcname, is_namedtuple_instance 

68 

69_T = TypeVar("_T") 

70 

71 

72# Ported from more-itertools 

73# https://github.com/more-itertools/more-itertools/blob/c8153e2801ade2527f3a6c8b623afae93f5a1ce1/more_itertools/recipes.py#L944-L973 

74def _batched(iterable, n, *, strict=False): 

75 """Batch data into tuples of length *n*. If the number of items in 

76 *iterable* is not divisible by *n*: 

77 * The last batch will be shorter if *strict* is ``False``. 

78 * :exc:`ValueError` will be raised if *strict* is ``True``. 

79 

80 >>> list(batched('ABCDEFG', 3)) 

81 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)] 

82 

83 On Python 3.13 and above, this is an alias for :func:`itertools.batched`. 

84 """ 

85 if n < 1: 

86 raise ValueError("n must be at least one") 

87 it = iter(iterable) 

88 while batch := tuple(itertools.islice(it, n)): 

89 if strict and len(batch) != n: 

90 raise ValueError("batched(): incomplete batch") 

91 yield batch 

92 

93 

94if sys.hexversion >= 0x30D00A2: 

95 

96 def batched(iterable, n, *, strict=False): 

97 return itertools.batched(iterable, n, strict=strict) 

98 

99else: 

100 batched = _batched 

101 

102 batched.__doc__ = _batched.__doc__ 

103# End port 

104 

105 

106def identity(*args): 

107 return args 

108 

109 

110def _identity_cast(*args, typ): 

111 return typ(args) 

112 

113 

114_anom_count = itertools.count() 

115 

116 

117def parse_input(obj: Any) -> object: 

118 """Tokenize user input into GraphNode objects 

119 

120 Note: This is similar to `convert_legacy_task` but does not 

121 - compare any values to a global set of known keys to infer references/futures 

122 - parse tuples and interprets them as runnable tasks 

123 - Deal with SubgraphCallables 

124 

125 Parameters 

126 ---------- 

127 obj : _type_ 

128 _description_ 

129 

130 Returns 

131 ------- 

132 _type_ 

133 _description_ 

134 """ 

135 if isinstance(obj, GraphNode): 

136 return obj 

137 

138 if isinstance(obj, TaskRef): 

139 return Alias(obj.key) 

140 

141 if isinstance(obj, dict): 

142 parsed_dict = {k: parse_input(v) for k, v in obj.items()} 

143 if any(isinstance(v, GraphNode) for v in parsed_dict.values()): 

144 return Dict(parsed_dict) 

145 

146 if isinstance(obj, (list, set, tuple)): 

147 parsed_collection = tuple(parse_input(o) for o in obj) 

148 if any(isinstance(o, GraphNode) for o in parsed_collection): 

149 if isinstance(obj, list): 

150 return List(*parsed_collection) 

151 if isinstance(obj, set): 

152 return Set(*parsed_collection) 

153 if isinstance(obj, tuple): 

154 if is_namedtuple_instance(obj): 

155 return _wrap_namedtuple_task(None, obj, parse_input) 

156 return Tuple(*parsed_collection) 

157 

158 return obj 

159 

160 

161def _wrap_namedtuple_task(k, obj, parser): 

162 if hasattr(obj, "__getnewargs_ex__"): 

163 new_args, kwargs = obj.__getnewargs_ex__() 

164 kwargs = {k: parser(v) for k, v in kwargs.items()} 

165 elif hasattr(obj, "__getnewargs__"): 

166 new_args = obj.__getnewargs__() 

167 kwargs = {} 

168 

169 args_converted = parse_input(type(new_args)(map(parser, new_args))) 

170 

171 return Task( 

172 k, partial(_instantiate_named_tuple, type(obj)), args_converted, Dict(kwargs) 

173 ) 

174 

175 

176def _instantiate_named_tuple(typ, args, kwargs): 

177 return typ(*args, **kwargs) 

178 

179 

180class _MultiContainer(Container): 

181 container: tuple 

182 __slots__ = ("container",) 

183 

184 def __init__(self, *container): 

185 self.container = container 

186 

187 def __contains__(self, o: object) -> bool: 

188 for c in self.container: 

189 if o in c: 

190 return True 

191 return False 

192 

193 

194SubgraphType = None 

195 

196 

197def _execute_subgraph(inner_dsk, outkey, inkeys, *dependencies): 

198 final = {} 

199 final.update(inner_dsk) 

200 for k, v in zip(inkeys, dependencies): 

201 final[k] = DataNode(None, v) 

202 res = execute_graph(final, keys=[outkey]) 

203 return res[outkey] 

204 

205 

206def convert_legacy_task( 

207 key: KeyType | None, 

208 task: _T, 

209 all_keys: Container, 

210) -> GraphNode | _T: 

211 if isinstance(task, GraphNode): 

212 return task 

213 

214 if type(task) is tuple and task and callable(task[0]): 

215 func, args = task[0], task[1:] 

216 new_args = [] 

217 new: object 

218 for a in args: 

219 if isinstance(a, dict): 

220 new = Dict(a) 

221 else: 

222 new = convert_legacy_task(None, a, all_keys) 

223 new_args.append(new) 

224 return Task(key, func, *new_args) 

225 try: 

226 if isinstance(task, (int, float, str, tuple)): 

227 if task in all_keys: 

228 if key is None: 

229 return Alias(task) 

230 else: 

231 return Alias(key, target=task) 

232 except TypeError: 

233 # Unhashable 

234 pass 

235 

236 if isinstance(task, (list, tuple, set, frozenset)): 

237 if is_namedtuple_instance(task): 

238 return _wrap_namedtuple_task( 

239 key, 

240 task, 

241 partial( 

242 convert_legacy_task, 

243 None, 

244 all_keys=all_keys, 

245 ), 

246 ) 

247 else: 

248 parsed_args = tuple(convert_legacy_task(None, t, all_keys) for t in task) 

249 if any(isinstance(a, GraphNode) for a in parsed_args): 

250 return Task(key, _identity_cast, *parsed_args, typ=type(task)) 

251 else: 

252 return cast(_T, type(task)(parsed_args)) 

253 elif isinstance(task, TaskRef): 

254 if key is None: 

255 return Alias(task.key) 

256 else: 

257 return Alias(key, target=task.key) 

258 else: 

259 return task 

260 

261 

262def convert_legacy_graph( 

263 dsk: Mapping, 

264 all_keys: Container | None = None, 

265): 

266 if all_keys is None: 

267 all_keys = set(dsk) 

268 new_dsk = {} 

269 for k, arg in dsk.items(): 

270 t = convert_legacy_task(k, arg, all_keys) 

271 if isinstance(t, Alias) and t.target == k: 

272 continue 

273 elif not isinstance(t, GraphNode): 

274 t = DataNode(k, t) 

275 new_dsk[k] = t 

276 return new_dsk 

277 

278 

279def resolve_aliases(dsk: dict, keys: set, dependents: dict) -> dict: 

280 """Remove trivial sequential alias chains 

281 

282 Example: 

283 

284 dsk = {'x': 1, 'y': Alias('x'), 'z': Alias('y')} 

285 

286 resolve_aliases(dsk, {'z'}, {'x': {'y'}, 'y': {'z'}}) == {'z': 1} 

287 

288 """ 

289 if not keys: 

290 raise ValueError("No keys provided") 

291 dsk = dict(dsk) 

292 work = list(keys) 

293 seen = set() 

294 while work: 

295 k = work.pop() 

296 if k in seen or k not in dsk: 

297 continue 

298 seen.add(k) 

299 t = dsk[k] 

300 if isinstance(t, Alias): 

301 target_key = t.target 

302 # Rules for when we allow to collapse an alias 

303 # 1. The target key is not in the keys set. The keys set is what the 

304 # user is requesting and by collapsing we'd no longer be able to 

305 # return that result. 

306 # 2. The target key is in fact part of dsk. If it isnt' this could 

307 # point to a persisted dependency and we cannot collapse it. 

308 # 3. The target key has only one dependent which is the key we're 

309 # currently looking at. This means that there is a one to one 

310 # relation between this and the target key in which case we can 

311 # collapse them. 

312 # Note: If target was an alias as well, we could continue with 

313 # more advanced optimizations but this isn't implemented, yet 

314 if ( 

315 target_key not in keys 

316 and target_key in dsk 

317 # Note: whenever we're performing a collapse, we're not updating 

318 # the dependents. The length == 1 should still be sufficient for 

319 # chains of these aliases 

320 and len(dependents[target_key]) == 1 

321 ): 

322 tnew = dsk.pop(target_key).copy() 

323 

324 dsk[k] = tnew 

325 tnew.key = k 

326 if isinstance(tnew, Alias): 

327 work.append(k) 

328 seen.discard(k) 

329 else: 

330 work.extend(tnew.dependencies) 

331 

332 work.extend(t.dependencies) 

333 return dsk 

334 

335 

336class TaskRef: 

337 val: KeyType 

338 __slots__ = ("key",) 

339 

340 def __init__(self, key: KeyType): 

341 self.key = key 

342 

343 def __str__(self): 

344 return str(self.key) 

345 

346 def __repr__(self): 

347 return f"{type(self).__name__}({self.key!r})" 

348 

349 def __hash__(self) -> int: 

350 return hash(self.key) 

351 

352 def __eq__(self, value: object) -> bool: 

353 if not isinstance(value, TaskRef): 

354 return False 

355 return self.key == value.key 

356 

357 def __reduce__(self): 

358 return TaskRef, (self.key,) 

359 

360 def substitute(self, subs: dict, key: KeyType | None = None) -> TaskRef | GraphNode: 

361 if self.key in subs: 

362 val = subs[self.key] 

363 if isinstance(val, GraphNode): 

364 return val.substitute({}, key=self.key) 

365 elif isinstance(val, TaskRef): 

366 return val 

367 else: 

368 return TaskRef(val) 

369 return self 

370 

371 

372class GraphNode: 

373 key: KeyType 

374 _dependencies: frozenset 

375 

376 __slots__ = tuple(__annotations__) 

377 

378 def ref(self): 

379 return Alias(self.key) 

380 

381 def copy(self): 

382 raise NotImplementedError 

383 

384 @property 

385 def data_producer(self) -> bool: 

386 return False 

387 

388 @property 

389 def dependencies(self) -> frozenset: 

390 return self._dependencies 

391 

392 @property 

393 def block_fusion(self) -> bool: 

394 return False 

395 

396 def _verify_values(self, values: tuple | dict) -> None: 

397 if not self.dependencies: 

398 return 

399 if missing := set(self.dependencies) - set(values): 

400 raise RuntimeError(f"Not enough arguments provided: missing keys {missing}") 

401 

402 def __call__(self, values) -> Any: 

403 raise NotImplementedError("Not implemented") 

404 

405 def __eq__(self, value: object) -> bool: 

406 if type(value) is not type(self): 

407 return False 

408 

409 from dask.tokenize import tokenize 

410 

411 return tokenize(self) == tokenize(value) 

412 

413 @property 

414 def is_coro(self) -> bool: 

415 return False 

416 

417 def __sizeof__(self) -> int: 

418 all_slots = self.get_all_slots() 

419 return sum(sizeof(getattr(self, sl)) for sl in all_slots) + sys.getsizeof( 

420 type(self) 

421 ) 

422 

423 def substitute( 

424 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

425 ) -> GraphNode: 

426 """Substitute a dependency with a new value. The new value either has to 

427 be a new valid key or a GraphNode to replace the dependency entirely. 

428 

429 The GraphNode will not be mutated but instead a shallow copy will be 

430 returned. The substitution will be performed eagerly. 

431 

432 Parameters 

433 ---------- 

434 subs : dict[KeyType, KeyType | GraphNode] 

435 The mapping describing the substitutions to be made. 

436 key : KeyType | None, optional 

437 The key of the new GraphNode object. If None provided, the key of 

438 the old one will be reused. 

439 """ 

440 raise NotImplementedError 

441 

442 @staticmethod 

443 def fuse(*tasks: GraphNode, key: KeyType | None = None) -> GraphNode: 

444 """Fuse a set of tasks into a single task. 

445 

446 The tasks are fused into a single task that will execute the tasks in a 

447 subgraph. The internal tasks are no longer accessible from the outside. 

448 

449 All provided tasks must form a valid subgraph that will reduce to a 

450 single key. If multiple outputs are possible with the provided tasks, an 

451 exception will be raised. 

452 

453 The tasks will not be rewritten but instead a new Task will be created 

454 that will merely reference the old task objects. This way, Task objects 

455 may be reused in multiple fused tasks. 

456 

457 Parameters 

458 ---------- 

459 key : KeyType | None, optional 

460 The key of the new Task object. If None provided, the key of the 

461 final task will be used. 

462 

463 See also 

464 -------- 

465 GraphNode.substitute : Easer substitution of dependencies 

466 """ 

467 if any(t.key is None for t in tasks): 

468 raise ValueError("Cannot fuse tasks with missing keys") 

469 if len(tasks) == 1: 

470 return tasks[0].substitute({}, key=key) 

471 all_keys = set() 

472 all_deps: set[KeyType] = set() 

473 for t in tasks: 

474 all_deps.update(t.dependencies) 

475 all_keys.add(t.key) 

476 external_deps = tuple(sorted(all_deps - all_keys, key=hash)) 

477 leafs = all_keys - all_deps 

478 if len(leafs) > 1: 

479 raise ValueError(f"Cannot fuse tasks with multiple outputs {leafs}") 

480 

481 outkey = leafs.pop() 

482 return Task( 

483 key or outkey, 

484 _execute_subgraph, 

485 {t.key: t for t in tasks}, 

486 outkey, 

487 external_deps, 

488 *(TaskRef(k) for k in external_deps), 

489 _data_producer=any(t.data_producer for t in tasks), 

490 ) 

491 

492 @classmethod 

493 @lru_cache 

494 def get_all_slots(cls): 

495 slots = list() 

496 for c in cls.mro(): 

497 slots.extend(getattr(c, "__slots__", ())) 

498 # Interestingly, sorting this causes the nested containers to pickle 

499 # more efficiently 

500 return sorted(set(slots)) 

501 

502 

503_no_deps: frozenset = frozenset() 

504 

505 

506class Alias(GraphNode): 

507 target: KeyType 

508 __slots__ = tuple(__annotations__) 

509 

510 def __init__( 

511 self, key: KeyType | TaskRef, target: Alias | TaskRef | KeyType | None = None 

512 ): 

513 if isinstance(key, TaskRef): 

514 key = key.key 

515 self.key = key 

516 if target is None: 

517 target = key 

518 if isinstance(target, Alias): 

519 target = target.target 

520 if isinstance(target, TaskRef): 

521 target = target.key 

522 self.target = target 

523 self._dependencies = frozenset((self.target,)) 

524 

525 def __reduce__(self): 

526 return Alias, (self.key, self.target) 

527 

528 def copy(self): 

529 return Alias(self.key, self.target) 

530 

531 def substitute( 

532 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

533 ) -> GraphNode: 

534 if self.key in subs or self.target in subs: 

535 sub_key = subs.get(self.key, self.key) 

536 val = subs.get(self.target, self.target) 

537 if sub_key == self.key and val == self.target: 

538 return self 

539 if isinstance(val, (GraphNode, TaskRef)): 

540 return val.substitute({}, key=key) 

541 if key is None and isinstance(sub_key, GraphNode): 

542 raise RuntimeError( 

543 f"Invalid substitution encountered {self.key!r} -> {sub_key}" 

544 ) 

545 return Alias(key or sub_key, val) # type: ignore 

546 return self 

547 

548 def __dask_tokenize__(self): 

549 return (type(self).__name__, self.key, self.target) 

550 

551 def __call__(self, values=()): 

552 self._verify_values(values) 

553 return values[self.target] 

554 

555 def __repr__(self): 

556 if self.key != self.target: 

557 return f"Alias({self.key!r}->{self.target!r})" 

558 else: 

559 return f"Alias({self.key!r})" 

560 

561 def __eq__(self, value: object) -> bool: 

562 if not isinstance(value, Alias): 

563 return False 

564 if self.key != value.key: 

565 return False 

566 if self.target != value.target: 

567 return False 

568 return True 

569 

570 

571class DataNode(GraphNode): 

572 value: Any 

573 typ: type 

574 __slots__ = tuple(__annotations__) 

575 

576 def __init__(self, key: Any, value: Any): 

577 if key is None: 

578 key = (type(value).__name__, next(_anom_count)) 

579 self.key = key 

580 self.value = value 

581 self.typ = type(value) 

582 self._dependencies = _no_deps 

583 

584 @property 

585 def data_producer(self) -> bool: 

586 return True 

587 

588 def copy(self): 

589 return DataNode(self.key, self.value) 

590 

591 def __call__(self, values=()): 

592 return self.value 

593 

594 def __repr__(self): 

595 return f"DataNode({self.value!r})" 

596 

597 def __reduce__(self): 

598 return (DataNode, (self.key, self.value)) 

599 

600 def __dask_tokenize__(self): 

601 from dask.base import tokenize 

602 

603 return (type(self).__name__, tokenize(self.value)) 

604 

605 def substitute( 

606 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

607 ) -> DataNode: 

608 if key is not None and key != self.key: 

609 return DataNode(key, self.value) 

610 return self 

611 

612 def __iter__(self): 

613 return iter(self.value) 

614 

615 

616def _get_dependencies(obj: object) -> set | frozenset: 

617 if isinstance(obj, TaskRef): 

618 return {obj.key} 

619 elif isinstance(obj, GraphNode): 

620 return obj.dependencies 

621 elif isinstance(obj, dict): 

622 if not obj: 

623 return _no_deps 

624 return set().union(*map(_get_dependencies, obj.values())) 

625 elif isinstance(obj, (list, tuple, frozenset, set)): 

626 if not obj: 

627 return _no_deps 

628 return set().union(*map(_get_dependencies, obj)) 

629 return _no_deps 

630 

631 

632class Task(GraphNode): 

633 func: Callable 

634 args: tuple 

635 kwargs: dict 

636 _data_producer: bool 

637 _token: str | None 

638 _is_coro: bool | None 

639 _repr: str | None 

640 

641 __slots__ = tuple(__annotations__) 

642 

643 def __init__( 

644 self, 

645 key: Any, 

646 func: Callable, 

647 /, 

648 *args: Any, 

649 _data_producer: bool = False, 

650 **kwargs: Any, 

651 ): 

652 self.key = key 

653 self.func = func 

654 if isinstance(func, Task): 

655 raise TypeError("Cannot nest tasks") 

656 

657 self.args = args 

658 self.kwargs = kwargs 

659 _dependencies: set[KeyType] | None = None 

660 for a in itertools.chain(args, kwargs.values()): 

661 if isinstance(a, TaskRef): 

662 if _dependencies is None: 

663 _dependencies = {a.key} 

664 else: 

665 _dependencies.add(a.key) 

666 elif isinstance(a, GraphNode) and a.dependencies: 

667 if _dependencies is None: 

668 _dependencies = set(a.dependencies) 

669 else: 

670 _dependencies.update(a.dependencies) 

671 if _dependencies: 

672 self._dependencies = frozenset(_dependencies) 

673 else: 

674 self._dependencies = _no_deps 

675 self._is_coro = None 

676 self._token = None 

677 self._repr = None 

678 self._data_producer = _data_producer 

679 

680 @property 

681 def data_producer(self) -> bool: 

682 return self._data_producer 

683 

684 def has_subgraph(self) -> bool: 

685 return self.func == _execute_subgraph 

686 

687 def copy(self): 

688 return type(self)( 

689 self.key, 

690 self.func, 

691 *self.args, 

692 **self.kwargs, 

693 ) 

694 

695 def __hash__(self): 

696 return hash(self._get_token()) 

697 

698 def _get_token(self) -> str: 

699 if self._token: 

700 return self._token 

701 from dask.base import tokenize 

702 

703 self._token = tokenize( 

704 ( 

705 type(self).__name__, 

706 self.func, 

707 self.args, 

708 self.kwargs, 

709 ) 

710 ) 

711 return self._token 

712 

713 def __dask_tokenize__(self): 

714 return self._get_token() 

715 

716 def __repr__(self) -> str: 

717 # When `Task` is deserialized the constructor will not run and 

718 # `self._repr` is thus undefined. 

719 if not hasattr(self, "_repr") or not self._repr: 

720 head = funcname(self.func) 

721 tail = ")" 

722 label_size = 40 

723 args = self.args 

724 kwargs = self.kwargs 

725 if args or kwargs: 

726 label_size2 = int( 

727 (label_size - len(head) - len(tail) - len(str(self.key))) 

728 // (len(args) + len(kwargs)) 

729 ) 

730 if args: 

731 if label_size2 > 5: 

732 args_repr = ", ".join(repr(t) for t in args) 

733 else: 

734 args_repr = "..." 

735 else: 

736 args_repr = "" 

737 if kwargs: 

738 if label_size2 > 5: 

739 kwargs_repr = ", " + ", ".join( 

740 f"{k}={repr(v)}" for k, v in sorted(kwargs.items()) 

741 ) 

742 else: 

743 kwargs_repr = ", ..." 

744 else: 

745 kwargs_repr = "" 

746 self._repr = f"<Task {self.key!r} {head}({args_repr}{kwargs_repr}{tail}>" 

747 return self._repr 

748 

749 def __call__(self, values=()): 

750 self._verify_values(values) 

751 

752 def _eval(a): 

753 if isinstance(a, GraphNode): 

754 return a({k: values[k] for k in a.dependencies}) 

755 elif isinstance(a, TaskRef): 

756 return values[a.key] 

757 else: 

758 return a 

759 

760 new_argspec = tuple(map(_eval, self.args)) 

761 if self.kwargs: 

762 kwargs = {k: _eval(kw) for k, kw in self.kwargs.items()} 

763 return self.func(*new_argspec, **kwargs) 

764 return self.func(*new_argspec) 

765 

766 def __setstate__(self, state): 

767 slots = self.__class__.get_all_slots() 

768 for sl, val in zip(slots, state): 

769 setattr(self, sl, val) 

770 

771 def __getstate__(self): 

772 slots = self.__class__.get_all_slots() 

773 return tuple(getattr(self, sl) for sl in slots) 

774 

775 @property 

776 def is_coro(self): 

777 if self._is_coro is None: 

778 # Note: Can't use cached_property on objects without __dict__ 

779 try: 

780 from distributed.utils import iscoroutinefunction 

781 

782 self._is_coro = iscoroutinefunction(self.func) 

783 except Exception: 

784 self._is_coro = False 

785 return self._is_coro 

786 

787 def substitute( 

788 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

789 ) -> Task: 

790 subs_filtered = { 

791 k: v for k, v in subs.items() if k in self.dependencies and k != v 

792 } 

793 extras = _extra_args(type(self)) # type: ignore 

794 extra_kwargs = { 

795 name: getattr(self, name) for name in extras if name not in {"key", "func"} 

796 } 

797 if subs_filtered: 

798 new_args = tuple( 

799 ( 

800 a.substitute(subs_filtered) 

801 if isinstance(a, (GraphNode, TaskRef)) 

802 else a 

803 ) 

804 for a in self.args 

805 ) 

806 new_kwargs = { 

807 k: ( 

808 v.substitute(subs_filtered) 

809 if isinstance(v, (GraphNode, TaskRef)) 

810 else v 

811 ) 

812 for k, v in self.kwargs.items() 

813 } 

814 return type(self)( 

815 key or self.key, 

816 self.func, 

817 *new_args, 

818 **new_kwargs, # type: ignore[arg-type] 

819 **extra_kwargs, 

820 ) 

821 elif key is None or key == self.key: 

822 return self 

823 else: 

824 # Rename 

825 return type(self)( 

826 key, 

827 self.func, 

828 *self.args, 

829 **self.kwargs, 

830 **extra_kwargs, 

831 ) 

832 

833 

834class NestedContainer(Task, Iterable): 

835 constructor: Callable 

836 klass: type 

837 __slots__ = tuple(__annotations__) 

838 

839 def __init__( 

840 self, 

841 /, 

842 *args: Any, 

843 **kwargs: Any, 

844 ): 

845 if len(args) == 1 and isinstance(args[0], self.klass): 

846 args = args[0] # type: ignore 

847 super().__init__( 

848 None, 

849 self.to_container, 

850 *args, 

851 constructor=self.constructor, 

852 **kwargs, 

853 ) 

854 

855 def __getstate__(self): 

856 state = super().__getstate__() 

857 state = list(state) 

858 slots = self.__class__.get_all_slots() 

859 ix = slots.index("kwargs") 

860 # The constructor as a kwarg is redundant since this is encoded in the 

861 # class itself. Serializing the builtin types is not trivial 

862 # This saves about 15% of overhead 

863 state[ix] = state[ix].copy() 

864 state[ix].pop("constructor", None) 

865 return state 

866 

867 def __setstate__(self, state): 

868 super().__setstate__(state) 

869 self.kwargs["constructor"] = self.__class__.constructor 

870 return self 

871 

872 def __repr__(self): 

873 return f"{type(self).__name__}({self.args})" 

874 

875 def substitute( 

876 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

877 ) -> NestedContainer: 

878 subs_filtered = { 

879 k: v for k, v in subs.items() if k in self.dependencies and k != v 

880 } 

881 if not subs_filtered: 

882 return self 

883 return type(self)( 

884 *( 

885 ( 

886 a.substitute(subs_filtered) 

887 if isinstance(a, (GraphNode, TaskRef)) 

888 else a 

889 ) 

890 for a in self.args 

891 ) 

892 ) 

893 

894 def __dask_tokenize__(self): 

895 from dask.tokenize import tokenize 

896 

897 return ( 

898 type(self).__name__, 

899 self.klass, 

900 sorted(tokenize(a) for a in self.args), 

901 ) 

902 

903 return super().__dask_tokenize__() 

904 

905 @staticmethod 

906 def to_container(*args, constructor): 

907 return constructor(args) 

908 

909 def __iter__(self): 

910 yield from self.args 

911 

912 

913class List(NestedContainer): 

914 constructor = klass = list 

915 

916 

917class Tuple(NestedContainer): 

918 constructor = klass = tuple 

919 

920 

921class Set(NestedContainer): 

922 constructor = klass = set 

923 

924 

925class Dict(NestedContainer, Mapping): 

926 klass = dict 

927 

928 def __init__(self, /, *args: Any, **kwargs: Any): 

929 if args: 

930 assert not kwargs 

931 if len(args) == 1: 

932 args = args[0] 

933 if isinstance(args, dict): # type: ignore 

934 args = tuple(itertools.chain(*args.items())) # type: ignore 

935 elif isinstance(args, (list, tuple)): 

936 if all( 

937 len(el) == 2 if isinstance(el, (list, tuple)) else False 

938 for el in args 

939 ): 

940 args = tuple(itertools.chain(*args)) 

941 else: 

942 raise ValueError("Invalid argument provided") 

943 

944 if len(args) % 2 != 0: 

945 raise ValueError("Invalid number of arguments provided") 

946 

947 elif kwargs: 

948 assert not args 

949 args = tuple(itertools.chain(*kwargs.items())) 

950 

951 super().__init__(*args) 

952 

953 def __repr__(self): 

954 values = ", ".join(f"{k}: {v}" for k, v in batched(self.args, 2, strict=True)) 

955 return f"Dict({values})" 

956 

957 def substitute( 

958 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

959 ) -> Dict: 

960 subs_filtered = { 

961 k: v for k, v in subs.items() if k in self.dependencies and k != v 

962 } 

963 if not subs_filtered: 

964 return self 

965 

966 new_args = [] 

967 for arg in self.args: 

968 new_arg = ( 

969 arg.substitute(subs_filtered) 

970 if isinstance(arg, (GraphNode, TaskRef)) 

971 else arg 

972 ) 

973 new_args.append(new_arg) 

974 return type(self)(new_args) 

975 

976 def __iter__(self): 

977 yield from self.args[::2] 

978 

979 def __len__(self): 

980 return len(self.args) // 2 

981 

982 def __getitem__(self, key): 

983 for k, v in batched(self.args, 2, strict=True): 

984 if k == key: 

985 return v 

986 raise KeyError(key) 

987 

988 @staticmethod 

989 def constructor(args): 

990 return dict(batched(args, 2, strict=True)) 

991 

992 

993class DependenciesMapping(MutableMapping): 

994 def __init__(self, dsk): 

995 self.dsk = dsk 

996 self._removed = set() 

997 # Set a copy of dsk to avoid dct resizing 

998 self._cache = dsk.copy() 

999 self._cache.clear() 

1000 

1001 def __getitem__(self, key): 

1002 if (val := self._cache.get(key)) is not None: 

1003 return val 

1004 else: 

1005 v = self.dsk[key] 

1006 try: 

1007 deps = v.dependencies 

1008 except AttributeError: 

1009 from dask.core import get_dependencies 

1010 

1011 deps = get_dependencies(self.dsk, task=v) 

1012 

1013 if self._removed: 

1014 # deps is a frozenset but for good measure, let's not use -= since 

1015 # that _may_ perform an inplace mutation 

1016 deps = deps - self._removed 

1017 self._cache[key] = deps 

1018 return deps 

1019 

1020 def __iter__(self): 

1021 return iter(self.dsk) 

1022 

1023 def __delitem__(self, key: Any) -> None: 

1024 self._cache.clear() 

1025 self._removed.add(key) 

1026 

1027 def __setitem__(self, key: Any, value: Any) -> None: 

1028 raise NotImplementedError 

1029 

1030 def __len__(self) -> int: 

1031 return len(self.dsk) 

1032 

1033 

1034class _DevNullMapping(MutableMapping): 

1035 def __getitem__(self, key): 

1036 raise KeyError(key) 

1037 

1038 def __setitem__(self, key, value): 

1039 pass 

1040 

1041 def __delitem__(self, key): 

1042 pass 

1043 

1044 def __len__(self): 

1045 return 0 

1046 

1047 def __iter__(self): 

1048 return iter(()) 

1049 

1050 

1051def execute_graph( 

1052 dsk: Iterable[GraphNode] | Mapping[KeyType, GraphNode], 

1053 cache: MutableMapping[KeyType, object] | None = None, 

1054 keys: Container[KeyType] | None = None, 

1055) -> MutableMapping[KeyType, object]: 

1056 """Execute a given graph. 

1057 

1058 The graph is exceuted in topological order as defined by dask.order until 

1059 all leaf nodes, i.e. nodes without any dependents, are reached. The returned 

1060 dictionary contains the results of the leaf nodes. 

1061 

1062 If keys are required that are not part of the graph, they can be provided in the `cache` argument. 

1063 

1064 If `keys` is provided, the result will contain only values that are part of the `keys` set. 

1065 

1066 """ 

1067 if isinstance(dsk, (list, tuple, set, frozenset)): 

1068 dsk = {t.key: t for t in dsk} 

1069 else: 

1070 assert isinstance(dsk, dict) 

1071 

1072 refcount: defaultdict[KeyType, int] = defaultdict(int) 

1073 for vals in DependenciesMapping(dsk).values(): 

1074 for val in vals: 

1075 refcount[val] += 1 

1076 

1077 cache = cache or {} 

1078 from dask.order import order 

1079 

1080 priorities = order(dsk) 

1081 

1082 for key, node in sorted(dsk.items(), key=lambda it: priorities[it[0]]): 

1083 cache[key] = node(cache) 

1084 for dep in node.dependencies: 

1085 refcount[dep] -= 1 

1086 if refcount[dep] == 0 and keys and dep not in keys: 

1087 del cache[dep] 

1088 

1089 return cache 

1090 

1091 

1092def fuse_linear_task_spec(dsk, keys): 

1093 """ 

1094 keys are the keys from the graph that are requested by a computation. We 

1095 can't fuse those together. 

1096 """ 

1097 from dask.core import reverse_dict 

1098 from dask.optimization import default_fused_keys_renamer 

1099 

1100 keys = set(keys) 

1101 dependencies = DependenciesMapping(dsk) 

1102 dependents = reverse_dict(dependencies) 

1103 

1104 seen = set() 

1105 result = {} 

1106 

1107 for key in dsk: 

1108 if key in seen: 

1109 continue 

1110 

1111 seen.add(key) 

1112 

1113 deps = dependencies[key] 

1114 dependents_key = dependents[key] 

1115 

1116 if len(deps) != 1 and len(dependents_key) != 1 or dsk[key].block_fusion: 

1117 result[key] = dsk[key] 

1118 continue 

1119 

1120 linear_chain = [dsk[key]] 

1121 top_key = key 

1122 

1123 # Walk towards the leafs as long as the nodes have a single dependency 

1124 # and a single dependent, we can't fuse two nodes of an intermediate node 

1125 # is the source for 2 dependents 

1126 while len(deps) == 1: 

1127 (new_key,) = deps 

1128 if new_key in seen: 

1129 break 

1130 seen.add(new_key) 

1131 if new_key not in dsk: 

1132 # This can happen if a future is in the graph, the dependency mapping 

1133 # adds the key that is referenced by the future as a dependency 

1134 # see test_futures_to_delayed_array 

1135 break 

1136 if ( 

1137 len(dependents[new_key]) != 1 

1138 or dsk[new_key].block_fusion 

1139 or new_key in keys 

1140 ): 

1141 result[new_key] = dsk[new_key] 

1142 break 

1143 # backwards comp for new names, temporary until is_rootish is removed 

1144 linear_chain.insert(0, dsk[new_key]) 

1145 deps = dependencies[new_key] 

1146 

1147 # Walk the tree towards the root as long as the nodes have a single dependent 

1148 # and a single dependency, we can't fuse two nodes if node has multiple 

1149 # dependencies 

1150 while len(dependents_key) == 1 and top_key not in keys: 

1151 new_key = dependents_key.pop() 

1152 if new_key in seen: 

1153 break 

1154 seen.add(new_key) 

1155 if len(dependencies[new_key]) != 1 or dsk[new_key].block_fusion: 

1156 # Exit if the dependent has multiple dependencies, triangle 

1157 result[new_key] = dsk[new_key] 

1158 break 

1159 linear_chain.append(dsk[new_key]) 

1160 top_key = new_key 

1161 dependents_key = dependents[new_key] 

1162 

1163 if len(linear_chain) == 1: 

1164 result[top_key] = linear_chain[0] 

1165 else: 

1166 # Renaming the keys is necessary to preserve the rootish detection for now 

1167 renamed_key = default_fused_keys_renamer([tsk.key for tsk in linear_chain]) 

1168 result[renamed_key] = Task.fuse(*linear_chain, key=renamed_key) 

1169 if renamed_key != top_key: 

1170 # Having the same prefixes can result in the same key, i.e. getitem-hash -> getitem-hash 

1171 result[top_key] = Alias(top_key, target=renamed_key) 

1172 return result 

1173 

1174 

1175def cull( 

1176 dsk: dict[KeyType, GraphNode], keys: Iterable[KeyType] 

1177) -> dict[KeyType, GraphNode]: 

1178 if not isinstance(keys, (list, set, tuple)): 

1179 raise TypeError( 

1180 f"Expected list, set or tuple for keys, got {type(keys).__name__}" 

1181 ) 

1182 if len(keys) == len(dsk): 

1183 return dsk 

1184 work = set(keys) 

1185 seen: set[KeyType] = set() 

1186 dsk2 = {} 

1187 wpop = work.pop 

1188 wupdate = work.update 

1189 sadd = seen.add 

1190 while work: 

1191 k = wpop() 

1192 if k in seen or k not in dsk: 

1193 continue 

1194 sadd(k) 

1195 dsk2[k] = v = dsk[k] 

1196 wupdate(v.dependencies) 

1197 return dsk2 

1198 

1199 

1200@functools.cache 

1201def _extra_args(typ: type) -> set[str]: 

1202 import inspect 

1203 

1204 sig = inspect.signature(typ) 

1205 extras = set() 

1206 for name, param in sig.parameters.items(): 

1207 if param.kind in ( 

1208 inspect.Parameter.VAR_POSITIONAL, 

1209 inspect.Parameter.VAR_KEYWORD, 

1210 ): 

1211 continue 

1212 if name in typ.get_all_slots(): # type: ignore 

1213 extras.add(name) 

1214 return extras