Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/_task_spec.py: 25%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

656 statements  

1from __future__ import annotations 

2 

3""" Task specification for dask 

4 

5This module contains the task specification for dask. It is used to represent 

6runnable (task) and non-runnable (data) nodes in a dask graph. 

7 

8Simple examples of how to express tasks in dask 

9----------------------------------------------- 

10 

11.. code-block:: python 

12 

13 func("a", "b") ~ Task("key", func, "a", "b") 

14 

15 [func("a"), func("b")] ~ [Task("key-1", func, "a"), Task("key-2", func, "b")] 

16 

17 {"a": func("b")} ~ {"a": Task("a", func, "b")} 

18 

19 "literal-string" ~ DataNode("key", "literal-string") 

20 

21 

22Keys, Aliases and TaskRefs 

23------------------------- 

24 

25Keys are used to identify tasks in a dask graph. Every `GraphNode` instance has a 

26key attribute that _should_ reference the key in the dask graph. 

27 

28.. code-block:: python 

29 

30 {"key": Task("key", func, "a")} 

31 

32Referencing other tasks is possible by using either one of `Alias` or a 

33`TaskRef`. 

34 

35.. code-block:: python 

36 

37 # TaskRef can be used to provide the name of the reference explicitly 

38 t = Task("key", func, TaskRef("key-1")) 

39 

40 # If a task is still in scope, the method `ref` can be used for convenience 

41 t2 = Task("key2", func2, t.ref()) 

42 

43 

44Executing a task 

45---------------- 

46 

47A task can be executed by calling it with a dictionary of values. The values 

48should contain the dependencies of the task. 

49 

50.. code-block:: python 

51 

52 t = Task("key", add, TaskRef("a"), TaskRef("b")) 

53 assert t.dependencies == {"a", "b"} 

54 t({"a": 1, "b": 2}) == 3 

55 

56""" 

57import functools 

58import itertools 

59import sys 

60from collections import defaultdict 

61from collections.abc import Callable, Container, Iterable, Mapping, MutableMapping 

62from functools import lru_cache, partial 

63from typing import Any, TypeVar, cast 

64 

65from dask.sizeof import sizeof 

66from dask.typing import Key as KeyType 

67from dask.utils import funcname, is_namedtuple_instance 

68 

69_T = TypeVar("_T") 

70 

71 

72# Ported from more-itertools 

73# https://github.com/more-itertools/more-itertools/blob/c8153e2801ade2527f3a6c8b623afae93f5a1ce1/more_itertools/recipes.py#L944-L973 

74def _batched(iterable, n, *, strict=False): 

75 """Batch data into tuples of length *n*. If the number of items in 

76 *iterable* is not divisible by *n*: 

77 * The last batch will be shorter if *strict* is ``False``. 

78 * :exc:`ValueError` will be raised if *strict* is ``True``. 

79 

80 >>> list(batched('ABCDEFG', 3)) 

81 [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)] 

82 

83 On Python 3.13 and above, this is an alias for :func:`itertools.batched`. 

84 """ 

85 if n < 1: 

86 raise ValueError("n must be at least one") 

87 it = iter(iterable) 

88 while batch := tuple(itertools.islice(it, n)): 

89 if strict and len(batch) != n: 

90 raise ValueError("batched(): incomplete batch") 

91 yield batch 

92 

93 

94if sys.hexversion >= 0x30D00A2: 

95 

96 def batched(iterable, n, *, strict=False): 

97 return itertools.batched(iterable, n, strict=strict) 

98 

99else: 

100 batched = _batched 

101 

102 batched.__doc__ = _batched.__doc__ 

103# End port 

104 

105 

106def identity(*args): 

107 return args 

108 

109 

110def _identity_cast(*args, typ): 

111 return typ(args) 

112 

113 

114_anom_count = itertools.count() 

115 

116 

117def parse_input(obj: Any) -> object: 

118 """Tokenize user input into GraphNode objects 

119 

120 Note: This is similar to `convert_legacy_task` but does not 

121 - compare any values to a global set of known keys to infer references/futures 

122 - parse tuples and interprets them as runnable tasks 

123 - Deal with SubgraphCallables 

124 

125 Parameters 

126 ---------- 

127 obj : _type_ 

128 _description_ 

129 

130 Returns 

131 ------- 

132 _type_ 

133 _description_ 

134 """ 

135 if isinstance(obj, GraphNode): 

136 return obj 

137 

138 if isinstance(obj, TaskRef): 

139 return Alias(obj.key) 

140 

141 if isinstance(obj, dict): 

142 parsed_dict = {k: parse_input(v) for k, v in obj.items()} 

143 if any(isinstance(v, GraphNode) for v in parsed_dict.values()): 

144 return Dict(parsed_dict) 

145 

146 if isinstance(obj, (list, set, tuple)): 

147 parsed_collection = tuple(parse_input(o) for o in obj) 

148 if any(isinstance(o, GraphNode) for o in parsed_collection): 

149 if isinstance(obj, list): 

150 return List(*parsed_collection) 

151 if isinstance(obj, set): 

152 return Set(*parsed_collection) 

153 if isinstance(obj, tuple): 

154 if is_namedtuple_instance(obj): 

155 return _wrap_namedtuple_task(None, obj, parse_input) 

156 return Tuple(*parsed_collection) 

157 

158 return obj 

159 

160 

161def _wrap_namedtuple_task(k, obj, parser): 

162 if hasattr(obj, "__getnewargs_ex__"): 

163 new_args, kwargs = obj.__getnewargs_ex__() 

164 kwargs = {k: parser(v) for k, v in kwargs.items()} 

165 elif hasattr(obj, "__getnewargs__"): 

166 new_args = obj.__getnewargs__() 

167 kwargs = {} 

168 

169 args_converted = parse_input(type(new_args)(map(parser, new_args))) 

170 

171 return Task( 

172 k, partial(_instantiate_named_tuple, type(obj)), args_converted, Dict(kwargs) 

173 ) 

174 

175 

176def _instantiate_named_tuple(typ, args, kwargs): 

177 return typ(*args, **kwargs) 

178 

179 

180class _MultiContainer(Container): 

181 container: tuple 

182 __slots__ = ("container",) 

183 

184 def __init__(self, *container): 

185 self.container = container 

186 

187 def __contains__(self, o: object) -> bool: 

188 return any(o in c for c in self.container) 

189 

190 

191SubgraphType = None 

192 

193 

194def _execute_subgraph(inner_dsk, outkey, inkeys, *dependencies): 

195 final = {} 

196 final.update(inner_dsk) 

197 for k, v in zip(inkeys, dependencies): 

198 final[k] = DataNode(None, v) 

199 res = execute_graph(final, keys=[outkey]) 

200 return res[outkey] 

201 

202 

203def convert_legacy_task( 

204 key: KeyType | None, 

205 task: _T, 

206 all_keys: Container, 

207) -> GraphNode | _T: 

208 if isinstance(task, GraphNode): 

209 return task 

210 

211 if type(task) is tuple and task and callable(task[0]): 

212 func, args = task[0], task[1:] 

213 new_args = [] 

214 new: object 

215 for a in args: 

216 if isinstance(a, dict): 

217 new = Dict(a) 

218 else: 

219 new = convert_legacy_task(None, a, all_keys) 

220 new_args.append(new) 

221 return Task(key, func, *new_args) 

222 try: 

223 if isinstance(task, (int, float, str, tuple)): 

224 if task in all_keys: 

225 if key is None: 

226 return Alias(task) 

227 else: 

228 return Alias(key, target=task) 

229 except TypeError: 

230 # Unhashable 

231 pass 

232 

233 if isinstance(task, (list, tuple, set, frozenset)): 

234 if is_namedtuple_instance(task): 

235 return _wrap_namedtuple_task( 

236 key, 

237 task, 

238 partial( 

239 convert_legacy_task, 

240 None, 

241 all_keys=all_keys, 

242 ), 

243 ) 

244 else: 

245 parsed_args = tuple(convert_legacy_task(None, t, all_keys) for t in task) 

246 if any(isinstance(a, GraphNode) for a in parsed_args): 

247 return Task(key, _identity_cast, *parsed_args, typ=type(task)) 

248 else: 

249 return cast(_T, type(task)(parsed_args)) 

250 elif isinstance(task, TaskRef): 

251 if key is None: 

252 return Alias(task.key) 

253 else: 

254 return Alias(key, target=task.key) 

255 else: 

256 return task 

257 

258 

259def convert_legacy_graph( 

260 dsk: Mapping, 

261 all_keys: Container | None = None, 

262): 

263 if all_keys is None: 

264 all_keys = set(dsk) 

265 new_dsk = {} 

266 for k, arg in dsk.items(): 

267 t = convert_legacy_task(k, arg, all_keys) 

268 if isinstance(t, Alias) and t.target == k: 

269 continue 

270 elif not isinstance(t, GraphNode): 

271 t = DataNode(k, t) 

272 new_dsk[k] = t 

273 return new_dsk 

274 

275 

276def resolve_aliases(dsk: dict, keys: set, dependents: dict) -> dict: 

277 """Remove trivial sequential alias chains 

278 

279 Example: 

280 

281 dsk = {'x': 1, 'y': Alias('x'), 'z': Alias('y')} 

282 

283 resolve_aliases(dsk, {'z'}, {'x': {'y'}, 'y': {'z'}}) == {'z': 1} 

284 

285 """ 

286 if not keys: 

287 raise ValueError("No keys provided") 

288 dsk = dict(dsk) 

289 work = list(keys) 

290 seen = set() 

291 while work: 

292 k = work.pop() 

293 if k in seen or k not in dsk: 

294 continue 

295 seen.add(k) 

296 t = dsk[k] 

297 if isinstance(t, Alias): 

298 target_key = t.target 

299 # Rules for when we allow to collapse an alias 

300 # 1. The target key is not in the keys set. The keys set is what the 

301 # user is requesting and by collapsing we'd no longer be able to 

302 # return that result. 

303 # 2. The target key is in fact part of dsk. If it isn't this could 

304 # point to a persisted dependency and we cannot collapse it. 

305 # 3. The target key has only one dependent which is the key we're 

306 # currently looking at. This means that there is a one to one 

307 # relation between this and the target key in which case we can 

308 # collapse them. 

309 # Note: If target was an alias as well, we could continue with 

310 # more advanced optimizations but this isn't implemented, yet 

311 if ( 

312 target_key not in keys 

313 and target_key in dsk 

314 # Note: whenever we're performing a collapse, we're not updating 

315 # the dependents. The length == 1 should still be sufficient for 

316 # chains of these aliases 

317 and len(dependents[target_key]) == 1 

318 ): 

319 tnew = dsk.pop(target_key).copy() 

320 

321 dsk[k] = tnew 

322 tnew.key = k 

323 if isinstance(tnew, Alias): 

324 work.append(k) 

325 seen.discard(k) 

326 else: 

327 work.extend(tnew.dependencies) 

328 

329 work.extend(t.dependencies) 

330 return dsk 

331 

332 

333class TaskRef: 

334 val: KeyType 

335 __slots__ = ("key",) 

336 

337 def __init__(self, key: KeyType): 

338 self.key = key 

339 

340 def __str__(self): 

341 return str(self.key) 

342 

343 def __repr__(self): 

344 return f"{type(self).__name__}({self.key!r})" 

345 

346 def __hash__(self) -> int: 

347 return hash(self.key) 

348 

349 def __eq__(self, value: object) -> bool: 

350 if not isinstance(value, TaskRef): 

351 return False 

352 return self.key == value.key 

353 

354 def __reduce__(self): 

355 return TaskRef, (self.key,) 

356 

357 def substitute(self, subs: dict, key: KeyType | None = None) -> TaskRef | GraphNode: 

358 if self.key in subs: 

359 val = subs[self.key] 

360 if isinstance(val, GraphNode): 

361 return val.substitute({}, key=self.key) 

362 elif isinstance(val, TaskRef): 

363 return val 

364 else: 

365 return TaskRef(val) 

366 return self 

367 

368 

369class GraphNode: 

370 key: KeyType 

371 _dependencies: frozenset 

372 

373 __slots__ = tuple(__annotations__) 

374 

375 def ref(self): 

376 return Alias(self.key) 

377 

378 def copy(self): 

379 raise NotImplementedError 

380 

381 @property 

382 def data_producer(self) -> bool: 

383 return False 

384 

385 @property 

386 def dependencies(self) -> frozenset: 

387 return self._dependencies 

388 

389 @property 

390 def block_fusion(self) -> bool: 

391 return False 

392 

393 def _verify_values(self, values: tuple | dict) -> None: 

394 if not self.dependencies: 

395 return 

396 if missing := set(self.dependencies) - set(values): 

397 raise RuntimeError(f"Not enough arguments provided: missing keys {missing}") 

398 

399 def __call__(self, values) -> Any: 

400 raise NotImplementedError("Not implemented") 

401 

402 def __eq__(self, value: object) -> bool: 

403 if type(value) is not type(self): 

404 return False 

405 

406 from dask.tokenize import tokenize 

407 

408 return tokenize(self) == tokenize(value) 

409 

410 @property 

411 def is_coro(self) -> bool: 

412 return False 

413 

414 def __sizeof__(self) -> int: 

415 all_slots = self.get_all_slots() 

416 return sum(sizeof(getattr(self, sl)) for sl in all_slots) + sys.getsizeof( 

417 type(self) 

418 ) 

419 

420 def substitute( 

421 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

422 ) -> GraphNode: 

423 """Substitute a dependency with a new value. The new value either has to 

424 be a new valid key or a GraphNode to replace the dependency entirely. 

425 

426 The GraphNode will not be mutated but instead a shallow copy will be 

427 returned. The substitution will be performed eagerly. 

428 

429 Parameters 

430 ---------- 

431 subs : dict[KeyType, KeyType | GraphNode] 

432 The mapping describing the substitutions to be made. 

433 key : KeyType | None, optional 

434 The key of the new GraphNode object. If None provided, the key of 

435 the old one will be reused. 

436 """ 

437 raise NotImplementedError 

438 

439 @staticmethod 

440 def fuse(*tasks: GraphNode, key: KeyType | None = None) -> GraphNode: 

441 """Fuse a set of tasks into a single task. 

442 

443 The tasks are fused into a single task that will execute the tasks in a 

444 subgraph. The internal tasks are no longer accessible from the outside. 

445 

446 All provided tasks must form a valid subgraph that will reduce to a 

447 single key. If multiple outputs are possible with the provided tasks, an 

448 exception will be raised. 

449 

450 The tasks will not be rewritten but instead a new Task will be created 

451 that will merely reference the old task objects. This way, Task objects 

452 may be reused in multiple fused tasks. 

453 

454 Parameters 

455 ---------- 

456 key : KeyType | None, optional 

457 The key of the new Task object. If None provided, the key of the 

458 final task will be used. 

459 

460 See also 

461 -------- 

462 GraphNode.substitute : Easier substitution of dependencies 

463 """ 

464 if any(t.key is None for t in tasks): 

465 raise ValueError("Cannot fuse tasks with missing keys") 

466 if len(tasks) == 1: 

467 return tasks[0].substitute({}, key=key) 

468 all_keys = set() 

469 all_deps: set[KeyType] = set() 

470 for t in tasks: 

471 all_deps.update(t.dependencies) 

472 all_keys.add(t.key) 

473 external_deps = tuple(sorted(all_deps - all_keys, key=hash)) 

474 leafs = all_keys - all_deps 

475 if len(leafs) > 1: 

476 raise ValueError(f"Cannot fuse tasks with multiple outputs {leafs}") 

477 

478 outkey = leafs.pop() 

479 return Task( 

480 key or outkey, 

481 _execute_subgraph, 

482 {t.key: t for t in tasks}, 

483 outkey, 

484 external_deps, 

485 *(TaskRef(k) for k in external_deps), 

486 _data_producer=any(t.data_producer for t in tasks), 

487 ) 

488 

489 @classmethod 

490 @lru_cache 

491 def get_all_slots(cls): 

492 slots = list() 

493 for c in cls.mro(): 

494 slots.extend(getattr(c, "__slots__", ())) 

495 # Interestingly, sorting this causes the nested containers to pickle 

496 # more efficiently 

497 return sorted(set(slots)) 

498 

499 

500_no_deps: frozenset = frozenset() 

501 

502 

503class Alias(GraphNode): 

504 target: KeyType 

505 __slots__ = tuple(__annotations__) 

506 

507 def __init__( 

508 self, key: KeyType | TaskRef, target: Alias | TaskRef | KeyType | None = None 

509 ): 

510 if isinstance(key, TaskRef): 

511 key = key.key 

512 self.key = key 

513 if target is None: 

514 target = key 

515 if isinstance(target, Alias): 

516 target = target.target 

517 if isinstance(target, TaskRef): 

518 target = target.key 

519 self.target = target 

520 self._dependencies = frozenset((self.target,)) 

521 

522 def __reduce__(self): 

523 return Alias, (self.key, self.target) 

524 

525 def copy(self): 

526 return Alias(self.key, self.target) 

527 

528 def substitute( 

529 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

530 ) -> GraphNode: 

531 if self.key in subs or self.target in subs: 

532 sub_key = subs.get(self.key, self.key) 

533 val = subs.get(self.target, self.target) 

534 if sub_key == self.key and val == self.target: 

535 return self 

536 if isinstance(val, (GraphNode, TaskRef)): 

537 return val.substitute({}, key=key) 

538 if key is None and isinstance(sub_key, GraphNode): 

539 raise RuntimeError( 

540 f"Invalid substitution encountered {self.key!r} -> {sub_key}" 

541 ) 

542 return Alias(key or sub_key, val) # type: ignore 

543 return self 

544 

545 def __dask_tokenize__(self): 

546 return (type(self).__name__, self.key, self.target) 

547 

548 def __call__(self, values=()): 

549 self._verify_values(values) 

550 return values[self.target] 

551 

552 def __repr__(self): 

553 if self.key != self.target: 

554 return f"Alias({self.key!r}->{self.target!r})" 

555 else: 

556 return f"Alias({self.key!r})" 

557 

558 def __eq__(self, value: object) -> bool: 

559 if not isinstance(value, Alias): 

560 return False 

561 if self.key != value.key: 

562 return False 

563 return self.target == value.target 

564 

565 

566class DataNode(GraphNode): 

567 value: Any 

568 typ: type 

569 __slots__ = tuple(__annotations__) 

570 

571 def __init__(self, key: Any, value: Any): 

572 if key is None: 

573 key = (type(value).__name__, next(_anom_count)) 

574 self.key = key 

575 self.value = value 

576 self.typ = type(value) 

577 self._dependencies = _no_deps 

578 

579 @property 

580 def data_producer(self) -> bool: 

581 return True 

582 

583 def copy(self): 

584 return DataNode(self.key, self.value) 

585 

586 def __call__(self, values=()): 

587 return self.value 

588 

589 def __repr__(self): 

590 return f"DataNode({self.value!r})" 

591 

592 def __reduce__(self): 

593 return (DataNode, (self.key, self.value)) 

594 

595 def __dask_tokenize__(self): 

596 from dask.base import tokenize 

597 

598 return (type(self).__name__, tokenize(self.value)) 

599 

600 def substitute( 

601 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

602 ) -> DataNode: 

603 if key is not None and key != self.key: 

604 return DataNode(key, self.value) 

605 return self 

606 

607 def __iter__(self): 

608 return iter(self.value) 

609 

610 

611def _get_dependencies(obj: object) -> set | frozenset: 

612 if isinstance(obj, TaskRef): 

613 return {obj.key} 

614 elif isinstance(obj, GraphNode): 

615 return obj.dependencies 

616 elif isinstance(obj, dict): 

617 if not obj: 

618 return _no_deps 

619 return set().union(*map(_get_dependencies, obj.values())) 

620 elif isinstance(obj, (list, tuple, frozenset, set)): 

621 if not obj: 

622 return _no_deps 

623 return set().union(*map(_get_dependencies, obj)) 

624 return _no_deps 

625 

626 

627class Task(GraphNode): 

628 func: Callable 

629 args: tuple 

630 kwargs: dict 

631 _data_producer: bool 

632 _token: str | None 

633 _is_coro: bool | None 

634 _repr: str | None 

635 

636 __slots__ = tuple(__annotations__) 

637 

638 def __init__( 

639 self, 

640 key: Any, 

641 func: Callable, 

642 /, 

643 *args: Any, 

644 _data_producer: bool = False, 

645 **kwargs: Any, 

646 ): 

647 self.key = key 

648 self.func = func 

649 if isinstance(func, Task): 

650 raise TypeError("Cannot nest tasks") 

651 

652 self.args = args 

653 self.kwargs = kwargs 

654 _dependencies: set[KeyType] | None = None 

655 for a in itertools.chain(args, kwargs.values()): 

656 if isinstance(a, TaskRef): 

657 if _dependencies is None: 

658 _dependencies = {a.key} 

659 else: 

660 _dependencies.add(a.key) 

661 elif isinstance(a, GraphNode) and a.dependencies: 

662 if _dependencies is None: 

663 _dependencies = set(a.dependencies) 

664 else: 

665 _dependencies.update(a.dependencies) 

666 if _dependencies: 

667 self._dependencies = frozenset(_dependencies) 

668 else: 

669 self._dependencies = _no_deps 

670 self._is_coro = None 

671 self._token = None 

672 self._repr = None 

673 self._data_producer = _data_producer 

674 

675 @property 

676 def data_producer(self) -> bool: 

677 return self._data_producer 

678 

679 def has_subgraph(self) -> bool: 

680 return self.func == _execute_subgraph 

681 

682 def copy(self): 

683 return type(self)( 

684 self.key, 

685 self.func, 

686 *self.args, 

687 **self.kwargs, 

688 ) 

689 

690 def __hash__(self): 

691 return hash(self._get_token()) 

692 

693 def _get_token(self) -> str: 

694 if self._token: 

695 return self._token 

696 from dask.base import tokenize 

697 

698 self._token = tokenize( 

699 ( 

700 type(self).__name__, 

701 self.func, 

702 self.args, 

703 self.kwargs, 

704 ) 

705 ) 

706 return self._token 

707 

708 def __dask_tokenize__(self): 

709 return self._get_token() 

710 

711 def __repr__(self) -> str: 

712 # When `Task` is deserialized the constructor will not run and 

713 # `self._repr` is thus undefined. 

714 if not hasattr(self, "_repr") or not self._repr: 

715 head = funcname(self.func) 

716 tail = ")" 

717 label_size = 40 

718 args = self.args 

719 kwargs = self.kwargs 

720 if args or kwargs: 

721 label_size2 = int( 

722 (label_size - len(head) - len(tail) - len(str(self.key))) 

723 // (len(args) + len(kwargs)) 

724 ) 

725 if args: 

726 if label_size2 > 5: 

727 args_repr = ", ".join(repr(t) for t in args) 

728 else: 

729 args_repr = "..." 

730 else: 

731 args_repr = "" 

732 if kwargs: 

733 if label_size2 > 5: 

734 kwargs_repr = ", " + ", ".join( 

735 f"{k}={repr(v)}" for k, v in sorted(kwargs.items()) 

736 ) 

737 else: 

738 kwargs_repr = ", ..." 

739 else: 

740 kwargs_repr = "" 

741 self._repr = f"<Task {self.key!r} {head}({args_repr}{kwargs_repr}{tail}>" 

742 return self._repr 

743 

744 def __call__(self, values=()): 

745 self._verify_values(values) 

746 

747 def _eval(a): 

748 if isinstance(a, GraphNode): 

749 return a({k: values[k] for k in a.dependencies}) 

750 elif isinstance(a, TaskRef): 

751 return values[a.key] 

752 else: 

753 return a 

754 

755 new_argspec = tuple(map(_eval, self.args)) 

756 if self.kwargs: 

757 kwargs = {k: _eval(kw) for k, kw in self.kwargs.items()} 

758 return self.func(*new_argspec, **kwargs) 

759 return self.func(*new_argspec) 

760 

761 def __setstate__(self, state): 

762 slots = self.__class__.get_all_slots() 

763 for sl, val in zip(slots, state): 

764 setattr(self, sl, val) 

765 

766 def __getstate__(self): 

767 slots = self.__class__.get_all_slots() 

768 return tuple(getattr(self, sl) for sl in slots) 

769 

770 @property 

771 def is_coro(self): 

772 if self._is_coro is None: 

773 # Note: Can't use cached_property on objects without __dict__ 

774 try: 

775 from distributed.utils import iscoroutinefunction 

776 

777 self._is_coro = iscoroutinefunction(self.func) 

778 except Exception: 

779 self._is_coro = False 

780 return self._is_coro 

781 

782 def substitute( 

783 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

784 ) -> Task: 

785 subs_filtered = { 

786 k: v for k, v in subs.items() if k in self.dependencies and k != v 

787 } 

788 extras = _extra_args(type(self)) # type: ignore 

789 extra_kwargs = { 

790 name: getattr(self, name) for name in extras if name not in {"key", "func"} 

791 } 

792 if subs_filtered: 

793 new_args = tuple( 

794 ( 

795 a.substitute(subs_filtered) 

796 if isinstance(a, (GraphNode, TaskRef)) 

797 else a 

798 ) 

799 for a in self.args 

800 ) 

801 new_kwargs = { 

802 k: ( 

803 v.substitute(subs_filtered) 

804 if isinstance(v, (GraphNode, TaskRef)) 

805 else v 

806 ) 

807 for k, v in self.kwargs.items() 

808 } 

809 return type(self)( 

810 key or self.key, 

811 self.func, 

812 *new_args, 

813 **new_kwargs, # type: ignore[arg-type] 

814 **extra_kwargs, 

815 ) 

816 elif key is None or key == self.key: 

817 return self 

818 else: 

819 # Rename 

820 return type(self)( 

821 key, 

822 self.func, 

823 *self.args, 

824 **self.kwargs, 

825 **extra_kwargs, 

826 ) 

827 

828 

829class NestedContainer(Task, Iterable): 

830 constructor: Callable 

831 klass: type 

832 __slots__ = tuple(__annotations__) 

833 

834 def __init__( 

835 self, 

836 /, 

837 *args: Any, 

838 **kwargs: Any, 

839 ): 

840 if len(args) == 1 and isinstance(args[0], self.klass): 

841 args = args[0] # type: ignore 

842 super().__init__( 

843 None, 

844 self.to_container, 

845 *args, 

846 constructor=self.constructor, 

847 **kwargs, 

848 ) 

849 

850 def __getstate__(self): 

851 state = super().__getstate__() 

852 state = list(state) 

853 slots = self.__class__.get_all_slots() 

854 ix = slots.index("kwargs") 

855 # The constructor as a kwarg is redundant since this is encoded in the 

856 # class itself. Serializing the builtin types is not trivial 

857 # This saves about 15% of overhead 

858 state[ix] = state[ix].copy() 

859 state[ix].pop("constructor", None) 

860 return state 

861 

862 def __setstate__(self, state): 

863 super().__setstate__(state) 

864 self.kwargs["constructor"] = self.__class__.constructor 

865 return self 

866 

867 def __repr__(self): 

868 return f"{type(self).__name__}({self.args})" 

869 

870 def substitute( 

871 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

872 ) -> NestedContainer: 

873 subs_filtered = { 

874 k: v for k, v in subs.items() if k in self.dependencies and k != v 

875 } 

876 if not subs_filtered: 

877 return self 

878 return type(self)( 

879 *( 

880 ( 

881 a.substitute(subs_filtered) 

882 if isinstance(a, (GraphNode, TaskRef)) 

883 else a 

884 ) 

885 for a in self.args 

886 ) 

887 ) 

888 

889 def __dask_tokenize__(self): 

890 from dask.tokenize import tokenize 

891 

892 return ( 

893 type(self).__name__, 

894 self.klass, 

895 sorted(tokenize(a) for a in self.args), 

896 ) 

897 

898 return super().__dask_tokenize__() 

899 

900 @staticmethod 

901 def to_container(*args, constructor): 

902 return constructor(args) 

903 

904 def __iter__(self): 

905 yield from self.args 

906 

907 

908class List(NestedContainer): 

909 constructor = klass = list 

910 

911 

912class Tuple(NestedContainer): 

913 constructor = klass = tuple 

914 

915 

916class Set(NestedContainer): 

917 constructor = klass = set 

918 

919 

920class Dict(NestedContainer, Mapping): 

921 klass = dict 

922 

923 def __init__(self, /, *args: Any, **kwargs: Any): 

924 if args: 

925 assert not kwargs 

926 if len(args) == 1: 

927 args = args[0] 

928 if isinstance(args, dict): # type: ignore 

929 args = tuple(itertools.chain(*args.items())) # type: ignore 

930 elif isinstance(args, (list, tuple)): 

931 if all( 

932 len(el) == 2 if isinstance(el, (list, tuple)) else False 

933 for el in args 

934 ): 

935 args = tuple(itertools.chain(*args)) 

936 else: 

937 raise ValueError("Invalid argument provided") 

938 

939 if len(args) % 2 != 0: 

940 raise ValueError("Invalid number of arguments provided") 

941 

942 elif kwargs: 

943 assert not args 

944 args = tuple(itertools.chain(*kwargs.items())) 

945 

946 super().__init__(*args) 

947 

948 def __repr__(self): 

949 values = ", ".join(f"{k}: {v}" for k, v in batched(self.args, 2, strict=True)) 

950 return f"Dict({values})" 

951 

952 def substitute( 

953 self, subs: dict[KeyType, KeyType | GraphNode], key: KeyType | None = None 

954 ) -> Dict: 

955 subs_filtered = { 

956 k: v for k, v in subs.items() if k in self.dependencies and k != v 

957 } 

958 if not subs_filtered: 

959 return self 

960 

961 new_args = [] 

962 for arg in self.args: 

963 new_arg = ( 

964 arg.substitute(subs_filtered) 

965 if isinstance(arg, (GraphNode, TaskRef)) 

966 else arg 

967 ) 

968 new_args.append(new_arg) 

969 return type(self)(new_args) 

970 

971 def __iter__(self): 

972 yield from self.args[::2] 

973 

974 def __len__(self): 

975 return len(self.args) // 2 

976 

977 def __getitem__(self, key): 

978 for k, v in batched(self.args, 2, strict=True): 

979 if k == key: 

980 return v 

981 raise KeyError(key) 

982 

983 @staticmethod 

984 def constructor(args): 

985 return dict(batched(args, 2, strict=True)) 

986 

987 

988class DependenciesMapping(MutableMapping): 

989 def __init__(self, dsk): 

990 self.dsk = dsk 

991 self._removed = set() 

992 # Set a copy of dsk to avoid dct resizing 

993 self._cache = dsk.copy() 

994 self._cache.clear() 

995 

996 def __getitem__(self, key): 

997 if (val := self._cache.get(key)) is not None: 

998 return val 

999 else: 

1000 v = self.dsk[key] 

1001 try: 

1002 deps = v.dependencies 

1003 except AttributeError: 

1004 from dask.core import get_dependencies 

1005 

1006 deps = get_dependencies(self.dsk, task=v) 

1007 

1008 if self._removed: 

1009 # deps is a frozenset but for good measure, let's not use -= since 

1010 # that _may_ perform an inplace mutation 

1011 deps = deps - self._removed 

1012 self._cache[key] = deps 

1013 return deps 

1014 

1015 def __iter__(self): 

1016 return iter(self.dsk) 

1017 

1018 def __delitem__(self, key: Any) -> None: 

1019 self._cache.clear() 

1020 self._removed.add(key) 

1021 

1022 def __setitem__(self, key: Any, value: Any) -> None: 

1023 raise NotImplementedError 

1024 

1025 def __len__(self) -> int: 

1026 return len(self.dsk) 

1027 

1028 

1029class _DevNullMapping(MutableMapping): 

1030 def __getitem__(self, key): 

1031 raise KeyError(key) 

1032 

1033 def __setitem__(self, key, value): 

1034 pass 

1035 

1036 def __delitem__(self, key): 

1037 pass 

1038 

1039 def __len__(self): 

1040 return 0 

1041 

1042 def __iter__(self): 

1043 return iter(()) 

1044 

1045 

1046def execute_graph( 

1047 dsk: Iterable[GraphNode] | Mapping[KeyType, GraphNode], 

1048 cache: MutableMapping[KeyType, object] | None = None, 

1049 keys: Container[KeyType] | None = None, 

1050) -> MutableMapping[KeyType, object]: 

1051 """Execute a given graph. 

1052 

1053 The graph is executed in topological order as defined by dask.order until 

1054 all leaf nodes, i.e. nodes without any dependents, are reached. The returned 

1055 dictionary contains the results of the leaf nodes. 

1056 

1057 If keys are required that are not part of the graph, they can be provided in the `cache` argument. 

1058 

1059 If `keys` is provided, the result will contain only values that are part of the `keys` set. 

1060 

1061 """ 

1062 if isinstance(dsk, (list, tuple, set, frozenset)): 

1063 dsk = {t.key: t for t in dsk} 

1064 else: 

1065 assert isinstance(dsk, dict) 

1066 

1067 refcount: defaultdict[KeyType, int] = defaultdict(int) 

1068 for vals in DependenciesMapping(dsk).values(): 

1069 for val in vals: 

1070 refcount[val] += 1 

1071 

1072 cache = cache or {} 

1073 from dask.order import order 

1074 

1075 priorities = order(dsk) 

1076 

1077 for key, node in sorted(dsk.items(), key=lambda it: priorities[it[0]]): 

1078 cache[key] = node(cache) 

1079 for dep in node.dependencies: 

1080 refcount[dep] -= 1 

1081 if refcount[dep] == 0 and keys and dep not in keys: 

1082 del cache[dep] 

1083 

1084 return cache 

1085 

1086 

1087def fuse_linear_task_spec(dsk, keys): 

1088 """ 

1089 keys are the keys from the graph that are requested by a computation. We 

1090 can't fuse those together. 

1091 """ 

1092 from dask.core import reverse_dict 

1093 from dask.optimization import default_fused_keys_renamer 

1094 

1095 keys = set(keys) 

1096 dependencies = DependenciesMapping(dsk) 

1097 dependents = reverse_dict(dependencies) 

1098 

1099 seen = set() 

1100 result = {} 

1101 

1102 for key in dsk: 

1103 if key in seen: 

1104 continue 

1105 

1106 seen.add(key) 

1107 

1108 deps = dependencies[key] 

1109 dependents_key = dependents[key] 

1110 

1111 if len(deps) != 1 and len(dependents_key) != 1 or dsk[key].block_fusion: 

1112 result[key] = dsk[key] 

1113 continue 

1114 

1115 linear_chain = [dsk[key]] 

1116 top_key = key 

1117 

1118 # Walk towards the leafs as long as the nodes have a single dependency 

1119 # and a single dependent, we can't fuse two nodes of an intermediate node 

1120 # is the source for 2 dependents 

1121 while len(deps) == 1: 

1122 (new_key,) = deps 

1123 if new_key in seen: 

1124 break 

1125 seen.add(new_key) 

1126 if new_key not in dsk: 

1127 # This can happen if a future is in the graph, the dependency mapping 

1128 # adds the key that is referenced by the future as a dependency 

1129 # see test_futures_to_delayed_array 

1130 break 

1131 if ( 

1132 len(dependents[new_key]) != 1 

1133 or dsk[new_key].block_fusion 

1134 or new_key in keys 

1135 ): 

1136 result[new_key] = dsk[new_key] 

1137 break 

1138 # backwards comp for new names, temporary until is_rootish is removed 

1139 linear_chain.insert(0, dsk[new_key]) 

1140 deps = dependencies[new_key] 

1141 

1142 # Walk the tree towards the root as long as the nodes have a single dependent 

1143 # and a single dependency, we can't fuse two nodes if node has multiple 

1144 # dependencies 

1145 while len(dependents_key) == 1 and top_key not in keys: 

1146 new_key = dependents_key.pop() 

1147 if new_key in seen: 

1148 break 

1149 seen.add(new_key) 

1150 if len(dependencies[new_key]) != 1 or dsk[new_key].block_fusion: 

1151 # Exit if the dependent has multiple dependencies, triangle 

1152 result[new_key] = dsk[new_key] 

1153 break 

1154 linear_chain.append(dsk[new_key]) 

1155 top_key = new_key 

1156 dependents_key = dependents[new_key] 

1157 

1158 if len(linear_chain) == 1: 

1159 result[top_key] = linear_chain[0] 

1160 else: 

1161 # Renaming the keys is necessary to preserve the rootish detection for now 

1162 renamed_key = default_fused_keys_renamer([tsk.key for tsk in linear_chain]) 

1163 result[renamed_key] = Task.fuse(*linear_chain, key=renamed_key) 

1164 if renamed_key != top_key: 

1165 # Having the same prefixes can result in the same key, i.e. getitem-hash -> getitem-hash 

1166 result[top_key] = Alias(top_key, target=renamed_key) 

1167 return result 

1168 

1169 

1170def cull( 

1171 dsk: dict[KeyType, GraphNode], keys: Iterable[KeyType] 

1172) -> dict[KeyType, GraphNode]: 

1173 if not isinstance(keys, (list, set, tuple)): 

1174 raise TypeError( 

1175 f"Expected list, set or tuple for keys, got {type(keys).__name__}" 

1176 ) 

1177 if len(keys) == len(dsk): 

1178 return dsk 

1179 work = set(keys) 

1180 seen: set[KeyType] = set() 

1181 dsk2 = {} 

1182 wpop = work.pop 

1183 wupdate = work.update 

1184 sadd = seen.add 

1185 while work: 

1186 k = wpop() 

1187 if k in seen or k not in dsk: 

1188 continue 

1189 sadd(k) 

1190 dsk2[k] = v = dsk[k] 

1191 wupdate(v.dependencies) 

1192 return dsk2 

1193 

1194 

1195@functools.cache 

1196def _extra_args(typ: type) -> set[str]: 

1197 import inspect 

1198 

1199 sig = inspect.signature(typ) 

1200 extras = set() 

1201 for name, param in sig.parameters.items(): 

1202 if param.kind in ( 

1203 inspect.Parameter.VAR_POSITIONAL, 

1204 inspect.Parameter.VAR_KEYWORD, 

1205 ): 

1206 continue 

1207 if name in typ.get_all_slots(): # type: ignore 

1208 extras.add(name) 

1209 return extras