1# orm/clsregistry.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""Routines to handle the string class registry used by declarative.
9
10This system allows specification of classes and expressions used in
11:func:`_orm.relationship` using strings.
12
13"""
14
15from __future__ import annotations
16
17import re
18from typing import Any
19from typing import Callable
20from typing import cast
21from typing import Dict
22from typing import Generator
23from typing import Iterable
24from typing import List
25from typing import Mapping
26from typing import MutableMapping
27from typing import NoReturn
28from typing import Optional
29from typing import Set
30from typing import Tuple
31from typing import Type
32from typing import TYPE_CHECKING
33from typing import TypeVar
34from typing import Union
35import weakref
36
37from . import attributes
38from . import interfaces
39from .descriptor_props import SynonymProperty
40from .properties import ColumnProperty
41from .util import class_mapper
42from .. import exc
43from .. import inspection
44from .. import util
45from ..sql.schema import _get_table_key
46from ..util.typing import CallableReference
47
48if TYPE_CHECKING:
49 from .relationships import RelationshipProperty
50 from ..sql.schema import MetaData
51 from ..sql.schema import Table
52
53_T = TypeVar("_T", bound=Any)
54
55_ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]]
56
57# strong references to registries which we place in
58# the _decl_class_registry, which is usually weak referencing.
59# the internal registries here link to classes with weakrefs and remove
60# themselves when all references to contained classes are removed.
61_registries: Set[ClsRegistryToken] = set()
62
63
64def add_class(
65 classname: str, cls: Type[_T], decl_class_registry: _ClsRegistryType
66) -> None:
67 """Add a class to the _decl_class_registry associated with the
68 given declarative class.
69
70 """
71 if classname in decl_class_registry:
72 # class already exists.
73 existing = decl_class_registry[classname]
74 if not isinstance(existing, _MultipleClassMarker):
75 existing = decl_class_registry[classname] = _MultipleClassMarker(
76 [cls, cast("Type[Any]", existing)]
77 )
78 else:
79 decl_class_registry[classname] = cls
80
81 try:
82 root_module = cast(
83 _ModuleMarker, decl_class_registry["_sa_module_registry"]
84 )
85 except KeyError:
86 decl_class_registry["_sa_module_registry"] = root_module = (
87 _ModuleMarker("_sa_module_registry", None)
88 )
89
90 tokens = cls.__module__.split(".")
91
92 # build up a tree like this:
93 # modulename: myapp.snacks.nuts
94 #
95 # myapp->snack->nuts->(classes)
96 # snack->nuts->(classes)
97 # nuts->(classes)
98 #
99 # this allows partial token paths to be used.
100 while tokens:
101 token = tokens.pop(0)
102 module = root_module.get_module(token)
103 for token in tokens:
104 module = module.get_module(token)
105
106 try:
107 module.add_class(classname, cls)
108 except AttributeError as ae:
109 if not isinstance(module, _ModuleMarker):
110 raise exc.InvalidRequestError(
111 f'name "{classname}" matches both a '
112 "class name and a module name"
113 ) from ae
114 else:
115 raise
116
117
118def remove_class(
119 classname: str, cls: Type[Any], decl_class_registry: _ClsRegistryType
120) -> None:
121 if classname in decl_class_registry:
122 existing = decl_class_registry[classname]
123 if isinstance(existing, _MultipleClassMarker):
124 existing.remove_item(cls)
125 else:
126 del decl_class_registry[classname]
127
128 try:
129 root_module = cast(
130 _ModuleMarker, decl_class_registry["_sa_module_registry"]
131 )
132 except KeyError:
133 return
134
135 tokens = cls.__module__.split(".")
136
137 while tokens:
138 token = tokens.pop(0)
139 module = root_module.get_module(token)
140 for token in tokens:
141 module = module.get_module(token)
142 try:
143 module.remove_class(classname, cls)
144 except AttributeError:
145 if not isinstance(module, _ModuleMarker):
146 pass
147 else:
148 raise
149
150
151def _key_is_empty(
152 key: str,
153 decl_class_registry: _ClsRegistryType,
154 test: Callable[[Any], bool],
155) -> bool:
156 """test if a key is empty of a certain object.
157
158 used for unit tests against the registry to see if garbage collection
159 is working.
160
161 "test" is a callable that will be passed an object should return True
162 if the given object is the one we were looking for.
163
164 We can't pass the actual object itself b.c. this is for testing garbage
165 collection; the caller will have to have removed references to the
166 object itself.
167
168 """
169 if key not in decl_class_registry:
170 return True
171
172 thing = decl_class_registry[key]
173 if isinstance(thing, _MultipleClassMarker):
174 for sub_thing in thing.contents:
175 if test(sub_thing):
176 return False
177 else:
178 raise NotImplementedError("unknown codepath")
179 else:
180 return not test(thing)
181
182
183class ClsRegistryToken:
184 """an object that can be in the registry._class_registry as a value."""
185
186 __slots__ = ()
187
188
189class _MultipleClassMarker(ClsRegistryToken):
190 """refers to multiple classes of the same name
191 within _decl_class_registry.
192
193 """
194
195 __slots__ = "on_remove", "contents", "__weakref__"
196
197 contents: Set[weakref.ref[Type[Any]]]
198 on_remove: CallableReference[Optional[Callable[[], None]]]
199
200 def __init__(
201 self,
202 classes: Iterable[Type[Any]],
203 on_remove: Optional[Callable[[], None]] = None,
204 ):
205 self.on_remove = on_remove
206 self.contents = {
207 weakref.ref(item, self._remove_item) for item in classes
208 }
209 _registries.add(self)
210
211 def remove_item(self, cls: Type[Any]) -> None:
212 self._remove_item(weakref.ref(cls))
213
214 def __iter__(self) -> Generator[Optional[Type[Any]], None, None]:
215 return (ref() for ref in self.contents)
216
217 def attempt_get(self, path: List[str], key: str) -> Type[Any]:
218 if len(self.contents) > 1:
219 raise exc.InvalidRequestError(
220 'Multiple classes found for path "%s" '
221 "in the registry of this declarative "
222 "base. Please use a fully module-qualified path."
223 % (".".join(path + [key]))
224 )
225 else:
226 ref = list(self.contents)[0]
227 cls = ref()
228 if cls is None:
229 raise NameError(key)
230 return cls
231
232 def _remove_item(self, ref: weakref.ref[Type[Any]]) -> None:
233 self.contents.discard(ref)
234 if not self.contents:
235 _registries.discard(self)
236 if self.on_remove:
237 self.on_remove()
238
239 def add_item(self, item: Type[Any]) -> None:
240 # protect against class registration race condition against
241 # asynchronous garbage collection calling _remove_item,
242 # [ticket:3208] and [ticket:10782]
243 modules = {
244 cls.__module__
245 for cls in [ref() for ref in list(self.contents)]
246 if cls is not None
247 }
248 if item.__module__ in modules:
249 util.warn(
250 "This declarative base already contains a class with the "
251 "same class name and module name as %s.%s, and will "
252 "be replaced in the string-lookup table."
253 % (item.__module__, item.__name__)
254 )
255 self.contents.add(weakref.ref(item, self._remove_item))
256
257
258class _ModuleMarker(ClsRegistryToken):
259 """Refers to a module name within
260 _decl_class_registry.
261
262 """
263
264 __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__"
265
266 parent: Optional[_ModuleMarker]
267 contents: Dict[str, Union[_ModuleMarker, _MultipleClassMarker]]
268 mod_ns: _ModNS
269 path: List[str]
270
271 def __init__(self, name: str, parent: Optional[_ModuleMarker]):
272 self.parent = parent
273 self.name = name
274 self.contents = {}
275 self.mod_ns = _ModNS(self)
276 if self.parent:
277 self.path = self.parent.path + [self.name]
278 else:
279 self.path = []
280 _registries.add(self)
281
282 def __contains__(self, name: str) -> bool:
283 return name in self.contents
284
285 def __getitem__(self, name: str) -> ClsRegistryToken:
286 return self.contents[name]
287
288 def _remove_item(self, name: str) -> None:
289 self.contents.pop(name, None)
290 if not self.contents and self.parent is not None:
291 self.parent._remove_item(self.name)
292 _registries.discard(self)
293
294 def resolve_attr(self, key: str) -> Union[_ModNS, Type[Any]]:
295 return self.mod_ns.__getattr__(key)
296
297 def get_module(self, name: str) -> _ModuleMarker:
298 if name not in self.contents:
299 marker = _ModuleMarker(name, self)
300 self.contents[name] = marker
301 else:
302 marker = cast(_ModuleMarker, self.contents[name])
303 return marker
304
305 def add_class(self, name: str, cls: Type[Any]) -> None:
306 if name in self.contents:
307 existing = cast(_MultipleClassMarker, self.contents[name])
308 try:
309 existing.add_item(cls)
310 except AttributeError as ae:
311 if not isinstance(existing, _MultipleClassMarker):
312 raise exc.InvalidRequestError(
313 f'name "{name}" matches both a '
314 "class name and a module name"
315 ) from ae
316 else:
317 raise
318 else:
319 existing = self.contents[name] = _MultipleClassMarker(
320 [cls], on_remove=lambda: self._remove_item(name)
321 )
322
323 def remove_class(self, name: str, cls: Type[Any]) -> None:
324 if name in self.contents:
325 existing = cast(_MultipleClassMarker, self.contents[name])
326 existing.remove_item(cls)
327
328
329class _ModNS:
330 __slots__ = ("__parent",)
331
332 __parent: _ModuleMarker
333
334 def __init__(self, parent: _ModuleMarker):
335 self.__parent = parent
336
337 def __getattr__(self, key: str) -> Union[_ModNS, Type[Any]]:
338 try:
339 value = self.__parent.contents[key]
340 except KeyError:
341 pass
342 else:
343 if value is not None:
344 if isinstance(value, _ModuleMarker):
345 return value.mod_ns
346 else:
347 assert isinstance(value, _MultipleClassMarker)
348 return value.attempt_get(self.__parent.path, key)
349 raise NameError(
350 "Module %r has no mapped classes "
351 "registered under the name %r" % (self.__parent.name, key)
352 )
353
354
355class _GetColumns:
356 __slots__ = ("cls",)
357
358 cls: Type[Any]
359
360 def __init__(self, cls: Type[Any]):
361 self.cls = cls
362
363 def __getattr__(self, key: str) -> Any:
364 mp = class_mapper(self.cls, configure=False)
365 if mp:
366 if key not in mp.all_orm_descriptors:
367 raise AttributeError(
368 "Class %r does not have a mapped column named %r"
369 % (self.cls, key)
370 )
371
372 desc = mp.all_orm_descriptors[key]
373 if desc.extension_type is interfaces.NotExtension.NOT_EXTENSION:
374 assert isinstance(desc, attributes.QueryableAttribute)
375 prop = desc.property
376 if isinstance(prop, SynonymProperty):
377 key = prop.name
378 elif not isinstance(prop, ColumnProperty):
379 raise exc.InvalidRequestError(
380 "Property %r is not an instance of"
381 " ColumnProperty (i.e. does not correspond"
382 " directly to a Column)." % key
383 )
384 return getattr(self.cls, key)
385
386
387inspection._inspects(_GetColumns)(
388 lambda target: inspection.inspect(target.cls)
389)
390
391
392class _GetTable:
393 __slots__ = "key", "metadata"
394
395 key: str
396 metadata: MetaData
397
398 def __init__(self, key: str, metadata: MetaData):
399 self.key = key
400 self.metadata = metadata
401
402 def __getattr__(self, key: str) -> Table:
403 return self.metadata.tables[_get_table_key(key, self.key)]
404
405
406def _determine_container(key: str, value: Any) -> _GetColumns:
407 if isinstance(value, _MultipleClassMarker):
408 value = value.attempt_get([], key)
409 return _GetColumns(value)
410
411
412class _class_resolver:
413 __slots__ = (
414 "cls",
415 "prop",
416 "arg",
417 "fallback",
418 "_dict",
419 "_resolvers",
420 "favor_tables",
421 )
422
423 cls: Type[Any]
424 prop: RelationshipProperty[Any]
425 fallback: Mapping[str, Any]
426 arg: str
427 favor_tables: bool
428 _resolvers: Tuple[Callable[[str], Any], ...]
429
430 def __init__(
431 self,
432 cls: Type[Any],
433 prop: RelationshipProperty[Any],
434 fallback: Mapping[str, Any],
435 arg: str,
436 favor_tables: bool = False,
437 ):
438 self.cls = cls
439 self.prop = prop
440 self.arg = arg
441 self.fallback = fallback
442 self._dict = util.PopulateDict(self._access_cls)
443 self._resolvers = ()
444 self.favor_tables = favor_tables
445
446 def _access_cls(self, key: str) -> Any:
447 cls = self.cls
448
449 manager = attributes.manager_of_class(cls)
450 decl_base = manager.registry
451 assert decl_base is not None
452 decl_class_registry = decl_base._class_registry
453 metadata = decl_base.metadata
454
455 if self.favor_tables:
456 if key in metadata.tables:
457 return metadata.tables[key]
458 elif key in metadata._schemas:
459 return _GetTable(key, getattr(cls, "metadata", metadata))
460
461 if key in decl_class_registry:
462 return _determine_container(key, decl_class_registry[key])
463
464 if not self.favor_tables:
465 if key in metadata.tables:
466 return metadata.tables[key]
467 elif key in metadata._schemas:
468 return _GetTable(key, getattr(cls, "metadata", metadata))
469
470 if "_sa_module_registry" in decl_class_registry and key in cast(
471 _ModuleMarker, decl_class_registry["_sa_module_registry"]
472 ):
473 registry = cast(
474 _ModuleMarker, decl_class_registry["_sa_module_registry"]
475 )
476 return registry.resolve_attr(key)
477 elif self._resolvers:
478 for resolv in self._resolvers:
479 value = resolv(key)
480 if value is not None:
481 return value
482
483 return self.fallback[key]
484
485 def _raise_for_name(self, name: str, err: Exception) -> NoReturn:
486 generic_match = re.match(r"(.+)\[(.+)\]", name)
487
488 if generic_match:
489 clsarg = generic_match.group(2).strip("'")
490 raise exc.InvalidRequestError(
491 f"When initializing mapper {self.prop.parent}, "
492 f'expression "relationship({self.arg!r})" seems to be '
493 "using a generic class as the argument to relationship(); "
494 "please state the generic argument "
495 "using an annotation, e.g. "
496 f'"{self.prop.key}: Mapped[{generic_match.group(1)}'
497 f"['{clsarg}']] = relationship()\""
498 ) from err
499 else:
500 raise exc.InvalidRequestError(
501 "When initializing mapper %s, expression %r failed to "
502 "locate a name (%r). If this is a class name, consider "
503 "adding this relationship() to the %r class after "
504 "both dependent classes have been defined."
505 % (self.prop.parent, self.arg, name, self.cls)
506 ) from err
507
508 def _resolve_name(self) -> Union[Table, Type[Any], _ModNS]:
509 name = self.arg
510 d = self._dict
511 rval = None
512 try:
513 for token in name.split("."):
514 if rval is None:
515 rval = d[token]
516 else:
517 rval = getattr(rval, token)
518 except KeyError as err:
519 self._raise_for_name(name, err)
520 except NameError as n:
521 self._raise_for_name(n.args[0], n)
522 else:
523 if isinstance(rval, _GetColumns):
524 return rval.cls
525 else:
526 if TYPE_CHECKING:
527 assert isinstance(rval, (type, Table, _ModNS))
528 return rval
529
530 def __call__(self) -> Any:
531 try:
532 x = eval(self.arg, globals(), self._dict)
533
534 if isinstance(x, _GetColumns):
535 return x.cls
536 else:
537 return x
538 except NameError as n:
539 self._raise_for_name(n.args[0], n)
540
541
542_fallback_dict: Mapping[str, Any] = None # type: ignore
543
544
545def _resolver(cls: Type[Any], prop: RelationshipProperty[Any]) -> Tuple[
546 Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]],
547 Callable[[str, bool], _class_resolver],
548]:
549 global _fallback_dict
550
551 if _fallback_dict is None:
552 import sqlalchemy
553 from . import foreign
554 from . import remote
555
556 _fallback_dict = util.immutabledict(sqlalchemy.__dict__).union(
557 {"foreign": foreign, "remote": remote}
558 )
559
560 def resolve_arg(arg: str, favor_tables: bool = False) -> _class_resolver:
561 return _class_resolver(
562 cls, prop, _fallback_dict, arg, favor_tables=favor_tables
563 )
564
565 def resolve_name(
566 arg: str,
567 ) -> Callable[[], Union[Type[Any], Table, _ModNS]]:
568 return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name
569
570 return resolve_name, resolve_arg