1# sql/default_comparator.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""Default implementation of SQL comparison operations.
9"""
10
11from __future__ import annotations
12
13import typing
14from typing import Any
15from typing import Callable
16from typing import Dict
17from typing import NoReturn
18from typing import Optional
19from typing import Tuple
20from typing import Type
21from typing import Union
22
23from . import coercions
24from . import operators
25from . import roles
26from . import type_api
27from .elements import and_
28from .elements import BinaryExpression
29from .elements import ClauseElement
30from .elements import CollationClause
31from .elements import CollectionAggregate
32from .elements import ExpressionClauseList
33from .elements import False_
34from .elements import Null
35from .elements import OperatorExpression
36from .elements import or_
37from .elements import True_
38from .elements import UnaryExpression
39from .operators import OperatorType
40from .. import exc
41from .. import util
42
43_T = typing.TypeVar("_T", bound=Any)
44
45if typing.TYPE_CHECKING:
46 from .elements import ColumnElement
47 from .operators import custom_op
48 from .type_api import TypeEngine
49
50
51def _boolean_compare(
52 expr: ColumnElement[Any],
53 op: OperatorType,
54 obj: Any,
55 *,
56 negate_op: Optional[OperatorType] = None,
57 reverse: bool = False,
58 _python_is_types: Tuple[Type[Any], ...] = (type(None), bool),
59 result_type: Optional[TypeEngine[bool]] = None,
60 **kwargs: Any,
61) -> OperatorExpression[bool]:
62 if result_type is None:
63 result_type = type_api.BOOLEANTYPE
64
65 if isinstance(obj, _python_is_types + (Null, True_, False_)):
66 # allow x ==/!= True/False to be treated as a literal.
67 # this comes out to "== / != true/false" or "1/0" if those
68 # constants aren't supported and works on all platforms
69 if op in (operators.eq, operators.ne) and isinstance(
70 obj, (bool, True_, False_)
71 ):
72 return OperatorExpression._construct_for_op(
73 expr,
74 coercions.expect(roles.ConstExprRole, obj),
75 op,
76 type_=result_type,
77 negate=negate_op,
78 modifiers=kwargs,
79 )
80 elif op in (
81 operators.is_distinct_from,
82 operators.is_not_distinct_from,
83 ):
84 return OperatorExpression._construct_for_op(
85 expr,
86 coercions.expect(roles.ConstExprRole, obj),
87 op,
88 type_=result_type,
89 negate=negate_op,
90 modifiers=kwargs,
91 )
92 elif expr._is_collection_aggregate:
93 obj = coercions.expect(
94 roles.ConstExprRole, element=obj, operator=op, expr=expr
95 )
96 else:
97 # all other None uses IS, IS NOT
98 if op in (operators.eq, operators.is_):
99 return OperatorExpression._construct_for_op(
100 expr,
101 coercions.expect(roles.ConstExprRole, obj),
102 operators.is_,
103 negate=operators.is_not,
104 type_=result_type,
105 )
106 elif op in (operators.ne, operators.is_not):
107 return OperatorExpression._construct_for_op(
108 expr,
109 coercions.expect(roles.ConstExprRole, obj),
110 operators.is_not,
111 negate=operators.is_,
112 type_=result_type,
113 )
114 else:
115 raise exc.ArgumentError(
116 "Only '=', '!=', 'is_()', 'is_not()', "
117 "'is_distinct_from()', 'is_not_distinct_from()' "
118 "operators can be used with None/True/False"
119 )
120 else:
121 obj = coercions.expect(
122 roles.BinaryElementRole, element=obj, operator=op, expr=expr
123 )
124
125 if reverse:
126 return OperatorExpression._construct_for_op(
127 obj,
128 expr,
129 op,
130 type_=result_type,
131 negate=negate_op,
132 modifiers=kwargs,
133 )
134 else:
135 return OperatorExpression._construct_for_op(
136 expr,
137 obj,
138 op,
139 type_=result_type,
140 negate=negate_op,
141 modifiers=kwargs,
142 )
143
144
145def _custom_op_operate(
146 expr: ColumnElement[Any],
147 op: custom_op[Any],
148 obj: Any,
149 reverse: bool = False,
150 result_type: Optional[TypeEngine[Any]] = None,
151 **kw: Any,
152) -> ColumnElement[Any]:
153 if result_type is None:
154 if op.return_type:
155 result_type = op.return_type
156 elif op.is_comparison:
157 result_type = type_api.BOOLEANTYPE
158
159 return _binary_operate(
160 expr, op, obj, reverse=reverse, result_type=result_type, **kw
161 )
162
163
164def _binary_operate(
165 expr: ColumnElement[Any],
166 op: OperatorType,
167 obj: roles.BinaryElementRole[Any],
168 *,
169 reverse: bool = False,
170 result_type: Optional[TypeEngine[_T]] = None,
171 **kw: Any,
172) -> OperatorExpression[_T]:
173 coerced_obj = coercions.expect(
174 roles.BinaryElementRole, obj, expr=expr, operator=op
175 )
176
177 if reverse:
178 left, right = coerced_obj, expr
179 else:
180 left, right = expr, coerced_obj
181
182 if result_type is None:
183 op, result_type = left.comparator._adapt_expression(
184 op, right.comparator
185 )
186
187 return OperatorExpression._construct_for_op(
188 left, right, op, type_=result_type, modifiers=kw
189 )
190
191
192def _conjunction_operate(
193 expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any
194) -> ColumnElement[Any]:
195 if op is operators.and_:
196 return and_(expr, other)
197 elif op is operators.or_:
198 return or_(expr, other)
199 else:
200 raise NotImplementedError()
201
202
203def _scalar(
204 expr: ColumnElement[Any],
205 op: OperatorType,
206 fn: Callable[[ColumnElement[Any]], ColumnElement[Any]],
207 **kw: Any,
208) -> ColumnElement[Any]:
209 return fn(expr)
210
211
212def _in_impl(
213 expr: ColumnElement[Any],
214 op: OperatorType,
215 seq_or_selectable: ClauseElement,
216 negate_op: OperatorType,
217 **kw: Any,
218) -> ColumnElement[Any]:
219 seq_or_selectable = coercions.expect(
220 roles.InElementRole, seq_or_selectable, expr=expr, operator=op
221 )
222 if "in_ops" in seq_or_selectable._annotations:
223 op, negate_op = seq_or_selectable._annotations["in_ops"]
224
225 return _boolean_compare(
226 expr, op, seq_or_selectable, negate_op=negate_op, **kw
227 )
228
229
230def _getitem_impl(
231 expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any
232) -> ColumnElement[Any]:
233 if (
234 isinstance(expr.type, type_api.INDEXABLE)
235 or isinstance(expr.type, type_api.TypeDecorator)
236 and isinstance(expr.type.impl_instance, type_api.INDEXABLE)
237 ):
238 other = coercions.expect(
239 roles.BinaryElementRole, other, expr=expr, operator=op
240 )
241 return _binary_operate(expr, op, other, **kw)
242 else:
243 _unsupported_impl(expr, op, other, **kw)
244
245
246def _unsupported_impl(
247 expr: ColumnElement[Any], op: OperatorType, *arg: Any, **kw: Any
248) -> NoReturn:
249 raise NotImplementedError(
250 "Operator '%s' is not supported on this expression" % op.__name__
251 )
252
253
254def _inv_impl(
255 expr: ColumnElement[Any], op: OperatorType, **kw: Any
256) -> ColumnElement[Any]:
257 """See :meth:`.ColumnOperators.__inv__`."""
258
259 # undocumented element currently used by the ORM for
260 # relationship.contains()
261 if hasattr(expr, "negation_clause"):
262 return expr.negation_clause
263 else:
264 return expr._negate()
265
266
267def _neg_impl(
268 expr: ColumnElement[Any], op: OperatorType, **kw: Any
269) -> ColumnElement[Any]:
270 """See :meth:`.ColumnOperators.__neg__`."""
271 return UnaryExpression(expr, operator=operators.neg, type_=expr.type)
272
273
274def _bitwise_not_impl(
275 expr: ColumnElement[Any], op: OperatorType, **kw: Any
276) -> ColumnElement[Any]:
277 """See :meth:`.ColumnOperators.bitwise_not`."""
278
279 return UnaryExpression(
280 expr, operator=operators.bitwise_not_op, type_=expr.type
281 )
282
283
284def _match_impl(
285 expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any
286) -> ColumnElement[Any]:
287 """See :meth:`.ColumnOperators.match`."""
288
289 return _boolean_compare(
290 expr,
291 operators.match_op,
292 coercions.expect(
293 roles.BinaryElementRole,
294 other,
295 expr=expr,
296 operator=operators.match_op,
297 ),
298 result_type=type_api.MATCHTYPE,
299 negate_op=(
300 operators.not_match_op
301 if op is operators.match_op
302 else operators.match_op
303 ),
304 **kw,
305 )
306
307
308def _distinct_impl(
309 expr: ColumnElement[Any], op: OperatorType, **kw: Any
310) -> ColumnElement[Any]:
311 """See :meth:`.ColumnOperators.distinct`."""
312 return UnaryExpression(
313 expr, operator=operators.distinct_op, type_=expr.type
314 )
315
316
317def _between_impl(
318 expr: ColumnElement[Any],
319 op: OperatorType,
320 cleft: Any,
321 cright: Any,
322 **kw: Any,
323) -> ColumnElement[Any]:
324 """See :meth:`.ColumnOperators.between`."""
325 return BinaryExpression(
326 expr,
327 ExpressionClauseList._construct_for_list(
328 operators.and_,
329 type_api.NULLTYPE,
330 coercions.expect(
331 roles.BinaryElementRole,
332 cleft,
333 expr=expr,
334 operator=operators.and_,
335 ),
336 coercions.expect(
337 roles.BinaryElementRole,
338 cright,
339 expr=expr,
340 operator=operators.and_,
341 ),
342 group=False,
343 ),
344 op,
345 negate=(
346 operators.not_between_op
347 if op is operators.between_op
348 else operators.between_op
349 ),
350 modifiers=kw,
351 )
352
353
354def _collate_impl(
355 expr: ColumnElement[str], op: OperatorType, collation: str, **kw: Any
356) -> ColumnElement[str]:
357 return CollationClause._create_collation_expression(expr, collation)
358
359
360def _regexp_match_impl(
361 expr: ColumnElement[str],
362 op: OperatorType,
363 pattern: Any,
364 flags: Optional[str],
365 **kw: Any,
366) -> ColumnElement[Any]:
367 return BinaryExpression(
368 expr,
369 coercions.expect(
370 roles.BinaryElementRole,
371 pattern,
372 expr=expr,
373 operator=operators.comma_op,
374 ),
375 op,
376 negate=operators.not_regexp_match_op,
377 modifiers={"flags": flags},
378 )
379
380
381def _regexp_replace_impl(
382 expr: ColumnElement[Any],
383 op: OperatorType,
384 pattern: Any,
385 replacement: Any,
386 flags: Optional[str],
387 **kw: Any,
388) -> ColumnElement[Any]:
389 return BinaryExpression(
390 expr,
391 ExpressionClauseList._construct_for_list(
392 operators.comma_op,
393 type_api.NULLTYPE,
394 coercions.expect(
395 roles.BinaryElementRole,
396 pattern,
397 expr=expr,
398 operator=operators.comma_op,
399 ),
400 coercions.expect(
401 roles.BinaryElementRole,
402 replacement,
403 expr=expr,
404 operator=operators.comma_op,
405 ),
406 group=False,
407 ),
408 op,
409 modifiers={"flags": flags},
410 )
411
412
413# a mapping of operators with the method they use, along with
414# additional keyword arguments to be passed
415operator_lookup: Dict[
416 str,
417 Tuple[
418 Callable[..., ColumnElement[Any]],
419 util.immutabledict[
420 str, Union[OperatorType, Callable[..., ColumnElement[Any]]]
421 ],
422 ],
423] = {
424 "and_": (_conjunction_operate, util.EMPTY_DICT),
425 "or_": (_conjunction_operate, util.EMPTY_DICT),
426 "inv": (_inv_impl, util.EMPTY_DICT),
427 "add": (_binary_operate, util.EMPTY_DICT),
428 "mul": (_binary_operate, util.EMPTY_DICT),
429 "sub": (_binary_operate, util.EMPTY_DICT),
430 "div": (_binary_operate, util.EMPTY_DICT),
431 "mod": (_binary_operate, util.EMPTY_DICT),
432 "bitwise_xor_op": (_binary_operate, util.EMPTY_DICT),
433 "bitwise_or_op": (_binary_operate, util.EMPTY_DICT),
434 "bitwise_and_op": (_binary_operate, util.EMPTY_DICT),
435 "bitwise_not_op": (_bitwise_not_impl, util.EMPTY_DICT),
436 "bitwise_lshift_op": (_binary_operate, util.EMPTY_DICT),
437 "bitwise_rshift_op": (_binary_operate, util.EMPTY_DICT),
438 "truediv": (_binary_operate, util.EMPTY_DICT),
439 "floordiv": (_binary_operate, util.EMPTY_DICT),
440 "custom_op": (_custom_op_operate, util.EMPTY_DICT),
441 "json_path_getitem_op": (_binary_operate, util.EMPTY_DICT),
442 "json_getitem_op": (_binary_operate, util.EMPTY_DICT),
443 "concat_op": (_binary_operate, util.EMPTY_DICT),
444 "any_op": (
445 _scalar,
446 util.immutabledict({"fn": CollectionAggregate._create_any}),
447 ),
448 "all_op": (
449 _scalar,
450 util.immutabledict({"fn": CollectionAggregate._create_all}),
451 ),
452 "lt": (_boolean_compare, util.immutabledict({"negate_op": operators.ge})),
453 "le": (_boolean_compare, util.immutabledict({"negate_op": operators.gt})),
454 "ne": (_boolean_compare, util.immutabledict({"negate_op": operators.eq})),
455 "gt": (_boolean_compare, util.immutabledict({"negate_op": operators.le})),
456 "ge": (_boolean_compare, util.immutabledict({"negate_op": operators.lt})),
457 "eq": (_boolean_compare, util.immutabledict({"negate_op": operators.ne})),
458 "is_distinct_from": (
459 _boolean_compare,
460 util.immutabledict({"negate_op": operators.is_not_distinct_from}),
461 ),
462 "is_not_distinct_from": (
463 _boolean_compare,
464 util.immutabledict({"negate_op": operators.is_distinct_from}),
465 ),
466 "like_op": (
467 _boolean_compare,
468 util.immutabledict({"negate_op": operators.not_like_op}),
469 ),
470 "ilike_op": (
471 _boolean_compare,
472 util.immutabledict({"negate_op": operators.not_ilike_op}),
473 ),
474 "not_like_op": (
475 _boolean_compare,
476 util.immutabledict({"negate_op": operators.like_op}),
477 ),
478 "not_ilike_op": (
479 _boolean_compare,
480 util.immutabledict({"negate_op": operators.ilike_op}),
481 ),
482 "contains_op": (
483 _boolean_compare,
484 util.immutabledict({"negate_op": operators.not_contains_op}),
485 ),
486 "icontains_op": (
487 _boolean_compare,
488 util.immutabledict({"negate_op": operators.not_icontains_op}),
489 ),
490 "startswith_op": (
491 _boolean_compare,
492 util.immutabledict({"negate_op": operators.not_startswith_op}),
493 ),
494 "istartswith_op": (
495 _boolean_compare,
496 util.immutabledict({"negate_op": operators.not_istartswith_op}),
497 ),
498 "endswith_op": (
499 _boolean_compare,
500 util.immutabledict({"negate_op": operators.not_endswith_op}),
501 ),
502 "iendswith_op": (
503 _boolean_compare,
504 util.immutabledict({"negate_op": operators.not_iendswith_op}),
505 ),
506 "desc_op": (
507 _scalar,
508 util.immutabledict({"fn": UnaryExpression._create_desc}),
509 ),
510 "asc_op": (
511 _scalar,
512 util.immutabledict({"fn": UnaryExpression._create_asc}),
513 ),
514 "nulls_first_op": (
515 _scalar,
516 util.immutabledict({"fn": UnaryExpression._create_nulls_first}),
517 ),
518 "nulls_last_op": (
519 _scalar,
520 util.immutabledict({"fn": UnaryExpression._create_nulls_last}),
521 ),
522 "in_op": (
523 _in_impl,
524 util.immutabledict({"negate_op": operators.not_in_op}),
525 ),
526 "not_in_op": (
527 _in_impl,
528 util.immutabledict({"negate_op": operators.in_op}),
529 ),
530 "is_": (
531 _boolean_compare,
532 util.immutabledict({"negate_op": operators.is_}),
533 ),
534 "is_not": (
535 _boolean_compare,
536 util.immutabledict({"negate_op": operators.is_not}),
537 ),
538 "collate": (_collate_impl, util.EMPTY_DICT),
539 "match_op": (_match_impl, util.EMPTY_DICT),
540 "not_match_op": (_match_impl, util.EMPTY_DICT),
541 "distinct_op": (_distinct_impl, util.EMPTY_DICT),
542 "between_op": (_between_impl, util.EMPTY_DICT),
543 "not_between_op": (_between_impl, util.EMPTY_DICT),
544 "neg": (_neg_impl, util.EMPTY_DICT),
545 "getitem": (_getitem_impl, util.EMPTY_DICT),
546 "lshift": (_unsupported_impl, util.EMPTY_DICT),
547 "rshift": (_unsupported_impl, util.EMPTY_DICT),
548 "contains": (_unsupported_impl, util.EMPTY_DICT),
549 "regexp_match_op": (_regexp_match_impl, util.EMPTY_DICT),
550 "not_regexp_match_op": (_regexp_match_impl, util.EMPTY_DICT),
551 "regexp_replace_op": (_regexp_replace_impl, util.EMPTY_DICT),
552}