Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/sqlalchemy/sql/lambdas.py: 24%

531 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# sql/lambdas.py 

2# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: https://www.opensource.org/licenses/mit-license.php 

7 

8import inspect 

9import itertools 

10import operator 

11import sys 

12import threading 

13import types 

14import weakref 

15 

16from . import coercions 

17from . import elements 

18from . import roles 

19from . import schema 

20from . import traversals 

21from . import visitors 

22from .base import _clone 

23from .base import Options 

24from .operators import ColumnOperators 

25from .. import exc 

26from .. import inspection 

27from .. import util 

28from ..util import collections_abc 

29from ..util import compat 

30 

31_closure_per_cache_key = util.LRUCache(1000) 

32 

33 

34class LambdaOptions(Options): 

35 enable_tracking = True 

36 track_closure_variables = True 

37 track_on = None 

38 global_track_bound_values = True 

39 track_bound_values = True 

40 lambda_cache = None 

41 

42 

43def lambda_stmt( 

44 lmb, 

45 enable_tracking=True, 

46 track_closure_variables=True, 

47 track_on=None, 

48 global_track_bound_values=True, 

49 track_bound_values=True, 

50 lambda_cache=None, 

51): 

52 """Produce a SQL statement that is cached as a lambda. 

53 

54 The Python code object within the lambda is scanned for both Python 

55 literals that will become bound parameters as well as closure variables 

56 that refer to Core or ORM constructs that may vary. The lambda itself 

57 will be invoked only once per particular set of constructs detected. 

58 

59 E.g.:: 

60 

61 from sqlalchemy import lambda_stmt 

62 

63 stmt = lambda_stmt(lambda: table.select()) 

64 stmt += lambda s: s.where(table.c.id == 5) 

65 

66 result = connection.execute(stmt) 

67 

68 The object returned is an instance of :class:`_sql.StatementLambdaElement`. 

69 

70 .. versionadded:: 1.4 

71 

72 :param lmb: a Python function, typically a lambda, which takes no arguments 

73 and returns a SQL expression construct 

74 :param enable_tracking: when False, all scanning of the given lambda for 

75 changes in closure variables or bound parameters is disabled. Use for 

76 a lambda that produces the identical results in all cases with no 

77 parameterization. 

78 :param track_closure_variables: when False, changes in closure variables 

79 within the lambda will not be scanned. Use for a lambda where the 

80 state of its closure variables will never change the SQL structure 

81 returned by the lambda. 

82 :param track_bound_values: when False, bound parameter tracking will 

83 be disabled for the given lambda. Use for a lambda that either does 

84 not produce any bound values, or where the initial bound values never 

85 change. 

86 :param global_track_bound_values: when False, bound parameter tracking 

87 will be disabled for the entire statement including additional links 

88 added via the :meth:`_sql.StatementLambdaElement.add_criteria` method. 

89 :param lambda_cache: a dictionary or other mapping-like object where 

90 information about the lambda's Python code as well as the tracked closure 

91 variables in the lambda itself will be stored. Defaults 

92 to a global LRU cache. This cache is independent of the "compiled_cache" 

93 used by the :class:`_engine.Connection` object. 

94 

95 .. seealso:: 

96 

97 :ref:`engine_lambda_caching` 

98 

99 

100 """ 

101 

102 return StatementLambdaElement( 

103 lmb, 

104 roles.StatementRole, 

105 LambdaOptions( 

106 enable_tracking=enable_tracking, 

107 track_on=track_on, 

108 track_closure_variables=track_closure_variables, 

109 global_track_bound_values=global_track_bound_values, 

110 track_bound_values=track_bound_values, 

111 lambda_cache=lambda_cache, 

112 ), 

113 ) 

114 

115 

116class LambdaElement(elements.ClauseElement): 

117 """A SQL construct where the state is stored as an un-invoked lambda. 

118 

119 The :class:`_sql.LambdaElement` is produced transparently whenever 

120 passing lambda expressions into SQL constructs, such as:: 

121 

122 stmt = select(table).where(lambda: table.c.col == parameter) 

123 

124 The :class:`_sql.LambdaElement` is the base of the 

125 :class:`_sql.StatementLambdaElement` which represents a full statement 

126 within a lambda. 

127 

128 .. versionadded:: 1.4 

129 

130 .. seealso:: 

131 

132 :ref:`engine_lambda_caching` 

133 

134 """ 

135 

136 __visit_name__ = "lambda_element" 

137 

138 _is_lambda_element = True 

139 

140 _traverse_internals = [ 

141 ("_resolved", visitors.InternalTraversal.dp_clauseelement) 

142 ] 

143 

144 _transforms = () 

145 

146 parent_lambda = None 

147 

148 def __repr__(self): 

149 return "%s(%r)" % (self.__class__.__name__, self.fn.__code__) 

150 

151 def __init__( 

152 self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None 

153 ): 

154 self.fn = fn 

155 self.role = role 

156 self.tracker_key = (fn.__code__,) 

157 self.opts = opts 

158 

159 if apply_propagate_attrs is None and (role is roles.StatementRole): 

160 apply_propagate_attrs = self 

161 

162 rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts) 

163 

164 if apply_propagate_attrs is not None: 

165 propagate_attrs = rec.propagate_attrs 

166 if propagate_attrs: 

167 apply_propagate_attrs._propagate_attrs = propagate_attrs 

168 

169 def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts): 

170 lambda_cache = opts.lambda_cache 

171 if lambda_cache is None: 

172 lambda_cache = _closure_per_cache_key 

173 

174 tracker_key = self.tracker_key 

175 

176 fn = self.fn 

177 closure = fn.__closure__ 

178 tracker = AnalyzedCode.get( 

179 fn, 

180 self, 

181 opts, 

182 ) 

183 

184 self._resolved_bindparams = bindparams = [] 

185 

186 if self.parent_lambda is not None: 

187 parent_closure_cache_key = self.parent_lambda.closure_cache_key 

188 else: 

189 parent_closure_cache_key = () 

190 

191 if parent_closure_cache_key is not traversals.NO_CACHE: 

192 anon_map = traversals.anon_map() 

193 cache_key = tuple( 

194 [ 

195 getter(closure, opts, anon_map, bindparams) 

196 for getter in tracker.closure_trackers 

197 ] 

198 ) 

199 

200 if traversals.NO_CACHE not in anon_map: 

201 cache_key = parent_closure_cache_key + cache_key 

202 

203 self.closure_cache_key = cache_key 

204 

205 try: 

206 rec = lambda_cache[tracker_key + cache_key] 

207 except KeyError: 

208 rec = None 

209 else: 

210 cache_key = traversals.NO_CACHE 

211 rec = None 

212 

213 else: 

214 cache_key = traversals.NO_CACHE 

215 rec = None 

216 

217 self.closure_cache_key = cache_key 

218 

219 if rec is None: 

220 if cache_key is not traversals.NO_CACHE: 

221 

222 with AnalyzedCode._generation_mutex: 

223 key = tracker_key + cache_key 

224 if key not in lambda_cache: 

225 rec = AnalyzedFunction( 

226 tracker, self, apply_propagate_attrs, fn 

227 ) 

228 rec.closure_bindparams = bindparams 

229 lambda_cache[key] = rec 

230 else: 

231 rec = lambda_cache[key] 

232 else: 

233 rec = NonAnalyzedFunction(self._invoke_user_fn(fn)) 

234 

235 else: 

236 bindparams[:] = [ 

237 orig_bind._with_value(new_bind.value, maintain_key=True) 

238 for orig_bind, new_bind in zip( 

239 rec.closure_bindparams, bindparams 

240 ) 

241 ] 

242 

243 self._rec = rec 

244 

245 if cache_key is not traversals.NO_CACHE: 

246 if self.parent_lambda is not None: 

247 bindparams[:0] = self.parent_lambda._resolved_bindparams 

248 

249 lambda_element = self 

250 while lambda_element is not None: 

251 rec = lambda_element._rec 

252 if rec.bindparam_trackers: 

253 tracker_instrumented_fn = rec.tracker_instrumented_fn 

254 for tracker in rec.bindparam_trackers: 

255 tracker( 

256 lambda_element.fn, 

257 tracker_instrumented_fn, 

258 bindparams, 

259 ) 

260 lambda_element = lambda_element.parent_lambda 

261 

262 return rec 

263 

264 def __getattr__(self, key): 

265 return getattr(self._rec.expected_expr, key) 

266 

267 @property 

268 def _is_sequence(self): 

269 return self._rec.is_sequence 

270 

271 @property 

272 def _select_iterable(self): 

273 if self._is_sequence: 

274 return itertools.chain.from_iterable( 

275 [element._select_iterable for element in self._resolved] 

276 ) 

277 

278 else: 

279 return self._resolved._select_iterable 

280 

281 @property 

282 def _from_objects(self): 

283 if self._is_sequence: 

284 return itertools.chain.from_iterable( 

285 [element._from_objects for element in self._resolved] 

286 ) 

287 

288 else: 

289 return self._resolved._from_objects 

290 

291 def _param_dict(self): 

292 return {b.key: b.value for b in self._resolved_bindparams} 

293 

294 def _setup_binds_for_tracked_expr(self, expr): 

295 bindparam_lookup = {b.key: b for b in self._resolved_bindparams} 

296 

297 def replace(thing): 

298 if isinstance(thing, elements.BindParameter): 

299 

300 if thing.key in bindparam_lookup: 

301 bind = bindparam_lookup[thing.key] 

302 if thing.expanding: 

303 bind.expanding = True 

304 bind.expand_op = thing.expand_op 

305 bind.type = thing.type 

306 return bind 

307 

308 if self._rec.is_sequence: 

309 expr = [ 

310 visitors.replacement_traverse(sub_expr, {}, replace) 

311 for sub_expr in expr 

312 ] 

313 elif getattr(expr, "is_clause_element", False): 

314 expr = visitors.replacement_traverse(expr, {}, replace) 

315 

316 return expr 

317 

318 def _copy_internals( 

319 self, clone=_clone, deferred_copy_internals=None, **kw 

320 ): 

321 # TODO: this needs A LOT of tests 

322 self._resolved = clone( 

323 self._resolved, 

324 deferred_copy_internals=deferred_copy_internals, 

325 **kw 

326 ) 

327 

328 @util.memoized_property 

329 def _resolved(self): 

330 expr = self._rec.expected_expr 

331 

332 if self._resolved_bindparams: 

333 expr = self._setup_binds_for_tracked_expr(expr) 

334 

335 return expr 

336 

337 def _gen_cache_key(self, anon_map, bindparams): 

338 if self.closure_cache_key is traversals.NO_CACHE: 

339 anon_map[traversals.NO_CACHE] = True 

340 return None 

341 

342 cache_key = ( 

343 self.fn.__code__, 

344 self.__class__, 

345 ) + self.closure_cache_key 

346 

347 parent = self.parent_lambda 

348 while parent is not None: 

349 cache_key = ( 

350 (parent.fn.__code__,) + parent.closure_cache_key + cache_key 

351 ) 

352 

353 parent = parent.parent_lambda 

354 

355 if self._resolved_bindparams: 

356 bindparams.extend(self._resolved_bindparams) 

357 return cache_key 

358 

359 def _invoke_user_fn(self, fn, *arg): 

360 return fn() 

361 

362 

363class DeferredLambdaElement(LambdaElement): 

364 """A LambdaElement where the lambda accepts arguments and is 

365 invoked within the compile phase with special context. 

366 

367 This lambda doesn't normally produce its real SQL expression outside of the 

368 compile phase. It is passed a fixed set of initial arguments 

369 so that it can generate a sample expression. 

370 

371 """ 

372 

373 def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()): 

374 self.lambda_args = lambda_args 

375 super(DeferredLambdaElement, self).__init__(fn, role, opts) 

376 

377 def _invoke_user_fn(self, fn, *arg): 

378 return fn(*self.lambda_args) 

379 

380 def _resolve_with_args(self, *lambda_args): 

381 tracker_fn = self._rec.tracker_instrumented_fn 

382 expr = tracker_fn(*lambda_args) 

383 

384 expr = coercions.expect(self.role, expr) 

385 

386 expr = self._setup_binds_for_tracked_expr(expr) 

387 

388 # this validation is getting very close, but not quite, to achieving 

389 # #5767. The problem is if the base lambda uses an unnamed column 

390 # as is very common with mixins, the parameter name is different 

391 # and it produces a false positive; that is, for the documented case 

392 # that is exactly what people will be doing, it doesn't work, so 

393 # I'm not really sure how to handle this right now. 

394 # expected_binds = [ 

395 # b._orig_key 

396 # for b in self._rec.expr._generate_cache_key()[1] 

397 # if b.required 

398 # ] 

399 # got_binds = [ 

400 # b._orig_key for b in expr._generate_cache_key()[1] if b.required 

401 # ] 

402 # if expected_binds != got_binds: 

403 # raise exc.InvalidRequestError( 

404 # "Lambda callable at %s produced a different set of bound " 

405 # "parameters than its original run: %s" 

406 # % (self.fn.__code__, ", ".join(got_binds)) 

407 # ) 

408 

409 # TODO: TEST TEST TEST, this is very out there 

410 for deferred_copy_internals in self._transforms: 

411 expr = deferred_copy_internals(expr) 

412 

413 return expr 

414 

415 def _copy_internals( 

416 self, clone=_clone, deferred_copy_internals=None, **kw 

417 ): 

418 super(DeferredLambdaElement, self)._copy_internals( 

419 clone=clone, 

420 deferred_copy_internals=deferred_copy_internals, # **kw 

421 opts=kw, 

422 ) 

423 

424 # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know 

425 # our expression yet. so hold onto the replacement 

426 if deferred_copy_internals: 

427 self._transforms += (deferred_copy_internals,) 

428 

429 

430class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): 

431 """Represent a composable SQL statement as a :class:`_sql.LambdaElement`. 

432 

433 The :class:`_sql.StatementLambdaElement` is constructed using the 

434 :func:`_sql.lambda_stmt` function:: 

435 

436 

437 from sqlalchemy import lambda_stmt 

438 

439 stmt = lambda_stmt(lambda: select(table)) 

440 

441 Once constructed, additional criteria can be built onto the statement 

442 by adding subsequent lambdas, which accept the existing statement 

443 object as a single parameter:: 

444 

445 stmt += lambda s: s.where(table.c.col == parameter) 

446 

447 

448 .. versionadded:: 1.4 

449 

450 .. seealso:: 

451 

452 :ref:`engine_lambda_caching` 

453 

454 """ 

455 

456 def __add__(self, other): 

457 return self.add_criteria(other) 

458 

459 def add_criteria( 

460 self, 

461 other, 

462 enable_tracking=True, 

463 track_on=None, 

464 track_closure_variables=True, 

465 track_bound_values=True, 

466 ): 

467 """Add new criteria to this :class:`_sql.StatementLambdaElement`. 

468 

469 E.g.:: 

470 

471 >>> def my_stmt(parameter): 

472 ... stmt = lambda_stmt( 

473 ... lambda: select(table.c.x, table.c.y), 

474 ... ) 

475 ... stmt = stmt.add_criteria( 

476 ... lambda: table.c.x > parameter 

477 ... ) 

478 ... return stmt 

479 

480 The :meth:`_sql.StatementLambdaElement.add_criteria` method is 

481 equivalent to using the Python addition operator to add a new 

482 lambda, except that additional arguments may be added including 

483 ``track_closure_values`` and ``track_on``:: 

484 

485 >>> def my_stmt(self, foo): 

486 ... stmt = lambda_stmt( 

487 ... lambda: select(func.max(foo.x, foo.y)), 

488 ... track_closure_variables=False 

489 ... ) 

490 ... stmt = stmt.add_criteria( 

491 ... lambda: self.where_criteria, 

492 ... track_on=[self] 

493 ... ) 

494 ... return stmt 

495 

496 See :func:`_sql.lambda_stmt` for a description of the parameters 

497 accepted. 

498 

499 """ 

500 

501 opts = self.opts + dict( 

502 enable_tracking=enable_tracking, 

503 track_closure_variables=track_closure_variables, 

504 global_track_bound_values=self.opts.global_track_bound_values, 

505 track_on=track_on, 

506 track_bound_values=track_bound_values, 

507 ) 

508 

509 return LinkedLambdaElement(other, parent_lambda=self, opts=opts) 

510 

511 def _execute_on_connection( 

512 self, connection, multiparams, params, execution_options 

513 ): 

514 if self._rec.expected_expr.supports_execution: 

515 return connection._execute_clauseelement( 

516 self, multiparams, params, execution_options 

517 ) 

518 else: 

519 raise exc.ObjectNotExecutableError(self) 

520 

521 @property 

522 def _with_options(self): 

523 return self._rec.expected_expr._with_options 

524 

525 @property 

526 def _effective_plugin_target(self): 

527 return self._rec.expected_expr._effective_plugin_target 

528 

529 @property 

530 def _execution_options(self): 

531 return self._rec.expected_expr._execution_options 

532 

533 def spoil(self): 

534 """Return a new :class:`.StatementLambdaElement` that will run 

535 all lambdas unconditionally each time. 

536 

537 """ 

538 return NullLambdaStatement(self.fn()) 

539 

540 

541class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement): 

542 """Provides the :class:`.StatementLambdaElement` API but does not 

543 cache or analyze lambdas. 

544 

545 the lambdas are instead invoked immediately. 

546 

547 The intended use is to isolate issues that may arise when using 

548 lambda statements. 

549 

550 """ 

551 

552 __visit_name__ = "lambda_element" 

553 

554 _is_lambda_element = True 

555 

556 _traverse_internals = [ 

557 ("_resolved", visitors.InternalTraversal.dp_clauseelement) 

558 ] 

559 

560 def __init__(self, statement): 

561 self._resolved = statement 

562 self._propagate_attrs = statement._propagate_attrs 

563 

564 def __getattr__(self, key): 

565 return getattr(self._resolved, key) 

566 

567 def __add__(self, other): 

568 statement = other(self._resolved) 

569 

570 return NullLambdaStatement(statement) 

571 

572 def add_criteria(self, other, **kw): 

573 statement = other(self._resolved) 

574 

575 return NullLambdaStatement(statement) 

576 

577 def _execute_on_connection( 

578 self, connection, multiparams, params, execution_options 

579 ): 

580 if self._resolved.supports_execution: 

581 return connection._execute_clauseelement( 

582 self, multiparams, params, execution_options 

583 ) 

584 else: 

585 raise exc.ObjectNotExecutableError(self) 

586 

587 

588class LinkedLambdaElement(StatementLambdaElement): 

589 """Represent subsequent links of a :class:`.StatementLambdaElement`.""" 

590 

591 role = None 

592 

593 def __init__(self, fn, parent_lambda, opts): 

594 self.opts = opts 

595 self.fn = fn 

596 self.parent_lambda = parent_lambda 

597 

598 self.tracker_key = parent_lambda.tracker_key + (fn.__code__,) 

599 self._retrieve_tracker_rec(fn, self, opts) 

600 self._propagate_attrs = parent_lambda._propagate_attrs 

601 

602 def _invoke_user_fn(self, fn, *arg): 

603 return fn(self.parent_lambda._resolved) 

604 

605 

606class AnalyzedCode(object): 

607 __slots__ = ( 

608 "track_closure_variables", 

609 "track_bound_values", 

610 "bindparam_trackers", 

611 "closure_trackers", 

612 "build_py_wrappers", 

613 ) 

614 _fns = weakref.WeakKeyDictionary() 

615 

616 _generation_mutex = threading.RLock() 

617 

618 @classmethod 

619 def get(cls, fn, lambda_element, lambda_kw, **kw): 

620 try: 

621 # TODO: validate kw haven't changed? 

622 return cls._fns[fn.__code__] 

623 except KeyError: 

624 pass 

625 

626 with cls._generation_mutex: 

627 # check for other thread already created object 

628 if fn.__code__ in cls._fns: 

629 return cls._fns[fn.__code__] 

630 

631 cls._fns[fn.__code__] = analyzed = AnalyzedCode( 

632 fn, lambda_element, lambda_kw, **kw 

633 ) 

634 return analyzed 

635 

636 def __init__(self, fn, lambda_element, opts): 

637 if inspect.ismethod(fn): 

638 raise exc.ArgumentError( 

639 "Method %s may not be passed as a SQL expression" % fn 

640 ) 

641 closure = fn.__closure__ 

642 

643 self.track_bound_values = ( 

644 opts.track_bound_values and opts.global_track_bound_values 

645 ) 

646 enable_tracking = opts.enable_tracking 

647 track_on = opts.track_on 

648 track_closure_variables = opts.track_closure_variables 

649 

650 self.track_closure_variables = track_closure_variables and not track_on 

651 

652 # a list of callables generated from _bound_parameter_getter_* 

653 # functions. Each of these uses a PyWrapper object to retrieve 

654 # a parameter value 

655 self.bindparam_trackers = [] 

656 

657 # a list of callables generated from _cache_key_getter_* functions 

658 # these callables work to generate a cache key for the lambda 

659 # based on what's inside its closure variables. 

660 self.closure_trackers = [] 

661 

662 self.build_py_wrappers = [] 

663 

664 if enable_tracking: 

665 if track_on: 

666 self._init_track_on(track_on) 

667 

668 self._init_globals(fn) 

669 

670 if closure: 

671 self._init_closure(fn) 

672 

673 self._setup_additional_closure_trackers(fn, lambda_element, opts) 

674 

675 def _init_track_on(self, track_on): 

676 self.closure_trackers.extend( 

677 self._cache_key_getter_track_on(idx, elem) 

678 for idx, elem in enumerate(track_on) 

679 ) 

680 

681 def _init_globals(self, fn): 

682 build_py_wrappers = self.build_py_wrappers 

683 bindparam_trackers = self.bindparam_trackers 

684 track_bound_values = self.track_bound_values 

685 

686 for name in fn.__code__.co_names: 

687 if name not in fn.__globals__: 

688 continue 

689 

690 _bound_value = self._roll_down_to_literal(fn.__globals__[name]) 

691 

692 if coercions._deep_is_literal(_bound_value): 

693 build_py_wrappers.append((name, None)) 

694 if track_bound_values: 

695 bindparam_trackers.append( 

696 self._bound_parameter_getter_func_globals(name) 

697 ) 

698 

699 def _init_closure(self, fn): 

700 build_py_wrappers = self.build_py_wrappers 

701 closure = fn.__closure__ 

702 

703 track_bound_values = self.track_bound_values 

704 track_closure_variables = self.track_closure_variables 

705 bindparam_trackers = self.bindparam_trackers 

706 closure_trackers = self.closure_trackers 

707 

708 for closure_index, (fv, cell) in enumerate( 

709 zip(fn.__code__.co_freevars, closure) 

710 ): 

711 _bound_value = self._roll_down_to_literal(cell.cell_contents) 

712 

713 if coercions._deep_is_literal(_bound_value): 

714 build_py_wrappers.append((fv, closure_index)) 

715 if track_bound_values: 

716 bindparam_trackers.append( 

717 self._bound_parameter_getter_func_closure( 

718 fv, closure_index 

719 ) 

720 ) 

721 else: 

722 # for normal cell contents, add them to a list that 

723 # we can compare later when we get new lambdas. if 

724 # any identities have changed, then we will 

725 # recalculate the whole lambda and run it again. 

726 

727 if track_closure_variables: 

728 closure_trackers.append( 

729 self._cache_key_getter_closure_variable( 

730 fn, fv, closure_index, cell.cell_contents 

731 ) 

732 ) 

733 

734 def _setup_additional_closure_trackers(self, fn, lambda_element, opts): 

735 # an additional step is to actually run the function, then 

736 # go through the PyWrapper objects that were set up to catch a bound 

737 # parameter. then if they *didn't* make a param, oh they're another 

738 # object in the closure we have to track for our cache key. so 

739 # create trackers to catch those. 

740 

741 analyzed_function = AnalyzedFunction( 

742 self, 

743 lambda_element, 

744 None, 

745 fn, 

746 ) 

747 

748 closure_trackers = self.closure_trackers 

749 

750 for pywrapper in analyzed_function.closure_pywrappers: 

751 if not pywrapper._sa__has_param: 

752 closure_trackers.append( 

753 self._cache_key_getter_tracked_literal(fn, pywrapper) 

754 ) 

755 

756 @classmethod 

757 def _roll_down_to_literal(cls, element): 

758 is_clause_element = hasattr(element, "__clause_element__") 

759 

760 if is_clause_element: 

761 while not isinstance( 

762 element, (elements.ClauseElement, schema.SchemaItem, type) 

763 ): 

764 try: 

765 element = element.__clause_element__() 

766 except AttributeError: 

767 break 

768 

769 if not is_clause_element: 

770 insp = inspection.inspect(element, raiseerr=False) 

771 if insp is not None: 

772 try: 

773 return insp.__clause_element__() 

774 except AttributeError: 

775 return insp 

776 

777 # TODO: should we coerce consts None/True/False here? 

778 return element 

779 else: 

780 return element 

781 

782 def _bound_parameter_getter_func_globals(self, name): 

783 """Return a getter that will extend a list of bound parameters 

784 with new entries from the ``__globals__`` collection of a particular 

785 lambda. 

786 

787 """ 

788 

789 def extract_parameter_value( 

790 current_fn, tracker_instrumented_fn, result 

791 ): 

792 wrapper = tracker_instrumented_fn.__globals__[name] 

793 object.__getattribute__(wrapper, "_extract_bound_parameters")( 

794 current_fn.__globals__[name], result 

795 ) 

796 

797 return extract_parameter_value 

798 

799 def _bound_parameter_getter_func_closure(self, name, closure_index): 

800 """Return a getter that will extend a list of bound parameters 

801 with new entries from the ``__closure__`` collection of a particular 

802 lambda. 

803 

804 """ 

805 

806 def extract_parameter_value( 

807 current_fn, tracker_instrumented_fn, result 

808 ): 

809 wrapper = tracker_instrumented_fn.__closure__[ 

810 closure_index 

811 ].cell_contents 

812 object.__getattribute__(wrapper, "_extract_bound_parameters")( 

813 current_fn.__closure__[closure_index].cell_contents, result 

814 ) 

815 

816 return extract_parameter_value 

817 

818 def _cache_key_getter_track_on(self, idx, elem): 

819 """Return a getter that will extend a cache key with new entries 

820 from the "track_on" parameter passed to a :class:`.LambdaElement`. 

821 

822 """ 

823 

824 if isinstance(elem, tuple): 

825 # tuple must contain hascachekey elements 

826 def get(closure, opts, anon_map, bindparams): 

827 return tuple( 

828 tup_elem._gen_cache_key(anon_map, bindparams) 

829 for tup_elem in opts.track_on[idx] 

830 ) 

831 

832 elif isinstance(elem, traversals.HasCacheKey): 

833 

834 def get(closure, opts, anon_map, bindparams): 

835 return opts.track_on[idx]._gen_cache_key(anon_map, bindparams) 

836 

837 else: 

838 

839 def get(closure, opts, anon_map, bindparams): 

840 return opts.track_on[idx] 

841 

842 return get 

843 

844 def _cache_key_getter_closure_variable( 

845 self, 

846 fn, 

847 variable_name, 

848 idx, 

849 cell_contents, 

850 use_clause_element=False, 

851 use_inspect=False, 

852 ): 

853 """Return a getter that will extend a cache key with new entries 

854 from the ``__closure__`` collection of a particular lambda. 

855 

856 """ 

857 

858 if isinstance(cell_contents, traversals.HasCacheKey): 

859 

860 def get(closure, opts, anon_map, bindparams): 

861 

862 obj = closure[idx].cell_contents 

863 if use_inspect: 

864 obj = inspection.inspect(obj) 

865 elif use_clause_element: 

866 while hasattr(obj, "__clause_element__"): 

867 if not getattr(obj, "is_clause_element", False): 

868 obj = obj.__clause_element__() 

869 

870 return obj._gen_cache_key(anon_map, bindparams) 

871 

872 elif isinstance(cell_contents, types.FunctionType): 

873 

874 def get(closure, opts, anon_map, bindparams): 

875 return closure[idx].cell_contents.__code__ 

876 

877 elif isinstance(cell_contents, collections_abc.Sequence): 

878 

879 def get(closure, opts, anon_map, bindparams): 

880 contents = closure[idx].cell_contents 

881 

882 try: 

883 return tuple( 

884 elem._gen_cache_key(anon_map, bindparams) 

885 for elem in contents 

886 ) 

887 except AttributeError as ae: 

888 self._raise_for_uncacheable_closure_variable( 

889 variable_name, fn, from_=ae 

890 ) 

891 

892 else: 

893 # if the object is a mapped class or aliased class, or some 

894 # other object in the ORM realm of things like that, imitate 

895 # the logic used in coercions.expect() to roll it down to the 

896 # SQL element 

897 element = cell_contents 

898 is_clause_element = False 

899 while hasattr(element, "__clause_element__"): 

900 is_clause_element = True 

901 if not getattr(element, "is_clause_element", False): 

902 element = element.__clause_element__() 

903 else: 

904 break 

905 

906 if not is_clause_element: 

907 insp = inspection.inspect(element, raiseerr=False) 

908 if insp is not None: 

909 return self._cache_key_getter_closure_variable( 

910 fn, variable_name, idx, insp, use_inspect=True 

911 ) 

912 else: 

913 return self._cache_key_getter_closure_variable( 

914 fn, variable_name, idx, element, use_clause_element=True 

915 ) 

916 

917 self._raise_for_uncacheable_closure_variable(variable_name, fn) 

918 

919 return get 

920 

921 def _raise_for_uncacheable_closure_variable( 

922 self, variable_name, fn, from_=None 

923 ): 

924 util.raise_( 

925 exc.InvalidRequestError( 

926 "Closure variable named '%s' inside of lambda callable %s " 

927 "does not refer to a cacheable SQL element, and also does not " 

928 "appear to be serving as a SQL literal bound value based on " 

929 "the default " 

930 "SQL expression returned by the function. This variable " 

931 "needs to remain outside the scope of a SQL-generating lambda " 

932 "so that a proper cache key may be generated from the " 

933 "lambda's state. Evaluate this variable outside of the " 

934 "lambda, set track_on=[<elements>] to explicitly select " 

935 "closure elements to track, or set " 

936 "track_closure_variables=False to exclude " 

937 "closure variables from being part of the cache key." 

938 % (variable_name, fn.__code__), 

939 ), 

940 from_=from_, 

941 ) 

942 

943 def _cache_key_getter_tracked_literal(self, fn, pytracker): 

944 """Return a getter that will extend a cache key with new entries 

945 from the ``__closure__`` collection of a particular lambda. 

946 

947 this getter differs from _cache_key_getter_closure_variable 

948 in that these are detected after the function is run, and PyWrapper 

949 objects have recorded that a particular literal value is in fact 

950 not being interpreted as a bound parameter. 

951 

952 """ 

953 

954 elem = pytracker._sa__to_evaluate 

955 closure_index = pytracker._sa__closure_index 

956 variable_name = pytracker._sa__name 

957 

958 return self._cache_key_getter_closure_variable( 

959 fn, variable_name, closure_index, elem 

960 ) 

961 

962 

963class NonAnalyzedFunction(object): 

964 __slots__ = ("expr",) 

965 

966 closure_bindparams = None 

967 bindparam_trackers = None 

968 

969 def __init__(self, expr): 

970 self.expr = expr 

971 

972 @property 

973 def expected_expr(self): 

974 return self.expr 

975 

976 

977class AnalyzedFunction(object): 

978 __slots__ = ( 

979 "analyzed_code", 

980 "fn", 

981 "closure_pywrappers", 

982 "tracker_instrumented_fn", 

983 "expr", 

984 "bindparam_trackers", 

985 "expected_expr", 

986 "is_sequence", 

987 "propagate_attrs", 

988 "closure_bindparams", 

989 ) 

990 

991 def __init__( 

992 self, 

993 analyzed_code, 

994 lambda_element, 

995 apply_propagate_attrs, 

996 fn, 

997 ): 

998 self.analyzed_code = analyzed_code 

999 self.fn = fn 

1000 

1001 self.bindparam_trackers = analyzed_code.bindparam_trackers 

1002 

1003 self._instrument_and_run_function(lambda_element) 

1004 

1005 self._coerce_expression(lambda_element, apply_propagate_attrs) 

1006 

1007 def _instrument_and_run_function(self, lambda_element): 

1008 analyzed_code = self.analyzed_code 

1009 

1010 fn = self.fn 

1011 self.closure_pywrappers = closure_pywrappers = [] 

1012 

1013 build_py_wrappers = analyzed_code.build_py_wrappers 

1014 

1015 if not build_py_wrappers: 

1016 self.tracker_instrumented_fn = tracker_instrumented_fn = fn 

1017 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn) 

1018 else: 

1019 track_closure_variables = analyzed_code.track_closure_variables 

1020 closure = fn.__closure__ 

1021 

1022 # will form the __closure__ of the function when we rebuild it 

1023 if closure: 

1024 new_closure = { 

1025 fv: cell.cell_contents 

1026 for fv, cell in zip(fn.__code__.co_freevars, closure) 

1027 } 

1028 else: 

1029 new_closure = {} 

1030 

1031 # will form the __globals__ of the function when we rebuild it 

1032 new_globals = fn.__globals__.copy() 

1033 

1034 for name, closure_index in build_py_wrappers: 

1035 if closure_index is not None: 

1036 value = closure[closure_index].cell_contents 

1037 new_closure[name] = bind = PyWrapper( 

1038 fn, 

1039 name, 

1040 value, 

1041 closure_index=closure_index, 

1042 track_bound_values=( 

1043 self.analyzed_code.track_bound_values 

1044 ), 

1045 ) 

1046 if track_closure_variables: 

1047 closure_pywrappers.append(bind) 

1048 else: 

1049 value = fn.__globals__[name] 

1050 new_globals[name] = bind = PyWrapper(fn, name, value) 

1051 

1052 # rewrite the original fn. things that look like they will 

1053 # become bound parameters are wrapped in a PyWrapper. 

1054 self.tracker_instrumented_fn = ( 

1055 tracker_instrumented_fn 

1056 ) = self._rewrite_code_obj( 

1057 fn, 

1058 [new_closure[name] for name in fn.__code__.co_freevars], 

1059 new_globals, 

1060 ) 

1061 

1062 # now invoke the function. This will give us a new SQL 

1063 # expression, but all the places that there would be a bound 

1064 # parameter, the PyWrapper in its place will give us a bind 

1065 # with a predictable name we can match up later. 

1066 

1067 # additionally, each PyWrapper will log that it did in fact 

1068 # create a parameter, otherwise, it's some kind of Python 

1069 # object in the closure and we want to track that, to make 

1070 # sure it doesn't change to something else, or if it does, 

1071 # that we create a different tracked function with that 

1072 # variable. 

1073 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn) 

1074 

1075 def _coerce_expression(self, lambda_element, apply_propagate_attrs): 

1076 """Run the tracker-generated expression through coercion rules. 

1077 

1078 After the user-defined lambda has been invoked to produce a statement 

1079 for re-use, run it through coercion rules to both check that it's the 

1080 correct type of object and also to coerce it to its useful form. 

1081 

1082 """ 

1083 

1084 parent_lambda = lambda_element.parent_lambda 

1085 expr = self.expr 

1086 

1087 if parent_lambda is None: 

1088 if isinstance(expr, collections_abc.Sequence): 

1089 self.expected_expr = [ 

1090 coercions.expect( 

1091 lambda_element.role, 

1092 sub_expr, 

1093 apply_propagate_attrs=apply_propagate_attrs, 

1094 ) 

1095 for sub_expr in expr 

1096 ] 

1097 self.is_sequence = True 

1098 else: 

1099 self.expected_expr = coercions.expect( 

1100 lambda_element.role, 

1101 expr, 

1102 apply_propagate_attrs=apply_propagate_attrs, 

1103 ) 

1104 self.is_sequence = False 

1105 else: 

1106 self.expected_expr = expr 

1107 self.is_sequence = False 

1108 

1109 if apply_propagate_attrs is not None: 

1110 self.propagate_attrs = apply_propagate_attrs._propagate_attrs 

1111 else: 

1112 self.propagate_attrs = util.EMPTY_DICT 

1113 

1114 def _rewrite_code_obj(self, f, cell_values, globals_): 

1115 """Return a copy of f, with a new closure and new globals 

1116 

1117 yes it works in pypy :P 

1118 

1119 """ 

1120 

1121 argrange = range(len(cell_values)) 

1122 

1123 code = "def make_cells():\n" 

1124 if cell_values: 

1125 code += " (%s) = (%s)\n" % ( 

1126 ", ".join("i%d" % i for i in argrange), 

1127 ", ".join("o%d" % i for i in argrange), 

1128 ) 

1129 code += " def closure():\n" 

1130 code += " return %s\n" % ", ".join("i%d" % i for i in argrange) 

1131 code += " return closure.__closure__" 

1132 vars_ = {"o%d" % i: cell_values[i] for i in argrange} 

1133 compat.exec_(code, vars_, vars_) 

1134 closure = vars_["make_cells"]() 

1135 

1136 func = type(f)( 

1137 f.__code__, globals_, f.__name__, f.__defaults__, closure 

1138 ) 

1139 if sys.version_info >= (3,): 

1140 func.__annotations__ = f.__annotations__ 

1141 func.__kwdefaults__ = f.__kwdefaults__ 

1142 func.__doc__ = f.__doc__ 

1143 func.__module__ = f.__module__ 

1144 

1145 return func 

1146 

1147 

1148class PyWrapper(ColumnOperators): 

1149 """A wrapper object that is injected into the ``__globals__`` and 

1150 ``__closure__`` of a Python function. 

1151 

1152 When the function is instrumented with :class:`.PyWrapper` objects, it is 

1153 then invoked just once in order to set up the wrappers. We look through 

1154 all the :class:`.PyWrapper` objects we made to find the ones that generated 

1155 a :class:`.BindParameter` object, e.g. the expression system interpreted 

1156 something as a literal. Those positions in the globals/closure are then 

1157 ones that we will look at, each time a new lambda comes in that refers to 

1158 the same ``__code__`` object. In this way, we keep a single version of 

1159 the SQL expression that this lambda produced, without calling upon the 

1160 Python function that created it more than once, unless its other closure 

1161 variables have changed. The expression is then transformed to have the 

1162 new bound values embedded into it. 

1163 

1164 """ 

1165 

1166 def __init__( 

1167 self, 

1168 fn, 

1169 name, 

1170 to_evaluate, 

1171 closure_index=None, 

1172 getter=None, 

1173 track_bound_values=True, 

1174 ): 

1175 self.fn = fn 

1176 self._name = name 

1177 self._to_evaluate = to_evaluate 

1178 self._param = None 

1179 self._has_param = False 

1180 self._bind_paths = {} 

1181 self._getter = getter 

1182 self._closure_index = closure_index 

1183 self.track_bound_values = track_bound_values 

1184 

1185 def __call__(self, *arg, **kw): 

1186 elem = object.__getattribute__(self, "_to_evaluate") 

1187 value = elem(*arg, **kw) 

1188 if ( 

1189 self._sa_track_bound_values 

1190 and coercions._deep_is_literal(value) 

1191 and not isinstance( 

1192 # TODO: coverage where an ORM option or similar is here 

1193 value, 

1194 traversals.HasCacheKey, 

1195 ) 

1196 ): 

1197 name = object.__getattribute__(self, "_name") 

1198 raise exc.InvalidRequestError( 

1199 "Can't invoke Python callable %s() inside of lambda " 

1200 "expression argument at %s; lambda SQL constructs should " 

1201 "not invoke functions from closure variables to produce " 

1202 "literal values since the " 

1203 "lambda SQL system normally extracts bound values without " 

1204 "actually " 

1205 "invoking the lambda or any functions within it. Call the " 

1206 "function outside of the " 

1207 "lambda and assign to a local variable that is used in the " 

1208 "lambda as a closure variable, or set " 

1209 "track_bound_values=False if the return value of this " 

1210 "function is used in some other way other than a SQL bound " 

1211 "value." % (name, self._sa_fn.__code__) 

1212 ) 

1213 else: 

1214 return value 

1215 

1216 def operate(self, op, *other, **kwargs): 

1217 elem = object.__getattribute__(self, "_py_wrapper_literal")() 

1218 return op(elem, *other, **kwargs) 

1219 

1220 def reverse_operate(self, op, other, **kwargs): 

1221 elem = object.__getattribute__(self, "_py_wrapper_literal")() 

1222 return op(other, elem, **kwargs) 

1223 

1224 def _extract_bound_parameters(self, starting_point, result_list): 

1225 param = object.__getattribute__(self, "_param") 

1226 if param is not None: 

1227 param = param._with_value(starting_point, maintain_key=True) 

1228 result_list.append(param) 

1229 for pywrapper in object.__getattribute__(self, "_bind_paths").values(): 

1230 getter = object.__getattribute__(pywrapper, "_getter") 

1231 element = getter(starting_point) 

1232 pywrapper._sa__extract_bound_parameters(element, result_list) 

1233 

1234 def _py_wrapper_literal(self, expr=None, operator=None, **kw): 

1235 param = object.__getattribute__(self, "_param") 

1236 to_evaluate = object.__getattribute__(self, "_to_evaluate") 

1237 if param is None: 

1238 name = object.__getattribute__(self, "_name") 

1239 self._param = param = elements.BindParameter( 

1240 name, 

1241 required=False, 

1242 unique=True, 

1243 _compared_to_operator=operator, 

1244 _compared_to_type=expr.type if expr is not None else None, 

1245 ) 

1246 self._has_param = True 

1247 return param._with_value(to_evaluate, maintain_key=True) 

1248 

1249 def __bool__(self): 

1250 to_evaluate = object.__getattribute__(self, "_to_evaluate") 

1251 return bool(to_evaluate) 

1252 

1253 def __nonzero__(self): 

1254 to_evaluate = object.__getattribute__(self, "_to_evaluate") 

1255 return bool(to_evaluate) 

1256 

1257 def __getattribute__(self, key): 

1258 if key.startswith("_sa_"): 

1259 return object.__getattribute__(self, key[4:]) 

1260 elif key in ( 

1261 "__clause_element__", 

1262 "operate", 

1263 "reverse_operate", 

1264 "_py_wrapper_literal", 

1265 "__class__", 

1266 "__dict__", 

1267 ): 

1268 return object.__getattribute__(self, key) 

1269 

1270 if key.startswith("__"): 

1271 elem = object.__getattribute__(self, "_to_evaluate") 

1272 return getattr(elem, key) 

1273 else: 

1274 return self._sa__add_getter(key, operator.attrgetter) 

1275 

1276 def __iter__(self): 

1277 elem = object.__getattribute__(self, "_to_evaluate") 

1278 return iter(elem) 

1279 

1280 def __getitem__(self, key): 

1281 elem = object.__getattribute__(self, "_to_evaluate") 

1282 if not hasattr(elem, "__getitem__"): 

1283 raise AttributeError("__getitem__") 

1284 

1285 if isinstance(key, PyWrapper): 

1286 # TODO: coverage 

1287 raise exc.InvalidRequestError( 

1288 "Dictionary keys / list indexes inside of a cached " 

1289 "lambda must be Python literals only" 

1290 ) 

1291 return self._sa__add_getter(key, operator.itemgetter) 

1292 

1293 def _add_getter(self, key, getter_fn): 

1294 

1295 bind_paths = object.__getattribute__(self, "_bind_paths") 

1296 

1297 bind_path_key = (key, getter_fn) 

1298 if bind_path_key in bind_paths: 

1299 return bind_paths[bind_path_key] 

1300 

1301 getter = getter_fn(key) 

1302 elem = object.__getattribute__(self, "_to_evaluate") 

1303 value = getter(elem) 

1304 

1305 rolled_down_value = AnalyzedCode._roll_down_to_literal(value) 

1306 

1307 if coercions._deep_is_literal(rolled_down_value): 

1308 wrapper = PyWrapper(self._sa_fn, key, value, getter=getter) 

1309 bind_paths[bind_path_key] = wrapper 

1310 return wrapper 

1311 else: 

1312 return value 

1313 

1314 

1315@inspection._inspects(LambdaElement) 

1316def insp(lmb): 

1317 return inspection.inspect(lmb._resolved)