Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/sqlalchemy/orm/clsregistry.py: 31%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

308 statements  

1# orm/clsregistry.py 

2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: https://www.opensource.org/licenses/mit-license.php 

7 

8"""Routines to handle the string class registry used by declarative. 

9 

10This system allows specification of classes and expressions used in 

11:func:`_orm.relationship` using strings. 

12 

13""" 

14 

15from __future__ import annotations 

16 

17import re 

18from typing import Any 

19from typing import Callable 

20from typing import cast 

21from typing import Dict 

22from typing import Generator 

23from typing import Iterable 

24from typing import List 

25from typing import Mapping 

26from typing import MutableMapping 

27from typing import NoReturn 

28from typing import Optional 

29from typing import Set 

30from typing import Tuple 

31from typing import Type 

32from typing import TYPE_CHECKING 

33from typing import TypeVar 

34from typing import Union 

35import weakref 

36 

37from . import attributes 

38from . import interfaces 

39from .descriptor_props import SynonymProperty 

40from .properties import ColumnProperty 

41from .util import class_mapper 

42from .. import exc 

43from .. import inspection 

44from .. import util 

45from ..sql.schema import _get_table_key 

46from ..util.typing import CallableReference 

47 

48if TYPE_CHECKING: 

49 from .relationships import RelationshipProperty 

50 from ..sql.schema import MetaData 

51 from ..sql.schema import Table 

52 

53_T = TypeVar("_T", bound=Any) 

54 

55_ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]] 

56 

57# strong references to registries which we place in 

58# the _decl_class_registry, which is usually weak referencing. 

59# the internal registries here link to classes with weakrefs and remove 

60# themselves when all references to contained classes are removed. 

61_registries: Set[ClsRegistryToken] = set() 

62 

63 

64def add_class( 

65 classname: str, cls: Type[_T], decl_class_registry: _ClsRegistryType 

66) -> None: 

67 """Add a class to the _decl_class_registry associated with the 

68 given declarative class. 

69 

70 """ 

71 if classname in decl_class_registry: 

72 # class already exists. 

73 existing = decl_class_registry[classname] 

74 if not isinstance(existing, _MultipleClassMarker): 

75 existing = decl_class_registry[classname] = _MultipleClassMarker( 

76 [cls, cast("Type[Any]", existing)] 

77 ) 

78 else: 

79 decl_class_registry[classname] = cls 

80 

81 try: 

82 root_module = cast( 

83 _ModuleMarker, decl_class_registry["_sa_module_registry"] 

84 ) 

85 except KeyError: 

86 decl_class_registry["_sa_module_registry"] = root_module = ( 

87 _ModuleMarker("_sa_module_registry", None) 

88 ) 

89 

90 tokens = cls.__module__.split(".") 

91 

92 # build up a tree like this: 

93 # modulename: myapp.snacks.nuts 

94 # 

95 # myapp->snack->nuts->(classes) 

96 # snack->nuts->(classes) 

97 # nuts->(classes) 

98 # 

99 # this allows partial token paths to be used. 

100 while tokens: 

101 token = tokens.pop(0) 

102 module = root_module.get_module(token) 

103 for token in tokens: 

104 module = module.get_module(token) 

105 

106 try: 

107 module.add_class(classname, cls) 

108 except AttributeError as ae: 

109 if not isinstance(module, _ModuleMarker): 

110 raise exc.InvalidRequestError( 

111 f'name "{classname}" matches both a ' 

112 "class name and a module name" 

113 ) from ae 

114 else: 

115 raise 

116 

117 

118def remove_class( 

119 classname: str, cls: Type[Any], decl_class_registry: _ClsRegistryType 

120) -> None: 

121 if classname in decl_class_registry: 

122 existing = decl_class_registry[classname] 

123 if isinstance(existing, _MultipleClassMarker): 

124 existing.remove_item(cls) 

125 else: 

126 del decl_class_registry[classname] 

127 

128 try: 

129 root_module = cast( 

130 _ModuleMarker, decl_class_registry["_sa_module_registry"] 

131 ) 

132 except KeyError: 

133 return 

134 

135 tokens = cls.__module__.split(".") 

136 

137 while tokens: 

138 token = tokens.pop(0) 

139 module = root_module.get_module(token) 

140 for token in tokens: 

141 module = module.get_module(token) 

142 try: 

143 module.remove_class(classname, cls) 

144 except AttributeError: 

145 if not isinstance(module, _ModuleMarker): 

146 pass 

147 else: 

148 raise 

149 

150 

151def _key_is_empty( 

152 key: str, 

153 decl_class_registry: _ClsRegistryType, 

154 test: Callable[[Any], bool], 

155) -> bool: 

156 """test if a key is empty of a certain object. 

157 

158 used for unit tests against the registry to see if garbage collection 

159 is working. 

160 

161 "test" is a callable that will be passed an object should return True 

162 if the given object is the one we were looking for. 

163 

164 We can't pass the actual object itself b.c. this is for testing garbage 

165 collection; the caller will have to have removed references to the 

166 object itself. 

167 

168 """ 

169 if key not in decl_class_registry: 

170 return True 

171 

172 thing = decl_class_registry[key] 

173 if isinstance(thing, _MultipleClassMarker): 

174 for sub_thing in thing.contents: 

175 if test(sub_thing): 

176 return False 

177 else: 

178 raise NotImplementedError("unknown codepath") 

179 else: 

180 return not test(thing) 

181 

182 

183class ClsRegistryToken: 

184 """an object that can be in the registry._class_registry as a value.""" 

185 

186 __slots__ = () 

187 

188 

189class _MultipleClassMarker(ClsRegistryToken): 

190 """refers to multiple classes of the same name 

191 within _decl_class_registry. 

192 

193 """ 

194 

195 __slots__ = "on_remove", "contents", "__weakref__" 

196 

197 contents: Set[weakref.ref[Type[Any]]] 

198 on_remove: CallableReference[Optional[Callable[[], None]]] 

199 

200 def __init__( 

201 self, 

202 classes: Iterable[Type[Any]], 

203 on_remove: Optional[Callable[[], None]] = None, 

204 ): 

205 self.on_remove = on_remove 

206 self.contents = { 

207 weakref.ref(item, self._remove_item) for item in classes 

208 } 

209 _registries.add(self) 

210 

211 def remove_item(self, cls: Type[Any]) -> None: 

212 self._remove_item(weakref.ref(cls)) 

213 

214 def __iter__(self) -> Generator[Optional[Type[Any]], None, None]: 

215 return (ref() for ref in self.contents) 

216 

217 def attempt_get(self, path: List[str], key: str) -> Type[Any]: 

218 if len(self.contents) > 1: 

219 raise exc.InvalidRequestError( 

220 'Multiple classes found for path "%s" ' 

221 "in the registry of this declarative " 

222 "base. Please use a fully module-qualified path." 

223 % (".".join(path + [key])) 

224 ) 

225 else: 

226 ref = list(self.contents)[0] 

227 cls = ref() 

228 if cls is None: 

229 raise NameError(key) 

230 return cls 

231 

232 def _remove_item(self, ref: weakref.ref[Type[Any]]) -> None: 

233 self.contents.discard(ref) 

234 if not self.contents: 

235 _registries.discard(self) 

236 if self.on_remove: 

237 self.on_remove() 

238 

239 def add_item(self, item: Type[Any]) -> None: 

240 # protect against class registration race condition against 

241 # asynchronous garbage collection calling _remove_item, 

242 # [ticket:3208] and [ticket:10782] 

243 modules = { 

244 cls.__module__ 

245 for cls in [ref() for ref in list(self.contents)] 

246 if cls is not None 

247 } 

248 if item.__module__ in modules: 

249 util.warn( 

250 "This declarative base already contains a class with the " 

251 "same class name and module name as %s.%s, and will " 

252 "be replaced in the string-lookup table." 

253 % (item.__module__, item.__name__) 

254 ) 

255 self.contents.add(weakref.ref(item, self._remove_item)) 

256 

257 

258class _ModuleMarker(ClsRegistryToken): 

259 """Refers to a module name within 

260 _decl_class_registry. 

261 

262 """ 

263 

264 __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__" 

265 

266 parent: Optional[_ModuleMarker] 

267 contents: Dict[str, Union[_ModuleMarker, _MultipleClassMarker]] 

268 mod_ns: _ModNS 

269 path: List[str] 

270 

271 def __init__(self, name: str, parent: Optional[_ModuleMarker]): 

272 self.parent = parent 

273 self.name = name 

274 self.contents = {} 

275 self.mod_ns = _ModNS(self) 

276 if self.parent: 

277 self.path = self.parent.path + [self.name] 

278 else: 

279 self.path = [] 

280 _registries.add(self) 

281 

282 def __contains__(self, name: str) -> bool: 

283 return name in self.contents 

284 

285 def __getitem__(self, name: str) -> ClsRegistryToken: 

286 return self.contents[name] 

287 

288 def _remove_item(self, name: str) -> None: 

289 self.contents.pop(name, None) 

290 if not self.contents and self.parent is not None: 

291 self.parent._remove_item(self.name) 

292 _registries.discard(self) 

293 

294 def resolve_attr(self, key: str) -> Union[_ModNS, Type[Any]]: 

295 return self.mod_ns.__getattr__(key) 

296 

297 def get_module(self, name: str) -> _ModuleMarker: 

298 if name not in self.contents: 

299 marker = _ModuleMarker(name, self) 

300 self.contents[name] = marker 

301 else: 

302 marker = cast(_ModuleMarker, self.contents[name]) 

303 return marker 

304 

305 def add_class(self, name: str, cls: Type[Any]) -> None: 

306 if name in self.contents: 

307 existing = cast(_MultipleClassMarker, self.contents[name]) 

308 try: 

309 existing.add_item(cls) 

310 except AttributeError as ae: 

311 if not isinstance(existing, _MultipleClassMarker): 

312 raise exc.InvalidRequestError( 

313 f'name "{name}" matches both a ' 

314 "class name and a module name" 

315 ) from ae 

316 else: 

317 raise 

318 else: 

319 existing = self.contents[name] = _MultipleClassMarker( 

320 [cls], on_remove=lambda: self._remove_item(name) 

321 ) 

322 

323 def remove_class(self, name: str, cls: Type[Any]) -> None: 

324 if name in self.contents: 

325 existing = cast(_MultipleClassMarker, self.contents[name]) 

326 existing.remove_item(cls) 

327 

328 

329class _ModNS: 

330 __slots__ = ("__parent",) 

331 

332 __parent: _ModuleMarker 

333 

334 def __init__(self, parent: _ModuleMarker): 

335 self.__parent = parent 

336 

337 def __getattr__(self, key: str) -> Union[_ModNS, Type[Any]]: 

338 try: 

339 value = self.__parent.contents[key] 

340 except KeyError: 

341 pass 

342 else: 

343 if value is not None: 

344 if isinstance(value, _ModuleMarker): 

345 return value.mod_ns 

346 else: 

347 assert isinstance(value, _MultipleClassMarker) 

348 return value.attempt_get(self.__parent.path, key) 

349 raise NameError( 

350 "Module %r has no mapped classes " 

351 "registered under the name %r" % (self.__parent.name, key) 

352 ) 

353 

354 

355class _GetColumns: 

356 __slots__ = ("cls",) 

357 

358 cls: Type[Any] 

359 

360 def __init__(self, cls: Type[Any]): 

361 self.cls = cls 

362 

363 def __getattr__(self, key: str) -> Any: 

364 mp = class_mapper(self.cls, configure=False) 

365 if mp: 

366 if key not in mp.all_orm_descriptors: 

367 raise AttributeError( 

368 "Class %r does not have a mapped column named %r" 

369 % (self.cls, key) 

370 ) 

371 

372 desc = mp.all_orm_descriptors[key] 

373 if desc.extension_type is interfaces.NotExtension.NOT_EXTENSION: 

374 assert isinstance(desc, attributes.QueryableAttribute) 

375 prop = desc.property 

376 if isinstance(prop, SynonymProperty): 

377 key = prop.name 

378 elif not isinstance(prop, ColumnProperty): 

379 raise exc.InvalidRequestError( 

380 "Property %r is not an instance of" 

381 " ColumnProperty (i.e. does not correspond" 

382 " directly to a Column)." % key 

383 ) 

384 return getattr(self.cls, key) 

385 

386 

387inspection._inspects(_GetColumns)( 

388 lambda target: inspection.inspect(target.cls) 

389) 

390 

391 

392class _GetTable: 

393 __slots__ = "key", "metadata" 

394 

395 key: str 

396 metadata: MetaData 

397 

398 def __init__(self, key: str, metadata: MetaData): 

399 self.key = key 

400 self.metadata = metadata 

401 

402 def __getattr__(self, key: str) -> Table: 

403 return self.metadata.tables[_get_table_key(key, self.key)] 

404 

405 

406def _determine_container(key: str, value: Any) -> _GetColumns: 

407 if isinstance(value, _MultipleClassMarker): 

408 value = value.attempt_get([], key) 

409 return _GetColumns(value) 

410 

411 

412class _class_resolver: 

413 __slots__ = ( 

414 "cls", 

415 "prop", 

416 "arg", 

417 "fallback", 

418 "_dict", 

419 "_resolvers", 

420 "favor_tables", 

421 ) 

422 

423 cls: Type[Any] 

424 prop: RelationshipProperty[Any] 

425 fallback: Mapping[str, Any] 

426 arg: str 

427 favor_tables: bool 

428 _resolvers: Tuple[Callable[[str], Any], ...] 

429 

430 def __init__( 

431 self, 

432 cls: Type[Any], 

433 prop: RelationshipProperty[Any], 

434 fallback: Mapping[str, Any], 

435 arg: str, 

436 favor_tables: bool = False, 

437 ): 

438 self.cls = cls 

439 self.prop = prop 

440 self.arg = arg 

441 self.fallback = fallback 

442 self._dict = util.PopulateDict(self._access_cls) 

443 self._resolvers = () 

444 self.favor_tables = favor_tables 

445 

446 def _access_cls(self, key: str) -> Any: 

447 cls = self.cls 

448 

449 manager = attributes.manager_of_class(cls) 

450 decl_base = manager.registry 

451 assert decl_base is not None 

452 decl_class_registry = decl_base._class_registry 

453 metadata = decl_base.metadata 

454 

455 if self.favor_tables: 

456 if key in metadata.tables: 

457 return metadata.tables[key] 

458 elif key in metadata._schemas: 

459 return _GetTable(key, getattr(cls, "metadata", metadata)) 

460 

461 if key in decl_class_registry: 

462 return _determine_container(key, decl_class_registry[key]) 

463 

464 if not self.favor_tables: 

465 if key in metadata.tables: 

466 return metadata.tables[key] 

467 elif key in metadata._schemas: 

468 return _GetTable(key, getattr(cls, "metadata", metadata)) 

469 

470 if "_sa_module_registry" in decl_class_registry and key in cast( 

471 _ModuleMarker, decl_class_registry["_sa_module_registry"] 

472 ): 

473 registry = cast( 

474 _ModuleMarker, decl_class_registry["_sa_module_registry"] 

475 ) 

476 return registry.resolve_attr(key) 

477 elif self._resolvers: 

478 for resolv in self._resolvers: 

479 value = resolv(key) 

480 if value is not None: 

481 return value 

482 

483 return self.fallback[key] 

484 

485 def _raise_for_name(self, name: str, err: Exception) -> NoReturn: 

486 generic_match = re.match(r"(.+)\[(.+)\]", name) 

487 

488 if generic_match: 

489 clsarg = generic_match.group(2).strip("'") 

490 raise exc.InvalidRequestError( 

491 f"When initializing mapper {self.prop.parent}, " 

492 f'expression "relationship({self.arg!r})" seems to be ' 

493 "using a generic class as the argument to relationship(); " 

494 "please state the generic argument " 

495 "using an annotation, e.g. " 

496 f'"{self.prop.key}: Mapped[{generic_match.group(1)}' 

497 f"['{clsarg}']] = relationship()\"" 

498 ) from err 

499 else: 

500 raise exc.InvalidRequestError( 

501 "When initializing mapper %s, expression %r failed to " 

502 "locate a name (%r). If this is a class name, consider " 

503 "adding this relationship() to the %r class after " 

504 "both dependent classes have been defined." 

505 % (self.prop.parent, self.arg, name, self.cls) 

506 ) from err 

507 

508 def _resolve_name(self) -> Union[Table, Type[Any], _ModNS]: 

509 name = self.arg 

510 d = self._dict 

511 rval = None 

512 try: 

513 for token in name.split("."): 

514 if rval is None: 

515 rval = d[token] 

516 else: 

517 rval = getattr(rval, token) 

518 except KeyError as err: 

519 self._raise_for_name(name, err) 

520 except NameError as n: 

521 self._raise_for_name(n.args[0], n) 

522 else: 

523 if isinstance(rval, _GetColumns): 

524 return rval.cls 

525 else: 

526 if TYPE_CHECKING: 

527 assert isinstance(rval, (type, Table, _ModNS)) 

528 return rval 

529 

530 def __call__(self) -> Any: 

531 try: 

532 x = eval(self.arg, globals(), self._dict) 

533 

534 if isinstance(x, _GetColumns): 

535 return x.cls 

536 else: 

537 return x 

538 except NameError as n: 

539 self._raise_for_name(n.args[0], n) 

540 

541 

542_fallback_dict: Mapping[str, Any] = None # type: ignore 

543 

544 

545def _resolver(cls: Type[Any], prop: RelationshipProperty[Any]) -> Tuple[ 

546 Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]], 

547 Callable[[str, bool], _class_resolver], 

548]: 

549 global _fallback_dict 

550 

551 if _fallback_dict is None: 

552 import sqlalchemy 

553 from . import foreign 

554 from . import remote 

555 

556 _fallback_dict = util.immutabledict(sqlalchemy.__dict__).union( 

557 {"foreign": foreign, "remote": remote} 

558 ) 

559 

560 def resolve_arg(arg: str, favor_tables: bool = False) -> _class_resolver: 

561 return _class_resolver( 

562 cls, prop, _fallback_dict, arg, favor_tables=favor_tables 

563 ) 

564 

565 def resolve_name( 

566 arg: str, 

567 ) -> Callable[[], Union[Type[Any], Table, _ModNS]]: 

568 return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name 

569 

570 return resolve_name, resolve_arg