Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/sqlalchemy/sql/lambdas.py: 24%
531 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1# sql/lambdas.py
2# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
8import inspect
9import itertools
10import operator
11import sys
12import threading
13import types
14import weakref
16from . import coercions
17from . import elements
18from . import roles
19from . import schema
20from . import traversals
21from . import visitors
22from .base import _clone
23from .base import Options
24from .operators import ColumnOperators
25from .. import exc
26from .. import inspection
27from .. import util
28from ..util import collections_abc
29from ..util import compat
31_closure_per_cache_key = util.LRUCache(1000)
34class LambdaOptions(Options):
35 enable_tracking = True
36 track_closure_variables = True
37 track_on = None
38 global_track_bound_values = True
39 track_bound_values = True
40 lambda_cache = None
43def lambda_stmt(
44 lmb,
45 enable_tracking=True,
46 track_closure_variables=True,
47 track_on=None,
48 global_track_bound_values=True,
49 track_bound_values=True,
50 lambda_cache=None,
51):
52 """Produce a SQL statement that is cached as a lambda.
54 The Python code object within the lambda is scanned for both Python
55 literals that will become bound parameters as well as closure variables
56 that refer to Core or ORM constructs that may vary. The lambda itself
57 will be invoked only once per particular set of constructs detected.
59 E.g.::
61 from sqlalchemy import lambda_stmt
63 stmt = lambda_stmt(lambda: table.select())
64 stmt += lambda s: s.where(table.c.id == 5)
66 result = connection.execute(stmt)
68 The object returned is an instance of :class:`_sql.StatementLambdaElement`.
70 .. versionadded:: 1.4
72 :param lmb: a Python function, typically a lambda, which takes no arguments
73 and returns a SQL expression construct
74 :param enable_tracking: when False, all scanning of the given lambda for
75 changes in closure variables or bound parameters is disabled. Use for
76 a lambda that produces the identical results in all cases with no
77 parameterization.
78 :param track_closure_variables: when False, changes in closure variables
79 within the lambda will not be scanned. Use for a lambda where the
80 state of its closure variables will never change the SQL structure
81 returned by the lambda.
82 :param track_bound_values: when False, bound parameter tracking will
83 be disabled for the given lambda. Use for a lambda that either does
84 not produce any bound values, or where the initial bound values never
85 change.
86 :param global_track_bound_values: when False, bound parameter tracking
87 will be disabled for the entire statement including additional links
88 added via the :meth:`_sql.StatementLambdaElement.add_criteria` method.
89 :param lambda_cache: a dictionary or other mapping-like object where
90 information about the lambda's Python code as well as the tracked closure
91 variables in the lambda itself will be stored. Defaults
92 to a global LRU cache. This cache is independent of the "compiled_cache"
93 used by the :class:`_engine.Connection` object.
95 .. seealso::
97 :ref:`engine_lambda_caching`
100 """
102 return StatementLambdaElement(
103 lmb,
104 roles.StatementRole,
105 LambdaOptions(
106 enable_tracking=enable_tracking,
107 track_on=track_on,
108 track_closure_variables=track_closure_variables,
109 global_track_bound_values=global_track_bound_values,
110 track_bound_values=track_bound_values,
111 lambda_cache=lambda_cache,
112 ),
113 )
116class LambdaElement(elements.ClauseElement):
117 """A SQL construct where the state is stored as an un-invoked lambda.
119 The :class:`_sql.LambdaElement` is produced transparently whenever
120 passing lambda expressions into SQL constructs, such as::
122 stmt = select(table).where(lambda: table.c.col == parameter)
124 The :class:`_sql.LambdaElement` is the base of the
125 :class:`_sql.StatementLambdaElement` which represents a full statement
126 within a lambda.
128 .. versionadded:: 1.4
130 .. seealso::
132 :ref:`engine_lambda_caching`
134 """
136 __visit_name__ = "lambda_element"
138 _is_lambda_element = True
140 _traverse_internals = [
141 ("_resolved", visitors.InternalTraversal.dp_clauseelement)
142 ]
144 _transforms = ()
146 parent_lambda = None
148 def __repr__(self):
149 return "%s(%r)" % (self.__class__.__name__, self.fn.__code__)
151 def __init__(
152 self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None
153 ):
154 self.fn = fn
155 self.role = role
156 self.tracker_key = (fn.__code__,)
157 self.opts = opts
159 if apply_propagate_attrs is None and (role is roles.StatementRole):
160 apply_propagate_attrs = self
162 rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts)
164 if apply_propagate_attrs is not None:
165 propagate_attrs = rec.propagate_attrs
166 if propagate_attrs:
167 apply_propagate_attrs._propagate_attrs = propagate_attrs
169 def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts):
170 lambda_cache = opts.lambda_cache
171 if lambda_cache is None:
172 lambda_cache = _closure_per_cache_key
174 tracker_key = self.tracker_key
176 fn = self.fn
177 closure = fn.__closure__
178 tracker = AnalyzedCode.get(
179 fn,
180 self,
181 opts,
182 )
184 self._resolved_bindparams = bindparams = []
186 if self.parent_lambda is not None:
187 parent_closure_cache_key = self.parent_lambda.closure_cache_key
188 else:
189 parent_closure_cache_key = ()
191 if parent_closure_cache_key is not traversals.NO_CACHE:
192 anon_map = traversals.anon_map()
193 cache_key = tuple(
194 [
195 getter(closure, opts, anon_map, bindparams)
196 for getter in tracker.closure_trackers
197 ]
198 )
200 if traversals.NO_CACHE not in anon_map:
201 cache_key = parent_closure_cache_key + cache_key
203 self.closure_cache_key = cache_key
205 try:
206 rec = lambda_cache[tracker_key + cache_key]
207 except KeyError:
208 rec = None
209 else:
210 cache_key = traversals.NO_CACHE
211 rec = None
213 else:
214 cache_key = traversals.NO_CACHE
215 rec = None
217 self.closure_cache_key = cache_key
219 if rec is None:
220 if cache_key is not traversals.NO_CACHE:
222 with AnalyzedCode._generation_mutex:
223 key = tracker_key + cache_key
224 if key not in lambda_cache:
225 rec = AnalyzedFunction(
226 tracker, self, apply_propagate_attrs, fn
227 )
228 rec.closure_bindparams = bindparams
229 lambda_cache[key] = rec
230 else:
231 rec = lambda_cache[key]
232 else:
233 rec = NonAnalyzedFunction(self._invoke_user_fn(fn))
235 else:
236 bindparams[:] = [
237 orig_bind._with_value(new_bind.value, maintain_key=True)
238 for orig_bind, new_bind in zip(
239 rec.closure_bindparams, bindparams
240 )
241 ]
243 self._rec = rec
245 if cache_key is not traversals.NO_CACHE:
246 if self.parent_lambda is not None:
247 bindparams[:0] = self.parent_lambda._resolved_bindparams
249 lambda_element = self
250 while lambda_element is not None:
251 rec = lambda_element._rec
252 if rec.bindparam_trackers:
253 tracker_instrumented_fn = rec.tracker_instrumented_fn
254 for tracker in rec.bindparam_trackers:
255 tracker(
256 lambda_element.fn,
257 tracker_instrumented_fn,
258 bindparams,
259 )
260 lambda_element = lambda_element.parent_lambda
262 return rec
264 def __getattr__(self, key):
265 return getattr(self._rec.expected_expr, key)
267 @property
268 def _is_sequence(self):
269 return self._rec.is_sequence
271 @property
272 def _select_iterable(self):
273 if self._is_sequence:
274 return itertools.chain.from_iterable(
275 [element._select_iterable for element in self._resolved]
276 )
278 else:
279 return self._resolved._select_iterable
281 @property
282 def _from_objects(self):
283 if self._is_sequence:
284 return itertools.chain.from_iterable(
285 [element._from_objects for element in self._resolved]
286 )
288 else:
289 return self._resolved._from_objects
291 def _param_dict(self):
292 return {b.key: b.value for b in self._resolved_bindparams}
294 def _setup_binds_for_tracked_expr(self, expr):
295 bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
297 def replace(thing):
298 if isinstance(thing, elements.BindParameter):
300 if thing.key in bindparam_lookup:
301 bind = bindparam_lookup[thing.key]
302 if thing.expanding:
303 bind.expanding = True
304 bind.expand_op = thing.expand_op
305 bind.type = thing.type
306 return bind
308 if self._rec.is_sequence:
309 expr = [
310 visitors.replacement_traverse(sub_expr, {}, replace)
311 for sub_expr in expr
312 ]
313 elif getattr(expr, "is_clause_element", False):
314 expr = visitors.replacement_traverse(expr, {}, replace)
316 return expr
318 def _copy_internals(
319 self, clone=_clone, deferred_copy_internals=None, **kw
320 ):
321 # TODO: this needs A LOT of tests
322 self._resolved = clone(
323 self._resolved,
324 deferred_copy_internals=deferred_copy_internals,
325 **kw
326 )
328 @util.memoized_property
329 def _resolved(self):
330 expr = self._rec.expected_expr
332 if self._resolved_bindparams:
333 expr = self._setup_binds_for_tracked_expr(expr)
335 return expr
337 def _gen_cache_key(self, anon_map, bindparams):
338 if self.closure_cache_key is traversals.NO_CACHE:
339 anon_map[traversals.NO_CACHE] = True
340 return None
342 cache_key = (
343 self.fn.__code__,
344 self.__class__,
345 ) + self.closure_cache_key
347 parent = self.parent_lambda
348 while parent is not None:
349 cache_key = (
350 (parent.fn.__code__,) + parent.closure_cache_key + cache_key
351 )
353 parent = parent.parent_lambda
355 if self._resolved_bindparams:
356 bindparams.extend(self._resolved_bindparams)
357 return cache_key
359 def _invoke_user_fn(self, fn, *arg):
360 return fn()
363class DeferredLambdaElement(LambdaElement):
364 """A LambdaElement where the lambda accepts arguments and is
365 invoked within the compile phase with special context.
367 This lambda doesn't normally produce its real SQL expression outside of the
368 compile phase. It is passed a fixed set of initial arguments
369 so that it can generate a sample expression.
371 """
373 def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()):
374 self.lambda_args = lambda_args
375 super(DeferredLambdaElement, self).__init__(fn, role, opts)
377 def _invoke_user_fn(self, fn, *arg):
378 return fn(*self.lambda_args)
380 def _resolve_with_args(self, *lambda_args):
381 tracker_fn = self._rec.tracker_instrumented_fn
382 expr = tracker_fn(*lambda_args)
384 expr = coercions.expect(self.role, expr)
386 expr = self._setup_binds_for_tracked_expr(expr)
388 # this validation is getting very close, but not quite, to achieving
389 # #5767. The problem is if the base lambda uses an unnamed column
390 # as is very common with mixins, the parameter name is different
391 # and it produces a false positive; that is, for the documented case
392 # that is exactly what people will be doing, it doesn't work, so
393 # I'm not really sure how to handle this right now.
394 # expected_binds = [
395 # b._orig_key
396 # for b in self._rec.expr._generate_cache_key()[1]
397 # if b.required
398 # ]
399 # got_binds = [
400 # b._orig_key for b in expr._generate_cache_key()[1] if b.required
401 # ]
402 # if expected_binds != got_binds:
403 # raise exc.InvalidRequestError(
404 # "Lambda callable at %s produced a different set of bound "
405 # "parameters than its original run: %s"
406 # % (self.fn.__code__, ", ".join(got_binds))
407 # )
409 # TODO: TEST TEST TEST, this is very out there
410 for deferred_copy_internals in self._transforms:
411 expr = deferred_copy_internals(expr)
413 return expr
415 def _copy_internals(
416 self, clone=_clone, deferred_copy_internals=None, **kw
417 ):
418 super(DeferredLambdaElement, self)._copy_internals(
419 clone=clone,
420 deferred_copy_internals=deferred_copy_internals, # **kw
421 opts=kw,
422 )
424 # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know
425 # our expression yet. so hold onto the replacement
426 if deferred_copy_internals:
427 self._transforms += (deferred_copy_internals,)
430class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
431 """Represent a composable SQL statement as a :class:`_sql.LambdaElement`.
433 The :class:`_sql.StatementLambdaElement` is constructed using the
434 :func:`_sql.lambda_stmt` function::
437 from sqlalchemy import lambda_stmt
439 stmt = lambda_stmt(lambda: select(table))
441 Once constructed, additional criteria can be built onto the statement
442 by adding subsequent lambdas, which accept the existing statement
443 object as a single parameter::
445 stmt += lambda s: s.where(table.c.col == parameter)
448 .. versionadded:: 1.4
450 .. seealso::
452 :ref:`engine_lambda_caching`
454 """
456 def __add__(self, other):
457 return self.add_criteria(other)
459 def add_criteria(
460 self,
461 other,
462 enable_tracking=True,
463 track_on=None,
464 track_closure_variables=True,
465 track_bound_values=True,
466 ):
467 """Add new criteria to this :class:`_sql.StatementLambdaElement`.
469 E.g.::
471 >>> def my_stmt(parameter):
472 ... stmt = lambda_stmt(
473 ... lambda: select(table.c.x, table.c.y),
474 ... )
475 ... stmt = stmt.add_criteria(
476 ... lambda: table.c.x > parameter
477 ... )
478 ... return stmt
480 The :meth:`_sql.StatementLambdaElement.add_criteria` method is
481 equivalent to using the Python addition operator to add a new
482 lambda, except that additional arguments may be added including
483 ``track_closure_values`` and ``track_on``::
485 >>> def my_stmt(self, foo):
486 ... stmt = lambda_stmt(
487 ... lambda: select(func.max(foo.x, foo.y)),
488 ... track_closure_variables=False
489 ... )
490 ... stmt = stmt.add_criteria(
491 ... lambda: self.where_criteria,
492 ... track_on=[self]
493 ... )
494 ... return stmt
496 See :func:`_sql.lambda_stmt` for a description of the parameters
497 accepted.
499 """
501 opts = self.opts + dict(
502 enable_tracking=enable_tracking,
503 track_closure_variables=track_closure_variables,
504 global_track_bound_values=self.opts.global_track_bound_values,
505 track_on=track_on,
506 track_bound_values=track_bound_values,
507 )
509 return LinkedLambdaElement(other, parent_lambda=self, opts=opts)
511 def _execute_on_connection(
512 self, connection, multiparams, params, execution_options
513 ):
514 if self._rec.expected_expr.supports_execution:
515 return connection._execute_clauseelement(
516 self, multiparams, params, execution_options
517 )
518 else:
519 raise exc.ObjectNotExecutableError(self)
521 @property
522 def _with_options(self):
523 return self._rec.expected_expr._with_options
525 @property
526 def _effective_plugin_target(self):
527 return self._rec.expected_expr._effective_plugin_target
529 @property
530 def _execution_options(self):
531 return self._rec.expected_expr._execution_options
533 def spoil(self):
534 """Return a new :class:`.StatementLambdaElement` that will run
535 all lambdas unconditionally each time.
537 """
538 return NullLambdaStatement(self.fn())
541class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement):
542 """Provides the :class:`.StatementLambdaElement` API but does not
543 cache or analyze lambdas.
545 the lambdas are instead invoked immediately.
547 The intended use is to isolate issues that may arise when using
548 lambda statements.
550 """
552 __visit_name__ = "lambda_element"
554 _is_lambda_element = True
556 _traverse_internals = [
557 ("_resolved", visitors.InternalTraversal.dp_clauseelement)
558 ]
560 def __init__(self, statement):
561 self._resolved = statement
562 self._propagate_attrs = statement._propagate_attrs
564 def __getattr__(self, key):
565 return getattr(self._resolved, key)
567 def __add__(self, other):
568 statement = other(self._resolved)
570 return NullLambdaStatement(statement)
572 def add_criteria(self, other, **kw):
573 statement = other(self._resolved)
575 return NullLambdaStatement(statement)
577 def _execute_on_connection(
578 self, connection, multiparams, params, execution_options
579 ):
580 if self._resolved.supports_execution:
581 return connection._execute_clauseelement(
582 self, multiparams, params, execution_options
583 )
584 else:
585 raise exc.ObjectNotExecutableError(self)
588class LinkedLambdaElement(StatementLambdaElement):
589 """Represent subsequent links of a :class:`.StatementLambdaElement`."""
591 role = None
593 def __init__(self, fn, parent_lambda, opts):
594 self.opts = opts
595 self.fn = fn
596 self.parent_lambda = parent_lambda
598 self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
599 self._retrieve_tracker_rec(fn, self, opts)
600 self._propagate_attrs = parent_lambda._propagate_attrs
602 def _invoke_user_fn(self, fn, *arg):
603 return fn(self.parent_lambda._resolved)
606class AnalyzedCode(object):
607 __slots__ = (
608 "track_closure_variables",
609 "track_bound_values",
610 "bindparam_trackers",
611 "closure_trackers",
612 "build_py_wrappers",
613 )
614 _fns = weakref.WeakKeyDictionary()
616 _generation_mutex = threading.RLock()
618 @classmethod
619 def get(cls, fn, lambda_element, lambda_kw, **kw):
620 try:
621 # TODO: validate kw haven't changed?
622 return cls._fns[fn.__code__]
623 except KeyError:
624 pass
626 with cls._generation_mutex:
627 # check for other thread already created object
628 if fn.__code__ in cls._fns:
629 return cls._fns[fn.__code__]
631 cls._fns[fn.__code__] = analyzed = AnalyzedCode(
632 fn, lambda_element, lambda_kw, **kw
633 )
634 return analyzed
636 def __init__(self, fn, lambda_element, opts):
637 if inspect.ismethod(fn):
638 raise exc.ArgumentError(
639 "Method %s may not be passed as a SQL expression" % fn
640 )
641 closure = fn.__closure__
643 self.track_bound_values = (
644 opts.track_bound_values and opts.global_track_bound_values
645 )
646 enable_tracking = opts.enable_tracking
647 track_on = opts.track_on
648 track_closure_variables = opts.track_closure_variables
650 self.track_closure_variables = track_closure_variables and not track_on
652 # a list of callables generated from _bound_parameter_getter_*
653 # functions. Each of these uses a PyWrapper object to retrieve
654 # a parameter value
655 self.bindparam_trackers = []
657 # a list of callables generated from _cache_key_getter_* functions
658 # these callables work to generate a cache key for the lambda
659 # based on what's inside its closure variables.
660 self.closure_trackers = []
662 self.build_py_wrappers = []
664 if enable_tracking:
665 if track_on:
666 self._init_track_on(track_on)
668 self._init_globals(fn)
670 if closure:
671 self._init_closure(fn)
673 self._setup_additional_closure_trackers(fn, lambda_element, opts)
675 def _init_track_on(self, track_on):
676 self.closure_trackers.extend(
677 self._cache_key_getter_track_on(idx, elem)
678 for idx, elem in enumerate(track_on)
679 )
681 def _init_globals(self, fn):
682 build_py_wrappers = self.build_py_wrappers
683 bindparam_trackers = self.bindparam_trackers
684 track_bound_values = self.track_bound_values
686 for name in fn.__code__.co_names:
687 if name not in fn.__globals__:
688 continue
690 _bound_value = self._roll_down_to_literal(fn.__globals__[name])
692 if coercions._deep_is_literal(_bound_value):
693 build_py_wrappers.append((name, None))
694 if track_bound_values:
695 bindparam_trackers.append(
696 self._bound_parameter_getter_func_globals(name)
697 )
699 def _init_closure(self, fn):
700 build_py_wrappers = self.build_py_wrappers
701 closure = fn.__closure__
703 track_bound_values = self.track_bound_values
704 track_closure_variables = self.track_closure_variables
705 bindparam_trackers = self.bindparam_trackers
706 closure_trackers = self.closure_trackers
708 for closure_index, (fv, cell) in enumerate(
709 zip(fn.__code__.co_freevars, closure)
710 ):
711 _bound_value = self._roll_down_to_literal(cell.cell_contents)
713 if coercions._deep_is_literal(_bound_value):
714 build_py_wrappers.append((fv, closure_index))
715 if track_bound_values:
716 bindparam_trackers.append(
717 self._bound_parameter_getter_func_closure(
718 fv, closure_index
719 )
720 )
721 else:
722 # for normal cell contents, add them to a list that
723 # we can compare later when we get new lambdas. if
724 # any identities have changed, then we will
725 # recalculate the whole lambda and run it again.
727 if track_closure_variables:
728 closure_trackers.append(
729 self._cache_key_getter_closure_variable(
730 fn, fv, closure_index, cell.cell_contents
731 )
732 )
734 def _setup_additional_closure_trackers(self, fn, lambda_element, opts):
735 # an additional step is to actually run the function, then
736 # go through the PyWrapper objects that were set up to catch a bound
737 # parameter. then if they *didn't* make a param, oh they're another
738 # object in the closure we have to track for our cache key. so
739 # create trackers to catch those.
741 analyzed_function = AnalyzedFunction(
742 self,
743 lambda_element,
744 None,
745 fn,
746 )
748 closure_trackers = self.closure_trackers
750 for pywrapper in analyzed_function.closure_pywrappers:
751 if not pywrapper._sa__has_param:
752 closure_trackers.append(
753 self._cache_key_getter_tracked_literal(fn, pywrapper)
754 )
756 @classmethod
757 def _roll_down_to_literal(cls, element):
758 is_clause_element = hasattr(element, "__clause_element__")
760 if is_clause_element:
761 while not isinstance(
762 element, (elements.ClauseElement, schema.SchemaItem, type)
763 ):
764 try:
765 element = element.__clause_element__()
766 except AttributeError:
767 break
769 if not is_clause_element:
770 insp = inspection.inspect(element, raiseerr=False)
771 if insp is not None:
772 try:
773 return insp.__clause_element__()
774 except AttributeError:
775 return insp
777 # TODO: should we coerce consts None/True/False here?
778 return element
779 else:
780 return element
782 def _bound_parameter_getter_func_globals(self, name):
783 """Return a getter that will extend a list of bound parameters
784 with new entries from the ``__globals__`` collection of a particular
785 lambda.
787 """
789 def extract_parameter_value(
790 current_fn, tracker_instrumented_fn, result
791 ):
792 wrapper = tracker_instrumented_fn.__globals__[name]
793 object.__getattribute__(wrapper, "_extract_bound_parameters")(
794 current_fn.__globals__[name], result
795 )
797 return extract_parameter_value
799 def _bound_parameter_getter_func_closure(self, name, closure_index):
800 """Return a getter that will extend a list of bound parameters
801 with new entries from the ``__closure__`` collection of a particular
802 lambda.
804 """
806 def extract_parameter_value(
807 current_fn, tracker_instrumented_fn, result
808 ):
809 wrapper = tracker_instrumented_fn.__closure__[
810 closure_index
811 ].cell_contents
812 object.__getattribute__(wrapper, "_extract_bound_parameters")(
813 current_fn.__closure__[closure_index].cell_contents, result
814 )
816 return extract_parameter_value
818 def _cache_key_getter_track_on(self, idx, elem):
819 """Return a getter that will extend a cache key with new entries
820 from the "track_on" parameter passed to a :class:`.LambdaElement`.
822 """
824 if isinstance(elem, tuple):
825 # tuple must contain hascachekey elements
826 def get(closure, opts, anon_map, bindparams):
827 return tuple(
828 tup_elem._gen_cache_key(anon_map, bindparams)
829 for tup_elem in opts.track_on[idx]
830 )
832 elif isinstance(elem, traversals.HasCacheKey):
834 def get(closure, opts, anon_map, bindparams):
835 return opts.track_on[idx]._gen_cache_key(anon_map, bindparams)
837 else:
839 def get(closure, opts, anon_map, bindparams):
840 return opts.track_on[idx]
842 return get
844 def _cache_key_getter_closure_variable(
845 self,
846 fn,
847 variable_name,
848 idx,
849 cell_contents,
850 use_clause_element=False,
851 use_inspect=False,
852 ):
853 """Return a getter that will extend a cache key with new entries
854 from the ``__closure__`` collection of a particular lambda.
856 """
858 if isinstance(cell_contents, traversals.HasCacheKey):
860 def get(closure, opts, anon_map, bindparams):
862 obj = closure[idx].cell_contents
863 if use_inspect:
864 obj = inspection.inspect(obj)
865 elif use_clause_element:
866 while hasattr(obj, "__clause_element__"):
867 if not getattr(obj, "is_clause_element", False):
868 obj = obj.__clause_element__()
870 return obj._gen_cache_key(anon_map, bindparams)
872 elif isinstance(cell_contents, types.FunctionType):
874 def get(closure, opts, anon_map, bindparams):
875 return closure[idx].cell_contents.__code__
877 elif isinstance(cell_contents, collections_abc.Sequence):
879 def get(closure, opts, anon_map, bindparams):
880 contents = closure[idx].cell_contents
882 try:
883 return tuple(
884 elem._gen_cache_key(anon_map, bindparams)
885 for elem in contents
886 )
887 except AttributeError as ae:
888 self._raise_for_uncacheable_closure_variable(
889 variable_name, fn, from_=ae
890 )
892 else:
893 # if the object is a mapped class or aliased class, or some
894 # other object in the ORM realm of things like that, imitate
895 # the logic used in coercions.expect() to roll it down to the
896 # SQL element
897 element = cell_contents
898 is_clause_element = False
899 while hasattr(element, "__clause_element__"):
900 is_clause_element = True
901 if not getattr(element, "is_clause_element", False):
902 element = element.__clause_element__()
903 else:
904 break
906 if not is_clause_element:
907 insp = inspection.inspect(element, raiseerr=False)
908 if insp is not None:
909 return self._cache_key_getter_closure_variable(
910 fn, variable_name, idx, insp, use_inspect=True
911 )
912 else:
913 return self._cache_key_getter_closure_variable(
914 fn, variable_name, idx, element, use_clause_element=True
915 )
917 self._raise_for_uncacheable_closure_variable(variable_name, fn)
919 return get
921 def _raise_for_uncacheable_closure_variable(
922 self, variable_name, fn, from_=None
923 ):
924 util.raise_(
925 exc.InvalidRequestError(
926 "Closure variable named '%s' inside of lambda callable %s "
927 "does not refer to a cacheable SQL element, and also does not "
928 "appear to be serving as a SQL literal bound value based on "
929 "the default "
930 "SQL expression returned by the function. This variable "
931 "needs to remain outside the scope of a SQL-generating lambda "
932 "so that a proper cache key may be generated from the "
933 "lambda's state. Evaluate this variable outside of the "
934 "lambda, set track_on=[<elements>] to explicitly select "
935 "closure elements to track, or set "
936 "track_closure_variables=False to exclude "
937 "closure variables from being part of the cache key."
938 % (variable_name, fn.__code__),
939 ),
940 from_=from_,
941 )
943 def _cache_key_getter_tracked_literal(self, fn, pytracker):
944 """Return a getter that will extend a cache key with new entries
945 from the ``__closure__`` collection of a particular lambda.
947 this getter differs from _cache_key_getter_closure_variable
948 in that these are detected after the function is run, and PyWrapper
949 objects have recorded that a particular literal value is in fact
950 not being interpreted as a bound parameter.
952 """
954 elem = pytracker._sa__to_evaluate
955 closure_index = pytracker._sa__closure_index
956 variable_name = pytracker._sa__name
958 return self._cache_key_getter_closure_variable(
959 fn, variable_name, closure_index, elem
960 )
963class NonAnalyzedFunction(object):
964 __slots__ = ("expr",)
966 closure_bindparams = None
967 bindparam_trackers = None
969 def __init__(self, expr):
970 self.expr = expr
972 @property
973 def expected_expr(self):
974 return self.expr
977class AnalyzedFunction(object):
978 __slots__ = (
979 "analyzed_code",
980 "fn",
981 "closure_pywrappers",
982 "tracker_instrumented_fn",
983 "expr",
984 "bindparam_trackers",
985 "expected_expr",
986 "is_sequence",
987 "propagate_attrs",
988 "closure_bindparams",
989 )
991 def __init__(
992 self,
993 analyzed_code,
994 lambda_element,
995 apply_propagate_attrs,
996 fn,
997 ):
998 self.analyzed_code = analyzed_code
999 self.fn = fn
1001 self.bindparam_trackers = analyzed_code.bindparam_trackers
1003 self._instrument_and_run_function(lambda_element)
1005 self._coerce_expression(lambda_element, apply_propagate_attrs)
1007 def _instrument_and_run_function(self, lambda_element):
1008 analyzed_code = self.analyzed_code
1010 fn = self.fn
1011 self.closure_pywrappers = closure_pywrappers = []
1013 build_py_wrappers = analyzed_code.build_py_wrappers
1015 if not build_py_wrappers:
1016 self.tracker_instrumented_fn = tracker_instrumented_fn = fn
1017 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
1018 else:
1019 track_closure_variables = analyzed_code.track_closure_variables
1020 closure = fn.__closure__
1022 # will form the __closure__ of the function when we rebuild it
1023 if closure:
1024 new_closure = {
1025 fv: cell.cell_contents
1026 for fv, cell in zip(fn.__code__.co_freevars, closure)
1027 }
1028 else:
1029 new_closure = {}
1031 # will form the __globals__ of the function when we rebuild it
1032 new_globals = fn.__globals__.copy()
1034 for name, closure_index in build_py_wrappers:
1035 if closure_index is not None:
1036 value = closure[closure_index].cell_contents
1037 new_closure[name] = bind = PyWrapper(
1038 fn,
1039 name,
1040 value,
1041 closure_index=closure_index,
1042 track_bound_values=(
1043 self.analyzed_code.track_bound_values
1044 ),
1045 )
1046 if track_closure_variables:
1047 closure_pywrappers.append(bind)
1048 else:
1049 value = fn.__globals__[name]
1050 new_globals[name] = bind = PyWrapper(fn, name, value)
1052 # rewrite the original fn. things that look like they will
1053 # become bound parameters are wrapped in a PyWrapper.
1054 self.tracker_instrumented_fn = (
1055 tracker_instrumented_fn
1056 ) = self._rewrite_code_obj(
1057 fn,
1058 [new_closure[name] for name in fn.__code__.co_freevars],
1059 new_globals,
1060 )
1062 # now invoke the function. This will give us a new SQL
1063 # expression, but all the places that there would be a bound
1064 # parameter, the PyWrapper in its place will give us a bind
1065 # with a predictable name we can match up later.
1067 # additionally, each PyWrapper will log that it did in fact
1068 # create a parameter, otherwise, it's some kind of Python
1069 # object in the closure and we want to track that, to make
1070 # sure it doesn't change to something else, or if it does,
1071 # that we create a different tracked function with that
1072 # variable.
1073 self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
1075 def _coerce_expression(self, lambda_element, apply_propagate_attrs):
1076 """Run the tracker-generated expression through coercion rules.
1078 After the user-defined lambda has been invoked to produce a statement
1079 for re-use, run it through coercion rules to both check that it's the
1080 correct type of object and also to coerce it to its useful form.
1082 """
1084 parent_lambda = lambda_element.parent_lambda
1085 expr = self.expr
1087 if parent_lambda is None:
1088 if isinstance(expr, collections_abc.Sequence):
1089 self.expected_expr = [
1090 coercions.expect(
1091 lambda_element.role,
1092 sub_expr,
1093 apply_propagate_attrs=apply_propagate_attrs,
1094 )
1095 for sub_expr in expr
1096 ]
1097 self.is_sequence = True
1098 else:
1099 self.expected_expr = coercions.expect(
1100 lambda_element.role,
1101 expr,
1102 apply_propagate_attrs=apply_propagate_attrs,
1103 )
1104 self.is_sequence = False
1105 else:
1106 self.expected_expr = expr
1107 self.is_sequence = False
1109 if apply_propagate_attrs is not None:
1110 self.propagate_attrs = apply_propagate_attrs._propagate_attrs
1111 else:
1112 self.propagate_attrs = util.EMPTY_DICT
1114 def _rewrite_code_obj(self, f, cell_values, globals_):
1115 """Return a copy of f, with a new closure and new globals
1117 yes it works in pypy :P
1119 """
1121 argrange = range(len(cell_values))
1123 code = "def make_cells():\n"
1124 if cell_values:
1125 code += " (%s) = (%s)\n" % (
1126 ", ".join("i%d" % i for i in argrange),
1127 ", ".join("o%d" % i for i in argrange),
1128 )
1129 code += " def closure():\n"
1130 code += " return %s\n" % ", ".join("i%d" % i for i in argrange)
1131 code += " return closure.__closure__"
1132 vars_ = {"o%d" % i: cell_values[i] for i in argrange}
1133 compat.exec_(code, vars_, vars_)
1134 closure = vars_["make_cells"]()
1136 func = type(f)(
1137 f.__code__, globals_, f.__name__, f.__defaults__, closure
1138 )
1139 if sys.version_info >= (3,):
1140 func.__annotations__ = f.__annotations__
1141 func.__kwdefaults__ = f.__kwdefaults__
1142 func.__doc__ = f.__doc__
1143 func.__module__ = f.__module__
1145 return func
1148class PyWrapper(ColumnOperators):
1149 """A wrapper object that is injected into the ``__globals__`` and
1150 ``__closure__`` of a Python function.
1152 When the function is instrumented with :class:`.PyWrapper` objects, it is
1153 then invoked just once in order to set up the wrappers. We look through
1154 all the :class:`.PyWrapper` objects we made to find the ones that generated
1155 a :class:`.BindParameter` object, e.g. the expression system interpreted
1156 something as a literal. Those positions in the globals/closure are then
1157 ones that we will look at, each time a new lambda comes in that refers to
1158 the same ``__code__`` object. In this way, we keep a single version of
1159 the SQL expression that this lambda produced, without calling upon the
1160 Python function that created it more than once, unless its other closure
1161 variables have changed. The expression is then transformed to have the
1162 new bound values embedded into it.
1164 """
1166 def __init__(
1167 self,
1168 fn,
1169 name,
1170 to_evaluate,
1171 closure_index=None,
1172 getter=None,
1173 track_bound_values=True,
1174 ):
1175 self.fn = fn
1176 self._name = name
1177 self._to_evaluate = to_evaluate
1178 self._param = None
1179 self._has_param = False
1180 self._bind_paths = {}
1181 self._getter = getter
1182 self._closure_index = closure_index
1183 self.track_bound_values = track_bound_values
1185 def __call__(self, *arg, **kw):
1186 elem = object.__getattribute__(self, "_to_evaluate")
1187 value = elem(*arg, **kw)
1188 if (
1189 self._sa_track_bound_values
1190 and coercions._deep_is_literal(value)
1191 and not isinstance(
1192 # TODO: coverage where an ORM option or similar is here
1193 value,
1194 traversals.HasCacheKey,
1195 )
1196 ):
1197 name = object.__getattribute__(self, "_name")
1198 raise exc.InvalidRequestError(
1199 "Can't invoke Python callable %s() inside of lambda "
1200 "expression argument at %s; lambda SQL constructs should "
1201 "not invoke functions from closure variables to produce "
1202 "literal values since the "
1203 "lambda SQL system normally extracts bound values without "
1204 "actually "
1205 "invoking the lambda or any functions within it. Call the "
1206 "function outside of the "
1207 "lambda and assign to a local variable that is used in the "
1208 "lambda as a closure variable, or set "
1209 "track_bound_values=False if the return value of this "
1210 "function is used in some other way other than a SQL bound "
1211 "value." % (name, self._sa_fn.__code__)
1212 )
1213 else:
1214 return value
1216 def operate(self, op, *other, **kwargs):
1217 elem = object.__getattribute__(self, "_py_wrapper_literal")()
1218 return op(elem, *other, **kwargs)
1220 def reverse_operate(self, op, other, **kwargs):
1221 elem = object.__getattribute__(self, "_py_wrapper_literal")()
1222 return op(other, elem, **kwargs)
1224 def _extract_bound_parameters(self, starting_point, result_list):
1225 param = object.__getattribute__(self, "_param")
1226 if param is not None:
1227 param = param._with_value(starting_point, maintain_key=True)
1228 result_list.append(param)
1229 for pywrapper in object.__getattribute__(self, "_bind_paths").values():
1230 getter = object.__getattribute__(pywrapper, "_getter")
1231 element = getter(starting_point)
1232 pywrapper._sa__extract_bound_parameters(element, result_list)
1234 def _py_wrapper_literal(self, expr=None, operator=None, **kw):
1235 param = object.__getattribute__(self, "_param")
1236 to_evaluate = object.__getattribute__(self, "_to_evaluate")
1237 if param is None:
1238 name = object.__getattribute__(self, "_name")
1239 self._param = param = elements.BindParameter(
1240 name,
1241 required=False,
1242 unique=True,
1243 _compared_to_operator=operator,
1244 _compared_to_type=expr.type if expr is not None else None,
1245 )
1246 self._has_param = True
1247 return param._with_value(to_evaluate, maintain_key=True)
1249 def __bool__(self):
1250 to_evaluate = object.__getattribute__(self, "_to_evaluate")
1251 return bool(to_evaluate)
1253 def __nonzero__(self):
1254 to_evaluate = object.__getattribute__(self, "_to_evaluate")
1255 return bool(to_evaluate)
1257 def __getattribute__(self, key):
1258 if key.startswith("_sa_"):
1259 return object.__getattribute__(self, key[4:])
1260 elif key in (
1261 "__clause_element__",
1262 "operate",
1263 "reverse_operate",
1264 "_py_wrapper_literal",
1265 "__class__",
1266 "__dict__",
1267 ):
1268 return object.__getattribute__(self, key)
1270 if key.startswith("__"):
1271 elem = object.__getattribute__(self, "_to_evaluate")
1272 return getattr(elem, key)
1273 else:
1274 return self._sa__add_getter(key, operator.attrgetter)
1276 def __iter__(self):
1277 elem = object.__getattribute__(self, "_to_evaluate")
1278 return iter(elem)
1280 def __getitem__(self, key):
1281 elem = object.__getattribute__(self, "_to_evaluate")
1282 if not hasattr(elem, "__getitem__"):
1283 raise AttributeError("__getitem__")
1285 if isinstance(key, PyWrapper):
1286 # TODO: coverage
1287 raise exc.InvalidRequestError(
1288 "Dictionary keys / list indexes inside of a cached "
1289 "lambda must be Python literals only"
1290 )
1291 return self._sa__add_getter(key, operator.itemgetter)
1293 def _add_getter(self, key, getter_fn):
1295 bind_paths = object.__getattribute__(self, "_bind_paths")
1297 bind_path_key = (key, getter_fn)
1298 if bind_path_key in bind_paths:
1299 return bind_paths[bind_path_key]
1301 getter = getter_fn(key)
1302 elem = object.__getattribute__(self, "_to_evaluate")
1303 value = getter(elem)
1305 rolled_down_value = AnalyzedCode._roll_down_to_literal(value)
1307 if coercions._deep_is_literal(rolled_down_value):
1308 wrapper = PyWrapper(self._sa_fn, key, value, getter=getter)
1309 bind_paths[bind_path_key] = wrapper
1310 return wrapper
1311 else:
1312 return value
1315@inspection._inspects(LambdaElement)
1316def insp(lmb):
1317 return inspection.inspect(lmb._resolved)