1# orm/clsregistry.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""Routines to handle the string class registry used by declarative.
9
10This system allows specification of classes and expressions used in
11:func:`_orm.relationship` using strings.
12
13"""
14
15from __future__ import annotations
16
17import re
18from typing import Any
19from typing import Callable
20from typing import cast
21from typing import Dict
22from typing import Generator
23from typing import Iterable
24from typing import List
25from typing import Mapping
26from typing import MutableMapping
27from typing import NoReturn
28from typing import Optional
29from typing import Set
30from typing import Tuple
31from typing import Type
32from typing import TYPE_CHECKING
33from typing import TypeVar
34from typing import Union
35import weakref
36
37from . import attributes
38from . import interfaces
39from .descriptor_props import SynonymProperty
40from .properties import ColumnProperty
41from .util import _metadata_for_cls
42from .util import class_mapper
43from .. import exc
44from .. import inspection
45from .. import util
46from ..sql.schema import _get_table_key
47from ..util.typing import CallableReference
48
49if TYPE_CHECKING:
50 from .relationships import RelationshipProperty
51 from ..sql.schema import MetaData
52 from ..sql.schema import Table
53
54_T = TypeVar("_T", bound=Any)
55
56_ClsRegistryType = MutableMapping[str, Union[type, "_ClsRegistryToken"]]
57
58# strong references to registries which we place in
59# the _decl_class_registry, which is usually weak referencing.
60# the internal registries here link to classes with weakrefs and remove
61# themselves when all references to contained classes are removed.
62_registries: Set[_ClsRegistryToken] = set()
63
64
65def _add_class(
66 classname: str, cls: Type[_T], decl_class_registry: _ClsRegistryType
67) -> None:
68 """Add a class to the _decl_class_registry associated with the
69 given declarative class.
70
71 """
72 if classname in decl_class_registry:
73 # class already exists.
74 existing = decl_class_registry[classname]
75 if not isinstance(existing, _MultipleClassMarker):
76 decl_class_registry[classname] = _MultipleClassMarker(
77 [cls, cast("Type[Any]", existing)]
78 )
79 else:
80 decl_class_registry[classname] = cls
81
82 try:
83 root_module = cast(
84 _ModuleMarker, decl_class_registry["_sa_module_registry"]
85 )
86 except KeyError:
87 decl_class_registry["_sa_module_registry"] = root_module = (
88 _ModuleMarker("_sa_module_registry", None)
89 )
90
91 tokens = cls.__module__.split(".")
92
93 # build up a tree like this:
94 # modulename: myapp.snacks.nuts
95 #
96 # myapp->snack->nuts->(classes)
97 # snack->nuts->(classes)
98 # nuts->(classes)
99 #
100 # this allows partial token paths to be used.
101 while tokens:
102 token = tokens.pop(0)
103 module = root_module.get_module(token)
104 for token in tokens:
105 module = module.get_module(token)
106
107 try:
108 module.add_class(classname, cls)
109 except AttributeError as ae:
110 if not isinstance(module, _ModuleMarker):
111 raise exc.InvalidRequestError(
112 f'name "{classname}" matches both a '
113 "class name and a module name"
114 ) from ae
115 else:
116 raise
117
118
119def _remove_class(
120 classname: str, cls: Type[Any], decl_class_registry: _ClsRegistryType
121) -> None:
122 if classname in decl_class_registry:
123 existing = decl_class_registry[classname]
124 if isinstance(existing, _MultipleClassMarker):
125 existing.remove_item(cls)
126 else:
127 del decl_class_registry[classname]
128
129 try:
130 root_module = cast(
131 _ModuleMarker, decl_class_registry["_sa_module_registry"]
132 )
133 except KeyError:
134 return
135
136 tokens = cls.__module__.split(".")
137
138 while tokens:
139 token = tokens.pop(0)
140 module = root_module.get_module(token)
141 for token in tokens:
142 module = module.get_module(token)
143 try:
144 module.remove_class(classname, cls)
145 except AttributeError:
146 if not isinstance(module, _ModuleMarker):
147 pass
148 else:
149 raise
150
151
152def _key_is_empty(
153 key: str,
154 decl_class_registry: _ClsRegistryType,
155 test: Callable[[Any], bool],
156) -> bool:
157 """test if a key is empty of a certain object.
158
159 used for unit tests against the registry to see if garbage collection
160 is working.
161
162 "test" is a callable that will be passed an object should return True
163 if the given object is the one we were looking for.
164
165 We can't pass the actual object itself b.c. this is for testing garbage
166 collection; the caller will have to have removed references to the
167 object itself.
168
169 """
170 if key not in decl_class_registry:
171 return True
172
173 thing = decl_class_registry[key]
174 if isinstance(thing, _MultipleClassMarker):
175 for sub_thing in thing.contents:
176 if test(sub_thing):
177 return False
178 else:
179 raise NotImplementedError("unknown codepath")
180 else:
181 return not test(thing)
182
183
184class _ClsRegistryToken:
185 """an object that can be in the registry._class_registry as a value."""
186
187 __slots__ = ()
188
189
190class _MultipleClassMarker(_ClsRegistryToken):
191 """refers to multiple classes of the same name
192 within _decl_class_registry.
193
194 """
195
196 __slots__ = "on_remove", "contents", "__weakref__"
197
198 contents: Set[weakref.ref[Type[Any]]]
199 on_remove: CallableReference[Optional[Callable[[], None]]]
200
201 def __init__(
202 self,
203 classes: Iterable[Type[Any]],
204 on_remove: Optional[Callable[[], None]] = None,
205 ):
206 self.on_remove = on_remove
207 self.contents = {
208 weakref.ref(item, self._remove_item) for item in classes
209 }
210 _registries.add(self)
211
212 def remove_item(self, cls: Type[Any]) -> None:
213 self._remove_item(weakref.ref(cls))
214
215 def __iter__(self) -> Generator[Optional[Type[Any]], None, None]:
216 return (ref() for ref in self.contents)
217
218 def attempt_get(self, path: List[str], key: str) -> Type[Any]:
219 if len(self.contents) > 1:
220 raise exc.InvalidRequestError(
221 'Multiple classes found for path "%s" '
222 "in the registry of this declarative "
223 "base. Please use a fully module-qualified path."
224 % (".".join(path + [key]))
225 )
226 else:
227 ref = list(self.contents)[0]
228 cls = ref()
229 if cls is None:
230 raise NameError(key)
231 return cls
232
233 def _remove_item(self, ref: weakref.ref[Type[Any]]) -> None:
234 self.contents.discard(ref)
235 if not self.contents:
236 _registries.discard(self)
237 if self.on_remove:
238 self.on_remove()
239
240 def add_item(self, item: Type[Any]) -> None:
241 # protect against class registration race condition against
242 # asynchronous garbage collection calling _remove_item,
243 # [ticket:3208] and [ticket:10782]
244 modules = {
245 cls.__module__
246 for cls in [ref() for ref in list(self.contents)]
247 if cls is not None
248 }
249 if item.__module__ in modules:
250 util.warn(
251 "This declarative base already contains a class with the "
252 "same class name and module name as %s.%s, and will "
253 "be replaced in the string-lookup table."
254 % (item.__module__, item.__name__)
255 )
256 self.contents.add(weakref.ref(item, self._remove_item))
257
258
259class _ModuleMarker(_ClsRegistryToken):
260 """Refers to a module name within
261 _decl_class_registry.
262
263 """
264
265 __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__"
266
267 parent: Optional[_ModuleMarker]
268 contents: Dict[str, Union[_ModuleMarker, _MultipleClassMarker]]
269 mod_ns: _ModNS
270 path: List[str]
271
272 def __init__(self, name: str, parent: Optional[_ModuleMarker]):
273 self.parent = parent
274 self.name = name
275 self.contents = {}
276 self.mod_ns = _ModNS(self)
277 if self.parent:
278 self.path = self.parent.path + [self.name]
279 else:
280 self.path = []
281 _registries.add(self)
282
283 def __contains__(self, name: str) -> bool:
284 return name in self.contents
285
286 def __getitem__(self, name: str) -> _ClsRegistryToken:
287 return self.contents[name]
288
289 def _remove_item(self, name: str) -> None:
290 self.contents.pop(name, None)
291 if not self.contents:
292 if self.parent is not None:
293 self.parent._remove_item(self.name)
294 _registries.discard(self)
295
296 def resolve_attr(self, key: str) -> Union[_ModNS, Type[Any]]:
297 return self.mod_ns.__getattr__(key)
298
299 def get_module(self, name: str) -> _ModuleMarker:
300 if name not in self.contents:
301 marker = _ModuleMarker(name, self)
302 self.contents[name] = marker
303 else:
304 marker = cast(_ModuleMarker, self.contents[name])
305 return marker
306
307 def add_class(self, name: str, cls: Type[Any]) -> None:
308 if name in self.contents:
309 existing = cast(_MultipleClassMarker, self.contents[name])
310 try:
311 existing.add_item(cls)
312 except AttributeError as ae:
313 if not isinstance(existing, _MultipleClassMarker):
314 raise exc.InvalidRequestError(
315 f'name "{name}" matches both a '
316 "class name and a module name"
317 ) from ae
318 else:
319 raise
320 else:
321 self.contents[name] = _MultipleClassMarker(
322 [cls], on_remove=lambda: self._remove_item(name)
323 )
324
325 def remove_class(self, name: str, cls: Type[Any]) -> None:
326 if name in self.contents:
327 existing = cast(_MultipleClassMarker, self.contents[name])
328 existing.remove_item(cls)
329
330
331class _ModNS:
332 __slots__ = ("__parent",)
333
334 __parent: _ModuleMarker
335
336 def __init__(self, parent: _ModuleMarker):
337 self.__parent = parent
338
339 def __getattr__(self, key: str) -> Union[_ModNS, Type[Any]]:
340 try:
341 value = self.__parent.contents[key]
342 except KeyError:
343 pass
344 else:
345 if value is not None:
346 if isinstance(value, _ModuleMarker):
347 return value.mod_ns
348 else:
349 assert isinstance(value, _MultipleClassMarker)
350 return value.attempt_get(self.__parent.path, key)
351 raise NameError(
352 "Module %r has no mapped classes "
353 "registered under the name %r" % (self.__parent.name, key)
354 )
355
356
357class _GetColumns:
358 __slots__ = ("cls",)
359
360 cls: Type[Any]
361
362 def __init__(self, cls: Type[Any]):
363 self.cls = cls
364
365 def __getattr__(self, key: str) -> Any:
366 mp = class_mapper(self.cls, configure=False)
367 if mp:
368 if key not in mp.all_orm_descriptors:
369 raise AttributeError(
370 "Class %r does not have a mapped column named %r"
371 % (self.cls, key)
372 )
373
374 desc = mp.all_orm_descriptors[key]
375 if desc.extension_type is interfaces.NotExtension.NOT_EXTENSION:
376 assert isinstance(desc, attributes.QueryableAttribute)
377 prop = desc.property
378 if isinstance(prop, SynonymProperty):
379 key = prop.name
380 elif not isinstance(prop, ColumnProperty):
381 raise exc.InvalidRequestError(
382 "Property %r is not an instance of"
383 " ColumnProperty (i.e. does not correspond"
384 " directly to a Column)." % key
385 )
386 return getattr(self.cls, key)
387
388
389inspection._inspects(_GetColumns)(
390 lambda target: inspection.inspect(target.cls)
391)
392
393
394class _GetTable:
395 __slots__ = "key", "metadata"
396
397 key: str
398 metadata: MetaData
399
400 def __init__(self, key: str, metadata: MetaData):
401 self.key = key
402 self.metadata = metadata
403
404 def __getattr__(self, key: str) -> Table:
405 return self.metadata.tables[_get_table_key(key, self.key)]
406
407
408def _determine_container(key: str, value: Any) -> _GetColumns:
409 if isinstance(value, _MultipleClassMarker):
410 value = value.attempt_get([], key)
411 return _GetColumns(value)
412
413
414class _class_resolver:
415 __slots__ = (
416 "cls",
417 "prop",
418 "arg",
419 "fallback",
420 "_dict",
421 "_resolvers",
422 "tables_only",
423 )
424
425 cls: Type[Any]
426 prop: RelationshipProperty[Any]
427 fallback: Mapping[str, Any]
428 arg: str
429 tables_only: bool
430 _resolvers: Tuple[Callable[[str], Any], ...]
431
432 def __init__(
433 self,
434 cls: Type[Any],
435 prop: RelationshipProperty[Any],
436 fallback: Mapping[str, Any],
437 arg: str,
438 tables_only: bool = False,
439 ):
440 self.cls = cls
441 self.prop = prop
442 self.arg = arg
443 self.fallback = fallback
444 self._dict = util.PopulateDict(self._access_cls)
445 self._resolvers = ()
446 self.tables_only = tables_only
447
448 def _resolve_table_key(
449 self, key: str, metadata: MetaData
450 ) -> Optional[Table]:
451 if metadata.schema is not None and "." not in key:
452 schema_key = _get_table_key(key, metadata.schema)
453 if schema_key in metadata.tables:
454 return metadata.tables[schema_key]
455 if key in metadata.tables:
456 util.warn_deprecated(
457 "The string '%s' was resolved to the "
458 "non-schema-qualified table '%s', however "
459 "the MetaData object has a default schema "
460 "of '%s'. In a future version of SQLAlchemy, "
461 "this unqualified name will be resolved as "
462 "'%s'. To reference a table without a "
463 "schema, use the Table object directly."
464 % (key, key, metadata.schema, schema_key),
465 "2.1",
466 )
467 return metadata.tables[key]
468 elif key in metadata.tables:
469 return metadata.tables[key]
470 return None
471
472 def _access_cls(self, key: str) -> Any:
473 cls = self.cls
474
475 manager = attributes.manager_of_class(cls)
476 registry = manager.registry
477 assert registry is not None
478 decl_class_registry = registry._class_registry
479 metadata = _metadata_for_cls(cls, registry)
480
481 if self.tables_only:
482 table = self._resolve_table_key(key, metadata)
483 if table is not None:
484 return table
485 elif key in metadata._schemas:
486 return _GetTable(key, metadata)
487
488 if key in decl_class_registry:
489 dt = _determine_container(key, decl_class_registry[key])
490 if self.tables_only:
491 return dt.cls
492 else:
493 return dt
494
495 if not self.tables_only:
496 table = self._resolve_table_key(key, metadata)
497 if table is not None:
498 return table
499 elif key in metadata._schemas:
500 return _GetTable(key, metadata)
501
502 if "_sa_module_registry" in decl_class_registry and key in cast(
503 _ModuleMarker, decl_class_registry["_sa_module_registry"]
504 ):
505 _module_registry = cast(
506 _ModuleMarker, decl_class_registry["_sa_module_registry"]
507 )
508 return _module_registry.resolve_attr(key)
509
510 if self._resolvers:
511 for resolv in self._resolvers:
512 value = resolv(key)
513 if value is not None:
514 return value
515
516 return self.fallback[key]
517
518 def _raise_for_name(self, name: str, err: Exception) -> NoReturn:
519 generic_match = re.match(r"(.+)\[(.+)\]", name)
520
521 if generic_match:
522 clsarg = generic_match.group(2).strip("'")
523 raise exc.InvalidRequestError(
524 f"When initializing mapper {self.prop.parent}, "
525 f'expression "relationship({self.arg!r})" seems to be '
526 "using a generic class as the argument to relationship(); "
527 "please state the generic argument "
528 "using an annotation, e.g. "
529 f'"{self.prop.key}: Mapped[{generic_match.group(1)}'
530 f"['{clsarg}']] = relationship()\""
531 ) from err
532 else:
533 manager = attributes.manager_of_class(self.cls)
534 registry = manager.registry
535 metadata = (
536 _metadata_for_cls(self.cls, registry)
537 if registry is not None
538 else None
539 )
540
541 # when deprecated fallback lookup in
542 # _resolve_table_key is removed, consider adding
543 # additional context to the error message if the
544 # unqualified key is located under BLANK_SCHEMA
545 if metadata is not None and metadata.schema is not None:
546 schema_key = _get_table_key(name, metadata.schema)
547 assert schema_key not in metadata.tables
548
549 raise exc.InvalidRequestError(
550 "When initializing mapper %s, expression %r failed to "
551 "locate a name (%r). If this is a class name, consider "
552 "adding this relationship() to the %r class after "
553 "both dependent classes have been defined."
554 % (
555 self.prop.parent,
556 self.arg,
557 name,
558 self.cls,
559 )
560 ) from err
561
562 def _resolve_name(self) -> Union[Table, Type[Any], _ModNS]:
563 name = self.arg
564 d = self._dict
565 rval = None
566 try:
567 for token in name.split("."):
568 if rval is None:
569 rval = d[token]
570 else:
571 rval = getattr(rval, token)
572 except KeyError as err:
573 self._raise_for_name(name, err)
574 except NameError as n:
575 self._raise_for_name(n.args[0], n)
576 else:
577 if isinstance(rval, _GetColumns):
578 return rval.cls
579 else:
580 if TYPE_CHECKING:
581 assert isinstance(rval, (type, Table, _ModNS))
582 return rval
583
584 def __call__(self) -> Any:
585 if self.tables_only:
586 try:
587 return self._dict[self.arg]
588 except KeyError as k:
589 self._raise_for_name(self.arg, k)
590 else:
591 try:
592 x = eval(self.arg, globals(), self._dict)
593
594 if isinstance(x, _GetColumns):
595 return x.cls
596 else:
597 return x
598 except NameError as n:
599 self._raise_for_name(n.args[0], n)
600
601
602_fallback_dict: Mapping[str, Any] = None # type: ignore
603
604
605def _resolver(cls: Type[Any], prop: RelationshipProperty[Any]) -> Tuple[
606 Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]],
607 Callable[[str, bool], _class_resolver],
608]:
609 global _fallback_dict
610
611 if _fallback_dict is None:
612 import sqlalchemy
613 from . import foreign
614 from . import remote
615
616 _fallback_dict = util.immutabledict(sqlalchemy.__dict__).union(
617 {"foreign": foreign, "remote": remote}
618 )
619
620 def resolve_arg(arg: str, tables_only: bool = False) -> _class_resolver:
621 return _class_resolver(
622 cls, prop, _fallback_dict, arg, tables_only=tables_only
623 )
624
625 def resolve_name(
626 arg: str,
627 ) -> Callable[[], Union[Type[Any], Table, _ModNS]]:
628 return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name
629
630 return resolve_name, resolve_arg