1# orm/identity.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8from __future__ import annotations
9
10from typing import Any
11from typing import cast
12from typing import Dict
13from typing import Iterable
14from typing import Iterator
15from typing import List
16from typing import NoReturn
17from typing import Optional
18from typing import Set
19from typing import Tuple
20from typing import TYPE_CHECKING
21from typing import TypeVar
22import weakref
23
24from . import util as orm_util
25from .. import exc as sa_exc
26
27if TYPE_CHECKING:
28 from ._typing import _IdentityKeyType
29 from .state import InstanceState
30
31
32_T = TypeVar("_T", bound=Any)
33
34_O = TypeVar("_O", bound=object)
35
36
37class IdentityMap:
38 _wr: weakref.ref[IdentityMap]
39
40 _dict: Dict[_IdentityKeyType[Any], Any]
41 _modified: Set[InstanceState[Any]]
42
43 def __init__(self) -> None:
44 self._dict = {}
45 self._modified = set()
46 self._wr = weakref.ref(self)
47
48 def _kill(self) -> None:
49 self._add_unpresent = _killed # type: ignore
50
51 def all_states(self) -> List[InstanceState[Any]]:
52 raise NotImplementedError()
53
54 def contains_state(self, state: InstanceState[Any]) -> bool:
55 raise NotImplementedError()
56
57 def __contains__(self, key: _IdentityKeyType[Any]) -> bool:
58 raise NotImplementedError()
59
60 def safe_discard(self, state: InstanceState[Any]) -> None:
61 raise NotImplementedError()
62
63 def __getitem__(self, key: _IdentityKeyType[_O]) -> _O:
64 raise NotImplementedError()
65
66 def get(
67 self, key: _IdentityKeyType[_O], default: Optional[_O] = None
68 ) -> Optional[_O]:
69 raise NotImplementedError()
70
71 def fast_get_state(
72 self, key: _IdentityKeyType[_O]
73 ) -> Optional[InstanceState[_O]]:
74 raise NotImplementedError()
75
76 def keys(self) -> Iterable[_IdentityKeyType[Any]]:
77 return self._dict.keys()
78
79 def values(self) -> Iterable[object]:
80 raise NotImplementedError()
81
82 def replace(self, state: InstanceState[_O]) -> Optional[InstanceState[_O]]:
83 raise NotImplementedError()
84
85 def add(self, state: InstanceState[Any]) -> bool:
86 raise NotImplementedError()
87
88 def _fast_discard(self, state: InstanceState[Any]) -> None:
89 raise NotImplementedError()
90
91 def _add_unpresent(
92 self, state: InstanceState[Any], key: _IdentityKeyType[Any]
93 ) -> None:
94 """optional inlined form of add() which can assume item isn't present
95 in the map"""
96 self.add(state)
97
98 def _manage_incoming_state(self, state: InstanceState[Any]) -> None:
99 state._instance_dict = self._wr
100
101 if state.modified:
102 self._modified.add(state)
103
104 def _manage_removed_state(self, state: InstanceState[Any]) -> None:
105 del state._instance_dict
106 if state.modified:
107 self._modified.discard(state)
108
109 def _dirty_states(self) -> Set[InstanceState[Any]]:
110 return self._modified
111
112 def check_modified(self) -> bool:
113 """return True if any InstanceStates present have been marked
114 as 'modified'.
115
116 """
117 return bool(self._modified)
118
119 def has_key(self, key: _IdentityKeyType[Any]) -> bool:
120 return key in self
121
122 def __len__(self) -> int:
123 return len(self._dict)
124
125
126class WeakInstanceDict(IdentityMap):
127 _dict: Dict[_IdentityKeyType[Any], InstanceState[Any]]
128
129 def __getitem__(self, key: _IdentityKeyType[_O]) -> _O:
130 state = cast("InstanceState[_O]", self._dict[key])
131 o = state.obj()
132 if o is None:
133 raise KeyError(key)
134 return o
135
136 def __contains__(self, key: _IdentityKeyType[Any]) -> bool:
137 try:
138 if key in self._dict:
139 state = self._dict[key]
140 o = state.obj()
141 else:
142 return False
143 except KeyError:
144 return False
145 else:
146 return o is not None
147
148 def contains_state(self, state: InstanceState[Any]) -> bool:
149 if state.key in self._dict:
150 if TYPE_CHECKING:
151 assert state.key is not None
152 try:
153 return self._dict[state.key] is state
154 except KeyError:
155 return False
156 else:
157 return False
158
159 def replace(
160 self, state: InstanceState[Any]
161 ) -> Optional[InstanceState[Any]]:
162 assert state.key is not None
163 if state.key in self._dict:
164 try:
165 existing = existing_non_none = self._dict[state.key]
166 except KeyError:
167 # catch gc removed the key after we just checked for it
168 existing = None
169 else:
170 if existing_non_none is not state:
171 self._manage_removed_state(existing_non_none)
172 else:
173 return None
174 else:
175 existing = None
176
177 self._dict[state.key] = state
178 self._manage_incoming_state(state)
179 return existing
180
181 def add(self, state: InstanceState[Any]) -> bool:
182 key = state.key
183 assert key is not None
184 # inline of self.__contains__
185 if key in self._dict:
186 try:
187 existing_state = self._dict[key]
188 except KeyError:
189 # catch gc removed the key after we just checked for it
190 pass
191 else:
192 if existing_state is not state:
193 o = existing_state.obj()
194 if o is not None:
195 raise sa_exc.InvalidRequestError(
196 "Can't attach instance "
197 "%s; another instance with key %s is already "
198 "present in this session."
199 % (orm_util.state_str(state), state.key)
200 )
201 else:
202 return False
203 self._dict[key] = state
204 self._manage_incoming_state(state)
205 return True
206
207 def _add_unpresent(
208 self, state: InstanceState[Any], key: _IdentityKeyType[Any]
209 ) -> None:
210 # inlined form of add() called by loading.py
211 self._dict[key] = state
212 state._instance_dict = self._wr
213
214 def fast_get_state(
215 self, key: _IdentityKeyType[_O]
216 ) -> Optional[InstanceState[_O]]:
217 return self._dict.get(key)
218
219 def get(
220 self, key: _IdentityKeyType[_O], default: Optional[_O] = None
221 ) -> Optional[_O]:
222 if key not in self._dict:
223 return default
224 try:
225 state = cast("InstanceState[_O]", self._dict[key])
226 except KeyError:
227 # catch gc removed the key after we just checked for it
228 return default
229 else:
230 o = state.obj()
231 if o is None:
232 return default
233 return o
234
235 def items(self) -> List[Tuple[_IdentityKeyType[Any], InstanceState[Any]]]:
236 values = self.all_states()
237 result = []
238 for state in values:
239 value = state.obj()
240 key = state.key
241 assert key is not None
242 if value is not None:
243 result.append((key, value))
244 return result
245
246 def values(self) -> List[object]:
247 values = self.all_states()
248 result = []
249 for state in values:
250 value = state.obj()
251 if value is not None:
252 result.append(value)
253
254 return result
255
256 def __iter__(self) -> Iterator[_IdentityKeyType[Any]]:
257 return iter(self.keys())
258
259 def all_states(self) -> List[InstanceState[Any]]:
260 return list(self._dict.values())
261
262 def _fast_discard(self, state: InstanceState[Any]) -> None:
263 # used by InstanceState for state being
264 # GC'ed, inlines _managed_removed_state
265 key = state.key
266 assert key is not None
267 try:
268 st = self._dict[key]
269 except KeyError:
270 # catch gc removed the key after we just checked for it
271 pass
272 else:
273 if st is state:
274 self._dict.pop(key, None)
275
276 def discard(self, state: InstanceState[Any]) -> None:
277 self.safe_discard(state)
278
279 def safe_discard(self, state: InstanceState[Any]) -> None:
280 key = state.key
281 if key in self._dict:
282 assert key is not None
283 try:
284 st = self._dict[key]
285 except KeyError:
286 # catch gc removed the key after we just checked for it
287 pass
288 else:
289 if st is state:
290 self._dict.pop(key, None)
291 self._manage_removed_state(state)
292
293
294def _killed(state: InstanceState[Any], key: _IdentityKeyType[Any]) -> NoReturn:
295 # external function to avoid creating cycles when assigned to
296 # the IdentityMap
297 raise sa_exc.InvalidRequestError(
298 "Object %s cannot be converted to 'persistent' state, as this "
299 "identity map is no longer valid. Has the owning Session "
300 "been closed?" % orm_util.state_str(state),
301 code="lkrp",
302 )