Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/_expr.py: 21%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

732 statements  

1from __future__ import annotations 

2 

3import functools 

4import os 

5import uuid 

6import warnings 

7import weakref 

8from collections import defaultdict 

9from collections.abc import Generator 

10from typing import TYPE_CHECKING, Literal 

11 

12import toolz 

13 

14import dask 

15from dask._task_spec import Task, convert_legacy_graph 

16from dask.tokenize import _tokenize_deterministic 

17from dask.typing import Key 

18from dask.utils import ensure_dict, funcname, import_required 

19 

20if TYPE_CHECKING: 

21 # TODO import from typing (requires Python >=3.10) 

22 from typing import Any, TypeAlias 

23 

24 from dask.highlevelgraph import HighLevelGraph 

25 

26OptimizerStage: TypeAlias = Literal[ 

27 "logical", 

28 "simplified-logical", 

29 "tuned-logical", 

30 "physical", 

31 "simplified-physical", 

32 "fused", 

33] 

34 

35 

36def _unpack_collections(o): 

37 from dask.delayed import Delayed 

38 

39 if isinstance(o, Expr): 

40 return o 

41 

42 if hasattr(o, "expr") and not isinstance(o, Delayed): 

43 return o.expr 

44 else: 

45 return o 

46 

47 

48class Expr: 

49 _parameters: list[str] = [] 

50 _defaults: dict[str, Any] = {} 

51 

52 _pickle_functools_cache: bool = True 

53 

54 operands: list 

55 

56 _determ_token: str | None 

57 

58 def __new__(cls, *args, _determ_token=None, **kwargs): 

59 operands = list(args) 

60 for parameter in cls._parameters[len(operands) :]: 

61 try: 

62 operands.append(kwargs.pop(parameter)) 

63 except KeyError: 

64 operands.append(cls._defaults[parameter]) 

65 assert not kwargs, kwargs 

66 inst = object.__new__(cls) 

67 

68 inst._determ_token = _determ_token 

69 inst.operands = [_unpack_collections(o) for o in operands] 

70 # This is typically cached. Make sure the cache is populated by calling 

71 # it once 

72 inst._name 

73 return inst 

74 

75 def _tune_down(self): 

76 return None 

77 

78 def _tune_up(self, parent): 

79 return None 

80 

81 def finalize_compute(self): 

82 return self 

83 

84 def _operands_for_repr(self): 

85 return [f"{param}={op!r}" for param, op in zip(self._parameters, self.operands)] 

86 

87 def __str__(self): 

88 s = ", ".join(self._operands_for_repr()) 

89 return f"{type(self).__name__}({s})" 

90 

91 def __repr__(self): 

92 return str(self) 

93 

94 def _tree_repr_argument_construction(self, i, op, header): 

95 try: 

96 param = self._parameters[i] 

97 default = self._defaults[param] 

98 except (IndexError, KeyError): 

99 param = self._parameters[i] if i < len(self._parameters) else "" 

100 default = "--no-default--" 

101 

102 if repr(op) != repr(default): 

103 if param: 

104 header += f" {param}={op!r}" 

105 else: 

106 header += repr(op) 

107 return header 

108 

109 def _tree_repr_lines(self, indent=0, recursive=True): 

110 return " " * indent + repr(self) 

111 

112 def tree_repr(self): 

113 return os.linesep.join(self._tree_repr_lines()) 

114 

115 def analyze(self, filename: str | None = None, format: str | None = None) -> None: 

116 from dask.dataframe.dask_expr._expr import Expr as DFExpr 

117 from dask.dataframe.dask_expr.diagnostics import analyze 

118 

119 if not isinstance(self, DFExpr): 

120 raise TypeError( 

121 "analyze is only supported for dask.dataframe.Expr objects." 

122 ) 

123 return analyze(self, filename=filename, format=format) 

124 

125 def explain( 

126 self, stage: OptimizerStage = "fused", format: str | None = None 

127 ) -> None: 

128 from dask.dataframe.dask_expr.diagnostics import explain 

129 

130 return explain(self, stage, format) 

131 

132 def pprint(self): 

133 for line in self._tree_repr_lines(): 

134 print(line) 

135 

136 def __hash__(self): 

137 return hash(self._name) 

138 

139 def __dask_tokenize__(self): 

140 if not self._determ_token: 

141 # If the subclass does not implement a __dask_tokenize__ we'll want 

142 # to tokenize all operands. 

143 # Note how this differs to the implementation of 

144 # Expr.deterministic_token 

145 self._determ_token = _tokenize_deterministic(type(self), *self.operands) 

146 return self._determ_token 

147 

148 def __dask_keys__(self): 

149 """The keys for this expression 

150 

151 This is used to determine the keys of the output collection 

152 when this expression is computed. 

153 

154 Returns 

155 ------- 

156 keys: list 

157 The keys for this expression 

158 """ 

159 return [(self._name, i) for i in range(self.npartitions)] 

160 

161 @staticmethod 

162 def _reconstruct(*args): 

163 typ, *operands, token, cache = args 

164 inst = typ(*operands, _determ_token=token) 

165 for k, v in cache.items(): 

166 inst.__dict__[k] = v 

167 return inst 

168 

169 def __reduce__(self): 

170 if dask.config.get("dask-expr-no-serialize", False): 

171 raise RuntimeError(f"Serializing a {type(self)} object") 

172 cache = {} 

173 if type(self)._pickle_functools_cache: 

174 for k, v in type(self).__dict__.items(): 

175 if isinstance(v, functools.cached_property) and k in self.__dict__: 

176 cache[k] = getattr(self, k) 

177 

178 return Expr._reconstruct, ( 

179 type(self), 

180 *self.operands, 

181 self.deterministic_token, 

182 cache, 

183 ) 

184 

185 def _depth(self, cache=None): 

186 """Depth of the expression tree 

187 

188 Returns 

189 ------- 

190 depth: int 

191 """ 

192 if cache is None: 

193 cache = {} 

194 if not self.dependencies(): 

195 return 1 

196 else: 

197 result = [] 

198 for expr in self.dependencies(): 

199 if expr._name in cache: 

200 result.append(cache[expr._name]) 

201 else: 

202 result.append(expr._depth(cache) + 1) 

203 cache[expr._name] = result[-1] 

204 return max(result) 

205 

206 def __setattr__(self, name: str, value: Any) -> None: 

207 if name in ["operands", "_determ_token"]: 

208 object.__setattr__(self, name, value) 

209 return 

210 try: 

211 params = type(self)._parameters 

212 operands = object.__getattribute__(self, "operands") 

213 operands[params.index(name)] = value 

214 except ValueError: 

215 raise AttributeError( 

216 f"{type(self).__name__} object has no attribute {name}" 

217 ) 

218 

219 def operand(self, key): 

220 # Access an operand unambiguously 

221 # (e.g. if the key is reserved by a method/property) 

222 return self.operands[type(self)._parameters.index(key)] 

223 

224 def dependencies(self): 

225 # Dependencies are `Expr` operands only 

226 return [operand for operand in self.operands if isinstance(operand, Expr)] 

227 

228 def _task(self, key: Key, index: int) -> Task: 

229 """The task for the i'th partition 

230 

231 Parameters 

232 ---------- 

233 index: 

234 The index of the partition of this dataframe 

235 

236 Examples 

237 -------- 

238 >>> class Add(Expr): 

239 ... def _task(self, i): 

240 ... return Task( 

241 ... self.__dask_keys__()[i], 

242 ... operator.add, 

243 ... TaskRef((self.left._name, i)), 

244 ... TaskRef((self.right._name, i)) 

245 ... ) 

246 

247 Returns 

248 ------- 

249 task: 

250 The Dask task to compute this partition 

251 

252 See Also 

253 -------- 

254 Expr._layer 

255 """ 

256 raise NotImplementedError( 

257 "Expressions should define either _layer (full dictionary) or _task" 

258 f" (single task). This expression {type(self)} defines neither" 

259 ) 

260 

261 def _layer(self) -> dict: 

262 """The graph layer added by this expression. 

263 

264 Simple expressions that apply one task per partition can choose to only 

265 implement `Expr._task` instead. 

266 

267 Examples 

268 -------- 

269 >>> class Add(Expr): 

270 ... def _layer(self): 

271 ... return { 

272 ... name: Task( 

273 ... name, 

274 ... operator.add, 

275 ... TaskRef((self.left._name, i)), 

276 ... TaskRef((self.right._name, i)) 

277 ... ) 

278 ... for i, name in enumerate(self.__dask_keys__()) 

279 ... } 

280 

281 Returns 

282 ------- 

283 layer: dict 

284 The Dask task graph added by this expression 

285 

286 See Also 

287 -------- 

288 Expr._task 

289 Expr.__dask_graph__ 

290 """ 

291 

292 return { 

293 (self._name, i): self._task((self._name, i), i) 

294 for i in range(self.npartitions) 

295 } 

296 

297 def rewrite(self, kind: str, rewritten): 

298 """Rewrite an expression 

299 

300 This leverages the ``._{kind}_down`` and ``._{kind}_up`` 

301 methods defined on each class 

302 

303 Returns 

304 ------- 

305 expr: 

306 output expression 

307 changed: 

308 whether or not any change occurred 

309 """ 

310 if self._name in rewritten: 

311 return rewritten[self._name] 

312 

313 expr = self 

314 down_name = f"_{kind}_down" 

315 up_name = f"_{kind}_up" 

316 while True: 

317 _continue = False 

318 

319 # Rewrite this node 

320 out = getattr(expr, down_name)() 

321 if out is None: 

322 out = expr 

323 if not isinstance(out, Expr): 

324 return out 

325 if out._name != expr._name: 

326 expr = out 

327 continue 

328 

329 # Allow children to rewrite their parents 

330 for child in expr.dependencies(): 

331 out = getattr(child, up_name)(expr) 

332 if out is None: 

333 out = expr 

334 if not isinstance(out, Expr): 

335 return out 

336 if out is not expr and out._name != expr._name: 

337 expr = out 

338 _continue = True 

339 break 

340 

341 if _continue: 

342 continue 

343 

344 # Rewrite all of the children 

345 new_operands = [] 

346 changed = False 

347 for operand in expr.operands: 

348 if isinstance(operand, Expr): 

349 new = operand.rewrite(kind=kind, rewritten=rewritten) 

350 rewritten[operand._name] = new 

351 if new._name != operand._name: 

352 changed = True 

353 else: 

354 new = operand 

355 new_operands.append(new) 

356 

357 if changed: 

358 expr = type(expr)(*new_operands) 

359 continue 

360 else: 

361 break 

362 

363 return expr 

364 

365 def simplify_once(self, dependents: defaultdict, simplified: dict): 

366 """Simplify an expression 

367 

368 This leverages the ``._simplify_down`` and ``._simplify_up`` 

369 methods defined on each class 

370 

371 Parameters 

372 ---------- 

373 

374 dependents: defaultdict[list] 

375 The dependents for every node. 

376 simplified: dict 

377 Cache of simplified expressions for these dependents. 

378 

379 Returns 

380 ------- 

381 expr: 

382 output expression 

383 """ 

384 # Check if we've already simplified for these dependents 

385 if self._name in simplified: 

386 return simplified[self._name] 

387 

388 expr = self 

389 

390 while True: 

391 out = expr._simplify_down() 

392 if out is None: 

393 out = expr 

394 if not isinstance(out, Expr): 

395 return out 

396 if out._name != expr._name: 

397 expr = out 

398 

399 # Allow children to simplify their parents 

400 for child in expr.dependencies(): 

401 out = child._simplify_up(expr, dependents) 

402 if out is None: 

403 out = expr 

404 

405 if not isinstance(out, Expr): 

406 return out 

407 if out is not expr and out._name != expr._name: 

408 expr = out 

409 break 

410 

411 # Rewrite all of the children 

412 new_operands = [] 

413 changed = False 

414 for operand in expr.operands: 

415 if isinstance(operand, Expr): 

416 # Bandaid for now, waiting for Singleton 

417 dependents[operand._name].append(weakref.ref(expr)) 

418 new = operand.simplify_once( 

419 dependents=dependents, simplified=simplified 

420 ) 

421 simplified[operand._name] = new 

422 if new._name != operand._name: 

423 changed = True 

424 else: 

425 new = operand 

426 new_operands.append(new) 

427 

428 if changed: 

429 expr = type(expr)(*new_operands) 

430 

431 break 

432 

433 return expr 

434 

435 def optimize(self, fuse: bool = False) -> Expr: 

436 stage: OptimizerStage = "fused" if fuse else "simplified-physical" 

437 

438 return optimize_until(self, stage) 

439 

440 def fuse(self) -> Expr: 

441 return self 

442 

443 def simplify(self) -> Expr: 

444 expr = self 

445 seen = set() 

446 while True: 

447 dependents = collect_dependents(expr) 

448 new = expr.simplify_once(dependents=dependents, simplified={}) 

449 if new._name == expr._name: 

450 break 

451 if new._name in seen: 

452 raise RuntimeError( 

453 f"Optimizer does not converge. {expr!r} simplified to {new!r} which was already seen. " 

454 "Please report this issue on the dask issue tracker with a minimal reproducer." 

455 ) 

456 seen.add(new._name) 

457 expr = new 

458 return expr 

459 

460 def _simplify_down(self): 

461 return 

462 

463 def _simplify_up(self, parent, dependents): 

464 return 

465 

466 def lower_once(self, lowered: dict): 

467 # Check for a cached result 

468 try: 

469 return lowered[self._name] 

470 except KeyError: 

471 pass 

472 

473 expr = self 

474 

475 # Lower this node 

476 out = expr._lower() 

477 if out is None: 

478 out = expr 

479 if not isinstance(out, Expr): 

480 return out 

481 

482 # Lower all children 

483 new_operands = [] 

484 changed = False 

485 for operand in out.operands: 

486 if isinstance(operand, Expr): 

487 new = operand.lower_once(lowered) 

488 if new._name != operand._name: 

489 changed = True 

490 else: 

491 new = operand 

492 new_operands.append(new) 

493 

494 if changed: 

495 out = type(out)(*new_operands) 

496 

497 # Cache the result and return 

498 return lowered.setdefault(self._name, out) 

499 

500 def lower_completely(self) -> Expr: 

501 """Lower an expression completely 

502 

503 This calls the ``lower_once`` method in a loop 

504 until nothing changes. This function does not 

505 apply any other optimizations (like ``simplify``). 

506 

507 Returns 

508 ------- 

509 expr: 

510 output expression 

511 

512 See Also 

513 -------- 

514 Expr.lower_once 

515 Expr._lower 

516 """ 

517 # Lower until nothing changes 

518 expr = self 

519 lowered: dict = {} 

520 while True: 

521 new = expr.lower_once(lowered) 

522 if new._name == expr._name: 

523 break 

524 expr = new 

525 return expr 

526 

527 def _lower(self): 

528 return 

529 

530 @functools.cached_property 

531 def _funcname(self) -> str: 

532 return funcname(type(self)).lower() 

533 

534 @property 

535 def deterministic_token(self): 

536 if not self._determ_token: 

537 # Just tokenize self to fall back on __dask_tokenize__ 

538 # Note how this differs to the implementation of __dask_tokenize__ 

539 self._determ_token = self.__dask_tokenize__() 

540 return self._determ_token 

541 

542 @functools.cached_property 

543 def _name(self) -> str: 

544 return f"{self._funcname}-{self.deterministic_token}" 

545 

546 @property 

547 def _meta(self): 

548 raise NotImplementedError() 

549 

550 @classmethod 

551 def _annotations_tombstone(cls) -> _AnnotationsTombstone: 

552 return _AnnotationsTombstone() 

553 

554 def __dask_annotations__(self): 

555 return {} 

556 

557 def __dask_graph__(self): 

558 """Traverse expression tree, collect layers 

559 

560 Subclasses generally do not want to override this method unless custom 

561 logic is required to treat (e.g. ignore) specific operands during graph 

562 generation. 

563 

564 See also 

565 -------- 

566 Expr._layer 

567 Expr._task 

568 """ 

569 stack = [self] 

570 seen = set() 

571 layers = [] 

572 while stack: 

573 expr = stack.pop() 

574 

575 if expr._name in seen: 

576 continue 

577 seen.add(expr._name) 

578 

579 layers.append(expr._layer()) 

580 for operand in expr.dependencies(): 

581 stack.append(operand) 

582 

583 return toolz.merge(layers) 

584 

585 @property 

586 def dask(self): 

587 return self.__dask_graph__() 

588 

589 def substitute(self, old, new) -> Expr: 

590 """Substitute a specific term within the expression 

591 

592 Note that replacing non-`Expr` terms may produce 

593 unexpected results, and is not recommended. 

594 Substituting boolean values is not allowed. 

595 

596 Parameters 

597 ---------- 

598 old: 

599 Old term to find and replace. 

600 new: 

601 New term to replace instances of `old` with. 

602 

603 Examples 

604 -------- 

605 >>> (df + 10).substitute(10, 20) # doctest: +SKIP 

606 df + 20 

607 """ 

608 return self._substitute(old, new, _seen=set()) 

609 

610 def _substitute(self, old, new, _seen): 

611 if self._name in _seen: 

612 return self 

613 # Check if we are replacing a literal 

614 if isinstance(old, Expr): 

615 substitute_literal = False 

616 if self._name == old._name: 

617 return new 

618 else: 

619 substitute_literal = True 

620 if isinstance(old, bool): 

621 raise TypeError("Arguments to `substitute` cannot be bool.") 

622 

623 new_exprs = [] 

624 update = False 

625 for operand in self.operands: 

626 if isinstance(operand, Expr): 

627 val = operand._substitute(old, new, _seen) 

628 if operand._name != val._name: 

629 update = True 

630 new_exprs.append(val) 

631 elif ( 

632 "Fused" in type(self).__name__ 

633 and isinstance(operand, list) 

634 and all(isinstance(op, Expr) for op in operand) 

635 ): 

636 # Special handling for `Fused`. 

637 # We make no promise to dive through a 

638 # list operand in general, but NEED to 

639 # do so for the `Fused.exprs` operand. 

640 val = [] 

641 for op in operand: 

642 val.append(op._substitute(old, new, _seen)) 

643 if val[-1]._name != op._name: 

644 update = True 

645 new_exprs.append(val) 

646 elif ( 

647 substitute_literal 

648 and not isinstance(operand, bool) 

649 and isinstance(operand, type(old)) 

650 and operand == old 

651 ): 

652 new_exprs.append(new) 

653 update = True 

654 else: 

655 new_exprs.append(operand) 

656 

657 if update: # Only recreate if something changed 

658 return type(self)(*new_exprs) 

659 else: 

660 _seen.add(self._name) 

661 return self 

662 

663 def substitute_parameters(self, substitutions: dict) -> Expr: 

664 """Substitute specific `Expr` parameters 

665 

666 Parameters 

667 ---------- 

668 substitutions: 

669 Mapping of parameter keys to new values. Keys that 

670 are not found in ``self._parameters`` will be ignored. 

671 """ 

672 if not substitutions: 

673 return self 

674 

675 changed = False 

676 new_operands = [] 

677 for i, operand in enumerate(self.operands): 

678 if i < len(self._parameters) and self._parameters[i] in substitutions: 

679 new_operands.append(substitutions[self._parameters[i]]) 

680 changed = True 

681 else: 

682 new_operands.append(operand) 

683 if changed: 

684 return type(self)(*new_operands) 

685 return self 

686 

687 def _node_label_args(self): 

688 """Operands to include in the node label by `visualize`""" 

689 return self.dependencies() 

690 

691 def _to_graphviz( 

692 self, 

693 rankdir="BT", 

694 graph_attr=None, 

695 node_attr=None, 

696 edge_attr=None, 

697 **kwargs, 

698 ): 

699 from dask.dot import label, name 

700 

701 graphviz = import_required( 

702 "graphviz", 

703 "Drawing dask graphs with the graphviz visualization engine requires the `graphviz` " 

704 "python library and the `graphviz` system library.\n\n" 

705 "Please either conda or pip install as follows:\n\n" 

706 " conda install python-graphviz # either conda install\n" 

707 " python -m pip install graphviz # or pip install and follow installation instructions", 

708 ) 

709 

710 graph_attr = graph_attr or {} 

711 node_attr = node_attr or {} 

712 edge_attr = edge_attr or {} 

713 

714 graph_attr["rankdir"] = rankdir 

715 node_attr["shape"] = "box" 

716 node_attr["fontname"] = "helvetica" 

717 

718 graph_attr.update(kwargs) 

719 g = graphviz.Digraph( 

720 graph_attr=graph_attr, 

721 node_attr=node_attr, 

722 edge_attr=edge_attr, 

723 ) 

724 

725 stack = [self] 

726 seen = set() 

727 dependencies = {} 

728 while stack: 

729 expr = stack.pop() 

730 

731 if expr._name in seen: 

732 continue 

733 seen.add(expr._name) 

734 

735 dependencies[expr] = set(expr.dependencies()) 

736 for dep in expr.dependencies(): 

737 stack.append(dep) 

738 

739 cache = {} 

740 for expr in dependencies: 

741 expr_name = name(expr) 

742 attrs = {} 

743 

744 # Make node label 

745 deps = [ 

746 funcname(type(dep)) if isinstance(dep, Expr) else str(dep) 

747 for dep in expr._node_label_args() 

748 ] 

749 _label = funcname(type(expr)) 

750 if deps: 

751 _label = f"{_label}({', '.join(deps)})" if deps else _label 

752 node_label = label(_label, cache=cache) 

753 

754 attrs.setdefault("label", str(node_label)) 

755 attrs.setdefault("fontsize", "20") 

756 g.node(expr_name, **attrs) 

757 

758 for expr, deps in dependencies.items(): 

759 expr_name = name(expr) 

760 for dep in deps: 

761 dep_name = name(dep) 

762 g.edge(dep_name, expr_name) 

763 

764 return g 

765 

766 def visualize(self, filename="dask-expr.svg", format=None, **kwargs): 

767 """ 

768 Visualize the expression graph. 

769 Requires ``graphviz`` to be installed. 

770 

771 Parameters 

772 ---------- 

773 filename : str or None, optional 

774 The name of the file to write to disk. If the provided `filename` 

775 doesn't include an extension, '.png' will be used by default. 

776 If `filename` is None, no file will be written, and the graph is 

777 rendered in the Jupyter notebook only. 

778 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional 

779 Format in which to write output file. Default is 'svg'. 

780 **kwargs 

781 Additional keyword arguments to forward to ``to_graphviz``. 

782 """ 

783 from dask.dot import graphviz_to_file 

784 

785 g = self._to_graphviz(**kwargs) 

786 graphviz_to_file(g, filename, format) 

787 return g 

788 

789 def walk(self) -> Generator[Expr]: 

790 """Iterate through all expressions in the tree 

791 

792 Returns 

793 ------- 

794 nodes 

795 Generator of Expr instances in the graph. 

796 Ordering is a depth-first search of the expression tree 

797 """ 

798 stack = [self] 

799 seen = set() 

800 while stack: 

801 node = stack.pop() 

802 if node._name in seen: 

803 continue 

804 seen.add(node._name) 

805 

806 for dep in node.dependencies(): 

807 stack.append(dep) 

808 

809 yield node 

810 

811 def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]: 

812 """Search the expression graph for a specific operation type 

813 

814 Parameters 

815 ---------- 

816 operation 

817 The operation type to search for. 

818 

819 Returns 

820 ------- 

821 nodes 

822 Generator of `operation` instances. Ordering corresponds 

823 to a depth-first search of the expression graph. 

824 """ 

825 assert ( 

826 isinstance(operation, tuple) 

827 and all(issubclass(e, Expr) for e in operation) 

828 or issubclass(operation, Expr) # type: ignore[arg-type] 

829 ), "`operation` must be`Expr` subclass)" 

830 return (expr for expr in self.walk() if isinstance(expr, operation)) 

831 

832 def __getattr__(self, key): 

833 try: 

834 return object.__getattribute__(self, key) 

835 except AttributeError as err: 

836 if key.startswith("_meta"): 

837 # Avoid a recursive loop if/when `self._meta*` 

838 # produces an `AttributeError` 

839 raise RuntimeError( 

840 f"Failed to generate metadata for {self}. " 

841 "This operation may not be supported by the current backend." 

842 ) 

843 

844 # Allow operands to be accessed as attributes 

845 # as long as the keys are not already reserved 

846 # by existing methods/properties 

847 _parameters = type(self)._parameters 

848 if key in _parameters: 

849 idx = _parameters.index(key) 

850 return self.operands[idx] 

851 

852 raise AttributeError( 

853 f"{err}\n\n" 

854 "This often means that you are attempting to use an unsupported " 

855 f"API function.." 

856 ) 

857 

858 

859class SingletonExpr(Expr): 

860 """A singleton Expr class 

861 

862 This is used to treat the subclassed expression as a singleton. Singletons 

863 are deduplicated by expr._name which is typically based on the dask.tokenize 

864 output. 

865 

866 This is a crucial performance optimization for expressions that walk through 

867 an optimizer and are recreated repeatedly but isn't safe for objects that 

868 cannot be reliably or quickly tokenized. 

869 """ 

870 

871 _instances: weakref.WeakValueDictionary[str, SingletonExpr] 

872 

873 def __new__(cls, *args, _determ_token=None, **kwargs): 

874 if not hasattr(cls, "_instances"): 

875 cls._instances = weakref.WeakValueDictionary() 

876 inst = super().__new__(cls, *args, _determ_token=_determ_token, **kwargs) 

877 _name = inst._name 

878 if _name in cls._instances and cls.__init__ == object.__init__: 

879 return cls._instances[_name] 

880 

881 cls._instances[_name] = inst 

882 return inst 

883 

884 

885def collect_dependents(expr) -> defaultdict: 

886 dependents = defaultdict(list) 

887 stack = [expr] 

888 seen = set() 

889 while stack: 

890 node = stack.pop() 

891 if node._name in seen: 

892 continue 

893 seen.add(node._name) 

894 

895 for dep in node.dependencies(): 

896 stack.append(dep) 

897 dependents[dep._name].append(weakref.ref(node)) 

898 return dependents 

899 

900 

901def optimize(expr: Expr, fuse: bool = True) -> Expr: 

902 """High level query optimization 

903 

904 This leverages three optimization passes: 

905 

906 1. Class based simplification using the ``_simplify`` function and methods 

907 2. Blockwise fusion 

908 

909 Parameters 

910 ---------- 

911 expr: 

912 Input expression to optimize 

913 fuse: 

914 whether or not to turn on blockwise fusion 

915 

916 See Also 

917 -------- 

918 simplify 

919 optimize_blockwise_fusion 

920 """ 

921 stage: OptimizerStage = "fused" if fuse else "simplified-physical" 

922 

923 return optimize_until(expr, stage) 

924 

925 

926def optimize_until(expr: Expr, stage: OptimizerStage) -> Expr: 

927 result = expr 

928 if stage == "logical": 

929 return result 

930 

931 # Simplify 

932 expr = result.simplify() 

933 if stage == "simplified-logical": 

934 return expr 

935 

936 # Manipulate Expression to make it more efficient 

937 if dask.config.get("optimization.tune.active", True): 

938 expr = expr.rewrite(kind="tune", rewritten={}) 

939 if stage == "tuned-logical": 

940 return expr 

941 

942 # Lower 

943 expr = expr.lower_completely() 

944 if stage == "physical": 

945 return expr 

946 

947 # Simplify again 

948 expr = expr.simplify() 

949 if stage == "simplified-physical": 

950 return expr 

951 

952 # Final graph-specific optimizations 

953 expr = expr.fuse() 

954 if stage == "fused": 

955 return expr 

956 

957 raise ValueError(f"Stage {stage!r} not supported.") 

958 

959 

960class LLGExpr(Expr): 

961 """Low Level Graph Expression""" 

962 

963 _parameters = ["dsk"] 

964 

965 def __dask_keys__(self): 

966 return list(self.operand("dsk")) 

967 

968 def _layer(self) -> dict: 

969 return ensure_dict(self.operand("dsk")) 

970 

971 

972class HLGExpr(Expr): 

973 _parameters = [ 

974 "dsk", 

975 "low_level_optimizer", 

976 "output_keys", 

977 "postcompute", 

978 "_cached_optimized", 

979 ] 

980 _defaults = { 

981 "low_level_optimizer": None, 

982 "output_keys": None, 

983 "postcompute": None, 

984 "_cached_optimized": None, 

985 } 

986 

987 @property 

988 def hlg(self): 

989 return self.operand("dsk") 

990 

991 @staticmethod 

992 def from_collection(collection, optimize_graph=True): 

993 from dask.highlevelgraph import HighLevelGraph 

994 

995 if hasattr(collection, "dask"): 

996 dsk = collection.dask.copy() 

997 else: 

998 dsk = collection.__dask_graph__() 

999 

1000 # Delayed objects still ship with low level graphs as `dask` when going 

1001 # through optimize / persist 

1002 if not isinstance(dsk, HighLevelGraph): 

1003 

1004 dsk = HighLevelGraph.from_collections( 

1005 str(id(collection)), dsk, dependencies=() 

1006 ) 

1007 if optimize_graph and not hasattr(collection, "__dask_optimize__"): 

1008 warnings.warn( 

1009 f"Collection {type(collection)} does not define a " 

1010 "`__dask_optimize__` method. In the future this will raise. " 

1011 "If no optimization is desired, please set this to `None`.", 

1012 PendingDeprecationWarning, 

1013 ) 

1014 low_level_optimizer = None 

1015 else: 

1016 low_level_optimizer = ( 

1017 collection.__dask_optimize__ if optimize_graph else None 

1018 ) 

1019 return HLGExpr( 

1020 dsk=dsk, 

1021 low_level_optimizer=low_level_optimizer, 

1022 output_keys=collection.__dask_keys__(), 

1023 postcompute=collection.__dask_postcompute__(), 

1024 ) 

1025 

1026 def finalize_compute(self): 

1027 return HLGFinalizeCompute( 

1028 self, 

1029 low_level_optimizer=self.low_level_optimizer, 

1030 output_keys=self.output_keys, 

1031 postcompute=self.postcompute, 

1032 ) 

1033 

1034 def __dask_annotations__(self) -> dict[str, dict[Key, object]]: 

1035 # optimization has to be called (and cached) since blockwise fusion can 

1036 # alter annotations 

1037 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)` 

1038 dsk = self._optimized_dsk 

1039 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict) 

1040 for layer in dsk.layers.values(): 

1041 if layer.annotations: 

1042 annot = layer.annotations 

1043 for annot_type, value in annot.items(): 

1044 annotations_by_type[annot_type].update( 

1045 {k: (value(k) if callable(value) else value) for k in layer} 

1046 ) 

1047 return dict(annotations_by_type) 

1048 

1049 def __dask_keys__(self): 

1050 if (keys := self.operand("output_keys")) is not None: 

1051 return keys 

1052 dsk = self.hlg 

1053 # Note: This will materialize 

1054 dependencies = dsk.get_all_dependencies() 

1055 leafs = set(dependencies) 

1056 for val in dependencies.values(): 

1057 leafs -= val 

1058 self.output_keys = list(leafs) 

1059 return self.output_keys 

1060 

1061 @functools.cached_property 

1062 def _optimized_dsk(self) -> HighLevelGraph: 

1063 from dask.highlevelgraph import HighLevelGraph 

1064 

1065 optimizer = self.low_level_optimizer 

1066 keys = self.__dask_keys__() 

1067 dsk = self.hlg 

1068 if (optimizer := self.low_level_optimizer) is not None: 

1069 dsk = optimizer(dsk, keys) 

1070 return HighLevelGraph.merge(dsk) 

1071 

1072 @property 

1073 def deterministic_token(self): 

1074 if not self._determ_token: 

1075 self._determ_token = uuid.uuid4().hex 

1076 return self._determ_token 

1077 

1078 def _layer(self) -> dict: 

1079 dsk = self._optimized_dsk 

1080 return ensure_dict(dsk) 

1081 

1082 

1083class _HLGExprGroup(HLGExpr): 

1084 # Identical to HLGExpr 

1085 # Used internally to determine how output keys are supposed to be returned 

1086 pass 

1087 

1088 

1089class _HLGExprSequence(Expr): 

1090 

1091 def __getitem__(self, other): 

1092 return self.operands[other] 

1093 

1094 def _operands_for_repr(self): 

1095 return [ 

1096 f"name={self.operand('name')!r}", 

1097 f"dsk={self.operand('dsk')!r}", 

1098 ] 

1099 

1100 def _tree_repr_lines(self, indent=0, recursive=True): 

1101 return self._operands_for_repr() 

1102 

1103 def finalize_compute(self): 

1104 return _HLGExprSequence(*[op.finalize_compute() for op in self.operands]) 

1105 

1106 def _tune_down(self): 

1107 if len(self.operands) == 1: 

1108 return None 

1109 from dask.highlevelgraph import HighLevelGraph 

1110 

1111 groups = toolz.groupby( 

1112 lambda x: x.low_level_optimizer if isinstance(x, HLGExpr) else None, 

1113 self.operands, 

1114 ) 

1115 exprs = [] 

1116 changed = False 

1117 for optimizer, group in groups.items(): 

1118 if len(group) > 1: 

1119 graphs = [expr.hlg for expr in group] 

1120 

1121 changed = True 

1122 dsk = HighLevelGraph.merge(*graphs) 

1123 hlg_group = _HLGExprGroup( 

1124 dsk=dsk, 

1125 low_level_optimizer=optimizer, 

1126 output_keys=[v.__dask_keys__() for v in group], 

1127 postcompute=[g.postcompute for g in group], 

1128 ) 

1129 exprs.append(hlg_group) 

1130 else: 

1131 exprs.append(group[0]) 

1132 if not changed: 

1133 return None 

1134 return _HLGExprSequence(*exprs) 

1135 

1136 @functools.cached_property 

1137 def _optimized_dsk(self) -> HighLevelGraph: 

1138 from dask.highlevelgraph import HighLevelGraph 

1139 

1140 hlgexpr: HLGExpr 

1141 graphs = [] 

1142 # simplify_down ensure there are only one HLGExpr per optimizer/finalizer 

1143 for hlgexpr in self.operands: 

1144 keys = hlgexpr.__dask_keys__() 

1145 dsk = hlgexpr.hlg 

1146 if (optimizer := hlgexpr.low_level_optimizer) is not None: 

1147 dsk = optimizer(dsk, keys) 

1148 graphs.append(dsk) 

1149 

1150 return HighLevelGraph.merge(*graphs) 

1151 

1152 def __dask_graph__(self): 

1153 # This class has to override this and not just _layer to ensure the HLGs 

1154 # are not optimized individually 

1155 return ensure_dict(self._optimized_dsk) 

1156 

1157 _layer = __dask_graph__ 

1158 

1159 def __dask_annotations__(self) -> dict[str, dict[Key, object]]: 

1160 # optimization has to be called (and cached) since blockwise fusion can 

1161 # alter annotations 

1162 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)` 

1163 dsk = self._optimized_dsk 

1164 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict) 

1165 for layer in dsk.layers.values(): 

1166 if layer.annotations: 

1167 annot = layer.annotations 

1168 for annot_type, value in annot.items(): 

1169 annots = list( 

1170 (k, (value(k) if callable(value) else value)) for k in layer 

1171 ) 

1172 annotations_by_type[annot_type].update( 

1173 { 

1174 k: v 

1175 for k, v in annots 

1176 if not isinstance(v, _AnnotationsTombstone) 

1177 } 

1178 ) 

1179 if not annotations_by_type[annot_type]: 

1180 del annotations_by_type[annot_type] 

1181 return dict(annotations_by_type) 

1182 

1183 def __dask_keys__(self) -> list: 

1184 all_keys = [] 

1185 for op in self.operands: 

1186 if isinstance(op, _HLGExprGroup): 

1187 all_keys.extend(op.__dask_keys__()) 

1188 else: 

1189 all_keys.append(op.__dask_keys__()) 

1190 return all_keys 

1191 

1192 

1193class _ExprSequence(Expr): 

1194 """A sequence of expressions 

1195 

1196 This is used to be able to optimize multiple collections combined, e.g. when 

1197 being computed simultaneously with ``dask.compute((Expr1, Expr2))``. 

1198 """ 

1199 

1200 def __getitem__(self, other): 

1201 return self.operands[other] 

1202 

1203 def _layer(self) -> dict: 

1204 return toolz.merge(op._layer() for op in self.operands) 

1205 

1206 def __dask_keys__(self) -> list: 

1207 all_keys = [] 

1208 for op in self.operands: 

1209 all_keys.append(list(op.__dask_keys__())) 

1210 return all_keys 

1211 

1212 def __repr__(self): 

1213 return "ExprSequence(" + ", ".join(map(repr, self.operands)) + ")" 

1214 

1215 __str__ = __repr__ 

1216 

1217 def finalize_compute(self): 

1218 return _ExprSequence( 

1219 *(op.finalize_compute() for op in self.operands), 

1220 ) 

1221 

1222 def __dask_annotations__(self): 

1223 annotations_by_type = {} 

1224 for op in self.operands: 

1225 for k, v in op.__dask_annotations__().items(): 

1226 annotations_by_type.setdefault(k, {}).update(v) 

1227 return annotations_by_type 

1228 

1229 def __len__(self): 

1230 return len(self.operands) 

1231 

1232 def __iter__(self): 

1233 return iter(self.operands) 

1234 

1235 def _simplify_down(self): 

1236 from dask.highlevelgraph import HighLevelGraph 

1237 

1238 issue_warning = False 

1239 hlgs = [] 

1240 if any( 

1241 isinstance(op, (HLGExpr, HLGFinalizeCompute, dict)) for op in self.operands 

1242 ): 

1243 for op in self.operands: 

1244 if isinstance(op, (HLGExpr, HLGFinalizeCompute)): 

1245 hlgs.append(op) 

1246 elif isinstance(op, dict): 

1247 hlgs.append( 

1248 HLGExpr( 

1249 dsk=HighLevelGraph.from_collections( 

1250 str(id(op)), op, dependencies=() 

1251 ) 

1252 ) 

1253 ) 

1254 else: 

1255 issue_warning = True 

1256 opt = op.optimize() 

1257 hlgs.append( 

1258 HLGExpr( 

1259 dsk=HighLevelGraph.from_collections( 

1260 opt._name, opt.__dask_graph__(), dependencies=() 

1261 ) 

1262 ) 

1263 ) 

1264 if issue_warning: 

1265 warnings.warn( 

1266 "Computing mixed collections that are backed by " 

1267 "HighlevelGraphs/dicts and Expressions. " 

1268 "This forces Expressions to be materialized. " 

1269 "It is recommended to use only one type and separate the dask." 

1270 "compute calls if necessary.", 

1271 UserWarning, 

1272 ) 

1273 if not hlgs: 

1274 return None 

1275 return _HLGExprSequence(*hlgs) 

1276 

1277 

1278class _AnnotationsTombstone: ... 

1279 

1280 

1281class FinalizeCompute(Expr): 

1282 _parameters = ["expr"] 

1283 

1284 def _simplify_down(self): 

1285 return self.expr.finalize_compute() 

1286 

1287 

1288def _convert_dask_keys(keys): 

1289 from dask._task_spec import List, TaskRef 

1290 

1291 assert isinstance(keys, list) 

1292 new_keys = [] 

1293 for key in keys: 

1294 if isinstance(key, list): 

1295 new_keys.append(_convert_dask_keys(key)) 

1296 else: 

1297 new_keys.append(TaskRef(key)) 

1298 return List(*new_keys) 

1299 

1300 

1301class HLGFinalizeCompute(HLGExpr): 

1302 

1303 def _simplify_down(self): 

1304 if not self.postcompute: 

1305 return self.dsk 

1306 

1307 from dask.delayed import Delayed 

1308 

1309 # Skip finalization for Delayed 

1310 if self.dsk.postcompute == Delayed.__dask_postcompute__(self.dsk): 

1311 return self.dsk 

1312 return self 

1313 

1314 @property 

1315 def _name(self): 

1316 return f"finalize-{super()._name}" 

1317 

1318 def __dask_graph__(self): 

1319 # The baseclass __dask_graph__ will not just materialize this layer but 

1320 # also that of its dependencies, i.e. it will render the finalized and 

1321 # the non-finalized graph and combine them. We only want the finalized 

1322 # so we're overriding this. 

1323 # This is an artifact generated since the wrapped expression is 

1324 # identified automatically as a dependency but HLG expressions are not 

1325 # working in this layered way. 

1326 return self._layer() 

1327 

1328 @property 

1329 def hlg(self): 

1330 expr = self.operand("dsk") 

1331 layers = expr.dsk.layers.copy() 

1332 deps = expr.dsk.dependencies.copy() 

1333 keys = expr.__dask_keys__() 

1334 if isinstance(expr.postcompute, list): 

1335 postcomputes = expr.postcompute 

1336 else: 

1337 postcomputes = [expr.postcompute] 

1338 tasks = [ 

1339 Task(self._name, func, _convert_dask_keys(keys), *extra_args) 

1340 for func, extra_args in postcomputes 

1341 ] 

1342 from dask.highlevelgraph import HighLevelGraph, MaterializedLayer 

1343 

1344 leafs = set(deps) 

1345 for val in deps.values(): 

1346 leafs -= val 

1347 for t in tasks: 

1348 layers[t.key] = MaterializedLayer({t.key: t}) 

1349 deps[t.key] = leafs 

1350 return HighLevelGraph(layers, dependencies=deps) 

1351 

1352 def __dask_keys__(self): 

1353 return [self._name] 

1354 

1355 

1356class ProhibitReuse(Expr): 

1357 """ 

1358 An expression that guarantees that all keys are suffixes with a unique id. 

1359 This can be used to break a common subexpression apart. 

1360 """ 

1361 

1362 _parameters = ["expr"] 

1363 _ALLOWED_TYPES = [HLGExpr, LLGExpr, HLGFinalizeCompute, _HLGExprSequence] 

1364 

1365 def __dask_keys__(self): 

1366 return self._modify_keys(self.expr.__dask_keys__()) 

1367 

1368 @staticmethod 

1369 def _identity(obj): 

1370 return obj 

1371 

1372 @functools.cached_property 

1373 def _suffix(self): 

1374 return uuid.uuid4().hex 

1375 

1376 def _modify_keys(self, k): 

1377 if isinstance(k, list): 

1378 return [self._modify_keys(kk) for kk in k] 

1379 elif isinstance(k, tuple): 

1380 return (self._modify_keys(k[0]),) + k[1:] 

1381 elif isinstance(k, (int, float)): 

1382 k = str(k) 

1383 return f"{k}-{self._suffix}" 

1384 

1385 def _simplify_down(self): 

1386 # FIXME: Shuffling cannot be rewritten since the barrier key is 

1387 # hardcoded. Skipping this here should do the trick most of the time 

1388 if not isinstance( 

1389 self.expr, 

1390 tuple(self._ALLOWED_TYPES), 

1391 ): 

1392 return self.expr 

1393 

1394 def __dask_graph__(self): 

1395 try: 

1396 from distributed.shuffle._core import P2PBarrierTask 

1397 except ModuleNotFoundError: 

1398 P2PBarrierTask = type(None) 

1399 dsk = convert_legacy_graph(self.expr.__dask_graph__()) 

1400 

1401 subs = {old_key: self._modify_keys(old_key) for old_key in dsk} 

1402 dsk2 = {} 

1403 for old_key, new_key in subs.items(): 

1404 t = dsk[old_key] 

1405 if isinstance(t, P2PBarrierTask): 

1406 warnings.warn( 

1407 "Cannot block reusing for graphs including a " 

1408 "P2PBarrierTask. This may cause unexpected results. " 

1409 "This typically happens when converting a dask " 

1410 "DataFrame to delayed objects.", 

1411 UserWarning, 

1412 ) 

1413 return dsk 

1414 dsk2[new_key] = Task( 

1415 new_key, 

1416 ProhibitReuse._identity, 

1417 t.substitute(subs), 

1418 ) 

1419 

1420 dsk2.update(dsk) 

1421 return dsk2 

1422 

1423 _layer = __dask_graph__