1# orm/evaluator.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9"""Evaluation functions used **INTERNALLY** by ORM DML use cases.
10
11
12This module is **private, for internal use by SQLAlchemy**.
13
14.. versionchanged:: 2.0.4 renamed ``EvaluatorCompiler`` to
15 ``_EvaluatorCompiler``.
16
17"""
18
19from __future__ import annotations
20
21from typing import Type
22
23from . import exc as orm_exc
24from .base import LoaderCallableStatus
25from .base import PassiveFlag
26from .. import exc
27from .. import inspect
28from ..sql import and_
29from ..sql import operators
30from ..sql.sqltypes import Concatenable
31from ..sql.sqltypes import Integer
32from ..sql.sqltypes import Numeric
33from ..util import warn_deprecated
34
35
36class UnevaluatableError(exc.InvalidRequestError):
37 pass
38
39
40class _NoObject(operators.ColumnOperators):
41 def operate(self, *arg, **kw):
42 return None
43
44 def reverse_operate(self, *arg, **kw):
45 return None
46
47
48class _ExpiredObject(operators.ColumnOperators):
49 def operate(self, *arg, **kw):
50 return self
51
52 def reverse_operate(self, *arg, **kw):
53 return self
54
55
56_NO_OBJECT = _NoObject()
57_EXPIRED_OBJECT = _ExpiredObject()
58
59
60class _EvaluatorCompiler:
61 def __init__(self, target_cls=None):
62 self.target_cls = target_cls
63
64 def process(self, clause, *clauses):
65 if clauses:
66 clause = and_(clause, *clauses)
67
68 meth = getattr(self, f"visit_{clause.__visit_name__}", None)
69 if not meth:
70 raise UnevaluatableError(
71 f"Cannot evaluate {type(clause).__name__}"
72 )
73 return meth(clause)
74
75 def visit_grouping(self, clause):
76 return self.process(clause.element)
77
78 def visit_null(self, clause):
79 return lambda obj: None
80
81 def visit_false(self, clause):
82 return lambda obj: False
83
84 def visit_true(self, clause):
85 return lambda obj: True
86
87 def visit_column(self, clause):
88 try:
89 parentmapper = clause._annotations["parentmapper"]
90 except KeyError as ke:
91 raise UnevaluatableError(
92 f"Cannot evaluate column: {clause}"
93 ) from ke
94
95 if self.target_cls and not issubclass(
96 self.target_cls, parentmapper.class_
97 ):
98 raise UnevaluatableError(
99 "Can't evaluate criteria against "
100 f"alternate class {parentmapper.class_}"
101 )
102
103 parentmapper._check_configure()
104
105 # we'd like to use "proxy_key" annotation to get the "key", however
106 # in relationship primaryjoin cases proxy_key is sometimes deannotated
107 # and sometimes apparently not present in the first place (?).
108 # While I can stop it from being deannotated (though need to see if
109 # this breaks other things), not sure right now about cases where it's
110 # not there in the first place. can fix at some later point.
111 # key = clause._annotations["proxy_key"]
112
113 # for now, use the old way
114 try:
115 key = parentmapper._columntoproperty[clause].key
116 except orm_exc.UnmappedColumnError as err:
117 raise UnevaluatableError(
118 f"Cannot evaluate expression: {err}"
119 ) from err
120
121 # note this used to fall back to a simple `getattr(obj, key)` evaluator
122 # if impl was None; as of #8656, we ensure mappers are configured
123 # so that impl is available
124 impl = parentmapper.class_manager[key].impl
125
126 def get_corresponding_attr(obj):
127 if obj is None:
128 return _NO_OBJECT
129 state = inspect(obj)
130 dict_ = state.dict
131
132 value = impl.get(
133 state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH
134 )
135 if value is LoaderCallableStatus.PASSIVE_NO_RESULT:
136 return _EXPIRED_OBJECT
137 return value
138
139 return get_corresponding_attr
140
141 def visit_tuple(self, clause):
142 return self.visit_clauselist(clause)
143
144 def visit_expression_clauselist(self, clause):
145 return self.visit_clauselist(clause)
146
147 def visit_clauselist(self, clause):
148 evaluators = [self.process(clause) for clause in clause.clauses]
149
150 dispatch = (
151 f"visit_{clause.operator.__name__.rstrip('_')}_clauselist_op"
152 )
153 meth = getattr(self, dispatch, None)
154 if meth:
155 return meth(clause.operator, evaluators, clause)
156 else:
157 raise UnevaluatableError(
158 f"Cannot evaluate clauselist with operator {clause.operator}"
159 )
160
161 def visit_binary(self, clause):
162 eval_left = self.process(clause.left)
163 eval_right = self.process(clause.right)
164
165 dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op"
166 meth = getattr(self, dispatch, None)
167 if meth:
168 return meth(clause.operator, eval_left, eval_right, clause)
169 else:
170 raise UnevaluatableError(
171 f"Cannot evaluate {type(clause).__name__} with "
172 f"operator {clause.operator}"
173 )
174
175 def visit_or_clauselist_op(self, operator, evaluators, clause):
176 def evaluate(obj):
177 has_null = False
178 for sub_evaluate in evaluators:
179 value = sub_evaluate(obj)
180 if value is _EXPIRED_OBJECT:
181 return _EXPIRED_OBJECT
182 elif value:
183 return True
184 has_null = has_null or value is None
185 if has_null:
186 return None
187 return False
188
189 return evaluate
190
191 def visit_and_clauselist_op(self, operator, evaluators, clause):
192 def evaluate(obj):
193 for sub_evaluate in evaluators:
194 value = sub_evaluate(obj)
195 if value is _EXPIRED_OBJECT:
196 return _EXPIRED_OBJECT
197
198 if not value:
199 if value is None or value is _NO_OBJECT:
200 return None
201 return False
202 return True
203
204 return evaluate
205
206 def visit_comma_op_clauselist_op(self, operator, evaluators, clause):
207 def evaluate(obj):
208 values = []
209 for sub_evaluate in evaluators:
210 value = sub_evaluate(obj)
211 if value is _EXPIRED_OBJECT:
212 return _EXPIRED_OBJECT
213 elif value is None or value is _NO_OBJECT:
214 return None
215 values.append(value)
216 return tuple(values)
217
218 return evaluate
219
220 def visit_custom_op_binary_op(
221 self, operator, eval_left, eval_right, clause
222 ):
223 if operator.python_impl:
224 return self._straight_evaluate(
225 operator, eval_left, eval_right, clause
226 )
227 else:
228 raise UnevaluatableError(
229 f"Custom operator {operator.opstring!r} can't be evaluated "
230 "in Python unless it specifies a callable using "
231 "`.python_impl`."
232 )
233
234 def visit_is_binary_op(self, operator, eval_left, eval_right, clause):
235 def evaluate(obj):
236 left_val = eval_left(obj)
237 right_val = eval_right(obj)
238 if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
239 return _EXPIRED_OBJECT
240 return left_val == right_val
241
242 return evaluate
243
244 def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause):
245 def evaluate(obj):
246 left_val = eval_left(obj)
247 right_val = eval_right(obj)
248 if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
249 return _EXPIRED_OBJECT
250 return left_val != right_val
251
252 return evaluate
253
254 def _straight_evaluate(self, operator, eval_left, eval_right, clause):
255 def evaluate(obj):
256 left_val = eval_left(obj)
257 right_val = eval_right(obj)
258 if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
259 return _EXPIRED_OBJECT
260 elif left_val is None or right_val is None:
261 return None
262
263 return operator(eval_left(obj), eval_right(obj))
264
265 return evaluate
266
267 def _straight_evaluate_numeric_only(
268 self, operator, eval_left, eval_right, clause
269 ):
270 if clause.left.type._type_affinity not in (
271 Numeric,
272 Integer,
273 ) or clause.right.type._type_affinity not in (Numeric, Integer):
274 raise UnevaluatableError(
275 f'Cannot evaluate math operator "{operator.__name__}" for '
276 f"datatypes {clause.left.type}, {clause.right.type}"
277 )
278
279 return self._straight_evaluate(operator, eval_left, eval_right, clause)
280
281 visit_add_binary_op = _straight_evaluate_numeric_only
282 visit_mul_binary_op = _straight_evaluate_numeric_only
283 visit_sub_binary_op = _straight_evaluate_numeric_only
284 visit_mod_binary_op = _straight_evaluate_numeric_only
285 visit_truediv_binary_op = _straight_evaluate_numeric_only
286 visit_lt_binary_op = _straight_evaluate
287 visit_le_binary_op = _straight_evaluate
288 visit_ne_binary_op = _straight_evaluate
289 visit_gt_binary_op = _straight_evaluate
290 visit_ge_binary_op = _straight_evaluate
291 visit_eq_binary_op = _straight_evaluate
292
293 def visit_in_op_binary_op(self, operator, eval_left, eval_right, clause):
294 return self._straight_evaluate(
295 lambda a, b: a in b if a is not _NO_OBJECT else None,
296 eval_left,
297 eval_right,
298 clause,
299 )
300
301 def visit_not_in_op_binary_op(
302 self, operator, eval_left, eval_right, clause
303 ):
304 return self._straight_evaluate(
305 lambda a, b: a not in b if a is not _NO_OBJECT else None,
306 eval_left,
307 eval_right,
308 clause,
309 )
310
311 def visit_concat_op_binary_op(
312 self, operator, eval_left, eval_right, clause
313 ):
314
315 if not issubclass(
316 clause.left.type._type_affinity, Concatenable
317 ) or not issubclass(clause.right.type._type_affinity, Concatenable):
318 raise UnevaluatableError(
319 f"Cannot evaluate concatenate operator "
320 f'"{operator.__name__}" for '
321 f"datatypes {clause.left.type}, {clause.right.type}"
322 )
323
324 return self._straight_evaluate(
325 lambda a, b: a + b, eval_left, eval_right, clause
326 )
327
328 def visit_startswith_op_binary_op(
329 self, operator, eval_left, eval_right, clause
330 ):
331 return self._straight_evaluate(
332 lambda a, b: a.startswith(b), eval_left, eval_right, clause
333 )
334
335 def visit_endswith_op_binary_op(
336 self, operator, eval_left, eval_right, clause
337 ):
338 return self._straight_evaluate(
339 lambda a, b: a.endswith(b), eval_left, eval_right, clause
340 )
341
342 def visit_unary(self, clause):
343 eval_inner = self.process(clause.element)
344 if clause.operator is operators.inv:
345
346 def evaluate(obj):
347 value = eval_inner(obj)
348 if value is _EXPIRED_OBJECT:
349 return _EXPIRED_OBJECT
350 elif value is None:
351 return None
352 return not value
353
354 return evaluate
355 raise UnevaluatableError(
356 f"Cannot evaluate {type(clause).__name__} "
357 f"with operator {clause.operator}"
358 )
359
360 def visit_bindparam(self, clause):
361 if clause.callable:
362 val = clause.callable()
363 else:
364 val = clause.value
365 return lambda obj: val
366
367
368def __getattr__(name: str) -> Type[_EvaluatorCompiler]:
369 if name == "EvaluatorCompiler":
370 warn_deprecated(
371 "Direct use of 'EvaluatorCompiler' is not supported, and this "
372 "name will be removed in a future release. "
373 "'_EvaluatorCompiler' is for internal use only",
374 "2.0",
375 )
376 return _EvaluatorCompiler
377 else:
378 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")