1# util/_collections.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Collection classes and helpers."""
10from __future__ import annotations
11
12import operator
13import threading
14import types
15import typing
16from typing import Any
17from typing import Callable
18from typing import cast
19from typing import Container
20from typing import Dict
21from typing import FrozenSet
22from typing import Generic
23from typing import Iterable
24from typing import Iterator
25from typing import List
26from typing import Mapping
27from typing import NoReturn
28from typing import Optional
29from typing import overload
30from typing import Protocol
31from typing import Sequence
32from typing import Set
33from typing import Tuple
34from typing import TypeVar
35from typing import Union
36from typing import ValuesView
37import weakref
38
39from ._collections_cy import IdentitySet as IdentitySet
40from ._collections_cy import OrderedSet as OrderedSet
41from ._collections_cy import unique_list as unique_list # noqa: F401
42from ._immutabledict_cy import immutabledict as immutabledict
43from ._immutabledict_cy import ImmutableDictBase as ImmutableDictBase
44from ._immutabledict_cy import ReadOnlyContainer as ReadOnlyContainer
45from .typing import is_non_string_iterable
46from .typing import Literal
47
48
49_T = TypeVar("_T", bound=Any)
50_KT = TypeVar("_KT", bound=Any)
51_VT = TypeVar("_VT", bound=Any)
52_T_co = TypeVar("_T_co", covariant=True)
53
54EMPTY_SET: FrozenSet[Any] = frozenset()
55NONE_SET: FrozenSet[Any] = frozenset([None])
56
57
58def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]:
59 """merge two lists, maintaining ordering as much as possible.
60
61 this is to reconcile vars(cls) with cls.__annotations__.
62
63 Example::
64
65 >>> a = ['__tablename__', 'id', 'x', 'created_at']
66 >>> b = ['id', 'name', 'data', 'y', 'created_at']
67 >>> merge_lists_w_ordering(a, b)
68 ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at']
69
70 This is not necessarily the ordering that things had on the class,
71 in this case the class is::
72
73 class User(Base):
74 __tablename__ = "users"
75
76 id: Mapped[int] = mapped_column(primary_key=True)
77 name: Mapped[str]
78 data: Mapped[Optional[str]]
79 x = Column(Integer)
80 y: Mapped[int]
81 created_at: Mapped[datetime.datetime] = mapped_column()
82
83 But things are *mostly* ordered.
84
85 The algorithm could also be done by creating a partial ordering for
86 all items in both lists and then using topological_sort(), but that
87 is too much overhead.
88
89 Background on how I came up with this is at:
90 https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae
91
92 """
93 overlap = set(a).intersection(b)
94
95 result = []
96
97 current, other = iter(a), iter(b)
98
99 while True:
100 for element in current:
101 if element in overlap:
102 overlap.discard(element)
103 other, current = current, other
104 break
105
106 result.append(element)
107 else:
108 result.extend(other)
109 break
110
111 return result
112
113
114def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]:
115 if not d:
116 return EMPTY_DICT
117 elif isinstance(d, immutabledict):
118 return d
119 else:
120 return immutabledict(d)
121
122
123EMPTY_DICT: immutabledict[Any, Any] = immutabledict()
124
125
126class FacadeDict(ImmutableDictBase[_KT, _VT]):
127 """A dictionary that is not publicly mutable."""
128
129 def __new__(cls, *args: Any) -> FacadeDict[Any, Any]:
130 new: FacadeDict[Any, Any] = ImmutableDictBase.__new__(cls)
131 return new
132
133 def copy(self) -> NoReturn:
134 raise NotImplementedError(
135 "an immutabledict shouldn't need to be copied. use dict(d) "
136 "if you need a mutable dictionary."
137 )
138
139 def __reduce__(self) -> Any:
140 return FacadeDict, (dict(self),)
141
142 def _insert_item(self, key: _KT, value: _VT) -> None:
143 """insert an item into the dictionary directly."""
144 dict.__setitem__(self, key, value)
145
146 def __repr__(self) -> str:
147 return "FacadeDict(%s)" % dict.__repr__(self)
148
149
150_DT = TypeVar("_DT", bound=Any)
151
152_F = TypeVar("_F", bound=Any)
153
154
155class Properties(Generic[_T]):
156 """Provide a __getattr__/__setattr__ interface over a dict."""
157
158 __slots__ = ("_data",)
159
160 _data: Dict[str, _T]
161
162 def __init__(self, data: Dict[str, _T]):
163 object.__setattr__(self, "_data", data)
164
165 def __len__(self) -> int:
166 return len(self._data)
167
168 def __iter__(self) -> Iterator[_T]:
169 return iter(list(self._data.values()))
170
171 def __dir__(self) -> List[str]:
172 return dir(super()) + [str(k) for k in self._data.keys()]
173
174 def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]:
175 return list(self) + list(other)
176
177 def __setitem__(self, key: str, obj: _T) -> None:
178 self._data[key] = obj
179
180 def __getitem__(self, key: str) -> _T:
181 return self._data[key]
182
183 def __delitem__(self, key: str) -> None:
184 del self._data[key]
185
186 def __setattr__(self, key: str, obj: _T) -> None:
187 self._data[key] = obj
188
189 def __getstate__(self) -> Dict[str, Any]:
190 return {"_data": self._data}
191
192 def __setstate__(self, state: Dict[str, Any]) -> None:
193 object.__setattr__(self, "_data", state["_data"])
194
195 def __getattr__(self, key: str) -> _T:
196 try:
197 return self._data[key]
198 except KeyError:
199 raise AttributeError(key)
200
201 def __contains__(self, key: str) -> bool:
202 return key in self._data
203
204 def as_readonly(self) -> ReadOnlyProperties[_T]:
205 """Return an immutable proxy for this :class:`.Properties`."""
206
207 return ReadOnlyProperties(self._data)
208
209 def update(self, value: Dict[str, _T]) -> None:
210 self._data.update(value)
211
212 @overload
213 def get(self, key: str) -> Optional[_T]: ...
214
215 @overload
216 def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: ...
217
218 def get(
219 self, key: str, default: Optional[Union[_DT, _T]] = None
220 ) -> Optional[Union[_T, _DT]]:
221 if key in self:
222 return self[key]
223 else:
224 return default
225
226 def keys(self) -> List[str]:
227 return list(self._data)
228
229 def values(self) -> List[_T]:
230 return list(self._data.values())
231
232 def items(self) -> List[Tuple[str, _T]]:
233 return list(self._data.items())
234
235 def has_key(self, key: str) -> bool:
236 return key in self._data
237
238 def clear(self) -> None:
239 self._data.clear()
240
241
242class OrderedProperties(Properties[_T]):
243 """Provide a __getattr__/__setattr__ interface with an OrderedDict
244 as backing store."""
245
246 __slots__ = ()
247
248 def __init__(self):
249 Properties.__init__(self, OrderedDict())
250
251
252class ReadOnlyProperties(ReadOnlyContainer, Properties[_T]):
253 """Provide immutable dict/object attribute to an underlying dictionary."""
254
255 __slots__ = ()
256
257
258def _ordered_dictionary_sort(d, key=None):
259 """Sort an OrderedDict in-place."""
260
261 items = [(k, d[k]) for k in sorted(d, key=key)]
262
263 d.clear()
264
265 d.update(items)
266
267
268OrderedDict = dict
269sort_dictionary = _ordered_dictionary_sort
270
271
272class WeakSequence(Sequence[_T]):
273 def __init__(self, __elements: Sequence[_T] = (), /):
274 # adapted from weakref.WeakKeyDictionary, prevent reference
275 # cycles in the collection itself
276 def _remove(item, selfref=weakref.ref(self)):
277 self = selfref()
278 if self is not None:
279 self._storage.remove(item)
280
281 self._remove = _remove
282 self._storage = [
283 weakref.ref(element, _remove) for element in __elements
284 ]
285
286 def append(self, item):
287 self._storage.append(weakref.ref(item, self._remove))
288
289 def __len__(self):
290 return len(self._storage)
291
292 def __iter__(self):
293 return (
294 obj for obj in (ref() for ref in self._storage) if obj is not None
295 )
296
297 def __getitem__(self, index):
298 try:
299 obj = self._storage[index]
300 except KeyError:
301 raise IndexError("Index %s out of range" % index)
302 else:
303 return obj()
304
305
306OrderedIdentitySet = IdentitySet
307
308
309class PopulateDict(Dict[_KT, _VT]):
310 """A dict which populates missing values via a creation function.
311
312 Note the creation function takes a key, unlike
313 collections.defaultdict.
314
315 """
316
317 def __init__(self, creator: Callable[[_KT], _VT]):
318 self.creator = creator
319
320 def __missing__(self, key: Any) -> Any:
321 self[key] = val = self.creator(key)
322 return val
323
324
325class WeakPopulateDict(Dict[_KT, _VT]):
326 """Like PopulateDict, but assumes a self + a method and does not create
327 a reference cycle.
328
329 """
330
331 def __init__(self, creator_method: types.MethodType):
332 self.creator = creator_method.__func__
333 weakself = creator_method.__self__
334 self.weakself = weakref.ref(weakself)
335
336 def __missing__(self, key: Any) -> Any:
337 self[key] = val = self.creator(self.weakself(), key)
338 return val
339
340
341# Define collections that are capable of storing
342# ColumnElement objects as hashable keys/elements.
343# At this point, these are mostly historical, things
344# used to be more complicated.
345column_set = set
346column_dict = dict
347ordered_column_set = OrderedSet
348
349
350class UniqueAppender(Generic[_T]):
351 """Appends items to a collection ensuring uniqueness.
352
353 Additional appends() of the same object are ignored. Membership is
354 determined by identity (``is a``) not equality (``==``).
355 """
356
357 __slots__ = "data", "_data_appender", "_unique"
358
359 data: Union[Iterable[_T], Set[_T], List[_T]]
360 _data_appender: Callable[[_T], None]
361 _unique: Dict[int, Literal[True]]
362
363 def __init__(
364 self,
365 data: Union[Iterable[_T], Set[_T], List[_T]],
366 via: Optional[str] = None,
367 ):
368 self.data = data
369 self._unique = {}
370 if via:
371 self._data_appender = getattr(data, via)
372 elif hasattr(data, "append"):
373 self._data_appender = cast("List[_T]", data).append
374 elif hasattr(data, "add"):
375 self._data_appender = cast("Set[_T]", data).add
376
377 def append(self, item: _T) -> None:
378 id_ = id(item)
379 if id_ not in self._unique:
380 self._data_appender(item)
381 self._unique[id_] = True
382
383 def __iter__(self) -> Iterator[_T]:
384 return iter(self.data)
385
386
387def coerce_generator_arg(arg: Any) -> List[Any]:
388 if len(arg) == 1 and isinstance(arg[0], types.GeneratorType):
389 return list(arg[0])
390 else:
391 return cast("List[Any]", arg)
392
393
394def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]:
395 if x is None:
396 return default # type: ignore
397 if not is_non_string_iterable(x):
398 return [x]
399 elif isinstance(x, list):
400 return x
401 else:
402 return list(x)
403
404
405def has_intersection(set_: Container[Any], iterable: Iterable[Any]) -> bool:
406 r"""return True if any items of set\_ are present in iterable.
407
408 Goes through special effort to ensure __hash__ is not called
409 on items in iterable that don't support it.
410
411 """
412 return any(i in set_ for i in iterable if i.__hash__)
413
414
415def to_set(x):
416 if x is None:
417 return set()
418 if not isinstance(x, set):
419 return set(to_list(x))
420 else:
421 return x
422
423
424def to_column_set(x: Any) -> Set[Any]:
425 if x is None:
426 return column_set()
427 if not isinstance(x, column_set):
428 return column_set(to_list(x))
429 else:
430 return x
431
432
433def update_copy(d, _new=None, **kw):
434 """Copy the given dict and update with the given values."""
435
436 d = d.copy()
437 if _new:
438 d.update(_new)
439 d.update(**kw)
440 return d
441
442
443def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]:
444 """Given an iterator of which further sub-elements may also be
445 iterators, flatten the sub-elements into a single iterator.
446
447 """
448 elem: _T
449 for elem in x:
450 if not isinstance(elem, str) and hasattr(elem, "__iter__"):
451 yield from flatten_iterator(elem)
452 else:
453 yield elem
454
455
456class LRUCache(typing.MutableMapping[_KT, _VT]):
457 """Dictionary with 'squishy' removal of least
458 recently used items.
459
460 Note that either get() or [] should be used here, but
461 generally its not safe to do an "in" check first as the dictionary
462 can change subsequent to that call.
463
464 """
465
466 __slots__ = (
467 "capacity",
468 "threshold",
469 "size_alert",
470 "_data",
471 "_counter",
472 "_mutex",
473 )
474
475 capacity: int
476 threshold: float
477 size_alert: Optional[Callable[[LRUCache[_KT, _VT]], None]]
478
479 def __init__(
480 self,
481 capacity: int = 100,
482 threshold: float = 0.5,
483 size_alert: Optional[Callable[..., None]] = None,
484 ):
485 self.capacity = capacity
486 self.threshold = threshold
487 self.size_alert = size_alert
488 self._counter = 0
489 self._mutex = threading.Lock()
490 self._data: Dict[_KT, Tuple[_KT, _VT, List[int]]] = {}
491
492 def _inc_counter(self):
493 self._counter += 1
494 return self._counter
495
496 @overload
497 def get(self, key: _KT) -> Optional[_VT]: ...
498
499 @overload
500 def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ...
501
502 def get(
503 self, key: _KT, default: Optional[Union[_VT, _T]] = None
504 ) -> Optional[Union[_VT, _T]]:
505 item = self._data.get(key)
506 if item is not None:
507 item[2][0] = self._inc_counter()
508 return item[1]
509 else:
510 return default
511
512 def __getitem__(self, key: _KT) -> _VT:
513 item = self._data[key]
514 item[2][0] = self._inc_counter()
515 return item[1]
516
517 def __iter__(self) -> Iterator[_KT]:
518 return iter(self._data)
519
520 def __len__(self) -> int:
521 return len(self._data)
522
523 def values(self) -> ValuesView[_VT]:
524 return typing.ValuesView({k: i[1] for k, i in self._data.items()})
525
526 def __setitem__(self, key: _KT, value: _VT) -> None:
527 self._data[key] = (key, value, [self._inc_counter()])
528 self._manage_size()
529
530 def __delitem__(self, __v: _KT) -> None:
531 del self._data[__v]
532
533 @property
534 def size_threshold(self) -> float:
535 return self.capacity + self.capacity * self.threshold
536
537 def _manage_size(self) -> None:
538 if not self._mutex.acquire(False):
539 return
540 try:
541 size_alert = bool(self.size_alert)
542 while len(self) > self.capacity + self.capacity * self.threshold:
543 if size_alert:
544 size_alert = False
545 self.size_alert(self) # type: ignore
546 by_counter = sorted(
547 self._data.values(),
548 key=operator.itemgetter(2),
549 reverse=True,
550 )
551 for item in by_counter[self.capacity :]:
552 try:
553 del self._data[item[0]]
554 except KeyError:
555 # deleted elsewhere; skip
556 continue
557 finally:
558 self._mutex.release()
559
560
561class _CreateFuncType(Protocol[_T_co]):
562 def __call__(self) -> _T_co: ...
563
564
565class _ScopeFuncType(Protocol):
566 def __call__(self) -> Any: ...
567
568
569class ScopedRegistry(Generic[_T]):
570 """A Registry that can store one or multiple instances of a single
571 class on the basis of a "scope" function.
572
573 The object implements ``__call__`` as the "getter", so by
574 calling ``myregistry()`` the contained object is returned
575 for the current scope.
576
577 :param createfunc:
578 a callable that returns a new object to be placed in the registry
579
580 :param scopefunc:
581 a callable that will return a key to store/retrieve an object.
582 """
583
584 __slots__ = "createfunc", "scopefunc", "registry"
585
586 createfunc: _CreateFuncType[_T]
587 scopefunc: _ScopeFuncType
588 registry: Any
589
590 def __init__(
591 self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any]
592 ):
593 """Construct a new :class:`.ScopedRegistry`.
594
595 :param createfunc: A creation function that will generate
596 a new value for the current scope, if none is present.
597
598 :param scopefunc: A function that returns a hashable
599 token representing the current scope (such as, current
600 thread identifier).
601
602 """
603 self.createfunc = createfunc
604 self.scopefunc = scopefunc
605 self.registry = {}
606
607 def __call__(self) -> _T:
608 key = self.scopefunc()
609 try:
610 return self.registry[key] # type: ignore[no-any-return]
611 except KeyError:
612 return self.registry.setdefault(key, self.createfunc()) # type: ignore[no-any-return] # noqa: E501
613
614 def has(self) -> bool:
615 """Return True if an object is present in the current scope."""
616
617 return self.scopefunc() in self.registry
618
619 def set(self, obj: _T) -> None:
620 """Set the value for the current scope."""
621
622 self.registry[self.scopefunc()] = obj
623
624 def clear(self) -> None:
625 """Clear the current scope, if any."""
626
627 try:
628 del self.registry[self.scopefunc()]
629 except KeyError:
630 pass
631
632
633class ThreadLocalRegistry(ScopedRegistry[_T]):
634 """A :class:`.ScopedRegistry` that uses a ``threading.local()``
635 variable for storage.
636
637 """
638
639 def __init__(self, createfunc: Callable[[], _T]):
640 self.createfunc = createfunc
641 self.registry = threading.local()
642
643 def __call__(self) -> _T:
644 try:
645 return self.registry.value # type: ignore[no-any-return]
646 except AttributeError:
647 val = self.registry.value = self.createfunc()
648 return val
649
650 def has(self) -> bool:
651 return hasattr(self.registry, "value")
652
653 def set(self, obj: _T) -> None:
654 self.registry.value = obj
655
656 def clear(self) -> None:
657 try:
658 del self.registry.value
659 except AttributeError:
660 pass
661
662
663def has_dupes(sequence, target):
664 """Given a sequence and search object, return True if there's more
665 than one, False if zero or one of them.
666
667
668 """
669 # compare to .index version below, this version introduces less function
670 # overhead and is usually the same speed. At 15000 items (way bigger than
671 # a relationship-bound collection in memory usually is) it begins to
672 # fall behind the other version only by microseconds.
673 c = 0
674 for item in sequence:
675 if item is target:
676 c += 1
677 if c > 1:
678 return True
679 return False
680
681
682# .index version. the two __contains__ calls as well
683# as .index() and isinstance() slow this down.
684# def has_dupes(sequence, target):
685# if target not in sequence:
686# return False
687# elif not isinstance(sequence, collections_abc.Sequence):
688# return False
689#
690# idx = sequence.index(target)
691# return target in sequence[idx + 1:]