Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/_expr.py: 21%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

731 statements  

1from __future__ import annotations 

2 

3import functools 

4import os 

5import uuid 

6import warnings 

7import weakref 

8from collections import defaultdict 

9from collections.abc import Generator 

10from typing import TYPE_CHECKING, Literal 

11 

12import toolz 

13 

14import dask 

15from dask._task_spec import Task, convert_legacy_graph 

16from dask.tokenize import _tokenize_deterministic 

17from dask.typing import Key 

18from dask.utils import ensure_dict, funcname, import_required 

19 

20if TYPE_CHECKING: 

21 # TODO import from typing (requires Python >=3.10) 

22 from typing import Any, TypeAlias 

23 

24 from dask.highlevelgraph import HighLevelGraph 

25 

26OptimizerStage: TypeAlias = Literal[ 

27 "logical", 

28 "simplified-logical", 

29 "tuned-logical", 

30 "physical", 

31 "simplified-physical", 

32 "fused", 

33] 

34 

35 

36def _unpack_collections(o): 

37 from dask.delayed import Delayed 

38 

39 if isinstance(o, Expr): 

40 return o 

41 

42 if hasattr(o, "expr") and not isinstance(o, Delayed): 

43 return o.expr 

44 else: 

45 return o 

46 

47 

48class Expr: 

49 _parameters: list[str] = [] 

50 _defaults: dict[str, Any] = {} 

51 

52 _pickle_functools_cache: bool = True 

53 

54 operands: list 

55 

56 _determ_token: str | None 

57 

58 def __new__(cls, *args, _determ_token=None, **kwargs): 

59 operands = list(args) 

60 for parameter in cls._parameters[len(operands) :]: 

61 try: 

62 operands.append(kwargs.pop(parameter)) 

63 except KeyError: 

64 operands.append(cls._defaults[parameter]) 

65 assert not kwargs, kwargs 

66 inst = object.__new__(cls) 

67 

68 inst._determ_token = _determ_token 

69 inst.operands = [_unpack_collections(o) for o in operands] 

70 # This is typically cached. Make sure the cache is populated by calling 

71 # it once 

72 inst._name 

73 return inst 

74 

75 def _tune_down(self): 

76 return None 

77 

78 def _tune_up(self, parent): 

79 return None 

80 

81 def finalize_compute(self): 

82 return self 

83 

84 def _operands_for_repr(self): 

85 return [ 

86 f"{param}={repr(op)}" for param, op in zip(self._parameters, self.operands) 

87 ] 

88 

89 def __str__(self): 

90 s = ", ".join(self._operands_for_repr()) 

91 return f"{type(self).__name__}({s})" 

92 

93 def __repr__(self): 

94 return str(self) 

95 

96 def _tree_repr_argument_construction(self, i, op, header): 

97 try: 

98 param = self._parameters[i] 

99 default = self._defaults[param] 

100 except (IndexError, KeyError): 

101 param = self._parameters[i] if i < len(self._parameters) else "" 

102 default = "--no-default--" 

103 

104 if repr(op) != repr(default): 

105 if param: 

106 header += f" {param}={repr(op)}" 

107 else: 

108 header += repr(op) 

109 return header 

110 

111 def _tree_repr_lines(self, indent=0, recursive=True): 

112 return " " * indent + repr(self) 

113 

114 def tree_repr(self): 

115 return os.linesep.join(self._tree_repr_lines()) 

116 

117 def analyze(self, filename: str | None = None, format: str | None = None) -> None: 

118 from dask.dataframe.dask_expr._expr import Expr as DFExpr 

119 from dask.dataframe.dask_expr.diagnostics import analyze 

120 

121 if not isinstance(self, DFExpr): 

122 raise TypeError( 

123 "analyze is only supported for dask.dataframe.Expr objects." 

124 ) 

125 return analyze(self, filename=filename, format=format) 

126 

127 def explain( 

128 self, stage: OptimizerStage = "fused", format: str | None = None 

129 ) -> None: 

130 from dask.dataframe.dask_expr.diagnostics import explain 

131 

132 return explain(self, stage, format) 

133 

134 def pprint(self): 

135 for line in self._tree_repr_lines(): 

136 print(line) 

137 

138 def __hash__(self): 

139 return hash(self._name) 

140 

141 def __dask_tokenize__(self): 

142 if not self._determ_token: 

143 # If the subclass does not implement a __dask_tokenize__ we'll want 

144 # to tokenize all operands. 

145 # Note how this differs to the implementation of 

146 # Expr.deterministic_token 

147 self._determ_token = _tokenize_deterministic(type(self), *self.operands) 

148 return self._determ_token 

149 

150 def __dask_keys__(self): 

151 """The keys for this expression 

152 

153 This is used to determine the keys of the output collection 

154 when this expression is computed. 

155 

156 Returns 

157 ------- 

158 keys: list 

159 The keys for this expression 

160 """ 

161 return [(self._name, i) for i in range(self.npartitions)] 

162 

163 @staticmethod 

164 def _reconstruct(*args): 

165 typ, *operands, token, cache = args 

166 inst = typ(*operands, _determ_token=token) 

167 for k, v in cache.items(): 

168 inst.__dict__[k] = v 

169 return inst 

170 

171 def __reduce__(self): 

172 if dask.config.get("dask-expr-no-serialize", False): 

173 raise RuntimeError(f"Serializing a {type(self)} object") 

174 cache = {} 

175 if type(self)._pickle_functools_cache: 

176 for k, v in type(self).__dict__.items(): 

177 if isinstance(v, functools.cached_property) and k in self.__dict__: 

178 cache[k] = getattr(self, k) 

179 

180 return Expr._reconstruct, ( 

181 type(self), 

182 *self.operands, 

183 self.deterministic_token, 

184 cache, 

185 ) 

186 

187 def _depth(self, cache=None): 

188 """Depth of the expression tree 

189 

190 Returns 

191 ------- 

192 depth: int 

193 """ 

194 if cache is None: 

195 cache = {} 

196 if not self.dependencies(): 

197 return 1 

198 else: 

199 result = [] 

200 for expr in self.dependencies(): 

201 if expr._name in cache: 

202 result.append(cache[expr._name]) 

203 else: 

204 result.append(expr._depth(cache) + 1) 

205 cache[expr._name] = result[-1] 

206 return max(result) 

207 

208 def __setattr__(self, name: str, value: Any) -> None: 

209 if name in ["operands", "_determ_token"]: 

210 object.__setattr__(self, name, value) 

211 return 

212 try: 

213 params = type(self)._parameters 

214 operands = object.__getattribute__(self, "operands") 

215 operands[params.index(name)] = value 

216 except ValueError: 

217 raise AttributeError( 

218 f"{type(self).__name__} object has no attribute {name}" 

219 ) 

220 

221 def operand(self, key): 

222 # Access an operand unambiguously 

223 # (e.g. if the key is reserved by a method/property) 

224 return self.operands[type(self)._parameters.index(key)] 

225 

226 def dependencies(self): 

227 # Dependencies are `Expr` operands only 

228 return [operand for operand in self.operands if isinstance(operand, Expr)] 

229 

230 def _task(self, key: Key, index: int) -> Task: 

231 """The task for the i'th partition 

232 

233 Parameters 

234 ---------- 

235 index: 

236 The index of the partition of this dataframe 

237 

238 Examples 

239 -------- 

240 >>> class Add(Expr): 

241 ... def _task(self, i): 

242 ... return Task( 

243 ... self.__dask_keys__()[i], 

244 ... operator.add, 

245 ... TaskRef((self.left._name, i)), 

246 ... TaskRef((self.right._name, i)) 

247 ... ) 

248 

249 Returns 

250 ------- 

251 task: 

252 The Dask task to compute this partition 

253 

254 See Also 

255 -------- 

256 Expr._layer 

257 """ 

258 raise NotImplementedError( 

259 "Expressions should define either _layer (full dictionary) or _task" 

260 f" (single task). This expression {type(self)} defines neither" 

261 ) 

262 

263 def _layer(self) -> dict: 

264 """The graph layer added by this expression. 

265 

266 Simple expressions that apply one task per partition can choose to only 

267 implement `Expr._task` instead. 

268 

269 Examples 

270 -------- 

271 >>> class Add(Expr): 

272 ... def _layer(self): 

273 ... return { 

274 ... name: Task( 

275 ... name, 

276 ... operator.add, 

277 ... TaskRef((self.left._name, i)), 

278 ... TaskRef((self.right._name, i)) 

279 ... ) 

280 ... for i, name in enumerate(self.__dask_keys__()) 

281 ... } 

282 

283 Returns 

284 ------- 

285 layer: dict 

286 The Dask task graph added by this expression 

287 

288 See Also 

289 -------- 

290 Expr._task 

291 Expr.__dask_graph__ 

292 """ 

293 

294 return { 

295 (self._name, i): self._task((self._name, i), i) 

296 for i in range(self.npartitions) 

297 } 

298 

299 def rewrite(self, kind: str, rewritten): 

300 """Rewrite an expression 

301 

302 This leverages the ``._{kind}_down`` and ``._{kind}_up`` 

303 methods defined on each class 

304 

305 Returns 

306 ------- 

307 expr: 

308 output expression 

309 changed: 

310 whether or not any change occurred 

311 """ 

312 if self._name in rewritten: 

313 return rewritten[self._name] 

314 

315 expr = self 

316 down_name = f"_{kind}_down" 

317 up_name = f"_{kind}_up" 

318 while True: 

319 _continue = False 

320 

321 # Rewrite this node 

322 out = getattr(expr, down_name)() 

323 if out is None: 

324 out = expr 

325 if not isinstance(out, Expr): 

326 return out 

327 if out._name != expr._name: 

328 expr = out 

329 continue 

330 

331 # Allow children to rewrite their parents 

332 for child in expr.dependencies(): 

333 out = getattr(child, up_name)(expr) 

334 if out is None: 

335 out = expr 

336 if not isinstance(out, Expr): 

337 return out 

338 if out is not expr and out._name != expr._name: 

339 expr = out 

340 _continue = True 

341 break 

342 

343 if _continue: 

344 continue 

345 

346 # Rewrite all of the children 

347 new_operands = [] 

348 changed = False 

349 for operand in expr.operands: 

350 if isinstance(operand, Expr): 

351 new = operand.rewrite(kind=kind, rewritten=rewritten) 

352 rewritten[operand._name] = new 

353 if new._name != operand._name: 

354 changed = True 

355 else: 

356 new = operand 

357 new_operands.append(new) 

358 

359 if changed: 

360 expr = type(expr)(*new_operands) 

361 continue 

362 else: 

363 break 

364 

365 return expr 

366 

367 def simplify_once(self, dependents: defaultdict, simplified: dict): 

368 """Simplify an expression 

369 

370 This leverages the ``._simplify_down`` and ``._simplify_up`` 

371 methods defined on each class 

372 

373 Parameters 

374 ---------- 

375 

376 dependents: defaultdict[list] 

377 The dependents for every node. 

378 simplified: dict 

379 Cache of simplified expressions for these dependents. 

380 

381 Returns 

382 ------- 

383 expr: 

384 output expression 

385 """ 

386 # Check if we've already simplified for these dependents 

387 if self._name in simplified: 

388 return simplified[self._name] 

389 

390 expr = self 

391 

392 while True: 

393 out = expr._simplify_down() 

394 if out is None: 

395 out = expr 

396 if not isinstance(out, Expr): 

397 return out 

398 if out._name != expr._name: 

399 expr = out 

400 

401 # Allow children to simplify their parents 

402 for child in expr.dependencies(): 

403 out = child._simplify_up(expr, dependents) 

404 if out is None: 

405 out = expr 

406 

407 if not isinstance(out, Expr): 

408 return out 

409 if out is not expr and out._name != expr._name: 

410 expr = out 

411 break 

412 

413 # Rewrite all of the children 

414 new_operands = [] 

415 changed = False 

416 for operand in expr.operands: 

417 if isinstance(operand, Expr): 

418 # Bandaid for now, waiting for Singleton 

419 dependents[operand._name].append(weakref.ref(expr)) 

420 new = operand.simplify_once( 

421 dependents=dependents, simplified=simplified 

422 ) 

423 simplified[operand._name] = new 

424 if new._name != operand._name: 

425 changed = True 

426 else: 

427 new = operand 

428 new_operands.append(new) 

429 

430 if changed: 

431 expr = type(expr)(*new_operands) 

432 

433 break 

434 

435 return expr 

436 

437 def optimize(self, fuse: bool = False) -> Expr: 

438 stage: OptimizerStage = "fused" if fuse else "simplified-physical" 

439 

440 return optimize_until(self, stage) 

441 

442 def fuse(self) -> Expr: 

443 return self 

444 

445 def simplify(self) -> Expr: 

446 expr = self 

447 seen = set() 

448 while True: 

449 dependents = collect_dependents(expr) 

450 new = expr.simplify_once(dependents=dependents, simplified={}) 

451 if new._name == expr._name: 

452 break 

453 if new._name in seen: 

454 raise RuntimeError( 

455 f"Optimizer does not converge. {expr!r} simplified to {new!r} which was already seen. " 

456 "Please report this issue on the dask issue tracker with a minimal reproducer." 

457 ) 

458 seen.add(new._name) 

459 expr = new 

460 return expr 

461 

462 def _simplify_down(self): 

463 return 

464 

465 def _simplify_up(self, parent, dependents): 

466 return 

467 

468 def lower_once(self, lowered: dict): 

469 # Check for a cached result 

470 try: 

471 return lowered[self._name] 

472 except KeyError: 

473 pass 

474 

475 expr = self 

476 

477 # Lower this node 

478 out = expr._lower() 

479 if out is None: 

480 out = expr 

481 if not isinstance(out, Expr): 

482 return out 

483 

484 # Lower all children 

485 new_operands = [] 

486 changed = False 

487 for operand in out.operands: 

488 if isinstance(operand, Expr): 

489 new = operand.lower_once(lowered) 

490 if new._name != operand._name: 

491 changed = True 

492 else: 

493 new = operand 

494 new_operands.append(new) 

495 

496 if changed: 

497 out = type(out)(*new_operands) 

498 

499 # Cache the result and return 

500 return lowered.setdefault(self._name, out) 

501 

502 def lower_completely(self) -> Expr: 

503 """Lower an expression completely 

504 

505 This calls the ``lower_once`` method in a loop 

506 until nothing changes. This function does not 

507 apply any other optimizations (like ``simplify``). 

508 

509 Returns 

510 ------- 

511 expr: 

512 output expression 

513 

514 See Also 

515 -------- 

516 Expr.lower_once 

517 Expr._lower 

518 """ 

519 # Lower until nothing changes 

520 expr = self 

521 lowered: dict = {} 

522 while True: 

523 new = expr.lower_once(lowered) 

524 if new._name == expr._name: 

525 break 

526 expr = new 

527 return expr 

528 

529 def _lower(self): 

530 return 

531 

532 @functools.cached_property 

533 def _funcname(self) -> str: 

534 return funcname(type(self)).lower() 

535 

536 @property 

537 def deterministic_token(self): 

538 if not self._determ_token: 

539 # Just tokenize self to fall back on __dask_tokenize__ 

540 # Note how this differs to the implementation of __dask_tokenize__ 

541 self._determ_token = self.__dask_tokenize__() 

542 return self._determ_token 

543 

544 @functools.cached_property 

545 def _name(self) -> str: 

546 return self._funcname + "-" + self.deterministic_token 

547 

548 @property 

549 def _meta(self): 

550 raise NotImplementedError() 

551 

552 @classmethod 

553 def _annotations_tombstone(cls) -> _AnnotationsTombstone: 

554 return _AnnotationsTombstone() 

555 

556 def __dask_annotations__(self): 

557 return {} 

558 

559 def __dask_graph__(self): 

560 """Traverse expression tree, collect layers 

561 

562 Subclasses generally do not want to override this method unless custom 

563 logic is required to treat (e.g. ignore) specific operands during graph 

564 generation. 

565 

566 See also 

567 -------- 

568 Expr._layer 

569 Expr._task 

570 """ 

571 stack = [self] 

572 seen = set() 

573 layers = [] 

574 while stack: 

575 expr = stack.pop() 

576 

577 if expr._name in seen: 

578 continue 

579 seen.add(expr._name) 

580 

581 layers.append(expr._layer()) 

582 for operand in expr.dependencies(): 

583 stack.append(operand) 

584 

585 return toolz.merge(layers) 

586 

587 @property 

588 def dask(self): 

589 return self.__dask_graph__() 

590 

591 def substitute(self, old, new) -> Expr: 

592 """Substitute a specific term within the expression 

593 

594 Note that replacing non-`Expr` terms may produce 

595 unexpected results, and is not recommended. 

596 Substituting boolean values is not allowed. 

597 

598 Parameters 

599 ---------- 

600 old: 

601 Old term to find and replace. 

602 new: 

603 New term to replace instances of `old` with. 

604 

605 Examples 

606 -------- 

607 >>> (df + 10).substitute(10, 20) # doctest: +SKIP 

608 df + 20 

609 """ 

610 return self._substitute(old, new, _seen=set()) 

611 

612 def _substitute(self, old, new, _seen): 

613 if self._name in _seen: 

614 return self 

615 # Check if we are replacing a literal 

616 if isinstance(old, Expr): 

617 substitute_literal = False 

618 if self._name == old._name: 

619 return new 

620 else: 

621 substitute_literal = True 

622 if isinstance(old, bool): 

623 raise TypeError("Arguments to `substitute` cannot be bool.") 

624 

625 new_exprs = [] 

626 update = False 

627 for operand in self.operands: 

628 if isinstance(operand, Expr): 

629 val = operand._substitute(old, new, _seen) 

630 if operand._name != val._name: 

631 update = True 

632 new_exprs.append(val) 

633 elif ( 

634 "Fused" in type(self).__name__ 

635 and isinstance(operand, list) 

636 and all(isinstance(op, Expr) for op in operand) 

637 ): 

638 # Special handling for `Fused`. 

639 # We make no promise to dive through a 

640 # list operand in general, but NEED to 

641 # do so for the `Fused.exprs` operand. 

642 val = [] 

643 for op in operand: 

644 val.append(op._substitute(old, new, _seen)) 

645 if val[-1]._name != op._name: 

646 update = True 

647 new_exprs.append(val) 

648 elif ( 

649 substitute_literal 

650 and not isinstance(operand, bool) 

651 and isinstance(operand, type(old)) 

652 and operand == old 

653 ): 

654 new_exprs.append(new) 

655 update = True 

656 else: 

657 new_exprs.append(operand) 

658 

659 if update: # Only recreate if something changed 

660 return type(self)(*new_exprs) 

661 else: 

662 _seen.add(self._name) 

663 return self 

664 

665 def substitute_parameters(self, substitutions: dict) -> Expr: 

666 """Substitute specific `Expr` parameters 

667 

668 Parameters 

669 ---------- 

670 substitutions: 

671 Mapping of parameter keys to new values. Keys that 

672 are not found in ``self._parameters`` will be ignored. 

673 """ 

674 if not substitutions: 

675 return self 

676 

677 changed = False 

678 new_operands = [] 

679 for i, operand in enumerate(self.operands): 

680 if i < len(self._parameters) and self._parameters[i] in substitutions: 

681 new_operands.append(substitutions[self._parameters[i]]) 

682 changed = True 

683 else: 

684 new_operands.append(operand) 

685 if changed: 

686 return type(self)(*new_operands) 

687 return self 

688 

689 def _node_label_args(self): 

690 """Operands to include in the node label by `visualize`""" 

691 return self.dependencies() 

692 

693 def _to_graphviz( 

694 self, 

695 rankdir="BT", 

696 graph_attr=None, 

697 node_attr=None, 

698 edge_attr=None, 

699 **kwargs, 

700 ): 

701 from dask.dot import label, name 

702 

703 graphviz = import_required( 

704 "graphviz", 

705 "Drawing dask graphs with the graphviz visualization engine requires the `graphviz` " 

706 "python library and the `graphviz` system library.\n\n" 

707 "Please either conda or pip install as follows:\n\n" 

708 " conda install python-graphviz # either conda install\n" 

709 " python -m pip install graphviz # or pip install and follow installation instructions", 

710 ) 

711 

712 graph_attr = graph_attr or {} 

713 node_attr = node_attr or {} 

714 edge_attr = edge_attr or {} 

715 

716 graph_attr["rankdir"] = rankdir 

717 node_attr["shape"] = "box" 

718 node_attr["fontname"] = "helvetica" 

719 

720 graph_attr.update(kwargs) 

721 g = graphviz.Digraph( 

722 graph_attr=graph_attr, 

723 node_attr=node_attr, 

724 edge_attr=edge_attr, 

725 ) 

726 

727 stack = [self] 

728 seen = set() 

729 dependencies = {} 

730 while stack: 

731 expr = stack.pop() 

732 

733 if expr._name in seen: 

734 continue 

735 seen.add(expr._name) 

736 

737 dependencies[expr] = set(expr.dependencies()) 

738 for dep in expr.dependencies(): 

739 stack.append(dep) 

740 

741 cache = {} 

742 for expr in dependencies: 

743 expr_name = name(expr) 

744 attrs = {} 

745 

746 # Make node label 

747 deps = [ 

748 funcname(type(dep)) if isinstance(dep, Expr) else str(dep) 

749 for dep in expr._node_label_args() 

750 ] 

751 _label = funcname(type(expr)) 

752 if deps: 

753 _label = f"{_label}({', '.join(deps)})" if deps else _label 

754 node_label = label(_label, cache=cache) 

755 

756 attrs.setdefault("label", str(node_label)) 

757 attrs.setdefault("fontsize", "20") 

758 g.node(expr_name, **attrs) 

759 

760 for expr, deps in dependencies.items(): 

761 expr_name = name(expr) 

762 for dep in deps: 

763 dep_name = name(dep) 

764 g.edge(dep_name, expr_name) 

765 

766 return g 

767 

768 def visualize(self, filename="dask-expr.svg", format=None, **kwargs): 

769 """ 

770 Visualize the expression graph. 

771 Requires ``graphviz`` to be installed. 

772 

773 Parameters 

774 ---------- 

775 filename : str or None, optional 

776 The name of the file to write to disk. If the provided `filename` 

777 doesn't include an extension, '.png' will be used by default. 

778 If `filename` is None, no file will be written, and the graph is 

779 rendered in the Jupyter notebook only. 

780 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional 

781 Format in which to write output file. Default is 'svg'. 

782 **kwargs 

783 Additional keyword arguments to forward to ``to_graphviz``. 

784 """ 

785 from dask.dot import graphviz_to_file 

786 

787 g = self._to_graphviz(**kwargs) 

788 graphviz_to_file(g, filename, format) 

789 return g 

790 

791 def walk(self) -> Generator[Expr]: 

792 """Iterate through all expressions in the tree 

793 

794 Returns 

795 ------- 

796 nodes 

797 Generator of Expr instances in the graph. 

798 Ordering is a depth-first search of the expression tree 

799 """ 

800 stack = [self] 

801 seen = set() 

802 while stack: 

803 node = stack.pop() 

804 if node._name in seen: 

805 continue 

806 seen.add(node._name) 

807 

808 for dep in node.dependencies(): 

809 stack.append(dep) 

810 

811 yield node 

812 

813 def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]: 

814 """Search the expression graph for a specific operation type 

815 

816 Parameters 

817 ---------- 

818 operation 

819 The operation type to search for. 

820 

821 Returns 

822 ------- 

823 nodes 

824 Generator of `operation` instances. Ordering corresponds 

825 to a depth-first search of the expression graph. 

826 """ 

827 assert ( 

828 isinstance(operation, tuple) 

829 and all(issubclass(e, Expr) for e in operation) 

830 or issubclass(operation, Expr) # type: ignore 

831 ), "`operation` must be`Expr` subclass)" 

832 return (expr for expr in self.walk() if isinstance(expr, operation)) 

833 

834 def __getattr__(self, key): 

835 try: 

836 return object.__getattribute__(self, key) 

837 except AttributeError as err: 

838 if key.startswith("_meta"): 

839 # Avoid a recursive loop if/when `self._meta*` 

840 # produces an `AttributeError` 

841 raise RuntimeError( 

842 f"Failed to generate metadata for {self}. " 

843 "This operation may not be supported by the current backend." 

844 ) 

845 

846 # Allow operands to be accessed as attributes 

847 # as long as the keys are not already reserved 

848 # by existing methods/properties 

849 _parameters = type(self)._parameters 

850 if key in _parameters: 

851 idx = _parameters.index(key) 

852 return self.operands[idx] 

853 

854 raise AttributeError( 

855 f"{err}\n\n" 

856 "This often means that you are attempting to use an unsupported " 

857 f"API function.." 

858 ) 

859 

860 

861class SingletonExpr(Expr): 

862 """A singleton Expr class 

863 

864 This is used to treat the subclassed expression as a singleton. Singletons 

865 are deduplicated by expr._name which is typically based on the dask.tokenize 

866 output. 

867 

868 This is a crucial performance optimization for expressions that walk through 

869 an optimizer and are recreated repeatedly but isn't safe for objects that 

870 cannot be reliably or quickly tokenized. 

871 """ 

872 

873 _instances: weakref.WeakValueDictionary[str, SingletonExpr] 

874 

875 def __new__(cls, *args, _determ_token=None, **kwargs): 

876 if not hasattr(cls, "_instances"): 

877 cls._instances = weakref.WeakValueDictionary() 

878 inst = super().__new__(cls, *args, _determ_token=_determ_token, **kwargs) 

879 _name = inst._name 

880 if _name in cls._instances and cls.__init__ == object.__init__: 

881 return cls._instances[_name] 

882 

883 cls._instances[_name] = inst 

884 return inst 

885 

886 

887def collect_dependents(expr) -> defaultdict: 

888 dependents = defaultdict(list) 

889 stack = [expr] 

890 seen = set() 

891 while stack: 

892 node = stack.pop() 

893 if node._name in seen: 

894 continue 

895 seen.add(node._name) 

896 

897 for dep in node.dependencies(): 

898 stack.append(dep) 

899 dependents[dep._name].append(weakref.ref(node)) 

900 return dependents 

901 

902 

903def optimize(expr: Expr, fuse: bool = True) -> Expr: 

904 """High level query optimization 

905 

906 This leverages three optimization passes: 

907 

908 1. Class based simplification using the ``_simplify`` function and methods 

909 2. Blockwise fusion 

910 

911 Parameters 

912 ---------- 

913 expr: 

914 Input expression to optimize 

915 fuse: 

916 whether or not to turn on blockwise fusion 

917 

918 See Also 

919 -------- 

920 simplify 

921 optimize_blockwise_fusion 

922 """ 

923 stage: OptimizerStage = "fused" if fuse else "simplified-physical" 

924 

925 return optimize_until(expr, stage) 

926 

927 

928def optimize_until(expr: Expr, stage: OptimizerStage) -> Expr: 

929 result = expr 

930 if stage == "logical": 

931 return result 

932 

933 # Simplify 

934 expr = result.simplify() 

935 if stage == "simplified-logical": 

936 return expr 

937 

938 # Manipulate Expression to make it more efficient 

939 expr = expr.rewrite(kind="tune", rewritten={}) 

940 if stage == "tuned-logical": 

941 return expr 

942 

943 # Lower 

944 expr = expr.lower_completely() 

945 if stage == "physical": 

946 return expr 

947 

948 # Simplify again 

949 expr = expr.simplify() 

950 if stage == "simplified-physical": 

951 return expr 

952 

953 # Final graph-specific optimizations 

954 expr = expr.fuse() 

955 if stage == "fused": 

956 return expr 

957 

958 raise ValueError(f"Stage {stage!r} not supported.") 

959 

960 

961class LLGExpr(Expr): 

962 """Low Level Graph Expression""" 

963 

964 _parameters = ["dsk"] 

965 

966 def __dask_keys__(self): 

967 return list(self.operand("dsk")) 

968 

969 def _layer(self) -> dict: 

970 return ensure_dict(self.operand("dsk")) 

971 

972 

973class HLGExpr(Expr): 

974 _parameters = [ 

975 "dsk", 

976 "low_level_optimizer", 

977 "output_keys", 

978 "postcompute", 

979 "_cached_optimized", 

980 ] 

981 _defaults = { 

982 "low_level_optimizer": None, 

983 "output_keys": None, 

984 "postcompute": None, 

985 "_cached_optimized": None, 

986 } 

987 

988 @property 

989 def hlg(self): 

990 return self.operand("dsk") 

991 

992 @staticmethod 

993 def from_collection(collection, optimize_graph=True): 

994 from dask.highlevelgraph import HighLevelGraph 

995 

996 if hasattr(collection, "dask"): 

997 dsk = collection.dask.copy() 

998 else: 

999 dsk = collection.__dask_graph__() 

1000 

1001 # Delayed objects still ship with low level graphs as `dask` when going 

1002 # through optimize / persist 

1003 if not isinstance(dsk, HighLevelGraph): 

1004 

1005 dsk = HighLevelGraph.from_collections( 

1006 str(id(collection)), dsk, dependencies=() 

1007 ) 

1008 if optimize_graph and not hasattr(collection, "__dask_optimize__"): 

1009 warnings.warn( 

1010 f"Collection {type(collection)} does not define a " 

1011 "`__dask_optimize__` method. In the future this will raise. " 

1012 "If no optimization is desired, please set this to `None`.", 

1013 PendingDeprecationWarning, 

1014 ) 

1015 low_level_optimizer = None 

1016 else: 

1017 low_level_optimizer = ( 

1018 collection.__dask_optimize__ if optimize_graph else None 

1019 ) 

1020 return HLGExpr( 

1021 dsk=dsk, 

1022 low_level_optimizer=low_level_optimizer, 

1023 output_keys=collection.__dask_keys__(), 

1024 postcompute=collection.__dask_postcompute__(), 

1025 ) 

1026 

1027 def finalize_compute(self): 

1028 return HLGFinalizeCompute( 

1029 self, 

1030 low_level_optimizer=self.low_level_optimizer, 

1031 output_keys=self.output_keys, 

1032 postcompute=self.postcompute, 

1033 ) 

1034 

1035 def __dask_annotations__(self) -> dict[str, dict[Key, object]]: 

1036 # optimization has to be called (and cached) since blockwise fusion can 

1037 # alter annotations 

1038 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)` 

1039 dsk = self._optimized_dsk 

1040 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict) 

1041 for layer in dsk.layers.values(): 

1042 if layer.annotations: 

1043 annot = layer.annotations 

1044 for annot_type, value in annot.items(): 

1045 annotations_by_type[annot_type].update( 

1046 {k: (value(k) if callable(value) else value) for k in layer} 

1047 ) 

1048 return dict(annotations_by_type) 

1049 

1050 def __dask_keys__(self): 

1051 if (keys := self.operand("output_keys")) is not None: 

1052 return keys 

1053 dsk = self.hlg 

1054 # Note: This will materialize 

1055 dependencies = dsk.get_all_dependencies() 

1056 leafs = set(dependencies) 

1057 for val in dependencies.values(): 

1058 leafs -= val 

1059 self.output_keys = list(leafs) 

1060 return self.output_keys 

1061 

1062 @functools.cached_property 

1063 def _optimized_dsk(self) -> HighLevelGraph: 

1064 from dask.highlevelgraph import HighLevelGraph 

1065 

1066 optimizer = self.low_level_optimizer 

1067 keys = self.__dask_keys__() 

1068 dsk = self.hlg 

1069 if (optimizer := self.low_level_optimizer) is not None: 

1070 dsk = optimizer(dsk, keys) 

1071 return HighLevelGraph.merge(dsk) 

1072 

1073 @property 

1074 def deterministic_token(self): 

1075 if not self._determ_token: 

1076 self._determ_token = uuid.uuid4().hex 

1077 return self._determ_token 

1078 

1079 def _layer(self) -> dict: 

1080 dsk = self._optimized_dsk 

1081 return ensure_dict(dsk) 

1082 

1083 

1084class _HLGExprGroup(HLGExpr): 

1085 # Identical to HLGExpr 

1086 # Used internally to determine how output keys are supposed to be returned 

1087 pass 

1088 

1089 

1090class _HLGExprSequence(Expr): 

1091 

1092 def __getitem__(self, other): 

1093 return self.operands[other] 

1094 

1095 def _operands_for_repr(self): 

1096 return [ 

1097 f"name={self.operand('name')!r}", 

1098 f"dsk={self.operand('dsk')!r}", 

1099 ] 

1100 

1101 def _tree_repr_lines(self, indent=0, recursive=True): 

1102 return self._operands_for_repr() 

1103 

1104 def finalize_compute(self): 

1105 return _HLGExprSequence(*[op.finalize_compute() for op in self.operands]) 

1106 

1107 def _tune_down(self): 

1108 if len(self.operands) == 1: 

1109 return None 

1110 from dask.highlevelgraph import HighLevelGraph 

1111 

1112 groups = toolz.groupby( 

1113 lambda x: x.low_level_optimizer if isinstance(x, HLGExpr) else None, 

1114 self.operands, 

1115 ) 

1116 exprs = [] 

1117 changed = False 

1118 for optimizer, group in groups.items(): 

1119 if len(group) > 1: 

1120 graphs = [expr.hlg for expr in group] 

1121 

1122 changed = True 

1123 dsk = HighLevelGraph.merge(*graphs) 

1124 hlg_group = _HLGExprGroup( 

1125 dsk=dsk, 

1126 low_level_optimizer=optimizer, 

1127 output_keys=[v.__dask_keys__() for v in group], 

1128 postcompute=[g.postcompute for g in group], 

1129 ) 

1130 exprs.append(hlg_group) 

1131 else: 

1132 exprs.append(group[0]) 

1133 if not changed: 

1134 return None 

1135 return _HLGExprSequence(*exprs) 

1136 

1137 @functools.cached_property 

1138 def _optimized_dsk(self) -> HighLevelGraph: 

1139 from dask.highlevelgraph import HighLevelGraph 

1140 

1141 hlgexpr: HLGExpr 

1142 graphs = [] 

1143 # simplify_down ensure there are only one HLGExpr per optimizer/finalizer 

1144 for hlgexpr in self.operands: 

1145 keys = hlgexpr.__dask_keys__() 

1146 dsk = hlgexpr.hlg 

1147 if (optimizer := hlgexpr.low_level_optimizer) is not None: 

1148 dsk = optimizer(dsk, keys) 

1149 graphs.append(dsk) 

1150 

1151 return HighLevelGraph.merge(*graphs) 

1152 

1153 def __dask_graph__(self): 

1154 # This class has to override this and not just _layer to ensure the HLGs 

1155 # are not optimized individually 

1156 return ensure_dict(self._optimized_dsk) 

1157 

1158 _layer = __dask_graph__ 

1159 

1160 def __dask_annotations__(self) -> dict[str, dict[Key, object]]: 

1161 # optimization has to be called (and cached) since blockwise fusion can 

1162 # alter annotations 

1163 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)` 

1164 dsk = self._optimized_dsk 

1165 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict) 

1166 for layer in dsk.layers.values(): 

1167 if layer.annotations: 

1168 annot = layer.annotations 

1169 for annot_type, value in annot.items(): 

1170 annots = list( 

1171 (k, (value(k) if callable(value) else value)) for k in layer 

1172 ) 

1173 annotations_by_type[annot_type].update( 

1174 { 

1175 k: v 

1176 for k, v in annots 

1177 if not isinstance(v, _AnnotationsTombstone) 

1178 } 

1179 ) 

1180 if not annotations_by_type[annot_type]: 

1181 del annotations_by_type[annot_type] 

1182 return dict(annotations_by_type) 

1183 

1184 def __dask_keys__(self) -> list: 

1185 all_keys = [] 

1186 for op in self.operands: 

1187 if isinstance(op, _HLGExprGroup): 

1188 all_keys.extend(op.__dask_keys__()) 

1189 else: 

1190 all_keys.append(op.__dask_keys__()) 

1191 return all_keys 

1192 

1193 

1194class _ExprSequence(Expr): 

1195 """A sequence of expressions 

1196 

1197 This is used to be able to optimize multiple collections combined, e.g. when 

1198 being computed simultaneously with ``dask.compute((Expr1, Expr2))``. 

1199 """ 

1200 

1201 def __getitem__(self, other): 

1202 return self.operands[other] 

1203 

1204 def _layer(self) -> dict: 

1205 return toolz.merge(op._layer() for op in self.operands) 

1206 

1207 def __dask_keys__(self) -> list: 

1208 all_keys = [] 

1209 for op in self.operands: 

1210 all_keys.append(list(op.__dask_keys__())) 

1211 return all_keys 

1212 

1213 def __repr__(self): 

1214 return "ExprSequence(" + ", ".join(map(repr, self.operands)) + ")" 

1215 

1216 __str__ = __repr__ 

1217 

1218 def finalize_compute(self): 

1219 return _ExprSequence( 

1220 *(op.finalize_compute() for op in self.operands), 

1221 ) 

1222 

1223 def __dask_annotations__(self): 

1224 annotations_by_type = {} 

1225 for op in self.operands: 

1226 for k, v in op.__dask_annotations__().items(): 

1227 annotations_by_type.setdefault(k, {}).update(v) 

1228 return annotations_by_type 

1229 

1230 def __len__(self): 

1231 return len(self.operands) 

1232 

1233 def __iter__(self): 

1234 return iter(self.operands) 

1235 

1236 def _simplify_down(self): 

1237 from dask.highlevelgraph import HighLevelGraph 

1238 

1239 issue_warning = False 

1240 hlgs = [] 

1241 for op in self.operands: 

1242 if isinstance(op, (HLGExpr, HLGFinalizeCompute)): 

1243 hlgs.append(op) 

1244 elif isinstance(op, dict): 

1245 hlgs.append( 

1246 HLGExpr( 

1247 dsk=HighLevelGraph.from_collections( 

1248 str(id(op)), op, dependencies=() 

1249 ) 

1250 ) 

1251 ) 

1252 elif hlgs: 

1253 issue_warning = True 

1254 opt = op.optimize() 

1255 hlgs.append( 

1256 HLGExpr( 

1257 dsk=HighLevelGraph.from_collections( 

1258 opt._name, opt.__dask_graph__(), dependencies=() 

1259 ) 

1260 ) 

1261 ) 

1262 if issue_warning: 

1263 warnings.warn( 

1264 "Computing mixed collections that are backed by " 

1265 "HighlevelGraphs/dicts and Expressions. " 

1266 "This forces Expressions to be materialized. " 

1267 "It is recommended to use only one type and separate the dask." 

1268 "compute calls if necessary.", 

1269 UserWarning, 

1270 ) 

1271 if not hlgs: 

1272 return None 

1273 return _HLGExprSequence(*hlgs) 

1274 

1275 

1276class _AnnotationsTombstone: ... 

1277 

1278 

1279class FinalizeCompute(Expr): 

1280 _parameters = ["expr"] 

1281 

1282 def _simplify_down(self): 

1283 return self.expr.finalize_compute() 

1284 

1285 

1286def _convert_dask_keys(keys): 

1287 from dask._task_spec import List, TaskRef 

1288 

1289 assert isinstance(keys, list) 

1290 new_keys = [] 

1291 for key in keys: 

1292 if isinstance(key, list): 

1293 new_keys.append(_convert_dask_keys(key)) 

1294 else: 

1295 new_keys.append(TaskRef(key)) 

1296 return List(*new_keys) 

1297 

1298 

1299class HLGFinalizeCompute(HLGExpr): 

1300 

1301 def _simplify_down(self): 

1302 if not self.postcompute: 

1303 return self.dsk 

1304 

1305 from dask.delayed import Delayed 

1306 

1307 # Skip finalization for Delayed 

1308 if self.dsk.postcompute == Delayed.__dask_postcompute__(self.dsk): 

1309 return self.dsk 

1310 return self 

1311 

1312 @property 

1313 def _name(self): 

1314 return f"finalize-{super()._name}" 

1315 

1316 def __dask_graph__(self): 

1317 # The baseclass __dask_graph__ will not just materialize this layer but 

1318 # also that of its dependencies, i.e. it will render the finalized and 

1319 # the non-finalized graph and combine them. We only want the finalized 

1320 # so we're overriding this. 

1321 # This is an artifact generated since the wrapped expression is 

1322 # identified automatically as a dependency but HLG expressions are not 

1323 # working in this layered way. 

1324 return self._layer() 

1325 

1326 @property 

1327 def hlg(self): 

1328 expr = self.operand("dsk") 

1329 layers = expr.dsk.layers.copy() 

1330 deps = expr.dsk.dependencies.copy() 

1331 keys = expr.__dask_keys__() 

1332 if isinstance(expr.postcompute, list): 

1333 postcomputes = expr.postcompute 

1334 else: 

1335 postcomputes = [expr.postcompute] 

1336 tasks = [ 

1337 Task(self._name, func, _convert_dask_keys(keys), *extra_args) 

1338 for func, extra_args in postcomputes 

1339 ] 

1340 from dask.highlevelgraph import HighLevelGraph, MaterializedLayer 

1341 

1342 leafs = set(deps) 

1343 for val in deps.values(): 

1344 leafs -= val 

1345 for t in tasks: 

1346 layers[t.key] = MaterializedLayer({t.key: t}) 

1347 deps[t.key] = leafs 

1348 return HighLevelGraph(layers, dependencies=deps) 

1349 

1350 def __dask_keys__(self): 

1351 return [self._name] 

1352 

1353 

1354class ProhibitReuse(Expr): 

1355 """ 

1356 An expression that guarantees that all keys are suffixes with a unique id. 

1357 This can be used to break a common subexpression apart. 

1358 """ 

1359 

1360 _parameters = ["expr"] 

1361 _ALLOWED_TYPES = [HLGExpr, LLGExpr, HLGFinalizeCompute, _HLGExprSequence] 

1362 

1363 def __dask_keys__(self): 

1364 return self._modify_keys(self.expr.__dask_keys__()) 

1365 

1366 @staticmethod 

1367 def _identity(obj): 

1368 return obj 

1369 

1370 @functools.cached_property 

1371 def _suffix(self): 

1372 return uuid.uuid4().hex 

1373 

1374 def _modify_keys(self, k): 

1375 if isinstance(k, list): 

1376 return [self._modify_keys(kk) for kk in k] 

1377 elif isinstance(k, tuple): 

1378 return (self._modify_keys(k[0]),) + k[1:] 

1379 elif isinstance(k, (int, float)): 

1380 k = str(k) 

1381 return f"{k}-{self._suffix}" 

1382 

1383 def _simplify_down(self): 

1384 # FIXME: Shuffling cannot be rewritten since the barrier key is 

1385 # hardcoded. Skipping this here should do the trick most of the time 

1386 if not isinstance( 

1387 self.expr, 

1388 tuple(self._ALLOWED_TYPES), 

1389 ): 

1390 return self.expr 

1391 

1392 def __dask_graph__(self): 

1393 try: 

1394 from distributed.shuffle._core import P2PBarrierTask 

1395 except ModuleNotFoundError: 

1396 P2PBarrierTask = type(None) 

1397 dsk = convert_legacy_graph(self.expr.__dask_graph__()) 

1398 

1399 subs = {old_key: self._modify_keys(old_key) for old_key in dsk} 

1400 dsk2 = {} 

1401 for old_key, new_key in subs.items(): 

1402 t = dsk[old_key] 

1403 if isinstance(t, P2PBarrierTask): 

1404 warnings.warn( 

1405 "Cannot block reusing for graphs including a " 

1406 "P2PBarrierTask. This may cause unexpected results. " 

1407 "This typically happens when converting a dask " 

1408 "DataFrame to delayed objects.", 

1409 UserWarning, 

1410 ) 

1411 return dsk 

1412 dsk2[new_key] = Task( 

1413 new_key, 

1414 ProhibitReuse._identity, 

1415 t.substitute(subs), 

1416 ) 

1417 

1418 dsk2.update(dsk) 

1419 return dsk2 

1420 

1421 _layer = __dask_graph__