Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/_expr.py: 21%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

731 statements  

1from __future__ import annotations 

2 

3import functools 

4import os 

5import uuid 

6import warnings 

7import weakref 

8from collections import defaultdict 

9from collections.abc import Generator 

10from typing import TYPE_CHECKING, Any, Literal, TypeAlias 

11 

12import toolz 

13 

14import dask 

15from dask._task_spec import Task, convert_legacy_graph 

16from dask.tokenize import _tokenize_deterministic 

17from dask.typing import Key 

18from dask.utils import ensure_dict, funcname, import_required 

19 

20if TYPE_CHECKING: 

21 from dask.highlevelgraph import HighLevelGraph 

22 

23OptimizerStage: TypeAlias = Literal[ 

24 "logical", 

25 "simplified-logical", 

26 "tuned-logical", 

27 "physical", 

28 "simplified-physical", 

29 "fused", 

30] 

31 

32 

33def _unpack_collections(o): 

34 from dask.delayed import Delayed 

35 

36 if isinstance(o, Expr): 

37 return o 

38 

39 if hasattr(o, "expr") and not isinstance(o, Delayed): 

40 return o.expr 

41 else: 

42 return o 

43 

44 

45class Expr: 

46 _parameters: list[str] = [] 

47 _defaults: dict[str, Any] = {} 

48 

49 _pickle_functools_cache: bool = True 

50 

51 operands: list 

52 

53 _determ_token: str | None 

54 

55 def __new__(cls, *args, _determ_token=None, **kwargs): 

56 operands = list(args) 

57 for parameter in cls._parameters[len(operands) :]: 

58 try: 

59 operands.append(kwargs.pop(parameter)) 

60 except KeyError: 

61 operands.append(cls._defaults[parameter]) 

62 assert not kwargs, kwargs 

63 inst = object.__new__(cls) 

64 

65 inst._determ_token = _determ_token 

66 inst.operands = [_unpack_collections(o) for o in operands] 

67 # This is typically cached. Make sure the cache is populated by calling 

68 # it once 

69 inst._name 

70 return inst 

71 

72 def _tune_down(self): 

73 return None 

74 

75 def _tune_up(self, parent): 

76 return None 

77 

78 def finalize_compute(self): 

79 return self 

80 

81 def _operands_for_repr(self): 

82 return [f"{param}={op!r}" for param, op in zip(self._parameters, self.operands)] 

83 

84 def __str__(self): 

85 s = ", ".join(self._operands_for_repr()) 

86 return f"{type(self).__name__}({s})" 

87 

88 def __repr__(self): 

89 return str(self) 

90 

91 def _tree_repr_argument_construction(self, i, op, header): 

92 try: 

93 param = self._parameters[i] 

94 default = self._defaults[param] 

95 except (IndexError, KeyError): 

96 param = self._parameters[i] if i < len(self._parameters) else "" 

97 default = "--no-default--" 

98 

99 if repr(op) != repr(default): 

100 if param: 

101 header += f" {param}={op!r}" 

102 else: 

103 header += repr(op) 

104 return header 

105 

106 def _tree_repr_lines(self, indent=0, recursive=True): 

107 return " " * indent + repr(self) 

108 

109 def tree_repr(self): 

110 return os.linesep.join(self._tree_repr_lines()) 

111 

112 def analyze(self, filename: str | None = None, format: str | None = None) -> None: 

113 from dask.dataframe.dask_expr._expr import Expr as DFExpr 

114 from dask.dataframe.dask_expr.diagnostics import analyze 

115 

116 if not isinstance(self, DFExpr): 

117 raise TypeError( 

118 "analyze is only supported for dask.dataframe.Expr objects." 

119 ) 

120 return analyze(self, filename=filename, format=format) 

121 

122 def explain( 

123 self, stage: OptimizerStage = "fused", format: str | None = None 

124 ) -> None: 

125 from dask.dataframe.dask_expr.diagnostics import explain 

126 

127 return explain(self, stage, format) 

128 

129 def pprint(self): 

130 for line in self._tree_repr_lines(): 

131 print(line) 

132 

133 def __hash__(self): 

134 return hash(self._name) 

135 

136 def __dask_tokenize__(self): 

137 if not self._determ_token: 

138 # If the subclass does not implement a __dask_tokenize__ we'll want 

139 # to tokenize all operands. 

140 # Note how this differs to the implementation of 

141 # Expr.deterministic_token 

142 self._determ_token = _tokenize_deterministic(type(self), *self.operands) 

143 return self._determ_token 

144 

145 def __dask_keys__(self): 

146 """The keys for this expression 

147 

148 This is used to determine the keys of the output collection 

149 when this expression is computed. 

150 

151 Returns 

152 ------- 

153 keys: list 

154 The keys for this expression 

155 """ 

156 return [(self._name, i) for i in range(self.npartitions)] 

157 

158 @staticmethod 

159 def _reconstruct(*args): 

160 typ, *operands, token, cache = args 

161 inst = typ(*operands, _determ_token=token) 

162 for k, v in cache.items(): 

163 inst.__dict__[k] = v 

164 return inst 

165 

166 def __reduce__(self): 

167 if dask.config.get("dask-expr-no-serialize", False): 

168 raise RuntimeError(f"Serializing a {type(self)} object") 

169 cache = {} 

170 if type(self)._pickle_functools_cache: 

171 for k, v in type(self).__dict__.items(): 

172 if isinstance(v, functools.cached_property) and k in self.__dict__: 

173 cache[k] = getattr(self, k) 

174 

175 return Expr._reconstruct, ( 

176 type(self), 

177 *self.operands, 

178 self.deterministic_token, 

179 cache, 

180 ) 

181 

182 def _depth(self, cache=None): 

183 """Depth of the expression tree 

184 

185 Returns 

186 ------- 

187 depth: int 

188 """ 

189 if cache is None: 

190 cache = {} 

191 if not self.dependencies(): 

192 return 1 

193 else: 

194 result = [] 

195 for expr in self.dependencies(): 

196 if expr._name in cache: 

197 result.append(cache[expr._name]) 

198 else: 

199 result.append(expr._depth(cache) + 1) 

200 cache[expr._name] = result[-1] 

201 return max(result) 

202 

203 def __setattr__(self, name: str, value: Any) -> None: 

204 if name in ["operands", "_determ_token"]: 

205 object.__setattr__(self, name, value) 

206 return 

207 try: 

208 params = type(self)._parameters 

209 operands = object.__getattribute__(self, "operands") 

210 operands[params.index(name)] = value 

211 except ValueError: 

212 raise AttributeError( 

213 f"{type(self).__name__} object has no attribute {name}" 

214 ) 

215 

216 def operand(self, key): 

217 # Access an operand unambiguously 

218 # (e.g. if the key is reserved by a method/property) 

219 return self.operands[type(self)._parameters.index(key)] 

220 

221 def dependencies(self): 

222 # Dependencies are `Expr` operands only 

223 return [operand for operand in self.operands if isinstance(operand, Expr)] 

224 

225 def _task(self, key: Key, index: int) -> Task: 

226 """The task for the i'th partition 

227 

228 Parameters 

229 ---------- 

230 index: 

231 The index of the partition of this dataframe 

232 

233 Examples 

234 -------- 

235 >>> class Add(Expr): 

236 ... def _task(self, i): 

237 ... return Task( 

238 ... self.__dask_keys__()[i], 

239 ... operator.add, 

240 ... TaskRef((self.left._name, i)), 

241 ... TaskRef((self.right._name, i)) 

242 ... ) 

243 

244 Returns 

245 ------- 

246 task: 

247 The Dask task to compute this partition 

248 

249 See Also 

250 -------- 

251 Expr._layer 

252 """ 

253 raise NotImplementedError( 

254 "Expressions should define either _layer (full dictionary) or _task" 

255 f" (single task). This expression {type(self)} defines neither" 

256 ) 

257 

258 def _layer(self) -> dict: 

259 """The graph layer added by this expression. 

260 

261 Simple expressions that apply one task per partition can choose to only 

262 implement `Expr._task` instead. 

263 

264 Examples 

265 -------- 

266 >>> class Add(Expr): 

267 ... def _layer(self): 

268 ... return { 

269 ... name: Task( 

270 ... name, 

271 ... operator.add, 

272 ... TaskRef((self.left._name, i)), 

273 ... TaskRef((self.right._name, i)) 

274 ... ) 

275 ... for i, name in enumerate(self.__dask_keys__()) 

276 ... } 

277 

278 Returns 

279 ------- 

280 layer: dict 

281 The Dask task graph added by this expression 

282 

283 See Also 

284 -------- 

285 Expr._task 

286 Expr.__dask_graph__ 

287 """ 

288 

289 return { 

290 (self._name, i): self._task((self._name, i), i) 

291 for i in range(self.npartitions) 

292 } 

293 

294 def rewrite(self, kind: str, rewritten): 

295 """Rewrite an expression 

296 

297 This leverages the ``._{kind}_down`` and ``._{kind}_up`` 

298 methods defined on each class 

299 

300 Returns 

301 ------- 

302 expr: 

303 output expression 

304 changed: 

305 whether or not any change occurred 

306 """ 

307 if self._name in rewritten: 

308 return rewritten[self._name] 

309 

310 expr = self 

311 down_name = f"_{kind}_down" 

312 up_name = f"_{kind}_up" 

313 while True: 

314 _continue = False 

315 

316 # Rewrite this node 

317 out = getattr(expr, down_name)() 

318 if out is None: 

319 out = expr 

320 if not isinstance(out, Expr): 

321 return out 

322 if out._name != expr._name: 

323 expr = out 

324 continue 

325 

326 # Allow children to rewrite their parents 

327 for child in expr.dependencies(): 

328 out = getattr(child, up_name)(expr) 

329 if out is None: 

330 out = expr 

331 if not isinstance(out, Expr): 

332 return out 

333 if out is not expr and out._name != expr._name: 

334 expr = out 

335 _continue = True 

336 break 

337 

338 if _continue: 

339 continue 

340 

341 # Rewrite all of the children 

342 new_operands = [] 

343 changed = False 

344 for operand in expr.operands: 

345 if isinstance(operand, Expr): 

346 new = operand.rewrite(kind=kind, rewritten=rewritten) 

347 rewritten[operand._name] = new 

348 if new._name != operand._name: 

349 changed = True 

350 else: 

351 new = operand 

352 new_operands.append(new) 

353 

354 if changed: 

355 expr = type(expr)(*new_operands) 

356 continue 

357 else: 

358 break 

359 

360 return expr 

361 

362 def simplify_once(self, dependents: defaultdict, simplified: dict): 

363 """Simplify an expression 

364 

365 This leverages the ``._simplify_down`` and ``._simplify_up`` 

366 methods defined on each class 

367 

368 Parameters 

369 ---------- 

370 

371 dependents: defaultdict[list] 

372 The dependents for every node. 

373 simplified: dict 

374 Cache of simplified expressions for these dependents. 

375 

376 Returns 

377 ------- 

378 expr: 

379 output expression 

380 """ 

381 # Check if we've already simplified for these dependents 

382 if self._name in simplified: 

383 return simplified[self._name] 

384 

385 expr = self 

386 

387 while True: 

388 out = expr._simplify_down() 

389 if out is None: 

390 out = expr 

391 if not isinstance(out, Expr): 

392 return out 

393 if out._name != expr._name: 

394 expr = out 

395 

396 # Allow children to simplify their parents 

397 for child in expr.dependencies(): 

398 out = child._simplify_up(expr, dependents) 

399 if out is None: 

400 out = expr 

401 

402 if not isinstance(out, Expr): 

403 return out 

404 if out is not expr and out._name != expr._name: 

405 expr = out 

406 break 

407 

408 # Rewrite all of the children 

409 new_operands = [] 

410 changed = False 

411 for operand in expr.operands: 

412 if isinstance(operand, Expr): 

413 # Bandaid for now, waiting for Singleton 

414 dependents[operand._name].append(weakref.ref(expr)) 

415 new = operand.simplify_once( 

416 dependents=dependents, simplified=simplified 

417 ) 

418 simplified[operand._name] = new 

419 if new._name != operand._name: 

420 changed = True 

421 else: 

422 new = operand 

423 new_operands.append(new) 

424 

425 if changed: 

426 expr = type(expr)(*new_operands) 

427 

428 break 

429 

430 return expr 

431 

432 def optimize(self, fuse: bool = False) -> Expr: 

433 stage: OptimizerStage = "fused" if fuse else "simplified-physical" 

434 

435 return optimize_until(self, stage) 

436 

437 def fuse(self) -> Expr: 

438 return self 

439 

440 def simplify(self) -> Expr: 

441 expr = self 

442 seen = set() 

443 while True: 

444 dependents = collect_dependents(expr) 

445 new = expr.simplify_once(dependents=dependents, simplified={}) 

446 if new._name == expr._name: 

447 break 

448 if new._name in seen: 

449 raise RuntimeError( 

450 f"Optimizer does not converge. {expr!r} simplified to {new!r} which was already seen. " 

451 "Please report this issue on the dask issue tracker with a minimal reproducer." 

452 ) 

453 seen.add(new._name) 

454 expr = new 

455 return expr 

456 

457 def _simplify_down(self): 

458 return 

459 

460 def _simplify_up(self, parent, dependents): 

461 return 

462 

463 def lower_once(self, lowered: dict): 

464 # Check for a cached result 

465 try: 

466 return lowered[self._name] 

467 except KeyError: 

468 pass 

469 

470 expr = self 

471 

472 # Lower this node 

473 out = expr._lower() 

474 if out is None: 

475 out = expr 

476 if not isinstance(out, Expr): 

477 return out 

478 

479 # Lower all children 

480 new_operands = [] 

481 changed = False 

482 for operand in out.operands: 

483 if isinstance(operand, Expr): 

484 new = operand.lower_once(lowered) 

485 if new._name != operand._name: 

486 changed = True 

487 else: 

488 new = operand 

489 new_operands.append(new) 

490 

491 if changed: 

492 out = type(out)(*new_operands) 

493 

494 # Cache the result and return 

495 return lowered.setdefault(self._name, out) 

496 

497 def lower_completely(self) -> Expr: 

498 """Lower an expression completely 

499 

500 This calls the ``lower_once`` method in a loop 

501 until nothing changes. This function does not 

502 apply any other optimizations (like ``simplify``). 

503 

504 Returns 

505 ------- 

506 expr: 

507 output expression 

508 

509 See Also 

510 -------- 

511 Expr.lower_once 

512 Expr._lower 

513 """ 

514 # Lower until nothing changes 

515 expr = self 

516 lowered: dict = {} 

517 while True: 

518 new = expr.lower_once(lowered) 

519 if new._name == expr._name: 

520 break 

521 expr = new 

522 return expr 

523 

524 def _lower(self): 

525 return 

526 

527 @functools.cached_property 

528 def _funcname(self) -> str: 

529 return funcname(type(self)).lower() 

530 

531 @property 

532 def deterministic_token(self): 

533 if not self._determ_token: 

534 # Just tokenize self to fall back on __dask_tokenize__ 

535 # Note how this differs to the implementation of __dask_tokenize__ 

536 self._determ_token = self.__dask_tokenize__() 

537 return self._determ_token 

538 

539 @functools.cached_property 

540 def _name(self) -> str: 

541 return f"{self._funcname}-{self.deterministic_token}" 

542 

543 @property 

544 def _meta(self): 

545 raise NotImplementedError() 

546 

547 @classmethod 

548 def _annotations_tombstone(cls) -> _AnnotationsTombstone: 

549 return _AnnotationsTombstone() 

550 

551 def __dask_annotations__(self): 

552 return {} 

553 

554 def __dask_graph__(self): 

555 """Traverse expression tree, collect layers 

556 

557 Subclasses generally do not want to override this method unless custom 

558 logic is required to treat (e.g. ignore) specific operands during graph 

559 generation. 

560 

561 See also 

562 -------- 

563 Expr._layer 

564 Expr._task 

565 """ 

566 stack = [self] 

567 seen = set() 

568 layers = [] 

569 while stack: 

570 expr = stack.pop() 

571 

572 if expr._name in seen: 

573 continue 

574 seen.add(expr._name) 

575 

576 layers.append(expr._layer()) 

577 for operand in expr.dependencies(): 

578 stack.append(operand) 

579 

580 return toolz.merge(layers) 

581 

582 @property 

583 def dask(self): 

584 return self.__dask_graph__() 

585 

586 def substitute(self, old, new) -> Expr: 

587 """Substitute a specific term within the expression 

588 

589 Note that replacing non-`Expr` terms may produce 

590 unexpected results, and is not recommended. 

591 Substituting boolean values is not allowed. 

592 

593 Parameters 

594 ---------- 

595 old: 

596 Old term to find and replace. 

597 new: 

598 New term to replace instances of `old` with. 

599 

600 Examples 

601 -------- 

602 >>> (df + 10).substitute(10, 20) # doctest: +SKIP 

603 df + 20 

604 """ 

605 return self._substitute(old, new, _seen=set()) 

606 

607 def _substitute(self, old, new, _seen): 

608 if self._name in _seen: 

609 return self 

610 # Check if we are replacing a literal 

611 if isinstance(old, Expr): 

612 substitute_literal = False 

613 if self._name == old._name: 

614 return new 

615 else: 

616 substitute_literal = True 

617 if isinstance(old, bool): 

618 raise TypeError("Arguments to `substitute` cannot be bool.") 

619 

620 new_exprs = [] 

621 update = False 

622 for operand in self.operands: 

623 if isinstance(operand, Expr): 

624 val = operand._substitute(old, new, _seen) 

625 if operand._name != val._name: 

626 update = True 

627 new_exprs.append(val) 

628 elif ( 

629 "Fused" in type(self).__name__ 

630 and isinstance(operand, list) 

631 and all(isinstance(op, Expr) for op in operand) 

632 ): 

633 # Special handling for `Fused`. 

634 # We make no promise to dive through a 

635 # list operand in general, but NEED to 

636 # do so for the `Fused.exprs` operand. 

637 val = [] 

638 for op in operand: 

639 val.append(op._substitute(old, new, _seen)) 

640 if val[-1]._name != op._name: 

641 update = True 

642 new_exprs.append(val) 

643 elif ( 

644 substitute_literal 

645 and not isinstance(operand, bool) 

646 and isinstance(operand, type(old)) 

647 and operand == old 

648 ): 

649 new_exprs.append(new) 

650 update = True 

651 else: 

652 new_exprs.append(operand) 

653 

654 if update: # Only recreate if something changed 

655 return type(self)(*new_exprs) 

656 else: 

657 _seen.add(self._name) 

658 return self 

659 

660 def substitute_parameters(self, substitutions: dict) -> Expr: 

661 """Substitute specific `Expr` parameters 

662 

663 Parameters 

664 ---------- 

665 substitutions: 

666 Mapping of parameter keys to new values. Keys that 

667 are not found in ``self._parameters`` will be ignored. 

668 """ 

669 if not substitutions: 

670 return self 

671 

672 changed = False 

673 new_operands = [] 

674 for i, operand in enumerate(self.operands): 

675 if i < len(self._parameters) and self._parameters[i] in substitutions: 

676 new_operands.append(substitutions[self._parameters[i]]) 

677 changed = True 

678 else: 

679 new_operands.append(operand) 

680 if changed: 

681 return type(self)(*new_operands) 

682 return self 

683 

684 def _node_label_args(self): 

685 """Operands to include in the node label by `visualize`""" 

686 return self.dependencies() 

687 

688 def _to_graphviz( 

689 self, 

690 rankdir="BT", 

691 graph_attr=None, 

692 node_attr=None, 

693 edge_attr=None, 

694 **kwargs, 

695 ): 

696 from dask.dot import label, name 

697 

698 graphviz = import_required( 

699 "graphviz", 

700 "Drawing dask graphs with the graphviz visualization engine requires the `graphviz` " 

701 "python library and the `graphviz` system library.\n\n" 

702 "Please either conda or pip install as follows:\n\n" 

703 " conda install python-graphviz # either conda install\n" 

704 " python -m pip install graphviz # or pip install and follow installation instructions", 

705 ) 

706 

707 graph_attr = graph_attr or {} 

708 node_attr = node_attr or {} 

709 edge_attr = edge_attr or {} 

710 

711 graph_attr["rankdir"] = rankdir 

712 node_attr["shape"] = "box" 

713 node_attr["fontname"] = "helvetica" 

714 

715 graph_attr.update(kwargs) 

716 g = graphviz.Digraph( 

717 graph_attr=graph_attr, 

718 node_attr=node_attr, 

719 edge_attr=edge_attr, 

720 ) 

721 

722 stack = [self] 

723 seen = set() 

724 dependencies = {} 

725 while stack: 

726 expr = stack.pop() 

727 

728 if expr._name in seen: 

729 continue 

730 seen.add(expr._name) 

731 

732 dependencies[expr] = set(expr.dependencies()) 

733 for dep in expr.dependencies(): 

734 stack.append(dep) 

735 

736 cache = {} 

737 for expr in dependencies: 

738 expr_name = name(expr) 

739 attrs = {} 

740 

741 # Make node label 

742 deps = [ 

743 funcname(type(dep)) if isinstance(dep, Expr) else str(dep) 

744 for dep in expr._node_label_args() 

745 ] 

746 _label = funcname(type(expr)) 

747 if deps: 

748 _label = f"{_label}({', '.join(deps)})" if deps else _label 

749 node_label = label(_label, cache=cache) 

750 

751 attrs.setdefault("label", str(node_label)) 

752 attrs.setdefault("fontsize", "20") 

753 g.node(expr_name, **attrs) 

754 

755 for expr, deps in dependencies.items(): 

756 expr_name = name(expr) 

757 for dep in deps: 

758 dep_name = name(dep) 

759 g.edge(dep_name, expr_name) 

760 

761 return g 

762 

763 def visualize(self, filename="dask-expr.svg", format=None, **kwargs): 

764 """ 

765 Visualize the expression graph. 

766 Requires ``graphviz`` to be installed. 

767 

768 Parameters 

769 ---------- 

770 filename : str or None, optional 

771 The name of the file to write to disk. If the provided `filename` 

772 doesn't include an extension, '.png' will be used by default. 

773 If `filename` is None, no file will be written, and the graph is 

774 rendered in the Jupyter notebook only. 

775 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional 

776 Format in which to write output file. Default is 'svg'. 

777 **kwargs 

778 Additional keyword arguments to forward to ``to_graphviz``. 

779 """ 

780 from dask.dot import graphviz_to_file 

781 

782 g = self._to_graphviz(**kwargs) 

783 graphviz_to_file(g, filename, format) 

784 return g 

785 

786 def walk(self) -> Generator[Expr]: 

787 """Iterate through all expressions in the tree 

788 

789 Returns 

790 ------- 

791 nodes 

792 Generator of Expr instances in the graph. 

793 Ordering is a depth-first search of the expression tree 

794 """ 

795 stack = [self] 

796 seen = set() 

797 while stack: 

798 node = stack.pop() 

799 if node._name in seen: 

800 continue 

801 seen.add(node._name) 

802 

803 for dep in node.dependencies(): 

804 stack.append(dep) 

805 

806 yield node 

807 

808 def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]: 

809 """Search the expression graph for a specific operation type 

810 

811 Parameters 

812 ---------- 

813 operation 

814 The operation type to search for. 

815 

816 Returns 

817 ------- 

818 nodes 

819 Generator of `operation` instances. Ordering corresponds 

820 to a depth-first search of the expression graph. 

821 """ 

822 assert ( 

823 isinstance(operation, tuple) 

824 and all(issubclass(e, Expr) for e in operation) 

825 or issubclass(operation, Expr) # type: ignore[arg-type] 

826 ), "`operation` must be`Expr` subclass)" 

827 return (expr for expr in self.walk() if isinstance(expr, operation)) 

828 

829 def __getattr__(self, key): 

830 try: 

831 return object.__getattribute__(self, key) 

832 except AttributeError as err: 

833 if key.startswith("_meta"): 

834 # Avoid a recursive loop if/when `self._meta*` 

835 # produces an `AttributeError` 

836 raise RuntimeError( 

837 f"Failed to generate metadata for {self}. " 

838 "This operation may not be supported by the current backend." 

839 ) 

840 

841 # Allow operands to be accessed as attributes 

842 # as long as the keys are not already reserved 

843 # by existing methods/properties 

844 _parameters = type(self)._parameters 

845 if key in _parameters: 

846 idx = _parameters.index(key) 

847 return self.operands[idx] 

848 

849 raise AttributeError( 

850 f"{err}\n\n" 

851 "This often means that you are attempting to use an unsupported " 

852 f"API function.." 

853 ) 

854 

855 

856class SingletonExpr(Expr): 

857 """A singleton Expr class 

858 

859 This is used to treat the subclassed expression as a singleton. Singletons 

860 are deduplicated by expr._name which is typically based on the dask.tokenize 

861 output. 

862 

863 This is a crucial performance optimization for expressions that walk through 

864 an optimizer and are recreated repeatedly but isn't safe for objects that 

865 cannot be reliably or quickly tokenized. 

866 """ 

867 

868 _instances: weakref.WeakValueDictionary[str, SingletonExpr] 

869 

870 def __new__(cls, *args, _determ_token=None, **kwargs): 

871 if not hasattr(cls, "_instances"): 

872 cls._instances = weakref.WeakValueDictionary() 

873 inst = super().__new__(cls, *args, _determ_token=_determ_token, **kwargs) 

874 _name = inst._name 

875 if _name in cls._instances and cls.__init__ == object.__init__: 

876 return cls._instances[_name] 

877 

878 cls._instances[_name] = inst 

879 return inst 

880 

881 

882def collect_dependents(expr) -> defaultdict: 

883 dependents = defaultdict(list) 

884 stack = [expr] 

885 seen = set() 

886 while stack: 

887 node = stack.pop() 

888 if node._name in seen: 

889 continue 

890 seen.add(node._name) 

891 

892 for dep in node.dependencies(): 

893 stack.append(dep) 

894 dependents[dep._name].append(weakref.ref(node)) 

895 return dependents 

896 

897 

898def optimize(expr: Expr, fuse: bool = True) -> Expr: 

899 """High level query optimization 

900 

901 This leverages three optimization passes: 

902 

903 1. Class based simplification using the ``_simplify`` function and methods 

904 2. Blockwise fusion 

905 

906 Parameters 

907 ---------- 

908 expr: 

909 Input expression to optimize 

910 fuse: 

911 whether or not to turn on blockwise fusion 

912 

913 See Also 

914 -------- 

915 simplify 

916 optimize_blockwise_fusion 

917 """ 

918 stage: OptimizerStage = "fused" if fuse else "simplified-physical" 

919 

920 return optimize_until(expr, stage) 

921 

922 

923def optimize_until(expr: Expr, stage: OptimizerStage) -> Expr: 

924 result = expr 

925 if stage == "logical": 

926 return result 

927 

928 # Simplify 

929 expr = result.simplify() 

930 if stage == "simplified-logical": 

931 return expr 

932 

933 # Manipulate Expression to make it more efficient 

934 if dask.config.get("optimization.tune.active", True): 

935 expr = expr.rewrite(kind="tune", rewritten={}) 

936 if stage == "tuned-logical": 

937 return expr 

938 

939 # Lower 

940 expr = expr.lower_completely() 

941 if stage == "physical": 

942 return expr 

943 

944 # Simplify again 

945 expr = expr.simplify() 

946 if stage == "simplified-physical": 

947 return expr 

948 

949 # Final graph-specific optimizations 

950 expr = expr.fuse() 

951 if stage == "fused": 

952 return expr 

953 

954 raise ValueError(f"Stage {stage!r} not supported.") 

955 

956 

957class LLGExpr(Expr): 

958 """Low Level Graph Expression""" 

959 

960 _parameters = ["dsk"] 

961 

962 def __dask_keys__(self): 

963 return list(self.operand("dsk")) 

964 

965 def _layer(self) -> dict: 

966 return ensure_dict(self.operand("dsk")) 

967 

968 

969class HLGExpr(Expr): 

970 _parameters = [ 

971 "dsk", 

972 "low_level_optimizer", 

973 "output_keys", 

974 "postcompute", 

975 "_cached_optimized", 

976 ] 

977 _defaults = { 

978 "low_level_optimizer": None, 

979 "output_keys": None, 

980 "postcompute": None, 

981 "_cached_optimized": None, 

982 } 

983 

984 @property 

985 def hlg(self): 

986 return self.operand("dsk") 

987 

988 @staticmethod 

989 def from_collection(collection, optimize_graph=True): 

990 from dask.highlevelgraph import HighLevelGraph 

991 

992 if hasattr(collection, "dask"): 

993 dsk = collection.dask.copy() 

994 else: 

995 dsk = collection.__dask_graph__() 

996 

997 # Delayed objects still ship with low level graphs as `dask` when going 

998 # through optimize / persist 

999 if not isinstance(dsk, HighLevelGraph): 

1000 

1001 dsk = HighLevelGraph.from_collections( 

1002 str(id(collection)), dsk, dependencies=() 

1003 ) 

1004 if optimize_graph and not hasattr(collection, "__dask_optimize__"): 

1005 warnings.warn( 

1006 f"Collection {type(collection)} does not define a " 

1007 "`__dask_optimize__` method. In the future this will raise. " 

1008 "If no optimization is desired, please set this to `None`.", 

1009 PendingDeprecationWarning, 

1010 ) 

1011 low_level_optimizer = None 

1012 else: 

1013 low_level_optimizer = ( 

1014 collection.__dask_optimize__ if optimize_graph else None 

1015 ) 

1016 return HLGExpr( 

1017 dsk=dsk, 

1018 low_level_optimizer=low_level_optimizer, 

1019 output_keys=collection.__dask_keys__(), 

1020 postcompute=collection.__dask_postcompute__(), 

1021 ) 

1022 

1023 def finalize_compute(self): 

1024 return HLGFinalizeCompute( 

1025 self, 

1026 low_level_optimizer=self.low_level_optimizer, 

1027 output_keys=self.output_keys, 

1028 postcompute=self.postcompute, 

1029 ) 

1030 

1031 def __dask_annotations__(self) -> dict[str, dict[Key, object]]: 

1032 # optimization has to be called (and cached) since blockwise fusion can 

1033 # alter annotations 

1034 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)` 

1035 dsk = self._optimized_dsk 

1036 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict) 

1037 for layer in dsk.layers.values(): 

1038 if layer.annotations: 

1039 annot = layer.annotations 

1040 for annot_type, value in annot.items(): 

1041 annotations_by_type[annot_type].update( 

1042 {k: (value(k) if callable(value) else value) for k in layer} 

1043 ) 

1044 return dict(annotations_by_type) 

1045 

1046 def __dask_keys__(self): 

1047 if (keys := self.operand("output_keys")) is not None: 

1048 return keys 

1049 dsk = self.hlg 

1050 # Note: This will materialize 

1051 dependencies = dsk.get_all_dependencies() 

1052 leafs = set(dependencies) 

1053 for val in dependencies.values(): 

1054 leafs -= val 

1055 self.output_keys = list(leafs) 

1056 return self.output_keys 

1057 

1058 @functools.cached_property 

1059 def _optimized_dsk(self) -> HighLevelGraph: 

1060 from dask.highlevelgraph import HighLevelGraph 

1061 

1062 optimizer = self.low_level_optimizer 

1063 keys = self.__dask_keys__() 

1064 dsk = self.hlg 

1065 if (optimizer := self.low_level_optimizer) is not None: 

1066 dsk = optimizer(dsk, keys) 

1067 return HighLevelGraph.merge(dsk) 

1068 

1069 @property 

1070 def deterministic_token(self): 

1071 if not self._determ_token: 

1072 self._determ_token = uuid.uuid4().hex 

1073 return self._determ_token 

1074 

1075 def _layer(self) -> dict: 

1076 dsk = self._optimized_dsk 

1077 return ensure_dict(dsk) 

1078 

1079 

1080class _HLGExprGroup(HLGExpr): 

1081 # Identical to HLGExpr 

1082 # Used internally to determine how output keys are supposed to be returned 

1083 pass 

1084 

1085 

1086class _HLGExprSequence(Expr): 

1087 

1088 def __getitem__(self, other): 

1089 return self.operands[other] 

1090 

1091 def _operands_for_repr(self): 

1092 return [ 

1093 f"name={self.operand('name')!r}", 

1094 f"dsk={self.operand('dsk')!r}", 

1095 ] 

1096 

1097 def _tree_repr_lines(self, indent=0, recursive=True): 

1098 return self._operands_for_repr() 

1099 

1100 def finalize_compute(self): 

1101 return _HLGExprSequence(*[op.finalize_compute() for op in self.operands]) 

1102 

1103 def _tune_down(self): 

1104 if len(self.operands) == 1: 

1105 return None 

1106 from dask.highlevelgraph import HighLevelGraph 

1107 

1108 groups = toolz.groupby( 

1109 lambda x: x.low_level_optimizer if isinstance(x, HLGExpr) else None, 

1110 self.operands, 

1111 ) 

1112 exprs = [] 

1113 changed = False 

1114 for optimizer, group in groups.items(): 

1115 if len(group) > 1: 

1116 graphs = [expr.hlg for expr in group] 

1117 

1118 changed = True 

1119 dsk = HighLevelGraph.merge(*graphs) 

1120 hlg_group = _HLGExprGroup( 

1121 dsk=dsk, 

1122 low_level_optimizer=optimizer, 

1123 output_keys=[v.__dask_keys__() for v in group], 

1124 postcompute=[g.postcompute for g in group], 

1125 ) 

1126 exprs.append(hlg_group) 

1127 else: 

1128 exprs.append(group[0]) 

1129 if not changed: 

1130 return None 

1131 return _HLGExprSequence(*exprs) 

1132 

1133 @functools.cached_property 

1134 def _optimized_dsk(self) -> HighLevelGraph: 

1135 from dask.highlevelgraph import HighLevelGraph 

1136 

1137 hlgexpr: HLGExpr 

1138 graphs = [] 

1139 # simplify_down ensure there are only one HLGExpr per optimizer/finalizer 

1140 for hlgexpr in self.operands: 

1141 keys = hlgexpr.__dask_keys__() 

1142 dsk = hlgexpr.hlg 

1143 if (optimizer := hlgexpr.low_level_optimizer) is not None: 

1144 dsk = optimizer(dsk, keys) 

1145 graphs.append(dsk) 

1146 

1147 return HighLevelGraph.merge(*graphs) 

1148 

1149 def __dask_graph__(self): 

1150 # This class has to override this and not just _layer to ensure the HLGs 

1151 # are not optimized individually 

1152 return ensure_dict(self._optimized_dsk) 

1153 

1154 _layer = __dask_graph__ 

1155 

1156 def __dask_annotations__(self) -> dict[str, dict[Key, object]]: 

1157 # optimization has to be called (and cached) since blockwise fusion can 

1158 # alter annotations 

1159 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)` 

1160 dsk = self._optimized_dsk 

1161 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict) 

1162 for layer in dsk.layers.values(): 

1163 if layer.annotations: 

1164 annot = layer.annotations 

1165 for annot_type, value in annot.items(): 

1166 annots = list( 

1167 (k, (value(k) if callable(value) else value)) for k in layer 

1168 ) 

1169 annotations_by_type[annot_type].update( 

1170 { 

1171 k: v 

1172 for k, v in annots 

1173 if not isinstance(v, _AnnotationsTombstone) 

1174 } 

1175 ) 

1176 if not annotations_by_type[annot_type]: 

1177 del annotations_by_type[annot_type] 

1178 return dict(annotations_by_type) 

1179 

1180 def __dask_keys__(self) -> list: 

1181 all_keys = [] 

1182 for op in self.operands: 

1183 if isinstance(op, _HLGExprGroup): 

1184 all_keys.extend(op.__dask_keys__()) 

1185 else: 

1186 all_keys.append(op.__dask_keys__()) 

1187 return all_keys 

1188 

1189 

1190class _ExprSequence(Expr): 

1191 """A sequence of expressions 

1192 

1193 This is used to be able to optimize multiple collections combined, e.g. when 

1194 being computed simultaneously with ``dask.compute((Expr1, Expr2))``. 

1195 """ 

1196 

1197 def __getitem__(self, other): 

1198 return self.operands[other] 

1199 

1200 def _layer(self) -> dict: 

1201 return toolz.merge(op._layer() for op in self.operands) 

1202 

1203 def __dask_keys__(self) -> list: 

1204 all_keys = [] 

1205 for op in self.operands: 

1206 all_keys.append(list(op.__dask_keys__())) 

1207 return all_keys 

1208 

1209 def __repr__(self): 

1210 return "ExprSequence(" + ", ".join(map(repr, self.operands)) + ")" 

1211 

1212 __str__ = __repr__ 

1213 

1214 def finalize_compute(self): 

1215 return _ExprSequence( 

1216 *(op.finalize_compute() for op in self.operands), 

1217 ) 

1218 

1219 def __dask_annotations__(self): 

1220 annotations_by_type = {} 

1221 for op in self.operands: 

1222 for k, v in op.__dask_annotations__().items(): 

1223 annotations_by_type.setdefault(k, {}).update(v) 

1224 return annotations_by_type 

1225 

1226 def __len__(self): 

1227 return len(self.operands) 

1228 

1229 def __iter__(self): 

1230 return iter(self.operands) 

1231 

1232 def _simplify_down(self): 

1233 from dask.highlevelgraph import HighLevelGraph 

1234 

1235 issue_warning = False 

1236 hlgs = [] 

1237 if any( 

1238 isinstance(op, (HLGExpr, HLGFinalizeCompute, dict)) for op in self.operands 

1239 ): 

1240 for op in self.operands: 

1241 if isinstance(op, (HLGExpr, HLGFinalizeCompute)): 

1242 hlgs.append(op) 

1243 elif isinstance(op, dict): 

1244 hlgs.append( 

1245 HLGExpr( 

1246 dsk=HighLevelGraph.from_collections( 

1247 str(id(op)), op, dependencies=() 

1248 ) 

1249 ) 

1250 ) 

1251 else: 

1252 issue_warning = True 

1253 opt = op.optimize() 

1254 hlgs.append( 

1255 HLGExpr( 

1256 dsk=HighLevelGraph.from_collections( 

1257 opt._name, opt.__dask_graph__(), dependencies=() 

1258 ) 

1259 ) 

1260 ) 

1261 if issue_warning: 

1262 warnings.warn( 

1263 "Computing mixed collections that are backed by " 

1264 "HighlevelGraphs/dicts and Expressions. " 

1265 "This forces Expressions to be materialized. " 

1266 "It is recommended to use only one type and separate the dask." 

1267 "compute calls if necessary.", 

1268 UserWarning, 

1269 ) 

1270 if not hlgs: 

1271 return None 

1272 return _HLGExprSequence(*hlgs) 

1273 

1274 

1275class _AnnotationsTombstone: ... 

1276 

1277 

1278class FinalizeCompute(Expr): 

1279 _parameters = ["expr"] 

1280 

1281 def _simplify_down(self): 

1282 return self.expr.finalize_compute() 

1283 

1284 

1285def _convert_dask_keys(keys): 

1286 from dask._task_spec import List, TaskRef 

1287 

1288 assert isinstance(keys, list) 

1289 new_keys = [] 

1290 for key in keys: 

1291 if isinstance(key, list): 

1292 new_keys.append(_convert_dask_keys(key)) 

1293 else: 

1294 new_keys.append(TaskRef(key)) 

1295 return List(*new_keys) 

1296 

1297 

1298class HLGFinalizeCompute(HLGExpr): 

1299 

1300 def _simplify_down(self): 

1301 if not self.postcompute: 

1302 return self.dsk 

1303 

1304 from dask.delayed import Delayed 

1305 

1306 # Skip finalization for Delayed 

1307 if self.dsk.postcompute == Delayed.__dask_postcompute__(self.dsk): 

1308 return self.dsk 

1309 return self 

1310 

1311 @property 

1312 def _name(self): 

1313 return f"finalize-{super()._name}" 

1314 

1315 def __dask_graph__(self): 

1316 # The baseclass __dask_graph__ will not just materialize this layer but 

1317 # also that of its dependencies, i.e. it will render the finalized and 

1318 # the non-finalized graph and combine them. We only want the finalized 

1319 # so we're overriding this. 

1320 # This is an artifact generated since the wrapped expression is 

1321 # identified automatically as a dependency but HLG expressions are not 

1322 # working in this layered way. 

1323 return self._layer() 

1324 

1325 @property 

1326 def hlg(self): 

1327 expr = self.operand("dsk") 

1328 layers = expr.dsk.layers.copy() 

1329 deps = expr.dsk.dependencies.copy() 

1330 keys = expr.__dask_keys__() 

1331 if isinstance(expr.postcompute, list): 

1332 postcomputes = expr.postcompute 

1333 else: 

1334 postcomputes = [expr.postcompute] 

1335 tasks = [ 

1336 Task(self._name, func, _convert_dask_keys(keys), *extra_args) 

1337 for func, extra_args in postcomputes 

1338 ] 

1339 from dask.highlevelgraph import HighLevelGraph, MaterializedLayer 

1340 

1341 leafs = set(deps) 

1342 for val in deps.values(): 

1343 leafs -= val 

1344 for t in tasks: 

1345 layers[t.key] = MaterializedLayer({t.key: t}) 

1346 deps[t.key] = leafs 

1347 return HighLevelGraph(layers, dependencies=deps) 

1348 

1349 def __dask_keys__(self): 

1350 return [self._name] 

1351 

1352 

1353class ProhibitReuse(Expr): 

1354 """ 

1355 An expression that guarantees that all keys are suffixes with a unique id. 

1356 This can be used to break a common subexpression apart. 

1357 """ 

1358 

1359 _parameters = ["expr"] 

1360 _ALLOWED_TYPES = [HLGExpr, LLGExpr, HLGFinalizeCompute, _HLGExprSequence] 

1361 

1362 def __dask_keys__(self): 

1363 return self._modify_keys(self.expr.__dask_keys__()) 

1364 

1365 @staticmethod 

1366 def _identity(obj): 

1367 return obj 

1368 

1369 @functools.cached_property 

1370 def _suffix(self): 

1371 return uuid.uuid4().hex 

1372 

1373 def _modify_keys(self, k): 

1374 if isinstance(k, list): 

1375 return [self._modify_keys(kk) for kk in k] 

1376 elif isinstance(k, tuple): 

1377 return (self._modify_keys(k[0]),) + k[1:] 

1378 elif isinstance(k, (int, float)): 

1379 k = str(k) 

1380 return f"{k}-{self._suffix}" 

1381 

1382 def _simplify_down(self): 

1383 # FIXME: Shuffling cannot be rewritten since the barrier key is 

1384 # hardcoded. Skipping this here should do the trick most of the time 

1385 if not isinstance( 

1386 self.expr, 

1387 tuple(self._ALLOWED_TYPES), 

1388 ): 

1389 return self.expr 

1390 

1391 def __dask_graph__(self): 

1392 try: 

1393 from distributed.shuffle._core import P2PBarrierTask 

1394 except ModuleNotFoundError: 

1395 P2PBarrierTask = type(None) 

1396 dsk = convert_legacy_graph(self.expr.__dask_graph__()) 

1397 

1398 subs = {old_key: self._modify_keys(old_key) for old_key in dsk} 

1399 dsk2 = {} 

1400 for old_key, new_key in subs.items(): 

1401 t = dsk[old_key] 

1402 if isinstance(t, P2PBarrierTask): 

1403 warnings.warn( 

1404 "Cannot block reusing for graphs including a " 

1405 "P2PBarrierTask. This may cause unexpected results. " 

1406 "This typically happens when converting a dask " 

1407 "DataFrame to delayed objects.", 

1408 UserWarning, 

1409 ) 

1410 return dsk 

1411 dsk2[new_key] = Task( 

1412 new_key, 

1413 ProhibitReuse._identity, 

1414 t.substitute(subs), 

1415 ) 

1416 

1417 dsk2.update(dsk) 

1418 return dsk2 

1419 

1420 _layer = __dask_graph__