Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/SQLAlchemy-1.3.25.dev0-py3.11-linux-x86_64.egg/sqlalchemy/sql/util.py: 19%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

454 statements  

1# sql/util.py 

2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: http://www.opensource.org/licenses/mit-license.php 

7 

8"""High level utilities which build upon other modules here. 

9 

10""" 

11 

12from collections import deque 

13from itertools import chain 

14 

15from . import operators 

16from . import visitors 

17from .annotation import _deep_annotate # noqa 

18from .annotation import _deep_deannotate # noqa 

19from .annotation import _shallow_annotate # noqa 

20from .base import _from_objects 

21from .base import ColumnSet 

22from .ddl import sort_tables # noqa 

23from .elements import _expand_cloned 

24from .elements import _find_columns # noqa 

25from .elements import _label_reference 

26from .elements import _textual_label_reference 

27from .elements import BindParameter 

28from .elements import ColumnClause 

29from .elements import ColumnElement 

30from .elements import Grouping 

31from .elements import Label 

32from .elements import Null 

33from .elements import UnaryExpression 

34from .schema import Column 

35from .selectable import Alias 

36from .selectable import FromClause 

37from .selectable import FromGrouping 

38from .selectable import Join 

39from .selectable import ScalarSelect 

40from .selectable import SelectBase 

41from .selectable import TableClause 

42from .. import exc 

43from .. import util 

44 

45 

46join_condition = util.langhelpers.public_factory( 

47 Join._join_condition, ".sql.util.join_condition" 

48) 

49 

50 

51def find_join_source(clauses, join_to): 

52 """Given a list of FROM clauses and a selectable, 

53 return the first index and element from the list of 

54 clauses which can be joined against the selectable. returns 

55 None, None if no match is found. 

56 

57 e.g.:: 

58 

59 clause1 = table1.join(table2) 

60 clause2 = table4.join(table5) 

61 

62 join_to = table2.join(table3) 

63 

64 find_join_source([clause1, clause2], join_to) == clause1 

65 

66 """ 

67 

68 selectables = list(_from_objects(join_to)) 

69 idx = [] 

70 for i, f in enumerate(clauses): 

71 for s in selectables: 

72 if f.is_derived_from(s): 

73 idx.append(i) 

74 return idx 

75 

76 

77def find_left_clause_that_matches_given(clauses, join_from): 

78 """Given a list of FROM clauses and a selectable, 

79 return the indexes from the list of 

80 clauses which is derived from the selectable. 

81 

82 """ 

83 

84 selectables = list(_from_objects(join_from)) 

85 liberal_idx = [] 

86 for i, f in enumerate(clauses): 

87 for s in selectables: 

88 # basic check, if f is derived from s. 

89 # this can be joins containing a table, or an aliased table 

90 # or select statement matching to a table. This check 

91 # will match a table to a selectable that is adapted from 

92 # that table. With Query, this suits the case where a join 

93 # is being made to an adapted entity 

94 if f.is_derived_from(s): 

95 liberal_idx.append(i) 

96 break 

97 

98 # in an extremely small set of use cases, a join is being made where 

99 # there are multiple FROM clauses where our target table is represented 

100 # in more than one, such as embedded or similar. in this case, do 

101 # another pass where we try to get a more exact match where we aren't 

102 # looking at adaption relationships. 

103 if len(liberal_idx) > 1: 

104 conservative_idx = [] 

105 for idx in liberal_idx: 

106 f = clauses[idx] 

107 for s in selectables: 

108 if set(surface_selectables(f)).intersection( 

109 surface_selectables(s) 

110 ): 

111 conservative_idx.append(idx) 

112 break 

113 if conservative_idx: 

114 return conservative_idx 

115 

116 return liberal_idx 

117 

118 

119def find_left_clause_to_join_from(clauses, join_to, onclause): 

120 """Given a list of FROM clauses, a selectable, 

121 and optional ON clause, return a list of integer indexes from the 

122 clauses list indicating the clauses that can be joined from. 

123 

124 The presence of an "onclause" indicates that at least one clause can 

125 definitely be joined from; if the list of clauses is of length one 

126 and the onclause is given, returns that index. If the list of clauses 

127 is more than length one, and the onclause is given, attempts to locate 

128 which clauses contain the same columns. 

129 

130 """ 

131 idx = [] 

132 selectables = set(_from_objects(join_to)) 

133 

134 # if we are given more than one target clause to join 

135 # from, use the onclause to provide a more specific answer. 

136 # otherwise, don't try to limit, after all, "ON TRUE" is a valid 

137 # on clause 

138 if len(clauses) > 1 and onclause is not None: 

139 resolve_ambiguity = True 

140 cols_in_onclause = _find_columns(onclause) 

141 else: 

142 resolve_ambiguity = False 

143 cols_in_onclause = None 

144 

145 for i, f in enumerate(clauses): 

146 for s in selectables.difference([f]): 

147 if resolve_ambiguity: 

148 if set(f.c).union(s.c).issuperset(cols_in_onclause): 

149 idx.append(i) 

150 break 

151 elif Join._can_join(f, s) or onclause is not None: 

152 idx.append(i) 

153 break 

154 

155 if len(idx) > 1: 

156 # this is the same "hide froms" logic from 

157 # Selectable._get_display_froms 

158 toremove = set( 

159 chain(*[_expand_cloned(f._hide_froms) for f in clauses]) 

160 ) 

161 idx = [i for i in idx if clauses[i] not in toremove] 

162 

163 # onclause was given and none of them resolved, so assume 

164 # all indexes can match 

165 if not idx and onclause is not None: 

166 return range(len(clauses)) 

167 else: 

168 return idx 

169 

170 

171def visit_binary_product(fn, expr): 

172 """Produce a traversal of the given expression, delivering 

173 column comparisons to the given function. 

174 

175 The function is of the form:: 

176 

177 def my_fn(binary, left, right) 

178 

179 For each binary expression located which has a 

180 comparison operator, the product of "left" and 

181 "right" will be delivered to that function, 

182 in terms of that binary. 

183 

184 Hence an expression like:: 

185 

186 and_( 

187 (a + b) == q + func.sum(e + f), 

188 j == r 

189 ) 

190 

191 would have the traversal:: 

192 

193 a <eq> q 

194 a <eq> e 

195 a <eq> f 

196 b <eq> q 

197 b <eq> e 

198 b <eq> f 

199 j <eq> r 

200 

201 That is, every combination of "left" and 

202 "right" that doesn't further contain 

203 a binary comparison is passed as pairs. 

204 

205 """ 

206 stack = [] 

207 

208 def visit(element): 

209 if isinstance(element, ScalarSelect): 

210 # we don't want to dig into correlated subqueries, 

211 # those are just column elements by themselves 

212 yield element 

213 elif element.__visit_name__ == "binary" and operators.is_comparison( 

214 element.operator 

215 ): 

216 stack.insert(0, element) 

217 for l in visit(element.left): 

218 for r in visit(element.right): 

219 fn(stack[0], l, r) 

220 stack.pop(0) 

221 for elem in element.get_children(): 

222 visit(elem) 

223 else: 

224 if isinstance(element, ColumnClause): 

225 yield element 

226 for elem in element.get_children(): 

227 for e in visit(elem): 

228 yield e 

229 

230 list(visit(expr)) 

231 visit = None # remove gc cycles 

232 

233 

234def find_tables( 

235 clause, 

236 check_columns=False, 

237 include_aliases=False, 

238 include_joins=False, 

239 include_selects=False, 

240 include_crud=False, 

241): 

242 """locate Table objects within the given expression.""" 

243 

244 tables = [] 

245 _visitors = {} 

246 

247 if include_selects: 

248 _visitors["select"] = _visitors["compound_select"] = tables.append 

249 

250 if include_joins: 

251 _visitors["join"] = tables.append 

252 

253 if include_aliases: 

254 _visitors["alias"] = tables.append 

255 

256 if include_crud: 

257 _visitors["insert"] = _visitors["update"] = _visitors[ 

258 "delete" 

259 ] = lambda ent: tables.append(ent.table) 

260 

261 if check_columns: 

262 

263 def visit_column(column): 

264 tables.append(column.table) 

265 

266 _visitors["column"] = visit_column 

267 

268 _visitors["table"] = tables.append 

269 

270 visitors.traverse(clause, {"column_collections": False}, _visitors) 

271 return tables 

272 

273 

274def unwrap_order_by(clause): 

275 """Break up an 'order by' expression into individual column-expressions, 

276 without DESC/ASC/NULLS FIRST/NULLS LAST""" 

277 

278 cols = util.column_set() 

279 result = [] 

280 stack = deque([clause]) 

281 while stack: 

282 t = stack.popleft() 

283 if isinstance(t, ColumnElement) and ( 

284 not isinstance(t, UnaryExpression) 

285 or not operators.is_ordering_modifier(t.modifier) 

286 ): 

287 if isinstance(t, Label) and not isinstance( 

288 t.element, ScalarSelect 

289 ): 

290 t = t.element 

291 

292 if isinstance(t, Grouping): 

293 t = t.element 

294 

295 stack.append(t) 

296 continue 

297 

298 if isinstance(t, _label_reference): 

299 t = t.element 

300 if isinstance(t, (_textual_label_reference)): 

301 continue 

302 if t not in cols: 

303 cols.add(t) 

304 result.append(t) 

305 else: 

306 for c in t.get_children(): 

307 stack.append(c) 

308 return result 

309 

310 

311def unwrap_label_reference(element): 

312 def replace(elem): 

313 if isinstance(elem, (_label_reference, _textual_label_reference)): 

314 return elem.element 

315 

316 return visitors.replacement_traverse(element, {}, replace) 

317 

318 

319def expand_column_list_from_order_by(collist, order_by): 

320 """Given the columns clause and ORDER BY of a selectable, 

321 return a list of column expressions that can be added to the collist 

322 corresponding to the ORDER BY, without repeating those already 

323 in the collist. 

324 

325 """ 

326 cols_already_present = set( 

327 [ 

328 col.element if col._order_by_label_element is not None else col 

329 for col in collist 

330 ] 

331 ) 

332 

333 to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by])) 

334 

335 return [col for col in to_look_for if col not in cols_already_present] 

336 

337 

338def clause_is_present(clause, search): 

339 """Given a target clause and a second to search within, return True 

340 if the target is plainly present in the search without any 

341 subqueries or aliases involved. 

342 

343 Basically descends through Joins. 

344 

345 """ 

346 

347 for elem in surface_selectables(search): 

348 if clause == elem: # use == here so that Annotated's compare 

349 return True 

350 else: 

351 return False 

352 

353 

354def tables_from_leftmost(clause): 

355 if isinstance(clause, Join): 

356 for t in tables_from_leftmost(clause.left): 

357 yield t 

358 for t in tables_from_leftmost(clause.right): 

359 yield t 

360 elif isinstance(clause, FromGrouping): 

361 for t in tables_from_leftmost(clause.element): 

362 yield t 

363 else: 

364 yield clause 

365 

366 

367def surface_selectables(clause): 

368 stack = [clause] 

369 while stack: 

370 elem = stack.pop() 

371 yield elem 

372 if isinstance(elem, Join): 

373 stack.extend((elem.left, elem.right)) 

374 elif isinstance(elem, FromGrouping): 

375 stack.append(elem.element) 

376 

377 

378def surface_selectables_only(clause): 

379 stack = [clause] 

380 while stack: 

381 elem = stack.pop() 

382 if isinstance(elem, (TableClause, Alias)): 

383 yield elem 

384 if isinstance(elem, Join): 

385 stack.extend((elem.left, elem.right)) 

386 elif isinstance(elem, FromGrouping): 

387 stack.append(elem.element) 

388 elif isinstance(elem, ColumnClause): 

389 if elem.table is not None: 

390 stack.append(elem.table) 

391 else: 

392 yield elem 

393 elif elem is not None: 

394 yield elem 

395 

396 

397def surface_column_elements(clause, include_scalar_selects=True): 

398 """traverse and yield only outer-exposed column elements, such as would 

399 be addressable in the WHERE clause of a SELECT if this element were 

400 in the columns clause.""" 

401 

402 filter_ = (FromGrouping,) 

403 if not include_scalar_selects: 

404 filter_ += (SelectBase,) 

405 

406 stack = deque([clause]) 

407 while stack: 

408 elem = stack.popleft() 

409 yield elem 

410 for sub in elem.get_children(): 

411 if isinstance(sub, filter_): 

412 continue 

413 stack.append(sub) 

414 

415 

416def selectables_overlap(left, right): 

417 """Return True if left/right have some overlapping selectable""" 

418 

419 return bool( 

420 set(surface_selectables(left)).intersection(surface_selectables(right)) 

421 ) 

422 

423 

424def bind_values(clause): 

425 """Return an ordered list of "bound" values in the given clause. 

426 

427 E.g.:: 

428 

429 >>> expr = and_( 

430 ... table.c.foo==5, table.c.foo==7 

431 ... ) 

432 >>> bind_values(expr) 

433 [5, 7] 

434 """ 

435 

436 v = [] 

437 

438 def visit_bindparam(bind): 

439 v.append(bind.effective_value) 

440 

441 visitors.traverse(clause, {}, {"bindparam": visit_bindparam}) 

442 return v 

443 

444 

445def _quote_ddl_expr(element): 

446 if isinstance(element, util.string_types): 

447 element = element.replace("'", "''") 

448 return "'%s'" % element 

449 else: 

450 return repr(element) 

451 

452 

453class _repr_base(object): 

454 _LIST = 0 

455 _TUPLE = 1 

456 _DICT = 2 

457 

458 __slots__ = ("max_chars",) 

459 

460 def trunc(self, value): 

461 rep = repr(value) 

462 lenrep = len(rep) 

463 if lenrep > self.max_chars: 

464 segment_length = self.max_chars // 2 

465 rep = ( 

466 rep[0:segment_length] 

467 + ( 

468 " ... (%d characters truncated) ... " 

469 % (lenrep - self.max_chars) 

470 ) 

471 + rep[-segment_length:] 

472 ) 

473 return rep 

474 

475 

476class _repr_row(_repr_base): 

477 """Provide a string view of a row.""" 

478 

479 __slots__ = ("row",) 

480 

481 def __init__(self, row, max_chars=300): 

482 self.row = row 

483 self.max_chars = max_chars 

484 

485 def __repr__(self): 

486 trunc = self.trunc 

487 return "(%s%s)" % ( 

488 ", ".join(trunc(value) for value in self.row), 

489 "," if len(self.row) == 1 else "", 

490 ) 

491 

492 

493class _repr_params(_repr_base): 

494 """Provide a string view of bound parameters. 

495 

496 Truncates display to a given numnber of 'multi' parameter sets, 

497 as well as long values to a given number of characters. 

498 

499 """ 

500 

501 __slots__ = "params", "batches", "ismulti" 

502 

503 def __init__(self, params, batches, max_chars=300, ismulti=None): 

504 self.params = params 

505 self.ismulti = ismulti 

506 self.batches = batches 

507 self.max_chars = max_chars 

508 

509 def __repr__(self): 

510 if self.ismulti is None: 

511 return self.trunc(self.params) 

512 

513 if isinstance(self.params, list): 

514 typ = self._LIST 

515 

516 elif isinstance(self.params, tuple): 

517 typ = self._TUPLE 

518 elif isinstance(self.params, dict): 

519 typ = self._DICT 

520 else: 

521 return self.trunc(self.params) 

522 

523 if self.ismulti and len(self.params) > self.batches: 

524 msg = " ... displaying %i of %i total bound parameter sets ... " 

525 return " ".join( 

526 ( 

527 self._repr_multi(self.params[: self.batches - 2], typ)[ 

528 0:-1 

529 ], 

530 msg % (self.batches, len(self.params)), 

531 self._repr_multi(self.params[-2:], typ)[1:], 

532 ) 

533 ) 

534 elif self.ismulti: 

535 return self._repr_multi(self.params, typ) 

536 else: 

537 return self._repr_params(self.params, typ) 

538 

539 def _repr_multi(self, multi_params, typ): 

540 if multi_params: 

541 if isinstance(multi_params[0], list): 

542 elem_type = self._LIST 

543 elif isinstance(multi_params[0], tuple): 

544 elem_type = self._TUPLE 

545 elif isinstance(multi_params[0], dict): 

546 elem_type = self._DICT 

547 else: 

548 assert False, "Unknown parameter type %s" % ( 

549 type(multi_params[0]) 

550 ) 

551 

552 elements = ", ".join( 

553 self._repr_params(params, elem_type) for params in multi_params 

554 ) 

555 else: 

556 elements = "" 

557 

558 if typ == self._LIST: 

559 return "[%s]" % elements 

560 else: 

561 return "(%s)" % elements 

562 

563 def _repr_params(self, params, typ): 

564 trunc = self.trunc 

565 if typ is self._DICT: 

566 return "{%s}" % ( 

567 ", ".join( 

568 "%r: %s" % (key, trunc(value)) 

569 for key, value in params.items() 

570 ) 

571 ) 

572 elif typ is self._TUPLE: 

573 return "(%s%s)" % ( 

574 ", ".join(trunc(value) for value in params), 

575 "," if len(params) == 1 else "", 

576 ) 

577 else: 

578 return "[%s]" % (", ".join(trunc(value) for value in params)) 

579 

580 

581def adapt_criterion_to_null(crit, nulls): 

582 """given criterion containing bind params, convert selected elements 

583 to IS NULL. 

584 

585 """ 

586 

587 def visit_binary(binary): 

588 if ( 

589 isinstance(binary.left, BindParameter) 

590 and binary.left._identifying_key in nulls 

591 ): 

592 # reverse order if the NULL is on the left side 

593 binary.left = binary.right 

594 binary.right = Null() 

595 binary.operator = operators.is_ 

596 binary.negate = operators.isnot 

597 elif ( 

598 isinstance(binary.right, BindParameter) 

599 and binary.right._identifying_key in nulls 

600 ): 

601 binary.right = Null() 

602 binary.operator = operators.is_ 

603 binary.negate = operators.isnot 

604 

605 return visitors.cloned_traverse(crit, {}, {"binary": visit_binary}) 

606 

607 

608def splice_joins(left, right, stop_on=None): 

609 if left is None: 

610 return right 

611 

612 stack = [(right, None)] 

613 

614 adapter = ClauseAdapter(left) 

615 ret = None 

616 while stack: 

617 (right, prevright) = stack.pop() 

618 if isinstance(right, Join) and right is not stop_on: 

619 right = right._clone() 

620 right._reset_exported() 

621 right.onclause = adapter.traverse(right.onclause) 

622 stack.append((right.left, right)) 

623 else: 

624 right = adapter.traverse(right) 

625 if prevright is not None: 

626 prevright.left = right 

627 if ret is None: 

628 ret = right 

629 

630 return ret 

631 

632 

633def reduce_columns(columns, *clauses, **kw): 

634 r"""given a list of columns, return a 'reduced' set based on natural 

635 equivalents. 

636 

637 the set is reduced to the smallest list of columns which have no natural 

638 equivalent present in the list. A "natural equivalent" means that two 

639 columns will ultimately represent the same value because they are related 

640 by a foreign key. 

641 

642 \*clauses is an optional list of join clauses which will be traversed 

643 to further identify columns that are "equivalent". 

644 

645 \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys 

646 whose tables are not yet configured, or columns that aren't yet present. 

647 

648 This function is primarily used to determine the most minimal "primary 

649 key" from a selectable, by reducing the set of primary key columns present 

650 in the selectable to just those that are not repeated. 

651 

652 """ 

653 ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False) 

654 only_synonyms = kw.pop("only_synonyms", False) 

655 

656 columns = util.ordered_column_set(columns) 

657 

658 omit = util.column_set() 

659 for col in columns: 

660 for fk in chain(*[c.foreign_keys for c in col.proxy_set]): 

661 for c in columns: 

662 if c is col: 

663 continue 

664 try: 

665 fk_col = fk.column 

666 except exc.NoReferencedColumnError: 

667 # TODO: add specific coverage here 

668 # to test/sql/test_selectable ReduceTest 

669 if ignore_nonexistent_tables: 

670 continue 

671 else: 

672 raise 

673 except exc.NoReferencedTableError: 

674 # TODO: add specific coverage here 

675 # to test/sql/test_selectable ReduceTest 

676 if ignore_nonexistent_tables: 

677 continue 

678 else: 

679 raise 

680 if fk_col.shares_lineage(c) and ( 

681 not only_synonyms or c.name == col.name 

682 ): 

683 omit.add(col) 

684 break 

685 

686 if clauses: 

687 

688 def visit_binary(binary): 

689 if binary.operator == operators.eq: 

690 cols = util.column_set( 

691 chain(*[c.proxy_set for c in columns.difference(omit)]) 

692 ) 

693 if binary.left in cols and binary.right in cols: 

694 for c in reversed(columns): 

695 if c.shares_lineage(binary.right) and ( 

696 not only_synonyms or c.name == binary.left.name 

697 ): 

698 omit.add(c) 

699 break 

700 

701 for clause in clauses: 

702 if clause is not None: 

703 visitors.traverse(clause, {}, {"binary": visit_binary}) 

704 

705 return ColumnSet(columns.difference(omit)) 

706 

707 

708def criterion_as_pairs( 

709 expression, 

710 consider_as_foreign_keys=None, 

711 consider_as_referenced_keys=None, 

712 any_operator=False, 

713): 

714 """traverse an expression and locate binary criterion pairs.""" 

715 

716 if consider_as_foreign_keys and consider_as_referenced_keys: 

717 raise exc.ArgumentError( 

718 "Can only specify one of " 

719 "'consider_as_foreign_keys' or " 

720 "'consider_as_referenced_keys'" 

721 ) 

722 

723 def col_is(a, b): 

724 # return a is b 

725 return a.compare(b) 

726 

727 def visit_binary(binary): 

728 if not any_operator and binary.operator is not operators.eq: 

729 return 

730 if not isinstance(binary.left, ColumnElement) or not isinstance( 

731 binary.right, ColumnElement 

732 ): 

733 return 

734 

735 if consider_as_foreign_keys: 

736 if binary.left in consider_as_foreign_keys and ( 

737 col_is(binary.right, binary.left) 

738 or binary.right not in consider_as_foreign_keys 

739 ): 

740 pairs.append((binary.right, binary.left)) 

741 elif binary.right in consider_as_foreign_keys and ( 

742 col_is(binary.left, binary.right) 

743 or binary.left not in consider_as_foreign_keys 

744 ): 

745 pairs.append((binary.left, binary.right)) 

746 elif consider_as_referenced_keys: 

747 if binary.left in consider_as_referenced_keys and ( 

748 col_is(binary.right, binary.left) 

749 or binary.right not in consider_as_referenced_keys 

750 ): 

751 pairs.append((binary.left, binary.right)) 

752 elif binary.right in consider_as_referenced_keys and ( 

753 col_is(binary.left, binary.right) 

754 or binary.left not in consider_as_referenced_keys 

755 ): 

756 pairs.append((binary.right, binary.left)) 

757 else: 

758 if isinstance(binary.left, Column) and isinstance( 

759 binary.right, Column 

760 ): 

761 if binary.left.references(binary.right): 

762 pairs.append((binary.right, binary.left)) 

763 elif binary.right.references(binary.left): 

764 pairs.append((binary.left, binary.right)) 

765 

766 pairs = [] 

767 visitors.traverse(expression, {}, {"binary": visit_binary}) 

768 return pairs 

769 

770 

771class ClauseAdapter(visitors.ReplacingCloningVisitor): 

772 """Clones and modifies clauses based on column correspondence. 

773 

774 E.g.:: 

775 

776 table1 = Table('sometable', metadata, 

777 Column('col1', Integer), 

778 Column('col2', Integer) 

779 ) 

780 table2 = Table('someothertable', metadata, 

781 Column('col1', Integer), 

782 Column('col2', Integer) 

783 ) 

784 

785 condition = table1.c.col1 == table2.c.col1 

786 

787 make an alias of table1:: 

788 

789 s = table1.alias('foo') 

790 

791 calling ``ClauseAdapter(s).traverse(condition)`` converts 

792 condition to read:: 

793 

794 s.c.col1 == table2.c.col1 

795 

796 """ 

797 

798 def __init__( 

799 self, 

800 selectable, 

801 equivalents=None, 

802 include_fn=None, 

803 exclude_fn=None, 

804 adapt_on_names=False, 

805 anonymize_labels=False, 

806 ): 

807 self.__traverse_options__ = { 

808 "stop_on": [selectable], 

809 "anonymize_labels": anonymize_labels, 

810 } 

811 self.selectable = selectable 

812 self.include_fn = include_fn 

813 self.exclude_fn = exclude_fn 

814 self.equivalents = util.column_dict(equivalents or {}) 

815 self.adapt_on_names = adapt_on_names 

816 

817 def _corresponding_column( 

818 self, col, require_embedded, _seen=util.EMPTY_SET 

819 ): 

820 newcol = self.selectable.corresponding_column( 

821 col, require_embedded=require_embedded 

822 ) 

823 if newcol is None and col in self.equivalents and col not in _seen: 

824 for equiv in self.equivalents[col]: 

825 newcol = self._corresponding_column( 

826 equiv, 

827 require_embedded=require_embedded, 

828 _seen=_seen.union([col]), 

829 ) 

830 if newcol is not None: 

831 return newcol 

832 if self.adapt_on_names and newcol is None: 

833 newcol = self.selectable.c.get(col.name) 

834 return newcol 

835 

836 def replace(self, col): 

837 if isinstance(col, FromClause) and self.selectable.is_derived_from( 

838 col 

839 ): 

840 return self.selectable 

841 elif not isinstance(col, ColumnElement): 

842 return None 

843 elif self.include_fn and not self.include_fn(col): 

844 return None 

845 elif self.exclude_fn and self.exclude_fn(col): 

846 return None 

847 else: 

848 return self._corresponding_column(col, True) 

849 

850 

851class ColumnAdapter(ClauseAdapter): 

852 """Extends ClauseAdapter with extra utility functions. 

853 

854 Key aspects of ColumnAdapter include: 

855 

856 * Expressions that are adapted are stored in a persistent 

857 .columns collection; so that an expression E adapted into 

858 an expression E1, will return the same object E1 when adapted 

859 a second time. This is important in particular for things like 

860 Label objects that are anonymized, so that the ColumnAdapter can 

861 be used to present a consistent "adapted" view of things. 

862 

863 * Exclusion of items from the persistent collection based on 

864 include/exclude rules, but also independent of hash identity. 

865 This because "annotated" items all have the same hash identity as their 

866 parent. 

867 

868 * "wrapping" capability is added, so that the replacement of an expression 

869 E can proceed through a series of adapters. This differs from the 

870 visitor's "chaining" feature in that the resulting object is passed 

871 through all replacing functions unconditionally, rather than stopping 

872 at the first one that returns non-None. 

873 

874 * An adapt_required option, used by eager loading to indicate that 

875 We don't trust a result row column that is not translated. 

876 This is to prevent a column from being interpreted as that 

877 of the child row in a self-referential scenario, see 

878 inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency 

879 

880 """ 

881 

882 def __init__( 

883 self, 

884 selectable, 

885 equivalents=None, 

886 adapt_required=False, 

887 include_fn=None, 

888 exclude_fn=None, 

889 adapt_on_names=False, 

890 allow_label_resolve=True, 

891 anonymize_labels=False, 

892 ): 

893 ClauseAdapter.__init__( 

894 self, 

895 selectable, 

896 equivalents, 

897 include_fn=include_fn, 

898 exclude_fn=exclude_fn, 

899 adapt_on_names=adapt_on_names, 

900 anonymize_labels=anonymize_labels, 

901 ) 

902 

903 self.columns = util.WeakPopulateDict(self._locate_col) 

904 if self.include_fn or self.exclude_fn: 

905 self.columns = self._IncludeExcludeMapping(self, self.columns) 

906 self.adapt_required = adapt_required 

907 self.allow_label_resolve = allow_label_resolve 

908 self._wrap = None 

909 

910 class _IncludeExcludeMapping(object): 

911 def __init__(self, parent, columns): 

912 self.parent = parent 

913 self.columns = columns 

914 

915 def __getitem__(self, key): 

916 if ( 

917 self.parent.include_fn and not self.parent.include_fn(key) 

918 ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)): 

919 if self.parent._wrap: 

920 return self.parent._wrap.columns[key] 

921 else: 

922 return key 

923 return self.columns[key] 

924 

925 def wrap(self, adapter): 

926 ac = self.__class__.__new__(self.__class__) 

927 ac.__dict__.update(self.__dict__) 

928 ac._wrap = adapter 

929 ac.columns = util.WeakPopulateDict(ac._locate_col) 

930 if ac.include_fn or ac.exclude_fn: 

931 ac.columns = self._IncludeExcludeMapping(ac, ac.columns) 

932 

933 return ac 

934 

935 def traverse(self, obj): 

936 return self.columns[obj] 

937 

938 adapt_clause = traverse 

939 adapt_list = ClauseAdapter.copy_and_process 

940 

941 def _locate_col(self, col): 

942 

943 c = ClauseAdapter.traverse(self, col) 

944 

945 if self._wrap: 

946 c2 = self._wrap._locate_col(c) 

947 if c2 is not None: 

948 c = c2 

949 

950 if self.adapt_required and c is col: 

951 return None 

952 

953 c._allow_label_resolve = self.allow_label_resolve 

954 

955 return c 

956 

957 def __getstate__(self): 

958 d = self.__dict__.copy() 

959 del d["columns"] 

960 return d 

961 

962 def __setstate__(self, state): 

963 self.__dict__.update(state) 

964 self.columns = util.WeakPopulateDict(self._locate_col)