Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/_expr.py: 21%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

731 statements  

1from __future__ import annotations 

2 

3import functools 

4import os 

5import uuid 

6import warnings 

7import weakref 

8from collections import defaultdict 

9from collections.abc import Generator 

10from typing import TYPE_CHECKING, Literal 

11 

12import toolz 

13 

14import dask 

15from dask._task_spec import Task, convert_legacy_graph 

16from dask.tokenize import _tokenize_deterministic 

17from dask.typing import Key 

18from dask.utils import ensure_dict, funcname, import_required 

19 

20if TYPE_CHECKING: 

21 # TODO import from typing (requires Python >=3.10) 

22 from typing import Any, TypeAlias 

23 

24 from dask.highlevelgraph import HighLevelGraph 

25 

26OptimizerStage: TypeAlias = Literal[ 

27 "logical", 

28 "simplified-logical", 

29 "tuned-logical", 

30 "physical", 

31 "simplified-physical", 

32 "fused", 

33] 

34 

35 

36def _unpack_collections(o): 

37 from dask.delayed import Delayed 

38 

39 if isinstance(o, Expr): 

40 return o 

41 

42 if hasattr(o, "expr") and not isinstance(o, Delayed): 

43 return o.expr 

44 else: 

45 return o 

46 

47 

48class Expr: 

49 _parameters: list[str] = [] 

50 _defaults: dict[str, Any] = {} 

51 

52 _pickle_functools_cache: bool = True 

53 

54 operands: list 

55 

56 _determ_token: str | None 

57 

58 def __new__(cls, *args, _determ_token=None, **kwargs): 

59 operands = list(args) 

60 for parameter in cls._parameters[len(operands) :]: 

61 try: 

62 operands.append(kwargs.pop(parameter)) 

63 except KeyError: 

64 operands.append(cls._defaults[parameter]) 

65 assert not kwargs, kwargs 

66 inst = object.__new__(cls) 

67 

68 inst._determ_token = _determ_token 

69 inst.operands = [_unpack_collections(o) for o in operands] 

70 # This is typically cached. Make sure the cache is populated by calling 

71 # it once 

72 inst._name 

73 return inst 

74 

75 def _tune_down(self): 

76 return None 

77 

78 def _tune_up(self, parent): 

79 return None 

80 

81 def finalize_compute(self): 

82 return self 

83 

84 def _operands_for_repr(self): 

85 return [f"{param}={op!r}" for param, op in zip(self._parameters, self.operands)] 

86 

87 def __str__(self): 

88 s = ", ".join(self._operands_for_repr()) 

89 return f"{type(self).__name__}({s})" 

90 

91 def __repr__(self): 

92 return str(self) 

93 

94 def _tree_repr_argument_construction(self, i, op, header): 

95 try: 

96 param = self._parameters[i] 

97 default = self._defaults[param] 

98 except (IndexError, KeyError): 

99 param = self._parameters[i] if i < len(self._parameters) else "" 

100 default = "--no-default--" 

101 

102 if repr(op) != repr(default): 

103 if param: 

104 header += f" {param}={op!r}" 

105 else: 

106 header += repr(op) 

107 return header 

108 

109 def _tree_repr_lines(self, indent=0, recursive=True): 

110 return " " * indent + repr(self) 

111 

112 def tree_repr(self): 

113 return os.linesep.join(self._tree_repr_lines()) 

114 

115 def analyze(self, filename: str | None = None, format: str | None = None) -> None: 

116 from dask.dataframe.dask_expr._expr import Expr as DFExpr 

117 from dask.dataframe.dask_expr.diagnostics import analyze 

118 

119 if not isinstance(self, DFExpr): 

120 raise TypeError( 

121 "analyze is only supported for dask.dataframe.Expr objects." 

122 ) 

123 return analyze(self, filename=filename, format=format) 

124 

125 def explain( 

126 self, stage: OptimizerStage = "fused", format: str | None = None 

127 ) -> None: 

128 from dask.dataframe.dask_expr.diagnostics import explain 

129 

130 return explain(self, stage, format) 

131 

132 def pprint(self): 

133 for line in self._tree_repr_lines(): 

134 print(line) 

135 

136 def __hash__(self): 

137 return hash(self._name) 

138 

139 def __dask_tokenize__(self): 

140 if not self._determ_token: 

141 # If the subclass does not implement a __dask_tokenize__ we'll want 

142 # to tokenize all operands. 

143 # Note how this differs to the implementation of 

144 # Expr.deterministic_token 

145 self._determ_token = _tokenize_deterministic(type(self), *self.operands) 

146 return self._determ_token 

147 

148 def __dask_keys__(self): 

149 """The keys for this expression 

150 

151 This is used to determine the keys of the output collection 

152 when this expression is computed. 

153 

154 Returns 

155 ------- 

156 keys: list 

157 The keys for this expression 

158 """ 

159 return [(self._name, i) for i in range(self.npartitions)] 

160 

161 @staticmethod 

162 def _reconstruct(*args): 

163 typ, *operands, token, cache = args 

164 inst = typ(*operands, _determ_token=token) 

165 for k, v in cache.items(): 

166 inst.__dict__[k] = v 

167 return inst 

168 

169 def __reduce__(self): 

170 if dask.config.get("dask-expr-no-serialize", False): 

171 raise RuntimeError(f"Serializing a {type(self)} object") 

172 cache = {} 

173 if type(self)._pickle_functools_cache: 

174 for k, v in type(self).__dict__.items(): 

175 if isinstance(v, functools.cached_property) and k in self.__dict__: 

176 cache[k] = getattr(self, k) 

177 

178 return Expr._reconstruct, ( 

179 type(self), 

180 *self.operands, 

181 self.deterministic_token, 

182 cache, 

183 ) 

184 

185 def _depth(self, cache=None): 

186 """Depth of the expression tree 

187 

188 Returns 

189 ------- 

190 depth: int 

191 """ 

192 if cache is None: 

193 cache = {} 

194 if not self.dependencies(): 

195 return 1 

196 else: 

197 result = [] 

198 for expr in self.dependencies(): 

199 if expr._name in cache: 

200 result.append(cache[expr._name]) 

201 else: 

202 result.append(expr._depth(cache) + 1) 

203 cache[expr._name] = result[-1] 

204 return max(result) 

205 

206 def __setattr__(self, name: str, value: Any) -> None: 

207 if name in ["operands", "_determ_token"]: 

208 object.__setattr__(self, name, value) 

209 return 

210 try: 

211 params = type(self)._parameters 

212 operands = object.__getattribute__(self, "operands") 

213 operands[params.index(name)] = value 

214 except ValueError: 

215 raise AttributeError( 

216 f"{type(self).__name__} object has no attribute {name}" 

217 ) 

218 

219 def operand(self, key): 

220 # Access an operand unambiguously 

221 # (e.g. if the key is reserved by a method/property) 

222 return self.operands[type(self)._parameters.index(key)] 

223 

224 def dependencies(self): 

225 # Dependencies are `Expr` operands only 

226 return [operand for operand in self.operands if isinstance(operand, Expr)] 

227 

228 def _task(self, key: Key, index: int) -> Task: 

229 """The task for the i'th partition 

230 

231 Parameters 

232 ---------- 

233 index: 

234 The index of the partition of this dataframe 

235 

236 Examples 

237 -------- 

238 >>> class Add(Expr): 

239 ... def _task(self, i): 

240 ... return Task( 

241 ... self.__dask_keys__()[i], 

242 ... operator.add, 

243 ... TaskRef((self.left._name, i)), 

244 ... TaskRef((self.right._name, i)) 

245 ... ) 

246 

247 Returns 

248 ------- 

249 task: 

250 The Dask task to compute this partition 

251 

252 See Also 

253 -------- 

254 Expr._layer 

255 """ 

256 raise NotImplementedError( 

257 "Expressions should define either _layer (full dictionary) or _task" 

258 f" (single task). This expression {type(self)} defines neither" 

259 ) 

260 

261 def _layer(self) -> dict: 

262 """The graph layer added by this expression. 

263 

264 Simple expressions that apply one task per partition can choose to only 

265 implement `Expr._task` instead. 

266 

267 Examples 

268 -------- 

269 >>> class Add(Expr): 

270 ... def _layer(self): 

271 ... return { 

272 ... name: Task( 

273 ... name, 

274 ... operator.add, 

275 ... TaskRef((self.left._name, i)), 

276 ... TaskRef((self.right._name, i)) 

277 ... ) 

278 ... for i, name in enumerate(self.__dask_keys__()) 

279 ... } 

280 

281 Returns 

282 ------- 

283 layer: dict 

284 The Dask task graph added by this expression 

285 

286 See Also 

287 -------- 

288 Expr._task 

289 Expr.__dask_graph__ 

290 """ 

291 

292 return { 

293 (self._name, i): self._task((self._name, i), i) 

294 for i in range(self.npartitions) 

295 } 

296 

297 def rewrite(self, kind: str, rewritten): 

298 """Rewrite an expression 

299 

300 This leverages the ``._{kind}_down`` and ``._{kind}_up`` 

301 methods defined on each class 

302 

303 Returns 

304 ------- 

305 expr: 

306 output expression 

307 changed: 

308 whether or not any change occurred 

309 """ 

310 if self._name in rewritten: 

311 return rewritten[self._name] 

312 

313 expr = self 

314 down_name = f"_{kind}_down" 

315 up_name = f"_{kind}_up" 

316 while True: 

317 _continue = False 

318 

319 # Rewrite this node 

320 out = getattr(expr, down_name)() 

321 if out is None: 

322 out = expr 

323 if not isinstance(out, Expr): 

324 return out 

325 if out._name != expr._name: 

326 expr = out 

327 continue 

328 

329 # Allow children to rewrite their parents 

330 for child in expr.dependencies(): 

331 out = getattr(child, up_name)(expr) 

332 if out is None: 

333 out = expr 

334 if not isinstance(out, Expr): 

335 return out 

336 if out is not expr and out._name != expr._name: 

337 expr = out 

338 _continue = True 

339 break 

340 

341 if _continue: 

342 continue 

343 

344 # Rewrite all of the children 

345 new_operands = [] 

346 changed = False 

347 for operand in expr.operands: 

348 if isinstance(operand, Expr): 

349 new = operand.rewrite(kind=kind, rewritten=rewritten) 

350 rewritten[operand._name] = new 

351 if new._name != operand._name: 

352 changed = True 

353 else: 

354 new = operand 

355 new_operands.append(new) 

356 

357 if changed: 

358 expr = type(expr)(*new_operands) 

359 continue 

360 else: 

361 break 

362 

363 return expr 

364 

365 def simplify_once(self, dependents: defaultdict, simplified: dict): 

366 """Simplify an expression 

367 

368 This leverages the ``._simplify_down`` and ``._simplify_up`` 

369 methods defined on each class 

370 

371 Parameters 

372 ---------- 

373 

374 dependents: defaultdict[list] 

375 The dependents for every node. 

376 simplified: dict 

377 Cache of simplified expressions for these dependents. 

378 

379 Returns 

380 ------- 

381 expr: 

382 output expression 

383 """ 

384 # Check if we've already simplified for these dependents 

385 if self._name in simplified: 

386 return simplified[self._name] 

387 

388 expr = self 

389 

390 while True: 

391 out = expr._simplify_down() 

392 if out is None: 

393 out = expr 

394 if not isinstance(out, Expr): 

395 return out 

396 if out._name != expr._name: 

397 expr = out 

398 

399 # Allow children to simplify their parents 

400 for child in expr.dependencies(): 

401 out = child._simplify_up(expr, dependents) 

402 if out is None: 

403 out = expr 

404 

405 if not isinstance(out, Expr): 

406 return out 

407 if out is not expr and out._name != expr._name: 

408 expr = out 

409 break 

410 

411 # Rewrite all of the children 

412 new_operands = [] 

413 changed = False 

414 for operand in expr.operands: 

415 if isinstance(operand, Expr): 

416 # Bandaid for now, waiting for Singleton 

417 dependents[operand._name].append(weakref.ref(expr)) 

418 new = operand.simplify_once( 

419 dependents=dependents, simplified=simplified 

420 ) 

421 simplified[operand._name] = new 

422 if new._name != operand._name: 

423 changed = True 

424 else: 

425 new = operand 

426 new_operands.append(new) 

427 

428 if changed: 

429 expr = type(expr)(*new_operands) 

430 

431 break 

432 

433 return expr 

434 

435 def optimize(self, fuse: bool = False) -> Expr: 

436 stage: OptimizerStage = "fused" if fuse else "simplified-physical" 

437 

438 return optimize_until(self, stage) 

439 

440 def fuse(self) -> Expr: 

441 return self 

442 

443 def simplify(self) -> Expr: 

444 expr = self 

445 seen = set() 

446 while True: 

447 dependents = collect_dependents(expr) 

448 new = expr.simplify_once(dependents=dependents, simplified={}) 

449 if new._name == expr._name: 

450 break 

451 if new._name in seen: 

452 raise RuntimeError( 

453 f"Optimizer does not converge. {expr!r} simplified to {new!r} which was already seen. " 

454 "Please report this issue on the dask issue tracker with a minimal reproducer." 

455 ) 

456 seen.add(new._name) 

457 expr = new 

458 return expr 

459 

460 def _simplify_down(self): 

461 return 

462 

463 def _simplify_up(self, parent, dependents): 

464 return 

465 

466 def lower_once(self, lowered: dict): 

467 # Check for a cached result 

468 try: 

469 return lowered[self._name] 

470 except KeyError: 

471 pass 

472 

473 expr = self 

474 

475 # Lower this node 

476 out = expr._lower() 

477 if out is None: 

478 out = expr 

479 if not isinstance(out, Expr): 

480 return out 

481 

482 # Lower all children 

483 new_operands = [] 

484 changed = False 

485 for operand in out.operands: 

486 if isinstance(operand, Expr): 

487 new = operand.lower_once(lowered) 

488 if new._name != operand._name: 

489 changed = True 

490 else: 

491 new = operand 

492 new_operands.append(new) 

493 

494 if changed: 

495 out = type(out)(*new_operands) 

496 

497 # Cache the result and return 

498 return lowered.setdefault(self._name, out) 

499 

500 def lower_completely(self) -> Expr: 

501 """Lower an expression completely 

502 

503 This calls the ``lower_once`` method in a loop 

504 until nothing changes. This function does not 

505 apply any other optimizations (like ``simplify``). 

506 

507 Returns 

508 ------- 

509 expr: 

510 output expression 

511 

512 See Also 

513 -------- 

514 Expr.lower_once 

515 Expr._lower 

516 """ 

517 # Lower until nothing changes 

518 expr = self 

519 lowered: dict = {} 

520 while True: 

521 new = expr.lower_once(lowered) 

522 if new._name == expr._name: 

523 break 

524 expr = new 

525 return expr 

526 

527 def _lower(self): 

528 return 

529 

530 @functools.cached_property 

531 def _funcname(self) -> str: 

532 return funcname(type(self)).lower() 

533 

534 @property 

535 def deterministic_token(self): 

536 if not self._determ_token: 

537 # Just tokenize self to fall back on __dask_tokenize__ 

538 # Note how this differs to the implementation of __dask_tokenize__ 

539 self._determ_token = self.__dask_tokenize__() 

540 return self._determ_token 

541 

542 @functools.cached_property 

543 def _name(self) -> str: 

544 return self._funcname + "-" + self.deterministic_token 

545 

546 @property 

547 def _meta(self): 

548 raise NotImplementedError() 

549 

550 @classmethod 

551 def _annotations_tombstone(cls) -> _AnnotationsTombstone: 

552 return _AnnotationsTombstone() 

553 

554 def __dask_annotations__(self): 

555 return {} 

556 

557 def __dask_graph__(self): 

558 """Traverse expression tree, collect layers 

559 

560 Subclasses generally do not want to override this method unless custom 

561 logic is required to treat (e.g. ignore) specific operands during graph 

562 generation. 

563 

564 See also 

565 -------- 

566 Expr._layer 

567 Expr._task 

568 """ 

569 stack = [self] 

570 seen = set() 

571 layers = [] 

572 while stack: 

573 expr = stack.pop() 

574 

575 if expr._name in seen: 

576 continue 

577 seen.add(expr._name) 

578 

579 layers.append(expr._layer()) 

580 for operand in expr.dependencies(): 

581 stack.append(operand) 

582 

583 return toolz.merge(layers) 

584 

585 @property 

586 def dask(self): 

587 return self.__dask_graph__() 

588 

589 def substitute(self, old, new) -> Expr: 

590 """Substitute a specific term within the expression 

591 

592 Note that replacing non-`Expr` terms may produce 

593 unexpected results, and is not recommended. 

594 Substituting boolean values is not allowed. 

595 

596 Parameters 

597 ---------- 

598 old: 

599 Old term to find and replace. 

600 new: 

601 New term to replace instances of `old` with. 

602 

603 Examples 

604 -------- 

605 >>> (df + 10).substitute(10, 20) # doctest: +SKIP 

606 df + 20 

607 """ 

608 return self._substitute(old, new, _seen=set()) 

609 

610 def _substitute(self, old, new, _seen): 

611 if self._name in _seen: 

612 return self 

613 # Check if we are replacing a literal 

614 if isinstance(old, Expr): 

615 substitute_literal = False 

616 if self._name == old._name: 

617 return new 

618 else: 

619 substitute_literal = True 

620 if isinstance(old, bool): 

621 raise TypeError("Arguments to `substitute` cannot be bool.") 

622 

623 new_exprs = [] 

624 update = False 

625 for operand in self.operands: 

626 if isinstance(operand, Expr): 

627 val = operand._substitute(old, new, _seen) 

628 if operand._name != val._name: 

629 update = True 

630 new_exprs.append(val) 

631 elif ( 

632 "Fused" in type(self).__name__ 

633 and isinstance(operand, list) 

634 and all(isinstance(op, Expr) for op in operand) 

635 ): 

636 # Special handling for `Fused`. 

637 # We make no promise to dive through a 

638 # list operand in general, but NEED to 

639 # do so for the `Fused.exprs` operand. 

640 val = [] 

641 for op in operand: 

642 val.append(op._substitute(old, new, _seen)) 

643 if val[-1]._name != op._name: 

644 update = True 

645 new_exprs.append(val) 

646 elif ( 

647 substitute_literal 

648 and not isinstance(operand, bool) 

649 and isinstance(operand, type(old)) 

650 and operand == old 

651 ): 

652 new_exprs.append(new) 

653 update = True 

654 else: 

655 new_exprs.append(operand) 

656 

657 if update: # Only recreate if something changed 

658 return type(self)(*new_exprs) 

659 else: 

660 _seen.add(self._name) 

661 return self 

662 

663 def substitute_parameters(self, substitutions: dict) -> Expr: 

664 """Substitute specific `Expr` parameters 

665 

666 Parameters 

667 ---------- 

668 substitutions: 

669 Mapping of parameter keys to new values. Keys that 

670 are not found in ``self._parameters`` will be ignored. 

671 """ 

672 if not substitutions: 

673 return self 

674 

675 changed = False 

676 new_operands = [] 

677 for i, operand in enumerate(self.operands): 

678 if i < len(self._parameters) and self._parameters[i] in substitutions: 

679 new_operands.append(substitutions[self._parameters[i]]) 

680 changed = True 

681 else: 

682 new_operands.append(operand) 

683 if changed: 

684 return type(self)(*new_operands) 

685 return self 

686 

687 def _node_label_args(self): 

688 """Operands to include in the node label by `visualize`""" 

689 return self.dependencies() 

690 

691 def _to_graphviz( 

692 self, 

693 rankdir="BT", 

694 graph_attr=None, 

695 node_attr=None, 

696 edge_attr=None, 

697 **kwargs, 

698 ): 

699 from dask.dot import label, name 

700 

701 graphviz = import_required( 

702 "graphviz", 

703 "Drawing dask graphs with the graphviz visualization engine requires the `graphviz` " 

704 "python library and the `graphviz` system library.\n\n" 

705 "Please either conda or pip install as follows:\n\n" 

706 " conda install python-graphviz # either conda install\n" 

707 " python -m pip install graphviz # or pip install and follow installation instructions", 

708 ) 

709 

710 graph_attr = graph_attr or {} 

711 node_attr = node_attr or {} 

712 edge_attr = edge_attr or {} 

713 

714 graph_attr["rankdir"] = rankdir 

715 node_attr["shape"] = "box" 

716 node_attr["fontname"] = "helvetica" 

717 

718 graph_attr.update(kwargs) 

719 g = graphviz.Digraph( 

720 graph_attr=graph_attr, 

721 node_attr=node_attr, 

722 edge_attr=edge_attr, 

723 ) 

724 

725 stack = [self] 

726 seen = set() 

727 dependencies = {} 

728 while stack: 

729 expr = stack.pop() 

730 

731 if expr._name in seen: 

732 continue 

733 seen.add(expr._name) 

734 

735 dependencies[expr] = set(expr.dependencies()) 

736 for dep in expr.dependencies(): 

737 stack.append(dep) 

738 

739 cache = {} 

740 for expr in dependencies: 

741 expr_name = name(expr) 

742 attrs = {} 

743 

744 # Make node label 

745 deps = [ 

746 funcname(type(dep)) if isinstance(dep, Expr) else str(dep) 

747 for dep in expr._node_label_args() 

748 ] 

749 _label = funcname(type(expr)) 

750 if deps: 

751 _label = f"{_label}({', '.join(deps)})" if deps else _label 

752 node_label = label(_label, cache=cache) 

753 

754 attrs.setdefault("label", str(node_label)) 

755 attrs.setdefault("fontsize", "20") 

756 g.node(expr_name, **attrs) 

757 

758 for expr, deps in dependencies.items(): 

759 expr_name = name(expr) 

760 for dep in deps: 

761 dep_name = name(dep) 

762 g.edge(dep_name, expr_name) 

763 

764 return g 

765 

766 def visualize(self, filename="dask-expr.svg", format=None, **kwargs): 

767 """ 

768 Visualize the expression graph. 

769 Requires ``graphviz`` to be installed. 

770 

771 Parameters 

772 ---------- 

773 filename : str or None, optional 

774 The name of the file to write to disk. If the provided `filename` 

775 doesn't include an extension, '.png' will be used by default. 

776 If `filename` is None, no file will be written, and the graph is 

777 rendered in the Jupyter notebook only. 

778 format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional 

779 Format in which to write output file. Default is 'svg'. 

780 **kwargs 

781 Additional keyword arguments to forward to ``to_graphviz``. 

782 """ 

783 from dask.dot import graphviz_to_file 

784 

785 g = self._to_graphviz(**kwargs) 

786 graphviz_to_file(g, filename, format) 

787 return g 

788 

789 def walk(self) -> Generator[Expr]: 

790 """Iterate through all expressions in the tree 

791 

792 Returns 

793 ------- 

794 nodes 

795 Generator of Expr instances in the graph. 

796 Ordering is a depth-first search of the expression tree 

797 """ 

798 stack = [self] 

799 seen = set() 

800 while stack: 

801 node = stack.pop() 

802 if node._name in seen: 

803 continue 

804 seen.add(node._name) 

805 

806 for dep in node.dependencies(): 

807 stack.append(dep) 

808 

809 yield node 

810 

811 def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]: 

812 """Search the expression graph for a specific operation type 

813 

814 Parameters 

815 ---------- 

816 operation 

817 The operation type to search for. 

818 

819 Returns 

820 ------- 

821 nodes 

822 Generator of `operation` instances. Ordering corresponds 

823 to a depth-first search of the expression graph. 

824 """ 

825 assert ( 

826 isinstance(operation, tuple) 

827 and all(issubclass(e, Expr) for e in operation) 

828 or issubclass(operation, Expr) # type: ignore 

829 ), "`operation` must be`Expr` subclass)" 

830 return (expr for expr in self.walk() if isinstance(expr, operation)) 

831 

832 def __getattr__(self, key): 

833 try: 

834 return object.__getattribute__(self, key) 

835 except AttributeError as err: 

836 if key.startswith("_meta"): 

837 # Avoid a recursive loop if/when `self._meta*` 

838 # produces an `AttributeError` 

839 raise RuntimeError( 

840 f"Failed to generate metadata for {self}. " 

841 "This operation may not be supported by the current backend." 

842 ) 

843 

844 # Allow operands to be accessed as attributes 

845 # as long as the keys are not already reserved 

846 # by existing methods/properties 

847 _parameters = type(self)._parameters 

848 if key in _parameters: 

849 idx = _parameters.index(key) 

850 return self.operands[idx] 

851 

852 raise AttributeError( 

853 f"{err}\n\n" 

854 "This often means that you are attempting to use an unsupported " 

855 f"API function.." 

856 ) 

857 

858 

859class SingletonExpr(Expr): 

860 """A singleton Expr class 

861 

862 This is used to treat the subclassed expression as a singleton. Singletons 

863 are deduplicated by expr._name which is typically based on the dask.tokenize 

864 output. 

865 

866 This is a crucial performance optimization for expressions that walk through 

867 an optimizer and are recreated repeatedly but isn't safe for objects that 

868 cannot be reliably or quickly tokenized. 

869 """ 

870 

871 _instances: weakref.WeakValueDictionary[str, SingletonExpr] 

872 

873 def __new__(cls, *args, _determ_token=None, **kwargs): 

874 if not hasattr(cls, "_instances"): 

875 cls._instances = weakref.WeakValueDictionary() 

876 inst = super().__new__(cls, *args, _determ_token=_determ_token, **kwargs) 

877 _name = inst._name 

878 if _name in cls._instances and cls.__init__ == object.__init__: 

879 return cls._instances[_name] 

880 

881 cls._instances[_name] = inst 

882 return inst 

883 

884 

885def collect_dependents(expr) -> defaultdict: 

886 dependents = defaultdict(list) 

887 stack = [expr] 

888 seen = set() 

889 while stack: 

890 node = stack.pop() 

891 if node._name in seen: 

892 continue 

893 seen.add(node._name) 

894 

895 for dep in node.dependencies(): 

896 stack.append(dep) 

897 dependents[dep._name].append(weakref.ref(node)) 

898 return dependents 

899 

900 

901def optimize(expr: Expr, fuse: bool = True) -> Expr: 

902 """High level query optimization 

903 

904 This leverages three optimization passes: 

905 

906 1. Class based simplification using the ``_simplify`` function and methods 

907 2. Blockwise fusion 

908 

909 Parameters 

910 ---------- 

911 expr: 

912 Input expression to optimize 

913 fuse: 

914 whether or not to turn on blockwise fusion 

915 

916 See Also 

917 -------- 

918 simplify 

919 optimize_blockwise_fusion 

920 """ 

921 stage: OptimizerStage = "fused" if fuse else "simplified-physical" 

922 

923 return optimize_until(expr, stage) 

924 

925 

926def optimize_until(expr: Expr, stage: OptimizerStage) -> Expr: 

927 result = expr 

928 if stage == "logical": 

929 return result 

930 

931 # Simplify 

932 expr = result.simplify() 

933 if stage == "simplified-logical": 

934 return expr 

935 

936 # Manipulate Expression to make it more efficient 

937 expr = expr.rewrite(kind="tune", rewritten={}) 

938 if stage == "tuned-logical": 

939 return expr 

940 

941 # Lower 

942 expr = expr.lower_completely() 

943 if stage == "physical": 

944 return expr 

945 

946 # Simplify again 

947 expr = expr.simplify() 

948 if stage == "simplified-physical": 

949 return expr 

950 

951 # Final graph-specific optimizations 

952 expr = expr.fuse() 

953 if stage == "fused": 

954 return expr 

955 

956 raise ValueError(f"Stage {stage!r} not supported.") 

957 

958 

959class LLGExpr(Expr): 

960 """Low Level Graph Expression""" 

961 

962 _parameters = ["dsk"] 

963 

964 def __dask_keys__(self): 

965 return list(self.operand("dsk")) 

966 

967 def _layer(self) -> dict: 

968 return ensure_dict(self.operand("dsk")) 

969 

970 

971class HLGExpr(Expr): 

972 _parameters = [ 

973 "dsk", 

974 "low_level_optimizer", 

975 "output_keys", 

976 "postcompute", 

977 "_cached_optimized", 

978 ] 

979 _defaults = { 

980 "low_level_optimizer": None, 

981 "output_keys": None, 

982 "postcompute": None, 

983 "_cached_optimized": None, 

984 } 

985 

986 @property 

987 def hlg(self): 

988 return self.operand("dsk") 

989 

990 @staticmethod 

991 def from_collection(collection, optimize_graph=True): 

992 from dask.highlevelgraph import HighLevelGraph 

993 

994 if hasattr(collection, "dask"): 

995 dsk = collection.dask.copy() 

996 else: 

997 dsk = collection.__dask_graph__() 

998 

999 # Delayed objects still ship with low level graphs as `dask` when going 

1000 # through optimize / persist 

1001 if not isinstance(dsk, HighLevelGraph): 

1002 

1003 dsk = HighLevelGraph.from_collections( 

1004 str(id(collection)), dsk, dependencies=() 

1005 ) 

1006 if optimize_graph and not hasattr(collection, "__dask_optimize__"): 

1007 warnings.warn( 

1008 f"Collection {type(collection)} does not define a " 

1009 "`__dask_optimize__` method. In the future this will raise. " 

1010 "If no optimization is desired, please set this to `None`.", 

1011 PendingDeprecationWarning, 

1012 ) 

1013 low_level_optimizer = None 

1014 else: 

1015 low_level_optimizer = ( 

1016 collection.__dask_optimize__ if optimize_graph else None 

1017 ) 

1018 return HLGExpr( 

1019 dsk=dsk, 

1020 low_level_optimizer=low_level_optimizer, 

1021 output_keys=collection.__dask_keys__(), 

1022 postcompute=collection.__dask_postcompute__(), 

1023 ) 

1024 

1025 def finalize_compute(self): 

1026 return HLGFinalizeCompute( 

1027 self, 

1028 low_level_optimizer=self.low_level_optimizer, 

1029 output_keys=self.output_keys, 

1030 postcompute=self.postcompute, 

1031 ) 

1032 

1033 def __dask_annotations__(self) -> dict[str, dict[Key, object]]: 

1034 # optimization has to be called (and cached) since blockwise fusion can 

1035 # alter annotations 

1036 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)` 

1037 dsk = self._optimized_dsk 

1038 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict) 

1039 for layer in dsk.layers.values(): 

1040 if layer.annotations: 

1041 annot = layer.annotations 

1042 for annot_type, value in annot.items(): 

1043 annotations_by_type[annot_type].update( 

1044 {k: (value(k) if callable(value) else value) for k in layer} 

1045 ) 

1046 return dict(annotations_by_type) 

1047 

1048 def __dask_keys__(self): 

1049 if (keys := self.operand("output_keys")) is not None: 

1050 return keys 

1051 dsk = self.hlg 

1052 # Note: This will materialize 

1053 dependencies = dsk.get_all_dependencies() 

1054 leafs = set(dependencies) 

1055 for val in dependencies.values(): 

1056 leafs -= val 

1057 self.output_keys = list(leafs) 

1058 return self.output_keys 

1059 

1060 @functools.cached_property 

1061 def _optimized_dsk(self) -> HighLevelGraph: 

1062 from dask.highlevelgraph import HighLevelGraph 

1063 

1064 optimizer = self.low_level_optimizer 

1065 keys = self.__dask_keys__() 

1066 dsk = self.hlg 

1067 if (optimizer := self.low_level_optimizer) is not None: 

1068 dsk = optimizer(dsk, keys) 

1069 return HighLevelGraph.merge(dsk) 

1070 

1071 @property 

1072 def deterministic_token(self): 

1073 if not self._determ_token: 

1074 self._determ_token = uuid.uuid4().hex 

1075 return self._determ_token 

1076 

1077 def _layer(self) -> dict: 

1078 dsk = self._optimized_dsk 

1079 return ensure_dict(dsk) 

1080 

1081 

1082class _HLGExprGroup(HLGExpr): 

1083 # Identical to HLGExpr 

1084 # Used internally to determine how output keys are supposed to be returned 

1085 pass 

1086 

1087 

1088class _HLGExprSequence(Expr): 

1089 

1090 def __getitem__(self, other): 

1091 return self.operands[other] 

1092 

1093 def _operands_for_repr(self): 

1094 return [ 

1095 f"name={self.operand('name')!r}", 

1096 f"dsk={self.operand('dsk')!r}", 

1097 ] 

1098 

1099 def _tree_repr_lines(self, indent=0, recursive=True): 

1100 return self._operands_for_repr() 

1101 

1102 def finalize_compute(self): 

1103 return _HLGExprSequence(*[op.finalize_compute() for op in self.operands]) 

1104 

1105 def _tune_down(self): 

1106 if len(self.operands) == 1: 

1107 return None 

1108 from dask.highlevelgraph import HighLevelGraph 

1109 

1110 groups = toolz.groupby( 

1111 lambda x: x.low_level_optimizer if isinstance(x, HLGExpr) else None, 

1112 self.operands, 

1113 ) 

1114 exprs = [] 

1115 changed = False 

1116 for optimizer, group in groups.items(): 

1117 if len(group) > 1: 

1118 graphs = [expr.hlg for expr in group] 

1119 

1120 changed = True 

1121 dsk = HighLevelGraph.merge(*graphs) 

1122 hlg_group = _HLGExprGroup( 

1123 dsk=dsk, 

1124 low_level_optimizer=optimizer, 

1125 output_keys=[v.__dask_keys__() for v in group], 

1126 postcompute=[g.postcompute for g in group], 

1127 ) 

1128 exprs.append(hlg_group) 

1129 else: 

1130 exprs.append(group[0]) 

1131 if not changed: 

1132 return None 

1133 return _HLGExprSequence(*exprs) 

1134 

1135 @functools.cached_property 

1136 def _optimized_dsk(self) -> HighLevelGraph: 

1137 from dask.highlevelgraph import HighLevelGraph 

1138 

1139 hlgexpr: HLGExpr 

1140 graphs = [] 

1141 # simplify_down ensure there are only one HLGExpr per optimizer/finalizer 

1142 for hlgexpr in self.operands: 

1143 keys = hlgexpr.__dask_keys__() 

1144 dsk = hlgexpr.hlg 

1145 if (optimizer := hlgexpr.low_level_optimizer) is not None: 

1146 dsk = optimizer(dsk, keys) 

1147 graphs.append(dsk) 

1148 

1149 return HighLevelGraph.merge(*graphs) 

1150 

1151 def __dask_graph__(self): 

1152 # This class has to override this and not just _layer to ensure the HLGs 

1153 # are not optimized individually 

1154 return ensure_dict(self._optimized_dsk) 

1155 

1156 _layer = __dask_graph__ 

1157 

1158 def __dask_annotations__(self) -> dict[str, dict[Key, object]]: 

1159 # optimization has to be called (and cached) since blockwise fusion can 

1160 # alter annotations 

1161 # see `dask.blockwise.(_fuse_annotations|_can_fuse_annotations)` 

1162 dsk = self._optimized_dsk 

1163 annotations_by_type: defaultdict[str, dict[Key, object]] = defaultdict(dict) 

1164 for layer in dsk.layers.values(): 

1165 if layer.annotations: 

1166 annot = layer.annotations 

1167 for annot_type, value in annot.items(): 

1168 annots = list( 

1169 (k, (value(k) if callable(value) else value)) for k in layer 

1170 ) 

1171 annotations_by_type[annot_type].update( 

1172 { 

1173 k: v 

1174 for k, v in annots 

1175 if not isinstance(v, _AnnotationsTombstone) 

1176 } 

1177 ) 

1178 if not annotations_by_type[annot_type]: 

1179 del annotations_by_type[annot_type] 

1180 return dict(annotations_by_type) 

1181 

1182 def __dask_keys__(self) -> list: 

1183 all_keys = [] 

1184 for op in self.operands: 

1185 if isinstance(op, _HLGExprGroup): 

1186 all_keys.extend(op.__dask_keys__()) 

1187 else: 

1188 all_keys.append(op.__dask_keys__()) 

1189 return all_keys 

1190 

1191 

1192class _ExprSequence(Expr): 

1193 """A sequence of expressions 

1194 

1195 This is used to be able to optimize multiple collections combined, e.g. when 

1196 being computed simultaneously with ``dask.compute((Expr1, Expr2))``. 

1197 """ 

1198 

1199 def __getitem__(self, other): 

1200 return self.operands[other] 

1201 

1202 def _layer(self) -> dict: 

1203 return toolz.merge(op._layer() for op in self.operands) 

1204 

1205 def __dask_keys__(self) -> list: 

1206 all_keys = [] 

1207 for op in self.operands: 

1208 all_keys.append(list(op.__dask_keys__())) 

1209 return all_keys 

1210 

1211 def __repr__(self): 

1212 return "ExprSequence(" + ", ".join(map(repr, self.operands)) + ")" 

1213 

1214 __str__ = __repr__ 

1215 

1216 def finalize_compute(self): 

1217 return _ExprSequence( 

1218 *(op.finalize_compute() for op in self.operands), 

1219 ) 

1220 

1221 def __dask_annotations__(self): 

1222 annotations_by_type = {} 

1223 for op in self.operands: 

1224 for k, v in op.__dask_annotations__().items(): 

1225 annotations_by_type.setdefault(k, {}).update(v) 

1226 return annotations_by_type 

1227 

1228 def __len__(self): 

1229 return len(self.operands) 

1230 

1231 def __iter__(self): 

1232 return iter(self.operands) 

1233 

1234 def _simplify_down(self): 

1235 from dask.highlevelgraph import HighLevelGraph 

1236 

1237 issue_warning = False 

1238 hlgs = [] 

1239 if any( 

1240 isinstance(op, (HLGExpr, HLGFinalizeCompute, dict)) for op in self.operands 

1241 ): 

1242 for op in self.operands: 

1243 if isinstance(op, (HLGExpr, HLGFinalizeCompute)): 

1244 hlgs.append(op) 

1245 elif isinstance(op, dict): 

1246 hlgs.append( 

1247 HLGExpr( 

1248 dsk=HighLevelGraph.from_collections( 

1249 str(id(op)), op, dependencies=() 

1250 ) 

1251 ) 

1252 ) 

1253 else: 

1254 issue_warning = True 

1255 opt = op.optimize() 

1256 hlgs.append( 

1257 HLGExpr( 

1258 dsk=HighLevelGraph.from_collections( 

1259 opt._name, opt.__dask_graph__(), dependencies=() 

1260 ) 

1261 ) 

1262 ) 

1263 if issue_warning: 

1264 warnings.warn( 

1265 "Computing mixed collections that are backed by " 

1266 "HighlevelGraphs/dicts and Expressions. " 

1267 "This forces Expressions to be materialized. " 

1268 "It is recommended to use only one type and separate the dask." 

1269 "compute calls if necessary.", 

1270 UserWarning, 

1271 ) 

1272 if not hlgs: 

1273 return None 

1274 return _HLGExprSequence(*hlgs) 

1275 

1276 

1277class _AnnotationsTombstone: ... 

1278 

1279 

1280class FinalizeCompute(Expr): 

1281 _parameters = ["expr"] 

1282 

1283 def _simplify_down(self): 

1284 return self.expr.finalize_compute() 

1285 

1286 

1287def _convert_dask_keys(keys): 

1288 from dask._task_spec import List, TaskRef 

1289 

1290 assert isinstance(keys, list) 

1291 new_keys = [] 

1292 for key in keys: 

1293 if isinstance(key, list): 

1294 new_keys.append(_convert_dask_keys(key)) 

1295 else: 

1296 new_keys.append(TaskRef(key)) 

1297 return List(*new_keys) 

1298 

1299 

1300class HLGFinalizeCompute(HLGExpr): 

1301 

1302 def _simplify_down(self): 

1303 if not self.postcompute: 

1304 return self.dsk 

1305 

1306 from dask.delayed import Delayed 

1307 

1308 # Skip finalization for Delayed 

1309 if self.dsk.postcompute == Delayed.__dask_postcompute__(self.dsk): 

1310 return self.dsk 

1311 return self 

1312 

1313 @property 

1314 def _name(self): 

1315 return f"finalize-{super()._name}" 

1316 

1317 def __dask_graph__(self): 

1318 # The baseclass __dask_graph__ will not just materialize this layer but 

1319 # also that of its dependencies, i.e. it will render the finalized and 

1320 # the non-finalized graph and combine them. We only want the finalized 

1321 # so we're overriding this. 

1322 # This is an artifact generated since the wrapped expression is 

1323 # identified automatically as a dependency but HLG expressions are not 

1324 # working in this layered way. 

1325 return self._layer() 

1326 

1327 @property 

1328 def hlg(self): 

1329 expr = self.operand("dsk") 

1330 layers = expr.dsk.layers.copy() 

1331 deps = expr.dsk.dependencies.copy() 

1332 keys = expr.__dask_keys__() 

1333 if isinstance(expr.postcompute, list): 

1334 postcomputes = expr.postcompute 

1335 else: 

1336 postcomputes = [expr.postcompute] 

1337 tasks = [ 

1338 Task(self._name, func, _convert_dask_keys(keys), *extra_args) 

1339 for func, extra_args in postcomputes 

1340 ] 

1341 from dask.highlevelgraph import HighLevelGraph, MaterializedLayer 

1342 

1343 leafs = set(deps) 

1344 for val in deps.values(): 

1345 leafs -= val 

1346 for t in tasks: 

1347 layers[t.key] = MaterializedLayer({t.key: t}) 

1348 deps[t.key] = leafs 

1349 return HighLevelGraph(layers, dependencies=deps) 

1350 

1351 def __dask_keys__(self): 

1352 return [self._name] 

1353 

1354 

1355class ProhibitReuse(Expr): 

1356 """ 

1357 An expression that guarantees that all keys are suffixes with a unique id. 

1358 This can be used to break a common subexpression apart. 

1359 """ 

1360 

1361 _parameters = ["expr"] 

1362 _ALLOWED_TYPES = [HLGExpr, LLGExpr, HLGFinalizeCompute, _HLGExprSequence] 

1363 

1364 def __dask_keys__(self): 

1365 return self._modify_keys(self.expr.__dask_keys__()) 

1366 

1367 @staticmethod 

1368 def _identity(obj): 

1369 return obj 

1370 

1371 @functools.cached_property 

1372 def _suffix(self): 

1373 return uuid.uuid4().hex 

1374 

1375 def _modify_keys(self, k): 

1376 if isinstance(k, list): 

1377 return [self._modify_keys(kk) for kk in k] 

1378 elif isinstance(k, tuple): 

1379 return (self._modify_keys(k[0]),) + k[1:] 

1380 elif isinstance(k, (int, float)): 

1381 k = str(k) 

1382 return f"{k}-{self._suffix}" 

1383 

1384 def _simplify_down(self): 

1385 # FIXME: Shuffling cannot be rewritten since the barrier key is 

1386 # hardcoded. Skipping this here should do the trick most of the time 

1387 if not isinstance( 

1388 self.expr, 

1389 tuple(self._ALLOWED_TYPES), 

1390 ): 

1391 return self.expr 

1392 

1393 def __dask_graph__(self): 

1394 try: 

1395 from distributed.shuffle._core import P2PBarrierTask 

1396 except ModuleNotFoundError: 

1397 P2PBarrierTask = type(None) 

1398 dsk = convert_legacy_graph(self.expr.__dask_graph__()) 

1399 

1400 subs = {old_key: self._modify_keys(old_key) for old_key in dsk} 

1401 dsk2 = {} 

1402 for old_key, new_key in subs.items(): 

1403 t = dsk[old_key] 

1404 if isinstance(t, P2PBarrierTask): 

1405 warnings.warn( 

1406 "Cannot block reusing for graphs including a " 

1407 "P2PBarrierTask. This may cause unexpected results. " 

1408 "This typically happens when converting a dask " 

1409 "DataFrame to delayed objects.", 

1410 UserWarning, 

1411 ) 

1412 return dsk 

1413 dsk2[new_key] = Task( 

1414 new_key, 

1415 ProhibitReuse._identity, 

1416 t.substitute(subs), 

1417 ) 

1418 

1419 dsk2.update(dsk) 

1420 return dsk2 

1421 

1422 _layer = __dask_graph__