1# util/_collections.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Collection classes and helpers."""
10from __future__ import annotations
11
12import operator
13import threading
14import types
15import typing
16from typing import Any
17from typing import Callable
18from typing import cast
19from typing import Container
20from typing import Dict
21from typing import FrozenSet
22from typing import Generic
23from typing import Iterable
24from typing import Iterator
25from typing import List
26from typing import Literal
27from typing import Mapping
28from typing import NoReturn
29from typing import Optional
30from typing import overload
31from typing import Protocol
32from typing import Sequence
33from typing import Set
34from typing import Tuple
35from typing import TypeVar
36from typing import Union
37from typing import ValuesView
38import weakref
39
40from ._collections_cy import IdentitySet as IdentitySet
41from ._collections_cy import OrderedSet as OrderedSet
42from ._collections_cy import unique_list as unique_list # noqa: F401
43from ._immutabledict_cy import immutabledict as immutabledict
44from ._immutabledict_cy import ImmutableDictBase as ImmutableDictBase
45from ._immutabledict_cy import ReadOnlyContainer as ReadOnlyContainer
46from .typing import is_non_string_iterable
47
48
49_T = TypeVar("_T", bound=Any)
50_KT = TypeVar("_KT", bound=Any)
51_VT = TypeVar("_VT", bound=Any)
52_T_co = TypeVar("_T_co", covariant=True)
53
54EMPTY_SET: FrozenSet[Any] = frozenset()
55NONE_SET: FrozenSet[Any] = frozenset([None])
56
57
58def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]:
59 """merge two lists, maintaining ordering as much as possible.
60
61 this is to reconcile vars(cls) with cls.__annotations__.
62
63 Example::
64
65 >>> a = ["__tablename__", "id", "x", "created_at"]
66 >>> b = ["id", "name", "data", "y", "created_at"]
67 >>> merge_lists_w_ordering(a, b)
68 ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at']
69
70 This is not necessarily the ordering that things had on the class,
71 in this case the class is::
72
73 class User(Base):
74 __tablename__ = "users"
75
76 id: Mapped[int] = mapped_column(primary_key=True)
77 name: Mapped[str]
78 data: Mapped[Optional[str]]
79 x = Column(Integer)
80 y: Mapped[int]
81 created_at: Mapped[datetime.datetime] = mapped_column()
82
83 But things are *mostly* ordered.
84
85 The algorithm could also be done by creating a partial ordering for
86 all items in both lists and then using topological_sort(), but that
87 is too much overhead.
88
89 Background on how I came up with this is at:
90 https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae
91
92 """
93 overlap = set(a).intersection(b)
94
95 result = []
96
97 current, other = iter(a), iter(b)
98
99 while True:
100 for element in current:
101 if element in overlap:
102 overlap.discard(element)
103 other, current = current, other
104 break
105
106 result.append(element)
107 else:
108 result.extend(other)
109 break
110
111 return result
112
113
114def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]:
115 if not d:
116 return EMPTY_DICT
117 elif isinstance(d, immutabledict):
118 return d
119 else:
120 return immutabledict(d)
121
122
123EMPTY_DICT: immutabledict[Any, Any] = immutabledict()
124
125
126class FacadeDict(ImmutableDictBase[_KT, _VT]):
127 """A dictionary that is not publicly mutable."""
128
129 def __new__(cls, *args: Any) -> FacadeDict[Any, Any]:
130 new: FacadeDict[Any, Any] = ImmutableDictBase.__new__(cls)
131 return new
132
133 def copy(self) -> NoReturn:
134 raise NotImplementedError(
135 "an immutabledict shouldn't need to be copied. use dict(d) "
136 "if you need a mutable dictionary."
137 )
138
139 def __reduce__(self) -> Any:
140 return FacadeDict, (dict(self),)
141
142 def _insert_item(self, key: _KT, value: _VT) -> None:
143 """insert an item into the dictionary directly."""
144 dict.__setitem__(self, key, value)
145
146 def __repr__(self) -> str:
147 return "FacadeDict(%s)" % dict.__repr__(self)
148
149
150_DT = TypeVar("_DT", bound=Any)
151
152_F = TypeVar("_F", bound=Any)
153
154
155class Properties(Generic[_T]):
156 """Provide a __getattr__/__setattr__ interface over a dict."""
157
158 __slots__ = ("_data",)
159
160 _data: Dict[str, _T]
161
162 def __init__(self, data: Dict[str, _T]):
163 object.__setattr__(self, "_data", data)
164
165 def __len__(self) -> int:
166 return len(self._data)
167
168 def __iter__(self) -> Iterator[_T]:
169 return iter(list(self._data.values()))
170
171 def __dir__(self) -> List[str]:
172 return dir(super()) + [str(k) for k in self._data.keys()]
173
174 def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]:
175 return list(self) + list(other)
176
177 def __setitem__(self, key: str, obj: _T) -> None:
178 self._data[key] = obj
179
180 def __getitem__(self, key: str) -> _T:
181 return self._data[key]
182
183 def __delitem__(self, key: str) -> None:
184 del self._data[key]
185
186 def __setattr__(self, key: str, obj: _T) -> None:
187 self._data[key] = obj
188
189 def __getstate__(self) -> Dict[str, Any]:
190 return {"_data": self._data}
191
192 def __setstate__(self, state: Dict[str, Any]) -> None:
193 object.__setattr__(self, "_data", state["_data"])
194
195 def __getattr__(self, key: str) -> _T:
196 try:
197 return self._data[key]
198 except KeyError:
199 raise AttributeError(key)
200
201 def __contains__(self, key: str) -> bool:
202 return key in self._data
203
204 def as_readonly(self) -> ReadOnlyProperties[_T]:
205 """Return an immutable proxy for this :class:`.Properties`."""
206
207 return ReadOnlyProperties(self._data)
208
209 def update(self, value: Dict[str, _T]) -> None:
210 self._data.update(value)
211
212 @overload
213 def get(self, key: str) -> Optional[_T]: ...
214
215 @overload
216 def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: ...
217
218 def get(
219 self, key: str, default: Optional[Union[_DT, _T]] = None
220 ) -> Optional[Union[_T, _DT]]:
221 if key in self:
222 return self[key]
223 else:
224 return default
225
226 def keys(self) -> List[str]:
227 return list(self._data)
228
229 def values(self) -> List[_T]:
230 return list(self._data.values())
231
232 def items(self) -> List[Tuple[str, _T]]:
233 return list(self._data.items())
234
235 def has_key(self, key: str) -> bool:
236 return key in self._data
237
238 def clear(self) -> None:
239 self._data.clear()
240
241
242class OrderedProperties(Properties[_T]):
243 """Provide a __getattr__/__setattr__ interface with an OrderedDict
244 as backing store."""
245
246 __slots__ = ()
247
248 def __init__(self):
249 Properties.__init__(self, OrderedDict())
250
251
252class ReadOnlyProperties(ReadOnlyContainer, Properties[_T]):
253 """Provide immutable dict/object attribute to an underlying dictionary."""
254
255 __slots__ = ()
256
257
258def _ordered_dictionary_sort(d, key=None):
259 """Sort an OrderedDict in-place."""
260
261 items = [(k, d[k]) for k in sorted(d, key=key)]
262
263 d.clear()
264
265 d.update(items)
266
267
268OrderedDict = dict
269sort_dictionary = _ordered_dictionary_sort
270
271
272class WeakSequence(Sequence[_T]):
273 def __init__(self, __elements: Sequence[_T] = (), /):
274 # adapted from weakref.WeakKeyDictionary, prevent reference
275 # cycles in the collection itself
276 def _remove(item, selfref=weakref.ref(self)):
277 self = selfref()
278 if self is not None:
279 self._storage.remove(item)
280
281 self._remove = _remove
282 self._storage = [
283 weakref.ref(element, _remove) for element in __elements
284 ]
285
286 def append(self, item):
287 self._storage.append(weakref.ref(item, self._remove))
288
289 def __len__(self):
290 return len(self._storage)
291
292 def __iter__(self):
293 return (
294 obj for obj in (ref() for ref in self._storage) if obj is not None
295 )
296
297 def __getitem__(self, index):
298 return self._storage[index]()
299
300
301OrderedIdentitySet = IdentitySet
302
303
304class PopulateDict(Dict[_KT, _VT]):
305 """A dict which populates missing values via a creation function.
306
307 Note the creation function takes a key, unlike
308 collections.defaultdict.
309
310 """
311
312 def __init__(self, creator: Callable[[_KT], _VT]):
313 self.creator = creator
314
315 def __missing__(self, key: Any) -> Any:
316 self[key] = val = self.creator(key)
317 return val
318
319
320class WeakPopulateDict(Dict[_KT, _VT]):
321 """Like PopulateDict, but assumes a self + a method and does not create
322 a reference cycle.
323
324 """
325
326 def __init__(self, creator_method: types.MethodType):
327 self.creator = creator_method.__func__
328 weakself = creator_method.__self__
329 self.weakself = weakref.ref(weakself)
330
331 def __missing__(self, key: Any) -> Any:
332 self[key] = val = self.creator(self.weakself(), key)
333 return val
334
335
336# Define collections that are capable of storing
337# ColumnElement objects as hashable keys/elements.
338# At this point, these are mostly historical, things
339# used to be more complicated.
340column_set = set
341column_dict = dict
342ordered_column_set = OrderedSet
343
344
345class UniqueAppender(Generic[_T]):
346 """Appends items to a collection ensuring uniqueness.
347
348 Additional appends() of the same object are ignored. Membership is
349 determined by identity (``is a``) not equality (``==``).
350 """
351
352 __slots__ = "data", "_data_appender", "_unique"
353
354 data: Union[Iterable[_T], Set[_T], List[_T]]
355 _data_appender: Callable[[_T], None]
356 _unique: Dict[int, Literal[True]]
357
358 def __init__(
359 self,
360 data: Union[Iterable[_T], Set[_T], List[_T]],
361 via: Optional[str] = None,
362 ):
363 self.data = data
364 self._unique = {}
365 if via:
366 self._data_appender = getattr(data, via)
367 elif hasattr(data, "append"):
368 self._data_appender = cast("List[_T]", data).append
369 elif hasattr(data, "add"):
370 self._data_appender = cast("Set[_T]", data).add
371
372 def append(self, item: _T) -> None:
373 id_ = id(item)
374 if id_ not in self._unique:
375 self._data_appender(item)
376 self._unique[id_] = True
377
378 def __iter__(self) -> Iterator[_T]:
379 return iter(self.data)
380
381
382def coerce_generator_arg(arg: Any) -> List[Any]:
383 if len(arg) == 1 and isinstance(arg[0], types.GeneratorType):
384 return list(arg[0])
385 else:
386 return cast("List[Any]", arg)
387
388
389def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]:
390 if x is None:
391 return default # type: ignore
392 if not is_non_string_iterable(x):
393 return [x]
394 elif isinstance(x, list):
395 return x
396 else:
397 return list(x)
398
399
400def has_intersection(set_: Container[Any], iterable: Iterable[Any]) -> bool:
401 r"""return True if any items of set\_ are present in iterable.
402
403 Goes through special effort to ensure __hash__ is not called
404 on items in iterable that don't support it.
405
406 """
407 return any(i in set_ for i in iterable if i.__hash__)
408
409
410def to_set(x):
411 if x is None:
412 return set()
413 if not isinstance(x, set):
414 return set(to_list(x))
415 else:
416 return x
417
418
419def to_column_set(x: Any) -> Set[Any]:
420 if x is None:
421 return column_set()
422 if not isinstance(x, column_set):
423 return column_set(to_list(x))
424 else:
425 return x
426
427
428def update_copy(
429 d: Dict[Any, Any], _new: Optional[Dict[Any, Any]] = None, **kw: Any
430) -> Dict[Any, Any]:
431 """Copy the given dict and update with the given values."""
432
433 d = d.copy()
434 if _new:
435 d.update(_new)
436 d.update(**kw)
437 return d
438
439
440def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]:
441 """Given an iterator of which further sub-elements may also be
442 iterators, flatten the sub-elements into a single iterator.
443
444 """
445 elem: _T
446 for elem in x:
447 if not isinstance(elem, str) and hasattr(elem, "__iter__"):
448 yield from flatten_iterator(elem)
449 else:
450 yield elem
451
452
453class LRUCache(typing.MutableMapping[_KT, _VT]):
454 """Dictionary with 'squishy' removal of least
455 recently used items.
456
457 Note that either get() or [] should be used here, but
458 generally its not safe to do an "in" check first as the dictionary
459 can change subsequent to that call.
460
461 """
462
463 __slots__ = (
464 "capacity",
465 "threshold",
466 "size_alert",
467 "_data",
468 "_counter",
469 "_mutex",
470 )
471
472 capacity: int
473 threshold: float
474 size_alert: Optional[Callable[[LRUCache[_KT, _VT]], None]]
475
476 def __init__(
477 self,
478 capacity: int = 100,
479 threshold: float = 0.5,
480 size_alert: Optional[Callable[..., None]] = None,
481 ):
482 self.capacity = capacity
483 self.threshold = threshold
484 self.size_alert = size_alert
485 self._counter = 0
486 self._mutex = threading.Lock()
487 self._data: Dict[_KT, Tuple[_KT, _VT, List[int]]] = {}
488
489 def _inc_counter(self):
490 self._counter += 1
491 return self._counter
492
493 @overload
494 def get(self, key: _KT) -> Optional[_VT]: ...
495
496 @overload
497 def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ...
498
499 def get(
500 self, key: _KT, default: Optional[Union[_VT, _T]] = None
501 ) -> Optional[Union[_VT, _T]]:
502 item = self._data.get(key)
503 if item is not None:
504 item[2][0] = self._inc_counter()
505 return item[1]
506 else:
507 return default
508
509 def __getitem__(self, key: _KT) -> _VT:
510 item = self._data[key]
511 item[2][0] = self._inc_counter()
512 return item[1]
513
514 def __iter__(self) -> Iterator[_KT]:
515 return iter(self._data)
516
517 def __len__(self) -> int:
518 return len(self._data)
519
520 def values(self) -> ValuesView[_VT]:
521 return typing.ValuesView({k: i[1] for k, i in self._data.items()})
522
523 def __setitem__(self, key: _KT, value: _VT) -> None:
524 self._data[key] = (key, value, [self._inc_counter()])
525 self._manage_size()
526
527 def __delitem__(self, __v: _KT) -> None:
528 del self._data[__v]
529
530 @property
531 def size_threshold(self) -> float:
532 return self.capacity + self.capacity * self.threshold
533
534 def _manage_size(self) -> None:
535 if not self._mutex.acquire(False):
536 return
537 try:
538 size_alert = bool(self.size_alert)
539 while len(self) > self.capacity + self.capacity * self.threshold:
540 if size_alert:
541 size_alert = False
542 self.size_alert(self) # type: ignore
543 by_counter = sorted(
544 self._data.values(),
545 key=operator.itemgetter(2),
546 reverse=True,
547 )
548 for item in by_counter[self.capacity :]:
549 try:
550 del self._data[item[0]]
551 except KeyError:
552 # deleted elsewhere; skip
553 continue
554 finally:
555 self._mutex.release()
556
557
558class _CreateFuncType(Protocol[_T_co]):
559 def __call__(self) -> _T_co: ...
560
561
562class _ScopeFuncType(Protocol):
563 def __call__(self) -> Any: ...
564
565
566class ScopedRegistry(Generic[_T]):
567 """A Registry that can store one or multiple instances of a single
568 class on the basis of a "scope" function.
569
570 The object implements ``__call__`` as the "getter", so by
571 calling ``myregistry()`` the contained object is returned
572 for the current scope.
573
574 :param createfunc:
575 a callable that returns a new object to be placed in the registry
576
577 :param scopefunc:
578 a callable that will return a key to store/retrieve an object.
579 """
580
581 __slots__ = "createfunc", "scopefunc", "registry"
582
583 createfunc: _CreateFuncType[_T]
584 scopefunc: _ScopeFuncType
585 registry: Any
586
587 def __init__(
588 self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any]
589 ):
590 """Construct a new :class:`.ScopedRegistry`.
591
592 :param createfunc: A creation function that will generate
593 a new value for the current scope, if none is present.
594
595 :param scopefunc: A function that returns a hashable
596 token representing the current scope (such as, current
597 thread identifier).
598
599 """
600 self.createfunc = createfunc
601 self.scopefunc = scopefunc
602 self.registry = {}
603
604 def __call__(self) -> _T:
605 key = self.scopefunc()
606 try:
607 return self.registry[key] # type: ignore[no-any-return]
608 except KeyError:
609 return self.registry.setdefault(key, self.createfunc()) # type: ignore[no-any-return] # noqa: E501
610
611 def has(self) -> bool:
612 """Return True if an object is present in the current scope."""
613
614 return self.scopefunc() in self.registry
615
616 def set(self, obj: _T) -> None:
617 """Set the value for the current scope."""
618
619 self.registry[self.scopefunc()] = obj
620
621 def clear(self) -> None:
622 """Clear the current scope, if any."""
623
624 try:
625 del self.registry[self.scopefunc()]
626 except KeyError:
627 pass
628
629
630class ThreadLocalRegistry(ScopedRegistry[_T]):
631 """A :class:`.ScopedRegistry` that uses a ``threading.local()``
632 variable for storage.
633
634 """
635
636 def __init__(self, createfunc: Callable[[], _T]):
637 self.createfunc = createfunc
638 self.registry = threading.local()
639
640 def __call__(self) -> _T:
641 try:
642 return self.registry.value # type: ignore[no-any-return]
643 except AttributeError:
644 val = self.registry.value = self.createfunc()
645 return val
646
647 def has(self) -> bool:
648 return hasattr(self.registry, "value")
649
650 def set(self, obj: _T) -> None:
651 self.registry.value = obj
652
653 def clear(self) -> None:
654 try:
655 del self.registry.value
656 except AttributeError:
657 pass
658
659
660def has_dupes(sequence, target):
661 """Given a sequence and search object, return True if there's more
662 than one, False if zero or one of them.
663
664
665 """
666 # compare to .index version below, this version introduces less function
667 # overhead and is usually the same speed. At 15000 items (way bigger than
668 # a relationship-bound collection in memory usually is) it begins to
669 # fall behind the other version only by microseconds.
670 c = 0
671 for item in sequence:
672 if item is target:
673 c += 1
674 if c > 1:
675 return True
676 return False
677
678
679# .index version. the two __contains__ calls as well
680# as .index() and isinstance() slow this down.
681# def has_dupes(sequence, target):
682# if target not in sequence:
683# return False
684# elif not isinstance(sequence, collections_abc.Sequence):
685# return False
686#
687# idx = sequence.index(target)
688# return target in sequence[idx + 1:]