1# orm/clsregistry.py 
    2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors 
    3# <see AUTHORS file> 
    4# 
    5# This module is part of SQLAlchemy and is released under 
    6# the MIT License: https://www.opensource.org/licenses/mit-license.php 
    7 
    8"""Routines to handle the string class registry used by declarative. 
    9 
    10This system allows specification of classes and expressions used in 
    11:func:`_orm.relationship` using strings. 
    12 
    13""" 
    14 
    15from __future__ import annotations 
    16 
    17import re 
    18from typing import Any 
    19from typing import Callable 
    20from typing import cast 
    21from typing import Dict 
    22from typing import Generator 
    23from typing import Iterable 
    24from typing import List 
    25from typing import Mapping 
    26from typing import MutableMapping 
    27from typing import NoReturn 
    28from typing import Optional 
    29from typing import Set 
    30from typing import Tuple 
    31from typing import Type 
    32from typing import TYPE_CHECKING 
    33from typing import TypeVar 
    34from typing import Union 
    35import weakref 
    36 
    37from . import attributes 
    38from . import interfaces 
    39from .descriptor_props import SynonymProperty 
    40from .properties import ColumnProperty 
    41from .util import class_mapper 
    42from .. import exc 
    43from .. import inspection 
    44from .. import util 
    45from ..sql.schema import _get_table_key 
    46from ..util.typing import CallableReference 
    47 
    48if TYPE_CHECKING: 
    49    from .relationships import RelationshipProperty 
    50    from ..sql.schema import MetaData 
    51    from ..sql.schema import Table 
    52 
    53_T = TypeVar("_T", bound=Any) 
    54 
    55_ClsRegistryType = MutableMapping[str, Union[type, "_ClsRegistryToken"]] 
    56 
    57# strong references to registries which we place in 
    58# the _decl_class_registry, which is usually weak referencing. 
    59# the internal registries here link to classes with weakrefs and remove 
    60# themselves when all references to contained classes are removed. 
    61_registries: Set[_ClsRegistryToken] = set() 
    62 
    63 
    64def _add_class( 
    65    classname: str, cls: Type[_T], decl_class_registry: _ClsRegistryType 
    66) -> None: 
    67    """Add a class to the _decl_class_registry associated with the 
    68    given declarative class. 
    69 
    70    """ 
    71    if classname in decl_class_registry: 
    72        # class already exists. 
    73        existing = decl_class_registry[classname] 
    74        if not isinstance(existing, _MultipleClassMarker): 
    75            decl_class_registry[classname] = _MultipleClassMarker( 
    76                [cls, cast("Type[Any]", existing)] 
    77            ) 
    78    else: 
    79        decl_class_registry[classname] = cls 
    80 
    81    try: 
    82        root_module = cast( 
    83            _ModuleMarker, decl_class_registry["_sa_module_registry"] 
    84        ) 
    85    except KeyError: 
    86        decl_class_registry["_sa_module_registry"] = root_module = ( 
    87            _ModuleMarker("_sa_module_registry", None) 
    88        ) 
    89 
    90    tokens = cls.__module__.split(".") 
    91 
    92    # build up a tree like this: 
    93    # modulename:  myapp.snacks.nuts 
    94    # 
    95    # myapp->snack->nuts->(classes) 
    96    # snack->nuts->(classes) 
    97    # nuts->(classes) 
    98    # 
    99    # this allows partial token paths to be used. 
    100    while tokens: 
    101        token = tokens.pop(0) 
    102        module = root_module.get_module(token) 
    103        for token in tokens: 
    104            module = module.get_module(token) 
    105 
    106        try: 
    107            module.add_class(classname, cls) 
    108        except AttributeError as ae: 
    109            if not isinstance(module, _ModuleMarker): 
    110                raise exc.InvalidRequestError( 
    111                    f'name "{classname}" matches both a ' 
    112                    "class name and a module name" 
    113                ) from ae 
    114            else: 
    115                raise 
    116 
    117 
    118def _remove_class( 
    119    classname: str, cls: Type[Any], decl_class_registry: _ClsRegistryType 
    120) -> None: 
    121    if classname in decl_class_registry: 
    122        existing = decl_class_registry[classname] 
    123        if isinstance(existing, _MultipleClassMarker): 
    124            existing.remove_item(cls) 
    125        else: 
    126            del decl_class_registry[classname] 
    127 
    128    try: 
    129        root_module = cast( 
    130            _ModuleMarker, decl_class_registry["_sa_module_registry"] 
    131        ) 
    132    except KeyError: 
    133        return 
    134 
    135    tokens = cls.__module__.split(".") 
    136 
    137    while tokens: 
    138        token = tokens.pop(0) 
    139        module = root_module.get_module(token) 
    140        for token in tokens: 
    141            module = module.get_module(token) 
    142        try: 
    143            module.remove_class(classname, cls) 
    144        except AttributeError: 
    145            if not isinstance(module, _ModuleMarker): 
    146                pass 
    147            else: 
    148                raise 
    149 
    150 
    151def _key_is_empty( 
    152    key: str, 
    153    decl_class_registry: _ClsRegistryType, 
    154    test: Callable[[Any], bool], 
    155) -> bool: 
    156    """test if a key is empty of a certain object. 
    157 
    158    used for unit tests against the registry to see if garbage collection 
    159    is working. 
    160 
    161    "test" is a callable that will be passed an object should return True 
    162    if the given object is the one we were looking for. 
    163 
    164    We can't pass the actual object itself b.c. this is for testing garbage 
    165    collection; the caller will have to have removed references to the 
    166    object itself. 
    167 
    168    """ 
    169    if key not in decl_class_registry: 
    170        return True 
    171 
    172    thing = decl_class_registry[key] 
    173    if isinstance(thing, _MultipleClassMarker): 
    174        for sub_thing in thing.contents: 
    175            if test(sub_thing): 
    176                return False 
    177        else: 
    178            raise NotImplementedError("unknown codepath") 
    179    else: 
    180        return not test(thing) 
    181 
    182 
    183class _ClsRegistryToken: 
    184    """an object that can be in the registry._class_registry as a value.""" 
    185 
    186    __slots__ = () 
    187 
    188 
    189class _MultipleClassMarker(_ClsRegistryToken): 
    190    """refers to multiple classes of the same name 
    191    within _decl_class_registry. 
    192 
    193    """ 
    194 
    195    __slots__ = "on_remove", "contents", "__weakref__" 
    196 
    197    contents: Set[weakref.ref[Type[Any]]] 
    198    on_remove: CallableReference[Optional[Callable[[], None]]] 
    199 
    200    def __init__( 
    201        self, 
    202        classes: Iterable[Type[Any]], 
    203        on_remove: Optional[Callable[[], None]] = None, 
    204    ): 
    205        self.on_remove = on_remove 
    206        self.contents = { 
    207            weakref.ref(item, self._remove_item) for item in classes 
    208        } 
    209        _registries.add(self) 
    210 
    211    def remove_item(self, cls: Type[Any]) -> None: 
    212        self._remove_item(weakref.ref(cls)) 
    213 
    214    def __iter__(self) -> Generator[Optional[Type[Any]], None, None]: 
    215        return (ref() for ref in self.contents) 
    216 
    217    def attempt_get(self, path: List[str], key: str) -> Type[Any]: 
    218        if len(self.contents) > 1: 
    219            raise exc.InvalidRequestError( 
    220                'Multiple classes found for path "%s" ' 
    221                "in the registry of this declarative " 
    222                "base. Please use a fully module-qualified path." 
    223                % (".".join(path + [key])) 
    224            ) 
    225        else: 
    226            ref = list(self.contents)[0] 
    227            cls = ref() 
    228            if cls is None: 
    229                raise NameError(key) 
    230            return cls 
    231 
    232    def _remove_item(self, ref: weakref.ref[Type[Any]]) -> None: 
    233        self.contents.discard(ref) 
    234        if not self.contents: 
    235            _registries.discard(self) 
    236            if self.on_remove: 
    237                self.on_remove() 
    238 
    239    def add_item(self, item: Type[Any]) -> None: 
    240        # protect against class registration race condition against 
    241        # asynchronous garbage collection calling _remove_item, 
    242        # [ticket:3208] and [ticket:10782] 
    243        modules = { 
    244            cls.__module__ 
    245            for cls in [ref() for ref in list(self.contents)] 
    246            if cls is not None 
    247        } 
    248        if item.__module__ in modules: 
    249            util.warn( 
    250                "This declarative base already contains a class with the " 
    251                "same class name and module name as %s.%s, and will " 
    252                "be replaced in the string-lookup table." 
    253                % (item.__module__, item.__name__) 
    254            ) 
    255        self.contents.add(weakref.ref(item, self._remove_item)) 
    256 
    257 
    258class _ModuleMarker(_ClsRegistryToken): 
    259    """Refers to a module name within 
    260    _decl_class_registry. 
    261 
    262    """ 
    263 
    264    __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__" 
    265 
    266    parent: Optional[_ModuleMarker] 
    267    contents: Dict[str, Union[_ModuleMarker, _MultipleClassMarker]] 
    268    mod_ns: _ModNS 
    269    path: List[str] 
    270 
    271    def __init__(self, name: str, parent: Optional[_ModuleMarker]): 
    272        self.parent = parent 
    273        self.name = name 
    274        self.contents = {} 
    275        self.mod_ns = _ModNS(self) 
    276        if self.parent: 
    277            self.path = self.parent.path + [self.name] 
    278        else: 
    279            self.path = [] 
    280        _registries.add(self) 
    281 
    282    def __contains__(self, name: str) -> bool: 
    283        return name in self.contents 
    284 
    285    def __getitem__(self, name: str) -> _ClsRegistryToken: 
    286        return self.contents[name] 
    287 
    288    def _remove_item(self, name: str) -> None: 
    289        self.contents.pop(name, None) 
    290        if not self.contents: 
    291            if self.parent is not None: 
    292                self.parent._remove_item(self.name) 
    293            _registries.discard(self) 
    294 
    295    def resolve_attr(self, key: str) -> Union[_ModNS, Type[Any]]: 
    296        return self.mod_ns.__getattr__(key) 
    297 
    298    def get_module(self, name: str) -> _ModuleMarker: 
    299        if name not in self.contents: 
    300            marker = _ModuleMarker(name, self) 
    301            self.contents[name] = marker 
    302        else: 
    303            marker = cast(_ModuleMarker, self.contents[name]) 
    304        return marker 
    305 
    306    def add_class(self, name: str, cls: Type[Any]) -> None: 
    307        if name in self.contents: 
    308            existing = cast(_MultipleClassMarker, self.contents[name]) 
    309            try: 
    310                existing.add_item(cls) 
    311            except AttributeError as ae: 
    312                if not isinstance(existing, _MultipleClassMarker): 
    313                    raise exc.InvalidRequestError( 
    314                        f'name "{name}" matches both a ' 
    315                        "class name and a module name" 
    316                    ) from ae 
    317                else: 
    318                    raise 
    319        else: 
    320            self.contents[name] = _MultipleClassMarker( 
    321                [cls], on_remove=lambda: self._remove_item(name) 
    322            ) 
    323 
    324    def remove_class(self, name: str, cls: Type[Any]) -> None: 
    325        if name in self.contents: 
    326            existing = cast(_MultipleClassMarker, self.contents[name]) 
    327            existing.remove_item(cls) 
    328 
    329 
    330class _ModNS: 
    331    __slots__ = ("__parent",) 
    332 
    333    __parent: _ModuleMarker 
    334 
    335    def __init__(self, parent: _ModuleMarker): 
    336        self.__parent = parent 
    337 
    338    def __getattr__(self, key: str) -> Union[_ModNS, Type[Any]]: 
    339        try: 
    340            value = self.__parent.contents[key] 
    341        except KeyError: 
    342            pass 
    343        else: 
    344            if value is not None: 
    345                if isinstance(value, _ModuleMarker): 
    346                    return value.mod_ns 
    347                else: 
    348                    assert isinstance(value, _MultipleClassMarker) 
    349                    return value.attempt_get(self.__parent.path, key) 
    350        raise NameError( 
    351            "Module %r has no mapped classes " 
    352            "registered under the name %r" % (self.__parent.name, key) 
    353        ) 
    354 
    355 
    356class _GetColumns: 
    357    __slots__ = ("cls",) 
    358 
    359    cls: Type[Any] 
    360 
    361    def __init__(self, cls: Type[Any]): 
    362        self.cls = cls 
    363 
    364    def __getattr__(self, key: str) -> Any: 
    365        mp = class_mapper(self.cls, configure=False) 
    366        if mp: 
    367            if key not in mp.all_orm_descriptors: 
    368                raise AttributeError( 
    369                    "Class %r does not have a mapped column named %r" 
    370                    % (self.cls, key) 
    371                ) 
    372 
    373            desc = mp.all_orm_descriptors[key] 
    374            if desc.extension_type is interfaces.NotExtension.NOT_EXTENSION: 
    375                assert isinstance(desc, attributes.QueryableAttribute) 
    376                prop = desc.property 
    377                if isinstance(prop, SynonymProperty): 
    378                    key = prop.name 
    379                elif not isinstance(prop, ColumnProperty): 
    380                    raise exc.InvalidRequestError( 
    381                        "Property %r is not an instance of" 
    382                        " ColumnProperty (i.e. does not correspond" 
    383                        " directly to a Column)." % key 
    384                    ) 
    385        return getattr(self.cls, key) 
    386 
    387 
    388inspection._inspects(_GetColumns)( 
    389    lambda target: inspection.inspect(target.cls) 
    390) 
    391 
    392 
    393class _GetTable: 
    394    __slots__ = "key", "metadata" 
    395 
    396    key: str 
    397    metadata: MetaData 
    398 
    399    def __init__(self, key: str, metadata: MetaData): 
    400        self.key = key 
    401        self.metadata = metadata 
    402 
    403    def __getattr__(self, key: str) -> Table: 
    404        return self.metadata.tables[_get_table_key(key, self.key)] 
    405 
    406 
    407def _determine_container(key: str, value: Any) -> _GetColumns: 
    408    if isinstance(value, _MultipleClassMarker): 
    409        value = value.attempt_get([], key) 
    410    return _GetColumns(value) 
    411 
    412 
    413class _class_resolver: 
    414    __slots__ = ( 
    415        "cls", 
    416        "prop", 
    417        "arg", 
    418        "fallback", 
    419        "_dict", 
    420        "_resolvers", 
    421        "tables_only", 
    422    ) 
    423 
    424    cls: Type[Any] 
    425    prop: RelationshipProperty[Any] 
    426    fallback: Mapping[str, Any] 
    427    arg: str 
    428    tables_only: bool 
    429    _resolvers: Tuple[Callable[[str], Any], ...] 
    430 
    431    def __init__( 
    432        self, 
    433        cls: Type[Any], 
    434        prop: RelationshipProperty[Any], 
    435        fallback: Mapping[str, Any], 
    436        arg: str, 
    437        tables_only: bool = False, 
    438    ): 
    439        self.cls = cls 
    440        self.prop = prop 
    441        self.arg = arg 
    442        self.fallback = fallback 
    443        self._dict = util.PopulateDict(self._access_cls) 
    444        self._resolvers = () 
    445        self.tables_only = tables_only 
    446 
    447    def _access_cls(self, key: str) -> Any: 
    448        cls = self.cls 
    449 
    450        manager = attributes.manager_of_class(cls) 
    451        decl_base = manager.registry 
    452        assert decl_base is not None 
    453        decl_class_registry = decl_base._class_registry 
    454        metadata = decl_base.metadata 
    455 
    456        if self.tables_only: 
    457            if key in metadata.tables: 
    458                return metadata.tables[key] 
    459            elif key in metadata._schemas: 
    460                return _GetTable(key, getattr(cls, "metadata", metadata)) 
    461 
    462        if key in decl_class_registry: 
    463            dt = _determine_container(key, decl_class_registry[key]) 
    464            if self.tables_only: 
    465                return dt.cls 
    466            else: 
    467                return dt 
    468 
    469        if not self.tables_only: 
    470            if key in metadata.tables: 
    471                return metadata.tables[key] 
    472            elif key in metadata._schemas: 
    473                return _GetTable(key, getattr(cls, "metadata", metadata)) 
    474 
    475        if "_sa_module_registry" in decl_class_registry and key in cast( 
    476            _ModuleMarker, decl_class_registry["_sa_module_registry"] 
    477        ): 
    478            registry = cast( 
    479                _ModuleMarker, decl_class_registry["_sa_module_registry"] 
    480            ) 
    481            return registry.resolve_attr(key) 
    482 
    483        if self._resolvers: 
    484            for resolv in self._resolvers: 
    485                value = resolv(key) 
    486                if value is not None: 
    487                    return value 
    488 
    489        return self.fallback[key] 
    490 
    491    def _raise_for_name(self, name: str, err: Exception) -> NoReturn: 
    492        generic_match = re.match(r"(.+)\[(.+)\]", name) 
    493 
    494        if generic_match: 
    495            clsarg = generic_match.group(2).strip("'") 
    496            raise exc.InvalidRequestError( 
    497                f"When initializing mapper {self.prop.parent}, " 
    498                f'expression "relationship({self.arg!r})" seems to be ' 
    499                "using a generic class as the argument to relationship(); " 
    500                "please state the generic argument " 
    501                "using an annotation, e.g. " 
    502                f'"{self.prop.key}: Mapped[{generic_match.group(1)}' 
    503                f"['{clsarg}']] = relationship()\"" 
    504            ) from err 
    505        else: 
    506            raise exc.InvalidRequestError( 
    507                "When initializing mapper %s, expression %r failed to " 
    508                "locate a name (%r). If this is a class name, consider " 
    509                "adding this relationship() to the %r class after " 
    510                "both dependent classes have been defined." 
    511                % (self.prop.parent, self.arg, name, self.cls) 
    512            ) from err 
    513 
    514    def _resolve_name(self) -> Union[Table, Type[Any], _ModNS]: 
    515        name = self.arg 
    516        d = self._dict 
    517        rval = None 
    518        try: 
    519            for token in name.split("."): 
    520                if rval is None: 
    521                    rval = d[token] 
    522                else: 
    523                    rval = getattr(rval, token) 
    524        except KeyError as err: 
    525            self._raise_for_name(name, err) 
    526        except NameError as n: 
    527            self._raise_for_name(n.args[0], n) 
    528        else: 
    529            if isinstance(rval, _GetColumns): 
    530                return rval.cls 
    531            else: 
    532                if TYPE_CHECKING: 
    533                    assert isinstance(rval, (type, Table, _ModNS)) 
    534                return rval 
    535 
    536    def __call__(self) -> Any: 
    537        if self.tables_only: 
    538            try: 
    539                return self._dict[self.arg] 
    540            except KeyError as k: 
    541                self._raise_for_name(self.arg, k) 
    542        else: 
    543            try: 
    544                x = eval(self.arg, globals(), self._dict) 
    545 
    546                if isinstance(x, _GetColumns): 
    547                    return x.cls 
    548                else: 
    549                    return x 
    550            except NameError as n: 
    551                self._raise_for_name(n.args[0], n) 
    552 
    553 
    554_fallback_dict: Mapping[str, Any] = None  # type: ignore 
    555 
    556 
    557def _resolver(cls: Type[Any], prop: RelationshipProperty[Any]) -> Tuple[ 
    558    Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]], 
    559    Callable[[str, bool], _class_resolver], 
    560]: 
    561    global _fallback_dict 
    562 
    563    if _fallback_dict is None: 
    564        import sqlalchemy 
    565        from . import foreign 
    566        from . import remote 
    567 
    568        _fallback_dict = util.immutabledict(sqlalchemy.__dict__).union( 
    569            {"foreign": foreign, "remote": remote} 
    570        ) 
    571 
    572    def resolve_arg(arg: str, tables_only: bool = False) -> _class_resolver: 
    573        return _class_resolver( 
    574            cls, prop, _fallback_dict, arg, tables_only=tables_only 
    575        ) 
    576 
    577    def resolve_name( 
    578        arg: str, 
    579    ) -> Callable[[], Union[Type[Any], Table, _ModNS]]: 
    580        return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name 
    581 
    582    return resolve_name, resolve_arg