1# orm/instrumentation.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Defines SQLAlchemy's system of class instrumentation.
10
11This module is usually not directly visible to user applications, but
12defines a large part of the ORM's interactivity.
13
14instrumentation.py deals with registration of end-user classes
15for state tracking. It interacts closely with state.py
16and attributes.py which establish per-instance and per-class-attribute
17instrumentation, respectively.
18
19The class instrumentation system can be customized on a per-class
20or global basis using the :mod:`sqlalchemy.ext.instrumentation`
21module, which provides the means to build and specify
22alternate instrumentation forms.
23
24"""
25
26from __future__ import annotations
27
28from typing import Any
29from typing import Callable
30from typing import cast
31from typing import Collection
32from typing import Dict
33from typing import Generic
34from typing import Iterable
35from typing import List
36from typing import Literal
37from typing import Optional
38from typing import Protocol
39from typing import Set
40from typing import Tuple
41from typing import Type
42from typing import TYPE_CHECKING
43from typing import TypeVar
44from typing import Union
45import weakref
46
47from . import base
48from . import collections
49from . import exc
50from . import interfaces
51from . import state
52from ._typing import _O
53from .attributes import _is_collection_attribute_impl
54from .. import util
55from ..event import EventTarget
56from ..util import HasMemoized
57
58if TYPE_CHECKING:
59 from ._typing import _RegistryType
60 from .attributes import _AttributeImpl
61 from .attributes import QueryableAttribute
62 from .collections import _AdaptedCollectionProtocol
63 from .collections import _CollectionFactoryType
64 from .decl_base import _MapperConfig
65 from .events import InstanceEvents
66 from .mapper import Mapper
67 from .state import InstanceState
68 from ..event import dispatcher
69
70_T = TypeVar("_T", bound=Any)
71DEL_ATTR = util.symbol("DEL_ATTR")
72
73
74class _ExpiredAttributeLoaderProto(Protocol):
75 def __call__(
76 self,
77 state: state.InstanceState[Any],
78 toload: Set[str],
79 passive: base.PassiveFlag,
80 ) -> None: ...
81
82
83class _ManagerFactory(Protocol):
84 def __call__(self, class_: Type[_O]) -> ClassManager[_O]: ...
85
86
87class ClassManager(
88 HasMemoized,
89 Dict[str, "QueryableAttribute[Any]"],
90 Generic[_O],
91 EventTarget,
92):
93 """Tracks state information at the class level."""
94
95 dispatch: dispatcher[ClassManager[_O]]
96
97 MANAGER_ATTR = base.DEFAULT_MANAGER_ATTR
98 STATE_ATTR = base.DEFAULT_STATE_ATTR
99
100 _state_setter = staticmethod(util.attrsetter(STATE_ATTR))
101
102 expired_attribute_loader: _ExpiredAttributeLoaderProto
103 "previously known as deferred_scalar_loader"
104
105 init_method: Optional[Callable[..., None]]
106 original_init: Optional[Callable[..., None]] = None
107
108 factory: Optional[_ManagerFactory]
109
110 declarative_scan: Optional[weakref.ref[_MapperConfig]] = None
111
112 registry: _RegistryType
113
114 if not TYPE_CHECKING:
115 # starts as None during setup
116 registry = None
117
118 class_: Type[_O]
119
120 _bases: List[ClassManager[Any]]
121
122 @property
123 @util.deprecated(
124 "1.4",
125 message="The ClassManager.deferred_scalar_loader attribute is now "
126 "named expired_attribute_loader",
127 )
128 def deferred_scalar_loader(self):
129 return self.expired_attribute_loader
130
131 @deferred_scalar_loader.setter
132 @util.deprecated(
133 "1.4",
134 message="The ClassManager.deferred_scalar_loader attribute is now "
135 "named expired_attribute_loader",
136 )
137 def deferred_scalar_loader(self, obj):
138 self.expired_attribute_loader = obj
139
140 def __init__(self, class_):
141 self.class_ = class_
142 self.info = {}
143 self.new_init = None
144 self.local_attrs = {}
145 self.originals = {}
146 self._finalized = False
147 self.factory = None
148 self.init_method = None
149
150 self._bases = [
151 mgr
152 for mgr in cast(
153 "List[Optional[ClassManager[Any]]]",
154 [
155 opt_manager_of_class(base)
156 for base in self.class_.__bases__
157 if isinstance(base, type)
158 ],
159 )
160 if mgr is not None
161 ]
162
163 for base_ in self._bases:
164 self.update(base_)
165
166 cast(
167 "InstanceEvents", self.dispatch._events
168 )._new_classmanager_instance(class_, self)
169
170 for basecls in class_.__mro__:
171 mgr = opt_manager_of_class(basecls)
172 if mgr is not None:
173 self.dispatch._update(mgr.dispatch)
174
175 self.manage()
176
177 if "__del__" in class_.__dict__:
178 util.warn(
179 "__del__() method on class %s will "
180 "cause unreachable cycles and memory leaks, "
181 "as SQLAlchemy instrumentation often creates "
182 "reference cycles. Please remove this method." % class_
183 )
184
185 def _update_state(
186 self,
187 finalize: bool = False,
188 mapper: Optional[Mapper[_O]] = None,
189 registry: Optional[_RegistryType] = None,
190 declarative_scan: Optional[_MapperConfig] = None,
191 expired_attribute_loader: Optional[
192 _ExpiredAttributeLoaderProto
193 ] = None,
194 init_method: Optional[Callable[..., None]] = None,
195 ) -> None:
196 if mapper:
197 self.mapper = mapper #
198 if registry:
199 registry._add_manager(self)
200 if declarative_scan:
201 self.declarative_scan = weakref.ref(declarative_scan)
202 if expired_attribute_loader:
203 self.expired_attribute_loader = expired_attribute_loader
204
205 if init_method:
206 assert not self._finalized, (
207 "class is already instrumented, "
208 "init_method %s can't be applied" % init_method
209 )
210 self.init_method = init_method
211
212 if not self._finalized:
213 self.original_init = (
214 self.init_method
215 if self.init_method is not None
216 and self.class_.__init__ is object.__init__
217 else self.class_.__init__
218 )
219
220 if finalize and not self._finalized:
221 self._finalize()
222
223 def _finalize(self) -> None:
224 if self._finalized:
225 return
226 self._finalized = True
227
228 self._instrument_init()
229
230 _instrumentation_factory.dispatch.class_instrument(self.class_)
231
232 def __hash__(self) -> int: # type: ignore[override]
233 return id(self)
234
235 def __eq__(self, other: Any) -> bool:
236 return other is self
237
238 @property
239 def is_mapped(self) -> bool:
240 return "mapper" in self.__dict__
241
242 @HasMemoized.memoized_attribute
243 def _all_key_set(self):
244 return frozenset(self)
245
246 @HasMemoized.memoized_attribute
247 def _collection_impl_keys(self):
248 return frozenset(
249 [attr.key for attr in self.values() if attr.impl.collection]
250 )
251
252 @HasMemoized.memoized_attribute
253 def _scalar_loader_impls(self):
254 return frozenset(
255 [
256 attr.impl
257 for attr in self.values()
258 if attr.impl.accepts_scalar_loader
259 ]
260 )
261
262 @HasMemoized.memoized_attribute
263 def _loader_impls(self):
264 return frozenset([attr.impl for attr in self.values()])
265
266 @util.memoized_property
267 def mapper(self) -> Mapper[_O]:
268 # raises unless self.mapper has been assigned
269 raise exc.UnmappedClassError(self.class_)
270
271 def _all_sqla_attributes(self, exclude=None):
272 """return an iterator of all classbound attributes that are
273 implement :class:`.InspectionAttr`.
274
275 This includes :class:`.QueryableAttribute` as well as extension
276 types such as :class:`.hybrid_property` and
277 :class:`.AssociationProxy`.
278
279 """
280
281 found: Dict[str, Any] = {}
282
283 # constraints:
284 # 1. yield keys in cls.__dict__ order
285 # 2. if a subclass has the same key as a superclass, include that
286 # key as part of the ordering of the superclass, because an
287 # overridden key is usually installed by the mapper which is going
288 # on a different ordering
289 # 3. don't use getattr() as this fires off descriptors
290
291 for supercls in self.class_.__mro__[0:-1]:
292 inherits = supercls.__mro__[1]
293 for key in supercls.__dict__:
294 found.setdefault(key, supercls)
295 if key in inherits.__dict__:
296 continue
297 val = found[key].__dict__[key]
298 if (
299 isinstance(val, interfaces.InspectionAttr)
300 and val.is_attribute
301 ):
302 yield key, val
303
304 def _get_class_attr_mro(self, key, default=None):
305 """return an attribute on the class without tripping it."""
306
307 for supercls in self.class_.__mro__:
308 if key in supercls.__dict__:
309 return supercls.__dict__[key]
310 else:
311 return default
312
313 def _attr_has_impl(self, key: str) -> bool:
314 """Return True if the given attribute is fully initialized.
315
316 i.e. has an impl.
317 """
318
319 return key in self and self[key].impl is not None
320
321 def _subclass_manager(self, cls: Type[_T]) -> ClassManager[_T]:
322 """Create a new ClassManager for a subclass of this ClassManager's
323 class.
324
325 This is called automatically when attributes are instrumented so that
326 the attributes can be propagated to subclasses against their own
327 class-local manager, without the need for mappers etc. to have already
328 pre-configured managers for the full class hierarchy. Mappers
329 can post-configure the auto-generated ClassManager when needed.
330
331 """
332 return register_class(cls, finalize=False)
333
334 def _instrument_init(self):
335 self.new_init = _generate_init(self.class_, self, self.original_init)
336 self.install_member("__init__", self.new_init)
337
338 @util.memoized_property
339 def _state_constructor(self) -> Type[state.InstanceState[_O]]:
340 return state.InstanceState
341
342 def manage(self):
343 """Mark this instance as the manager for its class."""
344
345 setattr(self.class_, self.MANAGER_ATTR, self)
346
347 @util.hybridmethod
348 def manager_getter(self):
349 return _default_manager_getter
350
351 @util.hybridmethod
352 def state_getter(self):
353 """Return a (instance) -> InstanceState callable.
354
355 "state getter" callables should raise either KeyError or
356 AttributeError if no InstanceState could be found for the
357 instance.
358 """
359
360 return _default_state_getter
361
362 @util.hybridmethod
363 def dict_getter(self):
364 return _default_dict_getter
365
366 def instrument_attribute(
367 self,
368 key: str,
369 inst: QueryableAttribute[Any],
370 propagated: bool = False,
371 ) -> None:
372 if propagated:
373 if key in self.local_attrs:
374 return # don't override local attr with inherited attr
375 else:
376 self.local_attrs[key] = inst
377 self.install_descriptor(key, inst)
378 self._reset_memoizations()
379 self[key] = inst
380
381 for cls in self.class_.__subclasses__():
382 manager = self._subclass_manager(cls)
383 manager.instrument_attribute(key, inst, True)
384
385 def subclass_managers(self, recursive):
386 for cls in self.class_.__subclasses__():
387 mgr = opt_manager_of_class(cls)
388 if mgr is not None and mgr is not self:
389 yield mgr
390 if recursive:
391 yield from mgr.subclass_managers(True)
392
393 def post_configure_attribute(self, key):
394 _instrumentation_factory.dispatch.attribute_instrument(
395 self.class_, key, self[key]
396 )
397
398 def uninstrument_attribute(self, key, propagated=False):
399 if key not in self:
400 return
401 if propagated:
402 if key in self.local_attrs:
403 return # don't get rid of local attr
404 else:
405 del self.local_attrs[key]
406 self.uninstall_descriptor(key)
407 self._reset_memoizations()
408 del self[key]
409 for cls in self.class_.__subclasses__():
410 manager = opt_manager_of_class(cls)
411 if manager:
412 manager.uninstrument_attribute(key, True)
413
414 def unregister(self) -> None:
415 """remove all instrumentation established by this ClassManager."""
416
417 for key in list(self.originals):
418 self.uninstall_member(key)
419
420 self.mapper = None
421 self.dispatch = None # type: ignore
422 self.new_init = None
423 self.info.clear()
424
425 for key in list(self):
426 if key in self.local_attrs:
427 self.uninstrument_attribute(key)
428
429 if self.MANAGER_ATTR in self.class_.__dict__:
430 delattr(self.class_, self.MANAGER_ATTR)
431
432 def install_descriptor(
433 self, key: str, inst: QueryableAttribute[Any]
434 ) -> None:
435 if key in (self.STATE_ATTR, self.MANAGER_ATTR):
436 raise KeyError(
437 "%r: requested attribute name conflicts with "
438 "instrumentation attribute of the same name." % key
439 )
440 setattr(self.class_, key, inst)
441
442 def uninstall_descriptor(self, key: str) -> None:
443 delattr(self.class_, key)
444
445 def install_member(self, key: str, implementation: Any) -> None:
446 if key in (self.STATE_ATTR, self.MANAGER_ATTR):
447 raise KeyError(
448 "%r: requested attribute name conflicts with "
449 "instrumentation attribute of the same name." % key
450 )
451 self.originals.setdefault(key, self.class_.__dict__.get(key, DEL_ATTR))
452 setattr(self.class_, key, implementation)
453
454 def uninstall_member(self, key: str) -> None:
455 original = self.originals.pop(key, None)
456 if original is not DEL_ATTR:
457 setattr(self.class_, key, original)
458 else:
459 delattr(self.class_, key)
460
461 def instrument_collection_class(
462 self, key: str, collection_class: Type[Collection[Any]]
463 ) -> _CollectionFactoryType:
464 return collections._prepare_instrumentation(collection_class)
465
466 def initialize_collection(
467 self,
468 key: str,
469 state: InstanceState[_O],
470 factory: _CollectionFactoryType,
471 ) -> Tuple[collections.CollectionAdapter, _AdaptedCollectionProtocol]:
472 user_data = factory()
473 impl = self.get_impl(key)
474 assert _is_collection_attribute_impl(impl)
475 adapter = collections.CollectionAdapter(impl, state, user_data)
476 return adapter, user_data
477
478 def is_instrumented(self, key: str, search: bool = False) -> bool:
479 if search:
480 return key in self
481 else:
482 return key in self.local_attrs
483
484 def get_impl(self, key: str) -> _AttributeImpl:
485 return self[key].impl
486
487 @property
488 def attributes(self) -> Iterable[Any]:
489 return iter(self.values())
490
491 # InstanceState management
492
493 def new_instance(self, state: Optional[InstanceState[_O]] = None) -> _O:
494 # here, we would prefer _O to be bound to "object"
495 # so that mypy sees that __new__ is present. currently
496 # it's bound to Any as there were other problems not having
497 # it that way but these can be revisited
498 instance = self.class_.__new__(self.class_)
499 if state is None:
500 state = self._state_constructor(instance, self)
501 self._state_setter(instance, state)
502 return instance
503
504 def setup_instance(
505 self, instance: _O, state: Optional[InstanceState[_O]] = None
506 ) -> None:
507 if state is None:
508 state = self._state_constructor(instance, self)
509 self._state_setter(instance, state)
510
511 def teardown_instance(self, instance: _O) -> None:
512 delattr(instance, self.STATE_ATTR)
513
514 def _serialize(
515 self, state: InstanceState[_O], state_dict: Dict[str, Any]
516 ) -> _SerializeManager:
517 return _SerializeManager(state, state_dict)
518
519 def _new_state_if_none(
520 self, instance: _O
521 ) -> Union[Literal[False], InstanceState[_O]]:
522 """Install a default InstanceState if none is present.
523
524 A private convenience method used by the __init__ decorator.
525
526 """
527 if hasattr(instance, self.STATE_ATTR):
528 return False
529 elif self.class_ is not instance.__class__ and self.is_mapped:
530 # this will create a new ClassManager for the
531 # subclass, without a mapper. This is likely a
532 # user error situation but allow the object
533 # to be constructed, so that it is usable
534 # in a non-ORM context at least.
535 return self._subclass_manager(
536 instance.__class__
537 )._new_state_if_none(instance)
538 else:
539 state = self._state_constructor(instance, self)
540 self._state_setter(instance, state)
541 return state
542
543 def has_state(self, instance: _O) -> bool:
544 return hasattr(instance, self.STATE_ATTR)
545
546 def has_parent(
547 self, state: InstanceState[_O], key: str, optimistic: bool = False
548 ) -> bool:
549 """TODO"""
550 return self.get_impl(key).hasparent(state, optimistic=optimistic)
551
552 def __bool__(self) -> bool:
553 """All ClassManagers are non-zero regardless of attribute state."""
554 return True
555
556 def __repr__(self) -> str:
557 return "<%s of %r at %x>" % (
558 self.__class__.__name__,
559 self.class_,
560 id(self),
561 )
562
563
564class _SerializeManager:
565 """Provide serialization of a :class:`.ClassManager`.
566
567 The :class:`.InstanceState` uses ``__init__()`` on serialize
568 and ``__call__()`` on deserialize.
569
570 """
571
572 def __init__(self, state: state.InstanceState[Any], d: Dict[str, Any]):
573 self.class_ = state.class_
574 manager = state.manager
575 manager.dispatch.pickle(state, d)
576
577 def __call__(self, state, inst, state_dict):
578 state.manager = manager = opt_manager_of_class(self.class_)
579 if manager is None:
580 raise exc.UnmappedInstanceError(
581 inst,
582 "Cannot deserialize object of type %r - "
583 "no mapper() has "
584 "been configured for this class within the current "
585 "Python process!" % self.class_,
586 )
587 elif manager.is_mapped and not manager.mapper.configured:
588 manager.mapper._check_configure()
589
590 # setup _sa_instance_state ahead of time so that
591 # unpickle events can access the object normally.
592 # see [ticket:2362]
593 if inst is not None:
594 manager.setup_instance(inst, state)
595 manager.dispatch.unpickle(state, state_dict)
596
597
598class InstrumentationFactory(EventTarget):
599 """Factory for new ClassManager instances."""
600
601 dispatch: dispatcher[InstrumentationFactory]
602
603 def create_manager_for_cls(self, class_: Type[_O]) -> ClassManager[_O]:
604 assert class_ is not None
605 assert opt_manager_of_class(class_) is None
606
607 # give a more complicated subclass
608 # a chance to do what it wants here
609 manager, factory = self._locate_extended_factory(class_)
610
611 if factory is None:
612 factory = ClassManager
613 manager = ClassManager(class_)
614 else:
615 assert manager is not None
616
617 self._check_conflicts(class_, factory)
618
619 manager.factory = factory
620
621 return manager
622
623 def _locate_extended_factory(
624 self, class_: Type[_O]
625 ) -> Tuple[Optional[ClassManager[_O]], Optional[_ManagerFactory]]:
626 """Overridden by a subclass to do an extended lookup."""
627 return None, None
628
629 def _check_conflicts(
630 self, class_: Type[_O], factory: Callable[[Type[_O]], ClassManager[_O]]
631 ) -> None:
632 """Overridden by a subclass to test for conflicting factories."""
633
634 def unregister(self, class_: Type[_O]) -> None:
635 manager = manager_of_class(class_)
636 manager.unregister()
637 self.dispatch.class_uninstrument(class_)
638
639
640# this attribute is replaced by sqlalchemy.ext.instrumentation
641# when imported.
642_instrumentation_factory = InstrumentationFactory()
643
644# these attributes are replaced by sqlalchemy.ext.instrumentation
645# when a non-standard InstrumentationManager class is first
646# used to instrument a class.
647instance_state = _default_state_getter = base.instance_state
648
649instance_dict = _default_dict_getter = base.instance_dict
650
651manager_of_class = _default_manager_getter = base.manager_of_class
652opt_manager_of_class = _default_opt_manager_getter = base.opt_manager_of_class
653
654
655def register_class(
656 class_: Type[_O],
657 finalize: bool = True,
658 mapper: Optional[Mapper[_O]] = None,
659 registry: Optional[_RegistryType] = None,
660 declarative_scan: Optional[_MapperConfig] = None,
661 expired_attribute_loader: Optional[_ExpiredAttributeLoaderProto] = None,
662 init_method: Optional[Callable[..., None]] = None,
663) -> ClassManager[_O]:
664 """Register class instrumentation.
665
666 Returns the existing or newly created class manager.
667
668 """
669
670 manager = opt_manager_of_class(class_)
671 if manager is None:
672 manager = _instrumentation_factory.create_manager_for_cls(class_)
673 manager._update_state(
674 mapper=mapper,
675 registry=registry,
676 declarative_scan=declarative_scan,
677 expired_attribute_loader=expired_attribute_loader,
678 init_method=init_method,
679 finalize=finalize,
680 )
681
682 return manager
683
684
685def unregister_class(class_):
686 """Unregister class instrumentation."""
687
688 _instrumentation_factory.unregister(class_)
689
690
691def is_instrumented(instance, key):
692 """Return True if the given attribute on the given instance is
693 instrumented by the attributes package.
694
695 This function may be used regardless of instrumentation
696 applied directly to the class, i.e. no descriptors are required.
697
698 """
699 return manager_of_class(instance.__class__).is_instrumented(
700 key, search=True
701 )
702
703
704def _generate_init(class_, class_manager, original_init):
705 """Build an __init__ decorator that triggers ClassManager events."""
706
707 # TODO: we should use the ClassManager's notion of the
708 # original '__init__' method, once ClassManager is fixed
709 # to always reference that.
710
711 if original_init is None:
712 original_init = class_.__init__
713
714 # Go through some effort here and don't change the user's __init__
715 # calling signature, including the unlikely case that it has
716 # a return value.
717 # FIXME: need to juggle local names to avoid constructor argument
718 # clashes.
719 func_body = """\
720def __init__(%(apply_pos)s):
721 new_state = class_manager._new_state_if_none(%(self_arg)s)
722 if new_state:
723 return new_state._initialize_instance(%(apply_kw)s)
724 else:
725 return original_init(%(apply_kw)s)
726"""
727 func_vars = util.format_argspec_init(original_init, grouped=False)
728 func_text = func_body % func_vars
729
730 func_defaults = getattr(original_init, "__defaults__", None)
731 func_kw_defaults = getattr(original_init, "__kwdefaults__", None)
732
733 env = locals().copy()
734 env["__name__"] = __name__
735 exec(func_text, env)
736 __init__ = env["__init__"]
737 __init__.__doc__ = original_init.__doc__
738 __init__._sa_original_init = original_init
739
740 if func_defaults:
741 __init__.__defaults__ = func_defaults
742 if func_kw_defaults:
743 __init__.__kwdefaults__ = func_kw_defaults
744
745 return __init__