1# sql/lambdas.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11import collections.abc as collections_abc
12import inspect
13import itertools
14import operator
15import threading
16import types
17from types import CodeType
18from typing import Any
19from typing import Callable
20from typing import cast
21from typing import List
22from typing import MutableMapping
23from typing import Optional
24from typing import Tuple
25from typing import Type
26from typing import TYPE_CHECKING
27from typing import TypeVar
28from typing import Union
29import weakref
30
31from . import cache_key as _cache_key
32from . import coercions
33from . import elements
34from . import roles
35from . import schema
36from . import visitors
37from .base import _clone
38from .base import Executable
39from .base import Options
40from .cache_key import CacheConst
41from .operators import ColumnOperators
42from .. import exc
43from .. import inspection
44from .. import util
45from ..util.typing import Literal
46
47if TYPE_CHECKING:
48 from .elements import BindParameter
49 from .elements import ClauseElement
50 from .roles import SQLRole
51 from .visitors import _CloneCallableType
52
53_LambdaCacheType = MutableMapping[
54 Tuple[Any, ...], Union["NonAnalyzedFunction", "AnalyzedFunction"]
55]
56_BoundParameterGetter = Callable[..., Any]
57
58_closure_per_cache_key: _LambdaCacheType = util.LRUCache(1000)
59
60
61_LambdaType = Callable[[], Any]
62
63_AnyLambdaType = Callable[..., Any]
64
65_StmtLambdaType = Callable[[], Any]
66
67_E = TypeVar("_E", bound=Executable)
68_StmtLambdaElementType = Callable[[_E], Any]
69
70
71class LambdaOptions(Options):
72 enable_tracking = True
73 track_closure_variables = True
74 track_on: Optional[object] = None
75 global_track_bound_values = True
76 track_bound_values = True
77 lambda_cache: Optional[_LambdaCacheType] = None
78
79
80def lambda_stmt(
81 lmb: _StmtLambdaType,
82 enable_tracking: bool = True,
83 track_closure_variables: bool = True,
84 track_on: Optional[object] = None,
85 global_track_bound_values: bool = True,
86 track_bound_values: bool = True,
87 lambda_cache: Optional[_LambdaCacheType] = None,
88) -> StatementLambdaElement:
89 """Produce a SQL statement that is cached as a lambda.
90
91 The Python code object within the lambda is scanned for both Python
92 literals that will become bound parameters as well as closure variables
93 that refer to Core or ORM constructs that may vary. The lambda itself
94 will be invoked only once per particular set of constructs detected.
95
96 E.g.::
97
98 from sqlalchemy import lambda_stmt
99
100 stmt = lambda_stmt(lambda: table.select())
101 stmt += lambda s: s.where(table.c.id == 5)
102
103 result = connection.execute(stmt)
104
105 The object returned is an instance of :class:`_sql.StatementLambdaElement`.
106
107 .. versionadded:: 1.4
108
109 :param lmb: a Python function, typically a lambda, which takes no arguments
110 and returns a SQL expression construct
111 :param enable_tracking: when False, all scanning of the given lambda for
112 changes in closure variables or bound parameters is disabled. Use for
113 a lambda that produces the identical results in all cases with no
114 parameterization.
115 :param track_closure_variables: when False, changes in closure variables
116 within the lambda will not be scanned. Use for a lambda where the
117 state of its closure variables will never change the SQL structure
118 returned by the lambda.
119 :param track_bound_values: when False, bound parameter tracking will
120 be disabled for the given lambda. Use for a lambda that either does
121 not produce any bound values, or where the initial bound values never
122 change.
123 :param global_track_bound_values: when False, bound parameter tracking
124 will be disabled for the entire statement including additional links
125 added via the :meth:`_sql.StatementLambdaElement.add_criteria` method.
126 :param lambda_cache: a dictionary or other mapping-like object where
127 information about the lambda's Python code as well as the tracked closure
128 variables in the lambda itself will be stored. Defaults
129 to a global LRU cache. This cache is independent of the "compiled_cache"
130 used by the :class:`_engine.Connection` object.
131
132 .. seealso::
133
134 :ref:`engine_lambda_caching`
135
136
137 """
138
139 return StatementLambdaElement(
140 lmb,
141 roles.StatementRole,
142 LambdaOptions(
143 enable_tracking=enable_tracking,
144 track_on=track_on,
145 track_closure_variables=track_closure_variables,
146 global_track_bound_values=global_track_bound_values,
147 track_bound_values=track_bound_values,
148 lambda_cache=lambda_cache,
149 ),
150 )
151
152
153class LambdaElement(elements.ClauseElement):
154 """A SQL construct where the state is stored as an un-invoked lambda.
155
156 The :class:`_sql.LambdaElement` is produced transparently whenever
157 passing lambda expressions into SQL constructs, such as::
158
159 stmt = select(table).where(lambda: table.c.col == parameter)
160
161 The :class:`_sql.LambdaElement` is the base of the
162 :class:`_sql.StatementLambdaElement` which represents a full statement
163 within a lambda.
164
165 .. versionadded:: 1.4
166
167 .. seealso::
168
169 :ref:`engine_lambda_caching`
170
171 """
172
173 __visit_name__ = "lambda_element"
174
175 _is_lambda_element = True
176
177 _traverse_internals = [
178 ("_resolved", visitors.InternalTraversal.dp_clauseelement)
179 ]
180
181 _transforms: Tuple[_CloneCallableType, ...] = ()
182
183 _resolved_bindparams: List[BindParameter[Any]]
184 parent_lambda: Optional[StatementLambdaElement] = None
185 closure_cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]]
186 role: Type[SQLRole]
187 _rec: Union[AnalyzedFunction, NonAnalyzedFunction]
188 fn: _AnyLambdaType
189 tracker_key: Tuple[CodeType, ...]
190
191 def __repr__(self):
192 return "%s(%r)" % (
193 self.__class__.__name__,
194 self.fn.__code__,
195 )
196
197 def __init__(
198 self,
199 fn: _LambdaType,
200 role: Type[SQLRole],
201 opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
202 apply_propagate_attrs: Optional[ClauseElement] = None,
203 ):
204 self.fn = fn
205 self.role = role
206 self.tracker_key = (fn.__code__,)
207 self.opts = opts
208
209 if apply_propagate_attrs is None and (role is roles.StatementRole):
210 apply_propagate_attrs = self
211
212 rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts)
213
214 if apply_propagate_attrs is not None:
215 propagate_attrs = rec.propagate_attrs
216 if propagate_attrs:
217 apply_propagate_attrs._propagate_attrs = propagate_attrs
218
219 def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts):
220 lambda_cache = opts.lambda_cache
221 if lambda_cache is None:
222 lambda_cache = _closure_per_cache_key
223
224 tracker_key = self.tracker_key
225
226 fn = self.fn
227 closure = fn.__closure__
228 tracker = AnalyzedCode.get(
229 fn,
230 self,
231 opts,
232 )
233
234 bindparams: List[BindParameter[Any]]
235 self._resolved_bindparams = bindparams = []
236
237 if self.parent_lambda is not None:
238 parent_closure_cache_key = self.parent_lambda.closure_cache_key
239 else:
240 parent_closure_cache_key = ()
241
242 cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]]
243
244 if parent_closure_cache_key is not _cache_key.NO_CACHE:
245 anon_map = visitors.anon_map()
246 cache_key = tuple(
247 [
248 getter(closure, opts, anon_map, bindparams)
249 for getter in tracker.closure_trackers
250 ]
251 )
252
253 if _cache_key.NO_CACHE not in anon_map:
254 cache_key = parent_closure_cache_key + cache_key
255
256 self.closure_cache_key = cache_key
257
258 rec = lambda_cache.get(tracker_key + cache_key)
259 else:
260 cache_key = _cache_key.NO_CACHE
261 rec = None
262
263 else:
264 cache_key = _cache_key.NO_CACHE
265 rec = None
266
267 self.closure_cache_key = cache_key
268
269 if rec is None:
270 if cache_key is not _cache_key.NO_CACHE:
271 with AnalyzedCode._generation_mutex:
272 key = tracker_key + cache_key
273 if key not in lambda_cache:
274 rec = AnalyzedFunction(
275 tracker, self, apply_propagate_attrs, fn
276 )
277 rec.closure_bindparams = list(bindparams)
278 lambda_cache[key] = rec
279 else:
280 rec = lambda_cache[key]
281 else:
282 rec = NonAnalyzedFunction(self._invoke_user_fn(fn))
283
284 else:
285 bindparams[:] = [
286 orig_bind._with_value(new_bind.value, maintain_key=True)
287 for orig_bind, new_bind in zip(
288 rec.closure_bindparams, bindparams
289 )
290 ]
291
292 self._rec = rec
293
294 if cache_key is not _cache_key.NO_CACHE:
295 if self.parent_lambda is not None:
296 bindparams[:0] = self.parent_lambda._resolved_bindparams
297
298 lambda_element: Optional[LambdaElement] = self
299 while lambda_element is not None:
300 rec = lambda_element._rec
301 if rec.bindparam_trackers:
302 tracker_instrumented_fn = (
303 rec.tracker_instrumented_fn # type: ignore [union-attr] # noqa: E501
304 )
305 for tracker in rec.bindparam_trackers:
306 tracker(
307 lambda_element.fn,
308 tracker_instrumented_fn,
309 bindparams,
310 )
311 lambda_element = lambda_element.parent_lambda
312
313 return rec
314
315 def __getattr__(self, key):
316 return getattr(self._resolved, key)
317
318 @property
319 def _is_sequence(self):
320 return self._rec.is_sequence
321
322 @property
323 def _select_iterable(self):
324 if self._is_sequence:
325 return itertools.chain.from_iterable(
326 [element._select_iterable for element in self._resolved]
327 )
328
329 else:
330 return self._resolved._select_iterable
331
332 @property
333 def _from_objects(self):
334 if self._is_sequence:
335 return itertools.chain.from_iterable(
336 [element._from_objects for element in self._resolved]
337 )
338
339 else:
340 return self._resolved._from_objects
341
342 def _param_dict(self):
343 return {b.key: b.value for b in self._resolved_bindparams}
344
345 def _setup_binds_for_tracked_expr(self, expr):
346 bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
347
348 def replace(
349 element: Optional[visitors.ExternallyTraversible], **kw: Any
350 ) -> Optional[visitors.ExternallyTraversible]:
351 if isinstance(element, elements.BindParameter):
352 if element.key in bindparam_lookup:
353 bind = bindparam_lookup[element.key]
354 if element.expanding:
355 bind.expanding = True
356 bind.expand_op = element.expand_op
357 bind.type = element.type
358 return bind
359
360 return None
361
362 if self._rec.is_sequence:
363 expr = [
364 visitors.replacement_traverse(sub_expr, {}, replace)
365 for sub_expr in expr
366 ]
367 elif getattr(expr, "is_clause_element", False):
368 expr = visitors.replacement_traverse(expr, {}, replace)
369
370 return expr
371
372 def _copy_internals(
373 self,
374 clone: _CloneCallableType = _clone,
375 deferred_copy_internals: Optional[_CloneCallableType] = None,
376 **kw: Any,
377 ) -> None:
378 # TODO: this needs A LOT of tests
379 self._resolved = clone(
380 self._resolved,
381 deferred_copy_internals=deferred_copy_internals,
382 **kw,
383 )
384
385 @util.memoized_property
386 def _resolved(self):
387 expr = self._rec.expected_expr
388
389 if self._resolved_bindparams:
390 expr = self._setup_binds_for_tracked_expr(expr)
391
392 return expr
393
394 def _gen_cache_key(self, anon_map, bindparams):
395 if self.closure_cache_key is _cache_key.NO_CACHE:
396 anon_map[_cache_key.NO_CACHE] = True
397 return None
398
399 cache_key = (
400 self.fn.__code__,
401 self.__class__,
402 ) + self.closure_cache_key
403
404 parent = self.parent_lambda
405
406 while parent is not None:
407 assert parent.closure_cache_key is not CacheConst.NO_CACHE
408 parent_closure_cache_key: Tuple[Any, ...] = (
409 parent.closure_cache_key
410 )
411
412 cache_key = (
413 (parent.fn.__code__,) + parent_closure_cache_key + cache_key
414 )
415
416 parent = parent.parent_lambda
417
418 if self._resolved_bindparams:
419 bindparams.extend(self._resolved_bindparams)
420 return cache_key
421
422 def _invoke_user_fn(self, fn: _AnyLambdaType, *arg: Any) -> ClauseElement:
423 return fn() # type: ignore[no-any-return]
424
425
426class DeferredLambdaElement(LambdaElement):
427 """A LambdaElement where the lambda accepts arguments and is
428 invoked within the compile phase with special context.
429
430 This lambda doesn't normally produce its real SQL expression outside of the
431 compile phase. It is passed a fixed set of initial arguments
432 so that it can generate a sample expression.
433
434 """
435
436 def __init__(
437 self,
438 fn: _AnyLambdaType,
439 role: Type[roles.SQLRole],
440 opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
441 lambda_args: Tuple[Any, ...] = (),
442 ):
443 self.lambda_args = lambda_args
444 super().__init__(fn, role, opts)
445
446 def _invoke_user_fn(self, fn, *arg):
447 return fn(*self.lambda_args)
448
449 def _resolve_with_args(self, *lambda_args: Any) -> ClauseElement:
450 assert isinstance(self._rec, AnalyzedFunction)
451 tracker_fn = self._rec.tracker_instrumented_fn
452 expr = tracker_fn(*lambda_args)
453
454 expr = coercions.expect(self.role, expr)
455
456 expr = self._setup_binds_for_tracked_expr(expr)
457
458 # this validation is getting very close, but not quite, to achieving
459 # #5767. The problem is if the base lambda uses an unnamed column
460 # as is very common with mixins, the parameter name is different
461 # and it produces a false positive; that is, for the documented case
462 # that is exactly what people will be doing, it doesn't work, so
463 # I'm not really sure how to handle this right now.
464 # expected_binds = [
465 # b._orig_key
466 # for b in self._rec.expr._generate_cache_key()[1]
467 # if b.required
468 # ]
469 # got_binds = [
470 # b._orig_key for b in expr._generate_cache_key()[1] if b.required
471 # ]
472 # if expected_binds != got_binds:
473 # raise exc.InvalidRequestError(
474 # "Lambda callable at %s produced a different set of bound "
475 # "parameters than its original run: %s"
476 # % (self.fn.__code__, ", ".join(got_binds))
477 # )
478
479 # TODO: TEST TEST TEST, this is very out there
480 for deferred_copy_internals in self._transforms:
481 expr = deferred_copy_internals(expr)
482
483 return expr # type: ignore
484
485 def _copy_internals(
486 self, clone=_clone, deferred_copy_internals=None, **kw
487 ):
488 super()._copy_internals(
489 clone=clone,
490 deferred_copy_internals=deferred_copy_internals, # **kw
491 opts=kw,
492 )
493
494 # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know
495 # our expression yet. so hold onto the replacement
496 if deferred_copy_internals:
497 self._transforms += (deferred_copy_internals,)
498
499
500class StatementLambdaElement(
501 roles.AllowsLambdaRole, LambdaElement, Executable
502):
503 """Represent a composable SQL statement as a :class:`_sql.LambdaElement`.
504
505 The :class:`_sql.StatementLambdaElement` is constructed using the
506 :func:`_sql.lambda_stmt` function::
507
508
509 from sqlalchemy import lambda_stmt
510
511 stmt = lambda_stmt(lambda: select(table))
512
513 Once constructed, additional criteria can be built onto the statement
514 by adding subsequent lambdas, which accept the existing statement
515 object as a single parameter::
516
517 stmt += lambda s: s.where(table.c.col == parameter)
518
519 .. versionadded:: 1.4
520
521 .. seealso::
522
523 :ref:`engine_lambda_caching`
524
525 """
526
527 if TYPE_CHECKING:
528
529 def __init__(
530 self,
531 fn: _StmtLambdaType,
532 role: Type[SQLRole],
533 opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
534 apply_propagate_attrs: Optional[ClauseElement] = None,
535 ): ...
536
537 def __add__(
538 self, other: _StmtLambdaElementType[Any]
539 ) -> StatementLambdaElement:
540 return self.add_criteria(other)
541
542 def add_criteria(
543 self,
544 other: _StmtLambdaElementType[Any],
545 enable_tracking: bool = True,
546 track_on: Optional[Any] = None,
547 track_closure_variables: bool = True,
548 track_bound_values: bool = True,
549 ) -> StatementLambdaElement:
550 """Add new criteria to this :class:`_sql.StatementLambdaElement`.
551
552 E.g.::
553
554 >>> def my_stmt(parameter):
555 ... stmt = lambda_stmt(
556 ... lambda: select(table.c.x, table.c.y),
557 ... )
558 ... stmt = stmt.add_criteria(lambda: table.c.x > parameter)
559 ... return stmt
560
561 The :meth:`_sql.StatementLambdaElement.add_criteria` method is
562 equivalent to using the Python addition operator to add a new
563 lambda, except that additional arguments may be added including
564 ``track_closure_values`` and ``track_on``::
565
566 >>> def my_stmt(self, foo):
567 ... stmt = lambda_stmt(
568 ... lambda: select(func.max(foo.x, foo.y)),
569 ... track_closure_variables=False,
570 ... )
571 ... stmt = stmt.add_criteria(lambda: self.where_criteria, track_on=[self])
572 ... return stmt
573
574 See :func:`_sql.lambda_stmt` for a description of the parameters
575 accepted.
576
577 """ # noqa: E501
578
579 opts = self.opts + dict(
580 enable_tracking=enable_tracking,
581 track_closure_variables=track_closure_variables,
582 global_track_bound_values=self.opts.global_track_bound_values,
583 track_on=track_on,
584 track_bound_values=track_bound_values,
585 )
586
587 return LinkedLambdaElement(other, parent_lambda=self, opts=opts)
588
589 def _execute_on_connection(
590 self, connection, distilled_params, execution_options
591 ):
592 if TYPE_CHECKING:
593 assert isinstance(self._rec.expected_expr, ClauseElement)
594 if self._rec.expected_expr.supports_execution:
595 return connection._execute_clauseelement(
596 self, distilled_params, execution_options
597 )
598 else:
599 raise exc.ObjectNotExecutableError(self)
600
601 @property
602 def _proxied(self) -> Any:
603 return self._rec_expected_expr
604
605 @property
606 def _with_options(self): # type: ignore[override]
607 return self._proxied._with_options
608
609 @property
610 def _effective_plugin_target(self):
611 return self._proxied._effective_plugin_target
612
613 @property
614 def _execution_options(self): # type: ignore[override]
615 return self._proxied._execution_options
616
617 @property
618 def _all_selected_columns(self):
619 return self._proxied._all_selected_columns
620
621 @property
622 def is_select(self): # type: ignore[override]
623 return self._proxied.is_select
624
625 @property
626 def is_update(self): # type: ignore[override]
627 return self._proxied.is_update
628
629 @property
630 def is_insert(self): # type: ignore[override]
631 return self._proxied.is_insert
632
633 @property
634 def is_text(self): # type: ignore[override]
635 return self._proxied.is_text
636
637 @property
638 def is_delete(self): # type: ignore[override]
639 return self._proxied.is_delete
640
641 @property
642 def is_dml(self): # type: ignore[override]
643 return self._proxied.is_dml
644
645 def spoil(self) -> NullLambdaStatement:
646 """Return a new :class:`.StatementLambdaElement` that will run
647 all lambdas unconditionally each time.
648
649 """
650 return NullLambdaStatement(self.fn())
651
652
653class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement):
654 """Provides the :class:`.StatementLambdaElement` API but does not
655 cache or analyze lambdas.
656
657 the lambdas are instead invoked immediately.
658
659 The intended use is to isolate issues that may arise when using
660 lambda statements.
661
662 """
663
664 __visit_name__ = "lambda_element"
665
666 _is_lambda_element = True
667
668 _traverse_internals = [
669 ("_resolved", visitors.InternalTraversal.dp_clauseelement)
670 ]
671
672 def __init__(self, statement):
673 self._resolved = statement
674 self._propagate_attrs = statement._propagate_attrs
675
676 def __getattr__(self, key):
677 return getattr(self._resolved, key)
678
679 def __add__(self, other):
680 statement = other(self._resolved)
681
682 return NullLambdaStatement(statement)
683
684 def add_criteria(self, other, **kw):
685 statement = other(self._resolved)
686
687 return NullLambdaStatement(statement)
688
689 def _execute_on_connection(
690 self, connection, distilled_params, execution_options
691 ):
692 if self._resolved.supports_execution:
693 return connection._execute_clauseelement(
694 self, distilled_params, execution_options
695 )
696 else:
697 raise exc.ObjectNotExecutableError(self)
698
699
700class LinkedLambdaElement(StatementLambdaElement):
701 """Represent subsequent links of a :class:`.StatementLambdaElement`."""
702
703 parent_lambda: StatementLambdaElement
704
705 def __init__(
706 self,
707 fn: _StmtLambdaElementType[Any],
708 parent_lambda: StatementLambdaElement,
709 opts: Union[Type[LambdaOptions], LambdaOptions],
710 ):
711 self.opts = opts
712 self.fn = fn
713 self.parent_lambda = parent_lambda
714
715 self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
716 self._retrieve_tracker_rec(fn, self, opts)
717 self._propagate_attrs = parent_lambda._propagate_attrs
718
719 def _invoke_user_fn(self, fn, *arg):
720 return fn(self.parent_lambda._resolved)
721
722
723class AnalyzedCode:
724 __slots__ = (
725 "track_closure_variables",
726 "track_bound_values",
727 "bindparam_trackers",
728 "closure_trackers",
729 "build_py_wrappers",
730 )
731 _fns: weakref.WeakKeyDictionary[CodeType, AnalyzedCode] = (
732 weakref.WeakKeyDictionary()
733 )
734
735 _generation_mutex = threading.RLock()
736
737 @classmethod
738 def get(cls, fn, lambda_element, lambda_kw, **kw):
739 try:
740 # TODO: validate kw haven't changed?
741 return cls._fns[fn.__code__]
742 except KeyError:
743 pass
744
745 with cls._generation_mutex:
746 # check for other thread already created object
747 if fn.__code__ in cls._fns:
748 return cls._fns[fn.__code__]
749
750 analyzed: AnalyzedCode
751 cls._fns[fn.__code__] = analyzed = AnalyzedCode(
752 fn, lambda_element, lambda_kw, **kw
753 )
754 return analyzed
755
756 def __init__(self, fn, lambda_element, opts):
757 if inspect.ismethod(fn):
758 raise exc.ArgumentError(
759 "Method %s may not be passed as a SQL expression" % fn
760 )
761 closure = fn.__closure__
762
763 self.track_bound_values = (
764 opts.track_bound_values and opts.global_track_bound_values
765 )
766 enable_tracking = opts.enable_tracking
767 track_on = opts.track_on
768 track_closure_variables = opts.track_closure_variables
769
770 self.track_closure_variables = track_closure_variables and not track_on
771
772 # a list of callables generated from _bound_parameter_getter_*
773 # functions. Each of these uses a PyWrapper object to retrieve
774 # a parameter value
775 self.bindparam_trackers = []
776
777 # a list of callables generated from _cache_key_getter_* functions
778 # these callables work to generate a cache key for the lambda
779 # based on what's inside its closure variables.
780 self.closure_trackers = []
781
782 self.build_py_wrappers = []
783
784 if enable_tracking:
785 if track_on:
786 self._init_track_on(track_on)
787
788 self._init_globals(fn)
789
790 if closure:
791 self._init_closure(fn)
792
793 self._setup_additional_closure_trackers(fn, lambda_element, opts)
794
795 def _init_track_on(self, track_on):
796 self.closure_trackers.extend(
797 self._cache_key_getter_track_on(idx, elem)
798 for idx, elem in enumerate(track_on)
799 )
800
801 def _init_globals(self, fn):
802 build_py_wrappers = self.build_py_wrappers
803 bindparam_trackers = self.bindparam_trackers
804 track_bound_values = self.track_bound_values
805
806 for name in fn.__code__.co_names:
807 if name not in fn.__globals__:
808 continue
809
810 _bound_value = self._roll_down_to_literal(fn.__globals__[name])
811
812 if coercions._deep_is_literal(_bound_value):
813 build_py_wrappers.append((name, None))
814 if track_bound_values:
815 bindparam_trackers.append(
816 self._bound_parameter_getter_func_globals(name)
817 )
818
819 def _init_closure(self, fn):
820 build_py_wrappers = self.build_py_wrappers
821 closure = fn.__closure__
822
823 track_bound_values = self.track_bound_values
824 track_closure_variables = self.track_closure_variables
825 bindparam_trackers = self.bindparam_trackers
826 closure_trackers = self.closure_trackers
827
828 for closure_index, (fv, cell) in enumerate(
829 zip(fn.__code__.co_freevars, closure)
830 ):
831 _bound_value = self._roll_down_to_literal(cell.cell_contents)
832
833 if coercions._deep_is_literal(_bound_value):
834 build_py_wrappers.append((fv, closure_index))
835 if track_bound_values:
836 bindparam_trackers.append(
837 self._bound_parameter_getter_func_closure(
838 fv, closure_index
839 )
840 )
841 else:
842 # for normal cell contents, add them to a list that
843 # we can compare later when we get new lambdas. if
844 # any identities have changed, then we will
845 # recalculate the whole lambda and run it again.
846
847 if track_closure_variables:
848 closure_trackers.append(
849 self._cache_key_getter_closure_variable(
850 fn, fv, closure_index, cell.cell_contents
851 )
852 )
853
854 def _setup_additional_closure_trackers(self, fn, lambda_element, opts):
855 # an additional step is to actually run the function, then
856 # go through the PyWrapper objects that were set up to catch a bound
857 # parameter. then if they *didn't* make a param, oh they're another
858 # object in the closure we have to track for our cache key. so
859 # create trackers to catch those.
860
861 analyzed_function = AnalyzedFunction(
862 self,
863 lambda_element,
864 None,
865 fn,
866 )
867
868 closure_trackers = self.closure_trackers
869
870 for pywrapper in analyzed_function.closure_pywrappers:
871 if not pywrapper._sa__has_param:
872 closure_trackers.append(
873 self._cache_key_getter_tracked_literal(fn, pywrapper)
874 )
875
876 @classmethod
877 def _roll_down_to_literal(cls, element):
878 is_clause_element = hasattr(element, "__clause_element__")
879
880 if is_clause_element:
881 while not isinstance(
882 element, (elements.ClauseElement, schema.SchemaItem, type)
883 ):
884 try:
885 element = element.__clause_element__()
886 except AttributeError:
887 break
888
889 if not is_clause_element:
890 insp = inspection.inspect(element, raiseerr=False)
891 if insp is not None:
892 try:
893 return insp.__clause_element__()
894 except AttributeError:
895 return insp
896
897 # TODO: should we coerce consts None/True/False here?
898 return element
899 else:
900 return element
901
902 def _bound_parameter_getter_func_globals(self, name):
903 """Return a getter that will extend a list of bound parameters
904 with new entries from the ``__globals__`` collection of a particular
905 lambda.
906
907 """
908
909 def extract_parameter_value(
910 current_fn, tracker_instrumented_fn, result
911 ):
912 wrapper = tracker_instrumented_fn.__globals__[name]
913 object.__getattribute__(wrapper, "_extract_bound_parameters")(
914 current_fn.__globals__[name], result
915 )
916
917 return extract_parameter_value
918
919 def _bound_parameter_getter_func_closure(self, name, closure_index):
920 """Return a getter that will extend a list of bound parameters
921 with new entries from the ``__closure__`` collection of a particular
922 lambda.
923
924 """
925
926 def extract_parameter_value(
927 current_fn, tracker_instrumented_fn, result
928 ):
929 wrapper = tracker_instrumented_fn.__closure__[
930 closure_index
931 ].cell_contents
932 object.__getattribute__(wrapper, "_extract_bound_parameters")(
933 current_fn.__closure__[closure_index].cell_contents, result
934 )
935
936 return extract_parameter_value
937
938 def _cache_key_getter_track_on(self, idx, elem):
939 """Return a getter that will extend a cache key with new entries
940 from the "track_on" parameter passed to a :class:`.LambdaElement`.
941
942 """
943
944 if isinstance(elem, tuple):
945 # tuple must contain hascachekey elements
946 def get(closure, opts, anon_map, bindparams):
947 return tuple(
948 tup_elem._gen_cache_key(anon_map, bindparams)
949 for tup_elem in opts.track_on[idx]
950 )
951
952 elif isinstance(elem, _cache_key.HasCacheKey):
953
954 def get(closure, opts, anon_map, bindparams):
955 return opts.track_on[idx]._gen_cache_key(anon_map, bindparams)
956
957 else:
958
959 def get(closure, opts, anon_map, bindparams):
960 return opts.track_on[idx]
961
962 return get
963
964 def _cache_key_getter_closure_variable(
965 self,
966 fn,
967 variable_name,
968 idx,
969 cell_contents,
970 use_clause_element=False,
971 use_inspect=False,
972 ):
973 """Return a getter that will extend a cache key with new entries
974 from the ``__closure__`` collection of a particular lambda.
975
976 """
977
978 if isinstance(cell_contents, _cache_key.HasCacheKey):
979
980 def get(closure, opts, anon_map, bindparams):
981 obj = closure[idx].cell_contents
982 if use_inspect:
983 obj = inspection.inspect(obj)
984 elif use_clause_element:
985 while hasattr(obj, "__clause_element__"):
986 if not getattr(obj, "is_clause_element", False):
987 obj = obj.__clause_element__()
988
989 return obj._gen_cache_key(anon_map, bindparams)
990
991 elif isinstance(cell_contents, types.FunctionType):
992
993 def get(closure, opts, anon_map, bindparams):
994 return closure[idx].cell_contents.__code__
995
996 elif isinstance(cell_contents, collections_abc.Sequence):
997
998 def get(closure, opts, anon_map, bindparams):
999 contents = closure[idx].cell_contents
1000
1001 try:
1002 return tuple(
1003 elem._gen_cache_key(anon_map, bindparams)
1004 for elem in contents
1005 )
1006 except AttributeError as ae:
1007 self._raise_for_uncacheable_closure_variable(
1008 variable_name, fn, from_=ae
1009 )
1010
1011 else:
1012 # if the object is a mapped class or aliased class, or some
1013 # other object in the ORM realm of things like that, imitate
1014 # the logic used in coercions.expect() to roll it down to the
1015 # SQL element
1016 element = cell_contents
1017 is_clause_element = False
1018 while hasattr(element, "__clause_element__"):
1019 is_clause_element = True
1020 if not getattr(element, "is_clause_element", False):
1021 element = element.__clause_element__()
1022 else:
1023 break
1024
1025 if not is_clause_element:
1026 insp = inspection.inspect(element, raiseerr=False)
1027 if insp is not None:
1028 return self._cache_key_getter_closure_variable(
1029 fn, variable_name, idx, insp, use_inspect=True
1030 )
1031 else:
1032 return self._cache_key_getter_closure_variable(
1033 fn, variable_name, idx, element, use_clause_element=True
1034 )
1035
1036 self._raise_for_uncacheable_closure_variable(variable_name, fn)
1037
1038 return get
1039
1040 def _raise_for_uncacheable_closure_variable(
1041 self, variable_name, fn, from_=None
1042 ):
1043 raise exc.InvalidRequestError(
1044 "Closure variable named '%s' inside of lambda callable %s "
1045 "does not refer to a cacheable SQL element, and also does not "
1046 "appear to be serving as a SQL literal bound value based on "
1047 "the default "
1048 "SQL expression returned by the function. This variable "
1049 "needs to remain outside the scope of a SQL-generating lambda "
1050 "so that a proper cache key may be generated from the "
1051 "lambda's state. Evaluate this variable outside of the "
1052 "lambda, set track_on=[<elements>] to explicitly select "
1053 "closure elements to track, or set "
1054 "track_closure_variables=False to exclude "
1055 "closure variables from being part of the cache key."
1056 % (variable_name, fn.__code__),
1057 ) from from_
1058
1059 def _cache_key_getter_tracked_literal(self, fn, pytracker):
1060 """Return a getter that will extend a cache key with new entries
1061 from the ``__closure__`` collection of a particular lambda.
1062
1063 this getter differs from _cache_key_getter_closure_variable
1064 in that these are detected after the function is run, and PyWrapper
1065 objects have recorded that a particular literal value is in fact
1066 not being interpreted as a bound parameter.
1067
1068 """
1069
1070 elem = pytracker._sa__to_evaluate
1071 closure_index = pytracker._sa__closure_index
1072 variable_name = pytracker._sa__name
1073
1074 return self._cache_key_getter_closure_variable(
1075 fn, variable_name, closure_index, elem
1076 )
1077
1078
1079class NonAnalyzedFunction:
1080 __slots__ = ("expr",)
1081
1082 closure_bindparams: Optional[List[BindParameter[Any]]] = None
1083 bindparam_trackers: Optional[List[_BoundParameterGetter]] = None
1084
1085 is_sequence = False
1086
1087 expr: ClauseElement
1088
1089 def __init__(self, expr: ClauseElement):
1090 self.expr = expr
1091
1092 @property
1093 def expected_expr(self) -> ClauseElement:
1094 return self.expr
1095
1096
1097class AnalyzedFunction:
1098 __slots__ = (
1099 "analyzed_code",
1100 "fn",
1101 "closure_pywrappers",
1102 "tracker_instrumented_fn",
1103 "expr",
1104 "bindparam_trackers",
1105 "expected_expr",
1106 "is_sequence",
1107 "propagate_attrs",
1108 "closure_bindparams",
1109 )
1110
1111 closure_bindparams: Optional[List[BindParameter[Any]]]
1112 expected_expr: Union[ClauseElement, List[ClauseElement]]
1113 bindparam_trackers: Optional[List[_BoundParameterGetter]]
1114
1115 def __init__(
1116 self,
1117 analyzed_code,
1118 lambda_element,
1119 apply_propagate_attrs,
1120 fn,
1121 ):
1122 self.analyzed_code = analyzed_code
1123 self.fn = fn
1124
1125 self.bindparam_trackers = analyzed_code.bindparam_trackers
1126
1127 self._instrument_and_run_function(lambda_element)
1128
1129 self._coerce_expression(lambda_element, apply_propagate_attrs)
1130
1131 def _instrument_and_run_function(self, lambda_element):
1132 analyzed_code = self.analyzed_code
1133
1134 fn = self.fn
1135 self.closure_pywrappers = closure_pywrappers = []
1136
1137 build_py_wrappers = analyzed_code.build_py_wrappers
1138
1139 if not build_py_wrappers:
1140 self.tracker_instrumented_fn = tracker_instrumented_fn = fn
1141 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
1142 else:
1143 track_closure_variables = analyzed_code.track_closure_variables
1144 closure = fn.__closure__
1145
1146 # will form the __closure__ of the function when we rebuild it
1147 if closure:
1148 new_closure = {
1149 fv: cell.cell_contents
1150 for fv, cell in zip(fn.__code__.co_freevars, closure)
1151 }
1152 else:
1153 new_closure = {}
1154
1155 # will form the __globals__ of the function when we rebuild it
1156 new_globals = fn.__globals__.copy()
1157
1158 for name, closure_index in build_py_wrappers:
1159 if closure_index is not None:
1160 value = closure[closure_index].cell_contents
1161 new_closure[name] = bind = PyWrapper(
1162 fn,
1163 name,
1164 value,
1165 closure_index=closure_index,
1166 track_bound_values=(
1167 self.analyzed_code.track_bound_values
1168 ),
1169 )
1170 if track_closure_variables:
1171 closure_pywrappers.append(bind)
1172 else:
1173 value = fn.__globals__[name]
1174 new_globals[name] = PyWrapper(fn, name, value)
1175
1176 # rewrite the original fn. things that look like they will
1177 # become bound parameters are wrapped in a PyWrapper.
1178 self.tracker_instrumented_fn = tracker_instrumented_fn = (
1179 self._rewrite_code_obj(
1180 fn,
1181 [new_closure[name] for name in fn.__code__.co_freevars],
1182 new_globals,
1183 )
1184 )
1185
1186 # now invoke the function. This will give us a new SQL
1187 # expression, but all the places that there would be a bound
1188 # parameter, the PyWrapper in its place will give us a bind
1189 # with a predictable name we can match up later.
1190
1191 # additionally, each PyWrapper will log that it did in fact
1192 # create a parameter, otherwise, it's some kind of Python
1193 # object in the closure and we want to track that, to make
1194 # sure it doesn't change to something else, or if it does,
1195 # that we create a different tracked function with that
1196 # variable.
1197 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
1198
1199 def _coerce_expression(self, lambda_element, apply_propagate_attrs):
1200 """Run the tracker-generated expression through coercion rules.
1201
1202 After the user-defined lambda has been invoked to produce a statement
1203 for reuse, run it through coercion rules to both check that it's the
1204 correct type of object and also to coerce it to its useful form.
1205
1206 """
1207
1208 parent_lambda = lambda_element.parent_lambda
1209 expr = self.expr
1210
1211 if parent_lambda is None:
1212 if isinstance(expr, collections_abc.Sequence):
1213 self.expected_expr = [
1214 cast(
1215 "ClauseElement",
1216 coercions.expect(
1217 lambda_element.role,
1218 sub_expr,
1219 apply_propagate_attrs=apply_propagate_attrs,
1220 ),
1221 )
1222 for sub_expr in expr
1223 ]
1224 self.is_sequence = True
1225 else:
1226 self.expected_expr = cast(
1227 "ClauseElement",
1228 coercions.expect(
1229 lambda_element.role,
1230 expr,
1231 apply_propagate_attrs=apply_propagate_attrs,
1232 ),
1233 )
1234 self.is_sequence = False
1235 else:
1236 self.expected_expr = expr
1237 self.is_sequence = False
1238
1239 if apply_propagate_attrs is not None:
1240 self.propagate_attrs = apply_propagate_attrs._propagate_attrs
1241 else:
1242 self.propagate_attrs = util.EMPTY_DICT
1243
1244 def _rewrite_code_obj(self, f, cell_values, globals_):
1245 """Return a copy of f, with a new closure and new globals
1246
1247 yes it works in pypy :P
1248
1249 """
1250
1251 argrange = range(len(cell_values))
1252
1253 code = "def make_cells():\n"
1254 if cell_values:
1255 code += " (%s) = (%s)\n" % (
1256 ", ".join("i%d" % i for i in argrange),
1257 ", ".join("o%d" % i for i in argrange),
1258 )
1259 code += " def closure():\n"
1260 code += " return %s\n" % ", ".join("i%d" % i for i in argrange)
1261 code += " return closure.__closure__"
1262 vars_ = {"o%d" % i: cell_values[i] for i in argrange}
1263 exec(code, vars_, vars_)
1264 closure = vars_["make_cells"]()
1265
1266 func = type(f)(
1267 f.__code__, globals_, f.__name__, f.__defaults__, closure
1268 )
1269 func.__annotations__ = f.__annotations__
1270 func.__kwdefaults__ = f.__kwdefaults__
1271 func.__doc__ = f.__doc__
1272 func.__module__ = f.__module__
1273
1274 return func
1275
1276
1277class PyWrapper(ColumnOperators):
1278 """A wrapper object that is injected into the ``__globals__`` and
1279 ``__closure__`` of a Python function.
1280
1281 When the function is instrumented with :class:`.PyWrapper` objects, it is
1282 then invoked just once in order to set up the wrappers. We look through
1283 all the :class:`.PyWrapper` objects we made to find the ones that generated
1284 a :class:`.BindParameter` object, e.g. the expression system interpreted
1285 something as a literal. Those positions in the globals/closure are then
1286 ones that we will look at, each time a new lambda comes in that refers to
1287 the same ``__code__`` object. In this way, we keep a single version of
1288 the SQL expression that this lambda produced, without calling upon the
1289 Python function that created it more than once, unless its other closure
1290 variables have changed. The expression is then transformed to have the
1291 new bound values embedded into it.
1292
1293 """
1294
1295 def __init__(
1296 self,
1297 fn,
1298 name,
1299 to_evaluate,
1300 closure_index=None,
1301 getter=None,
1302 track_bound_values=True,
1303 ):
1304 self.fn = fn
1305 self._name = name
1306 self._to_evaluate = to_evaluate
1307 self._param = None
1308 self._has_param = False
1309 self._bind_paths = {}
1310 self._getter = getter
1311 self._closure_index = closure_index
1312 self.track_bound_values = track_bound_values
1313
1314 def __call__(self, *arg, **kw):
1315 elem = object.__getattribute__(self, "_to_evaluate")
1316 value = elem(*arg, **kw)
1317 if (
1318 self._sa_track_bound_values
1319 and coercions._deep_is_literal(value)
1320 and not isinstance(
1321 # TODO: coverage where an ORM option or similar is here
1322 value,
1323 _cache_key.HasCacheKey,
1324 )
1325 ):
1326 name = object.__getattribute__(self, "_name")
1327 raise exc.InvalidRequestError(
1328 "Can't invoke Python callable %s() inside of lambda "
1329 "expression argument at %s; lambda SQL constructs should "
1330 "not invoke functions from closure variables to produce "
1331 "literal values since the "
1332 "lambda SQL system normally extracts bound values without "
1333 "actually "
1334 "invoking the lambda or any functions within it. Call the "
1335 "function outside of the "
1336 "lambda and assign to a local variable that is used in the "
1337 "lambda as a closure variable, or set "
1338 "track_bound_values=False if the return value of this "
1339 "function is used in some other way other than a SQL bound "
1340 "value." % (name, self._sa_fn.__code__)
1341 )
1342 else:
1343 return value
1344
1345 def operate(self, op, *other, **kwargs):
1346 elem = object.__getattribute__(self, "_py_wrapper_literal")()
1347 return op(elem, *other, **kwargs)
1348
1349 def reverse_operate(self, op, other, **kwargs):
1350 elem = object.__getattribute__(self, "_py_wrapper_literal")()
1351 return op(other, elem, **kwargs)
1352
1353 def _extract_bound_parameters(self, starting_point, result_list):
1354 param = object.__getattribute__(self, "_param")
1355 if param is not None:
1356 param = param._with_value(starting_point, maintain_key=True)
1357 result_list.append(param)
1358 for pywrapper in object.__getattribute__(self, "_bind_paths").values():
1359 getter = object.__getattribute__(pywrapper, "_getter")
1360 element = getter(starting_point)
1361 pywrapper._sa__extract_bound_parameters(element, result_list)
1362
1363 def _py_wrapper_literal(self, expr=None, operator=None, **kw):
1364 param = object.__getattribute__(self, "_param")
1365 to_evaluate = object.__getattribute__(self, "_to_evaluate")
1366 if param is None:
1367 name = object.__getattribute__(self, "_name")
1368 self._param = param = elements.BindParameter(
1369 name,
1370 required=False,
1371 unique=True,
1372 _compared_to_operator=operator,
1373 _compared_to_type=expr.type if expr is not None else None,
1374 )
1375 self._has_param = True
1376 return param._with_value(to_evaluate, maintain_key=True)
1377
1378 def __bool__(self):
1379 to_evaluate = object.__getattribute__(self, "_to_evaluate")
1380 return bool(to_evaluate)
1381
1382 def __getattribute__(self, key):
1383 if key.startswith("_sa_"):
1384 return object.__getattribute__(self, key[4:])
1385 elif key in (
1386 "__clause_element__",
1387 "operate",
1388 "reverse_operate",
1389 "_py_wrapper_literal",
1390 "__class__",
1391 "__dict__",
1392 ):
1393 return object.__getattribute__(self, key)
1394
1395 if key.startswith("__"):
1396 elem = object.__getattribute__(self, "_to_evaluate")
1397 return getattr(elem, key)
1398 else:
1399 return self._sa__add_getter(key, operator.attrgetter)
1400
1401 def __iter__(self):
1402 elem = object.__getattribute__(self, "_to_evaluate")
1403 return iter(elem)
1404
1405 def __getitem__(self, key):
1406 elem = object.__getattribute__(self, "_to_evaluate")
1407 if not hasattr(elem, "__getitem__"):
1408 raise AttributeError("__getitem__")
1409
1410 if isinstance(key, PyWrapper):
1411 # TODO: coverage
1412 raise exc.InvalidRequestError(
1413 "Dictionary keys / list indexes inside of a cached "
1414 "lambda must be Python literals only"
1415 )
1416 return self._sa__add_getter(key, operator.itemgetter)
1417
1418 def _add_getter(self, key, getter_fn):
1419 bind_paths = object.__getattribute__(self, "_bind_paths")
1420
1421 bind_path_key = (key, getter_fn)
1422 if bind_path_key in bind_paths:
1423 return bind_paths[bind_path_key]
1424
1425 getter = getter_fn(key)
1426 elem = object.__getattribute__(self, "_to_evaluate")
1427 value = getter(elem)
1428
1429 rolled_down_value = AnalyzedCode._roll_down_to_literal(value)
1430
1431 if coercions._deep_is_literal(rolled_down_value):
1432 wrapper = PyWrapper(self._sa_fn, key, value, getter=getter)
1433 bind_paths[bind_path_key] = wrapper
1434 return wrapper
1435 else:
1436 return value
1437
1438
1439@inspection._inspects(LambdaElement)
1440def insp(lmb):
1441 return inspection.inspect(lmb._resolved)