1# event/attr.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""Attribute implementation for _Dispatch classes.
9
10The various listener targets for a particular event class are represented
11as attributes, which refer to collections of listeners to be fired off.
12These collections can exist at the class level as well as at the instance
13level. An event is fired off using code like this::
14
15 some_object.dispatch.first_connect(arg1, arg2)
16
17Above, ``some_object.dispatch`` would be an instance of ``_Dispatch`` and
18``first_connect`` is typically an instance of ``_ListenerCollection``
19if event listeners are present, or ``_EmptyListener`` if none are present.
20
21The attribute mechanics here spend effort trying to ensure listener functions
22are available with a minimum of function call overhead, that unnecessary
23objects aren't created (i.e. many empty per-instance listener collections),
24as well as that everything is garbage collectable when owning references are
25lost. Other features such as "propagation" of listener functions across
26many ``_Dispatch`` instances, "joining" of multiple ``_Dispatch`` instances,
27as well as support for subclass propagation (e.g. events assigned to
28``Pool`` vs. ``QueuePool``) are all implemented here.
29
30"""
31from __future__ import annotations
32
33import collections
34from itertools import chain
35import threading
36from types import TracebackType
37import typing
38from typing import Any
39from typing import cast
40from typing import Collection
41from typing import Deque
42from typing import FrozenSet
43from typing import Generic
44from typing import Iterator
45from typing import MutableMapping
46from typing import MutableSequence
47from typing import NoReturn
48from typing import Optional
49from typing import Sequence
50from typing import Set
51from typing import Tuple
52from typing import Type
53from typing import TypeVar
54from typing import Union
55import weakref
56
57from . import legacy
58from . import registry
59from .registry import _ET
60from .registry import _EventKey
61from .registry import _ListenerFnType
62from .. import exc
63from .. import util
64from ..util.concurrency import AsyncAdaptedLock
65from ..util.typing import Protocol
66
67_T = TypeVar("_T", bound=Any)
68
69if typing.TYPE_CHECKING:
70 from .base import _Dispatch
71 from .base import _DispatchCommon
72 from .base import _HasEventsDispatch
73
74
75class RefCollection(util.MemoizedSlots, Generic[_ET]):
76 __slots__ = ("ref",)
77
78 ref: weakref.ref[RefCollection[_ET]]
79
80 def _memoized_attr_ref(self) -> weakref.ref[RefCollection[_ET]]:
81 return weakref.ref(self, registry._collection_gced)
82
83
84class _empty_collection(Collection[_T]):
85 def append(self, element: _T) -> None:
86 pass
87
88 def appendleft(self, element: _T) -> None:
89 pass
90
91 def extend(self, other: Sequence[_T]) -> None:
92 pass
93
94 def remove(self, element: _T) -> None:
95 pass
96
97 def __contains__(self, element: Any) -> bool:
98 return False
99
100 def __iter__(self) -> Iterator[_T]:
101 return iter([])
102
103 def clear(self) -> None:
104 pass
105
106 def __len__(self) -> int:
107 return 0
108
109
110_ListenerFnSequenceType = Union[Deque[_T], _empty_collection[_T]]
111
112
113class _ClsLevelDispatch(RefCollection[_ET]):
114 """Class-level events on :class:`._Dispatch` classes."""
115
116 __slots__ = (
117 "clsname",
118 "name",
119 "arg_names",
120 "has_kw",
121 "legacy_signatures",
122 "_clslevel",
123 "__weakref__",
124 )
125
126 clsname: str
127 name: str
128 arg_names: Sequence[str]
129 has_kw: bool
130 legacy_signatures: MutableSequence[legacy._LegacySignatureType]
131 _clslevel: MutableMapping[
132 Type[_ET], _ListenerFnSequenceType[_ListenerFnType]
133 ]
134
135 def __init__(
136 self,
137 parent_dispatch_cls: Type[_HasEventsDispatch[_ET]],
138 fn: _ListenerFnType,
139 ):
140 self.name = fn.__name__
141 self.clsname = parent_dispatch_cls.__name__
142 argspec = util.inspect_getfullargspec(fn)
143 self.arg_names = argspec.args[1:]
144 self.has_kw = bool(argspec.varkw)
145 self.legacy_signatures = list(
146 reversed(
147 sorted(
148 getattr(fn, "_legacy_signatures", []), key=lambda s: s[0]
149 )
150 )
151 )
152 fn.__doc__ = legacy._augment_fn_docs(self, parent_dispatch_cls, fn)
153
154 self._clslevel = weakref.WeakKeyDictionary()
155
156 def _adjust_fn_spec(
157 self, fn: _ListenerFnType, named: bool
158 ) -> _ListenerFnType:
159 if named:
160 fn = self._wrap_fn_for_kw(fn)
161 if self.legacy_signatures:
162 try:
163 argspec = util.get_callable_argspec(fn, no_self=True)
164 except TypeError:
165 pass
166 else:
167 fn = legacy._wrap_fn_for_legacy(self, fn, argspec)
168 return fn
169
170 def _wrap_fn_for_kw(self, fn: _ListenerFnType) -> _ListenerFnType:
171 def wrap_kw(*args: Any, **kw: Any) -> Any:
172 argdict = dict(zip(self.arg_names, args))
173 argdict.update(kw)
174 return fn(**argdict)
175
176 return wrap_kw
177
178 def _do_insert_or_append(
179 self, event_key: _EventKey[_ET], is_append: bool
180 ) -> None:
181 target = event_key.dispatch_target
182 assert isinstance(
183 target, type
184 ), "Class-level Event targets must be classes."
185 if not getattr(target, "_sa_propagate_class_events", True):
186 raise exc.InvalidRequestError(
187 f"Can't assign an event directly to the {target} class"
188 )
189
190 cls: Type[_ET]
191
192 for cls in util.walk_subclasses(target):
193 if cls is not target and cls not in self._clslevel:
194 self.update_subclass(cls)
195 else:
196 if cls not in self._clslevel:
197 self.update_subclass(cls)
198 if is_append:
199 self._clslevel[cls].append(event_key._listen_fn)
200 else:
201 self._clslevel[cls].appendleft(event_key._listen_fn)
202 registry._stored_in_collection(event_key, self)
203
204 def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None:
205 self._do_insert_or_append(event_key, is_append=False)
206
207 def append(self, event_key: _EventKey[_ET], propagate: bool) -> None:
208 self._do_insert_or_append(event_key, is_append=True)
209
210 def update_subclass(self, target: Type[_ET]) -> None:
211 if target not in self._clslevel:
212 if getattr(target, "_sa_propagate_class_events", True):
213 self._clslevel[target] = collections.deque()
214 else:
215 self._clslevel[target] = _empty_collection()
216
217 clslevel = self._clslevel[target]
218 cls: Type[_ET]
219 for cls in target.__mro__[1:]:
220 if cls in self._clslevel:
221 clslevel.extend(
222 [fn for fn in self._clslevel[cls] if fn not in clslevel]
223 )
224
225 def remove(self, event_key: _EventKey[_ET]) -> None:
226 target = event_key.dispatch_target
227 cls: Type[_ET]
228 for cls in util.walk_subclasses(target):
229 if cls in self._clslevel:
230 self._clslevel[cls].remove(event_key._listen_fn)
231 registry._removed_from_collection(event_key, self)
232
233 def clear(self) -> None:
234 """Clear all class level listeners"""
235
236 to_clear: Set[_ListenerFnType] = set()
237 for dispatcher in self._clslevel.values():
238 to_clear.update(dispatcher)
239 dispatcher.clear()
240 registry._clear(self, to_clear)
241
242 def for_modify(self, obj: _Dispatch[_ET]) -> _ClsLevelDispatch[_ET]:
243 """Return an event collection which can be modified.
244
245 For _ClsLevelDispatch at the class level of
246 a dispatcher, this returns self.
247
248 """
249 return self
250
251
252class _InstanceLevelDispatch(RefCollection[_ET], Collection[_ListenerFnType]):
253 __slots__ = ()
254
255 parent: _ClsLevelDispatch[_ET]
256
257 def _adjust_fn_spec(
258 self, fn: _ListenerFnType, named: bool
259 ) -> _ListenerFnType:
260 return self.parent._adjust_fn_spec(fn, named)
261
262 def __contains__(self, item: Any) -> bool:
263 raise NotImplementedError()
264
265 def __len__(self) -> int:
266 raise NotImplementedError()
267
268 def __iter__(self) -> Iterator[_ListenerFnType]:
269 raise NotImplementedError()
270
271 def __bool__(self) -> bool:
272 raise NotImplementedError()
273
274 def exec_once(self, *args: Any, **kw: Any) -> None:
275 raise NotImplementedError()
276
277 def exec_once_unless_exception(self, *args: Any, **kw: Any) -> None:
278 raise NotImplementedError()
279
280 def _exec_w_sync_on_first_run(self, *args: Any, **kw: Any) -> None:
281 raise NotImplementedError()
282
283 def __call__(self, *args: Any, **kw: Any) -> None:
284 raise NotImplementedError()
285
286 def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None:
287 raise NotImplementedError()
288
289 def append(self, event_key: _EventKey[_ET], propagate: bool) -> None:
290 raise NotImplementedError()
291
292 def remove(self, event_key: _EventKey[_ET]) -> None:
293 raise NotImplementedError()
294
295 def for_modify(
296 self, obj: _DispatchCommon[_ET]
297 ) -> _InstanceLevelDispatch[_ET]:
298 """Return an event collection which can be modified.
299
300 For _ClsLevelDispatch at the class level of
301 a dispatcher, this returns self.
302
303 """
304 return self
305
306
307class _EmptyListener(_InstanceLevelDispatch[_ET]):
308 """Serves as a proxy interface to the events
309 served by a _ClsLevelDispatch, when there are no
310 instance-level events present.
311
312 Is replaced by _ListenerCollection when instance-level
313 events are added.
314
315 """
316
317 __slots__ = "parent", "parent_listeners", "name"
318
319 propagate: FrozenSet[_ListenerFnType] = frozenset()
320 listeners: Tuple[()] = ()
321 parent: _ClsLevelDispatch[_ET]
322 parent_listeners: _ListenerFnSequenceType[_ListenerFnType]
323 name: str
324
325 def __init__(self, parent: _ClsLevelDispatch[_ET], target_cls: Type[_ET]):
326 if target_cls not in parent._clslevel:
327 parent.update_subclass(target_cls)
328 self.parent = parent
329 self.parent_listeners = parent._clslevel[target_cls]
330 self.name = parent.name
331
332 def for_modify(
333 self, obj: _DispatchCommon[_ET]
334 ) -> _ListenerCollection[_ET]:
335 """Return an event collection which can be modified.
336
337 For _EmptyListener at the instance level of
338 a dispatcher, this generates a new
339 _ListenerCollection, applies it to the instance,
340 and returns it.
341
342 """
343 obj = cast("_Dispatch[_ET]", obj)
344
345 assert obj._instance_cls is not None
346 existing = getattr(obj, self.name)
347
348 with util.mini_gil:
349 if existing is self or isinstance(existing, _JoinedListener):
350 result = _ListenerCollection(self.parent, obj._instance_cls)
351 else:
352 # this codepath is an extremely rare race condition
353 # that has been observed in test_pool.py->test_timeout_race
354 # with freethreaded.
355 assert isinstance(existing, _ListenerCollection)
356 return existing
357
358 if existing is self:
359 setattr(obj, self.name, result)
360 return result
361
362 def _needs_modify(self, *args: Any, **kw: Any) -> NoReturn:
363 raise NotImplementedError("need to call for_modify()")
364
365 def exec_once(self, *args: Any, **kw: Any) -> NoReturn:
366 self._needs_modify(*args, **kw)
367
368 def exec_once_unless_exception(self, *args: Any, **kw: Any) -> NoReturn:
369 self._needs_modify(*args, **kw)
370
371 def insert(self, *args: Any, **kw: Any) -> NoReturn:
372 self._needs_modify(*args, **kw)
373
374 def append(self, *args: Any, **kw: Any) -> NoReturn:
375 self._needs_modify(*args, **kw)
376
377 def remove(self, *args: Any, **kw: Any) -> NoReturn:
378 self._needs_modify(*args, **kw)
379
380 def clear(self, *args: Any, **kw: Any) -> NoReturn:
381 self._needs_modify(*args, **kw)
382
383 def __call__(self, *args: Any, **kw: Any) -> None:
384 """Execute this event."""
385
386 for fn in self.parent_listeners:
387 fn(*args, **kw)
388
389 def __contains__(self, item: Any) -> bool:
390 return item in self.parent_listeners
391
392 def __len__(self) -> int:
393 return len(self.parent_listeners)
394
395 def __iter__(self) -> Iterator[_ListenerFnType]:
396 return iter(self.parent_listeners)
397
398 def __bool__(self) -> bool:
399 return bool(self.parent_listeners)
400
401
402class _MutexProtocol(Protocol):
403 def __enter__(self) -> bool: ...
404
405 def __exit__(
406 self,
407 exc_type: Optional[Type[BaseException]],
408 exc_val: Optional[BaseException],
409 exc_tb: Optional[TracebackType],
410 ) -> Optional[bool]: ...
411
412
413class _CompoundListener(_InstanceLevelDispatch[_ET]):
414 __slots__ = (
415 "_exec_once_mutex",
416 "_exec_once",
417 "_exec_w_sync_once",
418 "_is_asyncio",
419 )
420
421 _exec_once_mutex: Optional[_MutexProtocol]
422 parent_listeners: Collection[_ListenerFnType]
423 listeners: Collection[_ListenerFnType]
424 _exec_once: bool
425 _exec_w_sync_once: bool
426
427 def __init__(self, *arg: Any, **kw: Any):
428 super().__init__(*arg, **kw)
429 self._is_asyncio = False
430
431 def _set_asyncio(self) -> None:
432 self._is_asyncio = True
433
434 def _get_exec_once_mutex(self) -> _MutexProtocol:
435 with util.mini_gil:
436 if self._exec_once_mutex is not None:
437 return self._exec_once_mutex
438
439 if self._is_asyncio:
440 mutex = AsyncAdaptedLock()
441 else:
442 mutex = threading.Lock() # type: ignore[assignment]
443 self._exec_once_mutex = mutex
444
445 return mutex
446
447 def _exec_once_impl(
448 self, retry_on_exception: bool, *args: Any, **kw: Any
449 ) -> None:
450 with self._get_exec_once_mutex():
451 if not self._exec_once:
452 try:
453 self(*args, **kw)
454 exception = False
455 except:
456 exception = True
457 raise
458 finally:
459 if not exception or not retry_on_exception:
460 self._exec_once = True
461
462 def exec_once(self, *args: Any, **kw: Any) -> None:
463 """Execute this event, but only if it has not been
464 executed already for this collection."""
465
466 if not self._exec_once:
467 self._exec_once_impl(False, *args, **kw)
468
469 def exec_once_unless_exception(self, *args: Any, **kw: Any) -> None:
470 """Execute this event, but only if it has not been
471 executed already for this collection, or was called
472 by a previous exec_once_unless_exception call and
473 raised an exception.
474
475 If exec_once was already called, then this method will never run
476 the callable regardless of whether it raised or not.
477
478 .. versionadded:: 1.3.8
479
480 """
481 if not self._exec_once:
482 self._exec_once_impl(True, *args, **kw)
483
484 def _exec_w_sync_on_first_run(self, *args: Any, **kw: Any) -> None:
485 """Execute this event, and use a mutex if it has not been
486 executed already for this collection, or was called
487 by a previous _exec_w_sync_on_first_run call and
488 raised an exception.
489
490 If _exec_w_sync_on_first_run was already called and didn't raise an
491 exception, then a mutex is not used. It's not guaranteed
492 the mutex won't be used more than once in the case of very rare
493 race conditions.
494
495 .. versionadded:: 1.4.11
496
497 """
498 if not self._exec_w_sync_once:
499 with self._get_exec_once_mutex():
500 try:
501 self(*args, **kw)
502 except:
503 raise
504 else:
505 self._exec_w_sync_once = True
506 else:
507 self(*args, **kw)
508
509 def __call__(self, *args: Any, **kw: Any) -> None:
510 """Execute this event."""
511
512 for fn in self.parent_listeners:
513 fn(*args, **kw)
514 for fn in self.listeners:
515 fn(*args, **kw)
516
517 def __contains__(self, item: Any) -> bool:
518 return item in self.parent_listeners or item in self.listeners
519
520 def __len__(self) -> int:
521 return len(self.parent_listeners) + len(self.listeners)
522
523 def __iter__(self) -> Iterator[_ListenerFnType]:
524 return chain(self.parent_listeners, self.listeners)
525
526 def __bool__(self) -> bool:
527 return bool(self.listeners or self.parent_listeners)
528
529
530class _ListenerCollection(_CompoundListener[_ET]):
531 """Instance-level attributes on instances of :class:`._Dispatch`.
532
533 Represents a collection of listeners.
534
535 As of 0.7.9, _ListenerCollection is only first
536 created via the _EmptyListener.for_modify() method.
537
538 """
539
540 __slots__ = (
541 "parent_listeners",
542 "parent",
543 "name",
544 "listeners",
545 "propagate",
546 "__weakref__",
547 )
548
549 parent_listeners: Collection[_ListenerFnType]
550 parent: _ClsLevelDispatch[_ET]
551 name: str
552 listeners: Deque[_ListenerFnType]
553 propagate: Set[_ListenerFnType]
554
555 def __init__(self, parent: _ClsLevelDispatch[_ET], target_cls: Type[_ET]):
556 super().__init__()
557 if target_cls not in parent._clslevel:
558 parent.update_subclass(target_cls)
559 self._exec_once = False
560 self._exec_w_sync_once = False
561 self._exec_once_mutex = None
562 self.parent_listeners = parent._clslevel[target_cls]
563 self.parent = parent
564 self.name = parent.name
565 self.listeners = collections.deque()
566 self.propagate = set()
567
568 def for_modify(
569 self, obj: _DispatchCommon[_ET]
570 ) -> _ListenerCollection[_ET]:
571 """Return an event collection which can be modified.
572
573 For _ListenerCollection at the instance level of
574 a dispatcher, this returns self.
575
576 """
577 return self
578
579 def _update(
580 self, other: _ListenerCollection[_ET], only_propagate: bool = True
581 ) -> None:
582 """Populate from the listeners in another :class:`_Dispatch`
583 object."""
584 existing_listeners = self.listeners
585 existing_listener_set = set(existing_listeners)
586 self.propagate.update(other.propagate)
587 other_listeners = [
588 l
589 for l in other.listeners
590 if l not in existing_listener_set
591 and not only_propagate
592 or l in self.propagate
593 ]
594
595 existing_listeners.extend(other_listeners)
596
597 if other._is_asyncio:
598 self._set_asyncio()
599
600 to_associate = other.propagate.union(other_listeners)
601 registry._stored_in_collection_multi(self, other, to_associate)
602
603 def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None:
604 if event_key.prepend_to_list(self, self.listeners):
605 if propagate:
606 self.propagate.add(event_key._listen_fn)
607
608 def append(self, event_key: _EventKey[_ET], propagate: bool) -> None:
609 if event_key.append_to_list(self, self.listeners):
610 if propagate:
611 self.propagate.add(event_key._listen_fn)
612
613 def remove(self, event_key: _EventKey[_ET]) -> None:
614 self.listeners.remove(event_key._listen_fn)
615 self.propagate.discard(event_key._listen_fn)
616 registry._removed_from_collection(event_key, self)
617
618 def clear(self) -> None:
619 registry._clear(self, self.listeners)
620 self.propagate.clear()
621 self.listeners.clear()
622
623
624class _JoinedListener(_CompoundListener[_ET]):
625 __slots__ = "parent_dispatch", "name", "local", "parent_listeners"
626
627 parent_dispatch: _DispatchCommon[_ET]
628 name: str
629 local: _InstanceLevelDispatch[_ET]
630 parent_listeners: Collection[_ListenerFnType]
631
632 def __init__(
633 self,
634 parent_dispatch: _DispatchCommon[_ET],
635 name: str,
636 local: _EmptyListener[_ET],
637 ):
638 self._exec_once = False
639 self._exec_w_sync_once = False
640 self._exec_once_mutex = None
641 self.parent_dispatch = parent_dispatch
642 self.name = name
643 self.local = local
644 self.parent_listeners = self.local
645
646 if not typing.TYPE_CHECKING:
647 # first error, I don't really understand:
648 # Signature of "listeners" incompatible with
649 # supertype "_CompoundListener" [override]
650 # the name / return type are exactly the same
651 # second error is getattr_isn't typed, the cast() here
652 # adds too much method overhead
653 @property
654 def listeners(self) -> Collection[_ListenerFnType]:
655 return getattr(self.parent_dispatch, self.name)
656
657 def _adjust_fn_spec(
658 self, fn: _ListenerFnType, named: bool
659 ) -> _ListenerFnType:
660 return self.local._adjust_fn_spec(fn, named)
661
662 def for_modify(self, obj: _DispatchCommon[_ET]) -> _JoinedListener[_ET]:
663 self.local = self.parent_listeners = self.local.for_modify(obj)
664 return self
665
666 def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None:
667 self.local.insert(event_key, propagate)
668
669 def append(self, event_key: _EventKey[_ET], propagate: bool) -> None:
670 self.local.append(event_key, propagate)
671
672 def remove(self, event_key: _EventKey[_ET]) -> None:
673 self.local.remove(event_key)
674
675 def clear(self) -> None:
676 raise NotImplementedError()