1# sql/lambdas.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11import collections.abc as collections_abc
12import inspect
13import itertools
14import operator
15import threading
16import types
17from types import CodeType
18from typing import Any
19from typing import Callable
20from typing import cast
21from typing import List
22from typing import MutableMapping
23from typing import Optional
24from typing import Tuple
25from typing import Type
26from typing import TYPE_CHECKING
27from typing import TypeVar
28from typing import Union
29import weakref
30
31from . import cache_key as _cache_key
32from . import coercions
33from . import elements
34from . import roles
35from . import schema
36from . import visitors
37from .base import _clone
38from .base import Executable
39from .base import Options
40from .cache_key import CacheConst
41from .operators import ColumnOperators
42from .. import exc
43from .. import inspection
44from .. import util
45from ..util.typing import Literal
46
47
48if TYPE_CHECKING:
49 from .elements import BindParameter
50 from .elements import ClauseElement
51 from .roles import SQLRole
52 from .visitors import _CloneCallableType
53
54_LambdaCacheType = MutableMapping[
55 Tuple[Any, ...], Union["NonAnalyzedFunction", "AnalyzedFunction"]
56]
57_BoundParameterGetter = Callable[..., Any]
58
59_closure_per_cache_key: _LambdaCacheType = util.LRUCache(1000)
60
61
62_LambdaType = Callable[[], Any]
63
64_AnyLambdaType = Callable[..., Any]
65
66_StmtLambdaType = Callable[[], Any]
67
68_E = TypeVar("_E", bound=Executable)
69_StmtLambdaElementType = Callable[[_E], Any]
70
71
72class LambdaOptions(Options):
73 enable_tracking = True
74 track_closure_variables = True
75 track_on: Optional[object] = None
76 global_track_bound_values = True
77 track_bound_values = True
78 lambda_cache: Optional[_LambdaCacheType] = None
79
80
81def lambda_stmt(
82 lmb: _StmtLambdaType,
83 enable_tracking: bool = True,
84 track_closure_variables: bool = True,
85 track_on: Optional[object] = None,
86 global_track_bound_values: bool = True,
87 track_bound_values: bool = True,
88 lambda_cache: Optional[_LambdaCacheType] = None,
89) -> StatementLambdaElement:
90 """Produce a SQL statement that is cached as a lambda.
91
92 The Python code object within the lambda is scanned for both Python
93 literals that will become bound parameters as well as closure variables
94 that refer to Core or ORM constructs that may vary. The lambda itself
95 will be invoked only once per particular set of constructs detected.
96
97 E.g.::
98
99 from sqlalchemy import lambda_stmt
100
101 stmt = lambda_stmt(lambda: table.select())
102 stmt += lambda s: s.where(table.c.id == 5)
103
104 result = connection.execute(stmt)
105
106 The object returned is an instance of :class:`_sql.StatementLambdaElement`.
107
108 .. versionadded:: 1.4
109
110 :param lmb: a Python function, typically a lambda, which takes no arguments
111 and returns a SQL expression construct
112 :param enable_tracking: when False, all scanning of the given lambda for
113 changes in closure variables or bound parameters is disabled. Use for
114 a lambda that produces the identical results in all cases with no
115 parameterization.
116 :param track_closure_variables: when False, changes in closure variables
117 within the lambda will not be scanned. Use for a lambda where the
118 state of its closure variables will never change the SQL structure
119 returned by the lambda.
120 :param track_bound_values: when False, bound parameter tracking will
121 be disabled for the given lambda. Use for a lambda that either does
122 not produce any bound values, or where the initial bound values never
123 change.
124 :param global_track_bound_values: when False, bound parameter tracking
125 will be disabled for the entire statement including additional links
126 added via the :meth:`_sql.StatementLambdaElement.add_criteria` method.
127 :param lambda_cache: a dictionary or other mapping-like object where
128 information about the lambda's Python code as well as the tracked closure
129 variables in the lambda itself will be stored. Defaults
130 to a global LRU cache. This cache is independent of the "compiled_cache"
131 used by the :class:`_engine.Connection` object.
132
133 .. seealso::
134
135 :ref:`engine_lambda_caching`
136
137
138 """
139
140 return StatementLambdaElement(
141 lmb,
142 roles.StatementRole,
143 LambdaOptions(
144 enable_tracking=enable_tracking,
145 track_on=track_on,
146 track_closure_variables=track_closure_variables,
147 global_track_bound_values=global_track_bound_values,
148 track_bound_values=track_bound_values,
149 lambda_cache=lambda_cache,
150 ),
151 )
152
153
154class LambdaElement(elements.ClauseElement):
155 """A SQL construct where the state is stored as an un-invoked lambda.
156
157 The :class:`_sql.LambdaElement` is produced transparently whenever
158 passing lambda expressions into SQL constructs, such as::
159
160 stmt = select(table).where(lambda: table.c.col == parameter)
161
162 The :class:`_sql.LambdaElement` is the base of the
163 :class:`_sql.StatementLambdaElement` which represents a full statement
164 within a lambda.
165
166 .. versionadded:: 1.4
167
168 .. seealso::
169
170 :ref:`engine_lambda_caching`
171
172 """
173
174 __visit_name__ = "lambda_element"
175
176 _is_lambda_element = True
177
178 _traverse_internals = [
179 ("_resolved", visitors.InternalTraversal.dp_clauseelement)
180 ]
181
182 _transforms: Tuple[_CloneCallableType, ...] = ()
183
184 _resolved_bindparams: List[BindParameter[Any]]
185 parent_lambda: Optional[StatementLambdaElement] = None
186 closure_cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]]
187 role: Type[SQLRole]
188 _rec: Union[AnalyzedFunction, NonAnalyzedFunction]
189 fn: _AnyLambdaType
190 tracker_key: Tuple[CodeType, ...]
191
192 def __repr__(self):
193 return "%s(%r)" % (
194 self.__class__.__name__,
195 self.fn.__code__,
196 )
197
198 def __init__(
199 self,
200 fn: _LambdaType,
201 role: Type[SQLRole],
202 opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
203 apply_propagate_attrs: Optional[ClauseElement] = None,
204 ):
205 self.fn = fn
206 self.role = role
207 self.tracker_key = (fn.__code__,)
208 self.opts = opts
209
210 if apply_propagate_attrs is None and (role is roles.StatementRole):
211 apply_propagate_attrs = self
212
213 rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts)
214
215 if apply_propagate_attrs is not None:
216 propagate_attrs = rec.propagate_attrs
217 if propagate_attrs:
218 apply_propagate_attrs._propagate_attrs = propagate_attrs
219
220 def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts):
221 lambda_cache = opts.lambda_cache
222 if lambda_cache is None:
223 lambda_cache = _closure_per_cache_key
224
225 tracker_key = self.tracker_key
226
227 fn = self.fn
228 closure = fn.__closure__
229 tracker = AnalyzedCode.get(
230 fn,
231 self,
232 opts,
233 )
234
235 bindparams: List[BindParameter[Any]]
236 self._resolved_bindparams = bindparams = []
237
238 if self.parent_lambda is not None:
239 parent_closure_cache_key = self.parent_lambda.closure_cache_key
240 else:
241 parent_closure_cache_key = ()
242
243 cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]]
244
245 if parent_closure_cache_key is not _cache_key.NO_CACHE:
246 anon_map = visitors.anon_map()
247 cache_key = tuple(
248 [
249 getter(closure, opts, anon_map, bindparams)
250 for getter in tracker.closure_trackers
251 ]
252 )
253
254 if _cache_key.NO_CACHE not in anon_map:
255 cache_key = parent_closure_cache_key + cache_key
256
257 self.closure_cache_key = cache_key
258
259 rec = lambda_cache.get(tracker_key + cache_key)
260 else:
261 cache_key = _cache_key.NO_CACHE
262 rec = None
263
264 else:
265 cache_key = _cache_key.NO_CACHE
266 rec = None
267
268 self.closure_cache_key = cache_key
269
270 if rec is None:
271 if cache_key is not _cache_key.NO_CACHE:
272 with AnalyzedCode._generation_mutex:
273 key = tracker_key + cache_key
274 if key not in lambda_cache:
275 rec = AnalyzedFunction(
276 tracker, self, apply_propagate_attrs, fn
277 )
278 rec.closure_bindparams = list(bindparams)
279 lambda_cache[key] = rec
280 else:
281 rec = lambda_cache[key]
282 else:
283 rec = NonAnalyzedFunction(self._invoke_user_fn(fn))
284
285 else:
286 bindparams[:] = [
287 orig_bind._with_value(new_bind.value, maintain_key=True)
288 for orig_bind, new_bind in zip(
289 rec.closure_bindparams, bindparams
290 )
291 ]
292
293 self._rec = rec
294
295 if cache_key is not _cache_key.NO_CACHE:
296 if self.parent_lambda is not None:
297 bindparams[:0] = self.parent_lambda._resolved_bindparams
298
299 lambda_element: Optional[LambdaElement] = self
300 while lambda_element is not None:
301 rec = lambda_element._rec
302 if rec.bindparam_trackers:
303 tracker_instrumented_fn = (
304 rec.tracker_instrumented_fn # type:ignore [union-attr] # noqa: E501
305 )
306 for tracker in rec.bindparam_trackers:
307 tracker(
308 lambda_element.fn,
309 tracker_instrumented_fn,
310 bindparams,
311 )
312 lambda_element = lambda_element.parent_lambda
313
314 return rec
315
316 def __getattr__(self, key):
317 return getattr(self._rec.expected_expr, key)
318
319 @property
320 def _is_sequence(self):
321 return self._rec.is_sequence
322
323 @property
324 def _select_iterable(self):
325 if self._is_sequence:
326 return itertools.chain.from_iterable(
327 [element._select_iterable for element in self._resolved]
328 )
329
330 else:
331 return self._resolved._select_iterable
332
333 @property
334 def _from_objects(self):
335 if self._is_sequence:
336 return itertools.chain.from_iterable(
337 [element._from_objects for element in self._resolved]
338 )
339
340 else:
341 return self._resolved._from_objects
342
343 def _param_dict(self):
344 return {b.key: b.value for b in self._resolved_bindparams}
345
346 def _setup_binds_for_tracked_expr(self, expr):
347 bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
348
349 def replace(
350 element: Optional[visitors.ExternallyTraversible], **kw: Any
351 ) -> Optional[visitors.ExternallyTraversible]:
352 if isinstance(element, elements.BindParameter):
353 if element.key in bindparam_lookup:
354 bind = bindparam_lookup[element.key]
355 if element.expanding:
356 bind.expanding = True
357 bind.expand_op = element.expand_op
358 bind.type = element.type
359 return bind
360
361 return None
362
363 if self._rec.is_sequence:
364 expr = [
365 visitors.replacement_traverse(sub_expr, {}, replace)
366 for sub_expr in expr
367 ]
368 elif getattr(expr, "is_clause_element", False):
369 expr = visitors.replacement_traverse(expr, {}, replace)
370
371 return expr
372
373 def _copy_internals(
374 self,
375 clone: _CloneCallableType = _clone,
376 deferred_copy_internals: Optional[_CloneCallableType] = None,
377 **kw: Any,
378 ) -> None:
379 # TODO: this needs A LOT of tests
380 self._resolved = clone(
381 self._resolved,
382 deferred_copy_internals=deferred_copy_internals,
383 **kw,
384 )
385
386 @util.memoized_property
387 def _resolved(self):
388 expr = self._rec.expected_expr
389
390 if self._resolved_bindparams:
391 expr = self._setup_binds_for_tracked_expr(expr)
392
393 return expr
394
395 def _gen_cache_key(self, anon_map, bindparams):
396 if self.closure_cache_key is _cache_key.NO_CACHE:
397 anon_map[_cache_key.NO_CACHE] = True
398 return None
399
400 cache_key = (
401 self.fn.__code__,
402 self.__class__,
403 ) + self.closure_cache_key
404
405 parent = self.parent_lambda
406
407 while parent is not None:
408 assert parent.closure_cache_key is not CacheConst.NO_CACHE
409 parent_closure_cache_key: Tuple[Any, ...] = (
410 parent.closure_cache_key
411 )
412
413 cache_key = (
414 (parent.fn.__code__,) + parent_closure_cache_key + cache_key
415 )
416
417 parent = parent.parent_lambda
418
419 if self._resolved_bindparams:
420 bindparams.extend(self._resolved_bindparams)
421 return cache_key
422
423 def _invoke_user_fn(self, fn: _AnyLambdaType, *arg: Any) -> ClauseElement:
424 return fn() # type: ignore[no-any-return]
425
426
427class DeferredLambdaElement(LambdaElement):
428 """A LambdaElement where the lambda accepts arguments and is
429 invoked within the compile phase with special context.
430
431 This lambda doesn't normally produce its real SQL expression outside of the
432 compile phase. It is passed a fixed set of initial arguments
433 so that it can generate a sample expression.
434
435 """
436
437 def __init__(
438 self,
439 fn: _AnyLambdaType,
440 role: Type[roles.SQLRole],
441 opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
442 lambda_args: Tuple[Any, ...] = (),
443 ):
444 self.lambda_args = lambda_args
445 super().__init__(fn, role, opts)
446
447 def _invoke_user_fn(self, fn, *arg):
448 return fn(*self.lambda_args)
449
450 def _resolve_with_args(self, *lambda_args: Any) -> ClauseElement:
451 assert isinstance(self._rec, AnalyzedFunction)
452 tracker_fn = self._rec.tracker_instrumented_fn
453 expr = tracker_fn(*lambda_args)
454
455 expr = coercions.expect(self.role, expr)
456
457 expr = self._setup_binds_for_tracked_expr(expr)
458
459 # this validation is getting very close, but not quite, to achieving
460 # #5767. The problem is if the base lambda uses an unnamed column
461 # as is very common with mixins, the parameter name is different
462 # and it produces a false positive; that is, for the documented case
463 # that is exactly what people will be doing, it doesn't work, so
464 # I'm not really sure how to handle this right now.
465 # expected_binds = [
466 # b._orig_key
467 # for b in self._rec.expr._generate_cache_key()[1]
468 # if b.required
469 # ]
470 # got_binds = [
471 # b._orig_key for b in expr._generate_cache_key()[1] if b.required
472 # ]
473 # if expected_binds != got_binds:
474 # raise exc.InvalidRequestError(
475 # "Lambda callable at %s produced a different set of bound "
476 # "parameters than its original run: %s"
477 # % (self.fn.__code__, ", ".join(got_binds))
478 # )
479
480 # TODO: TEST TEST TEST, this is very out there
481 for deferred_copy_internals in self._transforms:
482 expr = deferred_copy_internals(expr)
483
484 return expr # type: ignore
485
486 def _copy_internals(
487 self, clone=_clone, deferred_copy_internals=None, **kw
488 ):
489 super()._copy_internals(
490 clone=clone,
491 deferred_copy_internals=deferred_copy_internals, # **kw
492 opts=kw,
493 )
494
495 # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know
496 # our expression yet. so hold onto the replacement
497 if deferred_copy_internals:
498 self._transforms += (deferred_copy_internals,)
499
500
501class StatementLambdaElement(
502 roles.AllowsLambdaRole, LambdaElement, Executable
503):
504 """Represent a composable SQL statement as a :class:`_sql.LambdaElement`.
505
506 The :class:`_sql.StatementLambdaElement` is constructed using the
507 :func:`_sql.lambda_stmt` function::
508
509
510 from sqlalchemy import lambda_stmt
511
512 stmt = lambda_stmt(lambda: select(table))
513
514 Once constructed, additional criteria can be built onto the statement
515 by adding subsequent lambdas, which accept the existing statement
516 object as a single parameter::
517
518 stmt += lambda s: s.where(table.c.col == parameter)
519
520 .. versionadded:: 1.4
521
522 .. seealso::
523
524 :ref:`engine_lambda_caching`
525
526 """
527
528 if TYPE_CHECKING:
529
530 def __init__(
531 self,
532 fn: _StmtLambdaType,
533 role: Type[SQLRole],
534 opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
535 apply_propagate_attrs: Optional[ClauseElement] = None,
536 ): ...
537
538 def __add__(
539 self, other: _StmtLambdaElementType[Any]
540 ) -> StatementLambdaElement:
541 return self.add_criteria(other)
542
543 def add_criteria(
544 self,
545 other: _StmtLambdaElementType[Any],
546 enable_tracking: bool = True,
547 track_on: Optional[Any] = None,
548 track_closure_variables: bool = True,
549 track_bound_values: bool = True,
550 ) -> StatementLambdaElement:
551 """Add new criteria to this :class:`_sql.StatementLambdaElement`.
552
553 E.g.::
554
555 >>> def my_stmt(parameter):
556 ... stmt = lambda_stmt(
557 ... lambda: select(table.c.x, table.c.y),
558 ... )
559 ... stmt = stmt.add_criteria(lambda: table.c.x > parameter)
560 ... return stmt
561
562 The :meth:`_sql.StatementLambdaElement.add_criteria` method is
563 equivalent to using the Python addition operator to add a new
564 lambda, except that additional arguments may be added including
565 ``track_closure_values`` and ``track_on``::
566
567 >>> def my_stmt(self, foo):
568 ... stmt = lambda_stmt(
569 ... lambda: select(func.max(foo.x, foo.y)),
570 ... track_closure_variables=False,
571 ... )
572 ... stmt = stmt.add_criteria(lambda: self.where_criteria, track_on=[self])
573 ... return stmt
574
575 See :func:`_sql.lambda_stmt` for a description of the parameters
576 accepted.
577
578 """ # noqa: E501
579
580 opts = self.opts + dict(
581 enable_tracking=enable_tracking,
582 track_closure_variables=track_closure_variables,
583 global_track_bound_values=self.opts.global_track_bound_values,
584 track_on=track_on,
585 track_bound_values=track_bound_values,
586 )
587
588 return LinkedLambdaElement(other, parent_lambda=self, opts=opts)
589
590 def _execute_on_connection(
591 self, connection, distilled_params, execution_options
592 ):
593 if TYPE_CHECKING:
594 assert isinstance(self._rec.expected_expr, ClauseElement)
595 if self._rec.expected_expr.supports_execution:
596 return connection._execute_clauseelement(
597 self, distilled_params, execution_options
598 )
599 else:
600 raise exc.ObjectNotExecutableError(self)
601
602 @property
603 def _proxied(self) -> Any:
604 return self._rec_expected_expr
605
606 @property
607 def _with_options(self): # type: ignore[override]
608 return self._proxied._with_options
609
610 @property
611 def _effective_plugin_target(self):
612 return self._proxied._effective_plugin_target
613
614 @property
615 def _execution_options(self): # type: ignore[override]
616 return self._proxied._execution_options
617
618 @property
619 def _all_selected_columns(self):
620 return self._proxied._all_selected_columns
621
622 @property
623 def is_select(self): # type: ignore[override]
624 return self._proxied.is_select
625
626 @property
627 def is_update(self): # type: ignore[override]
628 return self._proxied.is_update
629
630 @property
631 def is_insert(self): # type: ignore[override]
632 return self._proxied.is_insert
633
634 @property
635 def is_text(self): # type: ignore[override]
636 return self._proxied.is_text
637
638 @property
639 def is_delete(self): # type: ignore[override]
640 return self._proxied.is_delete
641
642 @property
643 def is_dml(self): # type: ignore[override]
644 return self._proxied.is_dml
645
646 def spoil(self) -> NullLambdaStatement:
647 """Return a new :class:`.StatementLambdaElement` that will run
648 all lambdas unconditionally each time.
649
650 """
651 return NullLambdaStatement(self.fn())
652
653
654class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement):
655 """Provides the :class:`.StatementLambdaElement` API but does not
656 cache or analyze lambdas.
657
658 the lambdas are instead invoked immediately.
659
660 The intended use is to isolate issues that may arise when using
661 lambda statements.
662
663 """
664
665 __visit_name__ = "lambda_element"
666
667 _is_lambda_element = True
668
669 _traverse_internals = [
670 ("_resolved", visitors.InternalTraversal.dp_clauseelement)
671 ]
672
673 def __init__(self, statement):
674 self._resolved = statement
675 self._propagate_attrs = statement._propagate_attrs
676
677 def __getattr__(self, key):
678 return getattr(self._resolved, key)
679
680 def __add__(self, other):
681 statement = other(self._resolved)
682
683 return NullLambdaStatement(statement)
684
685 def add_criteria(self, other, **kw):
686 statement = other(self._resolved)
687
688 return NullLambdaStatement(statement)
689
690 def _execute_on_connection(
691 self, connection, distilled_params, execution_options
692 ):
693 if self._resolved.supports_execution:
694 return connection._execute_clauseelement(
695 self, distilled_params, execution_options
696 )
697 else:
698 raise exc.ObjectNotExecutableError(self)
699
700
701class LinkedLambdaElement(StatementLambdaElement):
702 """Represent subsequent links of a :class:`.StatementLambdaElement`."""
703
704 parent_lambda: StatementLambdaElement
705
706 def __init__(
707 self,
708 fn: _StmtLambdaElementType[Any],
709 parent_lambda: StatementLambdaElement,
710 opts: Union[Type[LambdaOptions], LambdaOptions],
711 ):
712 self.opts = opts
713 self.fn = fn
714 self.parent_lambda = parent_lambda
715
716 self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
717 self._retrieve_tracker_rec(fn, self, opts)
718 self._propagate_attrs = parent_lambda._propagate_attrs
719
720 def _invoke_user_fn(self, fn, *arg):
721 return fn(self.parent_lambda._resolved)
722
723
724class AnalyzedCode:
725 __slots__ = (
726 "track_closure_variables",
727 "track_bound_values",
728 "bindparam_trackers",
729 "closure_trackers",
730 "build_py_wrappers",
731 )
732 _fns: weakref.WeakKeyDictionary[CodeType, AnalyzedCode] = (
733 weakref.WeakKeyDictionary()
734 )
735
736 _generation_mutex = threading.RLock()
737
738 @classmethod
739 def get(cls, fn, lambda_element, lambda_kw, **kw):
740 try:
741 # TODO: validate kw haven't changed?
742 return cls._fns[fn.__code__]
743 except KeyError:
744 pass
745
746 with cls._generation_mutex:
747 # check for other thread already created object
748 if fn.__code__ in cls._fns:
749 return cls._fns[fn.__code__]
750
751 analyzed: AnalyzedCode
752 cls._fns[fn.__code__] = analyzed = AnalyzedCode(
753 fn, lambda_element, lambda_kw, **kw
754 )
755 return analyzed
756
757 def __init__(self, fn, lambda_element, opts):
758 if inspect.ismethod(fn):
759 raise exc.ArgumentError(
760 "Method %s may not be passed as a SQL expression" % fn
761 )
762 closure = fn.__closure__
763
764 self.track_bound_values = (
765 opts.track_bound_values and opts.global_track_bound_values
766 )
767 enable_tracking = opts.enable_tracking
768 track_on = opts.track_on
769 track_closure_variables = opts.track_closure_variables
770
771 self.track_closure_variables = track_closure_variables and not track_on
772
773 # a list of callables generated from _bound_parameter_getter_*
774 # functions. Each of these uses a PyWrapper object to retrieve
775 # a parameter value
776 self.bindparam_trackers = []
777
778 # a list of callables generated from _cache_key_getter_* functions
779 # these callables work to generate a cache key for the lambda
780 # based on what's inside its closure variables.
781 self.closure_trackers = []
782
783 self.build_py_wrappers = []
784
785 if enable_tracking:
786 if track_on:
787 self._init_track_on(track_on)
788
789 self._init_globals(fn)
790
791 if closure:
792 self._init_closure(fn)
793
794 self._setup_additional_closure_trackers(fn, lambda_element, opts)
795
796 def _init_track_on(self, track_on):
797 self.closure_trackers.extend(
798 self._cache_key_getter_track_on(idx, elem)
799 for idx, elem in enumerate(track_on)
800 )
801
802 def _init_globals(self, fn):
803 build_py_wrappers = self.build_py_wrappers
804 bindparam_trackers = self.bindparam_trackers
805 track_bound_values = self.track_bound_values
806
807 for name in fn.__code__.co_names:
808 if name not in fn.__globals__:
809 continue
810
811 _bound_value = self._roll_down_to_literal(fn.__globals__[name])
812
813 if coercions._deep_is_literal(_bound_value):
814 build_py_wrappers.append((name, None))
815 if track_bound_values:
816 bindparam_trackers.append(
817 self._bound_parameter_getter_func_globals(name)
818 )
819
820 def _init_closure(self, fn):
821 build_py_wrappers = self.build_py_wrappers
822 closure = fn.__closure__
823
824 track_bound_values = self.track_bound_values
825 track_closure_variables = self.track_closure_variables
826 bindparam_trackers = self.bindparam_trackers
827 closure_trackers = self.closure_trackers
828
829 for closure_index, (fv, cell) in enumerate(
830 zip(fn.__code__.co_freevars, closure)
831 ):
832 _bound_value = self._roll_down_to_literal(cell.cell_contents)
833
834 if coercions._deep_is_literal(_bound_value):
835 build_py_wrappers.append((fv, closure_index))
836 if track_bound_values:
837 bindparam_trackers.append(
838 self._bound_parameter_getter_func_closure(
839 fv, closure_index
840 )
841 )
842 else:
843 # for normal cell contents, add them to a list that
844 # we can compare later when we get new lambdas. if
845 # any identities have changed, then we will
846 # recalculate the whole lambda and run it again.
847
848 if track_closure_variables:
849 closure_trackers.append(
850 self._cache_key_getter_closure_variable(
851 fn, fv, closure_index, cell.cell_contents
852 )
853 )
854
855 def _setup_additional_closure_trackers(self, fn, lambda_element, opts):
856 # an additional step is to actually run the function, then
857 # go through the PyWrapper objects that were set up to catch a bound
858 # parameter. then if they *didn't* make a param, oh they're another
859 # object in the closure we have to track for our cache key. so
860 # create trackers to catch those.
861
862 analyzed_function = AnalyzedFunction(
863 self,
864 lambda_element,
865 None,
866 fn,
867 )
868
869 closure_trackers = self.closure_trackers
870
871 for pywrapper in analyzed_function.closure_pywrappers:
872 if not pywrapper._sa__has_param:
873 closure_trackers.append(
874 self._cache_key_getter_tracked_literal(fn, pywrapper)
875 )
876
877 @classmethod
878 def _roll_down_to_literal(cls, element):
879 is_clause_element = hasattr(element, "__clause_element__")
880
881 if is_clause_element:
882 while not isinstance(
883 element, (elements.ClauseElement, schema.SchemaItem, type)
884 ):
885 try:
886 element = element.__clause_element__()
887 except AttributeError:
888 break
889
890 if not is_clause_element:
891 insp = inspection.inspect(element, raiseerr=False)
892 if insp is not None:
893 try:
894 return insp.__clause_element__()
895 except AttributeError:
896 return insp
897
898 # TODO: should we coerce consts None/True/False here?
899 return element
900 else:
901 return element
902
903 def _bound_parameter_getter_func_globals(self, name):
904 """Return a getter that will extend a list of bound parameters
905 with new entries from the ``__globals__`` collection of a particular
906 lambda.
907
908 """
909
910 def extract_parameter_value(
911 current_fn, tracker_instrumented_fn, result
912 ):
913 wrapper = tracker_instrumented_fn.__globals__[name]
914 object.__getattribute__(wrapper, "_extract_bound_parameters")(
915 current_fn.__globals__[name], result
916 )
917
918 return extract_parameter_value
919
920 def _bound_parameter_getter_func_closure(self, name, closure_index):
921 """Return a getter that will extend a list of bound parameters
922 with new entries from the ``__closure__`` collection of a particular
923 lambda.
924
925 """
926
927 def extract_parameter_value(
928 current_fn, tracker_instrumented_fn, result
929 ):
930 wrapper = tracker_instrumented_fn.__closure__[
931 closure_index
932 ].cell_contents
933 object.__getattribute__(wrapper, "_extract_bound_parameters")(
934 current_fn.__closure__[closure_index].cell_contents, result
935 )
936
937 return extract_parameter_value
938
939 def _cache_key_getter_track_on(self, idx, elem):
940 """Return a getter that will extend a cache key with new entries
941 from the "track_on" parameter passed to a :class:`.LambdaElement`.
942
943 """
944
945 if isinstance(elem, tuple):
946 # tuple must contain hascachekey elements
947 def get(closure, opts, anon_map, bindparams):
948 return tuple(
949 tup_elem._gen_cache_key(anon_map, bindparams)
950 for tup_elem in opts.track_on[idx]
951 )
952
953 elif isinstance(elem, _cache_key.HasCacheKey):
954
955 def get(closure, opts, anon_map, bindparams):
956 return opts.track_on[idx]._gen_cache_key(anon_map, bindparams)
957
958 else:
959
960 def get(closure, opts, anon_map, bindparams):
961 return opts.track_on[idx]
962
963 return get
964
965 def _cache_key_getter_closure_variable(
966 self,
967 fn,
968 variable_name,
969 idx,
970 cell_contents,
971 use_clause_element=False,
972 use_inspect=False,
973 ):
974 """Return a getter that will extend a cache key with new entries
975 from the ``__closure__`` collection of a particular lambda.
976
977 """
978
979 if isinstance(cell_contents, _cache_key.HasCacheKey):
980
981 def get(closure, opts, anon_map, bindparams):
982 obj = closure[idx].cell_contents
983 if use_inspect:
984 obj = inspection.inspect(obj)
985 elif use_clause_element:
986 while hasattr(obj, "__clause_element__"):
987 if not getattr(obj, "is_clause_element", False):
988 obj = obj.__clause_element__()
989
990 return obj._gen_cache_key(anon_map, bindparams)
991
992 elif isinstance(cell_contents, types.FunctionType):
993
994 def get(closure, opts, anon_map, bindparams):
995 return closure[idx].cell_contents.__code__
996
997 elif isinstance(cell_contents, collections_abc.Sequence):
998
999 def get(closure, opts, anon_map, bindparams):
1000 contents = closure[idx].cell_contents
1001
1002 try:
1003 return tuple(
1004 elem._gen_cache_key(anon_map, bindparams)
1005 for elem in contents
1006 )
1007 except AttributeError as ae:
1008 self._raise_for_uncacheable_closure_variable(
1009 variable_name, fn, from_=ae
1010 )
1011
1012 else:
1013 # if the object is a mapped class or aliased class, or some
1014 # other object in the ORM realm of things like that, imitate
1015 # the logic used in coercions.expect() to roll it down to the
1016 # SQL element
1017 element = cell_contents
1018 is_clause_element = False
1019 while hasattr(element, "__clause_element__"):
1020 is_clause_element = True
1021 if not getattr(element, "is_clause_element", False):
1022 element = element.__clause_element__()
1023 else:
1024 break
1025
1026 if not is_clause_element:
1027 insp = inspection.inspect(element, raiseerr=False)
1028 if insp is not None:
1029 return self._cache_key_getter_closure_variable(
1030 fn, variable_name, idx, insp, use_inspect=True
1031 )
1032 else:
1033 return self._cache_key_getter_closure_variable(
1034 fn, variable_name, idx, element, use_clause_element=True
1035 )
1036
1037 self._raise_for_uncacheable_closure_variable(variable_name, fn)
1038
1039 return get
1040
1041 def _raise_for_uncacheable_closure_variable(
1042 self, variable_name, fn, from_=None
1043 ):
1044 raise exc.InvalidRequestError(
1045 "Closure variable named '%s' inside of lambda callable %s "
1046 "does not refer to a cacheable SQL element, and also does not "
1047 "appear to be serving as a SQL literal bound value based on "
1048 "the default "
1049 "SQL expression returned by the function. This variable "
1050 "needs to remain outside the scope of a SQL-generating lambda "
1051 "so that a proper cache key may be generated from the "
1052 "lambda's state. Evaluate this variable outside of the "
1053 "lambda, set track_on=[<elements>] to explicitly select "
1054 "closure elements to track, or set "
1055 "track_closure_variables=False to exclude "
1056 "closure variables from being part of the cache key."
1057 % (variable_name, fn.__code__),
1058 ) from from_
1059
1060 def _cache_key_getter_tracked_literal(self, fn, pytracker):
1061 """Return a getter that will extend a cache key with new entries
1062 from the ``__closure__`` collection of a particular lambda.
1063
1064 this getter differs from _cache_key_getter_closure_variable
1065 in that these are detected after the function is run, and PyWrapper
1066 objects have recorded that a particular literal value is in fact
1067 not being interpreted as a bound parameter.
1068
1069 """
1070
1071 elem = pytracker._sa__to_evaluate
1072 closure_index = pytracker._sa__closure_index
1073 variable_name = pytracker._sa__name
1074
1075 return self._cache_key_getter_closure_variable(
1076 fn, variable_name, closure_index, elem
1077 )
1078
1079
1080class NonAnalyzedFunction:
1081 __slots__ = ("expr",)
1082
1083 closure_bindparams: Optional[List[BindParameter[Any]]] = None
1084 bindparam_trackers: Optional[List[_BoundParameterGetter]] = None
1085
1086 is_sequence = False
1087
1088 expr: ClauseElement
1089
1090 def __init__(self, expr: ClauseElement):
1091 self.expr = expr
1092
1093 @property
1094 def expected_expr(self) -> ClauseElement:
1095 return self.expr
1096
1097
1098class AnalyzedFunction:
1099 __slots__ = (
1100 "analyzed_code",
1101 "fn",
1102 "closure_pywrappers",
1103 "tracker_instrumented_fn",
1104 "expr",
1105 "bindparam_trackers",
1106 "expected_expr",
1107 "is_sequence",
1108 "propagate_attrs",
1109 "closure_bindparams",
1110 )
1111
1112 closure_bindparams: Optional[List[BindParameter[Any]]]
1113 expected_expr: Union[ClauseElement, List[ClauseElement]]
1114 bindparam_trackers: Optional[List[_BoundParameterGetter]]
1115
1116 def __init__(
1117 self,
1118 analyzed_code,
1119 lambda_element,
1120 apply_propagate_attrs,
1121 fn,
1122 ):
1123 self.analyzed_code = analyzed_code
1124 self.fn = fn
1125
1126 self.bindparam_trackers = analyzed_code.bindparam_trackers
1127
1128 self._instrument_and_run_function(lambda_element)
1129
1130 self._coerce_expression(lambda_element, apply_propagate_attrs)
1131
1132 def _instrument_and_run_function(self, lambda_element):
1133 analyzed_code = self.analyzed_code
1134
1135 fn = self.fn
1136 self.closure_pywrappers = closure_pywrappers = []
1137
1138 build_py_wrappers = analyzed_code.build_py_wrappers
1139
1140 if not build_py_wrappers:
1141 self.tracker_instrumented_fn = tracker_instrumented_fn = fn
1142 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
1143 else:
1144 track_closure_variables = analyzed_code.track_closure_variables
1145 closure = fn.__closure__
1146
1147 # will form the __closure__ of the function when we rebuild it
1148 if closure:
1149 new_closure = {
1150 fv: cell.cell_contents
1151 for fv, cell in zip(fn.__code__.co_freevars, closure)
1152 }
1153 else:
1154 new_closure = {}
1155
1156 # will form the __globals__ of the function when we rebuild it
1157 new_globals = fn.__globals__.copy()
1158
1159 for name, closure_index in build_py_wrappers:
1160 if closure_index is not None:
1161 value = closure[closure_index].cell_contents
1162 new_closure[name] = bind = PyWrapper(
1163 fn,
1164 name,
1165 value,
1166 closure_index=closure_index,
1167 track_bound_values=(
1168 self.analyzed_code.track_bound_values
1169 ),
1170 )
1171 if track_closure_variables:
1172 closure_pywrappers.append(bind)
1173 else:
1174 value = fn.__globals__[name]
1175 new_globals[name] = PyWrapper(fn, name, value)
1176
1177 # rewrite the original fn. things that look like they will
1178 # become bound parameters are wrapped in a PyWrapper.
1179 self.tracker_instrumented_fn = tracker_instrumented_fn = (
1180 self._rewrite_code_obj(
1181 fn,
1182 [new_closure[name] for name in fn.__code__.co_freevars],
1183 new_globals,
1184 )
1185 )
1186
1187 # now invoke the function. This will give us a new SQL
1188 # expression, but all the places that there would be a bound
1189 # parameter, the PyWrapper in its place will give us a bind
1190 # with a predictable name we can match up later.
1191
1192 # additionally, each PyWrapper will log that it did in fact
1193 # create a parameter, otherwise, it's some kind of Python
1194 # object in the closure and we want to track that, to make
1195 # sure it doesn't change to something else, or if it does,
1196 # that we create a different tracked function with that
1197 # variable.
1198 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
1199
1200 def _coerce_expression(self, lambda_element, apply_propagate_attrs):
1201 """Run the tracker-generated expression through coercion rules.
1202
1203 After the user-defined lambda has been invoked to produce a statement
1204 for re-use, run it through coercion rules to both check that it's the
1205 correct type of object and also to coerce it to its useful form.
1206
1207 """
1208
1209 parent_lambda = lambda_element.parent_lambda
1210 expr = self.expr
1211
1212 if parent_lambda is None:
1213 if isinstance(expr, collections_abc.Sequence):
1214 self.expected_expr = [
1215 cast(
1216 "ClauseElement",
1217 coercions.expect(
1218 lambda_element.role,
1219 sub_expr,
1220 apply_propagate_attrs=apply_propagate_attrs,
1221 ),
1222 )
1223 for sub_expr in expr
1224 ]
1225 self.is_sequence = True
1226 else:
1227 self.expected_expr = cast(
1228 "ClauseElement",
1229 coercions.expect(
1230 lambda_element.role,
1231 expr,
1232 apply_propagate_attrs=apply_propagate_attrs,
1233 ),
1234 )
1235 self.is_sequence = False
1236 else:
1237 self.expected_expr = expr
1238 self.is_sequence = False
1239
1240 if apply_propagate_attrs is not None:
1241 self.propagate_attrs = apply_propagate_attrs._propagate_attrs
1242 else:
1243 self.propagate_attrs = util.EMPTY_DICT
1244
1245 def _rewrite_code_obj(self, f, cell_values, globals_):
1246 """Return a copy of f, with a new closure and new globals
1247
1248 yes it works in pypy :P
1249
1250 """
1251
1252 argrange = range(len(cell_values))
1253
1254 code = "def make_cells():\n"
1255 if cell_values:
1256 code += " (%s) = (%s)\n" % (
1257 ", ".join("i%d" % i for i in argrange),
1258 ", ".join("o%d" % i for i in argrange),
1259 )
1260 code += " def closure():\n"
1261 code += " return %s\n" % ", ".join("i%d" % i for i in argrange)
1262 code += " return closure.__closure__"
1263 vars_ = {"o%d" % i: cell_values[i] for i in argrange}
1264 exec(code, vars_, vars_)
1265 closure = vars_["make_cells"]()
1266
1267 func = type(f)(
1268 f.__code__, globals_, f.__name__, f.__defaults__, closure
1269 )
1270 func.__annotations__ = f.__annotations__
1271 func.__kwdefaults__ = f.__kwdefaults__
1272 func.__doc__ = f.__doc__
1273 func.__module__ = f.__module__
1274
1275 return func
1276
1277
1278class PyWrapper(ColumnOperators):
1279 """A wrapper object that is injected into the ``__globals__`` and
1280 ``__closure__`` of a Python function.
1281
1282 When the function is instrumented with :class:`.PyWrapper` objects, it is
1283 then invoked just once in order to set up the wrappers. We look through
1284 all the :class:`.PyWrapper` objects we made to find the ones that generated
1285 a :class:`.BindParameter` object, e.g. the expression system interpreted
1286 something as a literal. Those positions in the globals/closure are then
1287 ones that we will look at, each time a new lambda comes in that refers to
1288 the same ``__code__`` object. In this way, we keep a single version of
1289 the SQL expression that this lambda produced, without calling upon the
1290 Python function that created it more than once, unless its other closure
1291 variables have changed. The expression is then transformed to have the
1292 new bound values embedded into it.
1293
1294 """
1295
1296 def __init__(
1297 self,
1298 fn,
1299 name,
1300 to_evaluate,
1301 closure_index=None,
1302 getter=None,
1303 track_bound_values=True,
1304 ):
1305 self.fn = fn
1306 self._name = name
1307 self._to_evaluate = to_evaluate
1308 self._param = None
1309 self._has_param = False
1310 self._bind_paths = {}
1311 self._getter = getter
1312 self._closure_index = closure_index
1313 self.track_bound_values = track_bound_values
1314
1315 def __call__(self, *arg, **kw):
1316 elem = object.__getattribute__(self, "_to_evaluate")
1317 value = elem(*arg, **kw)
1318 if (
1319 self._sa_track_bound_values
1320 and coercions._deep_is_literal(value)
1321 and not isinstance(
1322 # TODO: coverage where an ORM option or similar is here
1323 value,
1324 _cache_key.HasCacheKey,
1325 )
1326 ):
1327 name = object.__getattribute__(self, "_name")
1328 raise exc.InvalidRequestError(
1329 "Can't invoke Python callable %s() inside of lambda "
1330 "expression argument at %s; lambda SQL constructs should "
1331 "not invoke functions from closure variables to produce "
1332 "literal values since the "
1333 "lambda SQL system normally extracts bound values without "
1334 "actually "
1335 "invoking the lambda or any functions within it. Call the "
1336 "function outside of the "
1337 "lambda and assign to a local variable that is used in the "
1338 "lambda as a closure variable, or set "
1339 "track_bound_values=False if the return value of this "
1340 "function is used in some other way other than a SQL bound "
1341 "value." % (name, self._sa_fn.__code__)
1342 )
1343 else:
1344 return value
1345
1346 def operate(self, op, *other, **kwargs):
1347 elem = object.__getattribute__(self, "_py_wrapper_literal")()
1348 return op(elem, *other, **kwargs)
1349
1350 def reverse_operate(self, op, other, **kwargs):
1351 elem = object.__getattribute__(self, "_py_wrapper_literal")()
1352 return op(other, elem, **kwargs)
1353
1354 def _extract_bound_parameters(self, starting_point, result_list):
1355 param = object.__getattribute__(self, "_param")
1356 if param is not None:
1357 param = param._with_value(starting_point, maintain_key=True)
1358 result_list.append(param)
1359 for pywrapper in object.__getattribute__(self, "_bind_paths").values():
1360 getter = object.__getattribute__(pywrapper, "_getter")
1361 element = getter(starting_point)
1362 pywrapper._sa__extract_bound_parameters(element, result_list)
1363
1364 def _py_wrapper_literal(self, expr=None, operator=None, **kw):
1365 param = object.__getattribute__(self, "_param")
1366 to_evaluate = object.__getattribute__(self, "_to_evaluate")
1367 if param is None:
1368 name = object.__getattribute__(self, "_name")
1369 self._param = param = elements.BindParameter(
1370 name,
1371 required=False,
1372 unique=True,
1373 _compared_to_operator=operator,
1374 _compared_to_type=expr.type if expr is not None else None,
1375 )
1376 self._has_param = True
1377 return param._with_value(to_evaluate, maintain_key=True)
1378
1379 def __bool__(self):
1380 to_evaluate = object.__getattribute__(self, "_to_evaluate")
1381 return bool(to_evaluate)
1382
1383 def __getattribute__(self, key):
1384 if key.startswith("_sa_"):
1385 return object.__getattribute__(self, key[4:])
1386 elif key in (
1387 "__clause_element__",
1388 "operate",
1389 "reverse_operate",
1390 "_py_wrapper_literal",
1391 "__class__",
1392 "__dict__",
1393 ):
1394 return object.__getattribute__(self, key)
1395
1396 if key.startswith("__"):
1397 elem = object.__getattribute__(self, "_to_evaluate")
1398 return getattr(elem, key)
1399 else:
1400 return self._sa__add_getter(key, operator.attrgetter)
1401
1402 def __iter__(self):
1403 elem = object.__getattribute__(self, "_to_evaluate")
1404 return iter(elem)
1405
1406 def __getitem__(self, key):
1407 elem = object.__getattribute__(self, "_to_evaluate")
1408 if not hasattr(elem, "__getitem__"):
1409 raise AttributeError("__getitem__")
1410
1411 if isinstance(key, PyWrapper):
1412 # TODO: coverage
1413 raise exc.InvalidRequestError(
1414 "Dictionary keys / list indexes inside of a cached "
1415 "lambda must be Python literals only"
1416 )
1417 return self._sa__add_getter(key, operator.itemgetter)
1418
1419 def _add_getter(self, key, getter_fn):
1420 bind_paths = object.__getattribute__(self, "_bind_paths")
1421
1422 bind_path_key = (key, getter_fn)
1423 if bind_path_key in bind_paths:
1424 return bind_paths[bind_path_key]
1425
1426 getter = getter_fn(key)
1427 elem = object.__getattribute__(self, "_to_evaluate")
1428 value = getter(elem)
1429
1430 rolled_down_value = AnalyzedCode._roll_down_to_literal(value)
1431
1432 if coercions._deep_is_literal(rolled_down_value):
1433 wrapper = PyWrapper(self._sa_fn, key, value, getter=getter)
1434 bind_paths[bind_path_key] = wrapper
1435 return wrapper
1436 else:
1437 return value
1438
1439
1440@inspection._inspects(LambdaElement)
1441def insp(lmb):
1442 return inspection.inspect(lmb._resolved)