Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/_expr.py: 21%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

731 statements  

1from __future__ import annotations 

2 

3import functools 

4import os 

5import uuid 

6import warnings 

7import weakref 

8from collections import defaultdict 

9from collections.abc import Generator 

10from typing import TYPE_CHECKING, Literal 

11 

12import toolz 

13 

14import dask 

15from dask._task_spec import Task, convert_legacy_graph 

16from dask.tokenize import _tokenize_deterministic 

17from dask.typing import Key 

18from dask.utils import ensure_dict, funcname, import_required 

19 

20if TYPE_CHECKING: 

21 # TODO import from typing (requires Python >=3.10) 

22 from typing import Any, TypeAlias 

23 

24 from dask.highlevelgraph import HighLevelGraph 

25 

26OptimizerStage: TypeAlias = Literal[ 

27 "logical", 

28 "simplified-logical", 

29 "tuned-logical", 

30 "physical", 

31 "simplified-physical", 

32 "fused", 

33] 

34 

35 

36def _unpack_collections(o): 

37 from dask.delayed import Delayed 

38 

39 if isinstance(o, Expr): 

40 return o 

41 

42 if hasattr(o, "expr") and not isinstance(o, Delayed): 

43 return o.expr 

44 else: 

45 return o 

46 

47 

48class Expr: 

49 _parameters: list[str] = [] 

50 _defaults: dict[str, Any] = {} 

51 

52 _pickle_functools_cache: bool = True 

53 

54 operands: list 

55 

56 _determ_token: str | None 

57 

58 def __new__(cls, *args, _determ_token=None, **kwargs): 

59 operands = list(args) 

60 for parameter in cls._parameters[len(operands) :]: 

61 try: 

62 operands.append(kwargs.pop(parameter)) 

63 except KeyError: 

64 operands.append(cls._defaults[parameter]) 

65 assert not kwargs, kwargs 

66 inst = object.__new__(cls) 

67 

68 inst._determ_token = _determ_token 

69 inst.operands = [_unpack_collections(o) for o in operands] 

70 # This is typically cached. Make sure the cache is populated by calling 

71 # it once 

72 inst._name 

73 return inst 

74 

75 def _tune_down(self): 

76 return None 

77 

78 def _tune_up(self, parent): 

79 return None 

80 

81 def finalize_compute(self): 

82 return self 

83 

84 def _operands_for_repr(self): 

85 return [ 

86 f"{param}={repr(op)}" for param, op in zip(self._parameters, self.operands) 

87 ] 

88 

89 def __str__(self): 

90 s = ", ".join(self._operands_for_repr()) 

91 return f"{type(self).__name__}({s})" 

92 

93 def __repr__(self): 

94 return str(self) 

95 

96 def _tree_repr_argument_construction(self, i, op, header): 

97 try: 

98 param = self._parameters[i] 

99 default = self._defaults[param] 

100 except (IndexError, KeyError): 

101 param = self._parameters[i] if i < len(self._parameters) else "" 

102 default = "--no-default--" 

103 

104 if repr(op) != repr(default): 

105 if param: 

106 header += f" {param}={repr(op)}" 

107 else: 

108 header += repr(op) 

109 return header 

110 

111 def _tree_repr_lines(self, indent=0, recursive=True): 

112 return " " * indent + repr(self) 

113 

114 def tree_repr(self): 

115 return os.linesep.join(self._tree_repr_lines()) 

116 

117 def analyze(self, filename: str | None = None, format: str | None = None) -> None: 

118 from dask.dataframe.dask_expr._expr import Expr as DFExpr 

119 from dask.dataframe.dask_expr.diagnostics import analyze 

120 

121 if not isinstance(self, DFExpr): 

122 raise TypeError( 

123 "analyze is only supported for dask.dataframe.Expr objects." 

124 ) 

125 return analyze(self, filename=filename, format=format) 

126 

127 def explain( 

128 self, stage: OptimizerStage = "fused", format: str | None = None 

129 ) -> None: 

130 from dask.dataframe.dask_expr.diagnostics import explain 

131 

132 return explain(self, stage, format) 

133 

134 def pprint(self): 

135 for line in self._tree_repr_lines(): 

136 print(line) 

137 

138 def __hash__(self): 

139 return hash(self._name) 

140 

141 def __dask_tokenize__(self): 

142 if not self._determ_token: 

143 # If the subclass does not implement a __dask_tokenize__ we'll want 

144 # to tokenize all operands. 

145 # Note how this differs to the implementation of 

146 # Expr.deterministic_token 

147 self._determ_token = _tokenize_deterministic(type(self), *self.operands) 

148 return self._determ_token 

149 

150 def __dask_keys__(self): 

151 """The keys for this expression 

152 

153 This is used to determine the keys of the output collection 

154 when this expression is computed. 

155 

156 Returns 

157 ------- 

158 keys: list 

159 The keys for this expression 

160 """ 

161 return [(self._name, i) for i in range(self.npartitions)] 

162 

163 @staticmethod 

164 def _reconstruct(*args): 

165 typ, *operands, token, cache = args 

166 inst = typ(*operands, _determ_token=token) 

167 for k, v in cache.items(): 

168 inst.__dict__[k] = v 

169 return inst 

170 

171 def __reduce__(self): 

172 if dask.config.get("dask-expr-no-serialize", False): 

173 raise RuntimeError(f"Serializing a {type(self)} object") 

174 cache = {} 

175 if type(self)._pickle_functools_cache: 

176 for k, v in type(self).__dict__.items(): 

177 if isinstance(v, functools.cached_property) and k in self.__dict__: 

178 cache[k] = getattr(self, k) 

179 

180 return Expr._reconstruct, tuple( 

181 [type(self), *self.operands, self.deterministic_token, cache] 

182 ) 

183 

184 def _depth(self, cache=None): 

185 """Depth of the expression tree 

186 

187 Returns 

188 ------- 

189 depth: int 

190 """ 

191 if cache is None: 

192 cache = {} 

193 if not self.dependencies(): 

194 return 1 

195 else: 

196 result = [] 

197 for expr in self.dependencies(): 

198 if expr._name in cache: 

199 result.append(cache[expr._name]) 

200 else: 

201 result.append(expr._depth(cache) + 1) 

202 cache[expr._name] = result[-1] 

203 return max(result) 

204 

205 def __setattr__(self, name: str, value: Any) -> None: 

206 if name in ["operands", "_determ_token"]: 

207 object.__setattr__(self, name, value) 

208 return 

209 try: 

210 params = type(self)._parameters 

211 operands = object.__getattribute__(self, "operands") 

212 operands[params.index(name)] = value 

213 except ValueError: 

214 raise AttributeError( 

215 f"{type(self).__name__} object has no attribute {name}" 

216 ) 

217 

218 def operand(self, key): 

219 # Access an operand unambiguously 

220 # (e.g. if the key is reserved by a method/property) 

221 return self.operands[type(self)._parameters.index(key)] 

222 

223 def dependencies(self): 

224 # Dependencies are `Expr` operands only 

225 return [operand for operand in self.operands if isinstance(operand, Expr)] 

226 

227 def _task(self, key: Key, index: int) -> Task: 

228 """The task for the i'th partition 

229 

230 Parameters 

231 ---------- 

232 index: 

233 The index of the partition of this dataframe 

234 

235 Examples 

236 -------- 

237 >>> class Add(Expr): 

238 ... def _task(self, i): 

239 ... return Task( 

240 ... self.__dask_keys__()[i], 

241 ... operator.add, 

242 ... TaskRef((self.left._name, i)), 

243 ... TaskRef((self.right._name, i)) 

244 ... ) 

245 

246 Returns 

247 ------- 

248 task: 

249 The Dask task to compute this partition 

250 

251 See Also 

252 -------- 

253 Expr._layer 

254 """ 

255 raise NotImplementedError( 

256 "Expressions should define either _layer (full dictionary) or _task" 

257 f" (single task). This expression {type(self)} defines neither" 

258 ) 

259 

260 def _layer(self) -> dict: 

261 """The graph layer added by this expression. 

262 

263 Simple expressions that apply one task per partition can choose to only 

264 implement `Expr._task` instead. 

265 

266 Examples 

267 -------- 

268 >>> class Add(Expr): 

269 ... def _layer(self): 

270 ... return { 

271 ... name: Task( 

272 ... name, 

273 ... operator.add, 

274 ... TaskRef((self.left._name, i)), 

275 ... TaskRef((self.right._name, i)) 

276 ... ) 

277 ... for i, name in enumerate(self.__dask_keys__()) 

278 ... } 

279 

280 Returns 

281 ------- 

282 layer: dict 

283 The Dask task graph added by this expression 

284 

285 See Also 

286 -------- 

287 Expr._task 

288 Expr.__dask_graph__ 

289 """ 

290 

291 return { 

292 (self._name, i): self._task((self._name, i), i) 

293 for i in range(self.npartitions) 

294 } 

295 

296 def rewrite(self, kind: str, rewritten): 

297 """Rewrite an expression 

298 

299 This leverages the ``._{kind}_down`` and ``._{kind}_up`` 

300 methods defined on each class 

301 

302 Returns 

303 ------- 

304 expr: 

305 output expression 

306 changed: 

307 whether or not any change occured 

308 """ 

309 if self._name in rewritten: 

310 return rewritten[self._name] 

311 

312 expr = self 

313 down_name = f"_{kind}_down" 

314 up_name = f"_{kind}_up" 

315 while True: 

316 _continue = False 

317 

318 # Rewrite this node 

319 out = getattr(expr, down_name)() 

320 if out is None: 

321 out = expr 

322 if not isinstance(out, Expr): 

323 return out 

324 if out._name != expr._name: 

325 expr = out 

326 continue 

327 

328 # Allow children to rewrite their parents 

329 for child in expr.dependencies(): 

330 out = getattr(child, up_name)(expr) 

331 if out is None: 

332 out = expr 

333 if not isinstance(out, Expr): 

334 return out 

335 if out is not expr and out._name != expr._name: 

336 expr = out 

337 _continue = True 

338 break 

339 

340 if _continue: 

341 continue 

342 

343 # Rewrite all of the children 

344 new_operands = [] 

345 changed = False 

346 for operand in expr.operands: 

347 if isinstance(operand, Expr): 

348 new = operand.rewrite(kind=kind, rewritten=rewritten) 

349 rewritten[operand._name] = new 

350 if new._name != operand._name: 

351 changed = True 

352 else: 

353 new = operand 

354 new_operands.append(new) 

355 

356 if changed: 

357 expr = type(expr)(*new_operands) 

358 continue 

359 else: 

360 break 

361 

362 return expr 

363 

364 def simplify_once(self, dependents: defaultdict, simplified: dict): 

365 """Simplify an expression 

366 

367 This leverages the ``._simplify_down`` and ``._simplify_up`` 

368 methods defined on each class 

369 

370 Parameters 

371 ---------- 

372 

373 dependents: defaultdict[list] 

374 The dependents for every node. 

375 simplified: dict 

376 Cache of simplified expressions for these dependents. 

377 

378 Returns 

379 ------- 

380 expr: 

381 output expression 

382 """ 

383 # Check if we've already simplified for these dependents 

384 if self._name in simplified: 

385 return simplified[self._name] 

386 

387 expr = self 

388 

389 while True: 

390 out = expr._simplify_down() 

391 if out is None: 

392 out = expr 

393 if not isinstance(out, Expr): 

394 return out 

395 if out._name != expr._name: 

396 expr = out 

397 

398 # Allow children to simplify their parents 

399 for child in expr.dependencies(): 

400 out = child._simplify_up(expr, dependents) 

401 if out is None: 

402 out = expr 

403 

404 if not isinstance(out, Expr): 

405 return out 

406 if out is not expr and out._name != expr._name: 

407 expr = out 

408 break 

409 

410 # Rewrite all of the children 

411 new_operands = [] 

412 changed = False 

413 for operand in expr.operands: 

414 if isinstance(operand, Expr): 

415 # Bandaid for now, waiting for Singleton 

416 dependents[operand._name].append(weakref.ref(expr)) 

417 new = operand.simplify_once( 

418 dependents=dependents, simplified=simplified 

419 ) 

420 simplified[operand._name] = new 

421 if new._name != operand._name: 

422 changed = True 

423 else: 

424 new = operand 

425 new_operands.append(new) 

426 

427 if changed: 

428 expr = type(expr)(*new_operands) 

429 

430 break 

431 

432 return expr 

433 

434 def optimize(self, fuse: bool = False) -> Expr: 

435 stage: OptimizerStage = "fused" if fuse else "simplified-physical" 

436 

437 return optimize_until(self, stage) 

438 

439 def fuse(self) -> Expr: 

440 return self 

441 

442 def simplify(self) -> Expr: 

443 expr = self 

444 seen = set() 

445 while True: 

446 dependents = collect_dependents(expr) 

447 new = expr.simplify_once(dependents=dependents, simplified={}) 

448 if new._name == expr._name: 

449 break 

450 if new._name in seen: 

451 raise RuntimeError( 

452 f"Optimizer does not converge. {expr!r} simplified to {new!r} which was already seen. " 

453 "Please report this issue on the dask issue tracker with a minimal reproducer." 

454 ) 

455 seen.add(new._name) 

456 expr = new 

457 return expr 

458 

459 def _simplify_down(self): 

460 return 

461 

462 def _simplify_up(self, parent, dependents): 

463 return 

464 

465 def lower_once(self, lowered: dict): 

466 # Check for a chached result 

467 try: 

468 return lowered[self._name] 

469 except KeyError: 

470 pass 

471 

472 expr = self 

473 

474 # Lower this node 

475 out = expr._lower() 

476 if out is None: 

477 out = expr 

478 if not isinstance(out, Expr): 

479 return out 

480 

481 # Lower all children 

482 new_operands = [] 

483 changed = False 

484 for operand in out.operands: 

485 if isinstance(operand, Expr): 

486 new = operand.lower_once(lowered) 

487 if new._name != operand._name: 

488 changed = True 

489 else: 

490 new = operand 

491 new_operands.append(new) 

492 

493 if changed: 

494 out = type(out)(*new_operands) 

495 

496 # Cache the result and return 

497 return lowered.setdefault(self._name, out) 

498 

499 def lower_completely(self) -> Expr: 

500 """Lower an expression completely 

501 

502 This calls the ``lower_once`` method in a loop 

503 until nothing changes. This function does not 

504 apply any other optimizations (like ``simplify``). 

505 

506 Returns 

507 ------- 

508 expr: 

509 output expression 

510 

511 See Also 

512 -------- 

513 Expr.lower_once 

514 Expr._lower 

515 """ 

516 # Lower until nothing changes 

517 expr = self 

518 lowered: dict = {} 

519 while True: 

520 new = expr.lower_once(lowered) 

521 if new._name == expr._name: 

522 break 

523 expr = new 

524 return expr 

525 

526 def _lower(self): 

527 return 

528 

529 @functools.cached_property 

530 def _funcname(self) -> str: 

531 return funcname(type(self)).lower() 

532 

533 @property 

534 def deterministic_token(self): 

535 if not self._determ_token: 

536 # Just tokenize self to fall back on __dask_tokenize__ 

537 # Note how this differs to the implementation of __dask_tokenize__ 

538 self._determ_token = self.__dask_tokenize__() 

539 return self._determ_token 

540 

541 @functools.cached_property 

542 def _name(self) -> str: 

543 return self._funcname + "-" + self.deterministic_token 

544 

545 @property 

546 def _meta(self): 

547 raise NotImplementedError() 

548 

549 @classmethod 

550 def _annotations_tombstone(cls) -> _AnnotationsTombstone: 

551 return _AnnotationsTombstone() 

552 

553 def __dask_annotations__(self): 

554 return {} 

555 

556 def __dask_graph__(self): 

557 """Traverse expression tree, collect layers 

558 

559 Subclasses generally do not want to override this method unless custom 

560 logic is required to treat (e.g. ignore) specific operands during graph 

561 generation. 

562 

563 See also 

564 -------- 

565 Expr._layer 

566 Expr._task 

567 """ 

568 stack = [self] 

569 seen = set() 

570 layers = [] 

571 while stack: 

572 expr = stack.pop() 

573 

574 if expr._name in seen: 

575 continue 

576 seen.add(expr._name) 

577 

578 layers.append(expr._layer()) 

579 for operand in expr.dependencies(): 

580 stack.append(operand) 

581 

582 return toolz.merge(layers) 

583 

584 @property 

585 def dask(self): 

586 return self.__dask_graph__() 

587 

588 def substitute(self, old, new) -> Expr: 

589 """Substitute a specific term within the expression 

590 

591 Note that replacing non-`Expr` terms may produce 

592 unexpected results, and is not recommended. 

593 Substituting boolean values is not allowed. 

594 

595 Parameters 

596 ---------- 

597 old: 

598 Old term to find and replace. 

599 new: 

600 New term to replace instances of `old` with. 

601 

602 Examples 

603 -------- 

604 >>> (df + 10).substitute(10, 20) # doctest: +SKIP 

605 df + 20 

606 """ 

607 return self._substitute(old, new, _seen=set()) 

608 

609 def _substitute(self, old, new, _seen): 

610 if self._name in _seen: 

611 return self 

612 # Check if we are replacing a literal 

613 if isinstance(old, Expr): 

614 substitute_literal = False 

615 if self._name == old._name: 

616 return new 

617 else: 

618 substitute_literal = True 

619 if isinstance(old, bool): 

620 raise TypeError("Arguments to `substitute` cannot be bool.") 

621 

622 new_exprs = [] 

623 update = False 

624 for operand in self.operands: 

625 if isinstance(operand, Expr): 

626 val = operand._substitute(old, new, _seen) 

627 if operand._name != val._name: 

628 update = True 

629 new_exprs.append(val) 

630 elif ( 

631 "Fused" in type(self).__name__ 

632 and isinstance(operand, list) 

633 and all(isinstance(op, Expr) for op in operand) 

634 ): 

635 # Special handling for `Fused`. 

636 # We make no promise to dive through a 

637 # list operand in general, but NEED to 

638 # do so for the `Fused.exprs` operand. 

639 val = [] 

640 for op in operand: 

641 val.append(op._substitute(old, new, _seen)) 

642 if val[-1]._name != op._name: 

643 update = True 

644 new_exprs.append(val) 

645 elif ( 

646 substitute_literal 

647 and not isinstance(operand, bool) 

648 and isinstance(operand, type(old)) 

649 and operand == old 

650 ): 

651 new_exprs.append(new) 

652 update = True 

653 else: 

654 new_exprs.append(operand) 

655 

656 if update: # Only recreate if something changed 

657 return type(self)(*new_exprs) 

658 else: 

659 _seen.add(self._name) 

660 return self 

661 

662 def substitute_parameters(self, substitutions: dict) -> Expr: 

663 """Substitute specific `Expr` parameters 

664 

665 Parameters 

666 ---------- 

667 substitutions: 

668 Mapping of parameter keys to new values. Keys that 

669 are not found in ``self._parameters`` will be ignored. 

670 """ 

671 if not substitutions: 

672 return self 

673 

674 changed = False 

675 new_operands = [] 

676 for i, operand in enumerate(self.operands): 

677 if i < len(self._parameters) and self._parameters[i] in substitutions: 

678 new_operands.append(substitutions[self._parameters[i]]) 

679 changed = True 

680 else: 

681 new_operands.append(operand) 

682 if changed: 

683 return type(self)(*new_operands) 

684 return self 

685 

686 def _node_label_args(self): 

687 """Operands to include in the node label by `visualize`""" 

688 return self.dependencies() 

689 

690 def _to_graphviz( 

691 self, 

692 rankdir="BT", 

693 graph_attr=None, 

694 node_attr=None, 

695 edge_attr=None, 

696 **kwargs, 

697 ): 

698 from dask.dot import label, name 

699 

700 graphviz = import_required( 

701 "graphviz", 

702 "Drawing dask graphs with the graphviz visualization engine requires the `graphviz` " 

703 "python library and the `graphviz` system library.\n\n" 

704 "Please either conda or pip install as follows:\n\n" 

705 " conda install python-graphviz # either conda install\n" 

706 " python -m pip install graphviz # or pip install and follow installation instructions", 

707 ) 

708 

709 graph_attr = graph_attr or {} 

710 node_attr = node_attr or {} 

711 edge_attr = edge_attr or {} 

712 

713 graph_attr["rankdir"] = rankdir 

714 node_attr["shape"] = "box" 

715 node_attr["fontname"] = "helvetica" 

716 

717 graph_attr.update(kwargs) 

718 g = graphviz.Digraph( 

719 graph_attr=graph_attr, 

720 node_attr=node_attr, 

721 edge_attr=edge_attr, 

722 ) 

723 

724 stack = [self] 

725 seen = set() 

726 dependencies = {} 

727 while stack: 

728 expr = stack.pop() 

729 

730 if expr._name in seen: 

731 continue 

732 seen.add(expr._name) 

733 

734 dependencies[expr] = set(expr.dependencies()) 

735 for dep in expr.dependencies(): 

736 stack.append(dep) 

737 

738 cache = {} 

739 for expr in dependencies: 

740 expr_name = name(expr) 

741 attrs = {} 

742 

743 # Make node label 

744 deps = [ 

745 funcname(type(dep)) if isinstance(dep, Expr) else str(dep) 

746 for dep in expr._node_label_args() 

747 ] 

748 _label = funcname(type(expr)) 

749 if deps: 

750 _label = f"{_label}({', '.join(deps)})" if deps else _label 

751 node_label = label(_label, cache=cache) 

752 

753 attrs.setdefault("label", str(node_label)) 

754 attrs.setdefault("fontsize", "20") 

755 g.node(expr_name, **attrs) 

756 

757 for expr, deps in dependencies.items(): 

758 expr_name = name(expr) 

759 for dep in deps: 

760 dep_name = name(dep) 

761 g.edge(dep_name, expr_name) 

762 

763 return g 

764 

765 def visualize(self, filename="dask-expr.svg", format=None, **kwargs): 

766 """ 

767 Visualize the expression graph. 

768 Requires ``graphviz`` to be installed. 

769 

770 Parameters 

771 ---------- 

772 filename : str or None, optional 

773 The name of the file to write to disk. If the provided `filename` 

774 doesn't include an extension, '.png' will be used by default. 

775 If `filename` is None, no file will be written, and the graph is 

776 rendered in the Jupyter notebook only. 

777 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional 

778 Format in which to write output file. Default is 'svg'. 

779 **kwargs 

780 Additional keyword arguments to forward to ``to_graphviz``. 

781 """ 

782 from dask.dot import graphviz_to_file 

783 

784 g = self._to_graphviz(**kwargs) 

785 graphviz_to_file(g, filename, format) 

786 return g 

787 

788 def walk(self) -> Generator[Expr]: 

789 """Iterate through all expressions in the tree 

790 

791 Returns 

792 ------- 

793 nodes 

794 Generator of Expr instances in the graph. 

795 Ordering is a depth-first search of the expression tree 

796 """ 

797 stack = [self] 

798 seen = set() 

799 while stack: 

800 node = stack.pop() 

801 if node._name in seen: 

802 continue 

803 seen.add(node._name) 

804 

805 for dep in node.dependencies(): 

806 stack.append(dep) 

807 

808 yield node 

809 

810 def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]: 

811 """Search the expression graph for a specific operation type 

812 

813 Parameters 

814 ---------- 

815 operation 

816 The operation type to search for. 

817 

818 Returns 

819 ------- 

820 nodes 

821 Generator of `operation` instances. Ordering corresponds 

822 to a depth-first search of the expression graph. 

823 """ 

824 assert ( 

825 isinstance(operation, tuple) 

826 and all(issubclass(e, Expr) for e in operation) 

827 or issubclass(operation, Expr) # type: ignore 

828 ), "`operation` must be`Expr` subclass)" 

829 return (expr for expr in self.walk() if isinstance(expr, operation)) 

830 

831 def __getattr__(self, key): 

832 try: 

833 return object.__getattribute__(self, key) 

834 except AttributeError as err: 

835 if key.startswith("_meta"): 

836 # Avoid a recursive loop if/when `self._meta*` 

837 # produces an `AttributeError` 

838 raise RuntimeError( 

839 f"Failed to generate metadata for {self}. " 

840 "This operation may not be supported by the current backend." 

841 ) 

842 

843 # Allow operands to be accessed as attributes 

844 # as long as the keys are not already reserved 

845 # by existing methods/properties 

846 _parameters = type(self)._parameters 

847 if key in _parameters: 

848 idx = _parameters.index(key) 

849 return self.operands[idx] 

850 

851 raise AttributeError( 

852 f"{err}\n\n" 

853 "This often means that you are attempting to use an unsupported " 

854 f"API function.." 

855 ) 

856 

857 

858class SingletonExpr(Expr): 

859 """A singleton Expr class 

860 

861 This is used to treat the subclassed expression as a singleton. Singletons 

862 are deduplicated by expr._name which is typically based on the dask.tokenize 

863 output. 

864 

865 This is a crucial performance optimization for expressions that walk through 

866 an optimizer and are recreated repeatedly but isn't safe for objects that 

867 cannot be reliably or quickly tokenized. 

868 """ 

869 

870 _instances: weakref.WeakValueDictionary[str, SingletonExpr] 

871 

872 def __new__(cls, *args, _determ_token=None, **kwargs): 

873 if not hasattr(cls, "_instances"): 

874 cls._instances = weakref.WeakValueDictionary() 

875 inst = super().__new__(cls, *args, _determ_token=_determ_token, **kwargs) 

876 _name = inst._name 

877 if _name in cls._instances and cls.__init__ == object.__init__: 

878 return cls._instances[_name] 

879 

880 cls._instances[_name] = inst 

881 return inst 

882 

883 

884def collect_dependents(expr) -> defaultdict: 

885 dependents = defaultdict(list) 

886 stack = [expr] 

887 seen = set() 

888 while stack: 

889 node = stack.pop() 

890 if node._name in seen: 

891 continue 

892 seen.add(node._name) 

893 

894 for dep in node.dependencies(): 

895 stack.append(dep) 

896 dependents[dep._name].append(weakref.ref(node)) 

897 return dependents 

898 

899 

900def optimize(expr: Expr, fuse: bool = True) -> Expr: 

901 """High level query optimization 

902 

903 This leverages three optimization passes: 

904 

905 1. Class based simplification using the ``_simplify`` function and methods 

906 2. Blockwise fusion 

907 

908 Parameters 

909 ---------- 

910 expr: 

911 Input expression to optimize 

912 fuse: 

913 whether or not to turn on blockwise fusion 

914 

915 See Also 

916 -------- 

917 simplify 

918 optimize_blockwise_fusion 

919 """ 

920 stage: OptimizerStage = "fused" if fuse else "simplified-physical" 

921 

922 return optimize_until(expr, stage) 

923 

924 

925def optimize_until(expr: Expr, stage: OptimizerStage) -> Expr: 

926 result = expr 

927 if stage == "logical": 

928 return result 

929 

930 # Simplify 

931 expr = result.simplify() 

932 if stage == "simplified-logical": 

933 return expr 

934 

935 # Manipulate Expression to make it more efficient 

936 expr = expr.rewrite(kind="tune", rewritten={}) 

937 if stage == "tuned-logical": 

938 return expr 

939 

940 # Lower 

941 expr = expr.lower_completely() 

942 if stage == "physical": 

943 return expr 

944 

945 # Simplify again 

946 expr = expr.simplify() 

947 if stage == "simplified-physical": 

948 return expr 

949 

950 # Final graph-specific optimizations 

951 expr = expr.fuse() 

952 if stage == "fused": 

953 return expr 

954 

955 raise ValueError(f"Stage {stage!r} not supported.") 

956 

957 

958class LLGExpr(Expr): 

959 """Low Level Graph Expression""" 

960 

961 _parameters = ["dsk"] 

962 

963 def __dask_keys__(self): 

964 return list(self.operand("dsk")) 

965 

966 def _layer(self) -> dict: 

967 return ensure_dict(self.operand("dsk")) 

968 

969 

970class HLGExpr(Expr): 

971 _parameters = [ 

972 "dsk", 

973 "low_level_optimizer", 

974 "output_keys", 

975 "postcompute", 

976 "_cached_optimized", 

977 ] 

978 _defaults = { 

979 "low_level_optimizer": None, 

980 "output_keys": None, 

981 "postcompute": None, 

982 "_cached_optimized": None, 

983 } 

984 

985 @property 

986 def hlg(self): 

987 return self.operand("dsk") 

988 

989 @staticmethod 

990 def from_collection(collection, optimize_graph=True): 

991 from dask.highlevelgraph import HighLevelGraph 

992 

993 if hasattr(collection, "dask"): 

994 dsk = collection.dask.copy() 

995 else: 

996 dsk = collection.__dask_graph__() 

997 

998 # Delayed objects still ship with low level graphs as `dask` when going 

999 # through optimize / persist 

1000 if not isinstance(dsk, HighLevelGraph): 

1001 

1002 dsk = HighLevelGraph.from_collections( 

1003 str(id(collection)), dsk, dependencies=() 

1004 ) 

1005 if optimize_graph and not hasattr(collection, "__dask_optimize__"): 

1006 warnings.warn( 

1007 f"Collection {type(collection)} does not define a " 

1008 "`__dask_optimize__` method. In the future this will raise. " 

1009 "If no optimization is desired, please set this to `None`.", 

1010 PendingDeprecationWarning, 

1011 ) 

1012 low_level_optimizer = None 

1013 else: 

1014 low_level_optimizer = ( 

1015 collection.__dask_optimize__ if optimize_graph else None 

1016 ) 

1017 return HLGExpr( 

1018 dsk=dsk, 

1019 low_level_optimizer=low_level_optimizer, 

1020 output_keys=collection.__dask_keys__(), 

1021 postcompute=collection.__dask_postcompute__(), 

1022 ) 

1023 

1024 def finalize_compute(self): 

1025 return HLGFinalizeCompute( 

1026 self, 

1027 low_level_optimizer=self.low_level_optimizer, 

1028 output_keys=self.output_keys, 

1029 postcompute=self.postcompute, 

1030 ) 

1031 

1032 def __dask_annotations__(self) -> dict[str, dict[Key, object]]: 

1033 # optimization has to be called (and cached) since blockwise fusion can 

1034 # alter annotations 

1035 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)` 

1036 dsk = self._optimized_dsk 

1037 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict) 

1038 for layer in dsk.layers.values(): 

1039 if layer.annotations: 

1040 annot = layer.annotations 

1041 for annot_type, value in annot.items(): 

1042 annotations_by_type[annot_type].update( 

1043 {k: (value(k) if callable(value) else value) for k in layer} 

1044 ) 

1045 return dict(annotations_by_type) 

1046 

1047 def __dask_keys__(self): 

1048 if (keys := self.operand("output_keys")) is not None: 

1049 return keys 

1050 dsk = self.hlg 

1051 # Note: This will materialize 

1052 dependencies = dsk.get_all_dependencies() 

1053 leafs = set(dependencies) 

1054 for val in dependencies.values(): 

1055 leafs -= val 

1056 self.output_keys = list(leafs) 

1057 return self.output_keys 

1058 

1059 @functools.cached_property 

1060 def _optimized_dsk(self) -> HighLevelGraph: 

1061 from dask.highlevelgraph import HighLevelGraph 

1062 

1063 optimizer = self.low_level_optimizer 

1064 keys = self.__dask_keys__() 

1065 dsk = self.hlg 

1066 if (optimizer := self.low_level_optimizer) is not None: 

1067 dsk = optimizer(dsk, keys) 

1068 return HighLevelGraph.merge(dsk) 

1069 

1070 @property 

1071 def deterministic_token(self): 

1072 if not self._determ_token: 

1073 self._determ_token = uuid.uuid4().hex 

1074 return self._determ_token 

1075 

1076 def _layer(self) -> dict: 

1077 dsk = self._optimized_dsk 

1078 return ensure_dict(dsk) 

1079 

1080 

1081class _HLGExprGroup(HLGExpr): 

1082 # Identical to HLGExpr 

1083 # Used internally to determine how output keys are supposed to be returned 

1084 pass 

1085 

1086 

1087class _HLGExprSequence(Expr): 

1088 

1089 def __getitem__(self, other): 

1090 return self.operands[other] 

1091 

1092 def _operands_for_repr(self): 

1093 return [ 

1094 f"name={self.operand('name')!r}", 

1095 f"dsk={self.operand('dsk')!r}", 

1096 ] 

1097 

1098 def _tree_repr_lines(self, indent=0, recursive=True): 

1099 return self._operands_for_repr() 

1100 

1101 def finalize_compute(self): 

1102 return _HLGExprSequence(*[op.finalize_compute() for op in self.operands]) 

1103 

1104 def _tune_down(self): 

1105 if len(self.operands) == 1: 

1106 return None 

1107 from dask.highlevelgraph import HighLevelGraph 

1108 

1109 groups = toolz.groupby( 

1110 lambda x: x.low_level_optimizer if isinstance(x, HLGExpr) else None, 

1111 self.operands, 

1112 ) 

1113 exprs = [] 

1114 changed = False 

1115 for optimizer, group in groups.items(): 

1116 if len(group) > 1: 

1117 graphs = [expr.hlg for expr in group] 

1118 

1119 changed = True 

1120 dsk = HighLevelGraph.merge(*graphs) 

1121 hlg_group = _HLGExprGroup( 

1122 dsk=dsk, 

1123 low_level_optimizer=optimizer, 

1124 output_keys=[v.__dask_keys__() for v in group], 

1125 postcompute=[g.postcompute for g in group], 

1126 ) 

1127 exprs.append(hlg_group) 

1128 else: 

1129 exprs.append(group[0]) 

1130 if not changed: 

1131 return None 

1132 return _HLGExprSequence(*exprs) 

1133 

1134 @functools.cached_property 

1135 def _optimized_dsk(self) -> HighLevelGraph: 

1136 from dask.highlevelgraph import HighLevelGraph 

1137 

1138 hlgexpr: HLGExpr 

1139 graphs = [] 

1140 # simplify_down ensure there are only one HLGExpr per optimizer/finalizer 

1141 for hlgexpr in self.operands: 

1142 keys = hlgexpr.__dask_keys__() 

1143 dsk = hlgexpr.hlg 

1144 if (optimizer := hlgexpr.low_level_optimizer) is not None: 

1145 dsk = optimizer(dsk, keys) 

1146 graphs.append(dsk) 

1147 

1148 return HighLevelGraph.merge(*graphs) 

1149 

1150 def __dask_graph__(self): 

1151 # This class has to override this and not just _layer to ensure the HLGs 

1152 # are not optimized individually 

1153 return ensure_dict(self._optimized_dsk) 

1154 

1155 _layer = __dask_graph__ 

1156 

1157 def __dask_annotations__(self) -> dict[str, dict[Key, object]]: 

1158 # optimization has to be called (and cached) since blockwise fusion can 

1159 # alter annotations 

1160 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)` 

1161 dsk = self._optimized_dsk 

1162 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict) 

1163 for layer in dsk.layers.values(): 

1164 if layer.annotations: 

1165 annot = layer.annotations 

1166 for annot_type, value in annot.items(): 

1167 annots = list( 

1168 (k, (value(k) if callable(value) else value)) for k in layer 

1169 ) 

1170 annotations_by_type[annot_type].update( 

1171 { 

1172 k: v 

1173 for k, v in annots 

1174 if not isinstance(v, _AnnotationsTombstone) 

1175 } 

1176 ) 

1177 if not annotations_by_type[annot_type]: 

1178 del annotations_by_type[annot_type] 

1179 return dict(annotations_by_type) 

1180 

1181 def __dask_keys__(self) -> list: 

1182 all_keys = [] 

1183 for op in self.operands: 

1184 if isinstance(op, _HLGExprGroup): 

1185 all_keys.extend(op.__dask_keys__()) 

1186 else: 

1187 all_keys.append(op.__dask_keys__()) 

1188 return all_keys 

1189 

1190 

1191class _ExprSequence(Expr): 

1192 """A sequence of expressions 

1193 

1194 This is used to be able to optimize multiple collections combined, e.g. when 

1195 being computed simultaneously with ``dask.compute((Expr1, Expr2))``. 

1196 """ 

1197 

1198 def __getitem__(self, other): 

1199 return self.operands[other] 

1200 

1201 def _layer(self) -> dict: 

1202 return toolz.merge(op._layer() for op in self.operands) 

1203 

1204 def __dask_keys__(self) -> list: 

1205 all_keys = [] 

1206 for op in self.operands: 

1207 all_keys.append(list(op.__dask_keys__())) 

1208 return all_keys 

1209 

1210 def __repr__(self): 

1211 return "ExprSequence(" + ", ".join(map(repr, self.operands)) + ")" 

1212 

1213 __str__ = __repr__ 

1214 

1215 def finalize_compute(self): 

1216 return _ExprSequence( 

1217 *(op.finalize_compute() for op in self.operands), 

1218 ) 

1219 

1220 def __dask_annotations__(self): 

1221 annotations_by_type = {} 

1222 for op in self.operands: 

1223 for k, v in op.__dask_annotations__().items(): 

1224 annotations_by_type.setdefault(k, {}).update(v) 

1225 return annotations_by_type 

1226 

1227 def __len__(self): 

1228 return len(self.operands) 

1229 

1230 def __iter__(self): 

1231 return iter(self.operands) 

1232 

1233 def _simplify_down(self): 

1234 from dask.highlevelgraph import HighLevelGraph 

1235 

1236 issue_warning = False 

1237 hlgs = [] 

1238 for op in self.operands: 

1239 if isinstance(op, (HLGExpr, HLGFinalizeCompute)): 

1240 hlgs.append(op) 

1241 elif isinstance(op, dict): 

1242 hlgs.append( 

1243 HLGExpr( 

1244 dsk=HighLevelGraph.from_collections( 

1245 str(id(op)), op, dependencies=() 

1246 ) 

1247 ) 

1248 ) 

1249 elif hlgs: 

1250 issue_warning = True 

1251 opt = op.optimize() 

1252 hlgs.append( 

1253 HLGExpr( 

1254 dsk=HighLevelGraph.from_collections( 

1255 opt._name, opt.__dask_graph__(), dependencies=() 

1256 ) 

1257 ) 

1258 ) 

1259 if issue_warning: 

1260 warnings.warn( 

1261 "Computing mixed collections that are backed by " 

1262 "HighlevelGraphs/dicts and Expressions. " 

1263 "This forces Expressions to be materialized. " 

1264 "It is recommended to use only one type and separate the dask." 

1265 "compute calls if necessary.", 

1266 UserWarning, 

1267 ) 

1268 if not hlgs: 

1269 return None 

1270 return _HLGExprSequence(*hlgs) 

1271 

1272 

1273class _AnnotationsTombstone: ... 

1274 

1275 

1276class FinalizeCompute(Expr): 

1277 _parameters = ["expr"] 

1278 

1279 def _simplify_down(self): 

1280 return self.expr.finalize_compute() 

1281 

1282 

1283def _convert_dask_keys(keys): 

1284 from dask._task_spec import List, TaskRef 

1285 

1286 assert isinstance(keys, list) 

1287 new_keys = [] 

1288 for key in keys: 

1289 if isinstance(key, list): 

1290 new_keys.append(_convert_dask_keys(key)) 

1291 else: 

1292 new_keys.append(TaskRef(key)) 

1293 return List(*new_keys) 

1294 

1295 

1296class HLGFinalizeCompute(HLGExpr): 

1297 

1298 def _simplify_down(self): 

1299 if not self.postcompute: 

1300 return self.dsk 

1301 

1302 from dask.delayed import Delayed 

1303 

1304 # Skip finalization for Delayed 

1305 if self.dsk.postcompute == Delayed.__dask_postcompute__(self.dsk): 

1306 return self.dsk 

1307 return self 

1308 

1309 @property 

1310 def _name(self): 

1311 return f"finalize-{super()._name}" 

1312 

1313 def __dask_graph__(self): 

1314 # The baseclass __dask_graph__ will not just materialize this layer but 

1315 # also that of its dependencies, i.e. it will render the finalized and 

1316 # the non-finalized graph and combine them. We only want the finalized 

1317 # so we're overriding this. 

1318 # This is an artifact generated since the wrapped expression is 

1319 # identified automatically as a dependency but HLG expressions are not 

1320 # working in this layered way. 

1321 return self._layer() 

1322 

1323 @property 

1324 def hlg(self): 

1325 expr = self.operand("dsk") 

1326 layers = expr.dsk.layers.copy() 

1327 deps = expr.dsk.dependencies.copy() 

1328 keys = expr.__dask_keys__() 

1329 if isinstance(expr.postcompute, list): 

1330 postcomputes = expr.postcompute 

1331 else: 

1332 postcomputes = [expr.postcompute] 

1333 tasks = [ 

1334 Task(self._name, func, _convert_dask_keys(keys), *extra_args) 

1335 for func, extra_args in postcomputes 

1336 ] 

1337 from dask.highlevelgraph import HighLevelGraph, MaterializedLayer 

1338 

1339 leafs = set(deps) 

1340 for val in deps.values(): 

1341 leafs -= val 

1342 for t in tasks: 

1343 layers[t.key] = MaterializedLayer({t.key: t}) 

1344 deps[t.key] = leafs 

1345 return HighLevelGraph(layers, dependencies=deps) 

1346 

1347 def __dask_keys__(self): 

1348 return [self._name] 

1349 

1350 

1351class ProhibitReuse(Expr): 

1352 """ 

1353 An expression that guarantees that all keys are suffixes with a unique id. 

1354 This can be used to break a common subexpression apart. 

1355 """ 

1356 

1357 _parameters = ["expr"] 

1358 _ALLOWED_TYPES = [HLGExpr, LLGExpr, HLGFinalizeCompute, _HLGExprSequence] 

1359 

1360 def __dask_keys__(self): 

1361 return self._modify_keys(self.expr.__dask_keys__()) 

1362 

1363 @staticmethod 

1364 def _identity(obj): 

1365 return obj 

1366 

1367 @functools.cached_property 

1368 def _suffix(self): 

1369 return uuid.uuid4().hex 

1370 

1371 def _modify_keys(self, k): 

1372 if isinstance(k, list): 

1373 return [self._modify_keys(kk) for kk in k] 

1374 elif isinstance(k, tuple): 

1375 return (self._modify_keys(k[0]),) + k[1:] 

1376 elif isinstance(k, (int, float)): 

1377 k = str(k) 

1378 return f"{k}-{self._suffix}" 

1379 

1380 def _simplify_down(self): 

1381 # FIXME: Shuffling cannot be rewritten since the barrier key is 

1382 # hardcoded. Skipping this here should do the trick most of the time 

1383 if not isinstance( 

1384 self.expr, 

1385 tuple(self._ALLOWED_TYPES), 

1386 ): 

1387 return self.expr 

1388 

1389 def __dask_graph__(self): 

1390 try: 

1391 from distributed.shuffle._core import P2PBarrierTask 

1392 except ModuleNotFoundError: 

1393 P2PBarrierTask = type(None) 

1394 dsk = convert_legacy_graph(self.expr.__dask_graph__()) 

1395 

1396 subs = {old_key: self._modify_keys(old_key) for old_key in dsk} 

1397 dsk2 = {} 

1398 for old_key, new_key in subs.items(): 

1399 t = dsk[old_key] 

1400 if isinstance(t, P2PBarrierTask): 

1401 warnings.warn( 

1402 "Cannot block reusing for graphs including a " 

1403 "P2PBarrierTask. This may cause unexpected results. " 

1404 "This typically happens when converting a dask " 

1405 "DataFrame to delayed objects.", 

1406 UserWarning, 

1407 ) 

1408 return dsk 

1409 dsk2[new_key] = Task( 

1410 new_key, 

1411 ProhibitReuse._identity, 

1412 t.substitute(subs), 

1413 ) 

1414 

1415 dsk2.update(dsk) 

1416 return dsk2 

1417 

1418 _layer = __dask_graph__