1# orm/evaluator.py 
    2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors 
    3# <see AUTHORS file> 
    4# 
    5# This module is part of SQLAlchemy and is released under 
    6# the MIT License: https://www.opensource.org/licenses/mit-license.php 
    7# mypy: ignore-errors 
    8 
    9"""Evaluation functions used **INTERNALLY** by ORM DML use cases. 
    10 
    11 
    12This module is **private, for internal use by SQLAlchemy**. 
    13 
    14.. versionchanged:: 2.0.4 renamed ``EvaluatorCompiler`` to 
    15   ``_EvaluatorCompiler``. 
    16 
    17""" 
    18 
    19 
    20from __future__ import annotations 
    21 
    22from typing import Type 
    23 
    24from . import exc as orm_exc 
    25from .base import LoaderCallableStatus 
    26from .base import PassiveFlag 
    27from .. import exc 
    28from .. import inspect 
    29from ..sql import and_ 
    30from ..sql import operators 
    31from ..sql.sqltypes import Concatenable 
    32from ..sql.sqltypes import Integer 
    33from ..sql.sqltypes import Numeric 
    34from ..util import warn_deprecated 
    35 
    36 
    37class UnevaluatableError(exc.InvalidRequestError): 
    38    pass 
    39 
    40 
    41class _NoObject(operators.ColumnOperators): 
    42    def operate(self, *arg, **kw): 
    43        return None 
    44 
    45    def reverse_operate(self, *arg, **kw): 
    46        return None 
    47 
    48 
    49class _ExpiredObject(operators.ColumnOperators): 
    50    def operate(self, *arg, **kw): 
    51        return self 
    52 
    53    def reverse_operate(self, *arg, **kw): 
    54        return self 
    55 
    56 
    57_NO_OBJECT = _NoObject() 
    58_EXPIRED_OBJECT = _ExpiredObject() 
    59 
    60 
    61class _EvaluatorCompiler: 
    62    def __init__(self, target_cls=None): 
    63        self.target_cls = target_cls 
    64 
    65    def process(self, clause, *clauses): 
    66        if clauses: 
    67            clause = and_(clause, *clauses) 
    68 
    69        meth = getattr(self, f"visit_{clause.__visit_name__}", None) 
    70        if not meth: 
    71            raise UnevaluatableError( 
    72                f"Cannot evaluate {type(clause).__name__}" 
    73            ) 
    74        return meth(clause) 
    75 
    76    def visit_grouping(self, clause): 
    77        return self.process(clause.element) 
    78 
    79    def visit_null(self, clause): 
    80        return lambda obj: None 
    81 
    82    def visit_false(self, clause): 
    83        return lambda obj: False 
    84 
    85    def visit_true(self, clause): 
    86        return lambda obj: True 
    87 
    88    def visit_column(self, clause): 
    89        try: 
    90            parentmapper = clause._annotations["parentmapper"] 
    91        except KeyError as ke: 
    92            raise UnevaluatableError( 
    93                f"Cannot evaluate column: {clause}" 
    94            ) from ke 
    95 
    96        if self.target_cls and not issubclass( 
    97            self.target_cls, parentmapper.class_ 
    98        ): 
    99            raise UnevaluatableError( 
    100                "Can't evaluate criteria against " 
    101                f"alternate class {parentmapper.class_}" 
    102            ) 
    103 
    104        parentmapper._check_configure() 
    105 
    106        # we'd like to use "proxy_key" annotation to get the "key", however 
    107        # in relationship primaryjoin cases proxy_key is sometimes deannotated 
    108        # and sometimes apparently not present in the first place (?). 
    109        # While I can stop it from being deannotated (though need to see if 
    110        # this breaks other things), not sure right now  about cases where it's 
    111        # not there in the first place.  can fix at some later point. 
    112        # key = clause._annotations["proxy_key"] 
    113 
    114        # for now, use the old way 
    115        try: 
    116            key = parentmapper._columntoproperty[clause].key 
    117        except orm_exc.UnmappedColumnError as err: 
    118            raise UnevaluatableError( 
    119                f"Cannot evaluate expression: {err}" 
    120            ) from err 
    121 
    122        # note this used to fall back to a simple `getattr(obj, key)` evaluator 
    123        # if impl was None; as of #8656, we ensure mappers are configured 
    124        # so that impl is available 
    125        impl = parentmapper.class_manager[key].impl 
    126 
    127        def get_corresponding_attr(obj): 
    128            if obj is None: 
    129                return _NO_OBJECT 
    130            state = inspect(obj) 
    131            dict_ = state.dict 
    132 
    133            value = impl.get( 
    134                state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH 
    135            ) 
    136            if value is LoaderCallableStatus.PASSIVE_NO_RESULT: 
    137                return _EXPIRED_OBJECT 
    138            return value 
    139 
    140        return get_corresponding_attr 
    141 
    142    def visit_tuple(self, clause): 
    143        return self.visit_clauselist(clause) 
    144 
    145    def visit_expression_clauselist(self, clause): 
    146        return self.visit_clauselist(clause) 
    147 
    148    def visit_clauselist(self, clause): 
    149        evaluators = [self.process(clause) for clause in clause.clauses] 
    150 
    151        dispatch = ( 
    152            f"visit_{clause.operator.__name__.rstrip('_')}_clauselist_op" 
    153        ) 
    154        meth = getattr(self, dispatch, None) 
    155        if meth: 
    156            return meth(clause.operator, evaluators, clause) 
    157        else: 
    158            raise UnevaluatableError( 
    159                f"Cannot evaluate clauselist with operator {clause.operator}" 
    160            ) 
    161 
    162    def visit_binary(self, clause): 
    163        eval_left = self.process(clause.left) 
    164        eval_right = self.process(clause.right) 
    165 
    166        dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op" 
    167        meth = getattr(self, dispatch, None) 
    168        if meth: 
    169            return meth(clause.operator, eval_left, eval_right, clause) 
    170        else: 
    171            raise UnevaluatableError( 
    172                f"Cannot evaluate {type(clause).__name__} with " 
    173                f"operator {clause.operator}" 
    174            ) 
    175 
    176    def visit_or_clauselist_op(self, operator, evaluators, clause): 
    177        def evaluate(obj): 
    178            has_null = False 
    179            for sub_evaluate in evaluators: 
    180                value = sub_evaluate(obj) 
    181                if value is _EXPIRED_OBJECT: 
    182                    return _EXPIRED_OBJECT 
    183                elif value: 
    184                    return True 
    185                has_null = has_null or value is None 
    186            if has_null: 
    187                return None 
    188            return False 
    189 
    190        return evaluate 
    191 
    192    def visit_and_clauselist_op(self, operator, evaluators, clause): 
    193        def evaluate(obj): 
    194            for sub_evaluate in evaluators: 
    195                value = sub_evaluate(obj) 
    196                if value is _EXPIRED_OBJECT: 
    197                    return _EXPIRED_OBJECT 
    198 
    199                if not value: 
    200                    if value is None or value is _NO_OBJECT: 
    201                        return None 
    202                    return False 
    203            return True 
    204 
    205        return evaluate 
    206 
    207    def visit_comma_op_clauselist_op(self, operator, evaluators, clause): 
    208        def evaluate(obj): 
    209            values = [] 
    210            for sub_evaluate in evaluators: 
    211                value = sub_evaluate(obj) 
    212                if value is _EXPIRED_OBJECT: 
    213                    return _EXPIRED_OBJECT 
    214                elif value is None or value is _NO_OBJECT: 
    215                    return None 
    216                values.append(value) 
    217            return tuple(values) 
    218 
    219        return evaluate 
    220 
    221    def visit_custom_op_binary_op( 
    222        self, operator, eval_left, eval_right, clause 
    223    ): 
    224        if operator.python_impl: 
    225            return self._straight_evaluate( 
    226                operator, eval_left, eval_right, clause 
    227            ) 
    228        else: 
    229            raise UnevaluatableError( 
    230                f"Custom operator {operator.opstring!r} can't be evaluated " 
    231                "in Python unless it specifies a callable using " 
    232                "`.python_impl`." 
    233            ) 
    234 
    235    def visit_is_binary_op(self, operator, eval_left, eval_right, clause): 
    236        def evaluate(obj): 
    237            left_val = eval_left(obj) 
    238            right_val = eval_right(obj) 
    239            if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: 
    240                return _EXPIRED_OBJECT 
    241            return left_val == right_val 
    242 
    243        return evaluate 
    244 
    245    def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause): 
    246        def evaluate(obj): 
    247            left_val = eval_left(obj) 
    248            right_val = eval_right(obj) 
    249            if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: 
    250                return _EXPIRED_OBJECT 
    251            return left_val != right_val 
    252 
    253        return evaluate 
    254 
    255    def _straight_evaluate(self, operator, eval_left, eval_right, clause): 
    256        def evaluate(obj): 
    257            left_val = eval_left(obj) 
    258            right_val = eval_right(obj) 
    259            if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: 
    260                return _EXPIRED_OBJECT 
    261            elif left_val is None or right_val is None: 
    262                return None 
    263 
    264            return operator(eval_left(obj), eval_right(obj)) 
    265 
    266        return evaluate 
    267 
    268    def _straight_evaluate_numeric_only( 
    269        self, operator, eval_left, eval_right, clause 
    270    ): 
    271        if clause.left.type._type_affinity not in ( 
    272            Numeric, 
    273            Integer, 
    274        ) or clause.right.type._type_affinity not in (Numeric, Integer): 
    275            raise UnevaluatableError( 
    276                f'Cannot evaluate math operator "{operator.__name__}" for ' 
    277                f"datatypes {clause.left.type}, {clause.right.type}" 
    278            ) 
    279 
    280        return self._straight_evaluate(operator, eval_left, eval_right, clause) 
    281 
    282    visit_add_binary_op = _straight_evaluate_numeric_only 
    283    visit_mul_binary_op = _straight_evaluate_numeric_only 
    284    visit_sub_binary_op = _straight_evaluate_numeric_only 
    285    visit_mod_binary_op = _straight_evaluate_numeric_only 
    286    visit_truediv_binary_op = _straight_evaluate_numeric_only 
    287    visit_lt_binary_op = _straight_evaluate 
    288    visit_le_binary_op = _straight_evaluate 
    289    visit_ne_binary_op = _straight_evaluate 
    290    visit_gt_binary_op = _straight_evaluate 
    291    visit_ge_binary_op = _straight_evaluate 
    292    visit_eq_binary_op = _straight_evaluate 
    293 
    294    def visit_in_op_binary_op(self, operator, eval_left, eval_right, clause): 
    295        return self._straight_evaluate( 
    296            lambda a, b: a in b if a is not _NO_OBJECT else None, 
    297            eval_left, 
    298            eval_right, 
    299            clause, 
    300        ) 
    301 
    302    def visit_not_in_op_binary_op( 
    303        self, operator, eval_left, eval_right, clause 
    304    ): 
    305        return self._straight_evaluate( 
    306            lambda a, b: a not in b if a is not _NO_OBJECT else None, 
    307            eval_left, 
    308            eval_right, 
    309            clause, 
    310        ) 
    311 
    312    def visit_concat_op_binary_op( 
    313        self, operator, eval_left, eval_right, clause 
    314    ): 
    315 
    316        if not issubclass( 
    317            clause.left.type._type_affinity, Concatenable 
    318        ) or not issubclass(clause.right.type._type_affinity, Concatenable): 
    319            raise UnevaluatableError( 
    320                f"Cannot evaluate concatenate operator " 
    321                f'"{operator.__name__}" for ' 
    322                f"datatypes {clause.left.type}, {clause.right.type}" 
    323            ) 
    324 
    325        return self._straight_evaluate( 
    326            lambda a, b: a + b, eval_left, eval_right, clause 
    327        ) 
    328 
    329    def visit_startswith_op_binary_op( 
    330        self, operator, eval_left, eval_right, clause 
    331    ): 
    332        return self._straight_evaluate( 
    333            lambda a, b: a.startswith(b), eval_left, eval_right, clause 
    334        ) 
    335 
    336    def visit_endswith_op_binary_op( 
    337        self, operator, eval_left, eval_right, clause 
    338    ): 
    339        return self._straight_evaluate( 
    340            lambda a, b: a.endswith(b), eval_left, eval_right, clause 
    341        ) 
    342 
    343    def visit_unary(self, clause): 
    344        eval_inner = self.process(clause.element) 
    345        if clause.operator is operators.inv: 
    346 
    347            def evaluate(obj): 
    348                value = eval_inner(obj) 
    349                if value is _EXPIRED_OBJECT: 
    350                    return _EXPIRED_OBJECT 
    351                elif value is None: 
    352                    return None 
    353                return not value 
    354 
    355            return evaluate 
    356        raise UnevaluatableError( 
    357            f"Cannot evaluate {type(clause).__name__} " 
    358            f"with operator {clause.operator}" 
    359        ) 
    360 
    361    def visit_bindparam(self, clause): 
    362        if clause.callable: 
    363            val = clause.callable() 
    364        else: 
    365            val = clause.value 
    366        return lambda obj: val 
    367 
    368 
    369def __getattr__(name: str) -> Type[_EvaluatorCompiler]: 
    370    if name == "EvaluatorCompiler": 
    371        warn_deprecated( 
    372            "Direct use of 'EvaluatorCompiler' is not supported, and this " 
    373            "name will be removed in a future release.  " 
    374            "'_EvaluatorCompiler' is for internal use only", 
    375            "2.0", 
    376        ) 
    377        return _EvaluatorCompiler 
    378    else: 
    379        raise AttributeError(f"module {__name__!r} has no attribute {name!r}")