Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/core.py: 32%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

222 statements  

1from __future__ import annotations 

2 

3from collections import defaultdict 

4from collections.abc import Collection, Iterable, Mapping, MutableMapping 

5from typing import Any, Literal, TypeVar, cast, overload 

6 

7import toolz 

8 

9from dask._task_spec import ( 

10 DependenciesMapping, 

11 TaskRef, 

12 convert_legacy_graph, 

13 execute_graph, 

14) 

15from dask.typing import Graph, Key, NoDefault, no_default 

16 

17 

18def ishashable(x): 

19 """Is x hashable? 

20 

21 Examples 

22 -------- 

23 

24 >>> ishashable(1) 

25 True 

26 >>> ishashable([1]) 

27 False 

28 

29 See Also 

30 -------- 

31 iskey 

32 """ 

33 try: 

34 hash(x) 

35 return True 

36 except TypeError: 

37 return False 

38 

39 

40def istask(x): 

41 """Is x a runnable task? 

42 

43 A task is a tuple with a callable first argument 

44 

45 Examples 

46 -------- 

47 

48 >>> inc = lambda x: x + 1 

49 >>> istask((inc, 1)) 

50 True 

51 >>> istask(1) 

52 False 

53 """ 

54 from dask._task_spec import DataNode, GraphNode 

55 

56 if isinstance(x, GraphNode): 

57 if isinstance(x, DataNode): 

58 return False 

59 return True 

60 return type(x) is tuple and x and callable(x[0]) 

61 

62 

63def preorder_traversal(task): 

64 """A generator to preorder-traverse a task.""" 

65 

66 for item in task: 

67 if istask(item): 

68 yield from preorder_traversal(item) 

69 elif isinstance(item, list): 

70 yield list 

71 yield from preorder_traversal(item) 

72 else: 

73 yield item 

74 

75 

76def lists_to_tuples(res, keys): 

77 if isinstance(keys, list): 

78 return tuple(lists_to_tuples(r, k) for r, k in zip(res, keys)) 

79 return res 

80 

81 

82def _pack_result(result: Mapping, keys: list | Key) -> Any: 

83 if isinstance(keys, list): 

84 return tuple(_pack_result(result, k) for k in keys) 

85 return result[keys] 

86 

87 

88def get(dsk: Mapping, out: list | Key, cache: MutableMapping | None = None) -> Any: 

89 """Get value from Dask 

90 

91 Examples 

92 -------- 

93 

94 >>> inc = lambda x: x + 1 

95 >>> d = {'x': 1, 'y': (inc, 'x')} 

96 

97 >>> get(d, 'x') 

98 1 

99 >>> get(d, 'y') 

100 2 

101 """ 

102 for k in flatten(out): 

103 if k not in dsk: 

104 raise KeyError(f"{k} is not a key in the graph") 

105 if cache is None: 

106 cache = {} 

107 

108 dsk2 = convert_legacy_graph(dsk, all_keys=set(dsk) | set(cache)) 

109 result = execute_graph(dsk2, cache, keys=set(flatten([out]))) 

110 return _pack_result(result, out) 

111 

112 

113def keys_in_tasks(keys: Collection[Key], tasks: Iterable[Any], as_list: bool = False): 

114 """Returns the keys in `keys` that are also in `tasks` 

115 

116 Examples 

117 -------- 

118 >>> inc = lambda x: x + 1 

119 >>> add = lambda x, y: x + y 

120 >>> dsk = {'x': 1, 

121 ... 'y': (inc, 'x'), 

122 ... 'z': (add, 'x', 'y'), 

123 ... 'w': (inc, 'z'), 

124 ... 'a': (add, (inc, 'x'), 1)} 

125 

126 >>> keys_in_tasks(dsk, ['x', 'y', 'j']) # doctest: +SKIP 

127 {'x', 'y'} 

128 """ 

129 from dask._task_spec import GraphNode 

130 

131 ret: list[Key] = [] 

132 while tasks: 

133 work = [] 

134 for w in tasks: 

135 typ = type(w) 

136 if typ is tuple and w and callable(w[0]): # istask(w) 

137 work.extend(w[1:]) 

138 elif typ is list: 

139 work.extend(w) 

140 elif typ is dict: 

141 work.extend(w.values()) 

142 elif isinstance(w, GraphNode): 

143 work.extend(w.dependencies) 

144 elif isinstance(w, TaskRef): 

145 work.append(w.key) 

146 else: 

147 try: 

148 if w in keys: 

149 ret.append(w) 

150 except TypeError: # not hashable 

151 pass 

152 tasks = work 

153 return ret if as_list else set(ret) 

154 

155 

156def iskey(key: object) -> bool: 

157 """Return True if the given object is a potential dask key; False otherwise. 

158 

159 The definition of a key in a Dask graph is any str, int, float, or tuple 

160 thereof. 

161 

162 See Also 

163 -------- 

164 ishashable 

165 validate_key 

166 dask.typing.Key 

167 """ 

168 typ = type(key) 

169 if typ is tuple: 

170 return all(iskey(i) for i in cast(tuple, key)) 

171 return typ in {int, float, str} 

172 

173 

174def validate_key(key: object) -> None: 

175 """Validate the format of a dask key. 

176 

177 See Also 

178 -------- 

179 iskey 

180 """ 

181 if iskey(key): 

182 return 

183 typ = type(key) 

184 

185 if typ is tuple: 

186 index = None 

187 try: 

188 for index, part in enumerate(cast(tuple, key)): # noqa: B007 

189 validate_key(part) 

190 except TypeError as e: 

191 raise TypeError( 

192 f"Composite key contains unexpected key type at {index=} ({key=!r})" 

193 ) from e 

194 raise TypeError(f"Unexpected key type {typ} ({key=!r})") 

195 

196 

197@overload 

198def get_dependencies( 

199 dsk: Graph, 

200 key: Key | None = ..., 

201 task: Key | NoDefault = ..., 

202 as_list: Literal[False] = ..., 

203) -> set[Key]: ... 

204 

205 

206@overload 

207def get_dependencies( 

208 dsk: Graph, 

209 key: Key | None, 

210 task: Key | NoDefault, 

211 as_list: Literal[True], 

212) -> list[Key]: ... 

213 

214 

215def get_dependencies( 

216 dsk: Graph, 

217 key: Key | None = None, 

218 task: Key | NoDefault = no_default, 

219 as_list: bool = False, 

220) -> set[Key] | list[Key]: 

221 """Get the immediate tasks on which this task depends 

222 

223 Examples 

224 -------- 

225 >>> inc = lambda x: x + 1 

226 >>> add = lambda x, y: x + y 

227 >>> dsk = {'x': 1, 

228 ... 'y': (inc, 'x'), 

229 ... 'z': (add, 'x', 'y'), 

230 ... 'w': (inc, 'z'), 

231 ... 'a': (add, (inc, 'x'), 1)} 

232 

233 >>> get_dependencies(dsk, 'x') 

234 set() 

235 

236 >>> get_dependencies(dsk, 'y') 

237 {'x'} 

238 

239 >>> get_dependencies(dsk, 'z') # doctest: +SKIP 

240 {'x', 'y'} 

241 

242 >>> get_dependencies(dsk, 'w') # Only direct dependencies 

243 {'z'} 

244 

245 >>> get_dependencies(dsk, 'a') # Ignore non-keys 

246 {'x'} 

247 

248 >>> get_dependencies(dsk, task=(inc, 'x')) # provide tasks directly 

249 {'x'} 

250 """ 

251 if key is not None: 

252 arg = dsk[key] 

253 elif task is not no_default: 

254 arg = task 

255 else: 

256 raise ValueError("Provide either key or task") 

257 

258 return keys_in_tasks(dsk, [arg], as_list=as_list) 

259 

260 

261def get_deps(dsk: Graph) -> tuple[dict[Key, set[Key]], dict[Key, set[Key]]]: 

262 """Get dependencies and dependents from dask dask graph 

263 

264 >>> inc = lambda x: x + 1 

265 >>> dsk = {'a': 1, 'b': (inc, 'a'), 'c': (inc, 'b')} 

266 >>> dependencies, dependents = get_deps(dsk) 

267 >>> dependencies 

268 {'a': set(), 'b': {'a'}, 'c': {'b'}} 

269 >>> dependents # doctest: +SKIP 

270 {'a': {'b'}, 'b': {'c'}, 'c': set()} 

271 """ 

272 dependencies = {k: get_dependencies(dsk, task=v) for k, v in dsk.items()} 

273 dependents = reverse_dict(dependencies) 

274 return dependencies, dependents 

275 

276 

277def flatten(seq, container=list): 

278 """ 

279 

280 >>> list(flatten([1])) 

281 [1] 

282 

283 >>> list(flatten([[1, 2], [1, 2]])) 

284 [1, 2, 1, 2] 

285 

286 >>> list(flatten([[[1], [2]], [[1], [2]]])) 

287 [1, 2, 1, 2] 

288 

289 >>> list(flatten(((1, 2), (1, 2)))) # Don't flatten tuples 

290 [(1, 2), (1, 2)] 

291 

292 >>> list(flatten((1, 2, [3, 4]))) # support heterogeneous 

293 [1, 2, 3, 4] 

294 """ 

295 if isinstance(seq, str): 

296 yield seq 

297 else: 

298 for item in seq: 

299 if isinstance(item, container): 

300 yield from flatten(item, container=container) 

301 else: 

302 yield item 

303 

304 

305T_ = TypeVar("T_") 

306 

307 

308def reverse_dict(d: Mapping[T_, Iterable[T_]]) -> dict[T_, set[T_]]: 

309 """ 

310 

311 >>> a, b, c = 'abc' 

312 >>> d = {a: [b, c], b: [c]} 

313 >>> reverse_dict(d) # doctest: +SKIP 

314 {'a': set([]), 'b': set(['a']}, 'c': set(['a', 'b'])} 

315 """ 

316 result: defaultdict[T_, set[T_]] = defaultdict(set) 

317 _add = set.add 

318 for k, vals in d.items(): 

319 result[k] 

320 for val in vals: 

321 _add(result[val], k) 

322 return dict(result) 

323 

324 

325def subs(task, key, val): 

326 """Perform a substitution on a task 

327 

328 Examples 

329 -------- 

330 >>> def inc(x): 

331 ... return x + 1 

332 

333 >>> subs((inc, 'x'), 'x', 1) # doctest: +ELLIPSIS 

334 (<function inc at ...>, 1) 

335 """ 

336 type_task = type(task) 

337 if not (type_task is tuple and task and callable(task[0])): # istask(task): 

338 try: 

339 if type_task is type(key) and task == key: 

340 return val 

341 except Exception: 

342 pass 

343 if type_task is list: 

344 return [subs(x, key, val) for x in task] 

345 return task 

346 newargs = [] 

347 hash_key = {key} 

348 for arg in task[1:]: 

349 type_arg = type(arg) 

350 if type_arg is tuple and arg and callable(arg[0]): # istask(task): 

351 arg = subs(arg, key, val) 

352 elif type_arg is list: 

353 arg = [subs(x, key, val) for x in arg] 

354 else: 

355 try: 

356 if arg in hash_key: # Hash and equality match 

357 arg = val 

358 except TypeError: # not hashable 

359 pass 

360 newargs.append(arg) 

361 return task[:1] + tuple(newargs) 

362 

363 

364def _toposort(dsk, keys=None, returncycle=False, dependencies=None): 

365 

366 # Stack-based depth-first search traversal. This is based on Tarjan's 

367 # method for topological sorting (see wikipedia for pseudocode) 

368 if keys is None: 

369 keys = dsk 

370 elif not isinstance(keys, list): 

371 keys = [keys] 

372 if not returncycle: 

373 ordered = [] 

374 

375 # Nodes whose descendents have been completely explored. 

376 # These nodes are guaranteed to not be part of a cycle. 

377 completed = set() 

378 

379 # All nodes that have been visited in the current traversal. Because 

380 # we are doing depth-first search, going "deeper" should never result 

381 # in visiting a node that has already been seen. The `seen` and 

382 # `completed` sets are mutually exclusive; it is okay to visit a node 

383 # that has already been added to `completed`. 

384 seen = set() 

385 

386 if dependencies is None: 

387 

388 dependencies = DependenciesMapping(dsk) 

389 

390 for key in keys: 

391 if key in completed: 

392 continue 

393 nodes = [key] 

394 while nodes: 

395 # Keep current node on the stack until all descendants are visited 

396 cur = nodes[-1] 

397 if cur in completed: 

398 # Already fully traversed descendants of cur 

399 nodes.pop() 

400 continue 

401 seen.add(cur) 

402 

403 # Add direct descendants of cur to nodes stack 

404 next_nodes = [] 

405 for nxt in dependencies[cur]: 

406 if nxt not in completed: 

407 if nxt in seen: 

408 # Cycle detected! 

409 # Let's report only the nodes that directly participate in the cycle. 

410 # We use `priorities` below to greedily construct a short cycle. 

411 # Shorter cycles may exist. 

412 priorities = {} 

413 prev = nodes[-1] 

414 # Give priority to nodes that were seen earlier. 

415 while nodes[-1] != nxt: 

416 priorities[nodes.pop()] = -len(priorities) 

417 priorities[nxt] = -len(priorities) 

418 # We're going to get the cycle by walking backwards along dependents, 

419 # so calculate dependents only for the nodes in play. 

420 inplay = set(priorities) 

421 dependents = reverse_dict( 

422 {k: inplay.intersection(dependencies[k]) for k in inplay} 

423 ) 

424 # Begin with the node that was seen twice and the node `prev` from 

425 # which we detected the cycle. 

426 cycle = [nodes.pop()] 

427 cycle.append(prev) 

428 while prev != cycle[0]: 

429 # Greedily take a step that takes us closest to completing the cycle. 

430 # This may not give us the shortest cycle, but we get *a* short cycle. 

431 deps = dependents[cycle[-1]] 

432 prev = min(deps, key=priorities.__getitem__) 

433 cycle.append(prev) 

434 cycle.reverse() 

435 

436 if returncycle: 

437 return cycle 

438 else: 

439 cycle = "->".join(str(x) for x in cycle) 

440 raise RuntimeError("Cycle detected in Dask: %s" % cycle) 

441 next_nodes.append(nxt) 

442 

443 if next_nodes: 

444 nodes.extend(next_nodes) 

445 else: 

446 # cur has no more descendants to explore, so we're done with it 

447 if not returncycle: 

448 ordered.append(cur) 

449 completed.add(cur) 

450 seen.remove(cur) 

451 nodes.pop() 

452 if returncycle: 

453 return [] 

454 return ordered 

455 

456 

457def toposort(dsk, dependencies=None): 

458 """Return a list of keys of dask sorted in topological order.""" 

459 return _toposort(dsk, dependencies=dependencies) 

460 

461 

462def getcycle(d, keys): 

463 """Return a list of nodes that form a cycle if Dask is not a DAG. 

464 

465 Returns an empty list if no cycle is found. 

466 

467 ``keys`` may be a single key or list of keys. 

468 

469 Examples 

470 -------- 

471 

472 >>> inc = lambda x: x + 1 

473 >>> d = {'x': (inc, 'z'), 'y': (inc, 'x'), 'z': (inc, 'y')} 

474 >>> getcycle(d, 'x') 

475 ['x', 'z', 'y', 'x'] 

476 

477 See Also 

478 -------- 

479 isdag 

480 """ 

481 return _toposort(d, keys=keys, returncycle=True) 

482 

483 

484def isdag(d, keys): 

485 """Does Dask form a directed acyclic graph when calculating keys? 

486 

487 ``keys`` may be a single key or list of keys. 

488 

489 Examples 

490 -------- 

491 

492 >>> inc = lambda x: x + 1 

493 >>> inc = lambda x: x + 1 

494 >>> isdag({'x': 0, 'y': (inc, 'x')}, 'y') 

495 True 

496 >>> isdag({'x': (inc, 'y'), 'y': (inc, 'x')}, 'y') 

497 False 

498 

499 See Also 

500 -------- 

501 getcycle 

502 """ 

503 return not getcycle(d, keys) 

504 

505 

506class literal: 

507 """A small serializable object to wrap literal values without copying""" 

508 

509 __slots__ = ("data",) 

510 

511 def __init__(self, data): 

512 self.data = data 

513 

514 def __repr__(self): 

515 return "literal<type=%s>" % type(self.data).__name__ 

516 

517 def __reduce__(self): 

518 return (literal, (self.data,)) 

519 

520 def __call__(self): 

521 return self.data 

522 

523 

524def quote(x): 

525 """Ensure that this value remains this value in a dask graph 

526 

527 Some values in dask graph take on special meaning. Sometimes we want to 

528 ensure that our data is not interpreted but remains literal. 

529 

530 >>> add = lambda x, y: x + y 

531 >>> quote((add, 1, 2)) 

532 (literal<type=tuple>,) 

533 """ 

534 if istask(x) or type(x) is list or type(x) is dict: 

535 return (literal(x),) 

536 return x 

537 

538 

539def reshapelist(shape, seq): 

540 """Reshape iterator to nested shape 

541 

542 >>> reshapelist((2, 3), range(6)) 

543 [[0, 1, 2], [3, 4, 5]] 

544 """ 

545 if len(shape) == 1: 

546 return list(seq) 

547 else: 

548 n = int(len(seq) / shape[0]) 

549 return [reshapelist(shape[1:], part) for part in toolz.partition(n, seq)]