1# util/langhelpers.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Routines to help with the creation, loading and introspection of
10modules, classes, hierarchies, attributes, functions, and methods.
11
12"""
13from __future__ import annotations
14
15import collections
16import enum
17from functools import update_wrapper
18import importlib.util
19import inspect
20import itertools
21import operator
22import re
23import sys
24import textwrap
25import threading
26import types
27from types import CodeType
28from types import ModuleType
29from typing import Any
30from typing import Callable
31from typing import cast
32from typing import Dict
33from typing import FrozenSet
34from typing import Generic
35from typing import Iterator
36from typing import List
37from typing import Literal
38from typing import Mapping
39from typing import NoReturn
40from typing import Optional
41from typing import overload
42from typing import Sequence
43from typing import Set
44from typing import Tuple
45from typing import Type
46from typing import TYPE_CHECKING
47from typing import TypeVar
48from typing import Union
49import warnings
50
51from . import _collections
52from . import compat
53from .. import exc
54
55_T = TypeVar("_T")
56_T_co = TypeVar("_T_co", covariant=True)
57_F = TypeVar("_F", bound=Callable[..., Any])
58_MA = TypeVar("_MA", bound="HasMemoized.memoized_attribute[Any]")
59_M = TypeVar("_M", bound=ModuleType)
60
61if compat.py314:
62 # vendor a minimal form of get_annotations per
63 # https://github.com/python/cpython/issues/133684#issuecomment-2863841891
64
65 from annotationlib import call_annotate_function # type: ignore
66 from annotationlib import Format
67
68 def _get_and_call_annotate(obj, format): # noqa: A002
69 annotate = getattr(obj, "__annotate__", None)
70 if annotate is not None:
71 ann = call_annotate_function(annotate, format, owner=obj)
72 if not isinstance(ann, dict):
73 raise ValueError(f"{obj!r}.__annotate__ returned a non-dict")
74 return ann
75 return None
76
77 # this is ported from py3.13.0a7
78 _BASE_GET_ANNOTATIONS = type.__dict__["__annotations__"].__get__ # type: ignore # noqa: E501
79
80 def _get_dunder_annotations(obj):
81 if isinstance(obj, type):
82 try:
83 ann = _BASE_GET_ANNOTATIONS(obj)
84 except AttributeError:
85 # For static types, the descriptor raises AttributeError.
86 return {}
87 else:
88 ann = getattr(obj, "__annotations__", None)
89 if ann is None:
90 return {}
91
92 if not isinstance(ann, dict):
93 raise ValueError(
94 f"{obj!r}.__annotations__ is neither a dict nor None"
95 )
96 return dict(ann)
97
98 def _vendored_get_annotations(
99 obj: Any, *, format: Format # noqa: A002
100 ) -> Mapping[str, Any]:
101 """A sparse implementation of annotationlib.get_annotations()"""
102
103 try:
104 ann = _get_dunder_annotations(obj)
105 except Exception:
106 pass
107 else:
108 if ann is not None:
109 return dict(ann)
110
111 # But if __annotations__ threw a NameError, we try calling __annotate__
112 ann = _get_and_call_annotate(obj, format)
113 if ann is None:
114 # If that didn't work either, we have a very weird object:
115 # evaluating
116 # __annotations__ threw NameError and there is no __annotate__.
117 # In that case,
118 # we fall back to trying __annotations__ again.
119 ann = _get_dunder_annotations(obj)
120
121 if ann is None:
122 if isinstance(obj, type) or callable(obj):
123 return {}
124 raise TypeError(f"{obj!r} does not have annotations")
125
126 if not ann:
127 return {}
128
129 return dict(ann)
130
131 def get_annotations(obj: Any) -> Mapping[str, Any]:
132 # FORWARDREF has the effect of giving us ForwardRefs and not
133 # actually trying to evaluate the annotations. We need this so
134 # that the annotations act as much like
135 # "from __future__ import annotations" as possible, which is going
136 # away in future python as a separate mode
137 return _vendored_get_annotations(obj, format=Format.FORWARDREF)
138
139else:
140
141 def get_annotations(obj: Any) -> Mapping[str, Any]:
142 return inspect.get_annotations(obj)
143
144
145def md5_hex(x: Any) -> str:
146 x = x.encode("utf-8")
147 m = compat.md5_not_for_security()
148 m.update(x)
149 return cast(str, m.hexdigest())
150
151
152class safe_reraise:
153 """Reraise an exception after invoking some
154 handler code.
155
156 Stores the existing exception info before
157 invoking so that it is maintained across a potential
158 coroutine context switch.
159
160 e.g.::
161
162 try:
163 sess.commit()
164 except:
165 with safe_reraise():
166 sess.rollback()
167
168 TODO: we should at some point evaluate current behaviors in this regard
169 based on current greenlet, gevent/eventlet implementations in Python 3, and
170 also see the degree to which our own asyncio (based on greenlet also) is
171 impacted by this. .rollback() will cause IO / context switch to occur in
172 all these scenarios; what happens to the exception context from an
173 "except:" block if we don't explicitly store it? Original issue was #2703.
174
175 """
176
177 __slots__ = ("_exc_info",)
178
179 _exc_info: Union[
180 None,
181 Tuple[
182 Type[BaseException],
183 BaseException,
184 types.TracebackType,
185 ],
186 Tuple[None, None, None],
187 ]
188
189 def __enter__(self) -> None:
190 self._exc_info = sys.exc_info()
191
192 def __exit__(
193 self,
194 type_: Optional[Type[BaseException]],
195 value: Optional[BaseException],
196 traceback: Optional[types.TracebackType],
197 ) -> NoReturn:
198 assert self._exc_info is not None
199 # see #2703 for notes
200 if type_ is None:
201 exc_type, exc_value, exc_tb = self._exc_info
202 assert exc_value is not None
203 self._exc_info = None # remove potential circular references
204 raise exc_value.with_traceback(exc_tb)
205 else:
206 self._exc_info = None # remove potential circular references
207 assert value is not None
208 raise value.with_traceback(traceback)
209
210
211def walk_subclasses(cls: Type[_T]) -> Iterator[Type[_T]]:
212 seen: Set[Any] = set()
213
214 stack = [cls]
215 while stack:
216 cls = stack.pop()
217 if cls in seen:
218 continue
219 else:
220 seen.add(cls)
221 stack.extend(cls.__subclasses__())
222 yield cls
223
224
225def string_or_unprintable(element: Any) -> str:
226 if isinstance(element, str):
227 return element
228 else:
229 try:
230 return str(element)
231 except Exception:
232 return "unprintable element %r" % element
233
234
235def clsname_as_plain_name(
236 cls: Type[Any], use_name: Optional[str] = None
237) -> str:
238 name = use_name or cls.__name__
239 return " ".join(n.lower() for n in re.findall(r"([A-Z][a-z]+|SQL)", name))
240
241
242def method_is_overridden(
243 instance_or_cls: Union[Type[Any], object],
244 against_method: Callable[..., Any],
245) -> bool:
246 """Return True if the two class methods don't match."""
247
248 if not isinstance(instance_or_cls, type):
249 current_cls = instance_or_cls.__class__
250 else:
251 current_cls = instance_or_cls
252
253 method_name = against_method.__name__
254
255 current_method: types.MethodType = getattr(current_cls, method_name)
256
257 return current_method != against_method
258
259
260def decode_slice(slc: slice) -> Tuple[Any, ...]:
261 """decode a slice object as sent to __getitem__.
262
263 takes into account the 2.5 __index__() method, basically.
264
265 """
266 ret: List[Any] = []
267 for x in slc.start, slc.stop, slc.step:
268 if hasattr(x, "__index__"):
269 x = x.__index__()
270 ret.append(x)
271 return tuple(ret)
272
273
274def _unique_symbols(used: Sequence[str], *bases: str) -> Iterator[str]:
275 used_set = set(used)
276 for base in bases:
277 pool = itertools.chain(
278 (base,),
279 map(lambda i: base + str(i), range(1000)),
280 )
281 for sym in pool:
282 if sym not in used_set:
283 used_set.add(sym)
284 yield sym
285 break
286 else:
287 raise NameError("exhausted namespace for symbol base %s" % base)
288
289
290def map_bits(fn: Callable[[int], Any], n: int) -> Iterator[Any]:
291 """Call the given function given each nonzero bit from n."""
292
293 while n:
294 b = n & (~n + 1)
295 yield fn(b)
296 n ^= b
297
298
299_Fn = TypeVar("_Fn", bound="Callable[..., Any]")
300
301# this seems to be in flux in recent mypy versions
302
303
304def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]:
305 """A signature-matching decorator factory."""
306
307 def decorate(fn: _Fn) -> _Fn:
308 if not inspect.isfunction(fn) and not inspect.ismethod(fn):
309 raise Exception("not a decoratable function")
310
311 # Python 3.14 defer creating __annotations__ until its used.
312 # We do not want to create __annotations__ now.
313 annofunc = getattr(fn, "__annotate__", None)
314 if annofunc is not None:
315 fn.__annotate__ = None # type: ignore[union-attr]
316 try:
317 spec = compat.inspect_getfullargspec(fn)
318 finally:
319 fn.__annotate__ = annofunc # type: ignore[union-attr]
320 else:
321 spec = compat.inspect_getfullargspec(fn)
322
323 # Do not generate code for annotations.
324 # update_wrapper() copies the annotation from fn to decorated.
325 # We use dummy defaults for code generation to avoid having
326 # copy of large globals for compiling.
327 # We copy __defaults__ and __kwdefaults__ from fn to decorated.
328 empty_defaults = (None,) * len(spec.defaults or ())
329 empty_kwdefaults = dict.fromkeys(spec.kwonlydefaults or ())
330 spec = spec._replace(
331 annotations={},
332 defaults=empty_defaults,
333 kwonlydefaults=empty_kwdefaults,
334 )
335
336 names = (
337 tuple(cast("Tuple[str, ...]", spec[0]))
338 + cast("Tuple[str, ...]", spec[1:3])
339 + (fn.__name__,)
340 )
341 targ_name, fn_name = _unique_symbols(names, "target", "fn")
342
343 metadata: Dict[str, Optional[str]] = dict(target=targ_name, fn=fn_name)
344 metadata.update(format_argspec_plus(spec, grouped=False))
345 metadata["name"] = fn.__name__
346
347 if inspect.iscoroutinefunction(fn):
348 metadata["prefix"] = "async "
349 metadata["target_prefix"] = "await "
350 else:
351 metadata["prefix"] = ""
352 metadata["target_prefix"] = ""
353
354 # look for __ positional arguments. This is a convention in
355 # SQLAlchemy that arguments should be passed positionally
356 # rather than as keyword
357 # arguments. note that apply_pos doesn't currently work in all cases
358 # such as when a kw-only indicator "*" is present, which is why
359 # we limit the use of this to just that case we can detect. As we add
360 # more kinds of methods that use @decorator, things may have to
361 # be further improved in this area
362 if "__" in repr(spec[0]):
363 code = (
364 """\
365%(prefix)sdef %(name)s%(grouped_args)s:
366 return %(target_prefix)s%(target)s(%(fn)s, %(apply_pos)s)
367"""
368 % metadata
369 )
370 else:
371 code = (
372 """\
373%(prefix)sdef %(name)s%(grouped_args)s:
374 return %(target_prefix)s%(target)s(%(fn)s, %(apply_kw)s)
375"""
376 % metadata
377 )
378
379 env: Dict[str, Any] = {
380 targ_name: target,
381 fn_name: fn,
382 "__name__": fn.__module__,
383 }
384
385 decorated = cast(
386 types.FunctionType,
387 _exec_code_in_env(code, env, fn.__name__),
388 )
389 decorated.__defaults__ = fn.__defaults__
390 decorated.__kwdefaults__ = fn.__kwdefaults__ # type: ignore
391 return update_wrapper(decorated, fn) # type: ignore[return-value]
392
393 return update_wrapper(decorate, target) # type: ignore[return-value]
394
395
396def _exec_code_in_env(
397 code: Union[str, types.CodeType], env: Dict[str, Any], fn_name: str
398) -> Callable[..., Any]:
399 exec(code, env)
400 return env[fn_name] # type: ignore[no-any-return]
401
402
403_PF = TypeVar("_PF")
404_TE = TypeVar("_TE")
405
406
407class PluginLoader:
408 def __init__(
409 self, group: str, auto_fn: Optional[Callable[..., Any]] = None
410 ):
411 self.group = group
412 self.impls: Dict[str, Any] = {}
413 self.auto_fn = auto_fn
414
415 def clear(self):
416 self.impls.clear()
417
418 def load(self, name: str) -> Any:
419 if name in self.impls:
420 return self.impls[name]()
421
422 if self.auto_fn:
423 loader = self.auto_fn(name)
424 if loader:
425 self.impls[name] = loader
426 return loader()
427
428 for impl in compat.importlib_metadata_get(self.group):
429 if impl.name == name:
430 self.impls[name] = impl.load
431 return impl.load()
432
433 raise exc.NoSuchModuleError(
434 "Can't load plugin: %s:%s" % (self.group, name)
435 )
436
437 def register(self, name: str, modulepath: str, objname: str) -> None:
438 def load():
439 mod = __import__(modulepath)
440 for token in modulepath.split(".")[1:]:
441 mod = getattr(mod, token)
442 return getattr(mod, objname)
443
444 self.impls[name] = load
445
446 def deregister(self, name: str) -> None:
447 del self.impls[name]
448
449
450def _inspect_func_args(fn):
451 try:
452 co_varkeywords = inspect.CO_VARKEYWORDS
453 except AttributeError:
454 # https://docs.python.org/3/library/inspect.html
455 # The flags are specific to CPython, and may not be defined in other
456 # Python implementations. Furthermore, the flags are an implementation
457 # detail, and can be removed or deprecated in future Python releases.
458 spec = compat.inspect_getfullargspec(fn)
459 return spec[0], bool(spec[2])
460 else:
461 # use fn.__code__ plus flags to reduce method call overhead
462 co = fn.__code__
463 nargs = co.co_argcount
464 return (
465 list(co.co_varnames[:nargs]),
466 bool(co.co_flags & co_varkeywords),
467 )
468
469
470@overload
471def get_cls_kwargs(
472 cls: type,
473 *,
474 _set: Optional[Set[str]] = None,
475 raiseerr: Literal[True] = ...,
476) -> Set[str]: ...
477
478
479@overload
480def get_cls_kwargs(
481 cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
482) -> Optional[Set[str]]: ...
483
484
485def get_cls_kwargs(
486 cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
487) -> Optional[Set[str]]:
488 r"""Return the full set of inherited kwargs for the given `cls`.
489
490 Probes a class's __init__ method, collecting all named arguments. If the
491 __init__ defines a \**kwargs catch-all, then the constructor is presumed
492 to pass along unrecognized keywords to its base classes, and the
493 collection process is repeated recursively on each of the bases.
494
495 Uses a subset of inspect.getfullargspec() to cut down on method overhead,
496 as this is used within the Core typing system to create copies of type
497 objects which is a performance-sensitive operation.
498
499 No anonymous tuple arguments please !
500
501 """
502 toplevel = _set is None
503 if toplevel:
504 _set = set()
505 assert _set is not None
506
507 ctr = cls.__dict__.get("__init__", False)
508
509 has_init = (
510 ctr
511 and isinstance(ctr, types.FunctionType)
512 and isinstance(ctr.__code__, types.CodeType)
513 )
514
515 if has_init:
516 names, has_kw = _inspect_func_args(ctr)
517 _set.update(names)
518
519 if not has_kw and not toplevel:
520 if raiseerr:
521 raise TypeError(
522 f"given cls {cls} doesn't have an __init__ method"
523 )
524 else:
525 return None
526 else:
527 has_kw = False
528
529 if not has_init or has_kw:
530 for c in cls.__bases__:
531 if get_cls_kwargs(c, _set=_set) is None:
532 break
533
534 _set.discard("self")
535 return _set
536
537
538def get_func_kwargs(func: Callable[..., Any]) -> List[str]:
539 """Return the set of legal kwargs for the given `func`.
540
541 Uses getargspec so is safe to call for methods, functions,
542 etc.
543
544 """
545
546 return compat.inspect_getfullargspec(func)[0]
547
548
549def get_callable_argspec(
550 fn: Callable[..., Any], no_self: bool = False, _is_init: bool = False
551) -> compat.FullArgSpec:
552 """Return the argument signature for any callable.
553
554 All pure-Python callables are accepted, including
555 functions, methods, classes, objects with __call__;
556 builtins and other edge cases like functools.partial() objects
557 raise a TypeError.
558
559 """
560 if inspect.isbuiltin(fn):
561 raise TypeError("Can't inspect builtin: %s" % fn)
562 elif inspect.isfunction(fn):
563 if _is_init and no_self:
564 spec = compat.inspect_getfullargspec(fn)
565 return compat.FullArgSpec(
566 spec.args[1:],
567 spec.varargs,
568 spec.varkw,
569 spec.defaults,
570 spec.kwonlyargs,
571 spec.kwonlydefaults,
572 spec.annotations,
573 )
574 else:
575 return compat.inspect_getfullargspec(fn)
576 elif inspect.ismethod(fn):
577 if no_self and (_is_init or fn.__self__):
578 spec = compat.inspect_getfullargspec(fn.__func__)
579 return compat.FullArgSpec(
580 spec.args[1:],
581 spec.varargs,
582 spec.varkw,
583 spec.defaults,
584 spec.kwonlyargs,
585 spec.kwonlydefaults,
586 spec.annotations,
587 )
588 else:
589 return compat.inspect_getfullargspec(fn.__func__)
590 elif inspect.isclass(fn):
591 return get_callable_argspec(
592 fn.__init__, no_self=no_self, _is_init=True
593 )
594 elif hasattr(fn, "__func__"):
595 return compat.inspect_getfullargspec(fn.__func__)
596 elif hasattr(fn, "__call__"):
597 if inspect.ismethod(fn.__call__):
598 return get_callable_argspec(fn.__call__, no_self=no_self)
599 else:
600 raise TypeError("Can't inspect callable: %s" % fn)
601 else:
602 raise TypeError("Can't inspect callable: %s" % fn)
603
604
605def format_argspec_plus(
606 fn: Union[Callable[..., Any], compat.FullArgSpec], grouped: bool = True
607) -> Dict[str, Optional[str]]:
608 """Returns a dictionary of formatted, introspected function arguments.
609
610 A enhanced variant of inspect.formatargspec to support code generation.
611
612 fn
613 An inspectable callable or tuple of inspect getargspec() results.
614 grouped
615 Defaults to True; include (parens, around, argument) lists
616
617 Returns:
618
619 args
620 Full inspect.formatargspec for fn
621 self_arg
622 The name of the first positional argument, varargs[0], or None
623 if the function defines no positional arguments.
624 apply_pos
625 args, re-written in calling rather than receiving syntax. Arguments are
626 passed positionally.
627 apply_kw
628 Like apply_pos, except keyword-ish args are passed as keywords.
629 apply_pos_proxied
630 Like apply_pos but omits the self/cls argument
631
632 Example::
633
634 >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123)
635 {'grouped_args': '(self, a, b, c=3, **d)',
636 'self_arg': 'self',
637 'apply_kw': '(self, a, b, c=c, **d)',
638 'apply_pos': '(self, a, b, c, **d)'}
639
640 """
641 if callable(fn):
642 spec = compat.inspect_getfullargspec(fn)
643 else:
644 spec = fn
645
646 args = compat.inspect_formatargspec(*spec)
647
648 apply_pos = compat.inspect_formatargspec(
649 spec[0], spec[1], spec[2], None, spec[4]
650 )
651
652 if spec[0]:
653 self_arg = spec[0][0]
654
655 apply_pos_proxied = compat.inspect_formatargspec(
656 spec[0][1:], spec[1], spec[2], None, spec[4]
657 )
658
659 elif spec[1]:
660 # I'm not sure what this is
661 self_arg = "%s[0]" % spec[1]
662
663 apply_pos_proxied = apply_pos
664 else:
665 self_arg = None
666 apply_pos_proxied = apply_pos
667
668 num_defaults = 0
669 if spec[3]:
670 num_defaults += len(cast(Tuple[Any], spec[3]))
671 if spec[4]:
672 num_defaults += len(spec[4])
673
674 name_args = spec[0] + spec[4]
675
676 defaulted_vals: Union[List[str], Tuple[()]]
677
678 if num_defaults:
679 defaulted_vals = name_args[0 - num_defaults :]
680 else:
681 defaulted_vals = ()
682
683 apply_kw = compat.inspect_formatargspec(
684 name_args,
685 spec[1],
686 spec[2],
687 defaulted_vals,
688 formatvalue=lambda x: "=" + str(x),
689 )
690
691 if spec[0]:
692 apply_kw_proxied = compat.inspect_formatargspec(
693 name_args[1:],
694 spec[1],
695 spec[2],
696 defaulted_vals,
697 formatvalue=lambda x: "=" + str(x),
698 )
699 else:
700 apply_kw_proxied = apply_kw
701
702 if grouped:
703 return dict(
704 grouped_args=args,
705 self_arg=self_arg,
706 apply_pos=apply_pos,
707 apply_kw=apply_kw,
708 apply_pos_proxied=apply_pos_proxied,
709 apply_kw_proxied=apply_kw_proxied,
710 )
711 else:
712 return dict(
713 grouped_args=args,
714 self_arg=self_arg,
715 apply_pos=apply_pos[1:-1],
716 apply_kw=apply_kw[1:-1],
717 apply_pos_proxied=apply_pos_proxied[1:-1],
718 apply_kw_proxied=apply_kw_proxied[1:-1],
719 )
720
721
722def format_argspec_init(method, grouped=True):
723 """format_argspec_plus with considerations for typical __init__ methods
724
725 Wraps format_argspec_plus with error handling strategies for typical
726 __init__ cases:
727
728 .. sourcecode:: text
729
730 object.__init__ -> (self)
731 other unreflectable (usually C) -> (self, *args, **kwargs)
732
733 """
734 if method is object.__init__:
735 grouped_args = "(self)"
736 args = "(self)" if grouped else "self"
737 proxied = "()" if grouped else ""
738 else:
739 try:
740 return format_argspec_plus(method, grouped=grouped)
741 except TypeError:
742 grouped_args = "(self, *args, **kwargs)"
743 args = grouped_args if grouped else "self, *args, **kwargs"
744 proxied = "(*args, **kwargs)" if grouped else "*args, **kwargs"
745 return dict(
746 self_arg="self",
747 grouped_args=grouped_args,
748 apply_pos=args,
749 apply_kw=args,
750 apply_pos_proxied=proxied,
751 apply_kw_proxied=proxied,
752 )
753
754
755def create_proxy_methods(
756 target_cls: Type[Any],
757 target_cls_sphinx_name: str,
758 proxy_cls_sphinx_name: str,
759 classmethods: Sequence[str] = (),
760 methods: Sequence[str] = (),
761 attributes: Sequence[str] = (),
762 use_intermediate_variable: Sequence[str] = (),
763) -> Callable[[_T], _T]:
764 """A class decorator indicating attributes should refer to a proxy
765 class.
766
767 This decorator is now a "marker" that does nothing at runtime. Instead,
768 it is consumed by the tools/generate_proxy_methods.py script to
769 statically generate proxy methods and attributes that are fully
770 recognized by typing tools such as mypy.
771
772 """
773
774 def decorate(cls):
775 return cls
776
777 return decorate
778
779
780def getargspec_init(method):
781 """inspect.getargspec with considerations for typical __init__ methods
782
783 Wraps inspect.getargspec with error handling for typical __init__ cases:
784
785 .. sourcecode:: text
786
787 object.__init__ -> (self)
788 other unreflectable (usually C) -> (self, *args, **kwargs)
789
790 """
791 try:
792 return compat.inspect_getfullargspec(method)
793 except TypeError:
794 if method is object.__init__:
795 return (["self"], None, None, None)
796 else:
797 return (["self"], "args", "kwargs", None)
798
799
800def unbound_method_to_callable(func_or_cls):
801 """Adjust the incoming callable such that a 'self' argument is not
802 required.
803
804 """
805
806 if isinstance(func_or_cls, types.MethodType) and not func_or_cls.__self__:
807 return func_or_cls.__func__
808 else:
809 return func_or_cls
810
811
812def generic_repr(
813 obj: Any,
814 additional_kw: Sequence[Tuple[str, Any]] = (),
815 to_inspect: Optional[Union[object, List[object]]] = None,
816 omit_kwarg: Sequence[str] = (),
817) -> str:
818 """Produce a __repr__() based on direct association of the __init__()
819 specification vs. same-named attributes present.
820
821 """
822 if to_inspect is None:
823 to_inspect = [obj]
824 else:
825 to_inspect = _collections.to_list(to_inspect)
826
827 missing = object()
828
829 pos_args = []
830 kw_args: _collections.OrderedDict[str, Any] = _collections.OrderedDict()
831 vargs = None
832 for i, insp in enumerate(to_inspect):
833 try:
834 spec = compat.inspect_getfullargspec(insp.__init__)
835 except TypeError:
836 continue
837 else:
838 default_len = len(spec.defaults) if spec.defaults else 0
839 if i == 0:
840 if spec.varargs:
841 vargs = spec.varargs
842 if default_len:
843 pos_args.extend(spec.args[1:-default_len])
844 else:
845 pos_args.extend(spec.args[1:])
846 else:
847 kw_args.update(
848 [(arg, missing) for arg in spec.args[1:-default_len]]
849 )
850
851 if default_len:
852 assert spec.defaults
853 kw_args.update(
854 [
855 (arg, default)
856 for arg, default in zip(
857 spec.args[-default_len:], spec.defaults
858 )
859 ]
860 )
861 output: List[str] = []
862
863 output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
864
865 if vargs is not None and hasattr(obj, vargs):
866 output.extend([repr(val) for val in getattr(obj, vargs)])
867
868 for arg, defval in kw_args.items():
869 if arg in omit_kwarg:
870 continue
871 try:
872 val = getattr(obj, arg, missing)
873 if val is not missing and val != defval:
874 output.append("%s=%r" % (arg, val))
875 except Exception:
876 pass
877
878 if additional_kw:
879 for arg, defval in additional_kw:
880 try:
881 val = getattr(obj, arg, missing)
882 if val is not missing and val != defval:
883 output.append("%s=%r" % (arg, val))
884 except Exception:
885 pass
886
887 return "%s(%s)" % (obj.__class__.__name__, ", ".join(output))
888
889
890def class_hierarchy(cls):
891 """Return an unordered sequence of all classes related to cls.
892
893 Traverses diamond hierarchies.
894
895 Fibs slightly: subclasses of builtin types are not returned. Thus
896 class_hierarchy(class A(object)) returns (A, object), not A plus every
897 class systemwide that derives from object.
898
899 """
900
901 hier = {cls}
902 process = list(cls.__mro__)
903 while process:
904 c = process.pop()
905 bases = (_ for _ in c.__bases__ if _ not in hier)
906
907 for b in bases:
908 process.append(b)
909 hier.add(b)
910
911 if c.__module__ == "builtins" or not hasattr(c, "__subclasses__"):
912 continue
913
914 for s in [
915 _
916 for _ in (
917 c.__subclasses__()
918 if not issubclass(c, type)
919 else c.__subclasses__(c)
920 )
921 if _ not in hier
922 ]:
923 process.append(s)
924 hier.add(s)
925 return list(hier)
926
927
928def iterate_attributes(cls):
929 """iterate all the keys and attributes associated
930 with a class, without using getattr().
931
932 Does not use getattr() so that class-sensitive
933 descriptors (i.e. property.__get__()) are not called.
934
935 """
936 keys = dir(cls)
937 for key in keys:
938 for c in cls.__mro__:
939 if key in c.__dict__:
940 yield (key, c.__dict__[key])
941 break
942
943
944def monkeypatch_proxied_specials(
945 into_cls,
946 from_cls,
947 skip=None,
948 only=None,
949 name="self.proxy",
950 from_instance=None,
951):
952 """Automates delegation of __specials__ for a proxying type."""
953
954 if only:
955 dunders = only
956 else:
957 if skip is None:
958 skip = (
959 "__slots__",
960 "__del__",
961 "__getattribute__",
962 "__metaclass__",
963 "__getstate__",
964 "__setstate__",
965 )
966 dunders = [
967 m
968 for m in dir(from_cls)
969 if (
970 m.startswith("__")
971 and m.endswith("__")
972 and not hasattr(into_cls, m)
973 and m not in skip
974 )
975 ]
976
977 for method in dunders:
978 try:
979 maybe_fn = getattr(from_cls, method)
980 if not hasattr(maybe_fn, "__call__"):
981 continue
982 maybe_fn = getattr(maybe_fn, "__func__", maybe_fn)
983 fn = cast(types.FunctionType, maybe_fn)
984
985 except AttributeError:
986 continue
987 try:
988 spec = compat.inspect_getfullargspec(fn)
989 fn_args = compat.inspect_formatargspec(spec[0])
990 d_args = compat.inspect_formatargspec(spec[0][1:])
991 except TypeError:
992 fn_args = "(self, *args, **kw)"
993 d_args = "(*args, **kw)"
994
995 py = (
996 "def %(method)s%(fn_args)s: "
997 "return %(name)s.%(method)s%(d_args)s" % locals()
998 )
999
1000 env: Dict[str, types.FunctionType] = (
1001 from_instance is not None and {name: from_instance} or {}
1002 )
1003 exec(py, env)
1004 try:
1005 env[method].__defaults__ = fn.__defaults__
1006 except AttributeError:
1007 pass
1008 setattr(into_cls, method, env[method])
1009
1010
1011def methods_equivalent(meth1, meth2):
1012 """Return True if the two methods are the same implementation."""
1013
1014 return getattr(meth1, "__func__", meth1) is getattr(
1015 meth2, "__func__", meth2
1016 )
1017
1018
1019def as_interface(obj, cls=None, methods=None, required=None):
1020 """Ensure basic interface compliance for an instance or dict of callables.
1021
1022 Checks that ``obj`` implements public methods of ``cls`` or has members
1023 listed in ``methods``. If ``required`` is not supplied, implementing at
1024 least one interface method is sufficient. Methods present on ``obj`` that
1025 are not in the interface are ignored.
1026
1027 If ``obj`` is a dict and ``dict`` does not meet the interface
1028 requirements, the keys of the dictionary are inspected. Keys present in
1029 ``obj`` that are not in the interface will raise TypeErrors.
1030
1031 Raises TypeError if ``obj`` does not meet the interface criteria.
1032
1033 In all passing cases, an object with callable members is returned. In the
1034 simple case, ``obj`` is returned as-is; if dict processing kicks in then
1035 an anonymous class is returned.
1036
1037 obj
1038 A type, instance, or dictionary of callables.
1039 cls
1040 Optional, a type. All public methods of cls are considered the
1041 interface. An ``obj`` instance of cls will always pass, ignoring
1042 ``required``..
1043 methods
1044 Optional, a sequence of method names to consider as the interface.
1045 required
1046 Optional, a sequence of mandatory implementations. If omitted, an
1047 ``obj`` that provides at least one interface method is considered
1048 sufficient. As a convenience, required may be a type, in which case
1049 all public methods of the type are required.
1050
1051 """
1052 if not cls and not methods:
1053 raise TypeError("a class or collection of method names are required")
1054
1055 if isinstance(cls, type) and isinstance(obj, cls):
1056 return obj
1057
1058 interface = set(methods or [m for m in dir(cls) if not m.startswith("_")])
1059 implemented = set(dir(obj))
1060
1061 complies = operator.ge
1062 if isinstance(required, type):
1063 required = interface
1064 elif not required:
1065 required = set()
1066 complies = operator.gt
1067 else:
1068 required = set(required)
1069
1070 if complies(implemented.intersection(interface), required):
1071 return obj
1072
1073 # No dict duck typing here.
1074 if not isinstance(obj, dict):
1075 qualifier = complies is operator.gt and "any of" or "all of"
1076 raise TypeError(
1077 "%r does not implement %s: %s"
1078 % (obj, qualifier, ", ".join(interface))
1079 )
1080
1081 class AnonymousInterface:
1082 """A callable-holding shell."""
1083
1084 if cls:
1085 AnonymousInterface.__name__ = "Anonymous" + cls.__name__
1086 found = set()
1087
1088 for method, impl in dictlike_iteritems(obj):
1089 if method not in interface:
1090 raise TypeError("%r: unknown in this interface" % method)
1091 if not callable(impl):
1092 raise TypeError("%r=%r is not callable" % (method, impl))
1093 setattr(AnonymousInterface, method, staticmethod(impl))
1094 found.add(method)
1095
1096 if complies(found, required):
1097 return AnonymousInterface
1098
1099 raise TypeError(
1100 "dictionary does not contain required keys %s"
1101 % ", ".join(required - found)
1102 )
1103
1104
1105_GFD = TypeVar("_GFD", bound="generic_fn_descriptor[Any]")
1106
1107
1108class generic_fn_descriptor(Generic[_T_co]):
1109 """Descriptor which proxies a function when the attribute is not
1110 present in dict
1111
1112 This superclass is organized in a particular way with "memoized" and
1113 "non-memoized" implementation classes that are hidden from type checkers,
1114 as Mypy seems to not be able to handle seeing multiple kinds of descriptor
1115 classes used for the same attribute.
1116
1117 """
1118
1119 fget: Callable[..., _T_co]
1120 __doc__: Optional[str]
1121 __name__: str
1122
1123 def __init__(self, fget: Callable[..., _T_co], doc: Optional[str] = None):
1124 self.fget = fget
1125 self.__doc__ = doc or fget.__doc__
1126 self.__name__ = fget.__name__
1127
1128 @overload
1129 def __get__(self: _GFD, obj: None, cls: Any) -> _GFD: ...
1130
1131 @overload
1132 def __get__(self, obj: object, cls: Any) -> _T_co: ...
1133
1134 def __get__(self: _GFD, obj: Any, cls: Any) -> Union[_GFD, _T_co]:
1135 raise NotImplementedError()
1136
1137 if TYPE_CHECKING:
1138
1139 def __set__(self, instance: Any, value: Any) -> None: ...
1140
1141 def __delete__(self, instance: Any) -> None: ...
1142
1143 def _reset(self, obj: Any) -> None:
1144 raise NotImplementedError()
1145
1146 @classmethod
1147 def reset(cls, obj: Any, name: str) -> None:
1148 raise NotImplementedError()
1149
1150
1151class _non_memoized_property(generic_fn_descriptor[_T_co]):
1152 """a plain descriptor that proxies a function.
1153
1154 primary rationale is to provide a plain attribute that's
1155 compatible with memoized_property which is also recognized as equivalent
1156 by mypy.
1157
1158 """
1159
1160 if not TYPE_CHECKING:
1161
1162 def __get__(self, obj, cls):
1163 if obj is None:
1164 return self
1165 return self.fget(obj)
1166
1167
1168class _memoized_property(generic_fn_descriptor[_T_co]):
1169 """A read-only @property that is only evaluated once."""
1170
1171 if not TYPE_CHECKING:
1172
1173 def __get__(self, obj, cls):
1174 if obj is None:
1175 return self
1176 obj.__dict__[self.__name__] = result = self.fget(obj)
1177 return result
1178
1179 def _reset(self, obj):
1180 _memoized_property.reset(obj, self.__name__)
1181
1182 @classmethod
1183 def reset(cls, obj, name):
1184 obj.__dict__.pop(name, None)
1185
1186
1187# despite many attempts to get Mypy to recognize an overridden descriptor
1188# where one is memoized and the other isn't, there seems to be no reliable
1189# way other than completely deceiving the type checker into thinking there
1190# is just one single descriptor type everywhere. Otherwise, if a superclass
1191# has non-memoized and subclass has memoized, that requires
1192# "class memoized(non_memoized)". but then if a superclass has memoized and
1193# superclass has non-memoized, the class hierarchy of the descriptors
1194# would need to be reversed; "class non_memoized(memoized)". so there's no
1195# way to achieve this.
1196# additional issues, RO properties:
1197# https://github.com/python/mypy/issues/12440
1198if TYPE_CHECKING:
1199 # allow memoized and non-memoized to be freely mixed by having them
1200 # be the same class
1201 memoized_property = generic_fn_descriptor
1202 non_memoized_property = generic_fn_descriptor
1203
1204 # for read only situations, mypy only sees @property as read only.
1205 # read only is needed when a subtype specializes the return type
1206 # of a property, meaning assignment needs to be disallowed
1207 ro_memoized_property = property
1208 ro_non_memoized_property = property
1209
1210else:
1211 memoized_property = ro_memoized_property = _memoized_property
1212 non_memoized_property = ro_non_memoized_property = _non_memoized_property
1213
1214
1215def memoized_instancemethod(fn: _F) -> _F:
1216 """Decorate a method memoize its return value.
1217
1218 Best applied to no-arg methods: memoization is not sensitive to
1219 argument values, and will always return the same value even when
1220 called with different arguments.
1221
1222 """
1223
1224 def oneshot(self, *args, **kw):
1225 result = fn(self, *args, **kw)
1226
1227 def memo(*a, **kw):
1228 return result
1229
1230 memo.__name__ = fn.__name__
1231 memo.__doc__ = fn.__doc__
1232 self.__dict__[fn.__name__] = memo
1233 return result
1234
1235 return update_wrapper(oneshot, fn) # type: ignore
1236
1237
1238class HasMemoized:
1239 """A mixin class that maintains the names of memoized elements in a
1240 collection for easy cache clearing, generative, etc.
1241
1242 """
1243
1244 if not TYPE_CHECKING:
1245 # support classes that want to have __slots__ with an explicit
1246 # slot for __dict__. not sure if that requires base __slots__ here.
1247 __slots__ = ()
1248
1249 _memoized_keys: FrozenSet[str] = frozenset()
1250
1251 def _reset_memoizations(self) -> None:
1252 for elem in self._memoized_keys:
1253 self.__dict__.pop(elem, None)
1254
1255 def _assert_no_memoizations(self) -> None:
1256 for elem in self._memoized_keys:
1257 assert elem not in self.__dict__
1258
1259 def _set_memoized_attribute(self, key: str, value: Any) -> None:
1260 self.__dict__[key] = value
1261 self._memoized_keys |= {key}
1262
1263 class memoized_attribute(memoized_property[_T]):
1264 """A read-only @property that is only evaluated once.
1265
1266 :meta private:
1267
1268 """
1269
1270 fget: Callable[..., _T]
1271 __doc__: Optional[str]
1272 __name__: str
1273
1274 def __init__(self, fget: Callable[..., _T], doc: Optional[str] = None):
1275 self.fget = fget
1276 self.__doc__ = doc or fget.__doc__
1277 self.__name__ = fget.__name__
1278
1279 @overload
1280 def __get__(self: _MA, obj: None, cls: Any) -> _MA: ...
1281
1282 @overload
1283 def __get__(self, obj: Any, cls: Any) -> _T: ...
1284
1285 def __get__(self, obj, cls):
1286 if obj is None:
1287 return self
1288 obj.__dict__[self.__name__] = result = self.fget(obj)
1289 obj._memoized_keys |= {self.__name__}
1290 return result
1291
1292 @classmethod
1293 def memoized_instancemethod(cls, fn: _F) -> _F:
1294 """Decorate a method memoize its return value.
1295
1296 :meta private:
1297
1298 """
1299
1300 def oneshot(self: Any, *args: Any, **kw: Any) -> Any:
1301 result = fn(self, *args, **kw)
1302
1303 def memo(*a, **kw):
1304 return result
1305
1306 memo.__name__ = fn.__name__
1307 memo.__doc__ = fn.__doc__
1308 self.__dict__[fn.__name__] = memo
1309 self._memoized_keys |= {fn.__name__}
1310 return result
1311
1312 return update_wrapper(oneshot, fn) # type: ignore
1313
1314
1315if TYPE_CHECKING:
1316 HasMemoized_ro_memoized_attribute = property
1317else:
1318 HasMemoized_ro_memoized_attribute = HasMemoized.memoized_attribute
1319
1320
1321class MemoizedSlots:
1322 """Apply memoized items to an object using a __getattr__ scheme.
1323
1324 This allows the functionality of memoized_property and
1325 memoized_instancemethod to be available to a class using __slots__.
1326
1327 The memoized get is not threadsafe under freethreading and the
1328 creator method may in extremely rare cases be called more than once.
1329
1330 """
1331
1332 __slots__ = ()
1333
1334 def _fallback_getattr(self, key):
1335 raise AttributeError(key)
1336
1337 def __getattr__(self, key: str) -> Any:
1338 if key.startswith("_memoized_attr_") or key.startswith(
1339 "_memoized_method_"
1340 ):
1341 raise AttributeError(key)
1342 # to avoid recursion errors when interacting with other __getattr__
1343 # schemes that refer to this one, when testing for memoized method
1344 # look at __class__ only rather than going into __getattr__ again.
1345 elif hasattr(self.__class__, f"_memoized_attr_{key}"):
1346 value = getattr(self, f"_memoized_attr_{key}")()
1347 setattr(self, key, value)
1348 return value
1349 elif hasattr(self.__class__, f"_memoized_method_{key}"):
1350 meth = getattr(self, f"_memoized_method_{key}")
1351
1352 def oneshot(*args, **kw):
1353 result = meth(*args, **kw)
1354
1355 def memo(*a, **kw):
1356 return result
1357
1358 memo.__name__ = meth.__name__
1359 memo.__doc__ = meth.__doc__
1360 setattr(self, key, memo)
1361 return result
1362
1363 oneshot.__doc__ = meth.__doc__
1364 return oneshot
1365 else:
1366 return self._fallback_getattr(key)
1367
1368
1369# from paste.deploy.converters
1370def asbool(obj: Any) -> bool:
1371 if isinstance(obj, str):
1372 obj = obj.strip().lower()
1373 if obj in ["true", "yes", "on", "y", "t", "1"]:
1374 return True
1375 elif obj in ["false", "no", "off", "n", "f", "0"]:
1376 return False
1377 else:
1378 raise ValueError("String is not true/false: %r" % obj)
1379 return bool(obj)
1380
1381
1382def bool_or_str(*text: str) -> Callable[[str], Union[str, bool]]:
1383 """Return a callable that will evaluate a string as
1384 boolean, or one of a set of "alternate" string values.
1385
1386 """
1387
1388 def bool_or_value(obj: str) -> Union[str, bool]:
1389 if obj in text:
1390 return obj
1391 else:
1392 return asbool(obj)
1393
1394 return bool_or_value
1395
1396
1397def asint(value: Any) -> Optional[int]:
1398 """Coerce to integer."""
1399
1400 if value is None:
1401 return value
1402 return int(value)
1403
1404
1405def coerce_kw_type(
1406 kw: Dict[str, Any],
1407 key: str,
1408 type_: Type[Any],
1409 flexi_bool: bool = True,
1410 dest: Optional[Dict[str, Any]] = None,
1411) -> None:
1412 r"""If 'key' is present in dict 'kw', coerce its value to type 'type\_' if
1413 necessary. If 'flexi_bool' is True, the string '0' is considered false
1414 when coercing to boolean.
1415 """
1416
1417 if dest is None:
1418 dest = kw
1419
1420 if (
1421 key in kw
1422 and (not isinstance(type_, type) or not isinstance(kw[key], type_))
1423 and kw[key] is not None
1424 ):
1425 if type_ is bool and flexi_bool:
1426 dest[key] = asbool(kw[key])
1427 else:
1428 dest[key] = type_(kw[key])
1429
1430
1431def constructor_key(obj: Any, cls: Type[Any]) -> Tuple[Any, ...]:
1432 """Produce a tuple structure that is cacheable using the __dict__ of
1433 obj to retrieve values
1434
1435 """
1436 names = get_cls_kwargs(cls)
1437 return (cls,) + tuple(
1438 (k, obj.__dict__[k]) for k in names if k in obj.__dict__
1439 )
1440
1441
1442def constructor_copy(obj: _T, cls: Type[_T], *args: Any, **kw: Any) -> _T:
1443 """Instantiate cls using the __dict__ of obj as constructor arguments.
1444
1445 Uses inspect to match the named arguments of ``cls``.
1446
1447 """
1448
1449 names = get_cls_kwargs(cls)
1450 kw.update(
1451 (k, obj.__dict__[k]) for k in names.difference(kw) if k in obj.__dict__
1452 )
1453 return cls(*args, **kw)
1454
1455
1456def counter() -> Callable[[], int]:
1457 """Return a threadsafe counter function."""
1458
1459 lock = threading.Lock()
1460 counter = itertools.count(1)
1461
1462 # avoid the 2to3 "next" transformation...
1463 def _next():
1464 with lock:
1465 return next(counter)
1466
1467 return _next
1468
1469
1470def duck_type_collection(
1471 specimen: Any, default: Optional[Type[Any]] = None
1472) -> Optional[Type[Any]]:
1473 """Given an instance or class, guess if it is or is acting as one of
1474 the basic collection types: list, set and dict. If the __emulates__
1475 property is present, return that preferentially.
1476 """
1477
1478 if hasattr(specimen, "__emulates__"):
1479 # canonicalize set vs sets.Set to a standard: the builtin set
1480 if specimen.__emulates__ is not None and issubclass(
1481 specimen.__emulates__, set
1482 ):
1483 return set
1484 else:
1485 return specimen.__emulates__ # type: ignore
1486
1487 isa = issubclass if isinstance(specimen, type) else isinstance
1488 if isa(specimen, list):
1489 return list
1490 elif isa(specimen, set):
1491 return set
1492 elif isa(specimen, dict):
1493 return dict
1494
1495 if hasattr(specimen, "append"):
1496 return list
1497 elif hasattr(specimen, "add"):
1498 return set
1499 elif hasattr(specimen, "set"):
1500 return dict
1501 else:
1502 return default
1503
1504
1505def assert_arg_type(
1506 arg: Any, argtype: Union[Tuple[Type[Any], ...], Type[Any]], name: str
1507) -> Any:
1508 if isinstance(arg, argtype):
1509 return arg
1510 else:
1511 if isinstance(argtype, tuple):
1512 raise exc.ArgumentError(
1513 "Argument '%s' is expected to be one of type %s, got '%s'"
1514 % (name, " or ".join("'%s'" % a for a in argtype), type(arg))
1515 )
1516 else:
1517 raise exc.ArgumentError(
1518 "Argument '%s' is expected to be of type '%s', got '%s'"
1519 % (name, argtype, type(arg))
1520 )
1521
1522
1523def dictlike_iteritems(dictlike):
1524 """Return a (key, value) iterator for almost any dict-like object."""
1525
1526 if hasattr(dictlike, "items"):
1527 return list(dictlike.items())
1528
1529 getter = getattr(dictlike, "__getitem__", getattr(dictlike, "get", None))
1530 if getter is None:
1531 raise TypeError("Object '%r' is not dict-like" % dictlike)
1532
1533 if hasattr(dictlike, "iterkeys"):
1534
1535 def iterator():
1536 for key in dictlike.iterkeys():
1537 assert getter is not None
1538 yield key, getter(key)
1539
1540 return iterator()
1541 elif hasattr(dictlike, "keys"):
1542 return iter((key, getter(key)) for key in dictlike.keys())
1543 else:
1544 raise TypeError("Object '%r' is not dict-like" % dictlike)
1545
1546
1547class classproperty(property):
1548 """A decorator that behaves like @property except that operates
1549 on classes rather than instances.
1550
1551 The decorator is currently special when using the declarative
1552 module, but note that the
1553 :class:`~.sqlalchemy.ext.declarative.declared_attr`
1554 decorator should be used for this purpose with declarative.
1555
1556 """
1557
1558 fget: Callable[[Any], Any]
1559
1560 def __init__(self, fget: Callable[[Any], Any], *arg: Any, **kw: Any):
1561 super().__init__(fget, *arg, **kw)
1562 self.__doc__ = fget.__doc__
1563
1564 def __get__(self, obj: Any, cls: Optional[type] = None) -> Any:
1565 return self.fget(cls)
1566
1567
1568class hybridproperty(Generic[_T]):
1569 def __init__(self, func: Callable[..., _T]):
1570 self.func = func
1571 self.clslevel = func
1572
1573 def __get__(self, instance: Any, owner: Any) -> _T:
1574 if instance is None:
1575 clsval = self.clslevel(owner)
1576 return clsval
1577 else:
1578 return self.func(instance)
1579
1580 def classlevel(self, func: Callable[..., Any]) -> hybridproperty[_T]:
1581 self.clslevel = func
1582 return self
1583
1584
1585class rw_hybridproperty(Generic[_T]):
1586 def __init__(self, func: Callable[..., _T]):
1587 self.func = func
1588 self.clslevel = func
1589 self.setfn: Optional[Callable[..., Any]] = None
1590
1591 def __get__(self, instance: Any, owner: Any) -> _T:
1592 if instance is None:
1593 clsval = self.clslevel(owner)
1594 return clsval
1595 else:
1596 return self.func(instance)
1597
1598 def __set__(self, instance: Any, value: Any) -> None:
1599 assert self.setfn is not None
1600 self.setfn(instance, value)
1601
1602 def setter(self, func: Callable[..., Any]) -> rw_hybridproperty[_T]:
1603 self.setfn = func
1604 return self
1605
1606 def classlevel(self, func: Callable[..., Any]) -> rw_hybridproperty[_T]:
1607 self.clslevel = func
1608 return self
1609
1610
1611class hybridmethod(Generic[_T]):
1612 """Decorate a function as cls- or instance- level."""
1613
1614 def __init__(self, func: Callable[..., _T]):
1615 self.func = self.__func__ = func
1616 self.clslevel = func
1617
1618 def __get__(self, instance: Any, owner: Any) -> Callable[..., _T]:
1619 if instance is None:
1620 return self.clslevel.__get__(owner, owner.__class__) # type:ignore
1621 else:
1622 return self.func.__get__(instance, owner) # type:ignore
1623
1624 def classlevel(self, func: Callable[..., Any]) -> hybridmethod[_T]:
1625 self.clslevel = func
1626 return self
1627
1628
1629class symbol(int):
1630 """A constant symbol.
1631
1632 >>> symbol("foo") is symbol("foo")
1633 True
1634 >>> symbol("foo")
1635 <symbol 'foo>
1636
1637 A slight refinement of the MAGICCOOKIE=object() pattern. The primary
1638 advantage of symbol() is its repr(). They are also singletons.
1639
1640 Repeated calls of symbol('name') will all return the same instance.
1641
1642 """
1643
1644 name: str
1645
1646 symbols: Dict[str, symbol] = {}
1647 _lock = threading.Lock()
1648
1649 def __new__(
1650 cls,
1651 name: str,
1652 doc: Optional[str] = None,
1653 canonical: Optional[int] = None,
1654 ) -> symbol:
1655 with cls._lock:
1656 sym = cls.symbols.get(name)
1657 if sym is None:
1658 assert isinstance(name, str)
1659 if canonical is None:
1660 canonical = hash(name)
1661 sym = int.__new__(symbol, canonical)
1662 sym.name = name
1663 if doc:
1664 sym.__doc__ = doc
1665
1666 # NOTE: we should ultimately get rid of this global thing,
1667 # however, currently it is to support pickling. The best
1668 # change would be when we are on py3.11 at a minimum, we
1669 # switch to stdlib enum.IntFlag.
1670 cls.symbols[name] = sym
1671 else:
1672 if canonical and canonical != sym:
1673 raise TypeError(
1674 f"Can't replace canonical symbol for {name!r} "
1675 f"with new int value {canonical}"
1676 )
1677 return sym
1678
1679 def __reduce__(self):
1680 return symbol, (self.name, "x", int(self))
1681
1682 def __str__(self):
1683 return repr(self)
1684
1685 def __repr__(self):
1686 return f"symbol({self.name!r})"
1687
1688
1689class _IntFlagMeta(type):
1690 def __init__(
1691 cls,
1692 classname: str,
1693 bases: Tuple[Type[Any], ...],
1694 dict_: Dict[str, Any],
1695 **kw: Any,
1696 ) -> None:
1697 items: List[symbol]
1698 cls._items = items = []
1699 for k, v in dict_.items():
1700 if re.match(r"^__.*__$", k):
1701 continue
1702 if isinstance(v, int):
1703 sym = symbol(k, canonical=v)
1704 elif not k.startswith("_"):
1705 raise TypeError("Expected integer values for IntFlag")
1706 else:
1707 continue
1708 setattr(cls, k, sym)
1709 items.append(sym)
1710
1711 cls.__members__ = _collections.immutabledict(
1712 {sym.name: sym for sym in items}
1713 )
1714
1715 def __iter__(self) -> Iterator[symbol]:
1716 raise NotImplementedError(
1717 "iter not implemented to ensure compatibility with "
1718 "Python 3.11 IntFlag. Please use __members__. See "
1719 "https://github.com/python/cpython/issues/99304"
1720 )
1721
1722
1723class _FastIntFlag(metaclass=_IntFlagMeta):
1724 """An 'IntFlag' copycat that isn't slow when performing bitwise
1725 operations.
1726
1727 the ``FastIntFlag`` class will return ``enum.IntFlag`` under TYPE_CHECKING
1728 and ``_FastIntFlag`` otherwise.
1729
1730 """
1731
1732
1733if TYPE_CHECKING:
1734 from enum import IntFlag
1735
1736 FastIntFlag = IntFlag
1737else:
1738 FastIntFlag = _FastIntFlag
1739
1740
1741_E = TypeVar("_E", bound=enum.Enum)
1742
1743
1744def parse_user_argument_for_enum(
1745 arg: Any,
1746 choices: Dict[_E, List[Any]],
1747 name: str,
1748 resolve_symbol_names: bool = False,
1749) -> Optional[_E]:
1750 """Given a user parameter, parse the parameter into a chosen value
1751 from a list of choice objects, typically Enum values.
1752
1753 The user argument can be a string name that matches the name of a
1754 symbol, or the symbol object itself, or any number of alternate choices
1755 such as True/False/ None etc.
1756
1757 :param arg: the user argument.
1758 :param choices: dictionary of enum values to lists of possible
1759 entries for each.
1760 :param name: name of the argument. Used in an :class:`.ArgumentError`
1761 that is raised if the parameter doesn't match any available argument.
1762
1763 """
1764 for enum_value, choice in choices.items():
1765 if arg is enum_value:
1766 return enum_value
1767 elif resolve_symbol_names and arg == enum_value.name:
1768 return enum_value
1769 elif arg in choice:
1770 return enum_value
1771
1772 if arg is None:
1773 return None
1774
1775 raise exc.ArgumentError(f"Invalid value for '{name}': {arg!r}")
1776
1777
1778_creation_order = 1
1779
1780
1781def set_creation_order(instance: Any) -> None:
1782 """Assign a '_creation_order' sequence to the given instance.
1783
1784 This allows multiple instances to be sorted in order of creation
1785 (typically within a single thread; the counter is not particularly
1786 threadsafe).
1787
1788 """
1789 global _creation_order
1790 instance._creation_order = _creation_order
1791 _creation_order += 1
1792
1793
1794def warn_exception(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
1795 """executes the given function, catches all exceptions and converts to
1796 a warning.
1797
1798 """
1799 try:
1800 return func(*args, **kwargs)
1801 except Exception:
1802 warn("%s('%s') ignored" % sys.exc_info()[0:2])
1803
1804
1805def ellipses_string(value, len_=25):
1806 try:
1807 if len(value) > len_:
1808 return "%s..." % value[0:len_]
1809 else:
1810 return value
1811 except TypeError:
1812 return value
1813
1814
1815class _hash_limit_string(str):
1816 """A string subclass that can only be hashed on a maximum amount
1817 of unique values.
1818
1819 This is used for warnings so that we can send out parameterized warnings
1820 without the __warningregistry__ of the module, or the non-overridable
1821 "once" registry within warnings.py, overloading memory,
1822
1823
1824 """
1825
1826 _hash: int
1827
1828 def __new__(
1829 cls, value: str, num: int, args: Sequence[Any]
1830 ) -> _hash_limit_string:
1831 interpolated = (value % args) + (
1832 " (this warning may be suppressed after %d occurrences)" % num
1833 )
1834 self = super().__new__(cls, interpolated)
1835 self._hash = hash("%s_%d" % (value, hash(interpolated) % num))
1836 return self
1837
1838 def __hash__(self) -> int:
1839 return self._hash
1840
1841 def __eq__(self, other: Any) -> bool:
1842 return hash(self) == hash(other)
1843
1844
1845def warn(msg: str, code: Optional[str] = None) -> None:
1846 """Issue a warning.
1847
1848 If msg is a string, :class:`.exc.SAWarning` is used as
1849 the category.
1850
1851 """
1852 if code:
1853 _warnings_warn(exc.SAWarning(msg, code=code))
1854 else:
1855 _warnings_warn(msg, exc.SAWarning)
1856
1857
1858def warn_limited(msg: str, args: Sequence[Any]) -> None:
1859 """Issue a warning with a parameterized string, limiting the number
1860 of registrations.
1861
1862 """
1863 if args:
1864 msg = _hash_limit_string(msg, 10, args)
1865 _warnings_warn(msg, exc.SAWarning)
1866
1867
1868_warning_tags: Dict[CodeType, Tuple[str, Type[Warning]]] = {}
1869
1870
1871def tag_method_for_warnings(
1872 message: str, category: Type[Warning]
1873) -> Callable[[_F], _F]:
1874 def go(fn):
1875 _warning_tags[fn.__code__] = (message, category)
1876 return fn
1877
1878 return go
1879
1880
1881_not_sa_pattern = re.compile(r"^(?:sqlalchemy\.(?!testing)|alembic\.)")
1882
1883
1884def _warnings_warn(
1885 message: Union[str, Warning],
1886 category: Optional[Type[Warning]] = None,
1887 stacklevel: int = 2,
1888) -> None:
1889
1890 if category is None and isinstance(message, Warning):
1891 category = type(message)
1892
1893 # adjust the given stacklevel to be outside of SQLAlchemy
1894 try:
1895 frame = sys._getframe(stacklevel)
1896 except ValueError:
1897 # being called from less than 3 (or given) stacklevels, weird,
1898 # but don't crash
1899 stacklevel = 0
1900 except:
1901 # _getframe() doesn't work, weird interpreter issue, weird,
1902 # ok, but don't crash
1903 stacklevel = 0
1904 else:
1905 stacklevel_found = warning_tag_found = False
1906 while frame is not None:
1907 # using __name__ here requires that we have __name__ in the
1908 # __globals__ of the decorated string functions we make also.
1909 # we generate this using {"__name__": fn.__module__}
1910 if not stacklevel_found and not re.match(
1911 _not_sa_pattern, frame.f_globals.get("__name__", "")
1912 ):
1913 # stop incrementing stack level if an out-of-SQLA line
1914 # were found.
1915 stacklevel_found = True
1916
1917 # however, for the warning tag thing, we have to keep
1918 # scanning up the whole traceback
1919
1920 if frame.f_code in _warning_tags:
1921 warning_tag_found = True
1922 (_suffix, _category) = _warning_tags[frame.f_code]
1923 category = category or _category
1924 message = f"{message} ({_suffix})"
1925
1926 frame = frame.f_back # type: ignore[assignment]
1927
1928 if not stacklevel_found:
1929 stacklevel += 1
1930 elif stacklevel_found and warning_tag_found:
1931 break
1932
1933 if category is not None:
1934 warnings.warn(message, category, stacklevel=stacklevel + 1)
1935 else:
1936 warnings.warn(message, stacklevel=stacklevel + 1)
1937
1938
1939def only_once(
1940 fn: Callable[..., _T], retry_on_exception: bool
1941) -> Callable[..., Optional[_T]]:
1942 """Decorate the given function to be a no-op after it is called exactly
1943 once."""
1944
1945 once = [fn]
1946
1947 def go(*arg: Any, **kw: Any) -> Optional[_T]:
1948 # strong reference fn so that it isn't garbage collected,
1949 # which interferes with the event system's expectations
1950 strong_fn = fn # noqa
1951 if once:
1952 once_fn = once.pop()
1953 try:
1954 return once_fn(*arg, **kw)
1955 except:
1956 if retry_on_exception:
1957 once.insert(0, once_fn)
1958 raise
1959
1960 return None
1961
1962 return go
1963
1964
1965_SQLA_RE = re.compile(r"sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py")
1966_UNITTEST_RE = re.compile(r"unit(?:2|test2?/)")
1967
1968
1969def chop_traceback(
1970 tb: List[str],
1971 exclude_prefix: re.Pattern[str] = _UNITTEST_RE,
1972 exclude_suffix: re.Pattern[str] = _SQLA_RE,
1973) -> List[str]:
1974 """Chop extraneous lines off beginning and end of a traceback.
1975
1976 :param tb:
1977 a list of traceback lines as returned by ``traceback.format_stack()``
1978
1979 :param exclude_prefix:
1980 a regular expression object matching lines to skip at beginning of
1981 ``tb``
1982
1983 :param exclude_suffix:
1984 a regular expression object matching lines to skip at end of ``tb``
1985 """
1986 start = 0
1987 end = len(tb) - 1
1988 while start <= end and exclude_prefix.search(tb[start]):
1989 start += 1
1990 while start <= end and exclude_suffix.search(tb[end]):
1991 end -= 1
1992 return tb[start : end + 1]
1993
1994
1995def attrsetter(attrname):
1996 code = "def set(obj, value): obj.%s = value" % attrname
1997 env = locals().copy()
1998 exec(code, env)
1999 return env["set"]
2000
2001
2002_dunders = re.compile("^__.+__$")
2003
2004
2005class TypingOnly:
2006 """A mixin class that marks a class as 'typing only', meaning it has
2007 absolutely no methods, attributes, or runtime functionality whatsoever.
2008
2009 """
2010
2011 __slots__ = ()
2012
2013 def __init_subclass__(cls) -> None:
2014 if TypingOnly in cls.__bases__:
2015 remaining = {
2016 name for name in cls.__dict__ if not _dunders.match(name)
2017 }
2018 if remaining:
2019 raise AssertionError(
2020 f"Class {cls} directly inherits TypingOnly but has "
2021 f"additional attributes {remaining}."
2022 )
2023 super().__init_subclass__()
2024
2025
2026class EnsureKWArg:
2027 r"""Apply translation of functions to accept \**kw arguments if they
2028 don't already.
2029
2030 Used to ensure cross-compatibility with third party legacy code, for things
2031 like compiler visit methods that need to accept ``**kw`` arguments,
2032 but may have been copied from old code that didn't accept them.
2033
2034 """
2035
2036 ensure_kwarg: str
2037 """a regular expression that indicates method names for which the method
2038 should accept ``**kw`` arguments.
2039
2040 The class will scan for methods matching the name template and decorate
2041 them if necessary to ensure ``**kw`` parameters are accepted.
2042
2043 """
2044
2045 def __init_subclass__(cls) -> None:
2046 fn_reg = cls.ensure_kwarg
2047 clsdict = cls.__dict__
2048 if fn_reg:
2049 for key in clsdict:
2050 m = re.match(fn_reg, key)
2051 if m:
2052 fn = clsdict[key]
2053 spec = compat.inspect_getfullargspec(fn)
2054 if not spec.varkw:
2055 wrapped = cls._wrap_w_kw(fn)
2056 setattr(cls, key, wrapped)
2057 super().__init_subclass__()
2058
2059 @classmethod
2060 def _wrap_w_kw(cls, fn: Callable[..., Any]) -> Callable[..., Any]:
2061 def wrap(*arg: Any, **kw: Any) -> Any:
2062 return fn(*arg)
2063
2064 return update_wrapper(wrap, fn)
2065
2066
2067def wrap_callable(wrapper, fn):
2068 """Augment functools.update_wrapper() to work with objects with
2069 a ``__call__()`` method.
2070
2071 :param fn:
2072 object with __call__ method
2073
2074 """
2075 if hasattr(fn, "__name__"):
2076 return update_wrapper(wrapper, fn)
2077 else:
2078 _f = wrapper
2079 _f.__name__ = fn.__class__.__name__
2080 if hasattr(fn, "__module__"):
2081 _f.__module__ = fn.__module__
2082
2083 if hasattr(fn.__call__, "__doc__") and fn.__call__.__doc__:
2084 _f.__doc__ = fn.__call__.__doc__
2085 elif fn.__doc__:
2086 _f.__doc__ = fn.__doc__
2087
2088 return _f
2089
2090
2091def quoted_token_parser(value):
2092 """Parse a dotted identifier with accommodation for quoted names.
2093
2094 Includes support for SQL-style double quotes as a literal character.
2095
2096 E.g.::
2097
2098 >>> quoted_token_parser("name")
2099 ["name"]
2100 >>> quoted_token_parser("schema.name")
2101 ["schema", "name"]
2102 >>> quoted_token_parser('"Schema"."Name"')
2103 ['Schema', 'Name']
2104 >>> quoted_token_parser('"Schema"."Name""Foo"')
2105 ['Schema', 'Name""Foo']
2106
2107 """
2108
2109 if '"' not in value:
2110 return value.split(".")
2111
2112 # 0 = outside of quotes
2113 # 1 = inside of quotes
2114 state = 0
2115 result: List[List[str]] = [[]]
2116 idx = 0
2117 lv = len(value)
2118 while idx < lv:
2119 char = value[idx]
2120 if char == '"':
2121 if state == 1 and idx < lv - 1 and value[idx + 1] == '"':
2122 result[-1].append('"')
2123 idx += 1
2124 else:
2125 state ^= 1
2126 elif char == "." and state == 0:
2127 result.append([])
2128 else:
2129 result[-1].append(char)
2130 idx += 1
2131
2132 return ["".join(token) for token in result]
2133
2134
2135def add_parameter_text(params: Any, text: str) -> Callable[[_F], _F]:
2136 params = _collections.to_list(params)
2137
2138 def decorate(fn):
2139 doc = fn.__doc__ is not None and fn.__doc__ or ""
2140 if doc:
2141 doc = inject_param_text(doc, {param: text for param in params})
2142 fn.__doc__ = doc
2143 return fn
2144
2145 return decorate
2146
2147
2148def _dedent_docstring(text: str) -> str:
2149 split_text = text.split("\n", 1)
2150 if len(split_text) == 1:
2151 return text
2152 else:
2153 firstline, remaining = split_text
2154 if not firstline.startswith(" "):
2155 return firstline + "\n" + textwrap.dedent(remaining)
2156 else:
2157 return textwrap.dedent(text)
2158
2159
2160def inject_docstring_text(
2161 given_doctext: Optional[str], injecttext: str, pos: int
2162) -> str:
2163 doctext: str = _dedent_docstring(given_doctext or "")
2164 lines = doctext.split("\n")
2165 if len(lines) == 1:
2166 lines.append("")
2167 injectlines = textwrap.dedent(injecttext).split("\n")
2168 if injectlines[0]:
2169 injectlines.insert(0, "")
2170
2171 blanks = [num for num, line in enumerate(lines) if not line.strip()]
2172 blanks.insert(0, 0)
2173
2174 inject_pos = blanks[min(pos, len(blanks) - 1)]
2175
2176 lines = lines[0:inject_pos] + injectlines + lines[inject_pos:]
2177 return "\n".join(lines)
2178
2179
2180_param_reg = re.compile(r"(\s+):param (.+?):")
2181
2182
2183def inject_param_text(doctext: str, inject_params: Dict[str, str]) -> str:
2184 doclines = collections.deque(doctext.splitlines())
2185 lines = []
2186
2187 # TODO: this is not working for params like ":param case_sensitive=True:"
2188
2189 to_inject = None
2190 while doclines:
2191 line = doclines.popleft()
2192
2193 m = _param_reg.match(line)
2194
2195 if to_inject is None:
2196 if m:
2197 param = m.group(2).lstrip("*")
2198 if param in inject_params:
2199 # default indent to that of :param: plus one
2200 indent = " " * len(m.group(1)) + " "
2201
2202 # but if the next line has text, use that line's
2203 # indentation
2204 if doclines:
2205 m2 = re.match(r"(\s+)\S", doclines[0])
2206 if m2:
2207 indent = " " * len(m2.group(1))
2208
2209 to_inject = indent + inject_params[param]
2210 elif m:
2211 lines.extend(["\n", to_inject, "\n"])
2212 to_inject = None
2213 elif not line.rstrip():
2214 lines.extend([line, to_inject, "\n"])
2215 to_inject = None
2216 elif line.endswith("::"):
2217 # TODO: this still won't cover if the code example itself has
2218 # blank lines in it, need to detect those via indentation.
2219 lines.extend([line, doclines.popleft()])
2220 continue
2221 lines.append(line)
2222
2223 return "\n".join(lines)
2224
2225
2226def repr_tuple_names(names: List[str]) -> Optional[str]:
2227 """Trims a list of strings from the middle and return a string of up to
2228 four elements. Strings greater than 11 characters will be truncated"""
2229 if len(names) == 0:
2230 return None
2231 flag = len(names) <= 4
2232 names = names[0:4] if flag else names[0:3] + names[-1:]
2233 res = ["%s.." % name[:11] if len(name) > 11 else name for name in names]
2234 if flag:
2235 return ", ".join(res)
2236 else:
2237 return "%s, ..., %s" % (", ".join(res[0:3]), res[-1])
2238
2239
2240def has_compiled_ext(raise_=False):
2241 from ._has_cython import HAS_CYEXTENSION
2242
2243 if HAS_CYEXTENSION:
2244 return True
2245 elif raise_:
2246 raise ImportError(
2247 "cython extensions were expected to be installed, "
2248 "but are not present"
2249 )
2250 else:
2251 return False
2252
2253
2254def load_uncompiled_module(module: _M) -> _M:
2255 """Load the non-compied version of a module that is also
2256 compiled with cython.
2257 """
2258 full_name = module.__name__
2259 assert module.__spec__
2260 parent_name = module.__spec__.parent
2261 assert parent_name
2262 parent_module = sys.modules[parent_name]
2263 assert parent_module.__spec__
2264 package_path = parent_module.__spec__.origin
2265 assert package_path and package_path.endswith("__init__.py")
2266
2267 name = full_name.split(".")[-1]
2268 module_path = package_path.replace("__init__.py", f"{name}.py")
2269
2270 py_spec = importlib.util.spec_from_file_location(full_name, module_path)
2271 assert py_spec
2272 py_module = importlib.util.module_from_spec(py_spec)
2273 assert py_spec.loader
2274 py_spec.loader.exec_module(py_module)
2275 return cast(_M, py_module)
2276
2277
2278class _Missing(enum.Enum):
2279 Missing = enum.auto()
2280
2281
2282Missing = _Missing.Missing
2283MissingOr = Union[_T, Literal[_Missing.Missing]]