1# orm/evaluator.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9"""Evaluation functions used **INTERNALLY** by ORM DML use cases.
10
11
12This module is **private, for internal use by SQLAlchemy**.
13
14.. versionchanged:: 2.0.4 renamed ``EvaluatorCompiler`` to
15 ``_EvaluatorCompiler``.
16
17"""
18
19
20from __future__ import annotations
21
22from typing import Type
23
24from . import exc as orm_exc
25from .base import LoaderCallableStatus
26from .base import PassiveFlag
27from .. import exc
28from .. import inspect
29from ..sql import and_
30from ..sql import operators
31from ..sql.sqltypes import Concatenable
32from ..sql.sqltypes import Integer
33from ..sql.sqltypes import Numeric
34from ..util import warn_deprecated
35
36
37class UnevaluatableError(exc.InvalidRequestError):
38 pass
39
40
41class _NoObject(operators.ColumnOperators):
42 def operate(self, *arg, **kw):
43 return None
44
45 def reverse_operate(self, *arg, **kw):
46 return None
47
48
49class _ExpiredObject(operators.ColumnOperators):
50 def operate(self, *arg, **kw):
51 return self
52
53 def reverse_operate(self, *arg, **kw):
54 return self
55
56
57_NO_OBJECT = _NoObject()
58_EXPIRED_OBJECT = _ExpiredObject()
59
60
61class _EvaluatorCompiler:
62 def __init__(self, target_cls=None):
63 self.target_cls = target_cls
64
65 def process(self, clause, *clauses):
66 if clauses:
67 clause = and_(clause, *clauses)
68
69 meth = getattr(self, f"visit_{clause.__visit_name__}", None)
70 if not meth:
71 raise UnevaluatableError(
72 f"Cannot evaluate {type(clause).__name__}"
73 )
74 return meth(clause)
75
76 def visit_grouping(self, clause):
77 return self.process(clause.element)
78
79 def visit_null(self, clause):
80 return lambda obj: None
81
82 def visit_false(self, clause):
83 return lambda obj: False
84
85 def visit_true(self, clause):
86 return lambda obj: True
87
88 def visit_column(self, clause):
89 try:
90 parentmapper = clause._annotations["parentmapper"]
91 except KeyError as ke:
92 raise UnevaluatableError(
93 f"Cannot evaluate column: {clause}"
94 ) from ke
95
96 if self.target_cls and not issubclass(
97 self.target_cls, parentmapper.class_
98 ):
99 raise UnevaluatableError(
100 "Can't evaluate criteria against "
101 f"alternate class {parentmapper.class_}"
102 )
103
104 parentmapper._check_configure()
105
106 # we'd like to use "proxy_key" annotation to get the "key", however
107 # in relationship primaryjoin cases proxy_key is sometimes deannotated
108 # and sometimes apparently not present in the first place (?).
109 # While I can stop it from being deannotated (though need to see if
110 # this breaks other things), not sure right now about cases where it's
111 # not there in the first place. can fix at some later point.
112 # key = clause._annotations["proxy_key"]
113
114 # for now, use the old way
115 try:
116 key = parentmapper._columntoproperty[clause].key
117 except orm_exc.UnmappedColumnError as err:
118 raise UnevaluatableError(
119 f"Cannot evaluate expression: {err}"
120 ) from err
121
122 # note this used to fall back to a simple `getattr(obj, key)` evaluator
123 # if impl was None; as of #8656, we ensure mappers are configured
124 # so that impl is available
125 impl = parentmapper.class_manager[key].impl
126
127 def get_corresponding_attr(obj):
128 if obj is None:
129 return _NO_OBJECT
130 state = inspect(obj)
131 dict_ = state.dict
132
133 value = impl.get(
134 state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH
135 )
136 if value is LoaderCallableStatus.PASSIVE_NO_RESULT:
137 return _EXPIRED_OBJECT
138 return value
139
140 return get_corresponding_attr
141
142 def visit_tuple(self, clause):
143 return self.visit_clauselist(clause)
144
145 def visit_expression_clauselist(self, clause):
146 return self.visit_clauselist(clause)
147
148 def visit_clauselist(self, clause):
149 evaluators = [self.process(clause) for clause in clause.clauses]
150
151 dispatch = (
152 f"visit_{clause.operator.__name__.rstrip('_')}_clauselist_op"
153 )
154 meth = getattr(self, dispatch, None)
155 if meth:
156 return meth(clause.operator, evaluators, clause)
157 else:
158 raise UnevaluatableError(
159 f"Cannot evaluate clauselist with operator {clause.operator}"
160 )
161
162 def visit_binary(self, clause):
163 eval_left = self.process(clause.left)
164 eval_right = self.process(clause.right)
165
166 dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op"
167 meth = getattr(self, dispatch, None)
168 if meth:
169 return meth(clause.operator, eval_left, eval_right, clause)
170 else:
171 raise UnevaluatableError(
172 f"Cannot evaluate {type(clause).__name__} with "
173 f"operator {clause.operator}"
174 )
175
176 def visit_or_clauselist_op(self, operator, evaluators, clause):
177 def evaluate(obj):
178 has_null = False
179 for sub_evaluate in evaluators:
180 value = sub_evaluate(obj)
181 if value is _EXPIRED_OBJECT:
182 return _EXPIRED_OBJECT
183 elif value:
184 return True
185 has_null = has_null or value is None
186 if has_null:
187 return None
188 return False
189
190 return evaluate
191
192 def visit_and_clauselist_op(self, operator, evaluators, clause):
193 def evaluate(obj):
194 for sub_evaluate in evaluators:
195 value = sub_evaluate(obj)
196 if value is _EXPIRED_OBJECT:
197 return _EXPIRED_OBJECT
198
199 if not value:
200 if value is None or value is _NO_OBJECT:
201 return None
202 return False
203 return True
204
205 return evaluate
206
207 def visit_comma_op_clauselist_op(self, operator, evaluators, clause):
208 def evaluate(obj):
209 values = []
210 for sub_evaluate in evaluators:
211 value = sub_evaluate(obj)
212 if value is _EXPIRED_OBJECT:
213 return _EXPIRED_OBJECT
214 elif value is None or value is _NO_OBJECT:
215 return None
216 values.append(value)
217 return tuple(values)
218
219 return evaluate
220
221 def visit_custom_op_binary_op(
222 self, operator, eval_left, eval_right, clause
223 ):
224 if operator.python_impl:
225 return self._straight_evaluate(
226 operator, eval_left, eval_right, clause
227 )
228 else:
229 raise UnevaluatableError(
230 f"Custom operator {operator.opstring!r} can't be evaluated "
231 "in Python unless it specifies a callable using "
232 "`.python_impl`."
233 )
234
235 def visit_is_binary_op(self, operator, eval_left, eval_right, clause):
236 def evaluate(obj):
237 left_val = eval_left(obj)
238 right_val = eval_right(obj)
239 if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
240 return _EXPIRED_OBJECT
241 return left_val == right_val
242
243 return evaluate
244
245 def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause):
246 def evaluate(obj):
247 left_val = eval_left(obj)
248 right_val = eval_right(obj)
249 if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
250 return _EXPIRED_OBJECT
251 return left_val != right_val
252
253 return evaluate
254
255 def _straight_evaluate(self, operator, eval_left, eval_right, clause):
256 def evaluate(obj):
257 left_val = eval_left(obj)
258 right_val = eval_right(obj)
259 if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
260 return _EXPIRED_OBJECT
261 elif left_val is None or right_val is None:
262 return None
263
264 return operator(eval_left(obj), eval_right(obj))
265
266 return evaluate
267
268 def _straight_evaluate_numeric_only(
269 self, operator, eval_left, eval_right, clause
270 ):
271 if clause.left.type._type_affinity not in (
272 Numeric,
273 Integer,
274 ) or clause.right.type._type_affinity not in (Numeric, Integer):
275 raise UnevaluatableError(
276 f'Cannot evaluate math operator "{operator.__name__}" for '
277 f"datatypes {clause.left.type}, {clause.right.type}"
278 )
279
280 return self._straight_evaluate(operator, eval_left, eval_right, clause)
281
282 visit_add_binary_op = _straight_evaluate_numeric_only
283 visit_mul_binary_op = _straight_evaluate_numeric_only
284 visit_sub_binary_op = _straight_evaluate_numeric_only
285 visit_mod_binary_op = _straight_evaluate_numeric_only
286 visit_truediv_binary_op = _straight_evaluate_numeric_only
287 visit_lt_binary_op = _straight_evaluate
288 visit_le_binary_op = _straight_evaluate
289 visit_ne_binary_op = _straight_evaluate
290 visit_gt_binary_op = _straight_evaluate
291 visit_ge_binary_op = _straight_evaluate
292 visit_eq_binary_op = _straight_evaluate
293
294 def visit_in_op_binary_op(self, operator, eval_left, eval_right, clause):
295 return self._straight_evaluate(
296 lambda a, b: a in b if a is not _NO_OBJECT else None,
297 eval_left,
298 eval_right,
299 clause,
300 )
301
302 def visit_not_in_op_binary_op(
303 self, operator, eval_left, eval_right, clause
304 ):
305 return self._straight_evaluate(
306 lambda a, b: a not in b if a is not _NO_OBJECT else None,
307 eval_left,
308 eval_right,
309 clause,
310 )
311
312 def visit_concat_op_binary_op(
313 self, operator, eval_left, eval_right, clause
314 ):
315
316 if not issubclass(
317 clause.left.type._type_affinity, Concatenable
318 ) or not issubclass(clause.right.type._type_affinity, Concatenable):
319 raise UnevaluatableError(
320 f"Cannot evaluate concatenate operator "
321 f'"{operator.__name__}" for '
322 f"datatypes {clause.left.type}, {clause.right.type}"
323 )
324
325 return self._straight_evaluate(
326 lambda a, b: a + b, eval_left, eval_right, clause
327 )
328
329 def visit_startswith_op_binary_op(
330 self, operator, eval_left, eval_right, clause
331 ):
332 return self._straight_evaluate(
333 lambda a, b: a.startswith(b), eval_left, eval_right, clause
334 )
335
336 def visit_endswith_op_binary_op(
337 self, operator, eval_left, eval_right, clause
338 ):
339 return self._straight_evaluate(
340 lambda a, b: a.endswith(b), eval_left, eval_right, clause
341 )
342
343 def visit_unary(self, clause):
344 eval_inner = self.process(clause.element)
345 if clause.operator is operators.inv:
346
347 def evaluate(obj):
348 value = eval_inner(obj)
349 if value is _EXPIRED_OBJECT:
350 return _EXPIRED_OBJECT
351 elif value is None:
352 return None
353 return not value
354
355 return evaluate
356 raise UnevaluatableError(
357 f"Cannot evaluate {type(clause).__name__} "
358 f"with operator {clause.operator}"
359 )
360
361 def visit_bindparam(self, clause):
362 if clause.callable:
363 val = clause.callable()
364 else:
365 val = clause.value
366 return lambda obj: val
367
368
369def __getattr__(name: str) -> Type[_EvaluatorCompiler]:
370 if name == "EvaluatorCompiler":
371 warn_deprecated(
372 "Direct use of 'EvaluatorCompiler' is not supported, and this "
373 "name will be removed in a future release. "
374 "'_EvaluatorCompiler' is for internal use only",
375 "2.0",
376 )
377 return _EvaluatorCompiler
378 else:
379 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")