1# util/_collections.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Collection classes and helpers."""
10from __future__ import annotations
11
12import operator
13import threading
14import types
15import typing
16from typing import Any
17from typing import Callable
18from typing import cast
19from typing import Container
20from typing import Dict
21from typing import FrozenSet
22from typing import Generic
23from typing import Iterable
24from typing import Iterator
25from typing import List
26from typing import Mapping
27from typing import NoReturn
28from typing import Optional
29from typing import overload
30from typing import Sequence
31from typing import Set
32from typing import Tuple
33from typing import TypeVar
34from typing import Union
35from typing import ValuesView
36import weakref
37
38from ._has_cy import HAS_CYEXTENSION
39from .typing import is_non_string_iterable
40from .typing import Literal
41from .typing import Protocol
42
43if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
44 from ._py_collections import immutabledict as immutabledict
45 from ._py_collections import IdentitySet as IdentitySet
46 from ._py_collections import ReadOnlyContainer as ReadOnlyContainer
47 from ._py_collections import ImmutableDictBase as ImmutableDictBase
48 from ._py_collections import OrderedSet as OrderedSet
49 from ._py_collections import unique_list as unique_list
50else:
51 from sqlalchemy.cyextension.immutabledict import (
52 ReadOnlyContainer as ReadOnlyContainer,
53 )
54 from sqlalchemy.cyextension.immutabledict import (
55 ImmutableDictBase as ImmutableDictBase,
56 )
57 from sqlalchemy.cyextension.immutabledict import (
58 immutabledict as immutabledict,
59 )
60 from sqlalchemy.cyextension.collections import IdentitySet as IdentitySet
61 from sqlalchemy.cyextension.collections import OrderedSet as OrderedSet
62 from sqlalchemy.cyextension.collections import ( # noqa
63 unique_list as unique_list,
64 )
65
66
67_T = TypeVar("_T", bound=Any)
68_KT = TypeVar("_KT", bound=Any)
69_VT = TypeVar("_VT", bound=Any)
70_T_co = TypeVar("_T_co", covariant=True)
71
72EMPTY_SET: FrozenSet[Any] = frozenset()
73NONE_SET: FrozenSet[Any] = frozenset([None])
74
75
76def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]:
77 """merge two lists, maintaining ordering as much as possible.
78
79 this is to reconcile vars(cls) with cls.__annotations__.
80
81 Example::
82
83 >>> a = ["__tablename__", "id", "x", "created_at"]
84 >>> b = ["id", "name", "data", "y", "created_at"]
85 >>> merge_lists_w_ordering(a, b)
86 ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at']
87
88 This is not necessarily the ordering that things had on the class,
89 in this case the class is::
90
91 class User(Base):
92 __tablename__ = "users"
93
94 id: Mapped[int] = mapped_column(primary_key=True)
95 name: Mapped[str]
96 data: Mapped[Optional[str]]
97 x = Column(Integer)
98 y: Mapped[int]
99 created_at: Mapped[datetime.datetime] = mapped_column()
100
101 But things are *mostly* ordered.
102
103 The algorithm could also be done by creating a partial ordering for
104 all items in both lists and then using topological_sort(), but that
105 is too much overhead.
106
107 Background on how I came up with this is at:
108 https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae
109
110 """
111 overlap = set(a).intersection(b)
112
113 result = []
114
115 current, other = iter(a), iter(b)
116
117 while True:
118 for element in current:
119 if element in overlap:
120 overlap.discard(element)
121 other, current = current, other
122 break
123
124 result.append(element)
125 else:
126 result.extend(other)
127 break
128
129 return result
130
131
132def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]:
133 if not d:
134 return EMPTY_DICT
135 elif isinstance(d, immutabledict):
136 return d
137 else:
138 return immutabledict(d)
139
140
141EMPTY_DICT: immutabledict[Any, Any] = immutabledict()
142
143
144class FacadeDict(ImmutableDictBase[_KT, _VT]):
145 """A dictionary that is not publicly mutable."""
146
147 def __new__(cls, *args: Any) -> FacadeDict[Any, Any]:
148 new = ImmutableDictBase.__new__(cls)
149 return new
150
151 def copy(self) -> NoReturn:
152 raise NotImplementedError(
153 "an immutabledict shouldn't need to be copied. use dict(d) "
154 "if you need a mutable dictionary."
155 )
156
157 def __reduce__(self) -> Any:
158 return FacadeDict, (dict(self),)
159
160 def _insert_item(self, key: _KT, value: _VT) -> None:
161 """insert an item into the dictionary directly."""
162 dict.__setitem__(self, key, value)
163
164 def __repr__(self) -> str:
165 return "FacadeDict(%s)" % dict.__repr__(self)
166
167
168_DT = TypeVar("_DT", bound=Any)
169
170_F = TypeVar("_F", bound=Any)
171
172
173class Properties(Generic[_T]):
174 """Provide a __getattr__/__setattr__ interface over a dict."""
175
176 __slots__ = ("_data",)
177
178 _data: Dict[str, _T]
179
180 def __init__(self, data: Dict[str, _T]):
181 object.__setattr__(self, "_data", data)
182
183 def __len__(self) -> int:
184 return len(self._data)
185
186 def __iter__(self) -> Iterator[_T]:
187 return iter(list(self._data.values()))
188
189 def __dir__(self) -> List[str]:
190 return dir(super()) + [str(k) for k in self._data.keys()]
191
192 def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]:
193 return list(self) + list(other)
194
195 def __setitem__(self, key: str, obj: _T) -> None:
196 self._data[key] = obj
197
198 def __getitem__(self, key: str) -> _T:
199 return self._data[key]
200
201 def __delitem__(self, key: str) -> None:
202 del self._data[key]
203
204 def __setattr__(self, key: str, obj: _T) -> None:
205 self._data[key] = obj
206
207 def __getstate__(self) -> Dict[str, Any]:
208 return {"_data": self._data}
209
210 def __setstate__(self, state: Dict[str, Any]) -> None:
211 object.__setattr__(self, "_data", state["_data"])
212
213 def __getattr__(self, key: str) -> _T:
214 try:
215 return self._data[key]
216 except KeyError:
217 raise AttributeError(key)
218
219 def __contains__(self, key: str) -> bool:
220 return key in self._data
221
222 def as_readonly(self) -> ReadOnlyProperties[_T]:
223 """Return an immutable proxy for this :class:`.Properties`."""
224
225 return ReadOnlyProperties(self._data)
226
227 def update(self, value: Dict[str, _T]) -> None:
228 self._data.update(value)
229
230 @overload
231 def get(self, key: str) -> Optional[_T]: ...
232
233 @overload
234 def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: ...
235
236 def get(
237 self, key: str, default: Optional[Union[_DT, _T]] = None
238 ) -> Optional[Union[_T, _DT]]:
239 if key in self:
240 return self[key]
241 else:
242 return default
243
244 def keys(self) -> List[str]:
245 return list(self._data)
246
247 def values(self) -> List[_T]:
248 return list(self._data.values())
249
250 def items(self) -> List[Tuple[str, _T]]:
251 return list(self._data.items())
252
253 def has_key(self, key: str) -> bool:
254 return key in self._data
255
256 def clear(self) -> None:
257 self._data.clear()
258
259
260class OrderedProperties(Properties[_T]):
261 """Provide a __getattr__/__setattr__ interface with an OrderedDict
262 as backing store."""
263
264 __slots__ = ()
265
266 def __init__(self):
267 Properties.__init__(self, OrderedDict())
268
269
270class ReadOnlyProperties(ReadOnlyContainer, Properties[_T]):
271 """Provide immutable dict/object attribute to an underlying dictionary."""
272
273 __slots__ = ()
274
275
276def _ordered_dictionary_sort(d, key=None):
277 """Sort an OrderedDict in-place."""
278
279 items = [(k, d[k]) for k in sorted(d, key=key)]
280
281 d.clear()
282
283 d.update(items)
284
285
286OrderedDict = dict
287sort_dictionary = _ordered_dictionary_sort
288
289
290class WeakSequence(Sequence[_T]):
291 def __init__(self, __elements: Sequence[_T] = ()):
292 # adapted from weakref.WeakKeyDictionary, prevent reference
293 # cycles in the collection itself
294 def _remove(item, selfref=weakref.ref(self)):
295 self = selfref()
296 if self is not None:
297 self._storage.remove(item)
298
299 self._remove = _remove
300 self._storage = [
301 weakref.ref(element, _remove) for element in __elements
302 ]
303
304 def append(self, item):
305 self._storage.append(weakref.ref(item, self._remove))
306
307 def __len__(self):
308 return len(self._storage)
309
310 def __iter__(self):
311 return (
312 obj for obj in (ref() for ref in self._storage) if obj is not None
313 )
314
315 def __getitem__(self, index):
316 try:
317 obj = self._storage[index]
318 except KeyError:
319 raise IndexError("Index %s out of range" % index)
320 else:
321 return obj()
322
323
324class OrderedIdentitySet(IdentitySet):
325 def __init__(self, iterable: Optional[Iterable[Any]] = None):
326 IdentitySet.__init__(self)
327 self._members = OrderedDict()
328 if iterable:
329 for o in iterable:
330 self.add(o)
331
332
333class PopulateDict(Dict[_KT, _VT]):
334 """A dict which populates missing values via a creation function.
335
336 Note the creation function takes a key, unlike
337 collections.defaultdict.
338
339 """
340
341 def __init__(self, creator: Callable[[_KT], _VT]):
342 self.creator = creator
343
344 def __missing__(self, key: Any) -> Any:
345 self[key] = val = self.creator(key)
346 return val
347
348
349class WeakPopulateDict(Dict[_KT, _VT]):
350 """Like PopulateDict, but assumes a self + a method and does not create
351 a reference cycle.
352
353 """
354
355 def __init__(self, creator_method: types.MethodType):
356 self.creator = creator_method.__func__
357 weakself = creator_method.__self__
358 self.weakself = weakref.ref(weakself)
359
360 def __missing__(self, key: Any) -> Any:
361 self[key] = val = self.creator(self.weakself(), key)
362 return val
363
364
365# Define collections that are capable of storing
366# ColumnElement objects as hashable keys/elements.
367# At this point, these are mostly historical, things
368# used to be more complicated.
369column_set = set
370column_dict = dict
371ordered_column_set = OrderedSet
372
373
374class UniqueAppender(Generic[_T]):
375 """Appends items to a collection ensuring uniqueness.
376
377 Additional appends() of the same object are ignored. Membership is
378 determined by identity (``is a``) not equality (``==``).
379 """
380
381 __slots__ = "data", "_data_appender", "_unique"
382
383 data: Union[Iterable[_T], Set[_T], List[_T]]
384 _data_appender: Callable[[_T], None]
385 _unique: Dict[int, Literal[True]]
386
387 def __init__(
388 self,
389 data: Union[Iterable[_T], Set[_T], List[_T]],
390 via: Optional[str] = None,
391 ):
392 self.data = data
393 self._unique = {}
394 if via:
395 self._data_appender = getattr(data, via)
396 elif hasattr(data, "append"):
397 self._data_appender = cast("List[_T]", data).append
398 elif hasattr(data, "add"):
399 self._data_appender = cast("Set[_T]", data).add
400
401 def append(self, item: _T) -> None:
402 id_ = id(item)
403 if id_ not in self._unique:
404 self._data_appender(item)
405 self._unique[id_] = True
406
407 def __iter__(self) -> Iterator[_T]:
408 return iter(self.data)
409
410
411def coerce_generator_arg(arg: Any) -> List[Any]:
412 if len(arg) == 1 and isinstance(arg[0], types.GeneratorType):
413 return list(arg[0])
414 else:
415 return cast("List[Any]", arg)
416
417
418def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]:
419 if x is None:
420 return default # type: ignore
421 if not is_non_string_iterable(x):
422 return [x]
423 elif isinstance(x, list):
424 return x
425 else:
426 return list(x)
427
428
429def has_intersection(set_: Container[Any], iterable: Iterable[Any]) -> bool:
430 r"""return True if any items of set\_ are present in iterable.
431
432 Goes through special effort to ensure __hash__ is not called
433 on items in iterable that don't support it.
434
435 """
436 return any(i in set_ for i in iterable if i.__hash__)
437
438
439def to_set(x):
440 if x is None:
441 return set()
442 if not isinstance(x, set):
443 return set(to_list(x))
444 else:
445 return x
446
447
448def to_column_set(x: Any) -> Set[Any]:
449 if x is None:
450 return column_set()
451 if not isinstance(x, column_set):
452 return column_set(to_list(x))
453 else:
454 return x
455
456
457def update_copy(
458 d: Dict[Any, Any], _new: Optional[Dict[Any, Any]] = None, **kw: Any
459) -> Dict[Any, Any]:
460 """Copy the given dict and update with the given values."""
461
462 d = d.copy()
463 if _new:
464 d.update(_new)
465 d.update(**kw)
466 return d
467
468
469def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]:
470 """Given an iterator of which further sub-elements may also be
471 iterators, flatten the sub-elements into a single iterator.
472
473 """
474 elem: _T
475 for elem in x:
476 if not isinstance(elem, str) and hasattr(elem, "__iter__"):
477 yield from flatten_iterator(elem)
478 else:
479 yield elem
480
481
482class LRUCache(typing.MutableMapping[_KT, _VT]):
483 """Dictionary with 'squishy' removal of least
484 recently used items.
485
486 Note that either get() or [] should be used here, but
487 generally its not safe to do an "in" check first as the dictionary
488 can change subsequent to that call.
489
490 """
491
492 __slots__ = (
493 "capacity",
494 "threshold",
495 "size_alert",
496 "_data",
497 "_counter",
498 "_mutex",
499 )
500
501 capacity: int
502 threshold: float
503 size_alert: Optional[Callable[[LRUCache[_KT, _VT]], None]]
504
505 def __init__(
506 self,
507 capacity: int = 100,
508 threshold: float = 0.5,
509 size_alert: Optional[Callable[..., None]] = None,
510 ):
511 self.capacity = capacity
512 self.threshold = threshold
513 self.size_alert = size_alert
514 self._counter = 0
515 self._mutex = threading.Lock()
516 self._data: Dict[_KT, Tuple[_KT, _VT, List[int]]] = {}
517
518 def _inc_counter(self):
519 self._counter += 1
520 return self._counter
521
522 @overload
523 def get(self, key: _KT) -> Optional[_VT]: ...
524
525 @overload
526 def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ...
527
528 def get(
529 self, key: _KT, default: Optional[Union[_VT, _T]] = None
530 ) -> Optional[Union[_VT, _T]]:
531 item = self._data.get(key)
532 if item is not None:
533 item[2][0] = self._inc_counter()
534 return item[1]
535 else:
536 return default
537
538 def __getitem__(self, key: _KT) -> _VT:
539 item = self._data[key]
540 item[2][0] = self._inc_counter()
541 return item[1]
542
543 def __iter__(self) -> Iterator[_KT]:
544 return iter(self._data)
545
546 def __len__(self) -> int:
547 return len(self._data)
548
549 def values(self) -> ValuesView[_VT]:
550 return typing.ValuesView({k: i[1] for k, i in self._data.items()})
551
552 def __setitem__(self, key: _KT, value: _VT) -> None:
553 self._data[key] = (key, value, [self._inc_counter()])
554 self._manage_size()
555
556 def __delitem__(self, __v: _KT) -> None:
557 del self._data[__v]
558
559 @property
560 def size_threshold(self) -> float:
561 return self.capacity + self.capacity * self.threshold
562
563 def _manage_size(self) -> None:
564 if not self._mutex.acquire(False):
565 return
566 try:
567 size_alert = bool(self.size_alert)
568 while len(self) > self.capacity + self.capacity * self.threshold:
569 if size_alert:
570 size_alert = False
571 self.size_alert(self) # type: ignore
572 by_counter = sorted(
573 self._data.values(),
574 key=operator.itemgetter(2),
575 reverse=True,
576 )
577 for item in by_counter[self.capacity :]:
578 try:
579 del self._data[item[0]]
580 except KeyError:
581 # deleted elsewhere; skip
582 continue
583 finally:
584 self._mutex.release()
585
586
587class _CreateFuncType(Protocol[_T_co]):
588 def __call__(self) -> _T_co: ...
589
590
591class _ScopeFuncType(Protocol):
592 def __call__(self) -> Any: ...
593
594
595class ScopedRegistry(Generic[_T]):
596 """A Registry that can store one or multiple instances of a single
597 class on the basis of a "scope" function.
598
599 The object implements ``__call__`` as the "getter", so by
600 calling ``myregistry()`` the contained object is returned
601 for the current scope.
602
603 :param createfunc:
604 a callable that returns a new object to be placed in the registry
605
606 :param scopefunc:
607 a callable that will return a key to store/retrieve an object.
608 """
609
610 __slots__ = "createfunc", "scopefunc", "registry"
611
612 createfunc: _CreateFuncType[_T]
613 scopefunc: _ScopeFuncType
614 registry: Any
615
616 def __init__(
617 self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any]
618 ):
619 """Construct a new :class:`.ScopedRegistry`.
620
621 :param createfunc: A creation function that will generate
622 a new value for the current scope, if none is present.
623
624 :param scopefunc: A function that returns a hashable
625 token representing the current scope (such as, current
626 thread identifier).
627
628 """
629 self.createfunc = createfunc
630 self.scopefunc = scopefunc
631 self.registry = {}
632
633 def __call__(self) -> _T:
634 key = self.scopefunc()
635 try:
636 return self.registry[key] # type: ignore[no-any-return]
637 except KeyError:
638 return self.registry.setdefault(key, self.createfunc()) # type: ignore[no-any-return] # noqa: E501
639
640 def has(self) -> bool:
641 """Return True if an object is present in the current scope."""
642
643 return self.scopefunc() in self.registry
644
645 def set(self, obj: _T) -> None:
646 """Set the value for the current scope."""
647
648 self.registry[self.scopefunc()] = obj
649
650 def clear(self) -> None:
651 """Clear the current scope, if any."""
652
653 try:
654 del self.registry[self.scopefunc()]
655 except KeyError:
656 pass
657
658
659class ThreadLocalRegistry(ScopedRegistry[_T]):
660 """A :class:`.ScopedRegistry` that uses a ``threading.local()``
661 variable for storage.
662
663 """
664
665 def __init__(self, createfunc: Callable[[], _T]):
666 self.createfunc = createfunc
667 self.registry = threading.local()
668
669 def __call__(self) -> _T:
670 try:
671 return self.registry.value # type: ignore[no-any-return]
672 except AttributeError:
673 val = self.registry.value = self.createfunc()
674 return val
675
676 def has(self) -> bool:
677 return hasattr(self.registry, "value")
678
679 def set(self, obj: _T) -> None:
680 self.registry.value = obj
681
682 def clear(self) -> None:
683 try:
684 del self.registry.value
685 except AttributeError:
686 pass
687
688
689def has_dupes(sequence, target):
690 """Given a sequence and search object, return True if there's more
691 than one, False if zero or one of them.
692
693
694 """
695 # compare to .index version below, this version introduces less function
696 # overhead and is usually the same speed. At 15000 items (way bigger than
697 # a relationship-bound collection in memory usually is) it begins to
698 # fall behind the other version only by microseconds.
699 c = 0
700 for item in sequence:
701 if item is target:
702 c += 1
703 if c > 1:
704 return True
705 return False
706
707
708# .index version. the two __contains__ calls as well
709# as .index() and isinstance() slow this down.
710# def has_dupes(sequence, target):
711# if target not in sequence:
712# return False
713# elif not isinstance(sequence, collections_abc.Sequence):
714# return False
715#
716# idx = sequence.index(target)
717# return target in sequence[idx + 1:]