Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/sqlalchemy/orm/clsregistry.py: 29%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

338 statements  

1# orm/clsregistry.py 

2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: https://www.opensource.org/licenses/mit-license.php 

7 

8"""Routines to handle the string class registry used by declarative. 

9 

10This system allows specification of classes and expressions used in 

11:func:`_orm.relationship` using strings. 

12 

13""" 

14 

15from __future__ import annotations 

16 

17import re 

18from typing import Any 

19from typing import Callable 

20from typing import cast 

21from typing import Dict 

22from typing import Generator 

23from typing import Iterable 

24from typing import List 

25from typing import Mapping 

26from typing import MutableMapping 

27from typing import NoReturn 

28from typing import Optional 

29from typing import Set 

30from typing import Tuple 

31from typing import Type 

32from typing import TYPE_CHECKING 

33from typing import TypeVar 

34from typing import Union 

35import weakref 

36 

37from . import attributes 

38from . import interfaces 

39from .descriptor_props import SynonymProperty 

40from .properties import ColumnProperty 

41from .util import _metadata_for_cls 

42from .util import class_mapper 

43from .. import exc 

44from .. import inspection 

45from .. import util 

46from ..sql.schema import _get_table_key 

47from ..util.typing import CallableReference 

48 

49if TYPE_CHECKING: 

50 from .relationships import RelationshipProperty 

51 from ..sql.schema import MetaData 

52 from ..sql.schema import Table 

53 

54_T = TypeVar("_T", bound=Any) 

55 

56_ClsRegistryType = MutableMapping[str, Union[type, "_ClsRegistryToken"]] 

57 

58# strong references to registries which we place in 

59# the _decl_class_registry, which is usually weak referencing. 

60# the internal registries here link to classes with weakrefs and remove 

61# themselves when all references to contained classes are removed. 

62_registries: Set[_ClsRegistryToken] = set() 

63 

64 

65def _add_class( 

66 classname: str, cls: Type[_T], decl_class_registry: _ClsRegistryType 

67) -> None: 

68 """Add a class to the _decl_class_registry associated with the 

69 given declarative class. 

70 

71 """ 

72 if classname in decl_class_registry: 

73 # class already exists. 

74 existing = decl_class_registry[classname] 

75 if not isinstance(existing, _MultipleClassMarker): 

76 decl_class_registry[classname] = _MultipleClassMarker( 

77 [cls, cast("Type[Any]", existing)] 

78 ) 

79 else: 

80 decl_class_registry[classname] = cls 

81 

82 try: 

83 root_module = cast( 

84 _ModuleMarker, decl_class_registry["_sa_module_registry"] 

85 ) 

86 except KeyError: 

87 decl_class_registry["_sa_module_registry"] = root_module = ( 

88 _ModuleMarker("_sa_module_registry", None) 

89 ) 

90 

91 tokens = cls.__module__.split(".") 

92 

93 # build up a tree like this: 

94 # modulename: myapp.snacks.nuts 

95 # 

96 # myapp->snack->nuts->(classes) 

97 # snack->nuts->(classes) 

98 # nuts->(classes) 

99 # 

100 # this allows partial token paths to be used. 

101 while tokens: 

102 token = tokens.pop(0) 

103 module = root_module.get_module(token) 

104 for token in tokens: 

105 module = module.get_module(token) 

106 

107 try: 

108 module.add_class(classname, cls) 

109 except AttributeError as ae: 

110 if not isinstance(module, _ModuleMarker): 

111 raise exc.InvalidRequestError( 

112 f'name "{classname}" matches both a ' 

113 "class name and a module name" 

114 ) from ae 

115 else: 

116 raise 

117 

118 

119def _remove_class( 

120 classname: str, cls: Type[Any], decl_class_registry: _ClsRegistryType 

121) -> None: 

122 if classname in decl_class_registry: 

123 existing = decl_class_registry[classname] 

124 if isinstance(existing, _MultipleClassMarker): 

125 existing.remove_item(cls) 

126 else: 

127 del decl_class_registry[classname] 

128 

129 try: 

130 root_module = cast( 

131 _ModuleMarker, decl_class_registry["_sa_module_registry"] 

132 ) 

133 except KeyError: 

134 return 

135 

136 tokens = cls.__module__.split(".") 

137 

138 while tokens: 

139 token = tokens.pop(0) 

140 module = root_module.get_module(token) 

141 for token in tokens: 

142 module = module.get_module(token) 

143 try: 

144 module.remove_class(classname, cls) 

145 except AttributeError: 

146 if not isinstance(module, _ModuleMarker): 

147 pass 

148 else: 

149 raise 

150 

151 

152def _key_is_empty( 

153 key: str, 

154 decl_class_registry: _ClsRegistryType, 

155 test: Callable[[Any], bool], 

156) -> bool: 

157 """test if a key is empty of a certain object. 

158 

159 used for unit tests against the registry to see if garbage collection 

160 is working. 

161 

162 "test" is a callable that will be passed an object should return True 

163 if the given object is the one we were looking for. 

164 

165 We can't pass the actual object itself b.c. this is for testing garbage 

166 collection; the caller will have to have removed references to the 

167 object itself. 

168 

169 """ 

170 if key not in decl_class_registry: 

171 return True 

172 

173 thing = decl_class_registry[key] 

174 if isinstance(thing, _MultipleClassMarker): 

175 for sub_thing in thing.contents: 

176 if test(sub_thing): 

177 return False 

178 else: 

179 raise NotImplementedError("unknown codepath") 

180 else: 

181 return not test(thing) 

182 

183 

184class _ClsRegistryToken: 

185 """an object that can be in the registry._class_registry as a value.""" 

186 

187 __slots__ = () 

188 

189 

190class _MultipleClassMarker(_ClsRegistryToken): 

191 """refers to multiple classes of the same name 

192 within _decl_class_registry. 

193 

194 """ 

195 

196 __slots__ = "on_remove", "contents", "__weakref__" 

197 

198 contents: Set[weakref.ref[Type[Any]]] 

199 on_remove: CallableReference[Optional[Callable[[], None]]] 

200 

201 def __init__( 

202 self, 

203 classes: Iterable[Type[Any]], 

204 on_remove: Optional[Callable[[], None]] = None, 

205 ): 

206 self.on_remove = on_remove 

207 self.contents = { 

208 weakref.ref(item, self._remove_item) for item in classes 

209 } 

210 _registries.add(self) 

211 

212 def remove_item(self, cls: Type[Any]) -> None: 

213 self._remove_item(weakref.ref(cls)) 

214 

215 def __iter__(self) -> Generator[Optional[Type[Any]], None, None]: 

216 return (ref() for ref in self.contents) 

217 

218 def attempt_get(self, path: List[str], key: str) -> Type[Any]: 

219 if len(self.contents) > 1: 

220 raise exc.InvalidRequestError( 

221 'Multiple classes found for path "%s" ' 

222 "in the registry of this declarative " 

223 "base. Please use a fully module-qualified path." 

224 % (".".join(path + [key])) 

225 ) 

226 else: 

227 ref = list(self.contents)[0] 

228 cls = ref() 

229 if cls is None: 

230 raise NameError(key) 

231 return cls 

232 

233 def _remove_item(self, ref: weakref.ref[Type[Any]]) -> None: 

234 self.contents.discard(ref) 

235 if not self.contents: 

236 _registries.discard(self) 

237 if self.on_remove: 

238 self.on_remove() 

239 

240 def add_item(self, item: Type[Any]) -> None: 

241 # protect against class registration race condition against 

242 # asynchronous garbage collection calling _remove_item, 

243 # [ticket:3208] and [ticket:10782] 

244 modules = { 

245 cls.__module__ 

246 for cls in [ref() for ref in list(self.contents)] 

247 if cls is not None 

248 } 

249 if item.__module__ in modules: 

250 util.warn( 

251 "This declarative base already contains a class with the " 

252 "same class name and module name as %s.%s, and will " 

253 "be replaced in the string-lookup table." 

254 % (item.__module__, item.__name__) 

255 ) 

256 self.contents.add(weakref.ref(item, self._remove_item)) 

257 

258 

259class _ModuleMarker(_ClsRegistryToken): 

260 """Refers to a module name within 

261 _decl_class_registry. 

262 

263 """ 

264 

265 __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__" 

266 

267 parent: Optional[_ModuleMarker] 

268 contents: Dict[str, Union[_ModuleMarker, _MultipleClassMarker]] 

269 mod_ns: _ModNS 

270 path: List[str] 

271 

272 def __init__(self, name: str, parent: Optional[_ModuleMarker]): 

273 self.parent = parent 

274 self.name = name 

275 self.contents = {} 

276 self.mod_ns = _ModNS(self) 

277 if self.parent: 

278 self.path = self.parent.path + [self.name] 

279 else: 

280 self.path = [] 

281 _registries.add(self) 

282 

283 def __contains__(self, name: str) -> bool: 

284 return name in self.contents 

285 

286 def __getitem__(self, name: str) -> _ClsRegistryToken: 

287 return self.contents[name] 

288 

289 def _remove_item(self, name: str) -> None: 

290 self.contents.pop(name, None) 

291 if not self.contents: 

292 if self.parent is not None: 

293 self.parent._remove_item(self.name) 

294 _registries.discard(self) 

295 

296 def resolve_attr(self, key: str) -> Union[_ModNS, Type[Any]]: 

297 return self.mod_ns.__getattr__(key) 

298 

299 def get_module(self, name: str) -> _ModuleMarker: 

300 if name not in self.contents: 

301 marker = _ModuleMarker(name, self) 

302 self.contents[name] = marker 

303 else: 

304 marker = cast(_ModuleMarker, self.contents[name]) 

305 return marker 

306 

307 def add_class(self, name: str, cls: Type[Any]) -> None: 

308 if name in self.contents: 

309 existing = cast(_MultipleClassMarker, self.contents[name]) 

310 try: 

311 existing.add_item(cls) 

312 except AttributeError as ae: 

313 if not isinstance(existing, _MultipleClassMarker): 

314 raise exc.InvalidRequestError( 

315 f'name "{name}" matches both a ' 

316 "class name and a module name" 

317 ) from ae 

318 else: 

319 raise 

320 else: 

321 self.contents[name] = _MultipleClassMarker( 

322 [cls], on_remove=lambda: self._remove_item(name) 

323 ) 

324 

325 def remove_class(self, name: str, cls: Type[Any]) -> None: 

326 if name in self.contents: 

327 existing = cast(_MultipleClassMarker, self.contents[name]) 

328 existing.remove_item(cls) 

329 

330 

331class _ModNS: 

332 __slots__ = ("__parent",) 

333 

334 __parent: _ModuleMarker 

335 

336 def __init__(self, parent: _ModuleMarker): 

337 self.__parent = parent 

338 

339 def __getattr__(self, key: str) -> Union[_ModNS, Type[Any]]: 

340 try: 

341 value = self.__parent.contents[key] 

342 except KeyError: 

343 pass 

344 else: 

345 if value is not None: 

346 if isinstance(value, _ModuleMarker): 

347 return value.mod_ns 

348 else: 

349 assert isinstance(value, _MultipleClassMarker) 

350 return value.attempt_get(self.__parent.path, key) 

351 raise NameError( 

352 "Module %r has no mapped classes " 

353 "registered under the name %r" % (self.__parent.name, key) 

354 ) 

355 

356 

357class _GetColumns: 

358 __slots__ = ("cls",) 

359 

360 cls: Type[Any] 

361 

362 def __init__(self, cls: Type[Any]): 

363 self.cls = cls 

364 

365 def __getattr__(self, key: str) -> Any: 

366 mp = class_mapper(self.cls, configure=False) 

367 if mp: 

368 if key not in mp.all_orm_descriptors: 

369 raise AttributeError( 

370 "Class %r does not have a mapped column named %r" 

371 % (self.cls, key) 

372 ) 

373 

374 desc = mp.all_orm_descriptors[key] 

375 if desc.extension_type is interfaces.NotExtension.NOT_EXTENSION: 

376 assert isinstance(desc, attributes.QueryableAttribute) 

377 prop = desc.property 

378 if isinstance(prop, SynonymProperty): 

379 key = prop.name 

380 elif not isinstance(prop, ColumnProperty): 

381 raise exc.InvalidRequestError( 

382 "Property %r is not an instance of" 

383 " ColumnProperty (i.e. does not correspond" 

384 " directly to a Column)." % key 

385 ) 

386 return getattr(self.cls, key) 

387 

388 

389inspection._inspects(_GetColumns)( 

390 lambda target: inspection.inspect(target.cls) 

391) 

392 

393 

394class _GetTable: 

395 __slots__ = "key", "metadata" 

396 

397 key: str 

398 metadata: MetaData 

399 

400 def __init__(self, key: str, metadata: MetaData): 

401 self.key = key 

402 self.metadata = metadata 

403 

404 def __getattr__(self, key: str) -> Table: 

405 return self.metadata.tables[_get_table_key(key, self.key)] 

406 

407 

408def _determine_container(key: str, value: Any) -> _GetColumns: 

409 if isinstance(value, _MultipleClassMarker): 

410 value = value.attempt_get([], key) 

411 return _GetColumns(value) 

412 

413 

414class _class_resolver: 

415 __slots__ = ( 

416 "cls", 

417 "prop", 

418 "arg", 

419 "fallback", 

420 "_dict", 

421 "_resolvers", 

422 "tables_only", 

423 ) 

424 

425 cls: Type[Any] 

426 prop: RelationshipProperty[Any] 

427 fallback: Mapping[str, Any] 

428 arg: str 

429 tables_only: bool 

430 _resolvers: Tuple[Callable[[str], Any], ...] 

431 

432 def __init__( 

433 self, 

434 cls: Type[Any], 

435 prop: RelationshipProperty[Any], 

436 fallback: Mapping[str, Any], 

437 arg: str, 

438 tables_only: bool = False, 

439 ): 

440 self.cls = cls 

441 self.prop = prop 

442 self.arg = arg 

443 self.fallback = fallback 

444 self._dict = util.PopulateDict(self._access_cls) 

445 self._resolvers = () 

446 self.tables_only = tables_only 

447 

448 def _resolve_table_key( 

449 self, key: str, metadata: MetaData 

450 ) -> Optional[Table]: 

451 if metadata.schema is not None and "." not in key: 

452 schema_key = _get_table_key(key, metadata.schema) 

453 if schema_key in metadata.tables: 

454 return metadata.tables[schema_key] 

455 if key in metadata.tables: 

456 util.warn_deprecated( 

457 "The string '%s' was resolved to the " 

458 "non-schema-qualified table '%s', however " 

459 "the MetaData object has a default schema " 

460 "of '%s'. In a future version of SQLAlchemy, " 

461 "this unqualified name will be resolved as " 

462 "'%s'. To reference a table without a " 

463 "schema, use the Table object directly." 

464 % (key, key, metadata.schema, schema_key), 

465 "2.1", 

466 ) 

467 return metadata.tables[key] 

468 elif key in metadata.tables: 

469 return metadata.tables[key] 

470 return None 

471 

472 def _access_cls(self, key: str) -> Any: 

473 cls = self.cls 

474 

475 manager = attributes.manager_of_class(cls) 

476 registry = manager.registry 

477 assert registry is not None 

478 decl_class_registry = registry._class_registry 

479 metadata = _metadata_for_cls(cls, registry) 

480 

481 if self.tables_only: 

482 table = self._resolve_table_key(key, metadata) 

483 if table is not None: 

484 return table 

485 elif key in metadata._schemas: 

486 return _GetTable(key, metadata) 

487 

488 if key in decl_class_registry: 

489 dt = _determine_container(key, decl_class_registry[key]) 

490 if self.tables_only: 

491 return dt.cls 

492 else: 

493 return dt 

494 

495 if not self.tables_only: 

496 table = self._resolve_table_key(key, metadata) 

497 if table is not None: 

498 return table 

499 elif key in metadata._schemas: 

500 return _GetTable(key, metadata) 

501 

502 if "_sa_module_registry" in decl_class_registry and key in cast( 

503 _ModuleMarker, decl_class_registry["_sa_module_registry"] 

504 ): 

505 _module_registry = cast( 

506 _ModuleMarker, decl_class_registry["_sa_module_registry"] 

507 ) 

508 return _module_registry.resolve_attr(key) 

509 

510 if self._resolvers: 

511 for resolv in self._resolvers: 

512 value = resolv(key) 

513 if value is not None: 

514 return value 

515 

516 return self.fallback[key] 

517 

518 def _raise_for_name(self, name: str, err: Exception) -> NoReturn: 

519 generic_match = re.match(r"(.+)\[(.+)\]", name) 

520 

521 if generic_match: 

522 clsarg = generic_match.group(2).strip("'") 

523 raise exc.InvalidRequestError( 

524 f"When initializing mapper {self.prop.parent}, " 

525 f'expression "relationship({self.arg!r})" seems to be ' 

526 "using a generic class as the argument to relationship(); " 

527 "please state the generic argument " 

528 "using an annotation, e.g. " 

529 f'"{self.prop.key}: Mapped[{generic_match.group(1)}' 

530 f"['{clsarg}']] = relationship()\"" 

531 ) from err 

532 else: 

533 manager = attributes.manager_of_class(self.cls) 

534 registry = manager.registry 

535 metadata = ( 

536 _metadata_for_cls(self.cls, registry) 

537 if registry is not None 

538 else None 

539 ) 

540 

541 # when deprecated fallback lookup in 

542 # _resolve_table_key is removed, consider adding 

543 # additional context to the error message if the 

544 # unqualified key is located under BLANK_SCHEMA 

545 if metadata is not None and metadata.schema is not None: 

546 schema_key = _get_table_key(name, metadata.schema) 

547 assert schema_key not in metadata.tables 

548 

549 raise exc.InvalidRequestError( 

550 "When initializing mapper %s, expression %r failed to " 

551 "locate a name (%r). If this is a class name, consider " 

552 "adding this relationship() to the %r class after " 

553 "both dependent classes have been defined." 

554 % ( 

555 self.prop.parent, 

556 self.arg, 

557 name, 

558 self.cls, 

559 ) 

560 ) from err 

561 

562 def _resolve_name(self) -> Union[Table, Type[Any], _ModNS]: 

563 name = self.arg 

564 d = self._dict 

565 rval = None 

566 try: 

567 for token in name.split("."): 

568 if rval is None: 

569 rval = d[token] 

570 else: 

571 rval = getattr(rval, token) 

572 except KeyError as err: 

573 self._raise_for_name(name, err) 

574 except NameError as n: 

575 self._raise_for_name(n.args[0], n) 

576 else: 

577 if isinstance(rval, _GetColumns): 

578 return rval.cls 

579 else: 

580 if TYPE_CHECKING: 

581 assert isinstance(rval, (type, Table, _ModNS)) 

582 return rval 

583 

584 def __call__(self) -> Any: 

585 if self.tables_only: 

586 try: 

587 return self._dict[self.arg] 

588 except KeyError as k: 

589 self._raise_for_name(self.arg, k) 

590 else: 

591 try: 

592 x = eval(self.arg, globals(), self._dict) 

593 

594 if isinstance(x, _GetColumns): 

595 return x.cls 

596 else: 

597 return x 

598 except NameError as n: 

599 self._raise_for_name(n.args[0], n) 

600 

601 

602_fallback_dict: Mapping[str, Any] = None # type: ignore 

603 

604 

605def _resolver(cls: Type[Any], prop: RelationshipProperty[Any]) -> Tuple[ 

606 Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]], 

607 Callable[[str, bool], _class_resolver], 

608]: 

609 global _fallback_dict 

610 

611 if _fallback_dict is None: 

612 import sqlalchemy 

613 from . import foreign 

614 from . import remote 

615 

616 _fallback_dict = util.immutabledict(sqlalchemy.__dict__).union( 

617 {"foreign": foreign, "remote": remote} 

618 ) 

619 

620 def resolve_arg(arg: str, tables_only: bool = False) -> _class_resolver: 

621 return _class_resolver( 

622 cls, prop, _fallback_dict, arg, tables_only=tables_only 

623 ) 

624 

625 def resolve_name( 

626 arg: str, 

627 ) -> Callable[[], Union[Type[Any], Table, _ModNS]]: 

628 return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name 

629 

630 return resolve_name, resolve_arg