Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/core.py: 33%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

220 statements  

1from __future__ import annotations 

2 

3from collections import defaultdict 

4from collections.abc import Collection, Iterable, Mapping, MutableMapping 

5from typing import Any, Literal, TypeVar, cast, overload 

6 

7import toolz 

8 

9from dask._task_spec import ( 

10 DependenciesMapping, 

11 TaskRef, 

12 convert_legacy_graph, 

13 execute_graph, 

14) 

15from dask.typing import Graph, Key, NoDefault, no_default 

16 

17 

18def ishashable(x): 

19 """Is x hashable? 

20 

21 Examples 

22 -------- 

23 

24 >>> ishashable(1) 

25 True 

26 >>> ishashable([1]) 

27 False 

28 

29 See Also 

30 -------- 

31 iskey 

32 """ 

33 try: 

34 hash(x) 

35 return True 

36 except TypeError: 

37 return False 

38 

39 

40def istask(x): 

41 """Is x a runnable task? 

42 

43 A task is a tuple with a callable first argument 

44 

45 Examples 

46 -------- 

47 

48 >>> inc = lambda x: x + 1 

49 >>> istask((inc, 1)) 

50 True 

51 >>> istask(1) 

52 False 

53 """ 

54 from dask._task_spec import DataNode, GraphNode 

55 

56 if isinstance(x, GraphNode): 

57 return not isinstance(x, DataNode) 

58 return type(x) is tuple and x and callable(x[0]) 

59 

60 

61def preorder_traversal(task): 

62 """A generator to preorder-traverse a task.""" 

63 

64 for item in task: 

65 if istask(item): 

66 yield from preorder_traversal(item) 

67 elif isinstance(item, list): 

68 yield list 

69 yield from preorder_traversal(item) 

70 else: 

71 yield item 

72 

73 

74def lists_to_tuples(res, keys): 

75 if isinstance(keys, list): 

76 return tuple(lists_to_tuples(r, k) for r, k in zip(res, keys)) 

77 return res 

78 

79 

80def _pack_result(result: Mapping, keys: list | Key) -> Any: 

81 if isinstance(keys, list): 

82 return tuple(_pack_result(result, k) for k in keys) 

83 return result[keys] 

84 

85 

86def get(dsk: Mapping, out: list | Key, cache: MutableMapping | None = None) -> Any: 

87 """Get value from Dask 

88 

89 Examples 

90 -------- 

91 

92 >>> inc = lambda x: x + 1 

93 >>> d = {'x': 1, 'y': (inc, 'x')} 

94 

95 >>> get(d, 'x') 

96 1 

97 >>> get(d, 'y') 

98 2 

99 """ 

100 for k in flatten(out): 

101 if k not in dsk: 

102 raise KeyError(f"{k} is not a key in the graph") 

103 if cache is None: 

104 cache = {} 

105 

106 dsk2 = convert_legacy_graph(dsk, all_keys=set(dsk) | set(cache)) 

107 result = execute_graph(dsk2, cache, keys=set(flatten([out]))) 

108 return _pack_result(result, out) 

109 

110 

111def keys_in_tasks(keys: Collection[Key], tasks: Iterable[Any], as_list: bool = False): 

112 """Returns the keys in `keys` that are also in `tasks` 

113 

114 Examples 

115 -------- 

116 >>> inc = lambda x: x + 1 

117 >>> add = lambda x, y: x + y 

118 >>> dsk = {'x': 1, 

119 ... 'y': (inc, 'x'), 

120 ... 'z': (add, 'x', 'y'), 

121 ... 'w': (inc, 'z'), 

122 ... 'a': (add, (inc, 'x'), 1)} 

123 

124 >>> keys_in_tasks(dsk, ['x', 'y', 'j']) # doctest: +SKIP 

125 {'x', 'y'} 

126 """ 

127 from dask._task_spec import GraphNode 

128 

129 ret: list[Key] = [] 

130 while tasks: 

131 work = [] 

132 for w in tasks: 

133 typ = type(w) 

134 if typ is tuple and w and callable(w[0]): # istask(w) 

135 work.extend(w[1:]) 

136 elif typ is list: 

137 work.extend(w) 

138 elif typ is dict: 

139 work.extend(w.values()) 

140 elif isinstance(w, GraphNode): 

141 work.extend(w.dependencies) 

142 elif isinstance(w, TaskRef): 

143 work.append(w.key) 

144 else: 

145 try: 

146 if w in keys: 

147 ret.append(w) 

148 except TypeError: # not hashable 

149 pass 

150 tasks = work 

151 return ret if as_list else set(ret) 

152 

153 

154def iskey(key: object) -> bool: 

155 """Return True if the given object is a potential dask key; False otherwise. 

156 

157 The definition of a key in a Dask graph is any str, int, float, or tuple 

158 thereof. 

159 

160 See Also 

161 -------- 

162 ishashable 

163 validate_key 

164 dask.typing.Key 

165 """ 

166 typ = type(key) 

167 if typ is tuple: 

168 return all(iskey(i) for i in cast(tuple, key)) 

169 return typ in {int, float, str} 

170 

171 

172def validate_key(key: object) -> None: 

173 """Validate the format of a dask key. 

174 

175 See Also 

176 -------- 

177 iskey 

178 """ 

179 if iskey(key): 

180 return 

181 typ = type(key) 

182 

183 if typ is tuple: 

184 index = None 

185 try: 

186 for index, part in enumerate(cast(tuple, key)): # noqa: B007 

187 validate_key(part) 

188 except TypeError as e: 

189 raise TypeError( 

190 f"Composite key contains unexpected key type at {index=} ({key=!r})" 

191 ) from e 

192 raise TypeError(f"Unexpected key type {typ} ({key=!r})") 

193 

194 

195@overload 

196def get_dependencies( 

197 dsk: Graph, 

198 key: Key | None = ..., 

199 task: Key | NoDefault = ..., 

200 as_list: Literal[False] = ..., 

201) -> set[Key]: ... 

202 

203 

204@overload 

205def get_dependencies( 

206 dsk: Graph, 

207 key: Key | None, 

208 task: Key | NoDefault, 

209 as_list: Literal[True], 

210) -> list[Key]: ... 

211 

212 

213def get_dependencies( 

214 dsk: Graph, 

215 key: Key | None = None, 

216 task: Key | NoDefault = no_default, 

217 as_list: bool = False, 

218) -> set[Key] | list[Key]: 

219 """Get the immediate tasks on which this task depends 

220 

221 Examples 

222 -------- 

223 >>> inc = lambda x: x + 1 

224 >>> add = lambda x, y: x + y 

225 >>> dsk = {'x': 1, 

226 ... 'y': (inc, 'x'), 

227 ... 'z': (add, 'x', 'y'), 

228 ... 'w': (inc, 'z'), 

229 ... 'a': (add, (inc, 'x'), 1)} 

230 

231 >>> get_dependencies(dsk, 'x') 

232 set() 

233 

234 >>> get_dependencies(dsk, 'y') 

235 {'x'} 

236 

237 >>> get_dependencies(dsk, 'z') # doctest: +SKIP 

238 {'x', 'y'} 

239 

240 >>> get_dependencies(dsk, 'w') # Only direct dependencies 

241 {'z'} 

242 

243 >>> get_dependencies(dsk, 'a') # Ignore non-keys 

244 {'x'} 

245 

246 >>> get_dependencies(dsk, task=(inc, 'x')) # provide tasks directly 

247 {'x'} 

248 """ 

249 if key is not None: 

250 arg = dsk[key] 

251 elif task is not no_default: 

252 arg = task 

253 else: 

254 raise ValueError("Provide either key or task") 

255 

256 return keys_in_tasks(dsk, [arg], as_list=as_list) 

257 

258 

259def get_deps(dsk: Graph) -> tuple[dict[Key, set[Key]], dict[Key, set[Key]]]: 

260 """Get dependencies and dependents from dask dask graph 

261 

262 >>> inc = lambda x: x + 1 

263 >>> dsk = {'a': 1, 'b': (inc, 'a'), 'c': (inc, 'b')} 

264 >>> dependencies, dependents = get_deps(dsk) 

265 >>> dependencies 

266 {'a': set(), 'b': {'a'}, 'c': {'b'}} 

267 >>> dependents # doctest: +SKIP 

268 {'a': {'b'}, 'b': {'c'}, 'c': set()} 

269 """ 

270 dependencies = {k: get_dependencies(dsk, task=v) for k, v in dsk.items()} 

271 dependents = reverse_dict(dependencies) 

272 return dependencies, dependents 

273 

274 

275def flatten(seq, container=list): 

276 """ 

277 

278 >>> list(flatten([1])) 

279 [1] 

280 

281 >>> list(flatten([[1, 2], [1, 2]])) 

282 [1, 2, 1, 2] 

283 

284 >>> list(flatten([[[1], [2]], [[1], [2]]])) 

285 [1, 2, 1, 2] 

286 

287 >>> list(flatten(((1, 2), (1, 2)))) # Don't flatten tuples 

288 [(1, 2), (1, 2)] 

289 

290 >>> list(flatten((1, 2, [3, 4]))) # support heterogeneous 

291 [1, 2, 3, 4] 

292 """ 

293 if isinstance(seq, str): 

294 yield seq 

295 else: 

296 for item in seq: 

297 if isinstance(item, container): 

298 yield from flatten(item, container=container) 

299 else: 

300 yield item 

301 

302 

303T_ = TypeVar("T_") 

304 

305 

306def reverse_dict(d: Mapping[T_, Iterable[T_]]) -> dict[T_, set[T_]]: 

307 """ 

308 

309 >>> a, b, c = 'abc' 

310 >>> d = {a: [b, c], b: [c]} 

311 >>> reverse_dict(d) # doctest: +SKIP 

312 {'a': set([]), 'b': set(['a']}, 'c': set(['a', 'b'])} 

313 """ 

314 result: defaultdict[T_, set[T_]] = defaultdict(set) 

315 _add = set.add 

316 for k, vals in d.items(): 

317 result[k] 

318 for val in vals: 

319 _add(result[val], k) 

320 return dict(result) 

321 

322 

323def subs(task, key, val): 

324 """Perform a substitution on a task 

325 

326 Examples 

327 -------- 

328 >>> def inc(x): 

329 ... return x + 1 

330 

331 >>> subs((inc, 'x'), 'x', 1) # doctest: +ELLIPSIS 

332 (<function inc at ...>, 1) 

333 """ 

334 type_task = type(task) 

335 if not (type_task is tuple and task and callable(task[0])): # istask(task): 

336 try: 

337 if type_task is type(key) and task == key: 

338 return val 

339 except Exception: 

340 pass 

341 if type_task is list: 

342 return [subs(x, key, val) for x in task] 

343 return task 

344 newargs = [] 

345 hash_key = {key} 

346 for arg in task[1:]: 

347 type_arg = type(arg) 

348 if type_arg is tuple and arg and callable(arg[0]): # istask(task): 

349 arg = subs(arg, key, val) 

350 elif type_arg is list: 

351 arg = [subs(x, key, val) for x in arg] 

352 else: 

353 try: 

354 if arg in hash_key: # Hash and equality match 

355 arg = val 

356 except TypeError: # not hashable 

357 pass 

358 newargs.append(arg) 

359 return task[:1] + tuple(newargs) 

360 

361 

362def _toposort(dsk, keys=None, returncycle=False, dependencies=None): 

363 

364 # Stack-based depth-first search traversal. This is based on Tarjan's 

365 # method for topological sorting (see wikipedia for pseudocode) 

366 if keys is None: 

367 keys = dsk 

368 elif not isinstance(keys, list): 

369 keys = [keys] 

370 if not returncycle: 

371 ordered = [] 

372 

373 # Nodes whose descendents have been completely explored. 

374 # These nodes are guaranteed to not be part of a cycle. 

375 completed = set() 

376 

377 # All nodes that have been visited in the current traversal. Because 

378 # we are doing depth-first search, going "deeper" should never result 

379 # in visiting a node that has already been seen. The `seen` and 

380 # `completed` sets are mutually exclusive; it is okay to visit a node 

381 # that has already been added to `completed`. 

382 seen = set() 

383 

384 if dependencies is None: 

385 

386 dependencies = DependenciesMapping(dsk) 

387 

388 for key in keys: 

389 if key in completed: 

390 continue 

391 nodes = [key] 

392 while nodes: 

393 # Keep current node on the stack until all descendants are visited 

394 cur = nodes[-1] 

395 if cur in completed: 

396 # Already fully traversed descendants of cur 

397 nodes.pop() 

398 continue 

399 seen.add(cur) 

400 

401 # Add direct descendants of cur to nodes stack 

402 next_nodes = [] 

403 for nxt in dependencies[cur]: 

404 if nxt not in completed: 

405 if nxt in seen: 

406 # Cycle detected! 

407 # Let's report only the nodes that directly participate in the cycle. 

408 # We use `priorities` below to greedily construct a short cycle. 

409 # Shorter cycles may exist. 

410 priorities = {} 

411 prev = nodes[-1] 

412 # Give priority to nodes that were seen earlier. 

413 while nodes[-1] != nxt: 

414 priorities[nodes.pop()] = -len(priorities) 

415 priorities[nxt] = -len(priorities) 

416 # We're going to get the cycle by walking backwards along dependents, 

417 # so calculate dependents only for the nodes in play. 

418 inplay = set(priorities) 

419 dependents = reverse_dict( 

420 {k: inplay.intersection(dependencies[k]) for k in inplay} 

421 ) 

422 # Begin with the node that was seen twice and the node `prev` from 

423 # which we detected the cycle. 

424 cycle = [nodes.pop()] 

425 cycle.append(prev) 

426 while prev != cycle[0]: 

427 # Greedily take a step that takes us closest to completing the cycle. 

428 # This may not give us the shortest cycle, but we get *a* short cycle. 

429 deps = dependents[cycle[-1]] 

430 prev = min(deps, key=priorities.__getitem__) 

431 cycle.append(prev) 

432 cycle.reverse() 

433 

434 if returncycle: 

435 return cycle 

436 else: 

437 cycle = "->".join(str(x) for x in cycle) 

438 raise RuntimeError("Cycle detected in Dask: %s" % cycle) 

439 next_nodes.append(nxt) 

440 

441 if next_nodes: 

442 nodes.extend(next_nodes) 

443 else: 

444 # cur has no more descendants to explore, so we're done with it 

445 if not returncycle: 

446 ordered.append(cur) 

447 completed.add(cur) 

448 seen.remove(cur) 

449 nodes.pop() 

450 if returncycle: 

451 return [] 

452 return ordered 

453 

454 

455def toposort(dsk, dependencies=None): 

456 """Return a list of keys of dask sorted in topological order.""" 

457 return _toposort(dsk, dependencies=dependencies) 

458 

459 

460def getcycle(d, keys): 

461 """Return a list of nodes that form a cycle if Dask is not a DAG. 

462 

463 Returns an empty list if no cycle is found. 

464 

465 ``keys`` may be a single key or list of keys. 

466 

467 Examples 

468 -------- 

469 

470 >>> inc = lambda x: x + 1 

471 >>> d = {'x': (inc, 'z'), 'y': (inc, 'x'), 'z': (inc, 'y')} 

472 >>> getcycle(d, 'x') 

473 ['x', 'z', 'y', 'x'] 

474 

475 See Also 

476 -------- 

477 isdag 

478 """ 

479 return _toposort(d, keys=keys, returncycle=True) 

480 

481 

482def isdag(d, keys): 

483 """Does Dask form a directed acyclic graph when calculating keys? 

484 

485 ``keys`` may be a single key or list of keys. 

486 

487 Examples 

488 -------- 

489 

490 >>> inc = lambda x: x + 1 

491 >>> inc = lambda x: x + 1 

492 >>> isdag({'x': 0, 'y': (inc, 'x')}, 'y') 

493 True 

494 >>> isdag({'x': (inc, 'y'), 'y': (inc, 'x')}, 'y') 

495 False 

496 

497 See Also 

498 -------- 

499 getcycle 

500 """ 

501 return not getcycle(d, keys) 

502 

503 

504class literal: 

505 """A small serializable object to wrap literal values without copying""" 

506 

507 __slots__ = ("data",) 

508 

509 def __init__(self, data): 

510 self.data = data 

511 

512 def __repr__(self): 

513 return "literal<type=%s>" % type(self.data).__name__ 

514 

515 def __reduce__(self): 

516 return (literal, (self.data,)) 

517 

518 def __call__(self): 

519 return self.data 

520 

521 

522def quote(x): 

523 """Ensure that this value remains this value in a dask graph 

524 

525 Some values in dask graph take on special meaning. Sometimes we want to 

526 ensure that our data is not interpreted but remains literal. 

527 

528 >>> add = lambda x, y: x + y 

529 >>> quote((add, 1, 2)) 

530 (literal<type=tuple>,) 

531 """ 

532 if istask(x) or type(x) is list or type(x) is dict: 

533 return (literal(x),) 

534 return x 

535 

536 

537def reshapelist(shape, seq): 

538 """Reshape iterator to nested shape 

539 

540 >>> reshapelist((2, 3), range(6)) 

541 [[0, 1, 2], [3, 4, 5]] 

542 """ 

543 if len(shape) == 1: 

544 return list(seq) 

545 else: 

546 n = int(len(seq) / shape[0]) 

547 return [reshapelist(shape[1:], part) for part in toolz.partition(n, seq)]