1# orm/identity.py 
    2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors 
    3# <see AUTHORS file> 
    4# 
    5# This module is part of SQLAlchemy and is released under 
    6# the MIT License: https://www.opensource.org/licenses/mit-license.php 
    7 
    8from __future__ import annotations 
    9 
    10from typing import Any 
    11from typing import cast 
    12from typing import Dict 
    13from typing import Iterable 
    14from typing import Iterator 
    15from typing import List 
    16from typing import NoReturn 
    17from typing import Optional 
    18from typing import Set 
    19from typing import Tuple 
    20from typing import TYPE_CHECKING 
    21from typing import TypeVar 
    22import weakref 
    23 
    24from . import util as orm_util 
    25from .. import exc as sa_exc 
    26 
    27if TYPE_CHECKING: 
    28    from ._typing import _IdentityKeyType 
    29    from .state import InstanceState 
    30 
    31 
    32_T = TypeVar("_T", bound=Any) 
    33 
    34_O = TypeVar("_O", bound=object) 
    35 
    36 
    37class IdentityMap: 
    38    _wr: weakref.ref[IdentityMap] 
    39 
    40    _dict: Dict[_IdentityKeyType[Any], Any] 
    41    _modified: Set[InstanceState[Any]] 
    42 
    43    def __init__(self) -> None: 
    44        self._dict = {} 
    45        self._modified = set() 
    46        self._wr = weakref.ref(self) 
    47 
    48    def _kill(self) -> None: 
    49        self._add_unpresent = _killed  # type: ignore 
    50 
    51    def all_states(self) -> List[InstanceState[Any]]: 
    52        raise NotImplementedError() 
    53 
    54    def contains_state(self, state: InstanceState[Any]) -> bool: 
    55        raise NotImplementedError() 
    56 
    57    def __contains__(self, key: _IdentityKeyType[Any]) -> bool: 
    58        raise NotImplementedError() 
    59 
    60    def safe_discard(self, state: InstanceState[Any]) -> None: 
    61        raise NotImplementedError() 
    62 
    63    def __getitem__(self, key: _IdentityKeyType[_O]) -> _O: 
    64        raise NotImplementedError() 
    65 
    66    def get( 
    67        self, key: _IdentityKeyType[_O], default: Optional[_O] = None 
    68    ) -> Optional[_O]: 
    69        raise NotImplementedError() 
    70 
    71    def fast_get_state( 
    72        self, key: _IdentityKeyType[_O] 
    73    ) -> Optional[InstanceState[_O]]: 
    74        raise NotImplementedError() 
    75 
    76    def keys(self) -> Iterable[_IdentityKeyType[Any]]: 
    77        return self._dict.keys() 
    78 
    79    def values(self) -> Iterable[object]: 
    80        raise NotImplementedError() 
    81 
    82    def replace(self, state: InstanceState[_O]) -> Optional[InstanceState[_O]]: 
    83        raise NotImplementedError() 
    84 
    85    def add(self, state: InstanceState[Any]) -> bool: 
    86        raise NotImplementedError() 
    87 
    88    def _fast_discard(self, state: InstanceState[Any]) -> None: 
    89        raise NotImplementedError() 
    90 
    91    def _add_unpresent( 
    92        self, state: InstanceState[Any], key: _IdentityKeyType[Any] 
    93    ) -> None: 
    94        """optional inlined form of add() which can assume item isn't present 
    95        in the map""" 
    96        self.add(state) 
    97 
    98    def _manage_incoming_state(self, state: InstanceState[Any]) -> None: 
    99        state._instance_dict = self._wr 
    100 
    101        if state.modified: 
    102            self._modified.add(state) 
    103 
    104    def _manage_removed_state(self, state: InstanceState[Any]) -> None: 
    105        del state._instance_dict 
    106        if state.modified: 
    107            self._modified.discard(state) 
    108 
    109    def _dirty_states(self) -> Set[InstanceState[Any]]: 
    110        return self._modified 
    111 
    112    def check_modified(self) -> bool: 
    113        """return True if any InstanceStates present have been marked 
    114        as 'modified'. 
    115 
    116        """ 
    117        return bool(self._modified) 
    118 
    119    def has_key(self, key: _IdentityKeyType[Any]) -> bool: 
    120        return key in self 
    121 
    122    def __len__(self) -> int: 
    123        return len(self._dict) 
    124 
    125 
    126class _WeakInstanceDict(IdentityMap): 
    127    _dict: Dict[_IdentityKeyType[Any], InstanceState[Any]] 
    128 
    129    def __getitem__(self, key: _IdentityKeyType[_O]) -> _O: 
    130        state = cast("InstanceState[_O]", self._dict[key]) 
    131        o = state.obj() 
    132        if o is None: 
    133            raise KeyError(key) 
    134        return o 
    135 
    136    def __contains__(self, key: _IdentityKeyType[Any]) -> bool: 
    137        try: 
    138            if key in self._dict: 
    139                state = self._dict[key] 
    140                o = state.obj() 
    141            else: 
    142                return False 
    143        except KeyError: 
    144            return False 
    145        else: 
    146            return o is not None 
    147 
    148    def contains_state(self, state: InstanceState[Any]) -> bool: 
    149        if state.key in self._dict: 
    150            if TYPE_CHECKING: 
    151                assert state.key is not None 
    152            try: 
    153                return self._dict[state.key] is state 
    154            except KeyError: 
    155                return False 
    156        else: 
    157            return False 
    158 
    159    def replace( 
    160        self, state: InstanceState[Any] 
    161    ) -> Optional[InstanceState[Any]]: 
    162        assert state.key is not None 
    163        if state.key in self._dict: 
    164            try: 
    165                existing = existing_non_none = self._dict[state.key] 
    166            except KeyError: 
    167                # catch gc removed the key after we just checked for it 
    168                existing = None 
    169            else: 
    170                if existing_non_none is not state: 
    171                    self._manage_removed_state(existing_non_none) 
    172                else: 
    173                    return None 
    174        else: 
    175            existing = None 
    176 
    177        self._dict[state.key] = state 
    178        self._manage_incoming_state(state) 
    179        return existing 
    180 
    181    def add(self, state: InstanceState[Any]) -> bool: 
    182        key = state.key 
    183        assert key is not None 
    184        # inline of self.__contains__ 
    185        if key in self._dict: 
    186            try: 
    187                existing_state = self._dict[key] 
    188            except KeyError: 
    189                # catch gc removed the key after we just checked for it 
    190                pass 
    191            else: 
    192                if existing_state is not state: 
    193                    o = existing_state.obj() 
    194                    if o is not None: 
    195                        raise sa_exc.InvalidRequestError( 
    196                            "Can't attach instance " 
    197                            "%s; another instance with key %s is already " 
    198                            "present in this session." 
    199                            % (orm_util.state_str(state), state.key) 
    200                        ) 
    201                else: 
    202                    return False 
    203        self._dict[key] = state 
    204        self._manage_incoming_state(state) 
    205        return True 
    206 
    207    def _add_unpresent( 
    208        self, state: InstanceState[Any], key: _IdentityKeyType[Any] 
    209    ) -> None: 
    210        # inlined form of add() called by loading.py 
    211        self._dict[key] = state 
    212        state._instance_dict = self._wr 
    213 
    214    def fast_get_state( 
    215        self, key: _IdentityKeyType[_O] 
    216    ) -> Optional[InstanceState[_O]]: 
    217        return self._dict.get(key) 
    218 
    219    def get( 
    220        self, key: _IdentityKeyType[_O], default: Optional[_O] = None 
    221    ) -> Optional[_O]: 
    222        if key not in self._dict: 
    223            return default 
    224        try: 
    225            state = cast("InstanceState[_O]", self._dict[key]) 
    226        except KeyError: 
    227            # catch gc removed the key after we just checked for it 
    228            return default 
    229        else: 
    230            o = state.obj() 
    231            if o is None: 
    232                return default 
    233            return o 
    234 
    235    def items(self) -> List[Tuple[_IdentityKeyType[Any], InstanceState[Any]]]: 
    236        values = self.all_states() 
    237        result = [] 
    238        for state in values: 
    239            value = state.obj() 
    240            key = state.key 
    241            assert key is not None 
    242            if value is not None: 
    243                result.append((key, value)) 
    244        return result 
    245 
    246    def values(self) -> List[object]: 
    247        values = self.all_states() 
    248        result = [] 
    249        for state in values: 
    250            value = state.obj() 
    251            if value is not None: 
    252                result.append(value) 
    253 
    254        return result 
    255 
    256    def __iter__(self) -> Iterator[_IdentityKeyType[Any]]: 
    257        return iter(self.keys()) 
    258 
    259    def all_states(self) -> List[InstanceState[Any]]: 
    260        return list(self._dict.values()) 
    261 
    262    def _fast_discard(self, state: InstanceState[Any]) -> None: 
    263        # used by InstanceState for state being 
    264        # GC'ed, inlines _managed_removed_state 
    265        key = state.key 
    266        assert key is not None 
    267        try: 
    268            st = self._dict[key] 
    269        except KeyError: 
    270            # catch gc removed the key after we just checked for it 
    271            pass 
    272        else: 
    273            if st is state: 
    274                self._dict.pop(key, None) 
    275 
    276    def discard(self, state: InstanceState[Any]) -> None: 
    277        self.safe_discard(state) 
    278 
    279    def safe_discard(self, state: InstanceState[Any]) -> None: 
    280        key = state.key 
    281        if key in self._dict: 
    282            assert key is not None 
    283            try: 
    284                st = self._dict[key] 
    285            except KeyError: 
    286                # catch gc removed the key after we just checked for it 
    287                pass 
    288            else: 
    289                if st is state: 
    290                    self._dict.pop(key, None) 
    291                    self._manage_removed_state(state) 
    292 
    293 
    294def _killed(state: InstanceState[Any], key: _IdentityKeyType[Any]) -> NoReturn: 
    295    # external function to avoid creating cycles when assigned to 
    296    # the IdentityMap 
    297    raise sa_exc.InvalidRequestError( 
    298        "Object %s cannot be converted to 'persistent' state, as this " 
    299        "identity map is no longer valid.  Has the owning Session " 
    300        "been closed?" % orm_util.state_str(state), 
    301        code="lkrp", 
    302    )