Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/sqlalchemy/sql/lambdas.py: 24%

533 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# sql/lambdas.py 

2# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: https://www.opensource.org/licenses/mit-license.php 

7 

8import inspect 

9import itertools 

10import operator 

11import sys 

12import threading 

13import types 

14import weakref 

15 

16from . import coercions 

17from . import elements 

18from . import roles 

19from . import schema 

20from . import traversals 

21from . import type_api 

22from . import visitors 

23from .base import _clone 

24from .base import Options 

25from .operators import ColumnOperators 

26from .. import exc 

27from .. import inspection 

28from .. import util 

29from ..util import collections_abc 

30from ..util import compat 

31 

32_closure_per_cache_key = util.LRUCache(1000) 

33 

34 

35class LambdaOptions(Options): 

36 enable_tracking = True 

37 track_closure_variables = True 

38 track_on = None 

39 global_track_bound_values = True 

40 track_bound_values = True 

41 lambda_cache = None 

42 

43 

44def lambda_stmt( 

45 lmb, 

46 enable_tracking=True, 

47 track_closure_variables=True, 

48 track_on=None, 

49 global_track_bound_values=True, 

50 track_bound_values=True, 

51 lambda_cache=None, 

52): 

53 """Produce a SQL statement that is cached as a lambda. 

54 

55 The Python code object within the lambda is scanned for both Python 

56 literals that will become bound parameters as well as closure variables 

57 that refer to Core or ORM constructs that may vary. The lambda itself 

58 will be invoked only once per particular set of constructs detected. 

59 

60 E.g.:: 

61 

62 from sqlalchemy import lambda_stmt 

63 

64 stmt = lambda_stmt(lambda: table.select()) 

65 stmt += lambda s: s.where(table.c.id == 5) 

66 

67 result = connection.execute(stmt) 

68 

69 The object returned is an instance of :class:`_sql.StatementLambdaElement`. 

70 

71 .. versionadded:: 1.4 

72 

73 :param lmb: a Python function, typically a lambda, which takes no arguments 

74 and returns a SQL expression construct 

75 :param enable_tracking: when False, all scanning of the given lambda for 

76 changes in closure variables or bound parameters is disabled. Use for 

77 a lambda that produces the identical results in all cases with no 

78 parameterization. 

79 :param track_closure_variables: when False, changes in closure variables 

80 within the lambda will not be scanned. Use for a lambda where the 

81 state of its closure variables will never change the SQL structure 

82 returned by the lambda. 

83 :param track_bound_values: when False, bound parameter tracking will 

84 be disabled for the given lambda. Use for a lambda that either does 

85 not produce any bound values, or where the initial bound values never 

86 change. 

87 :param global_track_bound_values: when False, bound parameter tracking 

88 will be disabled for the entire statement including additional links 

89 added via the :meth:`_sql.StatementLambdaElement.add_criteria` method. 

90 :param lambda_cache: a dictionary or other mapping-like object where 

91 information about the lambda's Python code as well as the tracked closure 

92 variables in the lambda itself will be stored. Defaults 

93 to a global LRU cache. This cache is independent of the "compiled_cache" 

94 used by the :class:`_engine.Connection` object. 

95 

96 .. seealso:: 

97 

98 :ref:`engine_lambda_caching` 

99 

100 

101 """ 

102 

103 return StatementLambdaElement( 

104 lmb, 

105 roles.StatementRole, 

106 LambdaOptions( 

107 enable_tracking=enable_tracking, 

108 track_on=track_on, 

109 track_closure_variables=track_closure_variables, 

110 global_track_bound_values=global_track_bound_values, 

111 track_bound_values=track_bound_values, 

112 lambda_cache=lambda_cache, 

113 ), 

114 ) 

115 

116 

117class LambdaElement(elements.ClauseElement): 

118 """A SQL construct where the state is stored as an un-invoked lambda. 

119 

120 The :class:`_sql.LambdaElement` is produced transparently whenever 

121 passing lambda expressions into SQL constructs, such as:: 

122 

123 stmt = select(table).where(lambda: table.c.col == parameter) 

124 

125 The :class:`_sql.LambdaElement` is the base of the 

126 :class:`_sql.StatementLambdaElement` which represents a full statement 

127 within a lambda. 

128 

129 .. versionadded:: 1.4 

130 

131 .. seealso:: 

132 

133 :ref:`engine_lambda_caching` 

134 

135 """ 

136 

137 __visit_name__ = "lambda_element" 

138 

139 _is_lambda_element = True 

140 

141 _traverse_internals = [ 

142 ("_resolved", visitors.InternalTraversal.dp_clauseelement) 

143 ] 

144 

145 _transforms = () 

146 

147 parent_lambda = None 

148 

149 def __repr__(self): 

150 return "%s(%r)" % (self.__class__.__name__, self.fn.__code__) 

151 

152 def __init__( 

153 self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None 

154 ): 

155 self.fn = fn 

156 self.role = role 

157 self.tracker_key = (fn.__code__,) 

158 self.opts = opts 

159 

160 if apply_propagate_attrs is None and (role is roles.StatementRole): 

161 apply_propagate_attrs = self 

162 

163 rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts) 

164 

165 if apply_propagate_attrs is not None: 

166 propagate_attrs = rec.propagate_attrs 

167 if propagate_attrs: 

168 apply_propagate_attrs._propagate_attrs = propagate_attrs 

169 

170 def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts): 

171 lambda_cache = opts.lambda_cache 

172 if lambda_cache is None: 

173 lambda_cache = _closure_per_cache_key 

174 

175 tracker_key = self.tracker_key 

176 

177 fn = self.fn 

178 closure = fn.__closure__ 

179 tracker = AnalyzedCode.get( 

180 fn, 

181 self, 

182 opts, 

183 ) 

184 

185 self._resolved_bindparams = bindparams = [] 

186 

187 if self.parent_lambda is not None: 

188 parent_closure_cache_key = self.parent_lambda.closure_cache_key 

189 else: 

190 parent_closure_cache_key = () 

191 

192 if parent_closure_cache_key is not traversals.NO_CACHE: 

193 anon_map = traversals.anon_map() 

194 cache_key = tuple( 

195 [ 

196 getter(closure, opts, anon_map, bindparams) 

197 for getter in tracker.closure_trackers 

198 ] 

199 ) 

200 

201 if traversals.NO_CACHE not in anon_map: 

202 cache_key = parent_closure_cache_key + cache_key 

203 

204 self.closure_cache_key = cache_key 

205 

206 try: 

207 rec = lambda_cache[tracker_key + cache_key] 

208 except KeyError: 

209 rec = None 

210 else: 

211 cache_key = traversals.NO_CACHE 

212 rec = None 

213 

214 else: 

215 cache_key = traversals.NO_CACHE 

216 rec = None 

217 

218 self.closure_cache_key = cache_key 

219 

220 if rec is None: 

221 if cache_key is not traversals.NO_CACHE: 

222 

223 with AnalyzedCode._generation_mutex: 

224 key = tracker_key + cache_key 

225 if key not in lambda_cache: 

226 rec = AnalyzedFunction( 

227 tracker, self, apply_propagate_attrs, fn 

228 ) 

229 rec.closure_bindparams = bindparams 

230 lambda_cache[key] = rec 

231 else: 

232 rec = lambda_cache[key] 

233 else: 

234 rec = NonAnalyzedFunction(self._invoke_user_fn(fn)) 

235 

236 else: 

237 bindparams[:] = [ 

238 orig_bind._with_value(new_bind.value, maintain_key=True) 

239 for orig_bind, new_bind in zip( 

240 rec.closure_bindparams, bindparams 

241 ) 

242 ] 

243 

244 self._rec = rec 

245 

246 if cache_key is not traversals.NO_CACHE: 

247 if self.parent_lambda is not None: 

248 bindparams[:0] = self.parent_lambda._resolved_bindparams 

249 

250 lambda_element = self 

251 while lambda_element is not None: 

252 rec = lambda_element._rec 

253 if rec.bindparam_trackers: 

254 tracker_instrumented_fn = rec.tracker_instrumented_fn 

255 for tracker in rec.bindparam_trackers: 

256 tracker( 

257 lambda_element.fn, 

258 tracker_instrumented_fn, 

259 bindparams, 

260 ) 

261 lambda_element = lambda_element.parent_lambda 

262 

263 return rec 

264 

265 def __getattr__(self, key): 

266 return getattr(self._rec.expected_expr, key) 

267 

268 @property 

269 def _is_sequence(self): 

270 return self._rec.is_sequence 

271 

272 @property 

273 def _select_iterable(self): 

274 if self._is_sequence: 

275 return itertools.chain.from_iterable( 

276 [element._select_iterable for element in self._resolved] 

277 ) 

278 

279 else: 

280 return self._resolved._select_iterable 

281 

282 @property 

283 def _from_objects(self): 

284 if self._is_sequence: 

285 return itertools.chain.from_iterable( 

286 [element._from_objects for element in self._resolved] 

287 ) 

288 

289 else: 

290 return self._resolved._from_objects 

291 

292 def _param_dict(self): 

293 return {b.key: b.value for b in self._resolved_bindparams} 

294 

295 def _setup_binds_for_tracked_expr(self, expr): 

296 bindparam_lookup = {b.key: b for b in self._resolved_bindparams} 

297 

298 def replace(thing): 

299 if isinstance(thing, elements.BindParameter): 

300 

301 if thing.key in bindparam_lookup: 

302 bind = bindparam_lookup[thing.key] 

303 if thing.expanding: 

304 bind.expanding = True 

305 bind.expand_op = thing.expand_op 

306 bind.type = thing.type 

307 return bind 

308 

309 if self._rec.is_sequence: 

310 expr = [ 

311 visitors.replacement_traverse(sub_expr, {}, replace) 

312 for sub_expr in expr 

313 ] 

314 elif getattr(expr, "is_clause_element", False): 

315 expr = visitors.replacement_traverse(expr, {}, replace) 

316 

317 return expr 

318 

319 def _copy_internals( 

320 self, clone=_clone, deferred_copy_internals=None, **kw 

321 ): 

322 # TODO: this needs A LOT of tests 

323 self._resolved = clone( 

324 self._resolved, 

325 deferred_copy_internals=deferred_copy_internals, 

326 **kw 

327 ) 

328 

329 @util.memoized_property 

330 def _resolved(self): 

331 expr = self._rec.expected_expr 

332 

333 if self._resolved_bindparams: 

334 expr = self._setup_binds_for_tracked_expr(expr) 

335 

336 return expr 

337 

338 def _gen_cache_key(self, anon_map, bindparams): 

339 if self.closure_cache_key is traversals.NO_CACHE: 

340 anon_map[traversals.NO_CACHE] = True 

341 return None 

342 

343 cache_key = ( 

344 self.fn.__code__, 

345 self.__class__, 

346 ) + self.closure_cache_key 

347 

348 parent = self.parent_lambda 

349 while parent is not None: 

350 cache_key = ( 

351 (parent.fn.__code__,) + parent.closure_cache_key + cache_key 

352 ) 

353 

354 parent = parent.parent_lambda 

355 

356 if self._resolved_bindparams: 

357 bindparams.extend(self._resolved_bindparams) 

358 return cache_key 

359 

360 def _invoke_user_fn(self, fn, *arg): 

361 return fn() 

362 

363 

364class DeferredLambdaElement(LambdaElement): 

365 """A LambdaElement where the lambda accepts arguments and is 

366 invoked within the compile phase with special context. 

367 

368 This lambda doesn't normally produce its real SQL expression outside of the 

369 compile phase. It is passed a fixed set of initial arguments 

370 so that it can generate a sample expression. 

371 

372 """ 

373 

374 def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()): 

375 self.lambda_args = lambda_args 

376 super(DeferredLambdaElement, self).__init__(fn, role, opts) 

377 

378 def _invoke_user_fn(self, fn, *arg): 

379 return fn(*self.lambda_args) 

380 

381 def _resolve_with_args(self, *lambda_args): 

382 tracker_fn = self._rec.tracker_instrumented_fn 

383 expr = tracker_fn(*lambda_args) 

384 

385 expr = coercions.expect(self.role, expr) 

386 

387 expr = self._setup_binds_for_tracked_expr(expr) 

388 

389 # this validation is getting very close, but not quite, to achieving 

390 # #5767. The problem is if the base lambda uses an unnamed column 

391 # as is very common with mixins, the parameter name is different 

392 # and it produces a false positive; that is, for the documented case 

393 # that is exactly what people will be doing, it doesn't work, so 

394 # I'm not really sure how to handle this right now. 

395 # expected_binds = [ 

396 # b._orig_key 

397 # for b in self._rec.expr._generate_cache_key()[1] 

398 # if b.required 

399 # ] 

400 # got_binds = [ 

401 # b._orig_key for b in expr._generate_cache_key()[1] if b.required 

402 # ] 

403 # if expected_binds != got_binds: 

404 # raise exc.InvalidRequestError( 

405 # "Lambda callable at %s produced a different set of bound " 

406 # "parameters than its original run: %s" 

407 # % (self.fn.__code__, ", ".join(got_binds)) 

408 # ) 

409 

410 # TODO: TEST TEST TEST, this is very out there 

411 for deferred_copy_internals in self._transforms: 

412 expr = deferred_copy_internals(expr) 

413 

414 return expr 

415 

416 def _copy_internals( 

417 self, clone=_clone, deferred_copy_internals=None, **kw 

418 ): 

419 super(DeferredLambdaElement, self)._copy_internals( 

420 clone=clone, 

421 deferred_copy_internals=deferred_copy_internals, # **kw 

422 opts=kw, 

423 ) 

424 

425 # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know 

426 # our expression yet. so hold onto the replacement 

427 if deferred_copy_internals: 

428 self._transforms += (deferred_copy_internals,) 

429 

430 

431class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): 

432 """Represent a composable SQL statement as a :class:`_sql.LambdaElement`. 

433 

434 The :class:`_sql.StatementLambdaElement` is constructed using the 

435 :func:`_sql.lambda_stmt` function:: 

436 

437 

438 from sqlalchemy import lambda_stmt 

439 

440 stmt = lambda_stmt(lambda: select(table)) 

441 

442 Once constructed, additional criteria can be built onto the statement 

443 by adding subsequent lambdas, which accept the existing statement 

444 object as a single parameter:: 

445 

446 stmt += lambda s: s.where(table.c.col == parameter) 

447 

448 

449 .. versionadded:: 1.4 

450 

451 .. seealso:: 

452 

453 :ref:`engine_lambda_caching` 

454 

455 """ 

456 

457 def __add__(self, other): 

458 return self.add_criteria(other) 

459 

460 def add_criteria( 

461 self, 

462 other, 

463 enable_tracking=True, 

464 track_on=None, 

465 track_closure_variables=True, 

466 track_bound_values=True, 

467 ): 

468 """Add new criteria to this :class:`_sql.StatementLambdaElement`. 

469 

470 E.g.:: 

471 

472 >>> def my_stmt(parameter): 

473 ... stmt = lambda_stmt( 

474 ... lambda: select(table.c.x, table.c.y), 

475 ... ) 

476 ... stmt = stmt.add_criteria( 

477 ... lambda: table.c.x > parameter 

478 ... ) 

479 ... return stmt 

480 

481 The :meth:`_sql.StatementLambdaElement.add_criteria` method is 

482 equivalent to using the Python addition operator to add a new 

483 lambda, except that additional arguments may be added including 

484 ``track_closure_values`` and ``track_on``:: 

485 

486 >>> def my_stmt(self, foo): 

487 ... stmt = lambda_stmt( 

488 ... lambda: select(func.max(foo.x, foo.y)), 

489 ... track_closure_variables=False 

490 ... ) 

491 ... stmt = stmt.add_criteria( 

492 ... lambda: self.where_criteria, 

493 ... track_on=[self] 

494 ... ) 

495 ... return stmt 

496 

497 See :func:`_sql.lambda_stmt` for a description of the parameters 

498 accepted. 

499 

500 """ 

501 

502 opts = self.opts + dict( 

503 enable_tracking=enable_tracking, 

504 track_closure_variables=track_closure_variables, 

505 global_track_bound_values=self.opts.global_track_bound_values, 

506 track_on=track_on, 

507 track_bound_values=track_bound_values, 

508 ) 

509 

510 return LinkedLambdaElement(other, parent_lambda=self, opts=opts) 

511 

512 def _execute_on_connection( 

513 self, connection, multiparams, params, execution_options 

514 ): 

515 if self._rec.expected_expr.supports_execution: 

516 return connection._execute_clauseelement( 

517 self, multiparams, params, execution_options 

518 ) 

519 else: 

520 raise exc.ObjectNotExecutableError(self) 

521 

522 @property 

523 def _with_options(self): 

524 return self._rec.expected_expr._with_options 

525 

526 @property 

527 def _effective_plugin_target(self): 

528 return self._rec.expected_expr._effective_plugin_target 

529 

530 @property 

531 def _execution_options(self): 

532 return self._rec.expected_expr._execution_options 

533 

534 def spoil(self): 

535 """Return a new :class:`.StatementLambdaElement` that will run 

536 all lambdas unconditionally each time. 

537 

538 """ 

539 return NullLambdaStatement(self.fn()) 

540 

541 

542class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement): 

543 """Provides the :class:`.StatementLambdaElement` API but does not 

544 cache or analyze lambdas. 

545 

546 the lambdas are instead invoked immediately. 

547 

548 The intended use is to isolate issues that may arise when using 

549 lambda statements. 

550 

551 """ 

552 

553 __visit_name__ = "lambda_element" 

554 

555 _is_lambda_element = True 

556 

557 _traverse_internals = [ 

558 ("_resolved", visitors.InternalTraversal.dp_clauseelement) 

559 ] 

560 

561 def __init__(self, statement): 

562 self._resolved = statement 

563 self._propagate_attrs = statement._propagate_attrs 

564 

565 def __getattr__(self, key): 

566 return getattr(self._resolved, key) 

567 

568 def __add__(self, other): 

569 statement = other(self._resolved) 

570 

571 return NullLambdaStatement(statement) 

572 

573 def add_criteria(self, other, **kw): 

574 statement = other(self._resolved) 

575 

576 return NullLambdaStatement(statement) 

577 

578 def _execute_on_connection( 

579 self, connection, multiparams, params, execution_options 

580 ): 

581 if self._resolved.supports_execution: 

582 return connection._execute_clauseelement( 

583 self, multiparams, params, execution_options 

584 ) 

585 else: 

586 raise exc.ObjectNotExecutableError(self) 

587 

588 

589class LinkedLambdaElement(StatementLambdaElement): 

590 """Represent subsequent links of a :class:`.StatementLambdaElement`.""" 

591 

592 role = None 

593 

594 def __init__(self, fn, parent_lambda, opts): 

595 self.opts = opts 

596 self.fn = fn 

597 self.parent_lambda = parent_lambda 

598 

599 self.tracker_key = parent_lambda.tracker_key + (fn.__code__,) 

600 self._retrieve_tracker_rec(fn, self, opts) 

601 self._propagate_attrs = parent_lambda._propagate_attrs 

602 

603 def _invoke_user_fn(self, fn, *arg): 

604 return fn(self.parent_lambda._resolved) 

605 

606 

607class AnalyzedCode(object): 

608 __slots__ = ( 

609 "track_closure_variables", 

610 "track_bound_values", 

611 "bindparam_trackers", 

612 "closure_trackers", 

613 "build_py_wrappers", 

614 ) 

615 _fns = weakref.WeakKeyDictionary() 

616 

617 _generation_mutex = threading.RLock() 

618 

619 @classmethod 

620 def get(cls, fn, lambda_element, lambda_kw, **kw): 

621 try: 

622 # TODO: validate kw haven't changed? 

623 return cls._fns[fn.__code__] 

624 except KeyError: 

625 pass 

626 

627 with cls._generation_mutex: 

628 # check for other thread already created object 

629 if fn.__code__ in cls._fns: 

630 return cls._fns[fn.__code__] 

631 

632 cls._fns[fn.__code__] = analyzed = AnalyzedCode( 

633 fn, lambda_element, lambda_kw, **kw 

634 ) 

635 return analyzed 

636 

637 def __init__(self, fn, lambda_element, opts): 

638 if inspect.ismethod(fn): 

639 raise exc.ArgumentError( 

640 "Method %s may not be passed as a SQL expression" % fn 

641 ) 

642 closure = fn.__closure__ 

643 

644 self.track_bound_values = ( 

645 opts.track_bound_values and opts.global_track_bound_values 

646 ) 

647 enable_tracking = opts.enable_tracking 

648 track_on = opts.track_on 

649 track_closure_variables = opts.track_closure_variables 

650 

651 self.track_closure_variables = track_closure_variables and not track_on 

652 

653 # a list of callables generated from _bound_parameter_getter_* 

654 # functions. Each of these uses a PyWrapper object to retrieve 

655 # a parameter value 

656 self.bindparam_trackers = [] 

657 

658 # a list of callables generated from _cache_key_getter_* functions 

659 # these callables work to generate a cache key for the lambda 

660 # based on what's inside its closure variables. 

661 self.closure_trackers = [] 

662 

663 self.build_py_wrappers = [] 

664 

665 if enable_tracking: 

666 if track_on: 

667 self._init_track_on(track_on) 

668 

669 self._init_globals(fn) 

670 

671 if closure: 

672 self._init_closure(fn) 

673 

674 self._setup_additional_closure_trackers(fn, lambda_element, opts) 

675 

676 def _init_track_on(self, track_on): 

677 self.closure_trackers.extend( 

678 self._cache_key_getter_track_on(idx, elem) 

679 for idx, elem in enumerate(track_on) 

680 ) 

681 

682 def _init_globals(self, fn): 

683 build_py_wrappers = self.build_py_wrappers 

684 bindparam_trackers = self.bindparam_trackers 

685 track_bound_values = self.track_bound_values 

686 

687 for name in fn.__code__.co_names: 

688 if name not in fn.__globals__: 

689 continue 

690 

691 _bound_value = self._roll_down_to_literal(fn.__globals__[name]) 

692 

693 if coercions._deep_is_literal(_bound_value): 

694 build_py_wrappers.append((name, None)) 

695 if track_bound_values: 

696 bindparam_trackers.append( 

697 self._bound_parameter_getter_func_globals(name) 

698 ) 

699 

700 def _init_closure(self, fn): 

701 build_py_wrappers = self.build_py_wrappers 

702 closure = fn.__closure__ 

703 

704 track_bound_values = self.track_bound_values 

705 track_closure_variables = self.track_closure_variables 

706 bindparam_trackers = self.bindparam_trackers 

707 closure_trackers = self.closure_trackers 

708 

709 for closure_index, (fv, cell) in enumerate( 

710 zip(fn.__code__.co_freevars, closure) 

711 ): 

712 _bound_value = self._roll_down_to_literal(cell.cell_contents) 

713 

714 if coercions._deep_is_literal(_bound_value): 

715 build_py_wrappers.append((fv, closure_index)) 

716 if track_bound_values: 

717 bindparam_trackers.append( 

718 self._bound_parameter_getter_func_closure( 

719 fv, closure_index 

720 ) 

721 ) 

722 else: 

723 # for normal cell contents, add them to a list that 

724 # we can compare later when we get new lambdas. if 

725 # any identities have changed, then we will 

726 # recalculate the whole lambda and run it again. 

727 

728 if track_closure_variables: 

729 closure_trackers.append( 

730 self._cache_key_getter_closure_variable( 

731 fn, fv, closure_index, cell.cell_contents 

732 ) 

733 ) 

734 

735 def _setup_additional_closure_trackers(self, fn, lambda_element, opts): 

736 # an additional step is to actually run the function, then 

737 # go through the PyWrapper objects that were set up to catch a bound 

738 # parameter. then if they *didn't* make a param, oh they're another 

739 # object in the closure we have to track for our cache key. so 

740 # create trackers to catch those. 

741 

742 analyzed_function = AnalyzedFunction( 

743 self, 

744 lambda_element, 

745 None, 

746 fn, 

747 ) 

748 

749 closure_trackers = self.closure_trackers 

750 

751 for pywrapper in analyzed_function.closure_pywrappers: 

752 if not pywrapper._sa__has_param: 

753 closure_trackers.append( 

754 self._cache_key_getter_tracked_literal(fn, pywrapper) 

755 ) 

756 

757 @classmethod 

758 def _roll_down_to_literal(cls, element): 

759 is_clause_element = hasattr(element, "__clause_element__") 

760 

761 if is_clause_element: 

762 while not isinstance( 

763 element, (elements.ClauseElement, schema.SchemaItem, type) 

764 ): 

765 try: 

766 element = element.__clause_element__() 

767 except AttributeError: 

768 break 

769 

770 if not is_clause_element: 

771 insp = inspection.inspect(element, raiseerr=False) 

772 if insp is not None: 

773 try: 

774 return insp.__clause_element__() 

775 except AttributeError: 

776 return insp 

777 

778 # TODO: should we coerce consts None/True/False here? 

779 return element 

780 else: 

781 return element 

782 

783 def _bound_parameter_getter_func_globals(self, name): 

784 """Return a getter that will extend a list of bound parameters 

785 with new entries from the ``__globals__`` collection of a particular 

786 lambda. 

787 

788 """ 

789 

790 def extract_parameter_value( 

791 current_fn, tracker_instrumented_fn, result 

792 ): 

793 wrapper = tracker_instrumented_fn.__globals__[name] 

794 object.__getattribute__(wrapper, "_extract_bound_parameters")( 

795 current_fn.__globals__[name], result 

796 ) 

797 

798 return extract_parameter_value 

799 

800 def _bound_parameter_getter_func_closure(self, name, closure_index): 

801 """Return a getter that will extend a list of bound parameters 

802 with new entries from the ``__closure__`` collection of a particular 

803 lambda. 

804 

805 """ 

806 

807 def extract_parameter_value( 

808 current_fn, tracker_instrumented_fn, result 

809 ): 

810 wrapper = tracker_instrumented_fn.__closure__[ 

811 closure_index 

812 ].cell_contents 

813 object.__getattribute__(wrapper, "_extract_bound_parameters")( 

814 current_fn.__closure__[closure_index].cell_contents, result 

815 ) 

816 

817 return extract_parameter_value 

818 

819 def _cache_key_getter_track_on(self, idx, elem): 

820 """Return a getter that will extend a cache key with new entries 

821 from the "track_on" parameter passed to a :class:`.LambdaElement`. 

822 

823 """ 

824 

825 if isinstance(elem, tuple): 

826 # tuple must contain hascachekey elements 

827 def get(closure, opts, anon_map, bindparams): 

828 return tuple( 

829 tup_elem._gen_cache_key(anon_map, bindparams) 

830 for tup_elem in opts.track_on[idx] 

831 ) 

832 

833 elif isinstance(elem, traversals.HasCacheKey): 

834 

835 def get(closure, opts, anon_map, bindparams): 

836 return opts.track_on[idx]._gen_cache_key(anon_map, bindparams) 

837 

838 else: 

839 

840 def get(closure, opts, anon_map, bindparams): 

841 return opts.track_on[idx] 

842 

843 return get 

844 

845 def _cache_key_getter_closure_variable( 

846 self, 

847 fn, 

848 variable_name, 

849 idx, 

850 cell_contents, 

851 use_clause_element=False, 

852 use_inspect=False, 

853 ): 

854 """Return a getter that will extend a cache key with new entries 

855 from the ``__closure__`` collection of a particular lambda. 

856 

857 """ 

858 

859 if isinstance(cell_contents, traversals.HasCacheKey): 

860 

861 def get(closure, opts, anon_map, bindparams): 

862 

863 obj = closure[idx].cell_contents 

864 if use_inspect: 

865 obj = inspection.inspect(obj) 

866 elif use_clause_element: 

867 while hasattr(obj, "__clause_element__"): 

868 if not getattr(obj, "is_clause_element", False): 

869 obj = obj.__clause_element__() 

870 

871 return obj._gen_cache_key(anon_map, bindparams) 

872 

873 elif isinstance(cell_contents, types.FunctionType): 

874 

875 def get(closure, opts, anon_map, bindparams): 

876 return closure[idx].cell_contents.__code__ 

877 

878 elif isinstance(cell_contents, collections_abc.Sequence): 

879 

880 def get(closure, opts, anon_map, bindparams): 

881 contents = closure[idx].cell_contents 

882 

883 try: 

884 return tuple( 

885 elem._gen_cache_key(anon_map, bindparams) 

886 for elem in contents 

887 ) 

888 except AttributeError as ae: 

889 self._raise_for_uncacheable_closure_variable( 

890 variable_name, fn, from_=ae 

891 ) 

892 

893 else: 

894 # if the object is a mapped class or aliased class, or some 

895 # other object in the ORM realm of things like that, imitate 

896 # the logic used in coercions.expect() to roll it down to the 

897 # SQL element 

898 element = cell_contents 

899 is_clause_element = False 

900 while hasattr(element, "__clause_element__"): 

901 is_clause_element = True 

902 if not getattr(element, "is_clause_element", False): 

903 element = element.__clause_element__() 

904 else: 

905 break 

906 

907 if not is_clause_element: 

908 insp = inspection.inspect(element, raiseerr=False) 

909 if insp is not None: 

910 return self._cache_key_getter_closure_variable( 

911 fn, variable_name, idx, insp, use_inspect=True 

912 ) 

913 else: 

914 return self._cache_key_getter_closure_variable( 

915 fn, variable_name, idx, element, use_clause_element=True 

916 ) 

917 

918 self._raise_for_uncacheable_closure_variable(variable_name, fn) 

919 

920 return get 

921 

922 def _raise_for_uncacheable_closure_variable( 

923 self, variable_name, fn, from_=None 

924 ): 

925 util.raise_( 

926 exc.InvalidRequestError( 

927 "Closure variable named '%s' inside of lambda callable %s " 

928 "does not refer to a cacheable SQL element, and also does not " 

929 "appear to be serving as a SQL literal bound value based on " 

930 "the default " 

931 "SQL expression returned by the function. This variable " 

932 "needs to remain outside the scope of a SQL-generating lambda " 

933 "so that a proper cache key may be generated from the " 

934 "lambda's state. Evaluate this variable outside of the " 

935 "lambda, set track_on=[<elements>] to explicitly select " 

936 "closure elements to track, or set " 

937 "track_closure_variables=False to exclude " 

938 "closure variables from being part of the cache key." 

939 % (variable_name, fn.__code__), 

940 ), 

941 from_=from_, 

942 ) 

943 

944 def _cache_key_getter_tracked_literal(self, fn, pytracker): 

945 """Return a getter that will extend a cache key with new entries 

946 from the ``__closure__`` collection of a particular lambda. 

947 

948 this getter differs from _cache_key_getter_closure_variable 

949 in that these are detected after the function is run, and PyWrapper 

950 objects have recorded that a particular literal value is in fact 

951 not being interpreted as a bound parameter. 

952 

953 """ 

954 

955 elem = pytracker._sa__to_evaluate 

956 closure_index = pytracker._sa__closure_index 

957 variable_name = pytracker._sa__name 

958 

959 return self._cache_key_getter_closure_variable( 

960 fn, variable_name, closure_index, elem 

961 ) 

962 

963 

964class NonAnalyzedFunction(object): 

965 __slots__ = ("expr",) 

966 

967 closure_bindparams = None 

968 bindparam_trackers = None 

969 

970 def __init__(self, expr): 

971 self.expr = expr 

972 

973 @property 

974 def expected_expr(self): 

975 return self.expr 

976 

977 

978class AnalyzedFunction(object): 

979 __slots__ = ( 

980 "analyzed_code", 

981 "fn", 

982 "closure_pywrappers", 

983 "tracker_instrumented_fn", 

984 "expr", 

985 "bindparam_trackers", 

986 "expected_expr", 

987 "is_sequence", 

988 "propagate_attrs", 

989 "closure_bindparams", 

990 ) 

991 

992 def __init__( 

993 self, 

994 analyzed_code, 

995 lambda_element, 

996 apply_propagate_attrs, 

997 fn, 

998 ): 

999 self.analyzed_code = analyzed_code 

1000 self.fn = fn 

1001 

1002 self.bindparam_trackers = analyzed_code.bindparam_trackers 

1003 

1004 self._instrument_and_run_function(lambda_element) 

1005 

1006 self._coerce_expression(lambda_element, apply_propagate_attrs) 

1007 

1008 def _instrument_and_run_function(self, lambda_element): 

1009 analyzed_code = self.analyzed_code 

1010 

1011 fn = self.fn 

1012 self.closure_pywrappers = closure_pywrappers = [] 

1013 

1014 build_py_wrappers = analyzed_code.build_py_wrappers 

1015 

1016 if not build_py_wrappers: 

1017 self.tracker_instrumented_fn = tracker_instrumented_fn = fn 

1018 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn) 

1019 else: 

1020 track_closure_variables = analyzed_code.track_closure_variables 

1021 closure = fn.__closure__ 

1022 

1023 # will form the __closure__ of the function when we rebuild it 

1024 if closure: 

1025 new_closure = { 

1026 fv: cell.cell_contents 

1027 for fv, cell in zip(fn.__code__.co_freevars, closure) 

1028 } 

1029 else: 

1030 new_closure = {} 

1031 

1032 # will form the __globals__ of the function when we rebuild it 

1033 new_globals = fn.__globals__.copy() 

1034 

1035 for name, closure_index in build_py_wrappers: 

1036 if closure_index is not None: 

1037 value = closure[closure_index].cell_contents 

1038 new_closure[name] = bind = PyWrapper( 

1039 fn, 

1040 name, 

1041 value, 

1042 closure_index=closure_index, 

1043 track_bound_values=( 

1044 self.analyzed_code.track_bound_values 

1045 ), 

1046 ) 

1047 if track_closure_variables: 

1048 closure_pywrappers.append(bind) 

1049 else: 

1050 value = fn.__globals__[name] 

1051 new_globals[name] = bind = PyWrapper(fn, name, value) 

1052 

1053 # rewrite the original fn. things that look like they will 

1054 # become bound parameters are wrapped in a PyWrapper. 

1055 self.tracker_instrumented_fn = ( 

1056 tracker_instrumented_fn 

1057 ) = self._rewrite_code_obj( 

1058 fn, 

1059 [new_closure[name] for name in fn.__code__.co_freevars], 

1060 new_globals, 

1061 ) 

1062 

1063 # now invoke the function. This will give us a new SQL 

1064 # expression, but all the places that there would be a bound 

1065 # parameter, the PyWrapper in its place will give us a bind 

1066 # with a predictable name we can match up later. 

1067 

1068 # additionally, each PyWrapper will log that it did in fact 

1069 # create a parameter, otherwise, it's some kind of Python 

1070 # object in the closure and we want to track that, to make 

1071 # sure it doesn't change to something else, or if it does, 

1072 # that we create a different tracked function with that 

1073 # variable. 

1074 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn) 

1075 

1076 def _coerce_expression(self, lambda_element, apply_propagate_attrs): 

1077 """Run the tracker-generated expression through coercion rules. 

1078 

1079 After the user-defined lambda has been invoked to produce a statement 

1080 for re-use, run it through coercion rules to both check that it's the 

1081 correct type of object and also to coerce it to its useful form. 

1082 

1083 """ 

1084 

1085 parent_lambda = lambda_element.parent_lambda 

1086 expr = self.expr 

1087 

1088 if parent_lambda is None: 

1089 if isinstance(expr, collections_abc.Sequence): 

1090 self.expected_expr = [ 

1091 coercions.expect( 

1092 lambda_element.role, 

1093 sub_expr, 

1094 apply_propagate_attrs=apply_propagate_attrs, 

1095 ) 

1096 for sub_expr in expr 

1097 ] 

1098 self.is_sequence = True 

1099 else: 

1100 self.expected_expr = coercions.expect( 

1101 lambda_element.role, 

1102 expr, 

1103 apply_propagate_attrs=apply_propagate_attrs, 

1104 ) 

1105 self.is_sequence = False 

1106 else: 

1107 self.expected_expr = expr 

1108 self.is_sequence = False 

1109 

1110 if apply_propagate_attrs is not None: 

1111 self.propagate_attrs = apply_propagate_attrs._propagate_attrs 

1112 else: 

1113 self.propagate_attrs = util.EMPTY_DICT 

1114 

1115 def _rewrite_code_obj(self, f, cell_values, globals_): 

1116 """Return a copy of f, with a new closure and new globals 

1117 

1118 yes it works in pypy :P 

1119 

1120 """ 

1121 

1122 argrange = range(len(cell_values)) 

1123 

1124 code = "def make_cells():\n" 

1125 if cell_values: 

1126 code += " (%s) = (%s)\n" % ( 

1127 ", ".join("i%d" % i for i in argrange), 

1128 ", ".join("o%d" % i for i in argrange), 

1129 ) 

1130 code += " def closure():\n" 

1131 code += " return %s\n" % ", ".join("i%d" % i for i in argrange) 

1132 code += " return closure.__closure__" 

1133 vars_ = {"o%d" % i: cell_values[i] for i in argrange} 

1134 compat.exec_(code, vars_, vars_) 

1135 closure = vars_["make_cells"]() 

1136 

1137 func = type(f)( 

1138 f.__code__, globals_, f.__name__, f.__defaults__, closure 

1139 ) 

1140 if sys.version_info >= (3,): 

1141 func.__annotations__ = f.__annotations__ 

1142 func.__kwdefaults__ = f.__kwdefaults__ 

1143 func.__doc__ = f.__doc__ 

1144 func.__module__ = f.__module__ 

1145 

1146 return func 

1147 

1148 

1149class PyWrapper(ColumnOperators): 

1150 """A wrapper object that is injected into the ``__globals__`` and 

1151 ``__closure__`` of a Python function. 

1152 

1153 When the function is instrumented with :class:`.PyWrapper` objects, it is 

1154 then invoked just once in order to set up the wrappers. We look through 

1155 all the :class:`.PyWrapper` objects we made to find the ones that generated 

1156 a :class:`.BindParameter` object, e.g. the expression system interpreted 

1157 something as a literal. Those positions in the globals/closure are then 

1158 ones that we will look at, each time a new lambda comes in that refers to 

1159 the same ``__code__`` object. In this way, we keep a single version of 

1160 the SQL expression that this lambda produced, without calling upon the 

1161 Python function that created it more than once, unless its other closure 

1162 variables have changed. The expression is then transformed to have the 

1163 new bound values embedded into it. 

1164 

1165 """ 

1166 

1167 def __init__( 

1168 self, 

1169 fn, 

1170 name, 

1171 to_evaluate, 

1172 closure_index=None, 

1173 getter=None, 

1174 track_bound_values=True, 

1175 ): 

1176 self.fn = fn 

1177 self._name = name 

1178 self._to_evaluate = to_evaluate 

1179 self._param = None 

1180 self._has_param = False 

1181 self._bind_paths = {} 

1182 self._getter = getter 

1183 self._closure_index = closure_index 

1184 self.track_bound_values = track_bound_values 

1185 

1186 def __call__(self, *arg, **kw): 

1187 elem = object.__getattribute__(self, "_to_evaluate") 

1188 value = elem(*arg, **kw) 

1189 if ( 

1190 self._sa_track_bound_values 

1191 and coercions._deep_is_literal(value) 

1192 and not isinstance( 

1193 # TODO: coverage where an ORM option or similar is here 

1194 value, 

1195 traversals.HasCacheKey, 

1196 ) 

1197 ): 

1198 name = object.__getattribute__(self, "_name") 

1199 raise exc.InvalidRequestError( 

1200 "Can't invoke Python callable %s() inside of lambda " 

1201 "expression argument at %s; lambda SQL constructs should " 

1202 "not invoke functions from closure variables to produce " 

1203 "literal values since the " 

1204 "lambda SQL system normally extracts bound values without " 

1205 "actually " 

1206 "invoking the lambda or any functions within it. Call the " 

1207 "function outside of the " 

1208 "lambda and assign to a local variable that is used in the " 

1209 "lambda as a closure variable, or set " 

1210 "track_bound_values=False if the return value of this " 

1211 "function is used in some other way other than a SQL bound " 

1212 "value." % (name, self._sa_fn.__code__) 

1213 ) 

1214 else: 

1215 return value 

1216 

1217 def operate(self, op, *other, **kwargs): 

1218 elem = object.__getattribute__(self, "__clause_element__")() 

1219 return op(elem, *other, **kwargs) 

1220 

1221 def reverse_operate(self, op, other, **kwargs): 

1222 elem = object.__getattribute__(self, "__clause_element__")() 

1223 return op(other, elem, **kwargs) 

1224 

1225 def _extract_bound_parameters(self, starting_point, result_list): 

1226 param = object.__getattribute__(self, "_param") 

1227 if param is not None: 

1228 param = param._with_value(starting_point, maintain_key=True) 

1229 result_list.append(param) 

1230 for pywrapper in object.__getattribute__(self, "_bind_paths").values(): 

1231 getter = object.__getattribute__(pywrapper, "_getter") 

1232 element = getter(starting_point) 

1233 pywrapper._sa__extract_bound_parameters(element, result_list) 

1234 

1235 def __clause_element__(self): 

1236 param = object.__getattribute__(self, "_param") 

1237 to_evaluate = object.__getattribute__(self, "_to_evaluate") 

1238 if param is None: 

1239 name = object.__getattribute__(self, "_name") 

1240 self._param = param = elements.BindParameter( 

1241 name, required=False, unique=True 

1242 ) 

1243 self._has_param = True 

1244 param.type = type_api._resolve_value_to_type(to_evaluate) 

1245 return param._with_value(to_evaluate, maintain_key=True) 

1246 

1247 def __bool__(self): 

1248 to_evaluate = object.__getattribute__(self, "_to_evaluate") 

1249 return bool(to_evaluate) 

1250 

1251 def __nonzero__(self): 

1252 to_evaluate = object.__getattribute__(self, "_to_evaluate") 

1253 return bool(to_evaluate) 

1254 

1255 def __getattribute__(self, key): 

1256 if key.startswith("_sa_"): 

1257 return object.__getattribute__(self, key[4:]) 

1258 elif key in ( 

1259 "__clause_element__", 

1260 "operate", 

1261 "reverse_operate", 

1262 "__class__", 

1263 "__dict__", 

1264 ): 

1265 return object.__getattribute__(self, key) 

1266 

1267 if key.startswith("__"): 

1268 elem = object.__getattribute__(self, "_to_evaluate") 

1269 return getattr(elem, key) 

1270 else: 

1271 return self._sa__add_getter(key, operator.attrgetter) 

1272 

1273 def __iter__(self): 

1274 elem = object.__getattribute__(self, "_to_evaluate") 

1275 return iter(elem) 

1276 

1277 def __getitem__(self, key): 

1278 elem = object.__getattribute__(self, "_to_evaluate") 

1279 if not hasattr(elem, "__getitem__"): 

1280 raise AttributeError("__getitem__") 

1281 

1282 if isinstance(key, PyWrapper): 

1283 # TODO: coverage 

1284 raise exc.InvalidRequestError( 

1285 "Dictionary keys / list indexes inside of a cached " 

1286 "lambda must be Python literals only" 

1287 ) 

1288 return self._sa__add_getter(key, operator.itemgetter) 

1289 

1290 def _add_getter(self, key, getter_fn): 

1291 

1292 bind_paths = object.__getattribute__(self, "_bind_paths") 

1293 

1294 bind_path_key = (key, getter_fn) 

1295 if bind_path_key in bind_paths: 

1296 return bind_paths[bind_path_key] 

1297 

1298 getter = getter_fn(key) 

1299 elem = object.__getattribute__(self, "_to_evaluate") 

1300 value = getter(elem) 

1301 

1302 rolled_down_value = AnalyzedCode._roll_down_to_literal(value) 

1303 

1304 if coercions._deep_is_literal(rolled_down_value): 

1305 wrapper = PyWrapper(self._sa_fn, key, value, getter=getter) 

1306 bind_paths[bind_path_key] = wrapper 

1307 return wrapper 

1308 else: 

1309 return value 

1310 

1311 

1312@inspection._inspects(LambdaElement) 

1313def insp(lmb): 

1314 return inspection.inspect(lmb._resolved)