1# util/_collections.py 
    2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors 
    3# <see AUTHORS file> 
    4# 
    5# This module is part of SQLAlchemy and is released under 
    6# the MIT License: https://www.opensource.org/licenses/mit-license.php 
    7# mypy: allow-untyped-defs, allow-untyped-calls 
    8 
    9"""Collection classes and helpers.""" 
    10from __future__ import annotations 
    11 
    12import operator 
    13import threading 
    14import types 
    15import typing 
    16from typing import Any 
    17from typing import Callable 
    18from typing import cast 
    19from typing import Container 
    20from typing import Dict 
    21from typing import FrozenSet 
    22from typing import Generic 
    23from typing import Iterable 
    24from typing import Iterator 
    25from typing import List 
    26from typing import Literal 
    27from typing import Mapping 
    28from typing import NoReturn 
    29from typing import Optional 
    30from typing import overload 
    31from typing import Protocol 
    32from typing import Sequence 
    33from typing import Set 
    34from typing import Tuple 
    35from typing import TypeVar 
    36from typing import Union 
    37from typing import ValuesView 
    38import weakref 
    39 
    40from ._collections_cy import IdentitySet as IdentitySet 
    41from ._collections_cy import OrderedSet as OrderedSet 
    42from ._collections_cy import unique_list as unique_list  # noqa: F401 
    43from ._immutabledict_cy import immutabledict as immutabledict 
    44from ._immutabledict_cy import ImmutableDictBase as ImmutableDictBase 
    45from ._immutabledict_cy import ReadOnlyContainer as ReadOnlyContainer 
    46from .typing import is_non_string_iterable 
    47 
    48 
    49_T = TypeVar("_T", bound=Any) 
    50_KT = TypeVar("_KT", bound=Any) 
    51_VT = TypeVar("_VT", bound=Any) 
    52_T_co = TypeVar("_T_co", covariant=True) 
    53 
    54EMPTY_SET: FrozenSet[Any] = frozenset() 
    55NONE_SET: FrozenSet[Any] = frozenset([None]) 
    56 
    57 
    58def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]: 
    59    """merge two lists, maintaining ordering as much as possible. 
    60 
    61    this is to reconcile vars(cls) with cls.__annotations__. 
    62 
    63    Example:: 
    64 
    65        >>> a = ["__tablename__", "id", "x", "created_at"] 
    66        >>> b = ["id", "name", "data", "y", "created_at"] 
    67        >>> merge_lists_w_ordering(a, b) 
    68        ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at'] 
    69 
    70    This is not necessarily the ordering that things had on the class, 
    71    in this case the class is:: 
    72 
    73        class User(Base): 
    74            __tablename__ = "users" 
    75 
    76            id: Mapped[int] = mapped_column(primary_key=True) 
    77            name: Mapped[str] 
    78            data: Mapped[Optional[str]] 
    79            x = Column(Integer) 
    80            y: Mapped[int] 
    81            created_at: Mapped[datetime.datetime] = mapped_column() 
    82 
    83    But things are *mostly* ordered. 
    84 
    85    The algorithm could also be done by creating a partial ordering for 
    86    all items in both lists and then using topological_sort(), but that 
    87    is too much overhead. 
    88 
    89    Background on how I came up with this is at: 
    90    https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae 
    91 
    92    """ 
    93    overlap = set(a).intersection(b) 
    94 
    95    result = [] 
    96 
    97    current, other = iter(a), iter(b) 
    98 
    99    while True: 
    100        for element in current: 
    101            if element in overlap: 
    102                overlap.discard(element) 
    103                other, current = current, other 
    104                break 
    105 
    106            result.append(element) 
    107        else: 
    108            result.extend(other) 
    109            break 
    110 
    111    return result 
    112 
    113 
    114def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]: 
    115    if not d: 
    116        return EMPTY_DICT 
    117    elif isinstance(d, immutabledict): 
    118        return d 
    119    else: 
    120        return immutabledict(d) 
    121 
    122 
    123EMPTY_DICT: immutabledict[Any, Any] = immutabledict() 
    124 
    125 
    126class FacadeDict(ImmutableDictBase[_KT, _VT]): 
    127    """A dictionary that is not publicly mutable.""" 
    128 
    129    def __new__(cls, *args: Any) -> FacadeDict[Any, Any]: 
    130        new: FacadeDict[Any, Any] = ImmutableDictBase.__new__(cls) 
    131        return new 
    132 
    133    def copy(self) -> NoReturn: 
    134        raise NotImplementedError( 
    135            "an immutabledict shouldn't need to be copied.  use dict(d) " 
    136            "if you need a mutable dictionary." 
    137        ) 
    138 
    139    def __reduce__(self) -> Any: 
    140        return FacadeDict, (dict(self),) 
    141 
    142    def _insert_item(self, key: _KT, value: _VT) -> None: 
    143        """insert an item into the dictionary directly.""" 
    144        dict.__setitem__(self, key, value) 
    145 
    146    def __repr__(self) -> str: 
    147        return "FacadeDict(%s)" % dict.__repr__(self) 
    148 
    149 
    150_DT = TypeVar("_DT", bound=Any) 
    151 
    152_F = TypeVar("_F", bound=Any) 
    153 
    154 
    155class Properties(Generic[_T]): 
    156    """Provide a __getattr__/__setattr__ interface over a dict.""" 
    157 
    158    __slots__ = ("_data",) 
    159 
    160    _data: Dict[str, _T] 
    161 
    162    def __init__(self, data: Dict[str, _T]): 
    163        object.__setattr__(self, "_data", data) 
    164 
    165    def __len__(self) -> int: 
    166        return len(self._data) 
    167 
    168    def __iter__(self) -> Iterator[_T]: 
    169        return iter(list(self._data.values())) 
    170 
    171    def __dir__(self) -> List[str]: 
    172        return dir(super()) + [str(k) for k in self._data.keys()] 
    173 
    174    def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]: 
    175        return list(self) + list(other) 
    176 
    177    def __setitem__(self, key: str, obj: _T) -> None: 
    178        self._data[key] = obj 
    179 
    180    def __getitem__(self, key: str) -> _T: 
    181        return self._data[key] 
    182 
    183    def __delitem__(self, key: str) -> None: 
    184        del self._data[key] 
    185 
    186    def __setattr__(self, key: str, obj: _T) -> None: 
    187        self._data[key] = obj 
    188 
    189    def __getstate__(self) -> Dict[str, Any]: 
    190        return {"_data": self._data} 
    191 
    192    def __setstate__(self, state: Dict[str, Any]) -> None: 
    193        object.__setattr__(self, "_data", state["_data"]) 
    194 
    195    def __getattr__(self, key: str) -> _T: 
    196        try: 
    197            return self._data[key] 
    198        except KeyError: 
    199            raise AttributeError(key) 
    200 
    201    def __contains__(self, key: str) -> bool: 
    202        return key in self._data 
    203 
    204    def as_readonly(self) -> ReadOnlyProperties[_T]: 
    205        """Return an immutable proxy for this :class:`.Properties`.""" 
    206 
    207        return ReadOnlyProperties(self._data) 
    208 
    209    def update(self, value: Dict[str, _T]) -> None: 
    210        self._data.update(value) 
    211 
    212    @overload 
    213    def get(self, key: str) -> Optional[_T]: ... 
    214 
    215    @overload 
    216    def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: ... 
    217 
    218    def get( 
    219        self, key: str, default: Optional[Union[_DT, _T]] = None 
    220    ) -> Optional[Union[_T, _DT]]: 
    221        if key in self: 
    222            return self[key] 
    223        else: 
    224            return default 
    225 
    226    def keys(self) -> List[str]: 
    227        return list(self._data) 
    228 
    229    def values(self) -> List[_T]: 
    230        return list(self._data.values()) 
    231 
    232    def items(self) -> List[Tuple[str, _T]]: 
    233        return list(self._data.items()) 
    234 
    235    def has_key(self, key: str) -> bool: 
    236        return key in self._data 
    237 
    238    def clear(self) -> None: 
    239        self._data.clear() 
    240 
    241 
    242class OrderedProperties(Properties[_T]): 
    243    """Provide a __getattr__/__setattr__ interface with an OrderedDict 
    244    as backing store.""" 
    245 
    246    __slots__ = () 
    247 
    248    def __init__(self): 
    249        Properties.__init__(self, OrderedDict()) 
    250 
    251 
    252class ReadOnlyProperties(ReadOnlyContainer, Properties[_T]): 
    253    """Provide immutable dict/object attribute to an underlying dictionary.""" 
    254 
    255    __slots__ = () 
    256 
    257 
    258def _ordered_dictionary_sort(d, key=None): 
    259    """Sort an OrderedDict in-place.""" 
    260 
    261    items = [(k, d[k]) for k in sorted(d, key=key)] 
    262 
    263    d.clear() 
    264 
    265    d.update(items) 
    266 
    267 
    268OrderedDict = dict 
    269sort_dictionary = _ordered_dictionary_sort 
    270 
    271 
    272class WeakSequence(Sequence[_T]): 
    273    def __init__(self, __elements: Sequence[_T] = (), /): 
    274        # adapted from weakref.WeakKeyDictionary, prevent reference 
    275        # cycles in the collection itself 
    276        def _remove(item, selfref=weakref.ref(self)): 
    277            self = selfref() 
    278            if self is not None: 
    279                self._storage.remove(item) 
    280 
    281        self._remove = _remove 
    282        self._storage = [ 
    283            weakref.ref(element, _remove) for element in __elements 
    284        ] 
    285 
    286    def append(self, item): 
    287        self._storage.append(weakref.ref(item, self._remove)) 
    288 
    289    def __len__(self): 
    290        return len(self._storage) 
    291 
    292    def __iter__(self): 
    293        return ( 
    294            obj for obj in (ref() for ref in self._storage) if obj is not None 
    295        ) 
    296 
    297    def __getitem__(self, index): 
    298        try: 
    299            obj = self._storage[index] 
    300        except KeyError: 
    301            raise IndexError("Index %s out of range" % index) 
    302        else: 
    303            return obj() 
    304 
    305 
    306OrderedIdentitySet = IdentitySet 
    307 
    308 
    309class PopulateDict(Dict[_KT, _VT]): 
    310    """A dict which populates missing values via a creation function. 
    311 
    312    Note the creation function takes a key, unlike 
    313    collections.defaultdict. 
    314 
    315    """ 
    316 
    317    def __init__(self, creator: Callable[[_KT], _VT]): 
    318        self.creator = creator 
    319 
    320    def __missing__(self, key: Any) -> Any: 
    321        self[key] = val = self.creator(key) 
    322        return val 
    323 
    324 
    325class WeakPopulateDict(Dict[_KT, _VT]): 
    326    """Like PopulateDict, but assumes a self + a method and does not create 
    327    a reference cycle. 
    328 
    329    """ 
    330 
    331    def __init__(self, creator_method: types.MethodType): 
    332        self.creator = creator_method.__func__ 
    333        weakself = creator_method.__self__ 
    334        self.weakself = weakref.ref(weakself) 
    335 
    336    def __missing__(self, key: Any) -> Any: 
    337        self[key] = val = self.creator(self.weakself(), key) 
    338        return val 
    339 
    340 
    341# Define collections that are capable of storing 
    342# ColumnElement objects as hashable keys/elements. 
    343# At this point, these are mostly historical, things 
    344# used to be more complicated. 
    345column_set = set 
    346column_dict = dict 
    347ordered_column_set = OrderedSet 
    348 
    349 
    350class UniqueAppender(Generic[_T]): 
    351    """Appends items to a collection ensuring uniqueness. 
    352 
    353    Additional appends() of the same object are ignored.  Membership is 
    354    determined by identity (``is a``) not equality (``==``). 
    355    """ 
    356 
    357    __slots__ = "data", "_data_appender", "_unique" 
    358 
    359    data: Union[Iterable[_T], Set[_T], List[_T]] 
    360    _data_appender: Callable[[_T], None] 
    361    _unique: Dict[int, Literal[True]] 
    362 
    363    def __init__( 
    364        self, 
    365        data: Union[Iterable[_T], Set[_T], List[_T]], 
    366        via: Optional[str] = None, 
    367    ): 
    368        self.data = data 
    369        self._unique = {} 
    370        if via: 
    371            self._data_appender = getattr(data, via) 
    372        elif hasattr(data, "append"): 
    373            self._data_appender = cast("List[_T]", data).append 
    374        elif hasattr(data, "add"): 
    375            self._data_appender = cast("Set[_T]", data).add 
    376 
    377    def append(self, item: _T) -> None: 
    378        id_ = id(item) 
    379        if id_ not in self._unique: 
    380            self._data_appender(item) 
    381            self._unique[id_] = True 
    382 
    383    def __iter__(self) -> Iterator[_T]: 
    384        return iter(self.data) 
    385 
    386 
    387def coerce_generator_arg(arg: Any) -> List[Any]: 
    388    if len(arg) == 1 and isinstance(arg[0], types.GeneratorType): 
    389        return list(arg[0]) 
    390    else: 
    391        return cast("List[Any]", arg) 
    392 
    393 
    394def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]: 
    395    if x is None: 
    396        return default  # type: ignore 
    397    if not is_non_string_iterable(x): 
    398        return [x] 
    399    elif isinstance(x, list): 
    400        return x 
    401    else: 
    402        return list(x) 
    403 
    404 
    405def has_intersection(set_: Container[Any], iterable: Iterable[Any]) -> bool: 
    406    r"""return True if any items of set\_ are present in iterable. 
    407 
    408    Goes through special effort to ensure __hash__ is not called 
    409    on items in iterable that don't support it. 
    410 
    411    """ 
    412    return any(i in set_ for i in iterable if i.__hash__) 
    413 
    414 
    415def to_set(x): 
    416    if x is None: 
    417        return set() 
    418    if not isinstance(x, set): 
    419        return set(to_list(x)) 
    420    else: 
    421        return x 
    422 
    423 
    424def to_column_set(x: Any) -> Set[Any]: 
    425    if x is None: 
    426        return column_set() 
    427    if not isinstance(x, column_set): 
    428        return column_set(to_list(x)) 
    429    else: 
    430        return x 
    431 
    432 
    433def update_copy( 
    434    d: Dict[Any, Any], _new: Optional[Dict[Any, Any]] = None, **kw: Any 
    435) -> Dict[Any, Any]: 
    436    """Copy the given dict and update with the given values.""" 
    437 
    438    d = d.copy() 
    439    if _new: 
    440        d.update(_new) 
    441    d.update(**kw) 
    442    return d 
    443 
    444 
    445def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]: 
    446    """Given an iterator of which further sub-elements may also be 
    447    iterators, flatten the sub-elements into a single iterator. 
    448 
    449    """ 
    450    elem: _T 
    451    for elem in x: 
    452        if not isinstance(elem, str) and hasattr(elem, "__iter__"): 
    453            yield from flatten_iterator(elem) 
    454        else: 
    455            yield elem 
    456 
    457 
    458class LRUCache(typing.MutableMapping[_KT, _VT]): 
    459    """Dictionary with 'squishy' removal of least 
    460    recently used items. 
    461 
    462    Note that either get() or [] should be used here, but 
    463    generally its not safe to do an "in" check first as the dictionary 
    464    can change subsequent to that call. 
    465 
    466    """ 
    467 
    468    __slots__ = ( 
    469        "capacity", 
    470        "threshold", 
    471        "size_alert", 
    472        "_data", 
    473        "_counter", 
    474        "_mutex", 
    475    ) 
    476 
    477    capacity: int 
    478    threshold: float 
    479    size_alert: Optional[Callable[[LRUCache[_KT, _VT]], None]] 
    480 
    481    def __init__( 
    482        self, 
    483        capacity: int = 100, 
    484        threshold: float = 0.5, 
    485        size_alert: Optional[Callable[..., None]] = None, 
    486    ): 
    487        self.capacity = capacity 
    488        self.threshold = threshold 
    489        self.size_alert = size_alert 
    490        self._counter = 0 
    491        self._mutex = threading.Lock() 
    492        self._data: Dict[_KT, Tuple[_KT, _VT, List[int]]] = {} 
    493 
    494    def _inc_counter(self): 
    495        self._counter += 1 
    496        return self._counter 
    497 
    498    @overload 
    499    def get(self, key: _KT) -> Optional[_VT]: ... 
    500 
    501    @overload 
    502    def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ... 
    503 
    504    def get( 
    505        self, key: _KT, default: Optional[Union[_VT, _T]] = None 
    506    ) -> Optional[Union[_VT, _T]]: 
    507        item = self._data.get(key) 
    508        if item is not None: 
    509            item[2][0] = self._inc_counter() 
    510            return item[1] 
    511        else: 
    512            return default 
    513 
    514    def __getitem__(self, key: _KT) -> _VT: 
    515        item = self._data[key] 
    516        item[2][0] = self._inc_counter() 
    517        return item[1] 
    518 
    519    def __iter__(self) -> Iterator[_KT]: 
    520        return iter(self._data) 
    521 
    522    def __len__(self) -> int: 
    523        return len(self._data) 
    524 
    525    def values(self) -> ValuesView[_VT]: 
    526        return typing.ValuesView({k: i[1] for k, i in self._data.items()}) 
    527 
    528    def __setitem__(self, key: _KT, value: _VT) -> None: 
    529        self._data[key] = (key, value, [self._inc_counter()]) 
    530        self._manage_size() 
    531 
    532    def __delitem__(self, __v: _KT) -> None: 
    533        del self._data[__v] 
    534 
    535    @property 
    536    def size_threshold(self) -> float: 
    537        return self.capacity + self.capacity * self.threshold 
    538 
    539    def _manage_size(self) -> None: 
    540        if not self._mutex.acquire(False): 
    541            return 
    542        try: 
    543            size_alert = bool(self.size_alert) 
    544            while len(self) > self.capacity + self.capacity * self.threshold: 
    545                if size_alert: 
    546                    size_alert = False 
    547                    self.size_alert(self)  # type: ignore 
    548                by_counter = sorted( 
    549                    self._data.values(), 
    550                    key=operator.itemgetter(2), 
    551                    reverse=True, 
    552                ) 
    553                for item in by_counter[self.capacity :]: 
    554                    try: 
    555                        del self._data[item[0]] 
    556                    except KeyError: 
    557                        # deleted elsewhere; skip 
    558                        continue 
    559        finally: 
    560            self._mutex.release() 
    561 
    562 
    563class _CreateFuncType(Protocol[_T_co]): 
    564    def __call__(self) -> _T_co: ... 
    565 
    566 
    567class _ScopeFuncType(Protocol): 
    568    def __call__(self) -> Any: ... 
    569 
    570 
    571class ScopedRegistry(Generic[_T]): 
    572    """A Registry that can store one or multiple instances of a single 
    573    class on the basis of a "scope" function. 
    574 
    575    The object implements ``__call__`` as the "getter", so by 
    576    calling ``myregistry()`` the contained object is returned 
    577    for the current scope. 
    578 
    579    :param createfunc: 
    580      a callable that returns a new object to be placed in the registry 
    581 
    582    :param scopefunc: 
    583      a callable that will return a key to store/retrieve an object. 
    584    """ 
    585 
    586    __slots__ = "createfunc", "scopefunc", "registry" 
    587 
    588    createfunc: _CreateFuncType[_T] 
    589    scopefunc: _ScopeFuncType 
    590    registry: Any 
    591 
    592    def __init__( 
    593        self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any] 
    594    ): 
    595        """Construct a new :class:`.ScopedRegistry`. 
    596 
    597        :param createfunc:  A creation function that will generate 
    598          a new value for the current scope, if none is present. 
    599 
    600        :param scopefunc:  A function that returns a hashable 
    601          token representing the current scope (such as, current 
    602          thread identifier). 
    603 
    604        """ 
    605        self.createfunc = createfunc 
    606        self.scopefunc = scopefunc 
    607        self.registry = {} 
    608 
    609    def __call__(self) -> _T: 
    610        key = self.scopefunc() 
    611        try: 
    612            return self.registry[key]  # type: ignore[no-any-return] 
    613        except KeyError: 
    614            return self.registry.setdefault(key, self.createfunc())  # type: ignore[no-any-return] # noqa: E501 
    615 
    616    def has(self) -> bool: 
    617        """Return True if an object is present in the current scope.""" 
    618 
    619        return self.scopefunc() in self.registry 
    620 
    621    def set(self, obj: _T) -> None: 
    622        """Set the value for the current scope.""" 
    623 
    624        self.registry[self.scopefunc()] = obj 
    625 
    626    def clear(self) -> None: 
    627        """Clear the current scope, if any.""" 
    628 
    629        try: 
    630            del self.registry[self.scopefunc()] 
    631        except KeyError: 
    632            pass 
    633 
    634 
    635class ThreadLocalRegistry(ScopedRegistry[_T]): 
    636    """A :class:`.ScopedRegistry` that uses a ``threading.local()`` 
    637    variable for storage. 
    638 
    639    """ 
    640 
    641    def __init__(self, createfunc: Callable[[], _T]): 
    642        self.createfunc = createfunc 
    643        self.registry = threading.local() 
    644 
    645    def __call__(self) -> _T: 
    646        try: 
    647            return self.registry.value  # type: ignore[no-any-return] 
    648        except AttributeError: 
    649            val = self.registry.value = self.createfunc() 
    650            return val 
    651 
    652    def has(self) -> bool: 
    653        return hasattr(self.registry, "value") 
    654 
    655    def set(self, obj: _T) -> None: 
    656        self.registry.value = obj 
    657 
    658    def clear(self) -> None: 
    659        try: 
    660            del self.registry.value 
    661        except AttributeError: 
    662            pass 
    663 
    664 
    665def has_dupes(sequence, target): 
    666    """Given a sequence and search object, return True if there's more 
    667    than one, False if zero or one of them. 
    668 
    669 
    670    """ 
    671    # compare to .index version below, this version introduces less function 
    672    # overhead and is usually the same speed.  At 15000 items (way bigger than 
    673    # a relationship-bound collection in memory usually is) it begins to 
    674    # fall behind the other version only by microseconds. 
    675    c = 0 
    676    for item in sequence: 
    677        if item is target: 
    678            c += 1 
    679            if c > 1: 
    680                return True 
    681    return False 
    682 
    683 
    684# .index version.  the two __contains__ calls as well 
    685# as .index() and isinstance() slow this down. 
    686# def has_dupes(sequence, target): 
    687#    if target not in sequence: 
    688#        return False 
    689#    elif not isinstance(sequence, collections_abc.Sequence): 
    690#        return False 
    691# 
    692#    idx = sequence.index(target) 
    693#    return target in sequence[idx + 1:]