1# sql/lambdas.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11import collections.abc as collections_abc
12import inspect
13import itertools
14import operator
15import threading
16import types
17from types import CodeType
18from typing import Any
19from typing import Callable
20from typing import cast
21from typing import List
22from typing import Literal
23from typing import MutableMapping
24from typing import Optional
25from typing import Tuple
26from typing import Type
27from typing import TYPE_CHECKING
28from typing import TypeVar
29from typing import Union
30import weakref
31
32from . import cache_key as _cache_key
33from . import coercions
34from . import elements
35from . import roles
36from . import schema
37from . import visitors
38from .base import _clone
39from .base import Executable
40from .base import ExecutableStatement
41from .base import Options
42from .cache_key import CacheConst
43from .operators import ColumnOperators
44from .. import exc
45from .. import inspection
46from .. import util
47
48
49if TYPE_CHECKING:
50 from .elements import BindParameter
51 from .elements import ClauseElement
52 from .roles import SQLRole
53 from .visitors import _CloneCallableType
54
55_LambdaCacheType = MutableMapping[
56 Tuple[Any, ...], Union["NonAnalyzedFunction", "AnalyzedFunction"]
57]
58_BoundParameterGetter = Callable[..., Any]
59
60_closure_per_cache_key: _LambdaCacheType = util.LRUCache(1000)
61
62
63_LambdaType = Callable[[], Any]
64
65_AnyLambdaType = Callable[..., Any]
66
67_StmtLambdaType = Callable[[], Any]
68
69_E = TypeVar("_E", bound=Executable)
70_StmtLambdaElementType = Callable[[_E], Any]
71
72
73class LambdaOptions(Options):
74 enable_tracking = True
75 track_closure_variables = True
76 track_on: Optional[object] = None
77 global_track_bound_values = True
78 track_bound_values = True
79 lambda_cache: Optional[_LambdaCacheType] = None
80
81
82def lambda_stmt(
83 lmb: _StmtLambdaType,
84 enable_tracking: bool = True,
85 track_closure_variables: bool = True,
86 track_on: Optional[object] = None,
87 global_track_bound_values: bool = True,
88 track_bound_values: bool = True,
89 lambda_cache: Optional[_LambdaCacheType] = None,
90) -> StatementLambdaElement:
91 """Produce a SQL statement that is cached as a lambda.
92
93 The Python code object within the lambda is scanned for both Python
94 literals that will become bound parameters as well as closure variables
95 that refer to Core or ORM constructs that may vary. The lambda itself
96 will be invoked only once per particular set of constructs detected.
97
98 E.g.::
99
100 from sqlalchemy import lambda_stmt
101
102 stmt = lambda_stmt(lambda: table.select())
103 stmt += lambda s: s.where(table.c.id == 5)
104
105 result = connection.execute(stmt)
106
107 The object returned is an instance of :class:`_sql.StatementLambdaElement`.
108
109 .. versionadded:: 1.4
110
111 :param lmb: a Python function, typically a lambda, which takes no arguments
112 and returns a SQL expression construct
113 :param enable_tracking: when False, all scanning of the given lambda for
114 changes in closure variables or bound parameters is disabled. Use for
115 a lambda that produces the identical results in all cases with no
116 parameterization.
117 :param track_closure_variables: when False, changes in closure variables
118 within the lambda will not be scanned. Use for a lambda where the
119 state of its closure variables will never change the SQL structure
120 returned by the lambda.
121 :param track_bound_values: when False, bound parameter tracking will
122 be disabled for the given lambda. Use for a lambda that either does
123 not produce any bound values, or where the initial bound values never
124 change.
125 :param global_track_bound_values: when False, bound parameter tracking
126 will be disabled for the entire statement including additional links
127 added via the :meth:`_sql.StatementLambdaElement.add_criteria` method.
128 :param lambda_cache: a dictionary or other mapping-like object where
129 information about the lambda's Python code as well as the tracked closure
130 variables in the lambda itself will be stored. Defaults
131 to a global LRU cache. This cache is independent of the "compiled_cache"
132 used by the :class:`_engine.Connection` object.
133
134 .. seealso::
135
136 :ref:`engine_lambda_caching`
137
138
139 """
140
141 return StatementLambdaElement(
142 lmb,
143 roles.StatementRole,
144 LambdaOptions(
145 enable_tracking=enable_tracking,
146 track_on=track_on,
147 track_closure_variables=track_closure_variables,
148 global_track_bound_values=global_track_bound_values,
149 track_bound_values=track_bound_values,
150 lambda_cache=lambda_cache,
151 ),
152 )
153
154
155class LambdaElement(elements.ClauseElement):
156 """A SQL construct where the state is stored as an un-invoked lambda.
157
158 The :class:`_sql.LambdaElement` is produced transparently whenever
159 passing lambda expressions into SQL constructs, such as::
160
161 stmt = select(table).where(lambda: table.c.col == parameter)
162
163 The :class:`_sql.LambdaElement` is the base of the
164 :class:`_sql.StatementLambdaElement` which represents a full statement
165 within a lambda.
166
167 .. versionadded:: 1.4
168
169 .. seealso::
170
171 :ref:`engine_lambda_caching`
172
173 """
174
175 __visit_name__ = "lambda_element"
176
177 _is_lambda_element = True
178
179 _traverse_internals = [
180 ("_resolved", visitors.InternalTraversal.dp_clauseelement)
181 ]
182
183 _transforms: Tuple[_CloneCallableType, ...] = ()
184
185 _resolved_bindparams: List[BindParameter[Any]]
186 parent_lambda: Optional[StatementLambdaElement] = None
187 closure_cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]]
188 role: Type[SQLRole]
189 _rec: Union[AnalyzedFunction, NonAnalyzedFunction]
190 fn: _AnyLambdaType
191 tracker_key: Tuple[CodeType, ...]
192
193 def __repr__(self):
194 return "%s(%r)" % (
195 self.__class__.__name__,
196 self.fn.__code__,
197 )
198
199 def __init__(
200 self,
201 fn: _LambdaType,
202 role: Type[SQLRole],
203 opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
204 apply_propagate_attrs: Optional[ClauseElement] = None,
205 ):
206 self.fn = fn
207 self.role = role
208 self.tracker_key = (fn.__code__,)
209 self.opts = opts
210
211 if apply_propagate_attrs is None and (role is roles.StatementRole):
212 apply_propagate_attrs = self
213
214 rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts)
215
216 if apply_propagate_attrs is not None:
217 propagate_attrs = rec.propagate_attrs
218 if propagate_attrs:
219 apply_propagate_attrs._propagate_attrs = propagate_attrs
220
221 def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts):
222 lambda_cache = opts.lambda_cache
223 if lambda_cache is None:
224 lambda_cache = _closure_per_cache_key
225
226 tracker_key = self.tracker_key
227
228 fn = self.fn
229 closure = fn.__closure__
230 tracker = AnalyzedCode.get(
231 fn,
232 self,
233 opts,
234 )
235
236 bindparams: List[BindParameter[Any]]
237 self._resolved_bindparams = bindparams = []
238
239 if self.parent_lambda is not None:
240 parent_closure_cache_key = self.parent_lambda.closure_cache_key
241 else:
242 parent_closure_cache_key = ()
243
244 cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]]
245
246 if parent_closure_cache_key is not _cache_key.NO_CACHE:
247 anon_map = visitors.anon_map()
248 cache_key = tuple(
249 [
250 getter(closure, opts, anon_map, bindparams)
251 for getter in tracker.closure_trackers
252 ]
253 )
254
255 if _cache_key.NO_CACHE not in anon_map:
256 cache_key = parent_closure_cache_key + cache_key
257
258 self.closure_cache_key = cache_key
259
260 rec = lambda_cache.get(tracker_key + cache_key)
261 else:
262 cache_key = _cache_key.NO_CACHE
263 rec = None
264
265 else:
266 cache_key = _cache_key.NO_CACHE
267 rec = None
268
269 self.closure_cache_key = cache_key
270
271 if rec is None:
272 if cache_key is not _cache_key.NO_CACHE:
273 with AnalyzedCode._generation_mutex:
274 key = tracker_key + cache_key
275 if key not in lambda_cache:
276 rec = AnalyzedFunction(
277 tracker, self, apply_propagate_attrs, fn
278 )
279 rec.closure_bindparams = list(bindparams)
280 lambda_cache[key] = rec
281 else:
282 rec = lambda_cache[key]
283 else:
284 rec = NonAnalyzedFunction(self._invoke_user_fn(fn))
285
286 else:
287 bindparams[:] = [
288 orig_bind._with_value(new_bind.value, maintain_key=True)
289 for orig_bind, new_bind in zip(
290 rec.closure_bindparams, bindparams
291 )
292 ]
293
294 self._rec = rec
295
296 if cache_key is not _cache_key.NO_CACHE:
297 if self.parent_lambda is not None:
298 bindparams[:0] = self.parent_lambda._resolved_bindparams
299
300 lambda_element: Optional[LambdaElement] = self
301 while lambda_element is not None:
302 rec = lambda_element._rec
303 if rec.bindparam_trackers:
304 tracker_instrumented_fn = (
305 rec.tracker_instrumented_fn # type:ignore [union-attr] # noqa: E501
306 )
307 for tracker in rec.bindparam_trackers:
308 tracker(
309 lambda_element.fn,
310 tracker_instrumented_fn,
311 bindparams,
312 )
313 lambda_element = lambda_element.parent_lambda
314
315 return rec
316
317 def __getattr__(self, key):
318 return getattr(self._rec.expected_expr, key)
319
320 @property
321 def _is_sequence(self):
322 return self._rec.is_sequence
323
324 @property
325 def _select_iterable(self):
326 if self._is_sequence:
327 return itertools.chain.from_iterable(
328 [element._select_iterable for element in self._resolved]
329 )
330
331 else:
332 return self._resolved._select_iterable
333
334 @property
335 def _from_objects(self):
336 if self._is_sequence:
337 return itertools.chain.from_iterable(
338 [element._from_objects for element in self._resolved]
339 )
340
341 else:
342 return self._resolved._from_objects
343
344 def _param_dict(self):
345 return {b.key: b.value for b in self._resolved_bindparams}
346
347 def _setup_binds_for_tracked_expr(self, expr):
348 bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
349
350 def replace(
351 element: Optional[visitors.ExternallyTraversible], **kw: Any
352 ) -> Optional[visitors.ExternallyTraversible]:
353 if isinstance(element, elements.BindParameter):
354 if element.key in bindparam_lookup:
355 bind = bindparam_lookup[element.key]
356 if element.expanding:
357 bind.expanding = True
358 bind.expand_op = element.expand_op
359 bind.type = element.type
360 return bind
361
362 return None
363
364 if self._rec.is_sequence:
365 expr = [
366 visitors.replacement_traverse(sub_expr, {}, replace)
367 for sub_expr in expr
368 ]
369 elif getattr(expr, "is_clause_element", False):
370 expr = visitors.replacement_traverse(expr, {}, replace)
371
372 return expr
373
374 def _copy_internals(
375 self,
376 clone: _CloneCallableType = _clone,
377 deferred_copy_internals: Optional[_CloneCallableType] = None,
378 **kw: Any,
379 ) -> None:
380 # TODO: this needs A LOT of tests
381 self._resolved = clone(
382 self._resolved,
383 deferred_copy_internals=deferred_copy_internals,
384 **kw,
385 )
386
387 @util.memoized_property
388 def _resolved(self):
389 expr = self._rec.expected_expr
390
391 if self._resolved_bindparams:
392 expr = self._setup_binds_for_tracked_expr(expr)
393
394 return expr
395
396 def _gen_cache_key(self, anon_map, bindparams):
397 if self.closure_cache_key is _cache_key.NO_CACHE:
398 anon_map[_cache_key.NO_CACHE] = True
399 return None
400
401 cache_key = (
402 self.fn.__code__,
403 self.__class__,
404 ) + self.closure_cache_key
405
406 parent = self.parent_lambda
407
408 while parent is not None:
409 assert parent.closure_cache_key is not CacheConst.NO_CACHE
410 parent_closure_cache_key: Tuple[Any, ...] = (
411 parent.closure_cache_key
412 )
413
414 cache_key = (
415 (parent.fn.__code__,) + parent_closure_cache_key + cache_key
416 )
417
418 parent = parent.parent_lambda
419
420 if self._resolved_bindparams:
421 bindparams.extend(self._resolved_bindparams)
422 return cache_key
423
424 def _invoke_user_fn(self, fn: _AnyLambdaType, *arg: Any) -> ClauseElement:
425 return fn() # type: ignore[no-any-return]
426
427
428class DeferredLambdaElement(LambdaElement):
429 """A LambdaElement where the lambda accepts arguments and is
430 invoked within the compile phase with special context.
431
432 This lambda doesn't normally produce its real SQL expression outside of the
433 compile phase. It is passed a fixed set of initial arguments
434 so that it can generate a sample expression.
435
436 """
437
438 def __init__(
439 self,
440 fn: _AnyLambdaType,
441 role: Type[roles.SQLRole],
442 opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
443 lambda_args: Tuple[Any, ...] = (),
444 ):
445 self.lambda_args = lambda_args
446 super().__init__(fn, role, opts)
447
448 def _invoke_user_fn(self, fn, *arg):
449 return fn(*self.lambda_args)
450
451 def _resolve_with_args(self, *lambda_args: Any) -> ClauseElement:
452 assert isinstance(self._rec, AnalyzedFunction)
453 tracker_fn = self._rec.tracker_instrumented_fn
454 expr = tracker_fn(*lambda_args)
455
456 expr = coercions.expect(self.role, expr)
457
458 expr = self._setup_binds_for_tracked_expr(expr)
459
460 # this validation is getting very close, but not quite, to achieving
461 # #5767. The problem is if the base lambda uses an unnamed column
462 # as is very common with mixins, the parameter name is different
463 # and it produces a false positive; that is, for the documented case
464 # that is exactly what people will be doing, it doesn't work, so
465 # I'm not really sure how to handle this right now.
466 # expected_binds = [
467 # b._orig_key
468 # for b in self._rec.expr._generate_cache_key()[1]
469 # if b.required
470 # ]
471 # got_binds = [
472 # b._orig_key for b in expr._generate_cache_key()[1] if b.required
473 # ]
474 # if expected_binds != got_binds:
475 # raise exc.InvalidRequestError(
476 # "Lambda callable at %s produced a different set of bound "
477 # "parameters than its original run: %s"
478 # % (self.fn.__code__, ", ".join(got_binds))
479 # )
480
481 # TODO: TEST TEST TEST, this is very out there
482 for deferred_copy_internals in self._transforms:
483 expr = deferred_copy_internals(expr)
484
485 return expr # type: ignore
486
487 def _copy_internals(
488 self, clone=_clone, deferred_copy_internals=None, **kw
489 ):
490 super()._copy_internals(
491 clone=clone,
492 deferred_copy_internals=deferred_copy_internals, # **kw
493 opts=kw,
494 )
495
496 # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know
497 # our expression yet. so hold onto the replacement
498 if deferred_copy_internals:
499 self._transforms += (deferred_copy_internals,)
500
501
502class StatementLambdaElement(
503 roles.AllowsLambdaRole, ExecutableStatement, LambdaElement
504):
505 """Represent a composable SQL statement as a :class:`_sql.LambdaElement`.
506
507 The :class:`_sql.StatementLambdaElement` is constructed using the
508 :func:`_sql.lambda_stmt` function::
509
510
511 from sqlalchemy import lambda_stmt
512
513 stmt = lambda_stmt(lambda: select(table))
514
515 Once constructed, additional criteria can be built onto the statement
516 by adding subsequent lambdas, which accept the existing statement
517 object as a single parameter::
518
519 stmt += lambda s: s.where(table.c.col == parameter)
520
521 .. versionadded:: 1.4
522
523 .. seealso::
524
525 :ref:`engine_lambda_caching`
526
527 """
528
529 if TYPE_CHECKING:
530
531 def __init__(
532 self,
533 fn: _StmtLambdaType,
534 role: Type[SQLRole],
535 opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
536 apply_propagate_attrs: Optional[ClauseElement] = None,
537 ): ...
538
539 def __add__(
540 self, other: _StmtLambdaElementType[Any]
541 ) -> StatementLambdaElement:
542 return self.add_criteria(other)
543
544 def add_criteria(
545 self,
546 other: _StmtLambdaElementType[Any],
547 enable_tracking: bool = True,
548 track_on: Optional[Any] = None,
549 track_closure_variables: bool = True,
550 track_bound_values: bool = True,
551 ) -> StatementLambdaElement:
552 """Add new criteria to this :class:`_sql.StatementLambdaElement`.
553
554 E.g.::
555
556 >>> def my_stmt(parameter):
557 ... stmt = lambda_stmt(
558 ... lambda: select(table.c.x, table.c.y),
559 ... )
560 ... stmt = stmt.add_criteria(lambda: table.c.x > parameter)
561 ... return stmt
562
563 The :meth:`_sql.StatementLambdaElement.add_criteria` method is
564 equivalent to using the Python addition operator to add a new
565 lambda, except that additional arguments may be added including
566 ``track_closure_values`` and ``track_on``::
567
568 >>> def my_stmt(self, foo):
569 ... stmt = lambda_stmt(
570 ... lambda: select(func.max(foo.x, foo.y)),
571 ... track_closure_variables=False,
572 ... )
573 ... stmt = stmt.add_criteria(lambda: self.where_criteria, track_on=[self])
574 ... return stmt
575
576 See :func:`_sql.lambda_stmt` for a description of the parameters
577 accepted.
578
579 """ # noqa: E501
580
581 opts = self.opts + dict(
582 enable_tracking=enable_tracking,
583 track_closure_variables=track_closure_variables,
584 global_track_bound_values=self.opts.global_track_bound_values,
585 track_on=track_on,
586 track_bound_values=track_bound_values,
587 )
588
589 return LinkedLambdaElement(other, parent_lambda=self, opts=opts)
590
591 def _execute_on_connection(
592 self, connection, distilled_params, execution_options
593 ):
594 if TYPE_CHECKING:
595 assert isinstance(self._rec.expected_expr, ClauseElement)
596 if self._rec.expected_expr.supports_execution:
597 return connection._execute_clauseelement(
598 self, distilled_params, execution_options
599 )
600 else:
601 raise exc.ObjectNotExecutableError(self)
602
603 @property
604 def _proxied(self) -> Any:
605 return self._rec_expected_expr
606
607 @property
608 def _with_options(self): # type: ignore[override]
609 return self._proxied._with_options
610
611 @property
612 def _effective_plugin_target(self):
613 return self._proxied._effective_plugin_target
614
615 @property
616 def _execution_options(self): # type: ignore[override]
617 return self._proxied._execution_options
618
619 @property
620 def _all_selected_columns(self):
621 return self._proxied._all_selected_columns
622
623 @property
624 def is_select(self): # type: ignore[override]
625 return self._proxied.is_select
626
627 @property
628 def is_update(self): # type: ignore[override]
629 return self._proxied.is_update
630
631 @property
632 def is_insert(self): # type: ignore[override]
633 return self._proxied.is_insert
634
635 @property
636 def is_text(self): # type: ignore[override]
637 return self._proxied.is_text
638
639 @property
640 def is_delete(self): # type: ignore[override]
641 return self._proxied.is_delete
642
643 @property
644 def is_dml(self): # type: ignore[override]
645 return self._proxied.is_dml
646
647 def spoil(self) -> NullLambdaStatement:
648 """Return a new :class:`.StatementLambdaElement` that will run
649 all lambdas unconditionally each time.
650
651 """
652 return NullLambdaStatement(self.fn())
653
654
655class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement):
656 """Provides the :class:`.StatementLambdaElement` API but does not
657 cache or analyze lambdas.
658
659 the lambdas are instead invoked immediately.
660
661 The intended use is to isolate issues that may arise when using
662 lambda statements.
663
664 """
665
666 __visit_name__ = "lambda_element"
667
668 _is_lambda_element = True
669
670 _traverse_internals = [
671 ("_resolved", visitors.InternalTraversal.dp_clauseelement)
672 ]
673
674 def __init__(self, statement):
675 self._resolved = statement
676 self._propagate_attrs = statement._propagate_attrs
677
678 def __getattr__(self, key):
679 return getattr(self._resolved, key)
680
681 def __add__(self, other):
682 statement = other(self._resolved)
683
684 return NullLambdaStatement(statement)
685
686 def add_criteria(self, other, **kw):
687 statement = other(self._resolved)
688
689 return NullLambdaStatement(statement)
690
691 def _execute_on_connection(
692 self, connection, distilled_params, execution_options
693 ):
694 if self._resolved.supports_execution:
695 return connection._execute_clauseelement(
696 self, distilled_params, execution_options
697 )
698 else:
699 raise exc.ObjectNotExecutableError(self)
700
701
702class LinkedLambdaElement(StatementLambdaElement):
703 """Represent subsequent links of a :class:`.StatementLambdaElement`."""
704
705 parent_lambda: StatementLambdaElement
706
707 def __init__(
708 self,
709 fn: _StmtLambdaElementType[Any],
710 parent_lambda: StatementLambdaElement,
711 opts: Union[Type[LambdaOptions], LambdaOptions],
712 ):
713 self.opts = opts
714 self.fn = fn
715 self.parent_lambda = parent_lambda
716
717 self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
718 self._retrieve_tracker_rec(fn, self, opts)
719 self._propagate_attrs = parent_lambda._propagate_attrs
720
721 def _invoke_user_fn(self, fn, *arg):
722 return fn(self.parent_lambda._resolved)
723
724
725class AnalyzedCode:
726 __slots__ = (
727 "track_closure_variables",
728 "track_bound_values",
729 "bindparam_trackers",
730 "closure_trackers",
731 "build_py_wrappers",
732 )
733 _fns: weakref.WeakKeyDictionary[CodeType, AnalyzedCode] = (
734 weakref.WeakKeyDictionary()
735 )
736
737 _generation_mutex = threading.RLock()
738
739 @classmethod
740 def get(cls, fn, lambda_element, lambda_kw, **kw):
741 try:
742 # TODO: validate kw haven't changed?
743 return cls._fns[fn.__code__]
744 except KeyError:
745 pass
746
747 with cls._generation_mutex:
748 # check for other thread already created object
749 if fn.__code__ in cls._fns:
750 return cls._fns[fn.__code__]
751
752 analyzed: AnalyzedCode
753 cls._fns[fn.__code__] = analyzed = AnalyzedCode(
754 fn, lambda_element, lambda_kw, **kw
755 )
756 return analyzed
757
758 def __init__(self, fn, lambda_element, opts):
759 if inspect.ismethod(fn):
760 raise exc.ArgumentError(
761 "Method %s may not be passed as a SQL expression" % fn
762 )
763 closure = fn.__closure__
764
765 self.track_bound_values = (
766 opts.track_bound_values and opts.global_track_bound_values
767 )
768 enable_tracking = opts.enable_tracking
769 track_on = opts.track_on
770 track_closure_variables = opts.track_closure_variables
771
772 self.track_closure_variables = track_closure_variables and not track_on
773
774 # a list of callables generated from _bound_parameter_getter_*
775 # functions. Each of these uses a PyWrapper object to retrieve
776 # a parameter value
777 self.bindparam_trackers = []
778
779 # a list of callables generated from _cache_key_getter_* functions
780 # these callables work to generate a cache key for the lambda
781 # based on what's inside its closure variables.
782 self.closure_trackers = []
783
784 self.build_py_wrappers = []
785
786 if enable_tracking:
787 if track_on:
788 self._init_track_on(track_on)
789
790 self._init_globals(fn)
791
792 if closure:
793 self._init_closure(fn)
794
795 self._setup_additional_closure_trackers(fn, lambda_element, opts)
796
797 def _init_track_on(self, track_on):
798 self.closure_trackers.extend(
799 self._cache_key_getter_track_on(idx, elem)
800 for idx, elem in enumerate(track_on)
801 )
802
803 def _init_globals(self, fn):
804 build_py_wrappers = self.build_py_wrappers
805 bindparam_trackers = self.bindparam_trackers
806 track_bound_values = self.track_bound_values
807
808 for name in fn.__code__.co_names:
809 if name not in fn.__globals__:
810 continue
811
812 _bound_value = self._roll_down_to_literal(fn.__globals__[name])
813
814 if coercions._deep_is_literal(_bound_value):
815 build_py_wrappers.append((name, None))
816 if track_bound_values:
817 bindparam_trackers.append(
818 self._bound_parameter_getter_func_globals(name)
819 )
820
821 def _init_closure(self, fn):
822 build_py_wrappers = self.build_py_wrappers
823 closure = fn.__closure__
824
825 track_bound_values = self.track_bound_values
826 track_closure_variables = self.track_closure_variables
827 bindparam_trackers = self.bindparam_trackers
828 closure_trackers = self.closure_trackers
829
830 for closure_index, (fv, cell) in enumerate(
831 zip(fn.__code__.co_freevars, closure)
832 ):
833 _bound_value = self._roll_down_to_literal(cell.cell_contents)
834
835 if coercions._deep_is_literal(_bound_value):
836 build_py_wrappers.append((fv, closure_index))
837 if track_bound_values:
838 bindparam_trackers.append(
839 self._bound_parameter_getter_func_closure(
840 fv, closure_index
841 )
842 )
843 else:
844 # for normal cell contents, add them to a list that
845 # we can compare later when we get new lambdas. if
846 # any identities have changed, then we will
847 # recalculate the whole lambda and run it again.
848
849 if track_closure_variables:
850 closure_trackers.append(
851 self._cache_key_getter_closure_variable(
852 fn, fv, closure_index, cell.cell_contents
853 )
854 )
855
856 def _setup_additional_closure_trackers(self, fn, lambda_element, opts):
857 # an additional step is to actually run the function, then
858 # go through the PyWrapper objects that were set up to catch a bound
859 # parameter. then if they *didn't* make a param, oh they're another
860 # object in the closure we have to track for our cache key. so
861 # create trackers to catch those.
862
863 analyzed_function = AnalyzedFunction(
864 self,
865 lambda_element,
866 None,
867 fn,
868 )
869
870 closure_trackers = self.closure_trackers
871
872 for pywrapper in analyzed_function.closure_pywrappers:
873 if not pywrapper._sa__has_param:
874 closure_trackers.append(
875 self._cache_key_getter_tracked_literal(fn, pywrapper)
876 )
877
878 @classmethod
879 def _roll_down_to_literal(cls, element):
880 is_clause_element = hasattr(element, "__clause_element__")
881
882 if is_clause_element:
883 while not isinstance(
884 element, (elements.ClauseElement, schema.SchemaItem, type)
885 ):
886 try:
887 element = element.__clause_element__()
888 except AttributeError:
889 break
890
891 if not is_clause_element:
892 insp = inspection.inspect(element, raiseerr=False)
893 if insp is not None:
894 try:
895 return insp.__clause_element__()
896 except AttributeError:
897 return insp
898
899 # TODO: should we coerce consts None/True/False here?
900 return element
901 else:
902 return element
903
904 def _bound_parameter_getter_func_globals(self, name):
905 """Return a getter that will extend a list of bound parameters
906 with new entries from the ``__globals__`` collection of a particular
907 lambda.
908
909 """
910
911 def extract_parameter_value(
912 current_fn, tracker_instrumented_fn, result
913 ):
914 wrapper = tracker_instrumented_fn.__globals__[name]
915 object.__getattribute__(wrapper, "_extract_bound_parameters")(
916 current_fn.__globals__[name], result
917 )
918
919 return extract_parameter_value
920
921 def _bound_parameter_getter_func_closure(self, name, closure_index):
922 """Return a getter that will extend a list of bound parameters
923 with new entries from the ``__closure__`` collection of a particular
924 lambda.
925
926 """
927
928 def extract_parameter_value(
929 current_fn, tracker_instrumented_fn, result
930 ):
931 wrapper = tracker_instrumented_fn.__closure__[
932 closure_index
933 ].cell_contents
934 object.__getattribute__(wrapper, "_extract_bound_parameters")(
935 current_fn.__closure__[closure_index].cell_contents, result
936 )
937
938 return extract_parameter_value
939
940 def _cache_key_getter_track_on(self, idx, elem):
941 """Return a getter that will extend a cache key with new entries
942 from the "track_on" parameter passed to a :class:`.LambdaElement`.
943
944 """
945
946 if isinstance(elem, tuple):
947 # tuple must contain hascachekey elements
948 def get(closure, opts, anon_map, bindparams):
949 return tuple(
950 tup_elem._gen_cache_key(anon_map, bindparams)
951 for tup_elem in opts.track_on[idx]
952 )
953
954 elif isinstance(elem, _cache_key.HasCacheKey):
955
956 def get(closure, opts, anon_map, bindparams):
957 return opts.track_on[idx]._gen_cache_key(anon_map, bindparams)
958
959 else:
960
961 def get(closure, opts, anon_map, bindparams):
962 return opts.track_on[idx]
963
964 return get
965
966 def _cache_key_getter_closure_variable(
967 self,
968 fn,
969 variable_name,
970 idx,
971 cell_contents,
972 use_clause_element=False,
973 use_inspect=False,
974 ):
975 """Return a getter that will extend a cache key with new entries
976 from the ``__closure__`` collection of a particular lambda.
977
978 """
979
980 if isinstance(cell_contents, _cache_key.HasCacheKey):
981
982 def get(closure, opts, anon_map, bindparams):
983 obj = closure[idx].cell_contents
984 if use_inspect:
985 obj = inspection.inspect(obj)
986 elif use_clause_element:
987 while hasattr(obj, "__clause_element__"):
988 if not getattr(obj, "is_clause_element", False):
989 obj = obj.__clause_element__()
990
991 return obj._gen_cache_key(anon_map, bindparams)
992
993 elif isinstance(cell_contents, types.FunctionType):
994
995 def get(closure, opts, anon_map, bindparams):
996 return closure[idx].cell_contents.__code__
997
998 elif isinstance(cell_contents, collections_abc.Sequence):
999
1000 def get(closure, opts, anon_map, bindparams):
1001 contents = closure[idx].cell_contents
1002
1003 try:
1004 return tuple(
1005 elem._gen_cache_key(anon_map, bindparams)
1006 for elem in contents
1007 )
1008 except AttributeError as ae:
1009 self._raise_for_uncacheable_closure_variable(
1010 variable_name, fn, from_=ae
1011 )
1012
1013 else:
1014 # if the object is a mapped class or aliased class, or some
1015 # other object in the ORM realm of things like that, imitate
1016 # the logic used in coercions.expect() to roll it down to the
1017 # SQL element
1018 element = cell_contents
1019 is_clause_element = False
1020 while hasattr(element, "__clause_element__"):
1021 is_clause_element = True
1022 if not getattr(element, "is_clause_element", False):
1023 element = element.__clause_element__()
1024 else:
1025 break
1026
1027 if not is_clause_element:
1028 insp = inspection.inspect(element, raiseerr=False)
1029 if insp is not None:
1030 return self._cache_key_getter_closure_variable(
1031 fn, variable_name, idx, insp, use_inspect=True
1032 )
1033 else:
1034 return self._cache_key_getter_closure_variable(
1035 fn, variable_name, idx, element, use_clause_element=True
1036 )
1037
1038 self._raise_for_uncacheable_closure_variable(variable_name, fn)
1039
1040 return get
1041
1042 def _raise_for_uncacheable_closure_variable(
1043 self, variable_name, fn, from_=None
1044 ):
1045 raise exc.InvalidRequestError(
1046 "Closure variable named '%s' inside of lambda callable %s "
1047 "does not refer to a cacheable SQL element, and also does not "
1048 "appear to be serving as a SQL literal bound value based on "
1049 "the default "
1050 "SQL expression returned by the function. This variable "
1051 "needs to remain outside the scope of a SQL-generating lambda "
1052 "so that a proper cache key may be generated from the "
1053 "lambda's state. Evaluate this variable outside of the "
1054 "lambda, set track_on=[<elements>] to explicitly select "
1055 "closure elements to track, or set "
1056 "track_closure_variables=False to exclude "
1057 "closure variables from being part of the cache key."
1058 % (variable_name, fn.__code__),
1059 ) from from_
1060
1061 def _cache_key_getter_tracked_literal(self, fn, pytracker):
1062 """Return a getter that will extend a cache key with new entries
1063 from the ``__closure__`` collection of a particular lambda.
1064
1065 this getter differs from _cache_key_getter_closure_variable
1066 in that these are detected after the function is run, and PyWrapper
1067 objects have recorded that a particular literal value is in fact
1068 not being interpreted as a bound parameter.
1069
1070 """
1071
1072 elem = pytracker._sa__to_evaluate
1073 closure_index = pytracker._sa__closure_index
1074 variable_name = pytracker._sa__name
1075
1076 return self._cache_key_getter_closure_variable(
1077 fn, variable_name, closure_index, elem
1078 )
1079
1080
1081class NonAnalyzedFunction:
1082 __slots__ = ("expr",)
1083
1084 closure_bindparams: Optional[List[BindParameter[Any]]] = None
1085 bindparam_trackers: Optional[List[_BoundParameterGetter]] = None
1086
1087 is_sequence = False
1088
1089 expr: ClauseElement
1090
1091 def __init__(self, expr: ClauseElement):
1092 self.expr = expr
1093
1094 @property
1095 def expected_expr(self) -> ClauseElement:
1096 return self.expr
1097
1098
1099class AnalyzedFunction:
1100 __slots__ = (
1101 "analyzed_code",
1102 "fn",
1103 "closure_pywrappers",
1104 "tracker_instrumented_fn",
1105 "expr",
1106 "bindparam_trackers",
1107 "expected_expr",
1108 "is_sequence",
1109 "propagate_attrs",
1110 "closure_bindparams",
1111 )
1112
1113 closure_bindparams: Optional[List[BindParameter[Any]]]
1114 expected_expr: Union[ClauseElement, List[ClauseElement]]
1115 bindparam_trackers: Optional[List[_BoundParameterGetter]]
1116
1117 def __init__(
1118 self,
1119 analyzed_code,
1120 lambda_element,
1121 apply_propagate_attrs,
1122 fn,
1123 ):
1124 self.analyzed_code = analyzed_code
1125 self.fn = fn
1126
1127 self.bindparam_trackers = analyzed_code.bindparam_trackers
1128
1129 self._instrument_and_run_function(lambda_element)
1130
1131 self._coerce_expression(lambda_element, apply_propagate_attrs)
1132
1133 def _instrument_and_run_function(self, lambda_element):
1134 analyzed_code = self.analyzed_code
1135
1136 fn = self.fn
1137 self.closure_pywrappers = closure_pywrappers = []
1138
1139 build_py_wrappers = analyzed_code.build_py_wrappers
1140
1141 if not build_py_wrappers:
1142 self.tracker_instrumented_fn = tracker_instrumented_fn = fn
1143 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
1144 else:
1145 track_closure_variables = analyzed_code.track_closure_variables
1146 closure = fn.__closure__
1147
1148 # will form the __closure__ of the function when we rebuild it
1149 if closure:
1150 new_closure = {
1151 fv: cell.cell_contents
1152 for fv, cell in zip(fn.__code__.co_freevars, closure)
1153 }
1154 else:
1155 new_closure = {}
1156
1157 # will form the __globals__ of the function when we rebuild it
1158 new_globals = fn.__globals__.copy()
1159
1160 for name, closure_index in build_py_wrappers:
1161 if closure_index is not None:
1162 value = closure[closure_index].cell_contents
1163 new_closure[name] = bind = PyWrapper(
1164 fn,
1165 name,
1166 value,
1167 closure_index=closure_index,
1168 track_bound_values=(
1169 self.analyzed_code.track_bound_values
1170 ),
1171 )
1172 if track_closure_variables:
1173 closure_pywrappers.append(bind)
1174 else:
1175 value = fn.__globals__[name]
1176 new_globals[name] = PyWrapper(fn, name, value)
1177
1178 # rewrite the original fn. things that look like they will
1179 # become bound parameters are wrapped in a PyWrapper.
1180 self.tracker_instrumented_fn = tracker_instrumented_fn = (
1181 self._rewrite_code_obj(
1182 fn,
1183 [new_closure[name] for name in fn.__code__.co_freevars],
1184 new_globals,
1185 )
1186 )
1187
1188 # now invoke the function. This will give us a new SQL
1189 # expression, but all the places that there would be a bound
1190 # parameter, the PyWrapper in its place will give us a bind
1191 # with a predictable name we can match up later.
1192
1193 # additionally, each PyWrapper will log that it did in fact
1194 # create a parameter, otherwise, it's some kind of Python
1195 # object in the closure and we want to track that, to make
1196 # sure it doesn't change to something else, or if it does,
1197 # that we create a different tracked function with that
1198 # variable.
1199 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
1200
1201 def _coerce_expression(self, lambda_element, apply_propagate_attrs):
1202 """Run the tracker-generated expression through coercion rules.
1203
1204 After the user-defined lambda has been invoked to produce a statement
1205 for re-use, run it through coercion rules to both check that it's the
1206 correct type of object and also to coerce it to its useful form.
1207
1208 """
1209
1210 parent_lambda = lambda_element.parent_lambda
1211 expr = self.expr
1212
1213 if parent_lambda is None:
1214 if isinstance(expr, collections_abc.Sequence):
1215 self.expected_expr = [
1216 cast(
1217 "ClauseElement",
1218 coercions.expect(
1219 lambda_element.role,
1220 sub_expr,
1221 apply_propagate_attrs=apply_propagate_attrs,
1222 ),
1223 )
1224 for sub_expr in expr
1225 ]
1226 self.is_sequence = True
1227 else:
1228 self.expected_expr = cast(
1229 "ClauseElement",
1230 coercions.expect(
1231 lambda_element.role,
1232 expr,
1233 apply_propagate_attrs=apply_propagate_attrs,
1234 ),
1235 )
1236 self.is_sequence = False
1237 else:
1238 self.expected_expr = expr
1239 self.is_sequence = False
1240
1241 if apply_propagate_attrs is not None:
1242 self.propagate_attrs = apply_propagate_attrs._propagate_attrs
1243 else:
1244 self.propagate_attrs = util.EMPTY_DICT
1245
1246 def _rewrite_code_obj(self, f, cell_values, globals_):
1247 """Return a copy of f, with a new closure and new globals
1248
1249 yes it works in pypy :P
1250
1251 """
1252
1253 argrange = range(len(cell_values))
1254
1255 code = "def make_cells():\n"
1256 if cell_values:
1257 code += " (%s) = (%s)\n" % (
1258 ", ".join("i%d" % i for i in argrange),
1259 ", ".join("o%d" % i for i in argrange),
1260 )
1261 code += " def closure():\n"
1262 code += " return %s\n" % ", ".join("i%d" % i for i in argrange)
1263 code += " return closure.__closure__"
1264 vars_ = {"o%d" % i: cell_values[i] for i in argrange}
1265 exec(code, vars_, vars_)
1266 closure = vars_["make_cells"]()
1267
1268 func = type(f)(
1269 f.__code__, globals_, f.__name__, f.__defaults__, closure
1270 )
1271 func.__annotations__ = f.__annotations__
1272 func.__kwdefaults__ = f.__kwdefaults__
1273 func.__doc__ = f.__doc__
1274 func.__module__ = f.__module__
1275
1276 return func
1277
1278
1279class PyWrapper(ColumnOperators):
1280 """A wrapper object that is injected into the ``__globals__`` and
1281 ``__closure__`` of a Python function.
1282
1283 When the function is instrumented with :class:`.PyWrapper` objects, it is
1284 then invoked just once in order to set up the wrappers. We look through
1285 all the :class:`.PyWrapper` objects we made to find the ones that generated
1286 a :class:`.BindParameter` object, e.g. the expression system interpreted
1287 something as a literal. Those positions in the globals/closure are then
1288 ones that we will look at, each time a new lambda comes in that refers to
1289 the same ``__code__`` object. In this way, we keep a single version of
1290 the SQL expression that this lambda produced, without calling upon the
1291 Python function that created it more than once, unless its other closure
1292 variables have changed. The expression is then transformed to have the
1293 new bound values embedded into it.
1294
1295 """
1296
1297 def __init__(
1298 self,
1299 fn,
1300 name,
1301 to_evaluate,
1302 closure_index=None,
1303 getter=None,
1304 track_bound_values=True,
1305 ):
1306 self.fn = fn
1307 self._name = name
1308 self._to_evaluate = to_evaluate
1309 self._param = None
1310 self._has_param = False
1311 self._bind_paths = {}
1312 self._getter = getter
1313 self._closure_index = closure_index
1314 self.track_bound_values = track_bound_values
1315
1316 def __call__(self, *arg, **kw):
1317 elem = object.__getattribute__(self, "_to_evaluate")
1318 value = elem(*arg, **kw)
1319 if (
1320 self._sa_track_bound_values
1321 and coercions._deep_is_literal(value)
1322 and not isinstance(
1323 # TODO: coverage where an ORM option or similar is here
1324 value,
1325 _cache_key.HasCacheKey,
1326 )
1327 ):
1328 name = object.__getattribute__(self, "_name")
1329 raise exc.InvalidRequestError(
1330 "Can't invoke Python callable %s() inside of lambda "
1331 "expression argument at %s; lambda SQL constructs should "
1332 "not invoke functions from closure variables to produce "
1333 "literal values since the "
1334 "lambda SQL system normally extracts bound values without "
1335 "actually "
1336 "invoking the lambda or any functions within it. Call the "
1337 "function outside of the "
1338 "lambda and assign to a local variable that is used in the "
1339 "lambda as a closure variable, or set "
1340 "track_bound_values=False if the return value of this "
1341 "function is used in some other way other than a SQL bound "
1342 "value." % (name, self._sa_fn.__code__)
1343 )
1344 else:
1345 return value
1346
1347 def operate(self, op, *other, **kwargs):
1348 elem = object.__getattribute__(self, "_py_wrapper_literal")()
1349 return op(elem, *other, **kwargs)
1350
1351 def reverse_operate(self, op, other, **kwargs):
1352 elem = object.__getattribute__(self, "_py_wrapper_literal")()
1353 return op(other, elem, **kwargs)
1354
1355 def _extract_bound_parameters(self, starting_point, result_list):
1356 param = object.__getattribute__(self, "_param")
1357 if param is not None:
1358 param = param._with_value(starting_point, maintain_key=True)
1359 result_list.append(param)
1360 for pywrapper in object.__getattribute__(self, "_bind_paths").values():
1361 getter = object.__getattribute__(pywrapper, "_getter")
1362 element = getter(starting_point)
1363 pywrapper._sa__extract_bound_parameters(element, result_list)
1364
1365 def _py_wrapper_literal(self, expr=None, operator=None, **kw):
1366 param = object.__getattribute__(self, "_param")
1367 to_evaluate = object.__getattribute__(self, "_to_evaluate")
1368 if param is None:
1369 name = object.__getattribute__(self, "_name")
1370 self._param = param = elements.BindParameter(
1371 name,
1372 required=False,
1373 unique=True,
1374 _compared_to_operator=operator,
1375 _compared_to_type=expr.type if expr is not None else None,
1376 )
1377 self._has_param = True
1378 return param._with_value(to_evaluate, maintain_key=True)
1379
1380 def __bool__(self):
1381 to_evaluate = object.__getattribute__(self, "_to_evaluate")
1382 return bool(to_evaluate)
1383
1384 def __getattribute__(self, key):
1385 if key.startswith("_sa_"):
1386 return object.__getattribute__(self, key[4:])
1387 elif key in (
1388 "__clause_element__",
1389 "operate",
1390 "reverse_operate",
1391 "_py_wrapper_literal",
1392 "__class__",
1393 "__dict__",
1394 ):
1395 return object.__getattribute__(self, key)
1396
1397 if key.startswith("__"):
1398 elem = object.__getattribute__(self, "_to_evaluate")
1399 return getattr(elem, key)
1400 else:
1401 return self._sa__add_getter(key, operator.attrgetter)
1402
1403 def __iter__(self):
1404 elem = object.__getattribute__(self, "_to_evaluate")
1405 return iter(elem)
1406
1407 def __getitem__(self, key):
1408 elem = object.__getattribute__(self, "_to_evaluate")
1409 if not hasattr(elem, "__getitem__"):
1410 raise AttributeError("__getitem__")
1411
1412 if isinstance(key, PyWrapper):
1413 # TODO: coverage
1414 raise exc.InvalidRequestError(
1415 "Dictionary keys / list indexes inside of a cached "
1416 "lambda must be Python literals only"
1417 )
1418 return self._sa__add_getter(key, operator.itemgetter)
1419
1420 def _add_getter(self, key, getter_fn):
1421 bind_paths = object.__getattribute__(self, "_bind_paths")
1422
1423 bind_path_key = (key, getter_fn)
1424 if bind_path_key in bind_paths:
1425 return bind_paths[bind_path_key]
1426
1427 getter = getter_fn(key)
1428 elem = object.__getattribute__(self, "_to_evaluate")
1429 value = getter(elem)
1430
1431 rolled_down_value = AnalyzedCode._roll_down_to_literal(value)
1432
1433 if coercions._deep_is_literal(rolled_down_value):
1434 wrapper = PyWrapper(self._sa_fn, key, value, getter=getter)
1435 bind_paths[bind_path_key] = wrapper
1436 return wrapper
1437 else:
1438 return value
1439
1440
1441@inspection._inspects(LambdaElement)
1442def insp(lmb):
1443 return inspection.inspect(lmb._resolved)