1# orm/clsregistry.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""Routines to handle the string class registry used by declarative.
9
10This system allows specification of classes and expressions used in
11:func:`_orm.relationship` using strings.
12
13"""
14
15from __future__ import annotations
16
17import re
18from typing import Any
19from typing import Callable
20from typing import cast
21from typing import Dict
22from typing import Generator
23from typing import Iterable
24from typing import List
25from typing import Mapping
26from typing import MutableMapping
27from typing import NoReturn
28from typing import Optional
29from typing import Set
30from typing import Tuple
31from typing import Type
32from typing import TYPE_CHECKING
33from typing import TypeVar
34from typing import Union
35import weakref
36
37from . import attributes
38from . import interfaces
39from .descriptor_props import SynonymProperty
40from .properties import ColumnProperty
41from .util import class_mapper
42from .. import exc
43from .. import inspection
44from .. import util
45from ..sql.schema import _get_table_key
46from ..util.typing import CallableReference
47
48if TYPE_CHECKING:
49 from .relationships import RelationshipProperty
50 from ..sql.schema import MetaData
51 from ..sql.schema import Table
52
53_T = TypeVar("_T", bound=Any)
54
55_ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]]
56
57# strong references to registries which we place in
58# the _decl_class_registry, which is usually weak referencing.
59# the internal registries here link to classes with weakrefs and remove
60# themselves when all references to contained classes are removed.
61_registries: Set[ClsRegistryToken] = set()
62
63
64def add_class(
65 classname: str, cls: Type[_T], decl_class_registry: _ClsRegistryType
66) -> None:
67 """Add a class to the _decl_class_registry associated with the
68 given declarative class.
69
70 """
71 if classname in decl_class_registry:
72 # class already exists.
73 existing = decl_class_registry[classname]
74 if not isinstance(existing, _MultipleClassMarker):
75 existing = decl_class_registry[classname] = _MultipleClassMarker(
76 [cls, cast("Type[Any]", existing)]
77 )
78 else:
79 decl_class_registry[classname] = cls
80
81 try:
82 root_module = cast(
83 _ModuleMarker, decl_class_registry["_sa_module_registry"]
84 )
85 except KeyError:
86 decl_class_registry["_sa_module_registry"] = root_module = (
87 _ModuleMarker("_sa_module_registry", None)
88 )
89
90 tokens = cls.__module__.split(".")
91
92 # build up a tree like this:
93 # modulename: myapp.snacks.nuts
94 #
95 # myapp->snack->nuts->(classes)
96 # snack->nuts->(classes)
97 # nuts->(classes)
98 #
99 # this allows partial token paths to be used.
100 while tokens:
101 token = tokens.pop(0)
102 module = root_module.get_module(token)
103 for token in tokens:
104 module = module.get_module(token)
105
106 try:
107 module.add_class(classname, cls)
108 except AttributeError as ae:
109 if not isinstance(module, _ModuleMarker):
110 raise exc.InvalidRequestError(
111 f'name "{classname}" matches both a '
112 "class name and a module name"
113 ) from ae
114 else:
115 raise
116
117
118def remove_class(
119 classname: str, cls: Type[Any], decl_class_registry: _ClsRegistryType
120) -> None:
121 if classname in decl_class_registry:
122 existing = decl_class_registry[classname]
123 if isinstance(existing, _MultipleClassMarker):
124 existing.remove_item(cls)
125 else:
126 del decl_class_registry[classname]
127
128 try:
129 root_module = cast(
130 _ModuleMarker, decl_class_registry["_sa_module_registry"]
131 )
132 except KeyError:
133 return
134
135 tokens = cls.__module__.split(".")
136
137 while tokens:
138 token = tokens.pop(0)
139 module = root_module.get_module(token)
140 for token in tokens:
141 module = module.get_module(token)
142 try:
143 module.remove_class(classname, cls)
144 except AttributeError:
145 if not isinstance(module, _ModuleMarker):
146 pass
147 else:
148 raise
149
150
151def _key_is_empty(
152 key: str,
153 decl_class_registry: _ClsRegistryType,
154 test: Callable[[Any], bool],
155) -> bool:
156 """test if a key is empty of a certain object.
157
158 used for unit tests against the registry to see if garbage collection
159 is working.
160
161 "test" is a callable that will be passed an object should return True
162 if the given object is the one we were looking for.
163
164 We can't pass the actual object itself b.c. this is for testing garbage
165 collection; the caller will have to have removed references to the
166 object itself.
167
168 """
169 if key not in decl_class_registry:
170 return True
171
172 thing = decl_class_registry[key]
173 if isinstance(thing, _MultipleClassMarker):
174 for sub_thing in thing.contents:
175 if test(sub_thing):
176 return False
177 else:
178 raise NotImplementedError("unknown codepath")
179 else:
180 return not test(thing)
181
182
183class ClsRegistryToken:
184 """an object that can be in the registry._class_registry as a value."""
185
186 __slots__ = ()
187
188
189class _MultipleClassMarker(ClsRegistryToken):
190 """refers to multiple classes of the same name
191 within _decl_class_registry.
192
193 """
194
195 __slots__ = "on_remove", "contents", "__weakref__"
196
197 contents: Set[weakref.ref[Type[Any]]]
198 on_remove: CallableReference[Optional[Callable[[], None]]]
199
200 def __init__(
201 self,
202 classes: Iterable[Type[Any]],
203 on_remove: Optional[Callable[[], None]] = None,
204 ):
205 self.on_remove = on_remove
206 self.contents = {
207 weakref.ref(item, self._remove_item) for item in classes
208 }
209 _registries.add(self)
210
211 def remove_item(self, cls: Type[Any]) -> None:
212 self._remove_item(weakref.ref(cls))
213
214 def __iter__(self) -> Generator[Optional[Type[Any]], None, None]:
215 return (ref() for ref in self.contents)
216
217 def attempt_get(self, path: List[str], key: str) -> Type[Any]:
218 if len(self.contents) > 1:
219 raise exc.InvalidRequestError(
220 'Multiple classes found for path "%s" '
221 "in the registry of this declarative "
222 "base. Please use a fully module-qualified path."
223 % (".".join(path + [key]))
224 )
225 else:
226 ref = list(self.contents)[0]
227 cls = ref()
228 if cls is None:
229 raise NameError(key)
230 return cls
231
232 def _remove_item(self, ref: weakref.ref[Type[Any]]) -> None:
233 self.contents.discard(ref)
234 if not self.contents:
235 _registries.discard(self)
236 if self.on_remove:
237 self.on_remove()
238
239 def add_item(self, item: Type[Any]) -> None:
240 # protect against class registration race condition against
241 # asynchronous garbage collection calling _remove_item,
242 # [ticket:3208] and [ticket:10782]
243 modules = {
244 cls.__module__
245 for cls in [ref() for ref in list(self.contents)]
246 if cls is not None
247 }
248 if item.__module__ in modules:
249 util.warn(
250 "This declarative base already contains a class with the "
251 "same class name and module name as %s.%s, and will "
252 "be replaced in the string-lookup table."
253 % (item.__module__, item.__name__)
254 )
255 self.contents.add(weakref.ref(item, self._remove_item))
256
257
258class _ModuleMarker(ClsRegistryToken):
259 """Refers to a module name within
260 _decl_class_registry.
261
262 """
263
264 __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__"
265
266 parent: Optional[_ModuleMarker]
267 contents: Dict[str, Union[_ModuleMarker, _MultipleClassMarker]]
268 mod_ns: _ModNS
269 path: List[str]
270
271 def __init__(self, name: str, parent: Optional[_ModuleMarker]):
272 self.parent = parent
273 self.name = name
274 self.contents = {}
275 self.mod_ns = _ModNS(self)
276 if self.parent:
277 self.path = self.parent.path + [self.name]
278 else:
279 self.path = []
280 _registries.add(self)
281
282 def __contains__(self, name: str) -> bool:
283 return name in self.contents
284
285 def __getitem__(self, name: str) -> ClsRegistryToken:
286 return self.contents[name]
287
288 def _remove_item(self, name: str) -> None:
289 self.contents.pop(name, None)
290 if not self.contents:
291 if self.parent is not None:
292 self.parent._remove_item(self.name)
293 _registries.discard(self)
294
295 def resolve_attr(self, key: str) -> Union[_ModNS, Type[Any]]:
296 return self.mod_ns.__getattr__(key)
297
298 def get_module(self, name: str) -> _ModuleMarker:
299 if name not in self.contents:
300 marker = _ModuleMarker(name, self)
301 self.contents[name] = marker
302 else:
303 marker = cast(_ModuleMarker, self.contents[name])
304 return marker
305
306 def add_class(self, name: str, cls: Type[Any]) -> None:
307 if name in self.contents:
308 existing = cast(_MultipleClassMarker, self.contents[name])
309 try:
310 existing.add_item(cls)
311 except AttributeError as ae:
312 if not isinstance(existing, _MultipleClassMarker):
313 raise exc.InvalidRequestError(
314 f'name "{name}" matches both a '
315 "class name and a module name"
316 ) from ae
317 else:
318 raise
319 else:
320 existing = self.contents[name] = _MultipleClassMarker(
321 [cls], on_remove=lambda: self._remove_item(name)
322 )
323
324 def remove_class(self, name: str, cls: Type[Any]) -> None:
325 if name in self.contents:
326 existing = cast(_MultipleClassMarker, self.contents[name])
327 existing.remove_item(cls)
328
329
330class _ModNS:
331 __slots__ = ("__parent",)
332
333 __parent: _ModuleMarker
334
335 def __init__(self, parent: _ModuleMarker):
336 self.__parent = parent
337
338 def __getattr__(self, key: str) -> Union[_ModNS, Type[Any]]:
339 try:
340 value = self.__parent.contents[key]
341 except KeyError:
342 pass
343 else:
344 if value is not None:
345 if isinstance(value, _ModuleMarker):
346 return value.mod_ns
347 else:
348 assert isinstance(value, _MultipleClassMarker)
349 return value.attempt_get(self.__parent.path, key)
350 raise NameError(
351 "Module %r has no mapped classes "
352 "registered under the name %r" % (self.__parent.name, key)
353 )
354
355
356class _GetColumns:
357 __slots__ = ("cls",)
358
359 cls: Type[Any]
360
361 def __init__(self, cls: Type[Any]):
362 self.cls = cls
363
364 def __getattr__(self, key: str) -> Any:
365 mp = class_mapper(self.cls, configure=False)
366 if mp:
367 if key not in mp.all_orm_descriptors:
368 raise AttributeError(
369 "Class %r does not have a mapped column named %r"
370 % (self.cls, key)
371 )
372
373 desc = mp.all_orm_descriptors[key]
374 if desc.extension_type is interfaces.NotExtension.NOT_EXTENSION:
375 assert isinstance(desc, attributes.QueryableAttribute)
376 prop = desc.property
377 if isinstance(prop, SynonymProperty):
378 key = prop.name
379 elif not isinstance(prop, ColumnProperty):
380 raise exc.InvalidRequestError(
381 "Property %r is not an instance of"
382 " ColumnProperty (i.e. does not correspond"
383 " directly to a Column)." % key
384 )
385 return getattr(self.cls, key)
386
387
388inspection._inspects(_GetColumns)(
389 lambda target: inspection.inspect(target.cls)
390)
391
392
393class _GetTable:
394 __slots__ = "key", "metadata"
395
396 key: str
397 metadata: MetaData
398
399 def __init__(self, key: str, metadata: MetaData):
400 self.key = key
401 self.metadata = metadata
402
403 def __getattr__(self, key: str) -> Table:
404 return self.metadata.tables[_get_table_key(key, self.key)]
405
406
407def _determine_container(key: str, value: Any) -> _GetColumns:
408 if isinstance(value, _MultipleClassMarker):
409 value = value.attempt_get([], key)
410 return _GetColumns(value)
411
412
413class _class_resolver:
414 __slots__ = (
415 "cls",
416 "prop",
417 "arg",
418 "fallback",
419 "_dict",
420 "_resolvers",
421 "favor_tables",
422 )
423
424 cls: Type[Any]
425 prop: RelationshipProperty[Any]
426 fallback: Mapping[str, Any]
427 arg: str
428 favor_tables: bool
429 _resolvers: Tuple[Callable[[str], Any], ...]
430
431 def __init__(
432 self,
433 cls: Type[Any],
434 prop: RelationshipProperty[Any],
435 fallback: Mapping[str, Any],
436 arg: str,
437 favor_tables: bool = False,
438 ):
439 self.cls = cls
440 self.prop = prop
441 self.arg = arg
442 self.fallback = fallback
443 self._dict = util.PopulateDict(self._access_cls)
444 self._resolvers = ()
445 self.favor_tables = favor_tables
446
447 def _access_cls(self, key: str) -> Any:
448 cls = self.cls
449
450 manager = attributes.manager_of_class(cls)
451 decl_base = manager.registry
452 assert decl_base is not None
453 decl_class_registry = decl_base._class_registry
454 metadata = decl_base.metadata
455
456 if self.favor_tables:
457 if key in metadata.tables:
458 return metadata.tables[key]
459 elif key in metadata._schemas:
460 return _GetTable(key, getattr(cls, "metadata", metadata))
461
462 if key in decl_class_registry:
463 return _determine_container(key, decl_class_registry[key])
464
465 if not self.favor_tables:
466 if key in metadata.tables:
467 return metadata.tables[key]
468 elif key in metadata._schemas:
469 return _GetTable(key, getattr(cls, "metadata", metadata))
470
471 if "_sa_module_registry" in decl_class_registry and key in cast(
472 _ModuleMarker, decl_class_registry["_sa_module_registry"]
473 ):
474 registry = cast(
475 _ModuleMarker, decl_class_registry["_sa_module_registry"]
476 )
477 return registry.resolve_attr(key)
478 elif self._resolvers:
479 for resolv in self._resolvers:
480 value = resolv(key)
481 if value is not None:
482 return value
483
484 return self.fallback[key]
485
486 def _raise_for_name(self, name: str, err: Exception) -> NoReturn:
487 generic_match = re.match(r"(.+)\[(.+)\]", name)
488
489 if generic_match:
490 clsarg = generic_match.group(2).strip("'")
491 raise exc.InvalidRequestError(
492 f"When initializing mapper {self.prop.parent}, "
493 f'expression "relationship({self.arg!r})" seems to be '
494 "using a generic class as the argument to relationship(); "
495 "please state the generic argument "
496 "using an annotation, e.g. "
497 f'"{self.prop.key}: Mapped[{generic_match.group(1)}'
498 f"['{clsarg}']] = relationship()\""
499 ) from err
500 else:
501 raise exc.InvalidRequestError(
502 "When initializing mapper %s, expression %r failed to "
503 "locate a name (%r). If this is a class name, consider "
504 "adding this relationship() to the %r class after "
505 "both dependent classes have been defined."
506 % (self.prop.parent, self.arg, name, self.cls)
507 ) from err
508
509 def _resolve_name(self) -> Union[Table, Type[Any], _ModNS]:
510 name = self.arg
511 d = self._dict
512 rval = None
513 try:
514 for token in name.split("."):
515 if rval is None:
516 rval = d[token]
517 else:
518 rval = getattr(rval, token)
519 except KeyError as err:
520 self._raise_for_name(name, err)
521 except NameError as n:
522 self._raise_for_name(n.args[0], n)
523 else:
524 if isinstance(rval, _GetColumns):
525 return rval.cls
526 else:
527 if TYPE_CHECKING:
528 assert isinstance(rval, (type, Table, _ModNS))
529 return rval
530
531 def __call__(self) -> Any:
532 try:
533 x = eval(self.arg, globals(), self._dict)
534
535 if isinstance(x, _GetColumns):
536 return x.cls
537 else:
538 return x
539 except NameError as n:
540 self._raise_for_name(n.args[0], n)
541
542
543_fallback_dict: Mapping[str, Any] = None # type: ignore
544
545
546def _resolver(cls: Type[Any], prop: RelationshipProperty[Any]) -> Tuple[
547 Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]],
548 Callable[[str, bool], _class_resolver],
549]:
550 global _fallback_dict
551
552 if _fallback_dict is None:
553 import sqlalchemy
554 from . import foreign
555 from . import remote
556
557 _fallback_dict = util.immutabledict(sqlalchemy.__dict__).union(
558 {"foreign": foreign, "remote": remote}
559 )
560
561 def resolve_arg(arg: str, favor_tables: bool = False) -> _class_resolver:
562 return _class_resolver(
563 cls, prop, _fallback_dict, arg, favor_tables=favor_tables
564 )
565
566 def resolve_name(
567 arg: str,
568 ) -> Callable[[], Union[Type[Any], Table, _ModNS]]:
569 return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name
570
571 return resolve_name, resolve_arg