1# orm/instrumentation.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Defines SQLAlchemy's system of class instrumentation.
10
11This module is usually not directly visible to user applications, but
12defines a large part of the ORM's interactivity.
13
14instrumentation.py deals with registration of end-user classes
15for state tracking. It interacts closely with state.py
16and attributes.py which establish per-instance and per-class-attribute
17instrumentation, respectively.
18
19The class instrumentation system can be customized on a per-class
20or global basis using the :mod:`sqlalchemy.ext.instrumentation`
21module, which provides the means to build and specify
22alternate instrumentation forms.
23
24.. versionchanged: 0.8
25 The instrumentation extension system was moved out of the
26 ORM and into the external :mod:`sqlalchemy.ext.instrumentation`
27 package. When that package is imported, it installs
28 itself within sqlalchemy.orm so that its more comprehensive
29 resolution mechanics take effect.
30
31"""
32
33
34from __future__ import annotations
35
36from typing import Any
37from typing import Callable
38from typing import cast
39from typing import Collection
40from typing import Dict
41from typing import Generic
42from typing import Iterable
43from typing import List
44from typing import Optional
45from typing import Protocol
46from typing import Set
47from typing import Tuple
48from typing import Type
49from typing import TYPE_CHECKING
50from typing import TypeVar
51from typing import Union
52import weakref
53
54from . import base
55from . import collections
56from . import exc
57from . import interfaces
58from . import state
59from ._typing import _O
60from .attributes import _is_collection_attribute_impl
61from .. import util
62from ..event import EventTarget
63from ..util import HasMemoized
64from ..util.typing import Literal
65
66if TYPE_CHECKING:
67 from ._typing import _RegistryType
68 from .attributes import AttributeImpl
69 from .attributes import QueryableAttribute
70 from .collections import _AdaptedCollectionProtocol
71 from .collections import _CollectionFactoryType
72 from .decl_base import _MapperConfig
73 from .events import InstanceEvents
74 from .mapper import Mapper
75 from .state import InstanceState
76 from ..event import dispatcher
77
78_T = TypeVar("_T", bound=Any)
79DEL_ATTR = util.symbol("DEL_ATTR")
80
81
82class _ExpiredAttributeLoaderProto(Protocol):
83 def __call__(
84 self,
85 state: state.InstanceState[Any],
86 toload: Set[str],
87 passive: base.PassiveFlag,
88 ) -> None: ...
89
90
91class _ManagerFactory(Protocol):
92 def __call__(self, class_: Type[_O]) -> ClassManager[_O]: ...
93
94
95class ClassManager(
96 HasMemoized,
97 Dict[str, "QueryableAttribute[Any]"],
98 Generic[_O],
99 EventTarget,
100):
101 """Tracks state information at the class level."""
102
103 dispatch: dispatcher[ClassManager[_O]]
104
105 MANAGER_ATTR = base.DEFAULT_MANAGER_ATTR
106 STATE_ATTR = base.DEFAULT_STATE_ATTR
107
108 _state_setter = staticmethod(util.attrsetter(STATE_ATTR))
109
110 expired_attribute_loader: _ExpiredAttributeLoaderProto
111 "previously known as deferred_scalar_loader"
112
113 init_method: Optional[Callable[..., None]]
114 original_init: Optional[Callable[..., None]] = None
115
116 factory: Optional[_ManagerFactory]
117
118 declarative_scan: Optional[weakref.ref[_MapperConfig]] = None
119
120 registry: _RegistryType
121
122 if not TYPE_CHECKING:
123 # starts as None during setup
124 registry = None
125
126 class_: Type[_O]
127
128 _bases: List[ClassManager[Any]]
129
130 @property
131 @util.deprecated(
132 "1.4",
133 message="The ClassManager.deferred_scalar_loader attribute is now "
134 "named expired_attribute_loader",
135 )
136 def deferred_scalar_loader(self):
137 return self.expired_attribute_loader
138
139 @deferred_scalar_loader.setter
140 @util.deprecated(
141 "1.4",
142 message="The ClassManager.deferred_scalar_loader attribute is now "
143 "named expired_attribute_loader",
144 )
145 def deferred_scalar_loader(self, obj):
146 self.expired_attribute_loader = obj
147
148 def __init__(self, class_):
149 self.class_ = class_
150 self.info = {}
151 self.new_init = None
152 self.local_attrs = {}
153 self.originals = {}
154 self._finalized = False
155 self.factory = None
156 self.init_method = None
157
158 self._bases = [
159 mgr
160 for mgr in cast(
161 "List[Optional[ClassManager[Any]]]",
162 [
163 opt_manager_of_class(base)
164 for base in self.class_.__bases__
165 if isinstance(base, type)
166 ],
167 )
168 if mgr is not None
169 ]
170
171 for base_ in self._bases:
172 self.update(base_)
173
174 cast(
175 "InstanceEvents", self.dispatch._events
176 )._new_classmanager_instance(class_, self)
177
178 for basecls in class_.__mro__:
179 mgr = opt_manager_of_class(basecls)
180 if mgr is not None:
181 self.dispatch._update(mgr.dispatch)
182
183 self.manage()
184
185 if "__del__" in class_.__dict__:
186 util.warn(
187 "__del__() method on class %s will "
188 "cause unreachable cycles and memory leaks, "
189 "as SQLAlchemy instrumentation often creates "
190 "reference cycles. Please remove this method." % class_
191 )
192
193 def _update_state(
194 self,
195 finalize: bool = False,
196 mapper: Optional[Mapper[_O]] = None,
197 registry: Optional[_RegistryType] = None,
198 declarative_scan: Optional[_MapperConfig] = None,
199 expired_attribute_loader: Optional[
200 _ExpiredAttributeLoaderProto
201 ] = None,
202 init_method: Optional[Callable[..., None]] = None,
203 ) -> None:
204 if mapper:
205 self.mapper = mapper #
206 if registry:
207 registry._add_manager(self)
208 if declarative_scan:
209 self.declarative_scan = weakref.ref(declarative_scan)
210 if expired_attribute_loader:
211 self.expired_attribute_loader = expired_attribute_loader
212
213 if init_method:
214 assert not self._finalized, (
215 "class is already instrumented, "
216 "init_method %s can't be applied" % init_method
217 )
218 self.init_method = init_method
219
220 if not self._finalized:
221 self.original_init = (
222 self.init_method
223 if self.init_method is not None
224 and self.class_.__init__ is object.__init__
225 else self.class_.__init__
226 )
227
228 if finalize and not self._finalized:
229 self._finalize()
230
231 def _finalize(self) -> None:
232 if self._finalized:
233 return
234 self._finalized = True
235
236 self._instrument_init()
237
238 _instrumentation_factory.dispatch.class_instrument(self.class_)
239
240 def __hash__(self) -> int: # type: ignore[override]
241 return id(self)
242
243 def __eq__(self, other: Any) -> bool:
244 return other is self
245
246 @property
247 def is_mapped(self) -> bool:
248 return "mapper" in self.__dict__
249
250 @HasMemoized.memoized_attribute
251 def _all_key_set(self):
252 return frozenset(self)
253
254 @HasMemoized.memoized_attribute
255 def _collection_impl_keys(self):
256 return frozenset(
257 [attr.key for attr in self.values() if attr.impl.collection]
258 )
259
260 @HasMemoized.memoized_attribute
261 def _scalar_loader_impls(self):
262 return frozenset(
263 [
264 attr.impl
265 for attr in self.values()
266 if attr.impl.accepts_scalar_loader
267 ]
268 )
269
270 @HasMemoized.memoized_attribute
271 def _loader_impls(self):
272 return frozenset([attr.impl for attr in self.values()])
273
274 @util.memoized_property
275 def mapper(self) -> Mapper[_O]:
276 # raises unless self.mapper has been assigned
277 raise exc.UnmappedClassError(self.class_)
278
279 def _all_sqla_attributes(self, exclude=None):
280 """return an iterator of all classbound attributes that are
281 implement :class:`.InspectionAttr`.
282
283 This includes :class:`.QueryableAttribute` as well as extension
284 types such as :class:`.hybrid_property` and
285 :class:`.AssociationProxy`.
286
287 """
288
289 found: Dict[str, Any] = {}
290
291 # constraints:
292 # 1. yield keys in cls.__dict__ order
293 # 2. if a subclass has the same key as a superclass, include that
294 # key as part of the ordering of the superclass, because an
295 # overridden key is usually installed by the mapper which is going
296 # on a different ordering
297 # 3. don't use getattr() as this fires off descriptors
298
299 for supercls in self.class_.__mro__[0:-1]:
300 inherits = supercls.__mro__[1]
301 for key in supercls.__dict__:
302 found.setdefault(key, supercls)
303 if key in inherits.__dict__:
304 continue
305 val = found[key].__dict__[key]
306 if (
307 isinstance(val, interfaces.InspectionAttr)
308 and val.is_attribute
309 ):
310 yield key, val
311
312 def _get_class_attr_mro(self, key, default=None):
313 """return an attribute on the class without tripping it."""
314
315 for supercls in self.class_.__mro__:
316 if key in supercls.__dict__:
317 return supercls.__dict__[key]
318 else:
319 return default
320
321 def _attr_has_impl(self, key: str) -> bool:
322 """Return True if the given attribute is fully initialized.
323
324 i.e. has an impl.
325 """
326
327 return key in self and self[key].impl is not None
328
329 def _subclass_manager(self, cls: Type[_T]) -> ClassManager[_T]:
330 """Create a new ClassManager for a subclass of this ClassManager's
331 class.
332
333 This is called automatically when attributes are instrumented so that
334 the attributes can be propagated to subclasses against their own
335 class-local manager, without the need for mappers etc. to have already
336 pre-configured managers for the full class hierarchy. Mappers
337 can post-configure the auto-generated ClassManager when needed.
338
339 """
340 return register_class(cls, finalize=False)
341
342 def _instrument_init(self):
343 self.new_init = _generate_init(self.class_, self, self.original_init)
344 self.install_member("__init__", self.new_init)
345
346 @util.memoized_property
347 def _state_constructor(self) -> Type[state.InstanceState[_O]]:
348 self.dispatch.first_init(self, self.class_)
349 return state.InstanceState
350
351 def manage(self):
352 """Mark this instance as the manager for its class."""
353
354 setattr(self.class_, self.MANAGER_ATTR, self)
355
356 @util.hybridmethod
357 def manager_getter(self):
358 return _default_manager_getter
359
360 @util.hybridmethod
361 def state_getter(self):
362 """Return a (instance) -> InstanceState callable.
363
364 "state getter" callables should raise either KeyError or
365 AttributeError if no InstanceState could be found for the
366 instance.
367 """
368
369 return _default_state_getter
370
371 @util.hybridmethod
372 def dict_getter(self):
373 return _default_dict_getter
374
375 def instrument_attribute(
376 self,
377 key: str,
378 inst: QueryableAttribute[Any],
379 propagated: bool = False,
380 ) -> None:
381 if propagated:
382 if key in self.local_attrs:
383 return # don't override local attr with inherited attr
384 else:
385 self.local_attrs[key] = inst
386 self.install_descriptor(key, inst)
387 self._reset_memoizations()
388 self[key] = inst
389
390 for cls in self.class_.__subclasses__():
391 manager = self._subclass_manager(cls)
392 manager.instrument_attribute(key, inst, True)
393
394 def subclass_managers(self, recursive):
395 for cls in self.class_.__subclasses__():
396 mgr = opt_manager_of_class(cls)
397 if mgr is not None and mgr is not self:
398 yield mgr
399 if recursive:
400 yield from mgr.subclass_managers(True)
401
402 def post_configure_attribute(self, key):
403 _instrumentation_factory.dispatch.attribute_instrument(
404 self.class_, key, self[key]
405 )
406
407 def uninstrument_attribute(self, key, propagated=False):
408 if key not in self:
409 return
410 if propagated:
411 if key in self.local_attrs:
412 return # don't get rid of local attr
413 else:
414 del self.local_attrs[key]
415 self.uninstall_descriptor(key)
416 self._reset_memoizations()
417 del self[key]
418 for cls in self.class_.__subclasses__():
419 manager = opt_manager_of_class(cls)
420 if manager:
421 manager.uninstrument_attribute(key, True)
422
423 def unregister(self) -> None:
424 """remove all instrumentation established by this ClassManager."""
425
426 for key in list(self.originals):
427 self.uninstall_member(key)
428
429 self.mapper = None
430 self.dispatch = None # type: ignore
431 self.new_init = None
432 self.info.clear()
433
434 for key in list(self):
435 if key in self.local_attrs:
436 self.uninstrument_attribute(key)
437
438 if self.MANAGER_ATTR in self.class_.__dict__:
439 delattr(self.class_, self.MANAGER_ATTR)
440
441 def install_descriptor(
442 self, key: str, inst: QueryableAttribute[Any]
443 ) -> None:
444 if key in (self.STATE_ATTR, self.MANAGER_ATTR):
445 raise KeyError(
446 "%r: requested attribute name conflicts with "
447 "instrumentation attribute of the same name." % key
448 )
449 setattr(self.class_, key, inst)
450
451 def uninstall_descriptor(self, key: str) -> None:
452 delattr(self.class_, key)
453
454 def install_member(self, key: str, implementation: Any) -> None:
455 if key in (self.STATE_ATTR, self.MANAGER_ATTR):
456 raise KeyError(
457 "%r: requested attribute name conflicts with "
458 "instrumentation attribute of the same name." % key
459 )
460 self.originals.setdefault(key, self.class_.__dict__.get(key, DEL_ATTR))
461 setattr(self.class_, key, implementation)
462
463 def uninstall_member(self, key: str) -> None:
464 original = self.originals.pop(key, None)
465 if original is not DEL_ATTR:
466 setattr(self.class_, key, original)
467 else:
468 delattr(self.class_, key)
469
470 def instrument_collection_class(
471 self, key: str, collection_class: Type[Collection[Any]]
472 ) -> _CollectionFactoryType:
473 return collections.prepare_instrumentation(collection_class)
474
475 def initialize_collection(
476 self,
477 key: str,
478 state: InstanceState[_O],
479 factory: _CollectionFactoryType,
480 ) -> Tuple[collections.CollectionAdapter, _AdaptedCollectionProtocol]:
481 user_data = factory()
482 impl = self.get_impl(key)
483 assert _is_collection_attribute_impl(impl)
484 adapter = collections.CollectionAdapter(impl, state, user_data)
485 return adapter, user_data
486
487 def is_instrumented(self, key: str, search: bool = False) -> bool:
488 if search:
489 return key in self
490 else:
491 return key in self.local_attrs
492
493 def get_impl(self, key: str) -> AttributeImpl:
494 return self[key].impl
495
496 @property
497 def attributes(self) -> Iterable[Any]:
498 return iter(self.values())
499
500 # InstanceState management
501
502 def new_instance(self, state: Optional[InstanceState[_O]] = None) -> _O:
503 # here, we would prefer _O to be bound to "object"
504 # so that mypy sees that __new__ is present. currently
505 # it's bound to Any as there were other problems not having
506 # it that way but these can be revisited
507 instance = self.class_.__new__(self.class_)
508 if state is None:
509 state = self._state_constructor(instance, self)
510 self._state_setter(instance, state)
511 return instance
512
513 def setup_instance(
514 self, instance: _O, state: Optional[InstanceState[_O]] = None
515 ) -> None:
516 if state is None:
517 state = self._state_constructor(instance, self)
518 self._state_setter(instance, state)
519
520 def teardown_instance(self, instance: _O) -> None:
521 delattr(instance, self.STATE_ATTR)
522
523 def _serialize(
524 self, state: InstanceState[_O], state_dict: Dict[str, Any]
525 ) -> _SerializeManager:
526 return _SerializeManager(state, state_dict)
527
528 def _new_state_if_none(
529 self, instance: _O
530 ) -> Union[Literal[False], InstanceState[_O]]:
531 """Install a default InstanceState if none is present.
532
533 A private convenience method used by the __init__ decorator.
534
535 """
536 if hasattr(instance, self.STATE_ATTR):
537 return False
538 elif self.class_ is not instance.__class__ and self.is_mapped:
539 # this will create a new ClassManager for the
540 # subclass, without a mapper. This is likely a
541 # user error situation but allow the object
542 # to be constructed, so that it is usable
543 # in a non-ORM context at least.
544 return self._subclass_manager(
545 instance.__class__
546 )._new_state_if_none(instance)
547 else:
548 state = self._state_constructor(instance, self)
549 self._state_setter(instance, state)
550 return state
551
552 def has_state(self, instance: _O) -> bool:
553 return hasattr(instance, self.STATE_ATTR)
554
555 def has_parent(
556 self, state: InstanceState[_O], key: str, optimistic: bool = False
557 ) -> bool:
558 """TODO"""
559 return self.get_impl(key).hasparent(state, optimistic=optimistic)
560
561 def __bool__(self) -> bool:
562 """All ClassManagers are non-zero regardless of attribute state."""
563 return True
564
565 def __repr__(self) -> str:
566 return "<%s of %r at %x>" % (
567 self.__class__.__name__,
568 self.class_,
569 id(self),
570 )
571
572
573class _SerializeManager:
574 """Provide serialization of a :class:`.ClassManager`.
575
576 The :class:`.InstanceState` uses ``__init__()`` on serialize
577 and ``__call__()`` on deserialize.
578
579 """
580
581 def __init__(self, state: state.InstanceState[Any], d: Dict[str, Any]):
582 self.class_ = state.class_
583 manager = state.manager
584 manager.dispatch.pickle(state, d)
585
586 def __call__(self, state, inst, state_dict):
587 state.manager = manager = opt_manager_of_class(self.class_)
588 if manager is None:
589 raise exc.UnmappedInstanceError(
590 inst,
591 "Cannot deserialize object of type %r - "
592 "no mapper() has "
593 "been configured for this class within the current "
594 "Python process!" % self.class_,
595 )
596 elif manager.is_mapped and not manager.mapper.configured:
597 manager.mapper._check_configure()
598
599 # setup _sa_instance_state ahead of time so that
600 # unpickle events can access the object normally.
601 # see [ticket:2362]
602 if inst is not None:
603 manager.setup_instance(inst, state)
604 manager.dispatch.unpickle(state, state_dict)
605
606
607class InstrumentationFactory(EventTarget):
608 """Factory for new ClassManager instances."""
609
610 dispatch: dispatcher[InstrumentationFactory]
611
612 def create_manager_for_cls(self, class_: Type[_O]) -> ClassManager[_O]:
613 assert class_ is not None
614 assert opt_manager_of_class(class_) is None
615
616 # give a more complicated subclass
617 # a chance to do what it wants here
618 manager, factory = self._locate_extended_factory(class_)
619
620 if factory is None:
621 factory = ClassManager
622 manager = ClassManager(class_)
623 else:
624 assert manager is not None
625
626 self._check_conflicts(class_, factory)
627
628 manager.factory = factory
629
630 return manager
631
632 def _locate_extended_factory(
633 self, class_: Type[_O]
634 ) -> Tuple[Optional[ClassManager[_O]], Optional[_ManagerFactory]]:
635 """Overridden by a subclass to do an extended lookup."""
636 return None, None
637
638 def _check_conflicts(
639 self, class_: Type[_O], factory: Callable[[Type[_O]], ClassManager[_O]]
640 ) -> None:
641 """Overridden by a subclass to test for conflicting factories."""
642
643 def unregister(self, class_: Type[_O]) -> None:
644 manager = manager_of_class(class_)
645 manager.unregister()
646 self.dispatch.class_uninstrument(class_)
647
648
649# this attribute is replaced by sqlalchemy.ext.instrumentation
650# when imported.
651_instrumentation_factory = InstrumentationFactory()
652
653# these attributes are replaced by sqlalchemy.ext.instrumentation
654# when a non-standard InstrumentationManager class is first
655# used to instrument a class.
656instance_state = _default_state_getter = base.instance_state
657
658instance_dict = _default_dict_getter = base.instance_dict
659
660manager_of_class = _default_manager_getter = base.manager_of_class
661opt_manager_of_class = _default_opt_manager_getter = base.opt_manager_of_class
662
663
664def register_class(
665 class_: Type[_O],
666 finalize: bool = True,
667 mapper: Optional[Mapper[_O]] = None,
668 registry: Optional[_RegistryType] = None,
669 declarative_scan: Optional[_MapperConfig] = None,
670 expired_attribute_loader: Optional[_ExpiredAttributeLoaderProto] = None,
671 init_method: Optional[Callable[..., None]] = None,
672) -> ClassManager[_O]:
673 """Register class instrumentation.
674
675 Returns the existing or newly created class manager.
676
677 """
678
679 manager = opt_manager_of_class(class_)
680 if manager is None:
681 manager = _instrumentation_factory.create_manager_for_cls(class_)
682 manager._update_state(
683 mapper=mapper,
684 registry=registry,
685 declarative_scan=declarative_scan,
686 expired_attribute_loader=expired_attribute_loader,
687 init_method=init_method,
688 finalize=finalize,
689 )
690
691 return manager
692
693
694def unregister_class(class_):
695 """Unregister class instrumentation."""
696
697 _instrumentation_factory.unregister(class_)
698
699
700def is_instrumented(instance, key):
701 """Return True if the given attribute on the given instance is
702 instrumented by the attributes package.
703
704 This function may be used regardless of instrumentation
705 applied directly to the class, i.e. no descriptors are required.
706
707 """
708 return manager_of_class(instance.__class__).is_instrumented(
709 key, search=True
710 )
711
712
713def _generate_init(class_, class_manager, original_init):
714 """Build an __init__ decorator that triggers ClassManager events."""
715
716 # TODO: we should use the ClassManager's notion of the
717 # original '__init__' method, once ClassManager is fixed
718 # to always reference that.
719
720 if original_init is None:
721 original_init = class_.__init__
722
723 # Go through some effort here and don't change the user's __init__
724 # calling signature, including the unlikely case that it has
725 # a return value.
726 # FIXME: need to juggle local names to avoid constructor argument
727 # clashes.
728 func_body = """\
729def __init__(%(apply_pos)s):
730 new_state = class_manager._new_state_if_none(%(self_arg)s)
731 if new_state:
732 return new_state._initialize_instance(%(apply_kw)s)
733 else:
734 return original_init(%(apply_kw)s)
735"""
736 func_vars = util.format_argspec_init(original_init, grouped=False)
737 func_text = func_body % func_vars
738
739 func_defaults = getattr(original_init, "__defaults__", None)
740 func_kw_defaults = getattr(original_init, "__kwdefaults__", None)
741
742 env = locals().copy()
743 env["__name__"] = __name__
744 exec(func_text, env)
745 __init__ = env["__init__"]
746 __init__.__doc__ = original_init.__doc__
747 __init__._sa_original_init = original_init
748
749 if func_defaults:
750 __init__.__defaults__ = func_defaults
751 if func_kw_defaults:
752 __init__.__kwdefaults__ = func_kw_defaults
753
754 return __init__