1# orm/evaluator.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9"""Evaluation functions used **INTERNALLY** by ORM DML use cases.
10
11
12This module is **private, for internal use by SQLAlchemy**.
13
14.. versionchanged:: 2.0.4 renamed ``EvaluatorCompiler`` to
15 ``_EvaluatorCompiler``.
16
17"""
18
19
20from __future__ import annotations
21
22from typing import Type
23
24from . import exc as orm_exc
25from .base import LoaderCallableStatus
26from .base import PassiveFlag
27from .. import exc
28from .. import inspect
29from ..sql import and_
30from ..sql import operators
31from ..sql.sqltypes import Integer
32from ..sql.sqltypes import Numeric
33from ..util import warn_deprecated
34
35
36class UnevaluatableError(exc.InvalidRequestError):
37 pass
38
39
40class _NoObject(operators.ColumnOperators):
41 def operate(self, *arg, **kw):
42 return None
43
44 def reverse_operate(self, *arg, **kw):
45 return None
46
47
48class _ExpiredObject(operators.ColumnOperators):
49 def operate(self, *arg, **kw):
50 return self
51
52 def reverse_operate(self, *arg, **kw):
53 return self
54
55
56_NO_OBJECT = _NoObject()
57_EXPIRED_OBJECT = _ExpiredObject()
58
59
60class _EvaluatorCompiler:
61 def __init__(self, target_cls=None):
62 self.target_cls = target_cls
63
64 def process(self, clause, *clauses):
65 if clauses:
66 clause = and_(clause, *clauses)
67
68 meth = getattr(self, f"visit_{clause.__visit_name__}", None)
69 if not meth:
70 raise UnevaluatableError(
71 f"Cannot evaluate {type(clause).__name__}"
72 )
73 return meth(clause)
74
75 def visit_grouping(self, clause):
76 return self.process(clause.element)
77
78 def visit_null(self, clause):
79 return lambda obj: None
80
81 def visit_false(self, clause):
82 return lambda obj: False
83
84 def visit_true(self, clause):
85 return lambda obj: True
86
87 def visit_column(self, clause):
88 try:
89 parentmapper = clause._annotations["parentmapper"]
90 except KeyError as ke:
91 raise UnevaluatableError(
92 f"Cannot evaluate column: {clause}"
93 ) from ke
94
95 if self.target_cls and not issubclass(
96 self.target_cls, parentmapper.class_
97 ):
98 raise UnevaluatableError(
99 "Can't evaluate criteria against "
100 f"alternate class {parentmapper.class_}"
101 )
102
103 parentmapper._check_configure()
104
105 # we'd like to use "proxy_key" annotation to get the "key", however
106 # in relationship primaryjoin cases proxy_key is sometimes deannotated
107 # and sometimes apparently not present in the first place (?).
108 # While I can stop it from being deannotated (though need to see if
109 # this breaks other things), not sure right now about cases where it's
110 # not there in the first place. can fix at some later point.
111 # key = clause._annotations["proxy_key"]
112
113 # for now, use the old way
114 try:
115 key = parentmapper._columntoproperty[clause].key
116 except orm_exc.UnmappedColumnError as err:
117 raise UnevaluatableError(
118 f"Cannot evaluate expression: {err}"
119 ) from err
120
121 # note this used to fall back to a simple `getattr(obj, key)` evaluator
122 # if impl was None; as of #8656, we ensure mappers are configured
123 # so that impl is available
124 impl = parentmapper.class_manager[key].impl
125
126 def get_corresponding_attr(obj):
127 if obj is None:
128 return _NO_OBJECT
129 state = inspect(obj)
130 dict_ = state.dict
131
132 value = impl.get(
133 state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH
134 )
135 if value is LoaderCallableStatus.PASSIVE_NO_RESULT:
136 return _EXPIRED_OBJECT
137 return value
138
139 return get_corresponding_attr
140
141 def visit_tuple(self, clause):
142 return self.visit_clauselist(clause)
143
144 def visit_expression_clauselist(self, clause):
145 return self.visit_clauselist(clause)
146
147 def visit_clauselist(self, clause):
148 evaluators = [self.process(clause) for clause in clause.clauses]
149
150 dispatch = (
151 f"visit_{clause.operator.__name__.rstrip('_')}_clauselist_op"
152 )
153 meth = getattr(self, dispatch, None)
154 if meth:
155 return meth(clause.operator, evaluators, clause)
156 else:
157 raise UnevaluatableError(
158 f"Cannot evaluate clauselist with operator {clause.operator}"
159 )
160
161 def visit_binary(self, clause):
162 eval_left = self.process(clause.left)
163 eval_right = self.process(clause.right)
164
165 dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op"
166 meth = getattr(self, dispatch, None)
167 if meth:
168 return meth(clause.operator, eval_left, eval_right, clause)
169 else:
170 raise UnevaluatableError(
171 f"Cannot evaluate {type(clause).__name__} with "
172 f"operator {clause.operator}"
173 )
174
175 def visit_or_clauselist_op(self, operator, evaluators, clause):
176 def evaluate(obj):
177 has_null = False
178 for sub_evaluate in evaluators:
179 value = sub_evaluate(obj)
180 if value is _EXPIRED_OBJECT:
181 return _EXPIRED_OBJECT
182 elif value:
183 return True
184 has_null = has_null or value is None
185 if has_null:
186 return None
187 return False
188
189 return evaluate
190
191 def visit_and_clauselist_op(self, operator, evaluators, clause):
192 def evaluate(obj):
193 for sub_evaluate in evaluators:
194 value = sub_evaluate(obj)
195 if value is _EXPIRED_OBJECT:
196 return _EXPIRED_OBJECT
197
198 if not value:
199 if value is None or value is _NO_OBJECT:
200 return None
201 return False
202 return True
203
204 return evaluate
205
206 def visit_comma_op_clauselist_op(self, operator, evaluators, clause):
207 def evaluate(obj):
208 values = []
209 for sub_evaluate in evaluators:
210 value = sub_evaluate(obj)
211 if value is _EXPIRED_OBJECT:
212 return _EXPIRED_OBJECT
213 elif value is None or value is _NO_OBJECT:
214 return None
215 values.append(value)
216 return tuple(values)
217
218 return evaluate
219
220 def visit_custom_op_binary_op(
221 self, operator, eval_left, eval_right, clause
222 ):
223 if operator.python_impl:
224 return self._straight_evaluate(
225 operator, eval_left, eval_right, clause
226 )
227 else:
228 raise UnevaluatableError(
229 f"Custom operator {operator.opstring!r} can't be evaluated "
230 "in Python unless it specifies a callable using "
231 "`.python_impl`."
232 )
233
234 def visit_is_binary_op(self, operator, eval_left, eval_right, clause):
235 def evaluate(obj):
236 left_val = eval_left(obj)
237 right_val = eval_right(obj)
238 if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
239 return _EXPIRED_OBJECT
240 return left_val == right_val
241
242 return evaluate
243
244 def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause):
245 def evaluate(obj):
246 left_val = eval_left(obj)
247 right_val = eval_right(obj)
248 if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
249 return _EXPIRED_OBJECT
250 return left_val != right_val
251
252 return evaluate
253
254 def _straight_evaluate(self, operator, eval_left, eval_right, clause):
255 def evaluate(obj):
256 left_val = eval_left(obj)
257 right_val = eval_right(obj)
258 if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
259 return _EXPIRED_OBJECT
260 elif left_val is None or right_val is None:
261 return None
262
263 return operator(eval_left(obj), eval_right(obj))
264
265 return evaluate
266
267 def _straight_evaluate_numeric_only(
268 self, operator, eval_left, eval_right, clause
269 ):
270 if clause.left.type._type_affinity not in (
271 Numeric,
272 Integer,
273 ) or clause.right.type._type_affinity not in (Numeric, Integer):
274 raise UnevaluatableError(
275 f'Cannot evaluate math operator "{operator.__name__}" for '
276 f"datatypes {clause.left.type}, {clause.right.type}"
277 )
278
279 return self._straight_evaluate(operator, eval_left, eval_right, clause)
280
281 visit_add_binary_op = _straight_evaluate_numeric_only
282 visit_mul_binary_op = _straight_evaluate_numeric_only
283 visit_sub_binary_op = _straight_evaluate_numeric_only
284 visit_mod_binary_op = _straight_evaluate_numeric_only
285 visit_truediv_binary_op = _straight_evaluate_numeric_only
286 visit_lt_binary_op = _straight_evaluate
287 visit_le_binary_op = _straight_evaluate
288 visit_ne_binary_op = _straight_evaluate
289 visit_gt_binary_op = _straight_evaluate
290 visit_ge_binary_op = _straight_evaluate
291 visit_eq_binary_op = _straight_evaluate
292
293 def visit_in_op_binary_op(self, operator, eval_left, eval_right, clause):
294 return self._straight_evaluate(
295 lambda a, b: a in b if a is not _NO_OBJECT else None,
296 eval_left,
297 eval_right,
298 clause,
299 )
300
301 def visit_not_in_op_binary_op(
302 self, operator, eval_left, eval_right, clause
303 ):
304 return self._straight_evaluate(
305 lambda a, b: a not in b if a is not _NO_OBJECT else None,
306 eval_left,
307 eval_right,
308 clause,
309 )
310
311 def visit_concat_op_binary_op(
312 self, operator, eval_left, eval_right, clause
313 ):
314 return self._straight_evaluate(
315 lambda a, b: a + b, eval_left, eval_right, clause
316 )
317
318 def visit_startswith_op_binary_op(
319 self, operator, eval_left, eval_right, clause
320 ):
321 return self._straight_evaluate(
322 lambda a, b: a.startswith(b), eval_left, eval_right, clause
323 )
324
325 def visit_endswith_op_binary_op(
326 self, operator, eval_left, eval_right, clause
327 ):
328 return self._straight_evaluate(
329 lambda a, b: a.endswith(b), eval_left, eval_right, clause
330 )
331
332 def visit_unary(self, clause):
333 eval_inner = self.process(clause.element)
334 if clause.operator is operators.inv:
335
336 def evaluate(obj):
337 value = eval_inner(obj)
338 if value is _EXPIRED_OBJECT:
339 return _EXPIRED_OBJECT
340 elif value is None:
341 return None
342 return not value
343
344 return evaluate
345 raise UnevaluatableError(
346 f"Cannot evaluate {type(clause).__name__} "
347 f"with operator {clause.operator}"
348 )
349
350 def visit_bindparam(self, clause):
351 if clause.callable:
352 val = clause.callable()
353 else:
354 val = clause.value
355 return lambda obj: val
356
357
358def __getattr__(name: str) -> Type[_EvaluatorCompiler]:
359 if name == "EvaluatorCompiler":
360 warn_deprecated(
361 "Direct use of 'EvaluatorCompiler' is not supported, and this "
362 "name will be removed in a future release. "
363 "'_EvaluatorCompiler' is for internal use only",
364 "2.0",
365 )
366 return _EvaluatorCompiler
367 else:
368 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")