Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/sqlalchemy/sql/lambdas.py: 24%
533 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1# sql/lambdas.py
2# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
8import inspect
9import itertools
10import operator
11import sys
12import threading
13import types
14import weakref
16from . import coercions
17from . import elements
18from . import roles
19from . import schema
20from . import traversals
21from . import type_api
22from . import visitors
23from .base import _clone
24from .base import Options
25from .operators import ColumnOperators
26from .. import exc
27from .. import inspection
28from .. import util
29from ..util import collections_abc
30from ..util import compat
32_closure_per_cache_key = util.LRUCache(1000)
35class LambdaOptions(Options):
36 enable_tracking = True
37 track_closure_variables = True
38 track_on = None
39 global_track_bound_values = True
40 track_bound_values = True
41 lambda_cache = None
44def lambda_stmt(
45 lmb,
46 enable_tracking=True,
47 track_closure_variables=True,
48 track_on=None,
49 global_track_bound_values=True,
50 track_bound_values=True,
51 lambda_cache=None,
52):
53 """Produce a SQL statement that is cached as a lambda.
55 The Python code object within the lambda is scanned for both Python
56 literals that will become bound parameters as well as closure variables
57 that refer to Core or ORM constructs that may vary. The lambda itself
58 will be invoked only once per particular set of constructs detected.
60 E.g.::
62 from sqlalchemy import lambda_stmt
64 stmt = lambda_stmt(lambda: table.select())
65 stmt += lambda s: s.where(table.c.id == 5)
67 result = connection.execute(stmt)
69 The object returned is an instance of :class:`_sql.StatementLambdaElement`.
71 .. versionadded:: 1.4
73 :param lmb: a Python function, typically a lambda, which takes no arguments
74 and returns a SQL expression construct
75 :param enable_tracking: when False, all scanning of the given lambda for
76 changes in closure variables or bound parameters is disabled. Use for
77 a lambda that produces the identical results in all cases with no
78 parameterization.
79 :param track_closure_variables: when False, changes in closure variables
80 within the lambda will not be scanned. Use for a lambda where the
81 state of its closure variables will never change the SQL structure
82 returned by the lambda.
83 :param track_bound_values: when False, bound parameter tracking will
84 be disabled for the given lambda. Use for a lambda that either does
85 not produce any bound values, or where the initial bound values never
86 change.
87 :param global_track_bound_values: when False, bound parameter tracking
88 will be disabled for the entire statement including additional links
89 added via the :meth:`_sql.StatementLambdaElement.add_criteria` method.
90 :param lambda_cache: a dictionary or other mapping-like object where
91 information about the lambda's Python code as well as the tracked closure
92 variables in the lambda itself will be stored. Defaults
93 to a global LRU cache. This cache is independent of the "compiled_cache"
94 used by the :class:`_engine.Connection` object.
96 .. seealso::
98 :ref:`engine_lambda_caching`
101 """
103 return StatementLambdaElement(
104 lmb,
105 roles.StatementRole,
106 LambdaOptions(
107 enable_tracking=enable_tracking,
108 track_on=track_on,
109 track_closure_variables=track_closure_variables,
110 global_track_bound_values=global_track_bound_values,
111 track_bound_values=track_bound_values,
112 lambda_cache=lambda_cache,
113 ),
114 )
117class LambdaElement(elements.ClauseElement):
118 """A SQL construct where the state is stored as an un-invoked lambda.
120 The :class:`_sql.LambdaElement` is produced transparently whenever
121 passing lambda expressions into SQL constructs, such as::
123 stmt = select(table).where(lambda: table.c.col == parameter)
125 The :class:`_sql.LambdaElement` is the base of the
126 :class:`_sql.StatementLambdaElement` which represents a full statement
127 within a lambda.
129 .. versionadded:: 1.4
131 .. seealso::
133 :ref:`engine_lambda_caching`
135 """
137 __visit_name__ = "lambda_element"
139 _is_lambda_element = True
141 _traverse_internals = [
142 ("_resolved", visitors.InternalTraversal.dp_clauseelement)
143 ]
145 _transforms = ()
147 parent_lambda = None
149 def __repr__(self):
150 return "%s(%r)" % (self.__class__.__name__, self.fn.__code__)
152 def __init__(
153 self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None
154 ):
155 self.fn = fn
156 self.role = role
157 self.tracker_key = (fn.__code__,)
158 self.opts = opts
160 if apply_propagate_attrs is None and (role is roles.StatementRole):
161 apply_propagate_attrs = self
163 rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts)
165 if apply_propagate_attrs is not None:
166 propagate_attrs = rec.propagate_attrs
167 if propagate_attrs:
168 apply_propagate_attrs._propagate_attrs = propagate_attrs
170 def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts):
171 lambda_cache = opts.lambda_cache
172 if lambda_cache is None:
173 lambda_cache = _closure_per_cache_key
175 tracker_key = self.tracker_key
177 fn = self.fn
178 closure = fn.__closure__
179 tracker = AnalyzedCode.get(
180 fn,
181 self,
182 opts,
183 )
185 self._resolved_bindparams = bindparams = []
187 if self.parent_lambda is not None:
188 parent_closure_cache_key = self.parent_lambda.closure_cache_key
189 else:
190 parent_closure_cache_key = ()
192 if parent_closure_cache_key is not traversals.NO_CACHE:
193 anon_map = traversals.anon_map()
194 cache_key = tuple(
195 [
196 getter(closure, opts, anon_map, bindparams)
197 for getter in tracker.closure_trackers
198 ]
199 )
201 if traversals.NO_CACHE not in anon_map:
202 cache_key = parent_closure_cache_key + cache_key
204 self.closure_cache_key = cache_key
206 try:
207 rec = lambda_cache[tracker_key + cache_key]
208 except KeyError:
209 rec = None
210 else:
211 cache_key = traversals.NO_CACHE
212 rec = None
214 else:
215 cache_key = traversals.NO_CACHE
216 rec = None
218 self.closure_cache_key = cache_key
220 if rec is None:
221 if cache_key is not traversals.NO_CACHE:
223 with AnalyzedCode._generation_mutex:
224 key = tracker_key + cache_key
225 if key not in lambda_cache:
226 rec = AnalyzedFunction(
227 tracker, self, apply_propagate_attrs, fn
228 )
229 rec.closure_bindparams = bindparams
230 lambda_cache[key] = rec
231 else:
232 rec = lambda_cache[key]
233 else:
234 rec = NonAnalyzedFunction(self._invoke_user_fn(fn))
236 else:
237 bindparams[:] = [
238 orig_bind._with_value(new_bind.value, maintain_key=True)
239 for orig_bind, new_bind in zip(
240 rec.closure_bindparams, bindparams
241 )
242 ]
244 self._rec = rec
246 if cache_key is not traversals.NO_CACHE:
247 if self.parent_lambda is not None:
248 bindparams[:0] = self.parent_lambda._resolved_bindparams
250 lambda_element = self
251 while lambda_element is not None:
252 rec = lambda_element._rec
253 if rec.bindparam_trackers:
254 tracker_instrumented_fn = rec.tracker_instrumented_fn
255 for tracker in rec.bindparam_trackers:
256 tracker(
257 lambda_element.fn,
258 tracker_instrumented_fn,
259 bindparams,
260 )
261 lambda_element = lambda_element.parent_lambda
263 return rec
265 def __getattr__(self, key):
266 return getattr(self._rec.expected_expr, key)
268 @property
269 def _is_sequence(self):
270 return self._rec.is_sequence
272 @property
273 def _select_iterable(self):
274 if self._is_sequence:
275 return itertools.chain.from_iterable(
276 [element._select_iterable for element in self._resolved]
277 )
279 else:
280 return self._resolved._select_iterable
282 @property
283 def _from_objects(self):
284 if self._is_sequence:
285 return itertools.chain.from_iterable(
286 [element._from_objects for element in self._resolved]
287 )
289 else:
290 return self._resolved._from_objects
292 def _param_dict(self):
293 return {b.key: b.value for b in self._resolved_bindparams}
295 def _setup_binds_for_tracked_expr(self, expr):
296 bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
298 def replace(thing):
299 if isinstance(thing, elements.BindParameter):
301 if thing.key in bindparam_lookup:
302 bind = bindparam_lookup[thing.key]
303 if thing.expanding:
304 bind.expanding = True
305 bind.expand_op = thing.expand_op
306 bind.type = thing.type
307 return bind
309 if self._rec.is_sequence:
310 expr = [
311 visitors.replacement_traverse(sub_expr, {}, replace)
312 for sub_expr in expr
313 ]
314 elif getattr(expr, "is_clause_element", False):
315 expr = visitors.replacement_traverse(expr, {}, replace)
317 return expr
319 def _copy_internals(
320 self, clone=_clone, deferred_copy_internals=None, **kw
321 ):
322 # TODO: this needs A LOT of tests
323 self._resolved = clone(
324 self._resolved,
325 deferred_copy_internals=deferred_copy_internals,
326 **kw
327 )
329 @util.memoized_property
330 def _resolved(self):
331 expr = self._rec.expected_expr
333 if self._resolved_bindparams:
334 expr = self._setup_binds_for_tracked_expr(expr)
336 return expr
338 def _gen_cache_key(self, anon_map, bindparams):
339 if self.closure_cache_key is traversals.NO_CACHE:
340 anon_map[traversals.NO_CACHE] = True
341 return None
343 cache_key = (
344 self.fn.__code__,
345 self.__class__,
346 ) + self.closure_cache_key
348 parent = self.parent_lambda
349 while parent is not None:
350 cache_key = (
351 (parent.fn.__code__,) + parent.closure_cache_key + cache_key
352 )
354 parent = parent.parent_lambda
356 if self._resolved_bindparams:
357 bindparams.extend(self._resolved_bindparams)
358 return cache_key
360 def _invoke_user_fn(self, fn, *arg):
361 return fn()
364class DeferredLambdaElement(LambdaElement):
365 """A LambdaElement where the lambda accepts arguments and is
366 invoked within the compile phase with special context.
368 This lambda doesn't normally produce its real SQL expression outside of the
369 compile phase. It is passed a fixed set of initial arguments
370 so that it can generate a sample expression.
372 """
374 def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()):
375 self.lambda_args = lambda_args
376 super(DeferredLambdaElement, self).__init__(fn, role, opts)
378 def _invoke_user_fn(self, fn, *arg):
379 return fn(*self.lambda_args)
381 def _resolve_with_args(self, *lambda_args):
382 tracker_fn = self._rec.tracker_instrumented_fn
383 expr = tracker_fn(*lambda_args)
385 expr = coercions.expect(self.role, expr)
387 expr = self._setup_binds_for_tracked_expr(expr)
389 # this validation is getting very close, but not quite, to achieving
390 # #5767. The problem is if the base lambda uses an unnamed column
391 # as is very common with mixins, the parameter name is different
392 # and it produces a false positive; that is, for the documented case
393 # that is exactly what people will be doing, it doesn't work, so
394 # I'm not really sure how to handle this right now.
395 # expected_binds = [
396 # b._orig_key
397 # for b in self._rec.expr._generate_cache_key()[1]
398 # if b.required
399 # ]
400 # got_binds = [
401 # b._orig_key for b in expr._generate_cache_key()[1] if b.required
402 # ]
403 # if expected_binds != got_binds:
404 # raise exc.InvalidRequestError(
405 # "Lambda callable at %s produced a different set of bound "
406 # "parameters than its original run: %s"
407 # % (self.fn.__code__, ", ".join(got_binds))
408 # )
410 # TODO: TEST TEST TEST, this is very out there
411 for deferred_copy_internals in self._transforms:
412 expr = deferred_copy_internals(expr)
414 return expr
416 def _copy_internals(
417 self, clone=_clone, deferred_copy_internals=None, **kw
418 ):
419 super(DeferredLambdaElement, self)._copy_internals(
420 clone=clone,
421 deferred_copy_internals=deferred_copy_internals, # **kw
422 opts=kw,
423 )
425 # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know
426 # our expression yet. so hold onto the replacement
427 if deferred_copy_internals:
428 self._transforms += (deferred_copy_internals,)
431class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
432 """Represent a composable SQL statement as a :class:`_sql.LambdaElement`.
434 The :class:`_sql.StatementLambdaElement` is constructed using the
435 :func:`_sql.lambda_stmt` function::
438 from sqlalchemy import lambda_stmt
440 stmt = lambda_stmt(lambda: select(table))
442 Once constructed, additional criteria can be built onto the statement
443 by adding subsequent lambdas, which accept the existing statement
444 object as a single parameter::
446 stmt += lambda s: s.where(table.c.col == parameter)
449 .. versionadded:: 1.4
451 .. seealso::
453 :ref:`engine_lambda_caching`
455 """
457 def __add__(self, other):
458 return self.add_criteria(other)
460 def add_criteria(
461 self,
462 other,
463 enable_tracking=True,
464 track_on=None,
465 track_closure_variables=True,
466 track_bound_values=True,
467 ):
468 """Add new criteria to this :class:`_sql.StatementLambdaElement`.
470 E.g.::
472 >>> def my_stmt(parameter):
473 ... stmt = lambda_stmt(
474 ... lambda: select(table.c.x, table.c.y),
475 ... )
476 ... stmt = stmt.add_criteria(
477 ... lambda: table.c.x > parameter
478 ... )
479 ... return stmt
481 The :meth:`_sql.StatementLambdaElement.add_criteria` method is
482 equivalent to using the Python addition operator to add a new
483 lambda, except that additional arguments may be added including
484 ``track_closure_values`` and ``track_on``::
486 >>> def my_stmt(self, foo):
487 ... stmt = lambda_stmt(
488 ... lambda: select(func.max(foo.x, foo.y)),
489 ... track_closure_variables=False
490 ... )
491 ... stmt = stmt.add_criteria(
492 ... lambda: self.where_criteria,
493 ... track_on=[self]
494 ... )
495 ... return stmt
497 See :func:`_sql.lambda_stmt` for a description of the parameters
498 accepted.
500 """
502 opts = self.opts + dict(
503 enable_tracking=enable_tracking,
504 track_closure_variables=track_closure_variables,
505 global_track_bound_values=self.opts.global_track_bound_values,
506 track_on=track_on,
507 track_bound_values=track_bound_values,
508 )
510 return LinkedLambdaElement(other, parent_lambda=self, opts=opts)
512 def _execute_on_connection(
513 self, connection, multiparams, params, execution_options
514 ):
515 if self._rec.expected_expr.supports_execution:
516 return connection._execute_clauseelement(
517 self, multiparams, params, execution_options
518 )
519 else:
520 raise exc.ObjectNotExecutableError(self)
522 @property
523 def _with_options(self):
524 return self._rec.expected_expr._with_options
526 @property
527 def _effective_plugin_target(self):
528 return self._rec.expected_expr._effective_plugin_target
530 @property
531 def _execution_options(self):
532 return self._rec.expected_expr._execution_options
534 def spoil(self):
535 """Return a new :class:`.StatementLambdaElement` that will run
536 all lambdas unconditionally each time.
538 """
539 return NullLambdaStatement(self.fn())
542class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement):
543 """Provides the :class:`.StatementLambdaElement` API but does not
544 cache or analyze lambdas.
546 the lambdas are instead invoked immediately.
548 The intended use is to isolate issues that may arise when using
549 lambda statements.
551 """
553 __visit_name__ = "lambda_element"
555 _is_lambda_element = True
557 _traverse_internals = [
558 ("_resolved", visitors.InternalTraversal.dp_clauseelement)
559 ]
561 def __init__(self, statement):
562 self._resolved = statement
563 self._propagate_attrs = statement._propagate_attrs
565 def __getattr__(self, key):
566 return getattr(self._resolved, key)
568 def __add__(self, other):
569 statement = other(self._resolved)
571 return NullLambdaStatement(statement)
573 def add_criteria(self, other, **kw):
574 statement = other(self._resolved)
576 return NullLambdaStatement(statement)
578 def _execute_on_connection(
579 self, connection, multiparams, params, execution_options
580 ):
581 if self._resolved.supports_execution:
582 return connection._execute_clauseelement(
583 self, multiparams, params, execution_options
584 )
585 else:
586 raise exc.ObjectNotExecutableError(self)
589class LinkedLambdaElement(StatementLambdaElement):
590 """Represent subsequent links of a :class:`.StatementLambdaElement`."""
592 role = None
594 def __init__(self, fn, parent_lambda, opts):
595 self.opts = opts
596 self.fn = fn
597 self.parent_lambda = parent_lambda
599 self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
600 self._retrieve_tracker_rec(fn, self, opts)
601 self._propagate_attrs = parent_lambda._propagate_attrs
603 def _invoke_user_fn(self, fn, *arg):
604 return fn(self.parent_lambda._resolved)
607class AnalyzedCode(object):
608 __slots__ = (
609 "track_closure_variables",
610 "track_bound_values",
611 "bindparam_trackers",
612 "closure_trackers",
613 "build_py_wrappers",
614 )
615 _fns = weakref.WeakKeyDictionary()
617 _generation_mutex = threading.RLock()
619 @classmethod
620 def get(cls, fn, lambda_element, lambda_kw, **kw):
621 try:
622 # TODO: validate kw haven't changed?
623 return cls._fns[fn.__code__]
624 except KeyError:
625 pass
627 with cls._generation_mutex:
628 # check for other thread already created object
629 if fn.__code__ in cls._fns:
630 return cls._fns[fn.__code__]
632 cls._fns[fn.__code__] = analyzed = AnalyzedCode(
633 fn, lambda_element, lambda_kw, **kw
634 )
635 return analyzed
637 def __init__(self, fn, lambda_element, opts):
638 if inspect.ismethod(fn):
639 raise exc.ArgumentError(
640 "Method %s may not be passed as a SQL expression" % fn
641 )
642 closure = fn.__closure__
644 self.track_bound_values = (
645 opts.track_bound_values and opts.global_track_bound_values
646 )
647 enable_tracking = opts.enable_tracking
648 track_on = opts.track_on
649 track_closure_variables = opts.track_closure_variables
651 self.track_closure_variables = track_closure_variables and not track_on
653 # a list of callables generated from _bound_parameter_getter_*
654 # functions. Each of these uses a PyWrapper object to retrieve
655 # a parameter value
656 self.bindparam_trackers = []
658 # a list of callables generated from _cache_key_getter_* functions
659 # these callables work to generate a cache key for the lambda
660 # based on what's inside its closure variables.
661 self.closure_trackers = []
663 self.build_py_wrappers = []
665 if enable_tracking:
666 if track_on:
667 self._init_track_on(track_on)
669 self._init_globals(fn)
671 if closure:
672 self._init_closure(fn)
674 self._setup_additional_closure_trackers(fn, lambda_element, opts)
676 def _init_track_on(self, track_on):
677 self.closure_trackers.extend(
678 self._cache_key_getter_track_on(idx, elem)
679 for idx, elem in enumerate(track_on)
680 )
682 def _init_globals(self, fn):
683 build_py_wrappers = self.build_py_wrappers
684 bindparam_trackers = self.bindparam_trackers
685 track_bound_values = self.track_bound_values
687 for name in fn.__code__.co_names:
688 if name not in fn.__globals__:
689 continue
691 _bound_value = self._roll_down_to_literal(fn.__globals__[name])
693 if coercions._deep_is_literal(_bound_value):
694 build_py_wrappers.append((name, None))
695 if track_bound_values:
696 bindparam_trackers.append(
697 self._bound_parameter_getter_func_globals(name)
698 )
700 def _init_closure(self, fn):
701 build_py_wrappers = self.build_py_wrappers
702 closure = fn.__closure__
704 track_bound_values = self.track_bound_values
705 track_closure_variables = self.track_closure_variables
706 bindparam_trackers = self.bindparam_trackers
707 closure_trackers = self.closure_trackers
709 for closure_index, (fv, cell) in enumerate(
710 zip(fn.__code__.co_freevars, closure)
711 ):
712 _bound_value = self._roll_down_to_literal(cell.cell_contents)
714 if coercions._deep_is_literal(_bound_value):
715 build_py_wrappers.append((fv, closure_index))
716 if track_bound_values:
717 bindparam_trackers.append(
718 self._bound_parameter_getter_func_closure(
719 fv, closure_index
720 )
721 )
722 else:
723 # for normal cell contents, add them to a list that
724 # we can compare later when we get new lambdas. if
725 # any identities have changed, then we will
726 # recalculate the whole lambda and run it again.
728 if track_closure_variables:
729 closure_trackers.append(
730 self._cache_key_getter_closure_variable(
731 fn, fv, closure_index, cell.cell_contents
732 )
733 )
735 def _setup_additional_closure_trackers(self, fn, lambda_element, opts):
736 # an additional step is to actually run the function, then
737 # go through the PyWrapper objects that were set up to catch a bound
738 # parameter. then if they *didn't* make a param, oh they're another
739 # object in the closure we have to track for our cache key. so
740 # create trackers to catch those.
742 analyzed_function = AnalyzedFunction(
743 self,
744 lambda_element,
745 None,
746 fn,
747 )
749 closure_trackers = self.closure_trackers
751 for pywrapper in analyzed_function.closure_pywrappers:
752 if not pywrapper._sa__has_param:
753 closure_trackers.append(
754 self._cache_key_getter_tracked_literal(fn, pywrapper)
755 )
757 @classmethod
758 def _roll_down_to_literal(cls, element):
759 is_clause_element = hasattr(element, "__clause_element__")
761 if is_clause_element:
762 while not isinstance(
763 element, (elements.ClauseElement, schema.SchemaItem, type)
764 ):
765 try:
766 element = element.__clause_element__()
767 except AttributeError:
768 break
770 if not is_clause_element:
771 insp = inspection.inspect(element, raiseerr=False)
772 if insp is not None:
773 try:
774 return insp.__clause_element__()
775 except AttributeError:
776 return insp
778 # TODO: should we coerce consts None/True/False here?
779 return element
780 else:
781 return element
783 def _bound_parameter_getter_func_globals(self, name):
784 """Return a getter that will extend a list of bound parameters
785 with new entries from the ``__globals__`` collection of a particular
786 lambda.
788 """
790 def extract_parameter_value(
791 current_fn, tracker_instrumented_fn, result
792 ):
793 wrapper = tracker_instrumented_fn.__globals__[name]
794 object.__getattribute__(wrapper, "_extract_bound_parameters")(
795 current_fn.__globals__[name], result
796 )
798 return extract_parameter_value
800 def _bound_parameter_getter_func_closure(self, name, closure_index):
801 """Return a getter that will extend a list of bound parameters
802 with new entries from the ``__closure__`` collection of a particular
803 lambda.
805 """
807 def extract_parameter_value(
808 current_fn, tracker_instrumented_fn, result
809 ):
810 wrapper = tracker_instrumented_fn.__closure__[
811 closure_index
812 ].cell_contents
813 object.__getattribute__(wrapper, "_extract_bound_parameters")(
814 current_fn.__closure__[closure_index].cell_contents, result
815 )
817 return extract_parameter_value
819 def _cache_key_getter_track_on(self, idx, elem):
820 """Return a getter that will extend a cache key with new entries
821 from the "track_on" parameter passed to a :class:`.LambdaElement`.
823 """
825 if isinstance(elem, tuple):
826 # tuple must contain hascachekey elements
827 def get(closure, opts, anon_map, bindparams):
828 return tuple(
829 tup_elem._gen_cache_key(anon_map, bindparams)
830 for tup_elem in opts.track_on[idx]
831 )
833 elif isinstance(elem, traversals.HasCacheKey):
835 def get(closure, opts, anon_map, bindparams):
836 return opts.track_on[idx]._gen_cache_key(anon_map, bindparams)
838 else:
840 def get(closure, opts, anon_map, bindparams):
841 return opts.track_on[idx]
843 return get
845 def _cache_key_getter_closure_variable(
846 self,
847 fn,
848 variable_name,
849 idx,
850 cell_contents,
851 use_clause_element=False,
852 use_inspect=False,
853 ):
854 """Return a getter that will extend a cache key with new entries
855 from the ``__closure__`` collection of a particular lambda.
857 """
859 if isinstance(cell_contents, traversals.HasCacheKey):
861 def get(closure, opts, anon_map, bindparams):
863 obj = closure[idx].cell_contents
864 if use_inspect:
865 obj = inspection.inspect(obj)
866 elif use_clause_element:
867 while hasattr(obj, "__clause_element__"):
868 if not getattr(obj, "is_clause_element", False):
869 obj = obj.__clause_element__()
871 return obj._gen_cache_key(anon_map, bindparams)
873 elif isinstance(cell_contents, types.FunctionType):
875 def get(closure, opts, anon_map, bindparams):
876 return closure[idx].cell_contents.__code__
878 elif isinstance(cell_contents, collections_abc.Sequence):
880 def get(closure, opts, anon_map, bindparams):
881 contents = closure[idx].cell_contents
883 try:
884 return tuple(
885 elem._gen_cache_key(anon_map, bindparams)
886 for elem in contents
887 )
888 except AttributeError as ae:
889 self._raise_for_uncacheable_closure_variable(
890 variable_name, fn, from_=ae
891 )
893 else:
894 # if the object is a mapped class or aliased class, or some
895 # other object in the ORM realm of things like that, imitate
896 # the logic used in coercions.expect() to roll it down to the
897 # SQL element
898 element = cell_contents
899 is_clause_element = False
900 while hasattr(element, "__clause_element__"):
901 is_clause_element = True
902 if not getattr(element, "is_clause_element", False):
903 element = element.__clause_element__()
904 else:
905 break
907 if not is_clause_element:
908 insp = inspection.inspect(element, raiseerr=False)
909 if insp is not None:
910 return self._cache_key_getter_closure_variable(
911 fn, variable_name, idx, insp, use_inspect=True
912 )
913 else:
914 return self._cache_key_getter_closure_variable(
915 fn, variable_name, idx, element, use_clause_element=True
916 )
918 self._raise_for_uncacheable_closure_variable(variable_name, fn)
920 return get
922 def _raise_for_uncacheable_closure_variable(
923 self, variable_name, fn, from_=None
924 ):
925 util.raise_(
926 exc.InvalidRequestError(
927 "Closure variable named '%s' inside of lambda callable %s "
928 "does not refer to a cacheable SQL element, and also does not "
929 "appear to be serving as a SQL literal bound value based on "
930 "the default "
931 "SQL expression returned by the function. This variable "
932 "needs to remain outside the scope of a SQL-generating lambda "
933 "so that a proper cache key may be generated from the "
934 "lambda's state. Evaluate this variable outside of the "
935 "lambda, set track_on=[<elements>] to explicitly select "
936 "closure elements to track, or set "
937 "track_closure_variables=False to exclude "
938 "closure variables from being part of the cache key."
939 % (variable_name, fn.__code__),
940 ),
941 from_=from_,
942 )
944 def _cache_key_getter_tracked_literal(self, fn, pytracker):
945 """Return a getter that will extend a cache key with new entries
946 from the ``__closure__`` collection of a particular lambda.
948 this getter differs from _cache_key_getter_closure_variable
949 in that these are detected after the function is run, and PyWrapper
950 objects have recorded that a particular literal value is in fact
951 not being interpreted as a bound parameter.
953 """
955 elem = pytracker._sa__to_evaluate
956 closure_index = pytracker._sa__closure_index
957 variable_name = pytracker._sa__name
959 return self._cache_key_getter_closure_variable(
960 fn, variable_name, closure_index, elem
961 )
964class NonAnalyzedFunction(object):
965 __slots__ = ("expr",)
967 closure_bindparams = None
968 bindparam_trackers = None
970 def __init__(self, expr):
971 self.expr = expr
973 @property
974 def expected_expr(self):
975 return self.expr
978class AnalyzedFunction(object):
979 __slots__ = (
980 "analyzed_code",
981 "fn",
982 "closure_pywrappers",
983 "tracker_instrumented_fn",
984 "expr",
985 "bindparam_trackers",
986 "expected_expr",
987 "is_sequence",
988 "propagate_attrs",
989 "closure_bindparams",
990 )
992 def __init__(
993 self,
994 analyzed_code,
995 lambda_element,
996 apply_propagate_attrs,
997 fn,
998 ):
999 self.analyzed_code = analyzed_code
1000 self.fn = fn
1002 self.bindparam_trackers = analyzed_code.bindparam_trackers
1004 self._instrument_and_run_function(lambda_element)
1006 self._coerce_expression(lambda_element, apply_propagate_attrs)
1008 def _instrument_and_run_function(self, lambda_element):
1009 analyzed_code = self.analyzed_code
1011 fn = self.fn
1012 self.closure_pywrappers = closure_pywrappers = []
1014 build_py_wrappers = analyzed_code.build_py_wrappers
1016 if not build_py_wrappers:
1017 self.tracker_instrumented_fn = tracker_instrumented_fn = fn
1018 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
1019 else:
1020 track_closure_variables = analyzed_code.track_closure_variables
1021 closure = fn.__closure__
1023 # will form the __closure__ of the function when we rebuild it
1024 if closure:
1025 new_closure = {
1026 fv: cell.cell_contents
1027 for fv, cell in zip(fn.__code__.co_freevars, closure)
1028 }
1029 else:
1030 new_closure = {}
1032 # will form the __globals__ of the function when we rebuild it
1033 new_globals = fn.__globals__.copy()
1035 for name, closure_index in build_py_wrappers:
1036 if closure_index is not None:
1037 value = closure[closure_index].cell_contents
1038 new_closure[name] = bind = PyWrapper(
1039 fn,
1040 name,
1041 value,
1042 closure_index=closure_index,
1043 track_bound_values=(
1044 self.analyzed_code.track_bound_values
1045 ),
1046 )
1047 if track_closure_variables:
1048 closure_pywrappers.append(bind)
1049 else:
1050 value = fn.__globals__[name]
1051 new_globals[name] = bind = PyWrapper(fn, name, value)
1053 # rewrite the original fn. things that look like they will
1054 # become bound parameters are wrapped in a PyWrapper.
1055 self.tracker_instrumented_fn = (
1056 tracker_instrumented_fn
1057 ) = self._rewrite_code_obj(
1058 fn,
1059 [new_closure[name] for name in fn.__code__.co_freevars],
1060 new_globals,
1061 )
1063 # now invoke the function. This will give us a new SQL
1064 # expression, but all the places that there would be a bound
1065 # parameter, the PyWrapper in its place will give us a bind
1066 # with a predictable name we can match up later.
1068 # additionally, each PyWrapper will log that it did in fact
1069 # create a parameter, otherwise, it's some kind of Python
1070 # object in the closure and we want to track that, to make
1071 # sure it doesn't change to something else, or if it does,
1072 # that we create a different tracked function with that
1073 # variable.
1074 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
1076 def _coerce_expression(self, lambda_element, apply_propagate_attrs):
1077 """Run the tracker-generated expression through coercion rules.
1079 After the user-defined lambda has been invoked to produce a statement
1080 for re-use, run it through coercion rules to both check that it's the
1081 correct type of object and also to coerce it to its useful form.
1083 """
1085 parent_lambda = lambda_element.parent_lambda
1086 expr = self.expr
1088 if parent_lambda is None:
1089 if isinstance(expr, collections_abc.Sequence):
1090 self.expected_expr = [
1091 coercions.expect(
1092 lambda_element.role,
1093 sub_expr,
1094 apply_propagate_attrs=apply_propagate_attrs,
1095 )
1096 for sub_expr in expr
1097 ]
1098 self.is_sequence = True
1099 else:
1100 self.expected_expr = coercions.expect(
1101 lambda_element.role,
1102 expr,
1103 apply_propagate_attrs=apply_propagate_attrs,
1104 )
1105 self.is_sequence = False
1106 else:
1107 self.expected_expr = expr
1108 self.is_sequence = False
1110 if apply_propagate_attrs is not None:
1111 self.propagate_attrs = apply_propagate_attrs._propagate_attrs
1112 else:
1113 self.propagate_attrs = util.EMPTY_DICT
1115 def _rewrite_code_obj(self, f, cell_values, globals_):
1116 """Return a copy of f, with a new closure and new globals
1118 yes it works in pypy :P
1120 """
1122 argrange = range(len(cell_values))
1124 code = "def make_cells():\n"
1125 if cell_values:
1126 code += " (%s) = (%s)\n" % (
1127 ", ".join("i%d" % i for i in argrange),
1128 ", ".join("o%d" % i for i in argrange),
1129 )
1130 code += " def closure():\n"
1131 code += " return %s\n" % ", ".join("i%d" % i for i in argrange)
1132 code += " return closure.__closure__"
1133 vars_ = {"o%d" % i: cell_values[i] for i in argrange}
1134 compat.exec_(code, vars_, vars_)
1135 closure = vars_["make_cells"]()
1137 func = type(f)(
1138 f.__code__, globals_, f.__name__, f.__defaults__, closure
1139 )
1140 if sys.version_info >= (3,):
1141 func.__annotations__ = f.__annotations__
1142 func.__kwdefaults__ = f.__kwdefaults__
1143 func.__doc__ = f.__doc__
1144 func.__module__ = f.__module__
1146 return func
1149class PyWrapper(ColumnOperators):
1150 """A wrapper object that is injected into the ``__globals__`` and
1151 ``__closure__`` of a Python function.
1153 When the function is instrumented with :class:`.PyWrapper` objects, it is
1154 then invoked just once in order to set up the wrappers. We look through
1155 all the :class:`.PyWrapper` objects we made to find the ones that generated
1156 a :class:`.BindParameter` object, e.g. the expression system interpreted
1157 something as a literal. Those positions in the globals/closure are then
1158 ones that we will look at, each time a new lambda comes in that refers to
1159 the same ``__code__`` object. In this way, we keep a single version of
1160 the SQL expression that this lambda produced, without calling upon the
1161 Python function that created it more than once, unless its other closure
1162 variables have changed. The expression is then transformed to have the
1163 new bound values embedded into it.
1165 """
1167 def __init__(
1168 self,
1169 fn,
1170 name,
1171 to_evaluate,
1172 closure_index=None,
1173 getter=None,
1174 track_bound_values=True,
1175 ):
1176 self.fn = fn
1177 self._name = name
1178 self._to_evaluate = to_evaluate
1179 self._param = None
1180 self._has_param = False
1181 self._bind_paths = {}
1182 self._getter = getter
1183 self._closure_index = closure_index
1184 self.track_bound_values = track_bound_values
1186 def __call__(self, *arg, **kw):
1187 elem = object.__getattribute__(self, "_to_evaluate")
1188 value = elem(*arg, **kw)
1189 if (
1190 self._sa_track_bound_values
1191 and coercions._deep_is_literal(value)
1192 and not isinstance(
1193 # TODO: coverage where an ORM option or similar is here
1194 value,
1195 traversals.HasCacheKey,
1196 )
1197 ):
1198 name = object.__getattribute__(self, "_name")
1199 raise exc.InvalidRequestError(
1200 "Can't invoke Python callable %s() inside of lambda "
1201 "expression argument at %s; lambda SQL constructs should "
1202 "not invoke functions from closure variables to produce "
1203 "literal values since the "
1204 "lambda SQL system normally extracts bound values without "
1205 "actually "
1206 "invoking the lambda or any functions within it. Call the "
1207 "function outside of the "
1208 "lambda and assign to a local variable that is used in the "
1209 "lambda as a closure variable, or set "
1210 "track_bound_values=False if the return value of this "
1211 "function is used in some other way other than a SQL bound "
1212 "value." % (name, self._sa_fn.__code__)
1213 )
1214 else:
1215 return value
1217 def operate(self, op, *other, **kwargs):
1218 elem = object.__getattribute__(self, "__clause_element__")()
1219 return op(elem, *other, **kwargs)
1221 def reverse_operate(self, op, other, **kwargs):
1222 elem = object.__getattribute__(self, "__clause_element__")()
1223 return op(other, elem, **kwargs)
1225 def _extract_bound_parameters(self, starting_point, result_list):
1226 param = object.__getattribute__(self, "_param")
1227 if param is not None:
1228 param = param._with_value(starting_point, maintain_key=True)
1229 result_list.append(param)
1230 for pywrapper in object.__getattribute__(self, "_bind_paths").values():
1231 getter = object.__getattribute__(pywrapper, "_getter")
1232 element = getter(starting_point)
1233 pywrapper._sa__extract_bound_parameters(element, result_list)
1235 def __clause_element__(self):
1236 param = object.__getattribute__(self, "_param")
1237 to_evaluate = object.__getattribute__(self, "_to_evaluate")
1238 if param is None:
1239 name = object.__getattribute__(self, "_name")
1240 self._param = param = elements.BindParameter(
1241 name, required=False, unique=True
1242 )
1243 self._has_param = True
1244 param.type = type_api._resolve_value_to_type(to_evaluate)
1245 return param._with_value(to_evaluate, maintain_key=True)
1247 def __bool__(self):
1248 to_evaluate = object.__getattribute__(self, "_to_evaluate")
1249 return bool(to_evaluate)
1251 def __nonzero__(self):
1252 to_evaluate = object.__getattribute__(self, "_to_evaluate")
1253 return bool(to_evaluate)
1255 def __getattribute__(self, key):
1256 if key.startswith("_sa_"):
1257 return object.__getattribute__(self, key[4:])
1258 elif key in (
1259 "__clause_element__",
1260 "operate",
1261 "reverse_operate",
1262 "__class__",
1263 "__dict__",
1264 ):
1265 return object.__getattribute__(self, key)
1267 if key.startswith("__"):
1268 elem = object.__getattribute__(self, "_to_evaluate")
1269 return getattr(elem, key)
1270 else:
1271 return self._sa__add_getter(key, operator.attrgetter)
1273 def __iter__(self):
1274 elem = object.__getattribute__(self, "_to_evaluate")
1275 return iter(elem)
1277 def __getitem__(self, key):
1278 elem = object.__getattribute__(self, "_to_evaluate")
1279 if not hasattr(elem, "__getitem__"):
1280 raise AttributeError("__getitem__")
1282 if isinstance(key, PyWrapper):
1283 # TODO: coverage
1284 raise exc.InvalidRequestError(
1285 "Dictionary keys / list indexes inside of a cached "
1286 "lambda must be Python literals only"
1287 )
1288 return self._sa__add_getter(key, operator.itemgetter)
1290 def _add_getter(self, key, getter_fn):
1292 bind_paths = object.__getattribute__(self, "_bind_paths")
1294 bind_path_key = (key, getter_fn)
1295 if bind_path_key in bind_paths:
1296 return bind_paths[bind_path_key]
1298 getter = getter_fn(key)
1299 elem = object.__getattribute__(self, "_to_evaluate")
1300 value = getter(elem)
1302 rolled_down_value = AnalyzedCode._roll_down_to_literal(value)
1304 if coercions._deep_is_literal(rolled_down_value):
1305 wrapper = PyWrapper(self._sa_fn, key, value, getter=getter)
1306 bind_paths[bind_path_key] = wrapper
1307 return wrapper
1308 else:
1309 return value
1312@inspection._inspects(LambdaElement)
1313def insp(lmb):
1314 return inspection.inspect(lmb._resolved)