1# sql/lambdas.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11import collections.abc as collections_abc
12import inspect
13import itertools
14import operator
15import threading
16import types
17from types import CodeType
18from typing import Any
19from typing import Callable
20from typing import cast
21from typing import List
22from typing import MutableMapping
23from typing import Optional
24from typing import Tuple
25from typing import Type
26from typing import TYPE_CHECKING
27from typing import TypeVar
28from typing import Union
29import weakref
30
31from . import cache_key as _cache_key
32from . import coercions
33from . import elements
34from . import roles
35from . import schema
36from . import visitors
37from .base import _clone
38from .base import Executable
39from .base import Options
40from .cache_key import CacheConst
41from .operators import ColumnOperators
42from .. import exc
43from .. import inspection
44from .. import util
45from ..util.typing import Literal
46
47
48if TYPE_CHECKING:
49 from .elements import BindParameter
50 from .elements import ClauseElement
51 from .roles import SQLRole
52 from .visitors import _CloneCallableType
53
54_LambdaCacheType = MutableMapping[
55 Tuple[Any, ...], Union["NonAnalyzedFunction", "AnalyzedFunction"]
56]
57_BoundParameterGetter = Callable[..., Any]
58
59_closure_per_cache_key: _LambdaCacheType = util.LRUCache(1000)
60
61
62_LambdaType = Callable[[], Any]
63
64_AnyLambdaType = Callable[..., Any]
65
66_StmtLambdaType = Callable[[], Any]
67
68_E = TypeVar("_E", bound=Executable)
69_StmtLambdaElementType = Callable[[_E], Any]
70
71
72class LambdaOptions(Options):
73 enable_tracking = True
74 track_closure_variables = True
75 track_on: Optional[object] = None
76 global_track_bound_values = True
77 track_bound_values = True
78 lambda_cache: Optional[_LambdaCacheType] = None
79
80
81def lambda_stmt(
82 lmb: _StmtLambdaType,
83 enable_tracking: bool = True,
84 track_closure_variables: bool = True,
85 track_on: Optional[object] = None,
86 global_track_bound_values: bool = True,
87 track_bound_values: bool = True,
88 lambda_cache: Optional[_LambdaCacheType] = None,
89) -> StatementLambdaElement:
90 """Produce a SQL statement that is cached as a lambda.
91
92 The Python code object within the lambda is scanned for both Python
93 literals that will become bound parameters as well as closure variables
94 that refer to Core or ORM constructs that may vary. The lambda itself
95 will be invoked only once per particular set of constructs detected.
96
97 E.g.::
98
99 from sqlalchemy import lambda_stmt
100
101 stmt = lambda_stmt(lambda: table.select())
102 stmt += lambda s: s.where(table.c.id == 5)
103
104 result = connection.execute(stmt)
105
106 The object returned is an instance of :class:`_sql.StatementLambdaElement`.
107
108 .. versionadded:: 1.4
109
110 :param lmb: a Python function, typically a lambda, which takes no arguments
111 and returns a SQL expression construct
112 :param enable_tracking: when False, all scanning of the given lambda for
113 changes in closure variables or bound parameters is disabled. Use for
114 a lambda that produces the identical results in all cases with no
115 parameterization.
116 :param track_closure_variables: when False, changes in closure variables
117 within the lambda will not be scanned. Use for a lambda where the
118 state of its closure variables will never change the SQL structure
119 returned by the lambda.
120 :param track_bound_values: when False, bound parameter tracking will
121 be disabled for the given lambda. Use for a lambda that either does
122 not produce any bound values, or where the initial bound values never
123 change.
124 :param global_track_bound_values: when False, bound parameter tracking
125 will be disabled for the entire statement including additional links
126 added via the :meth:`_sql.StatementLambdaElement.add_criteria` method.
127 :param lambda_cache: a dictionary or other mapping-like object where
128 information about the lambda's Python code as well as the tracked closure
129 variables in the lambda itself will be stored. Defaults
130 to a global LRU cache. This cache is independent of the "compiled_cache"
131 used by the :class:`_engine.Connection` object.
132
133 .. seealso::
134
135 :ref:`engine_lambda_caching`
136
137
138 """
139
140 return StatementLambdaElement(
141 lmb,
142 roles.StatementRole,
143 LambdaOptions(
144 enable_tracking=enable_tracking,
145 track_on=track_on,
146 track_closure_variables=track_closure_variables,
147 global_track_bound_values=global_track_bound_values,
148 track_bound_values=track_bound_values,
149 lambda_cache=lambda_cache,
150 ),
151 )
152
153
154class LambdaElement(elements.ClauseElement):
155 """A SQL construct where the state is stored as an un-invoked lambda.
156
157 The :class:`_sql.LambdaElement` is produced transparently whenever
158 passing lambda expressions into SQL constructs, such as::
159
160 stmt = select(table).where(lambda: table.c.col == parameter)
161
162 The :class:`_sql.LambdaElement` is the base of the
163 :class:`_sql.StatementLambdaElement` which represents a full statement
164 within a lambda.
165
166 .. versionadded:: 1.4
167
168 .. seealso::
169
170 :ref:`engine_lambda_caching`
171
172 """
173
174 __visit_name__ = "lambda_element"
175
176 _is_lambda_element = True
177
178 _traverse_internals = [
179 ("_resolved", visitors.InternalTraversal.dp_clauseelement)
180 ]
181
182 _transforms: Tuple[_CloneCallableType, ...] = ()
183
184 _resolved_bindparams: List[BindParameter[Any]]
185 parent_lambda: Optional[StatementLambdaElement] = None
186 closure_cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]]
187 role: Type[SQLRole]
188 _rec: Union[AnalyzedFunction, NonAnalyzedFunction]
189 fn: _AnyLambdaType
190 tracker_key: Tuple[CodeType, ...]
191
192 def __repr__(self):
193 return "%s(%r)" % (
194 self.__class__.__name__,
195 self.fn.__code__,
196 )
197
198 def __init__(
199 self,
200 fn: _LambdaType,
201 role: Type[SQLRole],
202 opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
203 apply_propagate_attrs: Optional[ClauseElement] = None,
204 ):
205 self.fn = fn
206 self.role = role
207 self.tracker_key = (fn.__code__,)
208 self.opts = opts
209
210 if apply_propagate_attrs is None and (role is roles.StatementRole):
211 apply_propagate_attrs = self
212
213 rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts)
214
215 if apply_propagate_attrs is not None:
216 propagate_attrs = rec.propagate_attrs
217 if propagate_attrs:
218 apply_propagate_attrs._propagate_attrs = propagate_attrs
219
220 def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts):
221 lambda_cache = opts.lambda_cache
222 if lambda_cache is None:
223 lambda_cache = _closure_per_cache_key
224
225 tracker_key = self.tracker_key
226
227 fn = self.fn
228 closure = fn.__closure__
229 tracker = AnalyzedCode.get(
230 fn,
231 self,
232 opts,
233 )
234
235 bindparams: List[BindParameter[Any]]
236 self._resolved_bindparams = bindparams = []
237
238 if self.parent_lambda is not None:
239 parent_closure_cache_key = self.parent_lambda.closure_cache_key
240 else:
241 parent_closure_cache_key = ()
242
243 cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]]
244
245 if parent_closure_cache_key is not _cache_key.NO_CACHE:
246 anon_map = visitors.anon_map()
247 cache_key = tuple(
248 [
249 getter(closure, opts, anon_map, bindparams)
250 for getter in tracker.closure_trackers
251 ]
252 )
253
254 if _cache_key.NO_CACHE not in anon_map:
255 cache_key = parent_closure_cache_key + cache_key
256
257 self.closure_cache_key = cache_key
258
259 rec = lambda_cache.get(tracker_key + cache_key)
260 else:
261 cache_key = _cache_key.NO_CACHE
262 rec = None
263
264 else:
265 cache_key = _cache_key.NO_CACHE
266 rec = None
267
268 self.closure_cache_key = cache_key
269
270 if rec is None:
271 if cache_key is not _cache_key.NO_CACHE:
272 with AnalyzedCode._generation_mutex:
273 key = tracker_key + cache_key
274 if key not in lambda_cache:
275 rec = AnalyzedFunction(
276 tracker, self, apply_propagate_attrs, fn
277 )
278 rec.closure_bindparams = list(bindparams)
279 lambda_cache[key] = rec
280 else:
281 rec = lambda_cache[key]
282 else:
283 rec = NonAnalyzedFunction(self._invoke_user_fn(fn))
284
285 else:
286 bindparams[:] = [
287 orig_bind._with_value(new_bind.value, maintain_key=True)
288 for orig_bind, new_bind in zip(
289 rec.closure_bindparams, bindparams
290 )
291 ]
292
293 self._rec = rec
294
295 if cache_key is not _cache_key.NO_CACHE:
296 if self.parent_lambda is not None:
297 bindparams[:0] = self.parent_lambda._resolved_bindparams
298
299 lambda_element: Optional[LambdaElement] = self
300 while lambda_element is not None:
301 rec = lambda_element._rec
302 if rec.bindparam_trackers:
303 tracker_instrumented_fn = rec.tracker_instrumented_fn
304 for tracker in rec.bindparam_trackers:
305 tracker(
306 lambda_element.fn,
307 tracker_instrumented_fn,
308 bindparams,
309 )
310 lambda_element = lambda_element.parent_lambda
311
312 return rec
313
314 def __getattr__(self, key):
315 return getattr(self._rec.expected_expr, key)
316
317 @property
318 def _is_sequence(self):
319 return self._rec.is_sequence
320
321 @property
322 def _select_iterable(self):
323 if self._is_sequence:
324 return itertools.chain.from_iterable(
325 [element._select_iterable for element in self._resolved]
326 )
327
328 else:
329 return self._resolved._select_iterable
330
331 @property
332 def _from_objects(self):
333 if self._is_sequence:
334 return itertools.chain.from_iterable(
335 [element._from_objects for element in self._resolved]
336 )
337
338 else:
339 return self._resolved._from_objects
340
341 def _param_dict(self):
342 return {b.key: b.value for b in self._resolved_bindparams}
343
344 def _setup_binds_for_tracked_expr(self, expr):
345 bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
346
347 def replace(
348 element: Optional[visitors.ExternallyTraversible], **kw: Any
349 ) -> Optional[visitors.ExternallyTraversible]:
350 if isinstance(element, elements.BindParameter):
351 if element.key in bindparam_lookup:
352 bind = bindparam_lookup[element.key]
353 if element.expanding:
354 bind.expanding = True
355 bind.expand_op = element.expand_op
356 bind.type = element.type
357 return bind
358
359 return None
360
361 if self._rec.is_sequence:
362 expr = [
363 visitors.replacement_traverse(sub_expr, {}, replace)
364 for sub_expr in expr
365 ]
366 elif getattr(expr, "is_clause_element", False):
367 expr = visitors.replacement_traverse(expr, {}, replace)
368
369 return expr
370
371 def _copy_internals(
372 self,
373 clone: _CloneCallableType = _clone,
374 deferred_copy_internals: Optional[_CloneCallableType] = None,
375 **kw: Any,
376 ) -> None:
377 # TODO: this needs A LOT of tests
378 self._resolved = clone(
379 self._resolved,
380 deferred_copy_internals=deferred_copy_internals,
381 **kw,
382 )
383
384 @util.memoized_property
385 def _resolved(self):
386 expr = self._rec.expected_expr
387
388 if self._resolved_bindparams:
389 expr = self._setup_binds_for_tracked_expr(expr)
390
391 return expr
392
393 def _gen_cache_key(self, anon_map, bindparams):
394 if self.closure_cache_key is _cache_key.NO_CACHE:
395 anon_map[_cache_key.NO_CACHE] = True
396 return None
397
398 cache_key = (
399 self.fn.__code__,
400 self.__class__,
401 ) + self.closure_cache_key
402
403 parent = self.parent_lambda
404
405 while parent is not None:
406 assert parent.closure_cache_key is not CacheConst.NO_CACHE
407 parent_closure_cache_key: Tuple[Any, ...] = (
408 parent.closure_cache_key
409 )
410
411 cache_key = (
412 (parent.fn.__code__,) + parent_closure_cache_key + cache_key
413 )
414
415 parent = parent.parent_lambda
416
417 if self._resolved_bindparams:
418 bindparams.extend(self._resolved_bindparams)
419 return cache_key
420
421 def _invoke_user_fn(self, fn: _AnyLambdaType, *arg: Any) -> ClauseElement:
422 return fn() # type: ignore[no-any-return]
423
424
425class DeferredLambdaElement(LambdaElement):
426 """A LambdaElement where the lambda accepts arguments and is
427 invoked within the compile phase with special context.
428
429 This lambda doesn't normally produce its real SQL expression outside of the
430 compile phase. It is passed a fixed set of initial arguments
431 so that it can generate a sample expression.
432
433 """
434
435 def __init__(
436 self,
437 fn: _AnyLambdaType,
438 role: Type[roles.SQLRole],
439 opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
440 lambda_args: Tuple[Any, ...] = (),
441 ):
442 self.lambda_args = lambda_args
443 super().__init__(fn, role, opts)
444
445 def _invoke_user_fn(self, fn, *arg):
446 return fn(*self.lambda_args)
447
448 def _resolve_with_args(self, *lambda_args: Any) -> ClauseElement:
449 assert isinstance(self._rec, AnalyzedFunction)
450 tracker_fn = self._rec.tracker_instrumented_fn
451 expr = tracker_fn(*lambda_args)
452
453 expr = coercions.expect(self.role, expr)
454
455 expr = self._setup_binds_for_tracked_expr(expr)
456
457 # this validation is getting very close, but not quite, to achieving
458 # #5767. The problem is if the base lambda uses an unnamed column
459 # as is very common with mixins, the parameter name is different
460 # and it produces a false positive; that is, for the documented case
461 # that is exactly what people will be doing, it doesn't work, so
462 # I'm not really sure how to handle this right now.
463 # expected_binds = [
464 # b._orig_key
465 # for b in self._rec.expr._generate_cache_key()[1]
466 # if b.required
467 # ]
468 # got_binds = [
469 # b._orig_key for b in expr._generate_cache_key()[1] if b.required
470 # ]
471 # if expected_binds != got_binds:
472 # raise exc.InvalidRequestError(
473 # "Lambda callable at %s produced a different set of bound "
474 # "parameters than its original run: %s"
475 # % (self.fn.__code__, ", ".join(got_binds))
476 # )
477
478 # TODO: TEST TEST TEST, this is very out there
479 for deferred_copy_internals in self._transforms:
480 expr = deferred_copy_internals(expr)
481
482 return expr # type: ignore
483
484 def _copy_internals(
485 self, clone=_clone, deferred_copy_internals=None, **kw
486 ):
487 super()._copy_internals(
488 clone=clone,
489 deferred_copy_internals=deferred_copy_internals, # **kw
490 opts=kw,
491 )
492
493 # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know
494 # our expression yet. so hold onto the replacement
495 if deferred_copy_internals:
496 self._transforms += (deferred_copy_internals,)
497
498
499class StatementLambdaElement(
500 roles.AllowsLambdaRole, LambdaElement, Executable
501):
502 """Represent a composable SQL statement as a :class:`_sql.LambdaElement`.
503
504 The :class:`_sql.StatementLambdaElement` is constructed using the
505 :func:`_sql.lambda_stmt` function::
506
507
508 from sqlalchemy import lambda_stmt
509
510 stmt = lambda_stmt(lambda: select(table))
511
512 Once constructed, additional criteria can be built onto the statement
513 by adding subsequent lambdas, which accept the existing statement
514 object as a single parameter::
515
516 stmt += lambda s: s.where(table.c.col == parameter)
517
518 .. versionadded:: 1.4
519
520 .. seealso::
521
522 :ref:`engine_lambda_caching`
523
524 """
525
526 if TYPE_CHECKING:
527
528 def __init__(
529 self,
530 fn: _StmtLambdaType,
531 role: Type[SQLRole],
532 opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
533 apply_propagate_attrs: Optional[ClauseElement] = None,
534 ): ...
535
536 def __add__(
537 self, other: _StmtLambdaElementType[Any]
538 ) -> StatementLambdaElement:
539 return self.add_criteria(other)
540
541 def add_criteria(
542 self,
543 other: _StmtLambdaElementType[Any],
544 enable_tracking: bool = True,
545 track_on: Optional[Any] = None,
546 track_closure_variables: bool = True,
547 track_bound_values: bool = True,
548 ) -> StatementLambdaElement:
549 """Add new criteria to this :class:`_sql.StatementLambdaElement`.
550
551 E.g.::
552
553 >>> def my_stmt(parameter):
554 ... stmt = lambda_stmt(
555 ... lambda: select(table.c.x, table.c.y),
556 ... )
557 ... stmt = stmt.add_criteria(lambda: table.c.x > parameter)
558 ... return stmt
559
560 The :meth:`_sql.StatementLambdaElement.add_criteria` method is
561 equivalent to using the Python addition operator to add a new
562 lambda, except that additional arguments may be added including
563 ``track_closure_values`` and ``track_on``::
564
565 >>> def my_stmt(self, foo):
566 ... stmt = lambda_stmt(
567 ... lambda: select(func.max(foo.x, foo.y)),
568 ... track_closure_variables=False,
569 ... )
570 ... stmt = stmt.add_criteria(lambda: self.where_criteria, track_on=[self])
571 ... return stmt
572
573 See :func:`_sql.lambda_stmt` for a description of the parameters
574 accepted.
575
576 """ # noqa: E501
577
578 opts = self.opts + dict(
579 enable_tracking=enable_tracking,
580 track_closure_variables=track_closure_variables,
581 global_track_bound_values=self.opts.global_track_bound_values,
582 track_on=track_on,
583 track_bound_values=track_bound_values,
584 )
585
586 return LinkedLambdaElement(other, parent_lambda=self, opts=opts)
587
588 def _execute_on_connection(
589 self, connection, distilled_params, execution_options
590 ):
591 if TYPE_CHECKING:
592 assert isinstance(self._rec.expected_expr, ClauseElement)
593 if self._rec.expected_expr.supports_execution:
594 return connection._execute_clauseelement(
595 self, distilled_params, execution_options
596 )
597 else:
598 raise exc.ObjectNotExecutableError(self)
599
600 @property
601 def _proxied(self) -> Any:
602 return self._rec_expected_expr
603
604 @property
605 def _with_options(self):
606 return self._proxied._with_options
607
608 @property
609 def _effective_plugin_target(self):
610 return self._proxied._effective_plugin_target
611
612 @property
613 def _execution_options(self):
614 return self._proxied._execution_options
615
616 @property
617 def _all_selected_columns(self):
618 return self._proxied._all_selected_columns
619
620 @property
621 def is_select(self):
622 return self._proxied.is_select
623
624 @property
625 def is_update(self):
626 return self._proxied.is_update
627
628 @property
629 def is_insert(self):
630 return self._proxied.is_insert
631
632 @property
633 def is_text(self):
634 return self._proxied.is_text
635
636 @property
637 def is_delete(self):
638 return self._proxied.is_delete
639
640 @property
641 def is_dml(self):
642 return self._proxied.is_dml
643
644 def spoil(self) -> NullLambdaStatement:
645 """Return a new :class:`.StatementLambdaElement` that will run
646 all lambdas unconditionally each time.
647
648 """
649 return NullLambdaStatement(self.fn())
650
651
652class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement):
653 """Provides the :class:`.StatementLambdaElement` API but does not
654 cache or analyze lambdas.
655
656 the lambdas are instead invoked immediately.
657
658 The intended use is to isolate issues that may arise when using
659 lambda statements.
660
661 """
662
663 __visit_name__ = "lambda_element"
664
665 _is_lambda_element = True
666
667 _traverse_internals = [
668 ("_resolved", visitors.InternalTraversal.dp_clauseelement)
669 ]
670
671 def __init__(self, statement):
672 self._resolved = statement
673 self._propagate_attrs = statement._propagate_attrs
674
675 def __getattr__(self, key):
676 return getattr(self._resolved, key)
677
678 def __add__(self, other):
679 statement = other(self._resolved)
680
681 return NullLambdaStatement(statement)
682
683 def add_criteria(self, other, **kw):
684 statement = other(self._resolved)
685
686 return NullLambdaStatement(statement)
687
688 def _execute_on_connection(
689 self, connection, distilled_params, execution_options
690 ):
691 if self._resolved.supports_execution:
692 return connection._execute_clauseelement(
693 self, distilled_params, execution_options
694 )
695 else:
696 raise exc.ObjectNotExecutableError(self)
697
698
699class LinkedLambdaElement(StatementLambdaElement):
700 """Represent subsequent links of a :class:`.StatementLambdaElement`."""
701
702 parent_lambda: StatementLambdaElement
703
704 def __init__(
705 self,
706 fn: _StmtLambdaElementType[Any],
707 parent_lambda: StatementLambdaElement,
708 opts: Union[Type[LambdaOptions], LambdaOptions],
709 ):
710 self.opts = opts
711 self.fn = fn
712 self.parent_lambda = parent_lambda
713
714 self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
715 self._retrieve_tracker_rec(fn, self, opts)
716 self._propagate_attrs = parent_lambda._propagate_attrs
717
718 def _invoke_user_fn(self, fn, *arg):
719 return fn(self.parent_lambda._resolved)
720
721
722class AnalyzedCode:
723 __slots__ = (
724 "track_closure_variables",
725 "track_bound_values",
726 "bindparam_trackers",
727 "closure_trackers",
728 "build_py_wrappers",
729 )
730 _fns: weakref.WeakKeyDictionary[CodeType, AnalyzedCode] = (
731 weakref.WeakKeyDictionary()
732 )
733
734 _generation_mutex = threading.RLock()
735
736 @classmethod
737 def get(cls, fn, lambda_element, lambda_kw, **kw):
738 try:
739 # TODO: validate kw haven't changed?
740 return cls._fns[fn.__code__]
741 except KeyError:
742 pass
743
744 with cls._generation_mutex:
745 # check for other thread already created object
746 if fn.__code__ in cls._fns:
747 return cls._fns[fn.__code__]
748
749 analyzed: AnalyzedCode
750 cls._fns[fn.__code__] = analyzed = AnalyzedCode(
751 fn, lambda_element, lambda_kw, **kw
752 )
753 return analyzed
754
755 def __init__(self, fn, lambda_element, opts):
756 if inspect.ismethod(fn):
757 raise exc.ArgumentError(
758 "Method %s may not be passed as a SQL expression" % fn
759 )
760 closure = fn.__closure__
761
762 self.track_bound_values = (
763 opts.track_bound_values and opts.global_track_bound_values
764 )
765 enable_tracking = opts.enable_tracking
766 track_on = opts.track_on
767 track_closure_variables = opts.track_closure_variables
768
769 self.track_closure_variables = track_closure_variables and not track_on
770
771 # a list of callables generated from _bound_parameter_getter_*
772 # functions. Each of these uses a PyWrapper object to retrieve
773 # a parameter value
774 self.bindparam_trackers = []
775
776 # a list of callables generated from _cache_key_getter_* functions
777 # these callables work to generate a cache key for the lambda
778 # based on what's inside its closure variables.
779 self.closure_trackers = []
780
781 self.build_py_wrappers = []
782
783 if enable_tracking:
784 if track_on:
785 self._init_track_on(track_on)
786
787 self._init_globals(fn)
788
789 if closure:
790 self._init_closure(fn)
791
792 self._setup_additional_closure_trackers(fn, lambda_element, opts)
793
794 def _init_track_on(self, track_on):
795 self.closure_trackers.extend(
796 self._cache_key_getter_track_on(idx, elem)
797 for idx, elem in enumerate(track_on)
798 )
799
800 def _init_globals(self, fn):
801 build_py_wrappers = self.build_py_wrappers
802 bindparam_trackers = self.bindparam_trackers
803 track_bound_values = self.track_bound_values
804
805 for name in fn.__code__.co_names:
806 if name not in fn.__globals__:
807 continue
808
809 _bound_value = self._roll_down_to_literal(fn.__globals__[name])
810
811 if coercions._deep_is_literal(_bound_value):
812 build_py_wrappers.append((name, None))
813 if track_bound_values:
814 bindparam_trackers.append(
815 self._bound_parameter_getter_func_globals(name)
816 )
817
818 def _init_closure(self, fn):
819 build_py_wrappers = self.build_py_wrappers
820 closure = fn.__closure__
821
822 track_bound_values = self.track_bound_values
823 track_closure_variables = self.track_closure_variables
824 bindparam_trackers = self.bindparam_trackers
825 closure_trackers = self.closure_trackers
826
827 for closure_index, (fv, cell) in enumerate(
828 zip(fn.__code__.co_freevars, closure)
829 ):
830 _bound_value = self._roll_down_to_literal(cell.cell_contents)
831
832 if coercions._deep_is_literal(_bound_value):
833 build_py_wrappers.append((fv, closure_index))
834 if track_bound_values:
835 bindparam_trackers.append(
836 self._bound_parameter_getter_func_closure(
837 fv, closure_index
838 )
839 )
840 else:
841 # for normal cell contents, add them to a list that
842 # we can compare later when we get new lambdas. if
843 # any identities have changed, then we will
844 # recalculate the whole lambda and run it again.
845
846 if track_closure_variables:
847 closure_trackers.append(
848 self._cache_key_getter_closure_variable(
849 fn, fv, closure_index, cell.cell_contents
850 )
851 )
852
853 def _setup_additional_closure_trackers(self, fn, lambda_element, opts):
854 # an additional step is to actually run the function, then
855 # go through the PyWrapper objects that were set up to catch a bound
856 # parameter. then if they *didn't* make a param, oh they're another
857 # object in the closure we have to track for our cache key. so
858 # create trackers to catch those.
859
860 analyzed_function = AnalyzedFunction(
861 self,
862 lambda_element,
863 None,
864 fn,
865 )
866
867 closure_trackers = self.closure_trackers
868
869 for pywrapper in analyzed_function.closure_pywrappers:
870 if not pywrapper._sa__has_param:
871 closure_trackers.append(
872 self._cache_key_getter_tracked_literal(fn, pywrapper)
873 )
874
875 @classmethod
876 def _roll_down_to_literal(cls, element):
877 is_clause_element = hasattr(element, "__clause_element__")
878
879 if is_clause_element:
880 while not isinstance(
881 element, (elements.ClauseElement, schema.SchemaItem, type)
882 ):
883 try:
884 element = element.__clause_element__()
885 except AttributeError:
886 break
887
888 if not is_clause_element:
889 insp = inspection.inspect(element, raiseerr=False)
890 if insp is not None:
891 try:
892 return insp.__clause_element__()
893 except AttributeError:
894 return insp
895
896 # TODO: should we coerce consts None/True/False here?
897 return element
898 else:
899 return element
900
901 def _bound_parameter_getter_func_globals(self, name):
902 """Return a getter that will extend a list of bound parameters
903 with new entries from the ``__globals__`` collection of a particular
904 lambda.
905
906 """
907
908 def extract_parameter_value(
909 current_fn, tracker_instrumented_fn, result
910 ):
911 wrapper = tracker_instrumented_fn.__globals__[name]
912 object.__getattribute__(wrapper, "_extract_bound_parameters")(
913 current_fn.__globals__[name], result
914 )
915
916 return extract_parameter_value
917
918 def _bound_parameter_getter_func_closure(self, name, closure_index):
919 """Return a getter that will extend a list of bound parameters
920 with new entries from the ``__closure__`` collection of a particular
921 lambda.
922
923 """
924
925 def extract_parameter_value(
926 current_fn, tracker_instrumented_fn, result
927 ):
928 wrapper = tracker_instrumented_fn.__closure__[
929 closure_index
930 ].cell_contents
931 object.__getattribute__(wrapper, "_extract_bound_parameters")(
932 current_fn.__closure__[closure_index].cell_contents, result
933 )
934
935 return extract_parameter_value
936
937 def _cache_key_getter_track_on(self, idx, elem):
938 """Return a getter that will extend a cache key with new entries
939 from the "track_on" parameter passed to a :class:`.LambdaElement`.
940
941 """
942
943 if isinstance(elem, tuple):
944 # tuple must contain hascachekey elements
945 def get(closure, opts, anon_map, bindparams):
946 return tuple(
947 tup_elem._gen_cache_key(anon_map, bindparams)
948 for tup_elem in opts.track_on[idx]
949 )
950
951 elif isinstance(elem, _cache_key.HasCacheKey):
952
953 def get(closure, opts, anon_map, bindparams):
954 return opts.track_on[idx]._gen_cache_key(anon_map, bindparams)
955
956 else:
957
958 def get(closure, opts, anon_map, bindparams):
959 return opts.track_on[idx]
960
961 return get
962
963 def _cache_key_getter_closure_variable(
964 self,
965 fn,
966 variable_name,
967 idx,
968 cell_contents,
969 use_clause_element=False,
970 use_inspect=False,
971 ):
972 """Return a getter that will extend a cache key with new entries
973 from the ``__closure__`` collection of a particular lambda.
974
975 """
976
977 if isinstance(cell_contents, _cache_key.HasCacheKey):
978
979 def get(closure, opts, anon_map, bindparams):
980 obj = closure[idx].cell_contents
981 if use_inspect:
982 obj = inspection.inspect(obj)
983 elif use_clause_element:
984 while hasattr(obj, "__clause_element__"):
985 if not getattr(obj, "is_clause_element", False):
986 obj = obj.__clause_element__()
987
988 return obj._gen_cache_key(anon_map, bindparams)
989
990 elif isinstance(cell_contents, types.FunctionType):
991
992 def get(closure, opts, anon_map, bindparams):
993 return closure[idx].cell_contents.__code__
994
995 elif isinstance(cell_contents, collections_abc.Sequence):
996
997 def get(closure, opts, anon_map, bindparams):
998 contents = closure[idx].cell_contents
999
1000 try:
1001 return tuple(
1002 elem._gen_cache_key(anon_map, bindparams)
1003 for elem in contents
1004 )
1005 except AttributeError as ae:
1006 self._raise_for_uncacheable_closure_variable(
1007 variable_name, fn, from_=ae
1008 )
1009
1010 else:
1011 # if the object is a mapped class or aliased class, or some
1012 # other object in the ORM realm of things like that, imitate
1013 # the logic used in coercions.expect() to roll it down to the
1014 # SQL element
1015 element = cell_contents
1016 is_clause_element = False
1017 while hasattr(element, "__clause_element__"):
1018 is_clause_element = True
1019 if not getattr(element, "is_clause_element", False):
1020 element = element.__clause_element__()
1021 else:
1022 break
1023
1024 if not is_clause_element:
1025 insp = inspection.inspect(element, raiseerr=False)
1026 if insp is not None:
1027 return self._cache_key_getter_closure_variable(
1028 fn, variable_name, idx, insp, use_inspect=True
1029 )
1030 else:
1031 return self._cache_key_getter_closure_variable(
1032 fn, variable_name, idx, element, use_clause_element=True
1033 )
1034
1035 self._raise_for_uncacheable_closure_variable(variable_name, fn)
1036
1037 return get
1038
1039 def _raise_for_uncacheable_closure_variable(
1040 self, variable_name, fn, from_=None
1041 ):
1042 raise exc.InvalidRequestError(
1043 "Closure variable named '%s' inside of lambda callable %s "
1044 "does not refer to a cacheable SQL element, and also does not "
1045 "appear to be serving as a SQL literal bound value based on "
1046 "the default "
1047 "SQL expression returned by the function. This variable "
1048 "needs to remain outside the scope of a SQL-generating lambda "
1049 "so that a proper cache key may be generated from the "
1050 "lambda's state. Evaluate this variable outside of the "
1051 "lambda, set track_on=[<elements>] to explicitly select "
1052 "closure elements to track, or set "
1053 "track_closure_variables=False to exclude "
1054 "closure variables from being part of the cache key."
1055 % (variable_name, fn.__code__),
1056 ) from from_
1057
1058 def _cache_key_getter_tracked_literal(self, fn, pytracker):
1059 """Return a getter that will extend a cache key with new entries
1060 from the ``__closure__`` collection of a particular lambda.
1061
1062 this getter differs from _cache_key_getter_closure_variable
1063 in that these are detected after the function is run, and PyWrapper
1064 objects have recorded that a particular literal value is in fact
1065 not being interpreted as a bound parameter.
1066
1067 """
1068
1069 elem = pytracker._sa__to_evaluate
1070 closure_index = pytracker._sa__closure_index
1071 variable_name = pytracker._sa__name
1072
1073 return self._cache_key_getter_closure_variable(
1074 fn, variable_name, closure_index, elem
1075 )
1076
1077
1078class NonAnalyzedFunction:
1079 __slots__ = ("expr",)
1080
1081 closure_bindparams: Optional[List[BindParameter[Any]]] = None
1082 bindparam_trackers: Optional[List[_BoundParameterGetter]] = None
1083
1084 is_sequence = False
1085
1086 expr: ClauseElement
1087
1088 def __init__(self, expr: ClauseElement):
1089 self.expr = expr
1090
1091 @property
1092 def expected_expr(self) -> ClauseElement:
1093 return self.expr
1094
1095
1096class AnalyzedFunction:
1097 __slots__ = (
1098 "analyzed_code",
1099 "fn",
1100 "closure_pywrappers",
1101 "tracker_instrumented_fn",
1102 "expr",
1103 "bindparam_trackers",
1104 "expected_expr",
1105 "is_sequence",
1106 "propagate_attrs",
1107 "closure_bindparams",
1108 )
1109
1110 closure_bindparams: Optional[List[BindParameter[Any]]]
1111 expected_expr: Union[ClauseElement, List[ClauseElement]]
1112 bindparam_trackers: Optional[List[_BoundParameterGetter]]
1113
1114 def __init__(
1115 self,
1116 analyzed_code,
1117 lambda_element,
1118 apply_propagate_attrs,
1119 fn,
1120 ):
1121 self.analyzed_code = analyzed_code
1122 self.fn = fn
1123
1124 self.bindparam_trackers = analyzed_code.bindparam_trackers
1125
1126 self._instrument_and_run_function(lambda_element)
1127
1128 self._coerce_expression(lambda_element, apply_propagate_attrs)
1129
1130 def _instrument_and_run_function(self, lambda_element):
1131 analyzed_code = self.analyzed_code
1132
1133 fn = self.fn
1134 self.closure_pywrappers = closure_pywrappers = []
1135
1136 build_py_wrappers = analyzed_code.build_py_wrappers
1137
1138 if not build_py_wrappers:
1139 self.tracker_instrumented_fn = tracker_instrumented_fn = fn
1140 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
1141 else:
1142 track_closure_variables = analyzed_code.track_closure_variables
1143 closure = fn.__closure__
1144
1145 # will form the __closure__ of the function when we rebuild it
1146 if closure:
1147 new_closure = {
1148 fv: cell.cell_contents
1149 for fv, cell in zip(fn.__code__.co_freevars, closure)
1150 }
1151 else:
1152 new_closure = {}
1153
1154 # will form the __globals__ of the function when we rebuild it
1155 new_globals = fn.__globals__.copy()
1156
1157 for name, closure_index in build_py_wrappers:
1158 if closure_index is not None:
1159 value = closure[closure_index].cell_contents
1160 new_closure[name] = bind = PyWrapper(
1161 fn,
1162 name,
1163 value,
1164 closure_index=closure_index,
1165 track_bound_values=(
1166 self.analyzed_code.track_bound_values
1167 ),
1168 )
1169 if track_closure_variables:
1170 closure_pywrappers.append(bind)
1171 else:
1172 value = fn.__globals__[name]
1173 new_globals[name] = PyWrapper(fn, name, value)
1174
1175 # rewrite the original fn. things that look like they will
1176 # become bound parameters are wrapped in a PyWrapper.
1177 self.tracker_instrumented_fn = tracker_instrumented_fn = (
1178 self._rewrite_code_obj(
1179 fn,
1180 [new_closure[name] for name in fn.__code__.co_freevars],
1181 new_globals,
1182 )
1183 )
1184
1185 # now invoke the function. This will give us a new SQL
1186 # expression, but all the places that there would be a bound
1187 # parameter, the PyWrapper in its place will give us a bind
1188 # with a predictable name we can match up later.
1189
1190 # additionally, each PyWrapper will log that it did in fact
1191 # create a parameter, otherwise, it's some kind of Python
1192 # object in the closure and we want to track that, to make
1193 # sure it doesn't change to something else, or if it does,
1194 # that we create a different tracked function with that
1195 # variable.
1196 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
1197
1198 def _coerce_expression(self, lambda_element, apply_propagate_attrs):
1199 """Run the tracker-generated expression through coercion rules.
1200
1201 After the user-defined lambda has been invoked to produce a statement
1202 for re-use, run it through coercion rules to both check that it's the
1203 correct type of object and also to coerce it to its useful form.
1204
1205 """
1206
1207 parent_lambda = lambda_element.parent_lambda
1208 expr = self.expr
1209
1210 if parent_lambda is None:
1211 if isinstance(expr, collections_abc.Sequence):
1212 self.expected_expr = [
1213 cast(
1214 "ClauseElement",
1215 coercions.expect(
1216 lambda_element.role,
1217 sub_expr,
1218 apply_propagate_attrs=apply_propagate_attrs,
1219 ),
1220 )
1221 for sub_expr in expr
1222 ]
1223 self.is_sequence = True
1224 else:
1225 self.expected_expr = cast(
1226 "ClauseElement",
1227 coercions.expect(
1228 lambda_element.role,
1229 expr,
1230 apply_propagate_attrs=apply_propagate_attrs,
1231 ),
1232 )
1233 self.is_sequence = False
1234 else:
1235 self.expected_expr = expr
1236 self.is_sequence = False
1237
1238 if apply_propagate_attrs is not None:
1239 self.propagate_attrs = apply_propagate_attrs._propagate_attrs
1240 else:
1241 self.propagate_attrs = util.EMPTY_DICT
1242
1243 def _rewrite_code_obj(self, f, cell_values, globals_):
1244 """Return a copy of f, with a new closure and new globals
1245
1246 yes it works in pypy :P
1247
1248 """
1249
1250 argrange = range(len(cell_values))
1251
1252 code = "def make_cells():\n"
1253 if cell_values:
1254 code += " (%s) = (%s)\n" % (
1255 ", ".join("i%d" % i for i in argrange),
1256 ", ".join("o%d" % i for i in argrange),
1257 )
1258 code += " def closure():\n"
1259 code += " return %s\n" % ", ".join("i%d" % i for i in argrange)
1260 code += " return closure.__closure__"
1261 vars_ = {"o%d" % i: cell_values[i] for i in argrange}
1262 exec(code, vars_, vars_)
1263 closure = vars_["make_cells"]()
1264
1265 func = type(f)(
1266 f.__code__, globals_, f.__name__, f.__defaults__, closure
1267 )
1268 func.__annotations__ = f.__annotations__
1269 func.__kwdefaults__ = f.__kwdefaults__
1270 func.__doc__ = f.__doc__
1271 func.__module__ = f.__module__
1272
1273 return func
1274
1275
1276class PyWrapper(ColumnOperators):
1277 """A wrapper object that is injected into the ``__globals__`` and
1278 ``__closure__`` of a Python function.
1279
1280 When the function is instrumented with :class:`.PyWrapper` objects, it is
1281 then invoked just once in order to set up the wrappers. We look through
1282 all the :class:`.PyWrapper` objects we made to find the ones that generated
1283 a :class:`.BindParameter` object, e.g. the expression system interpreted
1284 something as a literal. Those positions in the globals/closure are then
1285 ones that we will look at, each time a new lambda comes in that refers to
1286 the same ``__code__`` object. In this way, we keep a single version of
1287 the SQL expression that this lambda produced, without calling upon the
1288 Python function that created it more than once, unless its other closure
1289 variables have changed. The expression is then transformed to have the
1290 new bound values embedded into it.
1291
1292 """
1293
1294 def __init__(
1295 self,
1296 fn,
1297 name,
1298 to_evaluate,
1299 closure_index=None,
1300 getter=None,
1301 track_bound_values=True,
1302 ):
1303 self.fn = fn
1304 self._name = name
1305 self._to_evaluate = to_evaluate
1306 self._param = None
1307 self._has_param = False
1308 self._bind_paths = {}
1309 self._getter = getter
1310 self._closure_index = closure_index
1311 self.track_bound_values = track_bound_values
1312
1313 def __call__(self, *arg, **kw):
1314 elem = object.__getattribute__(self, "_to_evaluate")
1315 value = elem(*arg, **kw)
1316 if (
1317 self._sa_track_bound_values
1318 and coercions._deep_is_literal(value)
1319 and not isinstance(
1320 # TODO: coverage where an ORM option or similar is here
1321 value,
1322 _cache_key.HasCacheKey,
1323 )
1324 ):
1325 name = object.__getattribute__(self, "_name")
1326 raise exc.InvalidRequestError(
1327 "Can't invoke Python callable %s() inside of lambda "
1328 "expression argument at %s; lambda SQL constructs should "
1329 "not invoke functions from closure variables to produce "
1330 "literal values since the "
1331 "lambda SQL system normally extracts bound values without "
1332 "actually "
1333 "invoking the lambda or any functions within it. Call the "
1334 "function outside of the "
1335 "lambda and assign to a local variable that is used in the "
1336 "lambda as a closure variable, or set "
1337 "track_bound_values=False if the return value of this "
1338 "function is used in some other way other than a SQL bound "
1339 "value." % (name, self._sa_fn.__code__)
1340 )
1341 else:
1342 return value
1343
1344 def operate(self, op, *other, **kwargs):
1345 elem = object.__getattribute__(self, "_py_wrapper_literal")()
1346 return op(elem, *other, **kwargs)
1347
1348 def reverse_operate(self, op, other, **kwargs):
1349 elem = object.__getattribute__(self, "_py_wrapper_literal")()
1350 return op(other, elem, **kwargs)
1351
1352 def _extract_bound_parameters(self, starting_point, result_list):
1353 param = object.__getattribute__(self, "_param")
1354 if param is not None:
1355 param = param._with_value(starting_point, maintain_key=True)
1356 result_list.append(param)
1357 for pywrapper in object.__getattribute__(self, "_bind_paths").values():
1358 getter = object.__getattribute__(pywrapper, "_getter")
1359 element = getter(starting_point)
1360 pywrapper._sa__extract_bound_parameters(element, result_list)
1361
1362 def _py_wrapper_literal(self, expr=None, operator=None, **kw):
1363 param = object.__getattribute__(self, "_param")
1364 to_evaluate = object.__getattribute__(self, "_to_evaluate")
1365 if param is None:
1366 name = object.__getattribute__(self, "_name")
1367 self._param = param = elements.BindParameter(
1368 name,
1369 required=False,
1370 unique=True,
1371 _compared_to_operator=operator,
1372 _compared_to_type=expr.type if expr is not None else None,
1373 )
1374 self._has_param = True
1375 return param._with_value(to_evaluate, maintain_key=True)
1376
1377 def __bool__(self):
1378 to_evaluate = object.__getattribute__(self, "_to_evaluate")
1379 return bool(to_evaluate)
1380
1381 def __getattribute__(self, key):
1382 if key.startswith("_sa_"):
1383 return object.__getattribute__(self, key[4:])
1384 elif key in (
1385 "__clause_element__",
1386 "operate",
1387 "reverse_operate",
1388 "_py_wrapper_literal",
1389 "__class__",
1390 "__dict__",
1391 ):
1392 return object.__getattribute__(self, key)
1393
1394 if key.startswith("__"):
1395 elem = object.__getattribute__(self, "_to_evaluate")
1396 return getattr(elem, key)
1397 else:
1398 return self._sa__add_getter(key, operator.attrgetter)
1399
1400 def __iter__(self):
1401 elem = object.__getattribute__(self, "_to_evaluate")
1402 return iter(elem)
1403
1404 def __getitem__(self, key):
1405 elem = object.__getattribute__(self, "_to_evaluate")
1406 if not hasattr(elem, "__getitem__"):
1407 raise AttributeError("__getitem__")
1408
1409 if isinstance(key, PyWrapper):
1410 # TODO: coverage
1411 raise exc.InvalidRequestError(
1412 "Dictionary keys / list indexes inside of a cached "
1413 "lambda must be Python literals only"
1414 )
1415 return self._sa__add_getter(key, operator.itemgetter)
1416
1417 def _add_getter(self, key, getter_fn):
1418 bind_paths = object.__getattribute__(self, "_bind_paths")
1419
1420 bind_path_key = (key, getter_fn)
1421 if bind_path_key in bind_paths:
1422 return bind_paths[bind_path_key]
1423
1424 getter = getter_fn(key)
1425 elem = object.__getattribute__(self, "_to_evaluate")
1426 value = getter(elem)
1427
1428 rolled_down_value = AnalyzedCode._roll_down_to_literal(value)
1429
1430 if coercions._deep_is_literal(rolled_down_value):
1431 wrapper = PyWrapper(self._sa_fn, key, value, getter=getter)
1432 bind_paths[bind_path_key] = wrapper
1433 return wrapper
1434 else:
1435 return value
1436
1437
1438@inspection._inspects(LambdaElement)
1439def insp(lmb):
1440 return inspection.inspect(lmb._resolved)