1# sql/lambdas.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11import collections.abc as collections_abc
12import inspect
13import itertools
14import operator
15import threading
16import types
17from types import CodeType
18from typing import Any
19from typing import Callable
20from typing import cast
21from typing import List
22from typing import MutableMapping
23from typing import Optional
24from typing import Tuple
25from typing import Type
26from typing import TYPE_CHECKING
27from typing import TypeVar
28from typing import Union
29import weakref
30
31from . import cache_key as _cache_key
32from . import coercions
33from . import elements
34from . import roles
35from . import schema
36from . import visitors
37from .base import _clone
38from .base import Executable
39from .base import Options
40from .cache_key import CacheConst
41from .operators import ColumnOperators
42from .. import exc
43from .. import inspection
44from .. import util
45from ..util.typing import Literal
46
47
48if TYPE_CHECKING:
49 from .elements import BindParameter
50 from .elements import ClauseElement
51 from .roles import SQLRole
52 from .visitors import _CloneCallableType
53
54_LambdaCacheType = MutableMapping[
55 Tuple[Any, ...], Union["NonAnalyzedFunction", "AnalyzedFunction"]
56]
57_BoundParameterGetter = Callable[..., Any]
58
59_closure_per_cache_key: _LambdaCacheType = util.LRUCache(1000)
60
61
62_LambdaType = Callable[[], Any]
63
64_AnyLambdaType = Callable[..., Any]
65
66_StmtLambdaType = Callable[[], Any]
67
68_E = TypeVar("_E", bound=Executable)
69_StmtLambdaElementType = Callable[[_E], Any]
70
71
72class LambdaOptions(Options):
73 enable_tracking = True
74 track_closure_variables = True
75 track_on: Optional[object] = None
76 global_track_bound_values = True
77 track_bound_values = True
78 lambda_cache: Optional[_LambdaCacheType] = None
79
80
81def lambda_stmt(
82 lmb: _StmtLambdaType,
83 enable_tracking: bool = True,
84 track_closure_variables: bool = True,
85 track_on: Optional[object] = None,
86 global_track_bound_values: bool = True,
87 track_bound_values: bool = True,
88 lambda_cache: Optional[_LambdaCacheType] = None,
89) -> StatementLambdaElement:
90 """Produce a SQL statement that is cached as a lambda.
91
92 The Python code object within the lambda is scanned for both Python
93 literals that will become bound parameters as well as closure variables
94 that refer to Core or ORM constructs that may vary. The lambda itself
95 will be invoked only once per particular set of constructs detected.
96
97 E.g.::
98
99 from sqlalchemy import lambda_stmt
100
101 stmt = lambda_stmt(lambda: table.select())
102 stmt += lambda s: s.where(table.c.id == 5)
103
104 result = connection.execute(stmt)
105
106 The object returned is an instance of :class:`_sql.StatementLambdaElement`.
107
108 .. versionadded:: 1.4
109
110 :param lmb: a Python function, typically a lambda, which takes no arguments
111 and returns a SQL expression construct
112 :param enable_tracking: when False, all scanning of the given lambda for
113 changes in closure variables or bound parameters is disabled. Use for
114 a lambda that produces the identical results in all cases with no
115 parameterization.
116 :param track_closure_variables: when False, changes in closure variables
117 within the lambda will not be scanned. Use for a lambda where the
118 state of its closure variables will never change the SQL structure
119 returned by the lambda.
120 :param track_bound_values: when False, bound parameter tracking will
121 be disabled for the given lambda. Use for a lambda that either does
122 not produce any bound values, or where the initial bound values never
123 change.
124 :param global_track_bound_values: when False, bound parameter tracking
125 will be disabled for the entire statement including additional links
126 added via the :meth:`_sql.StatementLambdaElement.add_criteria` method.
127 :param lambda_cache: a dictionary or other mapping-like object where
128 information about the lambda's Python code as well as the tracked closure
129 variables in the lambda itself will be stored. Defaults
130 to a global LRU cache. This cache is independent of the "compiled_cache"
131 used by the :class:`_engine.Connection` object.
132
133 .. seealso::
134
135 :ref:`engine_lambda_caching`
136
137
138 """
139
140 return StatementLambdaElement(
141 lmb,
142 roles.StatementRole,
143 LambdaOptions(
144 enable_tracking=enable_tracking,
145 track_on=track_on,
146 track_closure_variables=track_closure_variables,
147 global_track_bound_values=global_track_bound_values,
148 track_bound_values=track_bound_values,
149 lambda_cache=lambda_cache,
150 ),
151 )
152
153
154class LambdaElement(elements.ClauseElement):
155 """A SQL construct where the state is stored as an un-invoked lambda.
156
157 The :class:`_sql.LambdaElement` is produced transparently whenever
158 passing lambda expressions into SQL constructs, such as::
159
160 stmt = select(table).where(lambda: table.c.col == parameter)
161
162 The :class:`_sql.LambdaElement` is the base of the
163 :class:`_sql.StatementLambdaElement` which represents a full statement
164 within a lambda.
165
166 .. versionadded:: 1.4
167
168 .. seealso::
169
170 :ref:`engine_lambda_caching`
171
172 """
173
174 __visit_name__ = "lambda_element"
175
176 _is_lambda_element = True
177
178 _traverse_internals = [
179 ("_resolved", visitors.InternalTraversal.dp_clauseelement)
180 ]
181
182 _transforms: Tuple[_CloneCallableType, ...] = ()
183
184 _resolved_bindparams: List[BindParameter[Any]]
185 parent_lambda: Optional[StatementLambdaElement] = None
186 closure_cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]]
187 role: Type[SQLRole]
188 _rec: Union[AnalyzedFunction, NonAnalyzedFunction]
189 fn: _AnyLambdaType
190 tracker_key: Tuple[CodeType, ...]
191
192 def __repr__(self):
193 return "%s(%r)" % (
194 self.__class__.__name__,
195 self.fn.__code__,
196 )
197
198 def __init__(
199 self,
200 fn: _LambdaType,
201 role: Type[SQLRole],
202 opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
203 apply_propagate_attrs: Optional[ClauseElement] = None,
204 ):
205 self.fn = fn
206 self.role = role
207 self.tracker_key = (fn.__code__,)
208 self.opts = opts
209
210 if apply_propagate_attrs is None and (role is roles.StatementRole):
211 apply_propagate_attrs = self
212
213 rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts)
214
215 if apply_propagate_attrs is not None:
216 propagate_attrs = rec.propagate_attrs
217 if propagate_attrs:
218 apply_propagate_attrs._propagate_attrs = propagate_attrs
219
220 def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts):
221 lambda_cache = opts.lambda_cache
222 if lambda_cache is None:
223 lambda_cache = _closure_per_cache_key
224
225 tracker_key = self.tracker_key
226
227 fn = self.fn
228 closure = fn.__closure__
229 tracker = AnalyzedCode.get(
230 fn,
231 self,
232 opts,
233 )
234
235 bindparams: List[BindParameter[Any]]
236 self._resolved_bindparams = bindparams = []
237
238 if self.parent_lambda is not None:
239 parent_closure_cache_key = self.parent_lambda.closure_cache_key
240 else:
241 parent_closure_cache_key = ()
242
243 cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]]
244
245 if parent_closure_cache_key is not _cache_key.NO_CACHE:
246 anon_map = visitors.anon_map()
247 cache_key = tuple(
248 [
249 getter(closure, opts, anon_map, bindparams)
250 for getter in tracker.closure_trackers
251 ]
252 )
253
254 if _cache_key.NO_CACHE not in anon_map:
255 cache_key = parent_closure_cache_key + cache_key
256
257 self.closure_cache_key = cache_key
258
259 try:
260 rec = lambda_cache[tracker_key + cache_key]
261 except KeyError:
262 rec = None
263 else:
264 cache_key = _cache_key.NO_CACHE
265 rec = None
266
267 else:
268 cache_key = _cache_key.NO_CACHE
269 rec = None
270
271 self.closure_cache_key = cache_key
272
273 if rec is None:
274 if cache_key is not _cache_key.NO_CACHE:
275 with AnalyzedCode._generation_mutex:
276 key = tracker_key + cache_key
277 if key not in lambda_cache:
278 rec = AnalyzedFunction(
279 tracker, self, apply_propagate_attrs, fn
280 )
281 rec.closure_bindparams = bindparams
282 lambda_cache[key] = rec
283 else:
284 rec = lambda_cache[key]
285 else:
286 rec = NonAnalyzedFunction(self._invoke_user_fn(fn))
287
288 else:
289 bindparams[:] = [
290 orig_bind._with_value(new_bind.value, maintain_key=True)
291 for orig_bind, new_bind in zip(
292 rec.closure_bindparams, bindparams
293 )
294 ]
295
296 self._rec = rec
297
298 if cache_key is not _cache_key.NO_CACHE:
299 if self.parent_lambda is not None:
300 bindparams[:0] = self.parent_lambda._resolved_bindparams
301
302 lambda_element: Optional[LambdaElement] = self
303 while lambda_element is not None:
304 rec = lambda_element._rec
305 if rec.bindparam_trackers:
306 tracker_instrumented_fn = rec.tracker_instrumented_fn
307 for tracker in rec.bindparam_trackers:
308 tracker(
309 lambda_element.fn,
310 tracker_instrumented_fn,
311 bindparams,
312 )
313 lambda_element = lambda_element.parent_lambda
314
315 return rec
316
317 def __getattr__(self, key):
318 return getattr(self._rec.expected_expr, key)
319
320 @property
321 def _is_sequence(self):
322 return self._rec.is_sequence
323
324 @property
325 def _select_iterable(self):
326 if self._is_sequence:
327 return itertools.chain.from_iterable(
328 [element._select_iterable for element in self._resolved]
329 )
330
331 else:
332 return self._resolved._select_iterable
333
334 @property
335 def _from_objects(self):
336 if self._is_sequence:
337 return itertools.chain.from_iterable(
338 [element._from_objects for element in self._resolved]
339 )
340
341 else:
342 return self._resolved._from_objects
343
344 def _param_dict(self):
345 return {b.key: b.value for b in self._resolved_bindparams}
346
347 def _setup_binds_for_tracked_expr(self, expr):
348 bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
349
350 def replace(
351 element: Optional[visitors.ExternallyTraversible], **kw: Any
352 ) -> Optional[visitors.ExternallyTraversible]:
353 if isinstance(element, elements.BindParameter):
354 if element.key in bindparam_lookup:
355 bind = bindparam_lookup[element.key]
356 if element.expanding:
357 bind.expanding = True
358 bind.expand_op = element.expand_op
359 bind.type = element.type
360 return bind
361
362 return None
363
364 if self._rec.is_sequence:
365 expr = [
366 visitors.replacement_traverse(sub_expr, {}, replace)
367 for sub_expr in expr
368 ]
369 elif getattr(expr, "is_clause_element", False):
370 expr = visitors.replacement_traverse(expr, {}, replace)
371
372 return expr
373
374 def _copy_internals(
375 self,
376 clone: _CloneCallableType = _clone,
377 deferred_copy_internals: Optional[_CloneCallableType] = None,
378 **kw: Any,
379 ) -> None:
380 # TODO: this needs A LOT of tests
381 self._resolved = clone(
382 self._resolved,
383 deferred_copy_internals=deferred_copy_internals,
384 **kw,
385 )
386
387 @util.memoized_property
388 def _resolved(self):
389 expr = self._rec.expected_expr
390
391 if self._resolved_bindparams:
392 expr = self._setup_binds_for_tracked_expr(expr)
393
394 return expr
395
396 def _gen_cache_key(self, anon_map, bindparams):
397 if self.closure_cache_key is _cache_key.NO_CACHE:
398 anon_map[_cache_key.NO_CACHE] = True
399 return None
400
401 cache_key = (
402 self.fn.__code__,
403 self.__class__,
404 ) + self.closure_cache_key
405
406 parent = self.parent_lambda
407
408 while parent is not None:
409 assert parent.closure_cache_key is not CacheConst.NO_CACHE
410 parent_closure_cache_key: Tuple[Any, ...] = (
411 parent.closure_cache_key
412 )
413
414 cache_key = (
415 (parent.fn.__code__,) + parent_closure_cache_key + cache_key
416 )
417
418 parent = parent.parent_lambda
419
420 if self._resolved_bindparams:
421 bindparams.extend(self._resolved_bindparams)
422 return cache_key
423
424 def _invoke_user_fn(self, fn: _AnyLambdaType, *arg: Any) -> ClauseElement:
425 return fn() # type: ignore[no-any-return]
426
427
428class DeferredLambdaElement(LambdaElement):
429 """A LambdaElement where the lambda accepts arguments and is
430 invoked within the compile phase with special context.
431
432 This lambda doesn't normally produce its real SQL expression outside of the
433 compile phase. It is passed a fixed set of initial arguments
434 so that it can generate a sample expression.
435
436 """
437
438 def __init__(
439 self,
440 fn: _AnyLambdaType,
441 role: Type[roles.SQLRole],
442 opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
443 lambda_args: Tuple[Any, ...] = (),
444 ):
445 self.lambda_args = lambda_args
446 super().__init__(fn, role, opts)
447
448 def _invoke_user_fn(self, fn, *arg):
449 return fn(*self.lambda_args)
450
451 def _resolve_with_args(self, *lambda_args: Any) -> ClauseElement:
452 assert isinstance(self._rec, AnalyzedFunction)
453 tracker_fn = self._rec.tracker_instrumented_fn
454 expr = tracker_fn(*lambda_args)
455
456 expr = coercions.expect(self.role, expr)
457
458 expr = self._setup_binds_for_tracked_expr(expr)
459
460 # this validation is getting very close, but not quite, to achieving
461 # #5767. The problem is if the base lambda uses an unnamed column
462 # as is very common with mixins, the parameter name is different
463 # and it produces a false positive; that is, for the documented case
464 # that is exactly what people will be doing, it doesn't work, so
465 # I'm not really sure how to handle this right now.
466 # expected_binds = [
467 # b._orig_key
468 # for b in self._rec.expr._generate_cache_key()[1]
469 # if b.required
470 # ]
471 # got_binds = [
472 # b._orig_key for b in expr._generate_cache_key()[1] if b.required
473 # ]
474 # if expected_binds != got_binds:
475 # raise exc.InvalidRequestError(
476 # "Lambda callable at %s produced a different set of bound "
477 # "parameters than its original run: %s"
478 # % (self.fn.__code__, ", ".join(got_binds))
479 # )
480
481 # TODO: TEST TEST TEST, this is very out there
482 for deferred_copy_internals in self._transforms:
483 expr = deferred_copy_internals(expr)
484
485 return expr # type: ignore
486
487 def _copy_internals(
488 self, clone=_clone, deferred_copy_internals=None, **kw
489 ):
490 super()._copy_internals(
491 clone=clone,
492 deferred_copy_internals=deferred_copy_internals, # **kw
493 opts=kw,
494 )
495
496 # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know
497 # our expression yet. so hold onto the replacement
498 if deferred_copy_internals:
499 self._transforms += (deferred_copy_internals,)
500
501
502class StatementLambdaElement(
503 roles.AllowsLambdaRole, LambdaElement, Executable
504):
505 """Represent a composable SQL statement as a :class:`_sql.LambdaElement`.
506
507 The :class:`_sql.StatementLambdaElement` is constructed using the
508 :func:`_sql.lambda_stmt` function::
509
510
511 from sqlalchemy import lambda_stmt
512
513 stmt = lambda_stmt(lambda: select(table))
514
515 Once constructed, additional criteria can be built onto the statement
516 by adding subsequent lambdas, which accept the existing statement
517 object as a single parameter::
518
519 stmt += lambda s: s.where(table.c.col == parameter)
520
521
522 .. versionadded:: 1.4
523
524 .. seealso::
525
526 :ref:`engine_lambda_caching`
527
528 """
529
530 if TYPE_CHECKING:
531
532 def __init__(
533 self,
534 fn: _StmtLambdaType,
535 role: Type[SQLRole],
536 opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
537 apply_propagate_attrs: Optional[ClauseElement] = None,
538 ): ...
539
540 def __add__(
541 self, other: _StmtLambdaElementType[Any]
542 ) -> StatementLambdaElement:
543 return self.add_criteria(other)
544
545 def add_criteria(
546 self,
547 other: _StmtLambdaElementType[Any],
548 enable_tracking: bool = True,
549 track_on: Optional[Any] = None,
550 track_closure_variables: bool = True,
551 track_bound_values: bool = True,
552 ) -> StatementLambdaElement:
553 """Add new criteria to this :class:`_sql.StatementLambdaElement`.
554
555 E.g.::
556
557 >>> def my_stmt(parameter):
558 ... stmt = lambda_stmt(
559 ... lambda: select(table.c.x, table.c.y),
560 ... )
561 ... stmt = stmt.add_criteria(
562 ... lambda: table.c.x > parameter
563 ... )
564 ... return stmt
565
566 The :meth:`_sql.StatementLambdaElement.add_criteria` method is
567 equivalent to using the Python addition operator to add a new
568 lambda, except that additional arguments may be added including
569 ``track_closure_values`` and ``track_on``::
570
571 >>> def my_stmt(self, foo):
572 ... stmt = lambda_stmt(
573 ... lambda: select(func.max(foo.x, foo.y)),
574 ... track_closure_variables=False
575 ... )
576 ... stmt = stmt.add_criteria(
577 ... lambda: self.where_criteria,
578 ... track_on=[self]
579 ... )
580 ... return stmt
581
582 See :func:`_sql.lambda_stmt` for a description of the parameters
583 accepted.
584
585 """
586
587 opts = self.opts + dict(
588 enable_tracking=enable_tracking,
589 track_closure_variables=track_closure_variables,
590 global_track_bound_values=self.opts.global_track_bound_values,
591 track_on=track_on,
592 track_bound_values=track_bound_values,
593 )
594
595 return LinkedLambdaElement(other, parent_lambda=self, opts=opts)
596
597 def _execute_on_connection(
598 self, connection, distilled_params, execution_options
599 ):
600 if TYPE_CHECKING:
601 assert isinstance(self._rec.expected_expr, ClauseElement)
602 if self._rec.expected_expr.supports_execution:
603 return connection._execute_clauseelement(
604 self, distilled_params, execution_options
605 )
606 else:
607 raise exc.ObjectNotExecutableError(self)
608
609 @property
610 def _proxied(self) -> Any:
611 return self._rec_expected_expr
612
613 @property
614 def _with_options(self):
615 return self._proxied._with_options
616
617 @property
618 def _effective_plugin_target(self):
619 return self._proxied._effective_plugin_target
620
621 @property
622 def _execution_options(self):
623 return self._proxied._execution_options
624
625 @property
626 def _all_selected_columns(self):
627 return self._proxied._all_selected_columns
628
629 @property
630 def is_select(self):
631 return self._proxied.is_select
632
633 @property
634 def is_update(self):
635 return self._proxied.is_update
636
637 @property
638 def is_insert(self):
639 return self._proxied.is_insert
640
641 @property
642 def is_text(self):
643 return self._proxied.is_text
644
645 @property
646 def is_delete(self):
647 return self._proxied.is_delete
648
649 @property
650 def is_dml(self):
651 return self._proxied.is_dml
652
653 def spoil(self) -> NullLambdaStatement:
654 """Return a new :class:`.StatementLambdaElement` that will run
655 all lambdas unconditionally each time.
656
657 """
658 return NullLambdaStatement(self.fn())
659
660
661class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement):
662 """Provides the :class:`.StatementLambdaElement` API but does not
663 cache or analyze lambdas.
664
665 the lambdas are instead invoked immediately.
666
667 The intended use is to isolate issues that may arise when using
668 lambda statements.
669
670 """
671
672 __visit_name__ = "lambda_element"
673
674 _is_lambda_element = True
675
676 _traverse_internals = [
677 ("_resolved", visitors.InternalTraversal.dp_clauseelement)
678 ]
679
680 def __init__(self, statement):
681 self._resolved = statement
682 self._propagate_attrs = statement._propagate_attrs
683
684 def __getattr__(self, key):
685 return getattr(self._resolved, key)
686
687 def __add__(self, other):
688 statement = other(self._resolved)
689
690 return NullLambdaStatement(statement)
691
692 def add_criteria(self, other, **kw):
693 statement = other(self._resolved)
694
695 return NullLambdaStatement(statement)
696
697 def _execute_on_connection(
698 self, connection, distilled_params, execution_options
699 ):
700 if self._resolved.supports_execution:
701 return connection._execute_clauseelement(
702 self, distilled_params, execution_options
703 )
704 else:
705 raise exc.ObjectNotExecutableError(self)
706
707
708class LinkedLambdaElement(StatementLambdaElement):
709 """Represent subsequent links of a :class:`.StatementLambdaElement`."""
710
711 parent_lambda: StatementLambdaElement
712
713 def __init__(
714 self,
715 fn: _StmtLambdaElementType[Any],
716 parent_lambda: StatementLambdaElement,
717 opts: Union[Type[LambdaOptions], LambdaOptions],
718 ):
719 self.opts = opts
720 self.fn = fn
721 self.parent_lambda = parent_lambda
722
723 self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
724 self._retrieve_tracker_rec(fn, self, opts)
725 self._propagate_attrs = parent_lambda._propagate_attrs
726
727 def _invoke_user_fn(self, fn, *arg):
728 return fn(self.parent_lambda._resolved)
729
730
731class AnalyzedCode:
732 __slots__ = (
733 "track_closure_variables",
734 "track_bound_values",
735 "bindparam_trackers",
736 "closure_trackers",
737 "build_py_wrappers",
738 )
739 _fns: weakref.WeakKeyDictionary[CodeType, AnalyzedCode] = (
740 weakref.WeakKeyDictionary()
741 )
742
743 _generation_mutex = threading.RLock()
744
745 @classmethod
746 def get(cls, fn, lambda_element, lambda_kw, **kw):
747 try:
748 # TODO: validate kw haven't changed?
749 return cls._fns[fn.__code__]
750 except KeyError:
751 pass
752
753 with cls._generation_mutex:
754 # check for other thread already created object
755 if fn.__code__ in cls._fns:
756 return cls._fns[fn.__code__]
757
758 analyzed: AnalyzedCode
759 cls._fns[fn.__code__] = analyzed = AnalyzedCode(
760 fn, lambda_element, lambda_kw, **kw
761 )
762 return analyzed
763
764 def __init__(self, fn, lambda_element, opts):
765 if inspect.ismethod(fn):
766 raise exc.ArgumentError(
767 "Method %s may not be passed as a SQL expression" % fn
768 )
769 closure = fn.__closure__
770
771 self.track_bound_values = (
772 opts.track_bound_values and opts.global_track_bound_values
773 )
774 enable_tracking = opts.enable_tracking
775 track_on = opts.track_on
776 track_closure_variables = opts.track_closure_variables
777
778 self.track_closure_variables = track_closure_variables and not track_on
779
780 # a list of callables generated from _bound_parameter_getter_*
781 # functions. Each of these uses a PyWrapper object to retrieve
782 # a parameter value
783 self.bindparam_trackers = []
784
785 # a list of callables generated from _cache_key_getter_* functions
786 # these callables work to generate a cache key for the lambda
787 # based on what's inside its closure variables.
788 self.closure_trackers = []
789
790 self.build_py_wrappers = []
791
792 if enable_tracking:
793 if track_on:
794 self._init_track_on(track_on)
795
796 self._init_globals(fn)
797
798 if closure:
799 self._init_closure(fn)
800
801 self._setup_additional_closure_trackers(fn, lambda_element, opts)
802
803 def _init_track_on(self, track_on):
804 self.closure_trackers.extend(
805 self._cache_key_getter_track_on(idx, elem)
806 for idx, elem in enumerate(track_on)
807 )
808
809 def _init_globals(self, fn):
810 build_py_wrappers = self.build_py_wrappers
811 bindparam_trackers = self.bindparam_trackers
812 track_bound_values = self.track_bound_values
813
814 for name in fn.__code__.co_names:
815 if name not in fn.__globals__:
816 continue
817
818 _bound_value = self._roll_down_to_literal(fn.__globals__[name])
819
820 if coercions._deep_is_literal(_bound_value):
821 build_py_wrappers.append((name, None))
822 if track_bound_values:
823 bindparam_trackers.append(
824 self._bound_parameter_getter_func_globals(name)
825 )
826
827 def _init_closure(self, fn):
828 build_py_wrappers = self.build_py_wrappers
829 closure = fn.__closure__
830
831 track_bound_values = self.track_bound_values
832 track_closure_variables = self.track_closure_variables
833 bindparam_trackers = self.bindparam_trackers
834 closure_trackers = self.closure_trackers
835
836 for closure_index, (fv, cell) in enumerate(
837 zip(fn.__code__.co_freevars, closure)
838 ):
839 _bound_value = self._roll_down_to_literal(cell.cell_contents)
840
841 if coercions._deep_is_literal(_bound_value):
842 build_py_wrappers.append((fv, closure_index))
843 if track_bound_values:
844 bindparam_trackers.append(
845 self._bound_parameter_getter_func_closure(
846 fv, closure_index
847 )
848 )
849 else:
850 # for normal cell contents, add them to a list that
851 # we can compare later when we get new lambdas. if
852 # any identities have changed, then we will
853 # recalculate the whole lambda and run it again.
854
855 if track_closure_variables:
856 closure_trackers.append(
857 self._cache_key_getter_closure_variable(
858 fn, fv, closure_index, cell.cell_contents
859 )
860 )
861
862 def _setup_additional_closure_trackers(self, fn, lambda_element, opts):
863 # an additional step is to actually run the function, then
864 # go through the PyWrapper objects that were set up to catch a bound
865 # parameter. then if they *didn't* make a param, oh they're another
866 # object in the closure we have to track for our cache key. so
867 # create trackers to catch those.
868
869 analyzed_function = AnalyzedFunction(
870 self,
871 lambda_element,
872 None,
873 fn,
874 )
875
876 closure_trackers = self.closure_trackers
877
878 for pywrapper in analyzed_function.closure_pywrappers:
879 if not pywrapper._sa__has_param:
880 closure_trackers.append(
881 self._cache_key_getter_tracked_literal(fn, pywrapper)
882 )
883
884 @classmethod
885 def _roll_down_to_literal(cls, element):
886 is_clause_element = hasattr(element, "__clause_element__")
887
888 if is_clause_element:
889 while not isinstance(
890 element, (elements.ClauseElement, schema.SchemaItem, type)
891 ):
892 try:
893 element = element.__clause_element__()
894 except AttributeError:
895 break
896
897 if not is_clause_element:
898 insp = inspection.inspect(element, raiseerr=False)
899 if insp is not None:
900 try:
901 return insp.__clause_element__()
902 except AttributeError:
903 return insp
904
905 # TODO: should we coerce consts None/True/False here?
906 return element
907 else:
908 return element
909
910 def _bound_parameter_getter_func_globals(self, name):
911 """Return a getter that will extend a list of bound parameters
912 with new entries from the ``__globals__`` collection of a particular
913 lambda.
914
915 """
916
917 def extract_parameter_value(
918 current_fn, tracker_instrumented_fn, result
919 ):
920 wrapper = tracker_instrumented_fn.__globals__[name]
921 object.__getattribute__(wrapper, "_extract_bound_parameters")(
922 current_fn.__globals__[name], result
923 )
924
925 return extract_parameter_value
926
927 def _bound_parameter_getter_func_closure(self, name, closure_index):
928 """Return a getter that will extend a list of bound parameters
929 with new entries from the ``__closure__`` collection of a particular
930 lambda.
931
932 """
933
934 def extract_parameter_value(
935 current_fn, tracker_instrumented_fn, result
936 ):
937 wrapper = tracker_instrumented_fn.__closure__[
938 closure_index
939 ].cell_contents
940 object.__getattribute__(wrapper, "_extract_bound_parameters")(
941 current_fn.__closure__[closure_index].cell_contents, result
942 )
943
944 return extract_parameter_value
945
946 def _cache_key_getter_track_on(self, idx, elem):
947 """Return a getter that will extend a cache key with new entries
948 from the "track_on" parameter passed to a :class:`.LambdaElement`.
949
950 """
951
952 if isinstance(elem, tuple):
953 # tuple must contain hascachekey elements
954 def get(closure, opts, anon_map, bindparams):
955 return tuple(
956 tup_elem._gen_cache_key(anon_map, bindparams)
957 for tup_elem in opts.track_on[idx]
958 )
959
960 elif isinstance(elem, _cache_key.HasCacheKey):
961
962 def get(closure, opts, anon_map, bindparams):
963 return opts.track_on[idx]._gen_cache_key(anon_map, bindparams)
964
965 else:
966
967 def get(closure, opts, anon_map, bindparams):
968 return opts.track_on[idx]
969
970 return get
971
972 def _cache_key_getter_closure_variable(
973 self,
974 fn,
975 variable_name,
976 idx,
977 cell_contents,
978 use_clause_element=False,
979 use_inspect=False,
980 ):
981 """Return a getter that will extend a cache key with new entries
982 from the ``__closure__`` collection of a particular lambda.
983
984 """
985
986 if isinstance(cell_contents, _cache_key.HasCacheKey):
987
988 def get(closure, opts, anon_map, bindparams):
989 obj = closure[idx].cell_contents
990 if use_inspect:
991 obj = inspection.inspect(obj)
992 elif use_clause_element:
993 while hasattr(obj, "__clause_element__"):
994 if not getattr(obj, "is_clause_element", False):
995 obj = obj.__clause_element__()
996
997 return obj._gen_cache_key(anon_map, bindparams)
998
999 elif isinstance(cell_contents, types.FunctionType):
1000
1001 def get(closure, opts, anon_map, bindparams):
1002 return closure[idx].cell_contents.__code__
1003
1004 elif isinstance(cell_contents, collections_abc.Sequence):
1005
1006 def get(closure, opts, anon_map, bindparams):
1007 contents = closure[idx].cell_contents
1008
1009 try:
1010 return tuple(
1011 elem._gen_cache_key(anon_map, bindparams)
1012 for elem in contents
1013 )
1014 except AttributeError as ae:
1015 self._raise_for_uncacheable_closure_variable(
1016 variable_name, fn, from_=ae
1017 )
1018
1019 else:
1020 # if the object is a mapped class or aliased class, or some
1021 # other object in the ORM realm of things like that, imitate
1022 # the logic used in coercions.expect() to roll it down to the
1023 # SQL element
1024 element = cell_contents
1025 is_clause_element = False
1026 while hasattr(element, "__clause_element__"):
1027 is_clause_element = True
1028 if not getattr(element, "is_clause_element", False):
1029 element = element.__clause_element__()
1030 else:
1031 break
1032
1033 if not is_clause_element:
1034 insp = inspection.inspect(element, raiseerr=False)
1035 if insp is not None:
1036 return self._cache_key_getter_closure_variable(
1037 fn, variable_name, idx, insp, use_inspect=True
1038 )
1039 else:
1040 return self._cache_key_getter_closure_variable(
1041 fn, variable_name, idx, element, use_clause_element=True
1042 )
1043
1044 self._raise_for_uncacheable_closure_variable(variable_name, fn)
1045
1046 return get
1047
1048 def _raise_for_uncacheable_closure_variable(
1049 self, variable_name, fn, from_=None
1050 ):
1051 raise exc.InvalidRequestError(
1052 "Closure variable named '%s' inside of lambda callable %s "
1053 "does not refer to a cacheable SQL element, and also does not "
1054 "appear to be serving as a SQL literal bound value based on "
1055 "the default "
1056 "SQL expression returned by the function. This variable "
1057 "needs to remain outside the scope of a SQL-generating lambda "
1058 "so that a proper cache key may be generated from the "
1059 "lambda's state. Evaluate this variable outside of the "
1060 "lambda, set track_on=[<elements>] to explicitly select "
1061 "closure elements to track, or set "
1062 "track_closure_variables=False to exclude "
1063 "closure variables from being part of the cache key."
1064 % (variable_name, fn.__code__),
1065 ) from from_
1066
1067 def _cache_key_getter_tracked_literal(self, fn, pytracker):
1068 """Return a getter that will extend a cache key with new entries
1069 from the ``__closure__`` collection of a particular lambda.
1070
1071 this getter differs from _cache_key_getter_closure_variable
1072 in that these are detected after the function is run, and PyWrapper
1073 objects have recorded that a particular literal value is in fact
1074 not being interpreted as a bound parameter.
1075
1076 """
1077
1078 elem = pytracker._sa__to_evaluate
1079 closure_index = pytracker._sa__closure_index
1080 variable_name = pytracker._sa__name
1081
1082 return self._cache_key_getter_closure_variable(
1083 fn, variable_name, closure_index, elem
1084 )
1085
1086
1087class NonAnalyzedFunction:
1088 __slots__ = ("expr",)
1089
1090 closure_bindparams: Optional[List[BindParameter[Any]]] = None
1091 bindparam_trackers: Optional[List[_BoundParameterGetter]] = None
1092
1093 is_sequence = False
1094
1095 expr: ClauseElement
1096
1097 def __init__(self, expr: ClauseElement):
1098 self.expr = expr
1099
1100 @property
1101 def expected_expr(self) -> ClauseElement:
1102 return self.expr
1103
1104
1105class AnalyzedFunction:
1106 __slots__ = (
1107 "analyzed_code",
1108 "fn",
1109 "closure_pywrappers",
1110 "tracker_instrumented_fn",
1111 "expr",
1112 "bindparam_trackers",
1113 "expected_expr",
1114 "is_sequence",
1115 "propagate_attrs",
1116 "closure_bindparams",
1117 )
1118
1119 closure_bindparams: Optional[List[BindParameter[Any]]]
1120 expected_expr: Union[ClauseElement, List[ClauseElement]]
1121 bindparam_trackers: Optional[List[_BoundParameterGetter]]
1122
1123 def __init__(
1124 self,
1125 analyzed_code,
1126 lambda_element,
1127 apply_propagate_attrs,
1128 fn,
1129 ):
1130 self.analyzed_code = analyzed_code
1131 self.fn = fn
1132
1133 self.bindparam_trackers = analyzed_code.bindparam_trackers
1134
1135 self._instrument_and_run_function(lambda_element)
1136
1137 self._coerce_expression(lambda_element, apply_propagate_attrs)
1138
1139 def _instrument_and_run_function(self, lambda_element):
1140 analyzed_code = self.analyzed_code
1141
1142 fn = self.fn
1143 self.closure_pywrappers = closure_pywrappers = []
1144
1145 build_py_wrappers = analyzed_code.build_py_wrappers
1146
1147 if not build_py_wrappers:
1148 self.tracker_instrumented_fn = tracker_instrumented_fn = fn
1149 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
1150 else:
1151 track_closure_variables = analyzed_code.track_closure_variables
1152 closure = fn.__closure__
1153
1154 # will form the __closure__ of the function when we rebuild it
1155 if closure:
1156 new_closure = {
1157 fv: cell.cell_contents
1158 for fv, cell in zip(fn.__code__.co_freevars, closure)
1159 }
1160 else:
1161 new_closure = {}
1162
1163 # will form the __globals__ of the function when we rebuild it
1164 new_globals = fn.__globals__.copy()
1165
1166 for name, closure_index in build_py_wrappers:
1167 if closure_index is not None:
1168 value = closure[closure_index].cell_contents
1169 new_closure[name] = bind = PyWrapper(
1170 fn,
1171 name,
1172 value,
1173 closure_index=closure_index,
1174 track_bound_values=(
1175 self.analyzed_code.track_bound_values
1176 ),
1177 )
1178 if track_closure_variables:
1179 closure_pywrappers.append(bind)
1180 else:
1181 value = fn.__globals__[name]
1182 new_globals[name] = bind = PyWrapper(fn, name, value)
1183
1184 # rewrite the original fn. things that look like they will
1185 # become bound parameters are wrapped in a PyWrapper.
1186 self.tracker_instrumented_fn = tracker_instrumented_fn = (
1187 self._rewrite_code_obj(
1188 fn,
1189 [new_closure[name] for name in fn.__code__.co_freevars],
1190 new_globals,
1191 )
1192 )
1193
1194 # now invoke the function. This will give us a new SQL
1195 # expression, but all the places that there would be a bound
1196 # parameter, the PyWrapper in its place will give us a bind
1197 # with a predictable name we can match up later.
1198
1199 # additionally, each PyWrapper will log that it did in fact
1200 # create a parameter, otherwise, it's some kind of Python
1201 # object in the closure and we want to track that, to make
1202 # sure it doesn't change to something else, or if it does,
1203 # that we create a different tracked function with that
1204 # variable.
1205 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
1206
1207 def _coerce_expression(self, lambda_element, apply_propagate_attrs):
1208 """Run the tracker-generated expression through coercion rules.
1209
1210 After the user-defined lambda has been invoked to produce a statement
1211 for re-use, run it through coercion rules to both check that it's the
1212 correct type of object and also to coerce it to its useful form.
1213
1214 """
1215
1216 parent_lambda = lambda_element.parent_lambda
1217 expr = self.expr
1218
1219 if parent_lambda is None:
1220 if isinstance(expr, collections_abc.Sequence):
1221 self.expected_expr = [
1222 cast(
1223 "ClauseElement",
1224 coercions.expect(
1225 lambda_element.role,
1226 sub_expr,
1227 apply_propagate_attrs=apply_propagate_attrs,
1228 ),
1229 )
1230 for sub_expr in expr
1231 ]
1232 self.is_sequence = True
1233 else:
1234 self.expected_expr = cast(
1235 "ClauseElement",
1236 coercions.expect(
1237 lambda_element.role,
1238 expr,
1239 apply_propagate_attrs=apply_propagate_attrs,
1240 ),
1241 )
1242 self.is_sequence = False
1243 else:
1244 self.expected_expr = expr
1245 self.is_sequence = False
1246
1247 if apply_propagate_attrs is not None:
1248 self.propagate_attrs = apply_propagate_attrs._propagate_attrs
1249 else:
1250 self.propagate_attrs = util.EMPTY_DICT
1251
1252 def _rewrite_code_obj(self, f, cell_values, globals_):
1253 """Return a copy of f, with a new closure and new globals
1254
1255 yes it works in pypy :P
1256
1257 """
1258
1259 argrange = range(len(cell_values))
1260
1261 code = "def make_cells():\n"
1262 if cell_values:
1263 code += " (%s) = (%s)\n" % (
1264 ", ".join("i%d" % i for i in argrange),
1265 ", ".join("o%d" % i for i in argrange),
1266 )
1267 code += " def closure():\n"
1268 code += " return %s\n" % ", ".join("i%d" % i for i in argrange)
1269 code += " return closure.__closure__"
1270 vars_ = {"o%d" % i: cell_values[i] for i in argrange}
1271 exec(code, vars_, vars_)
1272 closure = vars_["make_cells"]()
1273
1274 func = type(f)(
1275 f.__code__, globals_, f.__name__, f.__defaults__, closure
1276 )
1277 func.__annotations__ = f.__annotations__
1278 func.__kwdefaults__ = f.__kwdefaults__
1279 func.__doc__ = f.__doc__
1280 func.__module__ = f.__module__
1281
1282 return func
1283
1284
1285class PyWrapper(ColumnOperators):
1286 """A wrapper object that is injected into the ``__globals__`` and
1287 ``__closure__`` of a Python function.
1288
1289 When the function is instrumented with :class:`.PyWrapper` objects, it is
1290 then invoked just once in order to set up the wrappers. We look through
1291 all the :class:`.PyWrapper` objects we made to find the ones that generated
1292 a :class:`.BindParameter` object, e.g. the expression system interpreted
1293 something as a literal. Those positions in the globals/closure are then
1294 ones that we will look at, each time a new lambda comes in that refers to
1295 the same ``__code__`` object. In this way, we keep a single version of
1296 the SQL expression that this lambda produced, without calling upon the
1297 Python function that created it more than once, unless its other closure
1298 variables have changed. The expression is then transformed to have the
1299 new bound values embedded into it.
1300
1301 """
1302
1303 def __init__(
1304 self,
1305 fn,
1306 name,
1307 to_evaluate,
1308 closure_index=None,
1309 getter=None,
1310 track_bound_values=True,
1311 ):
1312 self.fn = fn
1313 self._name = name
1314 self._to_evaluate = to_evaluate
1315 self._param = None
1316 self._has_param = False
1317 self._bind_paths = {}
1318 self._getter = getter
1319 self._closure_index = closure_index
1320 self.track_bound_values = track_bound_values
1321
1322 def __call__(self, *arg, **kw):
1323 elem = object.__getattribute__(self, "_to_evaluate")
1324 value = elem(*arg, **kw)
1325 if (
1326 self._sa_track_bound_values
1327 and coercions._deep_is_literal(value)
1328 and not isinstance(
1329 # TODO: coverage where an ORM option or similar is here
1330 value,
1331 _cache_key.HasCacheKey,
1332 )
1333 ):
1334 name = object.__getattribute__(self, "_name")
1335 raise exc.InvalidRequestError(
1336 "Can't invoke Python callable %s() inside of lambda "
1337 "expression argument at %s; lambda SQL constructs should "
1338 "not invoke functions from closure variables to produce "
1339 "literal values since the "
1340 "lambda SQL system normally extracts bound values without "
1341 "actually "
1342 "invoking the lambda or any functions within it. Call the "
1343 "function outside of the "
1344 "lambda and assign to a local variable that is used in the "
1345 "lambda as a closure variable, or set "
1346 "track_bound_values=False if the return value of this "
1347 "function is used in some other way other than a SQL bound "
1348 "value." % (name, self._sa_fn.__code__)
1349 )
1350 else:
1351 return value
1352
1353 def operate(self, op, *other, **kwargs):
1354 elem = object.__getattribute__(self, "_py_wrapper_literal")()
1355 return op(elem, *other, **kwargs)
1356
1357 def reverse_operate(self, op, other, **kwargs):
1358 elem = object.__getattribute__(self, "_py_wrapper_literal")()
1359 return op(other, elem, **kwargs)
1360
1361 def _extract_bound_parameters(self, starting_point, result_list):
1362 param = object.__getattribute__(self, "_param")
1363 if param is not None:
1364 param = param._with_value(starting_point, maintain_key=True)
1365 result_list.append(param)
1366 for pywrapper in object.__getattribute__(self, "_bind_paths").values():
1367 getter = object.__getattribute__(pywrapper, "_getter")
1368 element = getter(starting_point)
1369 pywrapper._sa__extract_bound_parameters(element, result_list)
1370
1371 def _py_wrapper_literal(self, expr=None, operator=None, **kw):
1372 param = object.__getattribute__(self, "_param")
1373 to_evaluate = object.__getattribute__(self, "_to_evaluate")
1374 if param is None:
1375 name = object.__getattribute__(self, "_name")
1376 self._param = param = elements.BindParameter(
1377 name,
1378 required=False,
1379 unique=True,
1380 _compared_to_operator=operator,
1381 _compared_to_type=expr.type if expr is not None else None,
1382 )
1383 self._has_param = True
1384 return param._with_value(to_evaluate, maintain_key=True)
1385
1386 def __bool__(self):
1387 to_evaluate = object.__getattribute__(self, "_to_evaluate")
1388 return bool(to_evaluate)
1389
1390 def __getattribute__(self, key):
1391 if key.startswith("_sa_"):
1392 return object.__getattribute__(self, key[4:])
1393 elif key in (
1394 "__clause_element__",
1395 "operate",
1396 "reverse_operate",
1397 "_py_wrapper_literal",
1398 "__class__",
1399 "__dict__",
1400 ):
1401 return object.__getattribute__(self, key)
1402
1403 if key.startswith("__"):
1404 elem = object.__getattribute__(self, "_to_evaluate")
1405 return getattr(elem, key)
1406 else:
1407 return self._sa__add_getter(key, operator.attrgetter)
1408
1409 def __iter__(self):
1410 elem = object.__getattribute__(self, "_to_evaluate")
1411 return iter(elem)
1412
1413 def __getitem__(self, key):
1414 elem = object.__getattribute__(self, "_to_evaluate")
1415 if not hasattr(elem, "__getitem__"):
1416 raise AttributeError("__getitem__")
1417
1418 if isinstance(key, PyWrapper):
1419 # TODO: coverage
1420 raise exc.InvalidRequestError(
1421 "Dictionary keys / list indexes inside of a cached "
1422 "lambda must be Python literals only"
1423 )
1424 return self._sa__add_getter(key, operator.itemgetter)
1425
1426 def _add_getter(self, key, getter_fn):
1427 bind_paths = object.__getattribute__(self, "_bind_paths")
1428
1429 bind_path_key = (key, getter_fn)
1430 if bind_path_key in bind_paths:
1431 return bind_paths[bind_path_key]
1432
1433 getter = getter_fn(key)
1434 elem = object.__getattribute__(self, "_to_evaluate")
1435 value = getter(elem)
1436
1437 rolled_down_value = AnalyzedCode._roll_down_to_literal(value)
1438
1439 if coercions._deep_is_literal(rolled_down_value):
1440 wrapper = PyWrapper(self._sa_fn, key, value, getter=getter)
1441 bind_paths[bind_path_key] = wrapper
1442 return wrapper
1443 else:
1444 return value
1445
1446
1447@inspection._inspects(LambdaElement)
1448def insp(lmb):
1449 return inspection.inspect(lmb._resolved)