1# util/_collections.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Collection classes and helpers."""
10from __future__ import annotations
11
12import operator
13import threading
14import types
15import typing
16from typing import Any
17from typing import Callable
18from typing import cast
19from typing import Container
20from typing import Dict
21from typing import FrozenSet
22from typing import Generic
23from typing import Iterable
24from typing import Iterator
25from typing import List
26from typing import Mapping
27from typing import NoReturn
28from typing import Optional
29from typing import overload
30from typing import Sequence
31from typing import Set
32from typing import Tuple
33from typing import TypeVar
34from typing import Union
35from typing import ValuesView
36import weakref
37
38from ._has_cy import HAS_CYEXTENSION
39from .typing import is_non_string_iterable
40from .typing import Literal
41from .typing import Protocol
42
43if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
44 from ._py_collections import immutabledict as immutabledict
45 from ._py_collections import IdentitySet as IdentitySet
46 from ._py_collections import ReadOnlyContainer as ReadOnlyContainer
47 from ._py_collections import ImmutableDictBase as ImmutableDictBase
48 from ._py_collections import OrderedSet as OrderedSet
49 from ._py_collections import unique_list as unique_list
50else:
51 from sqlalchemy.cyextension.immutabledict import (
52 ReadOnlyContainer as ReadOnlyContainer,
53 )
54 from sqlalchemy.cyextension.immutabledict import (
55 ImmutableDictBase as ImmutableDictBase,
56 )
57 from sqlalchemy.cyextension.immutabledict import (
58 immutabledict as immutabledict,
59 )
60 from sqlalchemy.cyextension.collections import IdentitySet as IdentitySet
61 from sqlalchemy.cyextension.collections import OrderedSet as OrderedSet
62 from sqlalchemy.cyextension.collections import ( # noqa
63 unique_list as unique_list,
64 )
65
66
67_T = TypeVar("_T", bound=Any)
68_KT = TypeVar("_KT", bound=Any)
69_VT = TypeVar("_VT", bound=Any)
70_T_co = TypeVar("_T_co", covariant=True)
71
72EMPTY_SET: FrozenSet[Any] = frozenset()
73NONE_SET: FrozenSet[Any] = frozenset([None])
74
75
76def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]:
77 """merge two lists, maintaining ordering as much as possible.
78
79 this is to reconcile vars(cls) with cls.__annotations__.
80
81 Example::
82
83 >>> a = ["__tablename__", "id", "x", "created_at"]
84 >>> b = ["id", "name", "data", "y", "created_at"]
85 >>> merge_lists_w_ordering(a, b)
86 ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at']
87
88 This is not necessarily the ordering that things had on the class,
89 in this case the class is::
90
91 class User(Base):
92 __tablename__ = "users"
93
94 id: Mapped[int] = mapped_column(primary_key=True)
95 name: Mapped[str]
96 data: Mapped[Optional[str]]
97 x = Column(Integer)
98 y: Mapped[int]
99 created_at: Mapped[datetime.datetime] = mapped_column()
100
101 But things are *mostly* ordered.
102
103 The algorithm could also be done by creating a partial ordering for
104 all items in both lists and then using topological_sort(), but that
105 is too much overhead.
106
107 Background on how I came up with this is at:
108 https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae
109
110 """
111 overlap = set(a).intersection(b)
112
113 result = []
114
115 current, other = iter(a), iter(b)
116
117 while True:
118 for element in current:
119 if element in overlap:
120 overlap.discard(element)
121 other, current = current, other
122 break
123
124 result.append(element)
125 else:
126 result.extend(other)
127 break
128
129 return result
130
131
132def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]:
133 if not d:
134 return EMPTY_DICT
135 elif isinstance(d, immutabledict):
136 return d
137 else:
138 return immutabledict(d)
139
140
141EMPTY_DICT: immutabledict[Any, Any] = immutabledict()
142
143
144class FacadeDict(ImmutableDictBase[_KT, _VT]):
145 """A dictionary that is not publicly mutable."""
146
147 def __new__(cls, *args: Any) -> FacadeDict[Any, Any]:
148 new = ImmutableDictBase.__new__(cls)
149 return new
150
151 def copy(self) -> NoReturn:
152 raise NotImplementedError(
153 "an immutabledict shouldn't need to be copied. use dict(d) "
154 "if you need a mutable dictionary."
155 )
156
157 def __reduce__(self) -> Any:
158 return FacadeDict, (dict(self),)
159
160 def _insert_item(self, key: _KT, value: _VT) -> None:
161 """insert an item into the dictionary directly."""
162 dict.__setitem__(self, key, value)
163
164 def __repr__(self) -> str:
165 return "FacadeDict(%s)" % dict.__repr__(self)
166
167
168_DT = TypeVar("_DT", bound=Any)
169
170_F = TypeVar("_F", bound=Any)
171
172
173class Properties(Generic[_T]):
174 """Provide a __getattr__/__setattr__ interface over a dict."""
175
176 __slots__ = ("_data",)
177
178 _data: Dict[str, _T]
179
180 def __init__(self, data: Dict[str, _T]):
181 object.__setattr__(self, "_data", data)
182
183 def __len__(self) -> int:
184 return len(self._data)
185
186 def __iter__(self) -> Iterator[_T]:
187 return iter(list(self._data.values()))
188
189 def __dir__(self) -> List[str]:
190 return dir(super()) + [str(k) for k in self._data.keys()]
191
192 def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]:
193 return list(self) + list(other)
194
195 def __setitem__(self, key: str, obj: _T) -> None:
196 self._data[key] = obj
197
198 def __getitem__(self, key: str) -> _T:
199 return self._data[key]
200
201 def __delitem__(self, key: str) -> None:
202 del self._data[key]
203
204 def __setattr__(self, key: str, obj: _T) -> None:
205 self._data[key] = obj
206
207 def __getstate__(self) -> Dict[str, Any]:
208 return {"_data": self._data}
209
210 def __setstate__(self, state: Dict[str, Any]) -> None:
211 object.__setattr__(self, "_data", state["_data"])
212
213 def __getattr__(self, key: str) -> _T:
214 try:
215 return self._data[key]
216 except KeyError:
217 raise AttributeError(key)
218
219 def __contains__(self, key: str) -> bool:
220 return key in self._data
221
222 def as_readonly(self) -> ReadOnlyProperties[_T]:
223 """Return an immutable proxy for this :class:`.Properties`."""
224
225 return ReadOnlyProperties(self._data)
226
227 def update(self, value: Dict[str, _T]) -> None:
228 self._data.update(value)
229
230 @overload
231 def get(self, key: str) -> Optional[_T]: ...
232
233 @overload
234 def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: ...
235
236 def get(
237 self, key: str, default: Optional[Union[_DT, _T]] = None
238 ) -> Optional[Union[_T, _DT]]:
239 if key in self:
240 return self[key]
241 else:
242 return default
243
244 def keys(self) -> List[str]:
245 return list(self._data)
246
247 def values(self) -> List[_T]:
248 return list(self._data.values())
249
250 def items(self) -> List[Tuple[str, _T]]:
251 return list(self._data.items())
252
253 def has_key(self, key: str) -> bool:
254 return key in self._data
255
256 def clear(self) -> None:
257 self._data.clear()
258
259
260class OrderedProperties(Properties[_T]):
261 """Provide a __getattr__/__setattr__ interface with an OrderedDict
262 as backing store."""
263
264 __slots__ = ()
265
266 def __init__(self):
267 Properties.__init__(self, OrderedDict())
268
269
270class ReadOnlyProperties(ReadOnlyContainer, Properties[_T]):
271 """Provide immutable dict/object attribute to an underlying dictionary."""
272
273 __slots__ = ()
274
275
276def _ordered_dictionary_sort(d, key=None):
277 """Sort an OrderedDict in-place."""
278
279 items = [(k, d[k]) for k in sorted(d, key=key)]
280
281 d.clear()
282
283 d.update(items)
284
285
286OrderedDict = dict
287sort_dictionary = _ordered_dictionary_sort
288
289
290class WeakSequence(Sequence[_T]):
291 def __init__(self, __elements: Sequence[_T] = ()):
292 # adapted from weakref.WeakKeyDictionary, prevent reference
293 # cycles in the collection itself
294 def _remove(item, selfref=weakref.ref(self)):
295 self = selfref()
296 if self is not None:
297 self._storage.remove(item)
298
299 self._remove = _remove
300 self._storage = [
301 weakref.ref(element, _remove) for element in __elements
302 ]
303
304 def append(self, item):
305 self._storage.append(weakref.ref(item, self._remove))
306
307 def __len__(self):
308 return len(self._storage)
309
310 def __iter__(self):
311 return (
312 obj for obj in (ref() for ref in self._storage) if obj is not None
313 )
314
315 def __getitem__(self, index):
316 return self._storage[index]()
317
318
319class OrderedIdentitySet(IdentitySet):
320 def __init__(self, iterable: Optional[Iterable[Any]] = None):
321 IdentitySet.__init__(self)
322 self._members = OrderedDict()
323 if iterable:
324 for o in iterable:
325 self.add(o)
326
327
328class PopulateDict(Dict[_KT, _VT]):
329 """A dict which populates missing values via a creation function.
330
331 Note the creation function takes a key, unlike
332 collections.defaultdict.
333
334 """
335
336 def __init__(self, creator: Callable[[_KT], _VT]):
337 self.creator = creator
338
339 def __missing__(self, key: Any) -> Any:
340 self[key] = val = self.creator(key)
341 return val
342
343
344class WeakPopulateDict(Dict[_KT, _VT]):
345 """Like PopulateDict, but assumes a self + a method and does not create
346 a reference cycle.
347
348 """
349
350 def __init__(self, creator_method: types.MethodType):
351 self.creator = creator_method.__func__
352 weakself = creator_method.__self__
353 self.weakself = weakref.ref(weakself)
354
355 def __missing__(self, key: Any) -> Any:
356 self[key] = val = self.creator(self.weakself(), key)
357 return val
358
359
360# Define collections that are capable of storing
361# ColumnElement objects as hashable keys/elements.
362# At this point, these are mostly historical, things
363# used to be more complicated.
364column_set = set
365column_dict = dict
366ordered_column_set = OrderedSet
367
368
369class UniqueAppender(Generic[_T]):
370 """Appends items to a collection ensuring uniqueness.
371
372 Additional appends() of the same object are ignored. Membership is
373 determined by identity (``is a``) not equality (``==``).
374 """
375
376 __slots__ = "data", "_data_appender", "_unique"
377
378 data: Union[Iterable[_T], Set[_T], List[_T]]
379 _data_appender: Callable[[_T], None]
380 _unique: Dict[int, Literal[True]]
381
382 def __init__(
383 self,
384 data: Union[Iterable[_T], Set[_T], List[_T]],
385 via: Optional[str] = None,
386 ):
387 self.data = data
388 self._unique = {}
389 if via:
390 self._data_appender = getattr(data, via)
391 elif hasattr(data, "append"):
392 self._data_appender = cast("List[_T]", data).append
393 elif hasattr(data, "add"):
394 self._data_appender = cast("Set[_T]", data).add
395
396 def append(self, item: _T) -> None:
397 id_ = id(item)
398 if id_ not in self._unique:
399 self._data_appender(item)
400 self._unique[id_] = True
401
402 def __iter__(self) -> Iterator[_T]:
403 return iter(self.data)
404
405
406def coerce_generator_arg(arg: Any) -> List[Any]:
407 if len(arg) == 1 and isinstance(arg[0], types.GeneratorType):
408 return list(arg[0])
409 else:
410 return cast("List[Any]", arg)
411
412
413def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]:
414 if x is None:
415 return default # type: ignore
416 if not is_non_string_iterable(x):
417 return [x]
418 elif isinstance(x, list):
419 return x
420 else:
421 return list(x)
422
423
424def has_intersection(set_: Container[Any], iterable: Iterable[Any]) -> bool:
425 r"""return True if any items of set\_ are present in iterable.
426
427 Goes through special effort to ensure __hash__ is not called
428 on items in iterable that don't support it.
429
430 """
431 return any(i in set_ for i in iterable if i.__hash__)
432
433
434def to_set(x):
435 if x is None:
436 return set()
437 if not isinstance(x, set):
438 return set(to_list(x))
439 else:
440 return x
441
442
443def to_column_set(x: Any) -> Set[Any]:
444 if x is None:
445 return column_set()
446 if not isinstance(x, column_set):
447 return column_set(to_list(x))
448 else:
449 return x
450
451
452def update_copy(
453 d: Dict[Any, Any], _new: Optional[Dict[Any, Any]] = None, **kw: Any
454) -> Dict[Any, Any]:
455 """Copy the given dict and update with the given values."""
456
457 d = d.copy()
458 if _new:
459 d.update(_new)
460 d.update(**kw)
461 return d
462
463
464def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]:
465 """Given an iterator of which further sub-elements may also be
466 iterators, flatten the sub-elements into a single iterator.
467
468 """
469 elem: _T
470 for elem in x:
471 if not isinstance(elem, str) and hasattr(elem, "__iter__"):
472 yield from flatten_iterator(elem)
473 else:
474 yield elem
475
476
477class LRUCache(typing.MutableMapping[_KT, _VT]):
478 """Dictionary with 'squishy' removal of least
479 recently used items.
480
481 Note that either get() or [] should be used here, but
482 generally its not safe to do an "in" check first as the dictionary
483 can change subsequent to that call.
484
485 """
486
487 __slots__ = (
488 "capacity",
489 "threshold",
490 "size_alert",
491 "_data",
492 "_counter",
493 "_mutex",
494 )
495
496 capacity: int
497 threshold: float
498 size_alert: Optional[Callable[[LRUCache[_KT, _VT]], None]]
499
500 def __init__(
501 self,
502 capacity: int = 100,
503 threshold: float = 0.5,
504 size_alert: Optional[Callable[..., None]] = None,
505 ):
506 self.capacity = capacity
507 self.threshold = threshold
508 self.size_alert = size_alert
509 self._counter = 0
510 self._mutex = threading.Lock()
511 self._data: Dict[_KT, Tuple[_KT, _VT, List[int]]] = {}
512
513 def _inc_counter(self):
514 self._counter += 1
515 return self._counter
516
517 @overload
518 def get(self, key: _KT) -> Optional[_VT]: ...
519
520 @overload
521 def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ...
522
523 def get(
524 self, key: _KT, default: Optional[Union[_VT, _T]] = None
525 ) -> Optional[Union[_VT, _T]]:
526 item = self._data.get(key)
527 if item is not None:
528 item[2][0] = self._inc_counter()
529 return item[1]
530 else:
531 return default
532
533 def __getitem__(self, key: _KT) -> _VT:
534 item = self._data[key]
535 item[2][0] = self._inc_counter()
536 return item[1]
537
538 def __iter__(self) -> Iterator[_KT]:
539 return iter(self._data)
540
541 def __len__(self) -> int:
542 return len(self._data)
543
544 def values(self) -> ValuesView[_VT]:
545 return typing.ValuesView({k: i[1] for k, i in self._data.items()})
546
547 def __setitem__(self, key: _KT, value: _VT) -> None:
548 self._data[key] = (key, value, [self._inc_counter()])
549 self._manage_size()
550
551 def __delitem__(self, __v: _KT) -> None:
552 del self._data[__v]
553
554 @property
555 def size_threshold(self) -> float:
556 return self.capacity + self.capacity * self.threshold
557
558 def _manage_size(self) -> None:
559 if not self._mutex.acquire(False):
560 return
561 try:
562 size_alert = bool(self.size_alert)
563 while len(self) > self.capacity + self.capacity * self.threshold:
564 if size_alert:
565 size_alert = False
566 self.size_alert(self) # type: ignore
567 by_counter = sorted(
568 self._data.values(),
569 key=operator.itemgetter(2),
570 reverse=True,
571 )
572 for item in by_counter[self.capacity :]:
573 try:
574 del self._data[item[0]]
575 except KeyError:
576 # deleted elsewhere; skip
577 continue
578 finally:
579 self._mutex.release()
580
581
582class _CreateFuncType(Protocol[_T_co]):
583 def __call__(self) -> _T_co: ...
584
585
586class _ScopeFuncType(Protocol):
587 def __call__(self) -> Any: ...
588
589
590class ScopedRegistry(Generic[_T]):
591 """A Registry that can store one or multiple instances of a single
592 class on the basis of a "scope" function.
593
594 The object implements ``__call__`` as the "getter", so by
595 calling ``myregistry()`` the contained object is returned
596 for the current scope.
597
598 :param createfunc:
599 a callable that returns a new object to be placed in the registry
600
601 :param scopefunc:
602 a callable that will return a key to store/retrieve an object.
603 """
604
605 __slots__ = "createfunc", "scopefunc", "registry"
606
607 createfunc: _CreateFuncType[_T]
608 scopefunc: _ScopeFuncType
609 registry: Any
610
611 def __init__(
612 self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any]
613 ):
614 """Construct a new :class:`.ScopedRegistry`.
615
616 :param createfunc: A creation function that will generate
617 a new value for the current scope, if none is present.
618
619 :param scopefunc: A function that returns a hashable
620 token representing the current scope (such as, current
621 thread identifier).
622
623 """
624 self.createfunc = createfunc
625 self.scopefunc = scopefunc
626 self.registry = {}
627
628 def __call__(self) -> _T:
629 key = self.scopefunc()
630 try:
631 return self.registry[key] # type: ignore[no-any-return]
632 except KeyError:
633 return self.registry.setdefault(key, self.createfunc()) # type: ignore[no-any-return] # noqa: E501
634
635 def has(self) -> bool:
636 """Return True if an object is present in the current scope."""
637
638 return self.scopefunc() in self.registry
639
640 def set(self, obj: _T) -> None:
641 """Set the value for the current scope."""
642
643 self.registry[self.scopefunc()] = obj
644
645 def clear(self) -> None:
646 """Clear the current scope, if any."""
647
648 try:
649 del self.registry[self.scopefunc()]
650 except KeyError:
651 pass
652
653
654class ThreadLocalRegistry(ScopedRegistry[_T]):
655 """A :class:`.ScopedRegistry` that uses a ``threading.local()``
656 variable for storage.
657
658 """
659
660 def __init__(self, createfunc: Callable[[], _T]):
661 self.createfunc = createfunc
662 self.registry = threading.local()
663
664 def __call__(self) -> _T:
665 try:
666 return self.registry.value # type: ignore[no-any-return]
667 except AttributeError:
668 val = self.registry.value = self.createfunc()
669 return val
670
671 def has(self) -> bool:
672 return hasattr(self.registry, "value")
673
674 def set(self, obj: _T) -> None:
675 self.registry.value = obj
676
677 def clear(self) -> None:
678 try:
679 del self.registry.value
680 except AttributeError:
681 pass
682
683
684def has_dupes(sequence, target):
685 """Given a sequence and search object, return True if there's more
686 than one, False if zero or one of them.
687
688
689 """
690 # compare to .index version below, this version introduces less function
691 # overhead and is usually the same speed. At 15000 items (way bigger than
692 # a relationship-bound collection in memory usually is) it begins to
693 # fall behind the other version only by microseconds.
694 c = 0
695 for item in sequence:
696 if item is target:
697 c += 1
698 if c > 1:
699 return True
700 return False
701
702
703# .index version. the two __contains__ calls as well
704# as .index() and isinstance() slow this down.
705# def has_dupes(sequence, target):
706# if target not in sequence:
707# return False
708# elif not isinstance(sequence, collections_abc.Sequence):
709# return False
710#
711# idx = sequence.index(target)
712# return target in sequence[idx + 1:]