1# util/langhelpers.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Routines to help with the creation, loading and introspection of
10modules, classes, hierarchies, attributes, functions, and methods.
11
12"""
13from __future__ import annotations
14
15import collections
16import enum
17from functools import update_wrapper
18import importlib.util
19import inspect
20import itertools
21import operator
22import re
23import sys
24import textwrap
25import threading
26import types
27from types import CodeType
28from types import ModuleType
29from typing import Any
30from typing import Callable
31from typing import cast
32from typing import Dict
33from typing import FrozenSet
34from typing import Generic
35from typing import Iterator
36from typing import List
37from typing import Mapping
38from typing import NoReturn
39from typing import Optional
40from typing import overload
41from typing import Sequence
42from typing import Set
43from typing import Tuple
44from typing import Type
45from typing import TYPE_CHECKING
46from typing import TypeVar
47from typing import Union
48import warnings
49
50from . import _collections
51from . import compat
52from .typing import Literal
53from .. import exc
54
55_T = TypeVar("_T")
56_T_co = TypeVar("_T_co", covariant=True)
57_F = TypeVar("_F", bound=Callable[..., Any])
58_MA = TypeVar("_MA", bound="HasMemoized.memoized_attribute[Any]")
59_M = TypeVar("_M", bound=ModuleType)
60
61if compat.py314:
62 # vendor a minimal form of get_annotations per
63 # https://github.com/python/cpython/issues/133684#issuecomment-2863841891
64
65 from annotationlib import call_annotate_function # type: ignore
66 from annotationlib import Format
67
68 def _get_and_call_annotate(obj, format): # noqa: A002
69 annotate = getattr(obj, "__annotate__", None)
70 if annotate is not None:
71 ann = call_annotate_function(annotate, format, owner=obj)
72 if not isinstance(ann, dict):
73 raise ValueError(f"{obj!r}.__annotate__ returned a non-dict")
74 return ann
75 return None
76
77 # this is ported from py3.13.0a7
78 _BASE_GET_ANNOTATIONS = type.__dict__["__annotations__"].__get__ # type: ignore # noqa: E501
79
80 def _get_dunder_annotations(obj):
81 if isinstance(obj, type):
82 try:
83 ann = _BASE_GET_ANNOTATIONS(obj)
84 except AttributeError:
85 # For static types, the descriptor raises AttributeError.
86 return {}
87 else:
88 ann = getattr(obj, "__annotations__", None)
89 if ann is None:
90 return {}
91
92 if not isinstance(ann, dict):
93 raise ValueError(
94 f"{obj!r}.__annotations__ is neither a dict nor None"
95 )
96 return dict(ann)
97
98 def _vendored_get_annotations(
99 obj: Any, *, format: Format # noqa: A002
100 ) -> Mapping[str, Any]:
101 """A sparse implementation of annotationlib.get_annotations()"""
102
103 try:
104 ann = _get_dunder_annotations(obj)
105 except Exception:
106 pass
107 else:
108 if ann is not None:
109 return dict(ann)
110
111 # But if __annotations__ threw a NameError, we try calling __annotate__
112 ann = _get_and_call_annotate(obj, format)
113 if ann is None:
114 # If that didn't work either, we have a very weird object:
115 # evaluating
116 # __annotations__ threw NameError and there is no __annotate__.
117 # In that case,
118 # we fall back to trying __annotations__ again.
119 ann = _get_dunder_annotations(obj)
120
121 if ann is None:
122 if isinstance(obj, type) or callable(obj):
123 return {}
124 raise TypeError(f"{obj!r} does not have annotations")
125
126 if not ann:
127 return {}
128
129 return dict(ann)
130
131 def get_annotations(obj: Any) -> Mapping[str, Any]:
132 # FORWARDREF has the effect of giving us ForwardRefs and not
133 # actually trying to evaluate the annotations. We need this so
134 # that the annotations act as much like
135 # "from __future__ import annotations" as possible, which is going
136 # away in future python as a separate mode
137 return _vendored_get_annotations(obj, format=Format.FORWARDREF)
138
139elif compat.py310:
140
141 def get_annotations(obj: Any) -> Mapping[str, Any]:
142 return inspect.get_annotations(obj)
143
144else:
145
146 def get_annotations(obj: Any) -> Mapping[str, Any]:
147 # https://docs.python.org/3/howto/annotations.html#annotations-howto
148 if isinstance(obj, type):
149 ann = obj.__dict__.get("__annotations__", None)
150 else:
151 ann = obj.__annotations__
152
153 if ann is None:
154 return _collections.EMPTY_DICT
155 else:
156 return cast("Mapping[str, Any]", ann)
157
158
159def md5_hex(x: Any) -> str:
160 x = x.encode("utf-8")
161 m = compat.md5_not_for_security()
162 m.update(x)
163 return cast(str, m.hexdigest())
164
165
166class safe_reraise:
167 """Reraise an exception after invoking some
168 handler code.
169
170 Stores the existing exception info before
171 invoking so that it is maintained across a potential
172 coroutine context switch.
173
174 e.g.::
175
176 try:
177 sess.commit()
178 except:
179 with safe_reraise():
180 sess.rollback()
181
182 TODO: we should at some point evaluate current behaviors in this regard
183 based on current greenlet, gevent/eventlet implementations in Python 3, and
184 also see the degree to which our own asyncio (based on greenlet also) is
185 impacted by this. .rollback() will cause IO / context switch to occur in
186 all these scenarios; what happens to the exception context from an
187 "except:" block if we don't explicitly store it? Original issue was #2703.
188
189 """
190
191 __slots__ = ("_exc_info",)
192
193 _exc_info: Union[
194 None,
195 Tuple[
196 Type[BaseException],
197 BaseException,
198 types.TracebackType,
199 ],
200 Tuple[None, None, None],
201 ]
202
203 def __enter__(self) -> None:
204 self._exc_info = sys.exc_info()
205
206 def __exit__(
207 self,
208 type_: Optional[Type[BaseException]],
209 value: Optional[BaseException],
210 traceback: Optional[types.TracebackType],
211 ) -> NoReturn:
212 assert self._exc_info is not None
213 # see #2703 for notes
214 if type_ is None:
215 exc_type, exc_value, exc_tb = self._exc_info
216 assert exc_value is not None
217 self._exc_info = None # remove potential circular references
218 raise exc_value.with_traceback(exc_tb)
219 else:
220 self._exc_info = None # remove potential circular references
221 assert value is not None
222 raise value.with_traceback(traceback)
223
224
225def walk_subclasses(cls: Type[_T]) -> Iterator[Type[_T]]:
226 seen: Set[Any] = set()
227
228 stack = [cls]
229 while stack:
230 cls = stack.pop()
231 if cls in seen:
232 continue
233 else:
234 seen.add(cls)
235 stack.extend(cls.__subclasses__())
236 yield cls
237
238
239def string_or_unprintable(element: Any) -> str:
240 if isinstance(element, str):
241 return element
242 else:
243 try:
244 return str(element)
245 except Exception:
246 return "unprintable element %r" % element
247
248
249def clsname_as_plain_name(
250 cls: Type[Any], use_name: Optional[str] = None
251) -> str:
252 name = use_name or cls.__name__
253 return " ".join(n.lower() for n in re.findall(r"([A-Z][a-z]+|SQL)", name))
254
255
256def method_is_overridden(
257 instance_or_cls: Union[Type[Any], object],
258 against_method: Callable[..., Any],
259) -> bool:
260 """Return True if the two class methods don't match."""
261
262 if not isinstance(instance_or_cls, type):
263 current_cls = instance_or_cls.__class__
264 else:
265 current_cls = instance_or_cls
266
267 method_name = against_method.__name__
268
269 current_method: types.MethodType = getattr(current_cls, method_name)
270
271 return current_method != against_method
272
273
274def decode_slice(slc: slice) -> Tuple[Any, ...]:
275 """decode a slice object as sent to __getitem__.
276
277 takes into account the 2.5 __index__() method, basically.
278
279 """
280 ret: List[Any] = []
281 for x in slc.start, slc.stop, slc.step:
282 if hasattr(x, "__index__"):
283 x = x.__index__()
284 ret.append(x)
285 return tuple(ret)
286
287
288def _unique_symbols(used: Sequence[str], *bases: str) -> Iterator[str]:
289 used_set = set(used)
290 for base in bases:
291 pool = itertools.chain(
292 (base,),
293 map(lambda i: base + str(i), range(1000)),
294 )
295 for sym in pool:
296 if sym not in used_set:
297 used_set.add(sym)
298 yield sym
299 break
300 else:
301 raise NameError("exhausted namespace for symbol base %s" % base)
302
303
304def map_bits(fn: Callable[[int], Any], n: int) -> Iterator[Any]:
305 """Call the given function given each nonzero bit from n."""
306
307 while n:
308 b = n & (~n + 1)
309 yield fn(b)
310 n ^= b
311
312
313_Fn = TypeVar("_Fn", bound="Callable[..., Any]")
314
315# this seems to be in flux in recent mypy versions
316
317
318def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]:
319 """A signature-matching decorator factory."""
320
321 def decorate(fn: _Fn) -> _Fn:
322 if not inspect.isfunction(fn) and not inspect.ismethod(fn):
323 raise Exception("not a decoratable function")
324
325 # Python 3.14 defer creating __annotations__ until its used.
326 # We do not want to create __annotations__ now.
327 annofunc = getattr(fn, "__annotate__", None)
328 if annofunc is not None:
329 fn.__annotate__ = None # type: ignore[union-attr]
330 try:
331 spec = compat.inspect_getfullargspec(fn)
332 finally:
333 fn.__annotate__ = annofunc # type: ignore[union-attr]
334 else:
335 spec = compat.inspect_getfullargspec(fn)
336
337 # Do not generate code for annotations.
338 # update_wrapper() copies the annotation from fn to decorated.
339 # We use dummy defaults for code generation to avoid having
340 # copy of large globals for compiling.
341 # We copy __defaults__ and __kwdefaults__ from fn to decorated.
342 empty_defaults = (None,) * len(spec.defaults or ())
343 empty_kwdefaults = dict.fromkeys(spec.kwonlydefaults or ())
344 spec = spec._replace(
345 annotations={},
346 defaults=empty_defaults,
347 kwonlydefaults=empty_kwdefaults,
348 )
349
350 names = (
351 tuple(cast("Tuple[str, ...]", spec[0]))
352 + cast("Tuple[str, ...]", spec[1:3])
353 + (fn.__name__,)
354 )
355 targ_name, fn_name = _unique_symbols(names, "target", "fn")
356
357 metadata: Dict[str, Optional[str]] = dict(target=targ_name, fn=fn_name)
358 metadata.update(format_argspec_plus(spec, grouped=False))
359 metadata["name"] = fn.__name__
360
361 if inspect.iscoroutinefunction(fn):
362 metadata["prefix"] = "async "
363 metadata["target_prefix"] = "await "
364 else:
365 metadata["prefix"] = ""
366 metadata["target_prefix"] = ""
367
368 # look for __ positional arguments. This is a convention in
369 # SQLAlchemy that arguments should be passed positionally
370 # rather than as keyword
371 # arguments. note that apply_pos doesn't currently work in all cases
372 # such as when a kw-only indicator "*" is present, which is why
373 # we limit the use of this to just that case we can detect. As we add
374 # more kinds of methods that use @decorator, things may have to
375 # be further improved in this area
376 if "__" in repr(spec[0]):
377 code = (
378 """\
379%(prefix)sdef %(name)s%(grouped_args)s:
380 return %(target_prefix)s%(target)s(%(fn)s, %(apply_pos)s)
381"""
382 % metadata
383 )
384 else:
385 code = (
386 """\
387%(prefix)sdef %(name)s%(grouped_args)s:
388 return %(target_prefix)s%(target)s(%(fn)s, %(apply_kw)s)
389"""
390 % metadata
391 )
392
393 env: Dict[str, Any] = {
394 targ_name: target,
395 fn_name: fn,
396 "__name__": fn.__module__,
397 }
398
399 decorated = cast(
400 types.FunctionType,
401 _exec_code_in_env(code, env, fn.__name__),
402 )
403 decorated.__defaults__ = fn.__defaults__
404 decorated.__kwdefaults__ = fn.__kwdefaults__ # type: ignore
405 return update_wrapper(decorated, fn) # type: ignore[return-value]
406
407 return update_wrapper(decorate, target) # type: ignore[return-value]
408
409
410def _exec_code_in_env(
411 code: Union[str, types.CodeType], env: Dict[str, Any], fn_name: str
412) -> Callable[..., Any]:
413 exec(code, env)
414 return env[fn_name] # type: ignore[no-any-return]
415
416
417_PF = TypeVar("_PF")
418_TE = TypeVar("_TE")
419
420
421class PluginLoader:
422 def __init__(
423 self, group: str, auto_fn: Optional[Callable[..., Any]] = None
424 ):
425 self.group = group
426 self.impls: Dict[str, Any] = {}
427 self.auto_fn = auto_fn
428
429 def clear(self):
430 self.impls.clear()
431
432 def load(self, name: str) -> Any:
433 if name in self.impls:
434 return self.impls[name]()
435
436 if self.auto_fn:
437 loader = self.auto_fn(name)
438 if loader:
439 self.impls[name] = loader
440 return loader()
441
442 for impl in compat.importlib_metadata_get(self.group):
443 if impl.name == name:
444 self.impls[name] = impl.load
445 return impl.load()
446
447 raise exc.NoSuchModuleError(
448 "Can't load plugin: %s:%s" % (self.group, name)
449 )
450
451 def register(self, name: str, modulepath: str, objname: str) -> None:
452 def load():
453 mod = __import__(modulepath)
454 for token in modulepath.split(".")[1:]:
455 mod = getattr(mod, token)
456 return getattr(mod, objname)
457
458 self.impls[name] = load
459
460 def deregister(self, name: str) -> None:
461 del self.impls[name]
462
463
464def _inspect_func_args(fn):
465 try:
466 co_varkeywords = inspect.CO_VARKEYWORDS
467 except AttributeError:
468 # https://docs.python.org/3/library/inspect.html
469 # The flags are specific to CPython, and may not be defined in other
470 # Python implementations. Furthermore, the flags are an implementation
471 # detail, and can be removed or deprecated in future Python releases.
472 spec = compat.inspect_getfullargspec(fn)
473 return spec[0], bool(spec[2])
474 else:
475 # use fn.__code__ plus flags to reduce method call overhead
476 co = fn.__code__
477 nargs = co.co_argcount
478 return (
479 list(co.co_varnames[:nargs]),
480 bool(co.co_flags & co_varkeywords),
481 )
482
483
484@overload
485def get_cls_kwargs(
486 cls: type,
487 *,
488 _set: Optional[Set[str]] = None,
489 raiseerr: Literal[True] = ...,
490) -> Set[str]: ...
491
492
493@overload
494def get_cls_kwargs(
495 cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
496) -> Optional[Set[str]]: ...
497
498
499def get_cls_kwargs(
500 cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
501) -> Optional[Set[str]]:
502 r"""Return the full set of inherited kwargs for the given `cls`.
503
504 Probes a class's __init__ method, collecting all named arguments. If the
505 __init__ defines a \**kwargs catch-all, then the constructor is presumed
506 to pass along unrecognized keywords to its base classes, and the
507 collection process is repeated recursively on each of the bases.
508
509 Uses a subset of inspect.getfullargspec() to cut down on method overhead,
510 as this is used within the Core typing system to create copies of type
511 objects which is a performance-sensitive operation.
512
513 No anonymous tuple arguments please !
514
515 """
516 toplevel = _set is None
517 if toplevel:
518 _set = set()
519 assert _set is not None
520
521 ctr = cls.__dict__.get("__init__", False)
522
523 has_init = (
524 ctr
525 and isinstance(ctr, types.FunctionType)
526 and isinstance(ctr.__code__, types.CodeType)
527 )
528
529 if has_init:
530 names, has_kw = _inspect_func_args(ctr)
531 _set.update(names)
532
533 if not has_kw and not toplevel:
534 if raiseerr:
535 raise TypeError(
536 f"given cls {cls} doesn't have an __init__ method"
537 )
538 else:
539 return None
540 else:
541 has_kw = False
542
543 if not has_init or has_kw:
544 for c in cls.__bases__:
545 if get_cls_kwargs(c, _set=_set) is None:
546 break
547
548 _set.discard("self")
549 return _set
550
551
552def get_func_kwargs(func: Callable[..., Any]) -> List[str]:
553 """Return the set of legal kwargs for the given `func`.
554
555 Uses getargspec so is safe to call for methods, functions,
556 etc.
557
558 """
559
560 return compat.inspect_getfullargspec(func)[0]
561
562
563def get_callable_argspec(
564 fn: Callable[..., Any], no_self: bool = False, _is_init: bool = False
565) -> compat.FullArgSpec:
566 """Return the argument signature for any callable.
567
568 All pure-Python callables are accepted, including
569 functions, methods, classes, objects with __call__;
570 builtins and other edge cases like functools.partial() objects
571 raise a TypeError.
572
573 """
574 if inspect.isbuiltin(fn):
575 raise TypeError("Can't inspect builtin: %s" % fn)
576 elif inspect.isfunction(fn):
577 if _is_init and no_self:
578 spec = compat.inspect_getfullargspec(fn)
579 return compat.FullArgSpec(
580 spec.args[1:],
581 spec.varargs,
582 spec.varkw,
583 spec.defaults,
584 spec.kwonlyargs,
585 spec.kwonlydefaults,
586 spec.annotations,
587 )
588 else:
589 return compat.inspect_getfullargspec(fn)
590 elif inspect.ismethod(fn):
591 if no_self and (_is_init or fn.__self__):
592 spec = compat.inspect_getfullargspec(fn.__func__)
593 return compat.FullArgSpec(
594 spec.args[1:],
595 spec.varargs,
596 spec.varkw,
597 spec.defaults,
598 spec.kwonlyargs,
599 spec.kwonlydefaults,
600 spec.annotations,
601 )
602 else:
603 return compat.inspect_getfullargspec(fn.__func__)
604 elif inspect.isclass(fn):
605 return get_callable_argspec(
606 fn.__init__, no_self=no_self, _is_init=True
607 )
608 elif hasattr(fn, "__func__"):
609 return compat.inspect_getfullargspec(fn.__func__)
610 elif hasattr(fn, "__call__"):
611 if inspect.ismethod(fn.__call__):
612 return get_callable_argspec(fn.__call__, no_self=no_self)
613 else:
614 raise TypeError("Can't inspect callable: %s" % fn)
615 else:
616 raise TypeError("Can't inspect callable: %s" % fn)
617
618
619def format_argspec_plus(
620 fn: Union[Callable[..., Any], compat.FullArgSpec], grouped: bool = True
621) -> Dict[str, Optional[str]]:
622 """Returns a dictionary of formatted, introspected function arguments.
623
624 A enhanced variant of inspect.formatargspec to support code generation.
625
626 fn
627 An inspectable callable or tuple of inspect getargspec() results.
628 grouped
629 Defaults to True; include (parens, around, argument) lists
630
631 Returns:
632
633 args
634 Full inspect.formatargspec for fn
635 self_arg
636 The name of the first positional argument, varargs[0], or None
637 if the function defines no positional arguments.
638 apply_pos
639 args, re-written in calling rather than receiving syntax. Arguments are
640 passed positionally.
641 apply_kw
642 Like apply_pos, except keyword-ish args are passed as keywords.
643 apply_pos_proxied
644 Like apply_pos but omits the self/cls argument
645
646 Example::
647
648 >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123)
649 {'grouped_args': '(self, a, b, c=3, **d)',
650 'self_arg': 'self',
651 'apply_kw': '(self, a, b, c=c, **d)',
652 'apply_pos': '(self, a, b, c, **d)'}
653
654 """
655 if callable(fn):
656 spec = compat.inspect_getfullargspec(fn)
657 else:
658 spec = fn
659
660 args = compat.inspect_formatargspec(*spec)
661
662 apply_pos = compat.inspect_formatargspec(
663 spec[0], spec[1], spec[2], None, spec[4]
664 )
665
666 if spec[0]:
667 self_arg = spec[0][0]
668
669 apply_pos_proxied = compat.inspect_formatargspec(
670 spec[0][1:], spec[1], spec[2], None, spec[4]
671 )
672
673 elif spec[1]:
674 # I'm not sure what this is
675 self_arg = "%s[0]" % spec[1]
676
677 apply_pos_proxied = apply_pos
678 else:
679 self_arg = None
680 apply_pos_proxied = apply_pos
681
682 num_defaults = 0
683 if spec[3]:
684 num_defaults += len(cast(Tuple[Any], spec[3]))
685 if spec[4]:
686 num_defaults += len(spec[4])
687
688 name_args = spec[0] + spec[4]
689
690 defaulted_vals: Union[List[str], Tuple[()]]
691
692 if num_defaults:
693 defaulted_vals = name_args[0 - num_defaults :]
694 else:
695 defaulted_vals = ()
696
697 apply_kw = compat.inspect_formatargspec(
698 name_args,
699 spec[1],
700 spec[2],
701 defaulted_vals,
702 formatvalue=lambda x: "=" + str(x),
703 )
704
705 if spec[0]:
706 apply_kw_proxied = compat.inspect_formatargspec(
707 name_args[1:],
708 spec[1],
709 spec[2],
710 defaulted_vals,
711 formatvalue=lambda x: "=" + str(x),
712 )
713 else:
714 apply_kw_proxied = apply_kw
715
716 if grouped:
717 return dict(
718 grouped_args=args,
719 self_arg=self_arg,
720 apply_pos=apply_pos,
721 apply_kw=apply_kw,
722 apply_pos_proxied=apply_pos_proxied,
723 apply_kw_proxied=apply_kw_proxied,
724 )
725 else:
726 return dict(
727 grouped_args=args,
728 self_arg=self_arg,
729 apply_pos=apply_pos[1:-1],
730 apply_kw=apply_kw[1:-1],
731 apply_pos_proxied=apply_pos_proxied[1:-1],
732 apply_kw_proxied=apply_kw_proxied[1:-1],
733 )
734
735
736def format_argspec_init(method, grouped=True):
737 """format_argspec_plus with considerations for typical __init__ methods
738
739 Wraps format_argspec_plus with error handling strategies for typical
740 __init__ cases:
741
742 .. sourcecode:: text
743
744 object.__init__ -> (self)
745 other unreflectable (usually C) -> (self, *args, **kwargs)
746
747 """
748 if method is object.__init__:
749 grouped_args = "(self)"
750 args = "(self)" if grouped else "self"
751 proxied = "()" if grouped else ""
752 else:
753 try:
754 return format_argspec_plus(method, grouped=grouped)
755 except TypeError:
756 grouped_args = "(self, *args, **kwargs)"
757 args = grouped_args if grouped else "self, *args, **kwargs"
758 proxied = "(*args, **kwargs)" if grouped else "*args, **kwargs"
759 return dict(
760 self_arg="self",
761 grouped_args=grouped_args,
762 apply_pos=args,
763 apply_kw=args,
764 apply_pos_proxied=proxied,
765 apply_kw_proxied=proxied,
766 )
767
768
769def create_proxy_methods(
770 target_cls: Type[Any],
771 target_cls_sphinx_name: str,
772 proxy_cls_sphinx_name: str,
773 classmethods: Sequence[str] = (),
774 methods: Sequence[str] = (),
775 attributes: Sequence[str] = (),
776 use_intermediate_variable: Sequence[str] = (),
777) -> Callable[[_T], _T]:
778 """A class decorator indicating attributes should refer to a proxy
779 class.
780
781 This decorator is now a "marker" that does nothing at runtime. Instead,
782 it is consumed by the tools/generate_proxy_methods.py script to
783 statically generate proxy methods and attributes that are fully
784 recognized by typing tools such as mypy.
785
786 """
787
788 def decorate(cls):
789 return cls
790
791 return decorate
792
793
794def getargspec_init(method):
795 """inspect.getargspec with considerations for typical __init__ methods
796
797 Wraps inspect.getargspec with error handling for typical __init__ cases:
798
799 .. sourcecode:: text
800
801 object.__init__ -> (self)
802 other unreflectable (usually C) -> (self, *args, **kwargs)
803
804 """
805 try:
806 return compat.inspect_getfullargspec(method)
807 except TypeError:
808 if method is object.__init__:
809 return (["self"], None, None, None)
810 else:
811 return (["self"], "args", "kwargs", None)
812
813
814def unbound_method_to_callable(func_or_cls):
815 """Adjust the incoming callable such that a 'self' argument is not
816 required.
817
818 """
819
820 if isinstance(func_or_cls, types.MethodType) and not func_or_cls.__self__:
821 return func_or_cls.__func__
822 else:
823 return func_or_cls
824
825
826def generic_repr(
827 obj: Any,
828 additional_kw: Sequence[Tuple[str, Any]] = (),
829 to_inspect: Optional[Union[object, List[object]]] = None,
830 omit_kwarg: Sequence[str] = (),
831) -> str:
832 """Produce a __repr__() based on direct association of the __init__()
833 specification vs. same-named attributes present.
834
835 """
836 if to_inspect is None:
837 to_inspect = [obj]
838 else:
839 to_inspect = _collections.to_list(to_inspect)
840
841 missing = object()
842
843 pos_args = []
844 kw_args: _collections.OrderedDict[str, Any] = _collections.OrderedDict()
845 vargs = None
846 for i, insp in enumerate(to_inspect):
847 try:
848 spec = compat.inspect_getfullargspec(insp.__init__)
849 except TypeError:
850 continue
851 else:
852 default_len = len(spec.defaults) if spec.defaults else 0
853 if i == 0:
854 if spec.varargs:
855 vargs = spec.varargs
856 if default_len:
857 pos_args.extend(spec.args[1:-default_len])
858 else:
859 pos_args.extend(spec.args[1:])
860 else:
861 kw_args.update(
862 [(arg, missing) for arg in spec.args[1:-default_len]]
863 )
864
865 if default_len:
866 assert spec.defaults
867 kw_args.update(
868 [
869 (arg, default)
870 for arg, default in zip(
871 spec.args[-default_len:], spec.defaults
872 )
873 ]
874 )
875 output: List[str] = []
876
877 output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
878
879 if vargs is not None and hasattr(obj, vargs):
880 output.extend([repr(val) for val in getattr(obj, vargs)])
881
882 for arg, defval in kw_args.items():
883 if arg in omit_kwarg:
884 continue
885 try:
886 val = getattr(obj, arg, missing)
887 if val is not missing and val != defval:
888 output.append("%s=%r" % (arg, val))
889 except Exception:
890 pass
891
892 if additional_kw:
893 for arg, defval in additional_kw:
894 try:
895 val = getattr(obj, arg, missing)
896 if val is not missing and val != defval:
897 output.append("%s=%r" % (arg, val))
898 except Exception:
899 pass
900
901 return "%s(%s)" % (obj.__class__.__name__, ", ".join(output))
902
903
904def class_hierarchy(cls):
905 """Return an unordered sequence of all classes related to cls.
906
907 Traverses diamond hierarchies.
908
909 Fibs slightly: subclasses of builtin types are not returned. Thus
910 class_hierarchy(class A(object)) returns (A, object), not A plus every
911 class systemwide that derives from object.
912
913 """
914
915 hier = {cls}
916 process = list(cls.__mro__)
917 while process:
918 c = process.pop()
919 bases = (_ for _ in c.__bases__ if _ not in hier)
920
921 for b in bases:
922 process.append(b)
923 hier.add(b)
924
925 if c.__module__ == "builtins" or not hasattr(c, "__subclasses__"):
926 continue
927
928 for s in [
929 _
930 for _ in (
931 c.__subclasses__()
932 if not issubclass(c, type)
933 else c.__subclasses__(c)
934 )
935 if _ not in hier
936 ]:
937 process.append(s)
938 hier.add(s)
939 return list(hier)
940
941
942def iterate_attributes(cls):
943 """iterate all the keys and attributes associated
944 with a class, without using getattr().
945
946 Does not use getattr() so that class-sensitive
947 descriptors (i.e. property.__get__()) are not called.
948
949 """
950 keys = dir(cls)
951 for key in keys:
952 for c in cls.__mro__:
953 if key in c.__dict__:
954 yield (key, c.__dict__[key])
955 break
956
957
958def monkeypatch_proxied_specials(
959 into_cls,
960 from_cls,
961 skip=None,
962 only=None,
963 name="self.proxy",
964 from_instance=None,
965):
966 """Automates delegation of __specials__ for a proxying type."""
967
968 if only:
969 dunders = only
970 else:
971 if skip is None:
972 skip = (
973 "__slots__",
974 "__del__",
975 "__getattribute__",
976 "__metaclass__",
977 "__getstate__",
978 "__setstate__",
979 )
980 dunders = [
981 m
982 for m in dir(from_cls)
983 if (
984 m.startswith("__")
985 and m.endswith("__")
986 and not hasattr(into_cls, m)
987 and m not in skip
988 )
989 ]
990
991 for method in dunders:
992 try:
993 maybe_fn = getattr(from_cls, method)
994 if not hasattr(maybe_fn, "__call__"):
995 continue
996 maybe_fn = getattr(maybe_fn, "__func__", maybe_fn)
997 fn = cast(types.FunctionType, maybe_fn)
998
999 except AttributeError:
1000 continue
1001 try:
1002 spec = compat.inspect_getfullargspec(fn)
1003 fn_args = compat.inspect_formatargspec(spec[0])
1004 d_args = compat.inspect_formatargspec(spec[0][1:])
1005 except TypeError:
1006 fn_args = "(self, *args, **kw)"
1007 d_args = "(*args, **kw)"
1008
1009 py = (
1010 "def %(method)s%(fn_args)s: "
1011 "return %(name)s.%(method)s%(d_args)s" % locals()
1012 )
1013
1014 env: Dict[str, types.FunctionType] = (
1015 from_instance is not None and {name: from_instance} or {}
1016 )
1017 exec(py, env)
1018 try:
1019 env[method].__defaults__ = fn.__defaults__
1020 except AttributeError:
1021 pass
1022 setattr(into_cls, method, env[method])
1023
1024
1025def methods_equivalent(meth1, meth2):
1026 """Return True if the two methods are the same implementation."""
1027
1028 return getattr(meth1, "__func__", meth1) is getattr(
1029 meth2, "__func__", meth2
1030 )
1031
1032
1033def as_interface(obj, cls=None, methods=None, required=None):
1034 """Ensure basic interface compliance for an instance or dict of callables.
1035
1036 Checks that ``obj`` implements public methods of ``cls`` or has members
1037 listed in ``methods``. If ``required`` is not supplied, implementing at
1038 least one interface method is sufficient. Methods present on ``obj`` that
1039 are not in the interface are ignored.
1040
1041 If ``obj`` is a dict and ``dict`` does not meet the interface
1042 requirements, the keys of the dictionary are inspected. Keys present in
1043 ``obj`` that are not in the interface will raise TypeErrors.
1044
1045 Raises TypeError if ``obj`` does not meet the interface criteria.
1046
1047 In all passing cases, an object with callable members is returned. In the
1048 simple case, ``obj`` is returned as-is; if dict processing kicks in then
1049 an anonymous class is returned.
1050
1051 obj
1052 A type, instance, or dictionary of callables.
1053 cls
1054 Optional, a type. All public methods of cls are considered the
1055 interface. An ``obj`` instance of cls will always pass, ignoring
1056 ``required``..
1057 methods
1058 Optional, a sequence of method names to consider as the interface.
1059 required
1060 Optional, a sequence of mandatory implementations. If omitted, an
1061 ``obj`` that provides at least one interface method is considered
1062 sufficient. As a convenience, required may be a type, in which case
1063 all public methods of the type are required.
1064
1065 """
1066 if not cls and not methods:
1067 raise TypeError("a class or collection of method names are required")
1068
1069 if isinstance(cls, type) and isinstance(obj, cls):
1070 return obj
1071
1072 interface = set(methods or [m for m in dir(cls) if not m.startswith("_")])
1073 implemented = set(dir(obj))
1074
1075 complies = operator.ge
1076 if isinstance(required, type):
1077 required = interface
1078 elif not required:
1079 required = set()
1080 complies = operator.gt
1081 else:
1082 required = set(required)
1083
1084 if complies(implemented.intersection(interface), required):
1085 return obj
1086
1087 # No dict duck typing here.
1088 if not isinstance(obj, dict):
1089 qualifier = complies is operator.gt and "any of" or "all of"
1090 raise TypeError(
1091 "%r does not implement %s: %s"
1092 % (obj, qualifier, ", ".join(interface))
1093 )
1094
1095 class AnonymousInterface:
1096 """A callable-holding shell."""
1097
1098 if cls:
1099 AnonymousInterface.__name__ = "Anonymous" + cls.__name__
1100 found = set()
1101
1102 for method, impl in dictlike_iteritems(obj):
1103 if method not in interface:
1104 raise TypeError("%r: unknown in this interface" % method)
1105 if not callable(impl):
1106 raise TypeError("%r=%r is not callable" % (method, impl))
1107 setattr(AnonymousInterface, method, staticmethod(impl))
1108 found.add(method)
1109
1110 if complies(found, required):
1111 return AnonymousInterface
1112
1113 raise TypeError(
1114 "dictionary does not contain required keys %s"
1115 % ", ".join(required - found)
1116 )
1117
1118
1119_GFD = TypeVar("_GFD", bound="generic_fn_descriptor[Any]")
1120
1121
1122class generic_fn_descriptor(Generic[_T_co]):
1123 """Descriptor which proxies a function when the attribute is not
1124 present in dict
1125
1126 This superclass is organized in a particular way with "memoized" and
1127 "non-memoized" implementation classes that are hidden from type checkers,
1128 as Mypy seems to not be able to handle seeing multiple kinds of descriptor
1129 classes used for the same attribute.
1130
1131 """
1132
1133 fget: Callable[..., _T_co]
1134 __doc__: Optional[str]
1135 __name__: str
1136
1137 def __init__(self, fget: Callable[..., _T_co], doc: Optional[str] = None):
1138 self.fget = fget
1139 self.__doc__ = doc or fget.__doc__
1140 self.__name__ = fget.__name__
1141
1142 @overload
1143 def __get__(self: _GFD, obj: None, cls: Any) -> _GFD: ...
1144
1145 @overload
1146 def __get__(self, obj: object, cls: Any) -> _T_co: ...
1147
1148 def __get__(self: _GFD, obj: Any, cls: Any) -> Union[_GFD, _T_co]:
1149 raise NotImplementedError()
1150
1151 if TYPE_CHECKING:
1152
1153 def __set__(self, instance: Any, value: Any) -> None: ...
1154
1155 def __delete__(self, instance: Any) -> None: ...
1156
1157 def _reset(self, obj: Any) -> None:
1158 raise NotImplementedError()
1159
1160 @classmethod
1161 def reset(cls, obj: Any, name: str) -> None:
1162 raise NotImplementedError()
1163
1164
1165class _non_memoized_property(generic_fn_descriptor[_T_co]):
1166 """a plain descriptor that proxies a function.
1167
1168 primary rationale is to provide a plain attribute that's
1169 compatible with memoized_property which is also recognized as equivalent
1170 by mypy.
1171
1172 """
1173
1174 if not TYPE_CHECKING:
1175
1176 def __get__(self, obj, cls):
1177 if obj is None:
1178 return self
1179 return self.fget(obj)
1180
1181
1182class _memoized_property(generic_fn_descriptor[_T_co]):
1183 """A read-only @property that is only evaluated once."""
1184
1185 if not TYPE_CHECKING:
1186
1187 def __get__(self, obj, cls):
1188 if obj is None:
1189 return self
1190 obj.__dict__[self.__name__] = result = self.fget(obj)
1191 return result
1192
1193 def _reset(self, obj):
1194 _memoized_property.reset(obj, self.__name__)
1195
1196 @classmethod
1197 def reset(cls, obj, name):
1198 obj.__dict__.pop(name, None)
1199
1200
1201# despite many attempts to get Mypy to recognize an overridden descriptor
1202# where one is memoized and the other isn't, there seems to be no reliable
1203# way other than completely deceiving the type checker into thinking there
1204# is just one single descriptor type everywhere. Otherwise, if a superclass
1205# has non-memoized and subclass has memoized, that requires
1206# "class memoized(non_memoized)". but then if a superclass has memoized and
1207# superclass has non-memoized, the class hierarchy of the descriptors
1208# would need to be reversed; "class non_memoized(memoized)". so there's no
1209# way to achieve this.
1210# additional issues, RO properties:
1211# https://github.com/python/mypy/issues/12440
1212if TYPE_CHECKING:
1213 # allow memoized and non-memoized to be freely mixed by having them
1214 # be the same class
1215 memoized_property = generic_fn_descriptor
1216 non_memoized_property = generic_fn_descriptor
1217
1218 # for read only situations, mypy only sees @property as read only.
1219 # read only is needed when a subtype specializes the return type
1220 # of a property, meaning assignment needs to be disallowed
1221 ro_memoized_property = property
1222 ro_non_memoized_property = property
1223
1224else:
1225 memoized_property = ro_memoized_property = _memoized_property
1226 non_memoized_property = ro_non_memoized_property = _non_memoized_property
1227
1228
1229def memoized_instancemethod(fn: _F) -> _F:
1230 """Decorate a method memoize its return value.
1231
1232 Best applied to no-arg methods: memoization is not sensitive to
1233 argument values, and will always return the same value even when
1234 called with different arguments.
1235
1236 """
1237
1238 def oneshot(self, *args, **kw):
1239 result = fn(self, *args, **kw)
1240
1241 def memo(*a, **kw):
1242 return result
1243
1244 memo.__name__ = fn.__name__
1245 memo.__doc__ = fn.__doc__
1246 self.__dict__[fn.__name__] = memo
1247 return result
1248
1249 return update_wrapper(oneshot, fn) # type: ignore
1250
1251
1252class HasMemoized:
1253 """A mixin class that maintains the names of memoized elements in a
1254 collection for easy cache clearing, generative, etc.
1255
1256 """
1257
1258 if not TYPE_CHECKING:
1259 # support classes that want to have __slots__ with an explicit
1260 # slot for __dict__. not sure if that requires base __slots__ here.
1261 __slots__ = ()
1262
1263 _memoized_keys: FrozenSet[str] = frozenset()
1264
1265 def _reset_memoizations(self) -> None:
1266 for elem in self._memoized_keys:
1267 self.__dict__.pop(elem, None)
1268
1269 def _assert_no_memoizations(self) -> None:
1270 for elem in self._memoized_keys:
1271 assert elem not in self.__dict__
1272
1273 def _set_memoized_attribute(self, key: str, value: Any) -> None:
1274 self.__dict__[key] = value
1275 self._memoized_keys |= {key}
1276
1277 class memoized_attribute(memoized_property[_T]):
1278 """A read-only @property that is only evaluated once.
1279
1280 :meta private:
1281
1282 """
1283
1284 fget: Callable[..., _T]
1285 __doc__: Optional[str]
1286 __name__: str
1287
1288 def __init__(self, fget: Callable[..., _T], doc: Optional[str] = None):
1289 self.fget = fget
1290 self.__doc__ = doc or fget.__doc__
1291 self.__name__ = fget.__name__
1292
1293 @overload
1294 def __get__(self: _MA, obj: None, cls: Any) -> _MA: ...
1295
1296 @overload
1297 def __get__(self, obj: Any, cls: Any) -> _T: ...
1298
1299 def __get__(self, obj, cls):
1300 if obj is None:
1301 return self
1302 obj.__dict__[self.__name__] = result = self.fget(obj)
1303 obj._memoized_keys |= {self.__name__}
1304 return result
1305
1306 @classmethod
1307 def memoized_instancemethod(cls, fn: _F) -> _F:
1308 """Decorate a method memoize its return value.
1309
1310 :meta private:
1311
1312 """
1313
1314 def oneshot(self: Any, *args: Any, **kw: Any) -> Any:
1315 result = fn(self, *args, **kw)
1316
1317 def memo(*a, **kw):
1318 return result
1319
1320 memo.__name__ = fn.__name__
1321 memo.__doc__ = fn.__doc__
1322 self.__dict__[fn.__name__] = memo
1323 self._memoized_keys |= {fn.__name__}
1324 return result
1325
1326 return update_wrapper(oneshot, fn) # type: ignore
1327
1328
1329if TYPE_CHECKING:
1330 HasMemoized_ro_memoized_attribute = property
1331else:
1332 HasMemoized_ro_memoized_attribute = HasMemoized.memoized_attribute
1333
1334
1335class MemoizedSlots:
1336 """Apply memoized items to an object using a __getattr__ scheme.
1337
1338 This allows the functionality of memoized_property and
1339 memoized_instancemethod to be available to a class using __slots__.
1340
1341 """
1342
1343 __slots__ = ()
1344
1345 def _fallback_getattr(self, key):
1346 raise AttributeError(key)
1347
1348 def __getattr__(self, key: str) -> Any:
1349 if key.startswith("_memoized_attr_") or key.startswith(
1350 "_memoized_method_"
1351 ):
1352 raise AttributeError(key)
1353 # to avoid recursion errors when interacting with other __getattr__
1354 # schemes that refer to this one, when testing for memoized method
1355 # look at __class__ only rather than going into __getattr__ again.
1356 elif hasattr(self.__class__, f"_memoized_attr_{key}"):
1357 value = getattr(self, f"_memoized_attr_{key}")()
1358 setattr(self, key, value)
1359 return value
1360 elif hasattr(self.__class__, f"_memoized_method_{key}"):
1361 fn = getattr(self, f"_memoized_method_{key}")
1362
1363 def oneshot(*args, **kw):
1364 result = fn(*args, **kw)
1365
1366 def memo(*a, **kw):
1367 return result
1368
1369 memo.__name__ = fn.__name__
1370 memo.__doc__ = fn.__doc__
1371 setattr(self, key, memo)
1372 return result
1373
1374 oneshot.__doc__ = fn.__doc__
1375 return oneshot
1376 else:
1377 return self._fallback_getattr(key)
1378
1379
1380# from paste.deploy.converters
1381def asbool(obj: Any) -> bool:
1382 if isinstance(obj, str):
1383 obj = obj.strip().lower()
1384 if obj in ["true", "yes", "on", "y", "t", "1"]:
1385 return True
1386 elif obj in ["false", "no", "off", "n", "f", "0"]:
1387 return False
1388 else:
1389 raise ValueError("String is not true/false: %r" % obj)
1390 return bool(obj)
1391
1392
1393def bool_or_str(*text: str) -> Callable[[str], Union[str, bool]]:
1394 """Return a callable that will evaluate a string as
1395 boolean, or one of a set of "alternate" string values.
1396
1397 """
1398
1399 def bool_or_value(obj: str) -> Union[str, bool]:
1400 if obj in text:
1401 return obj
1402 else:
1403 return asbool(obj)
1404
1405 return bool_or_value
1406
1407
1408def asint(value: Any) -> Optional[int]:
1409 """Coerce to integer."""
1410
1411 if value is None:
1412 return value
1413 return int(value)
1414
1415
1416def coerce_kw_type(
1417 kw: Dict[str, Any],
1418 key: str,
1419 type_: Type[Any],
1420 flexi_bool: bool = True,
1421 dest: Optional[Dict[str, Any]] = None,
1422) -> None:
1423 r"""If 'key' is present in dict 'kw', coerce its value to type 'type\_' if
1424 necessary. If 'flexi_bool' is True, the string '0' is considered false
1425 when coercing to boolean.
1426 """
1427
1428 if dest is None:
1429 dest = kw
1430
1431 if (
1432 key in kw
1433 and (not isinstance(type_, type) or not isinstance(kw[key], type_))
1434 and kw[key] is not None
1435 ):
1436 if type_ is bool and flexi_bool:
1437 dest[key] = asbool(kw[key])
1438 else:
1439 dest[key] = type_(kw[key])
1440
1441
1442def constructor_key(obj: Any, cls: Type[Any]) -> Tuple[Any, ...]:
1443 """Produce a tuple structure that is cacheable using the __dict__ of
1444 obj to retrieve values
1445
1446 """
1447 names = get_cls_kwargs(cls)
1448 return (cls,) + tuple(
1449 (k, obj.__dict__[k]) for k in names if k in obj.__dict__
1450 )
1451
1452
1453def constructor_copy(obj: _T, cls: Type[_T], *args: Any, **kw: Any) -> _T:
1454 """Instantiate cls using the __dict__ of obj as constructor arguments.
1455
1456 Uses inspect to match the named arguments of ``cls``.
1457
1458 """
1459
1460 names = get_cls_kwargs(cls)
1461 kw.update(
1462 (k, obj.__dict__[k]) for k in names.difference(kw) if k in obj.__dict__
1463 )
1464 return cls(*args, **kw)
1465
1466
1467def counter() -> Callable[[], int]:
1468 """Return a threadsafe counter function."""
1469
1470 lock = threading.Lock()
1471 counter = itertools.count(1)
1472
1473 # avoid the 2to3 "next" transformation...
1474 def _next():
1475 with lock:
1476 return next(counter)
1477
1478 return _next
1479
1480
1481def duck_type_collection(
1482 specimen: Any, default: Optional[Type[Any]] = None
1483) -> Optional[Type[Any]]:
1484 """Given an instance or class, guess if it is or is acting as one of
1485 the basic collection types: list, set and dict. If the __emulates__
1486 property is present, return that preferentially.
1487 """
1488
1489 if hasattr(specimen, "__emulates__"):
1490 # canonicalize set vs sets.Set to a standard: the builtin set
1491 if specimen.__emulates__ is not None and issubclass(
1492 specimen.__emulates__, set
1493 ):
1494 return set
1495 else:
1496 return specimen.__emulates__ # type: ignore
1497
1498 isa = issubclass if isinstance(specimen, type) else isinstance
1499 if isa(specimen, list):
1500 return list
1501 elif isa(specimen, set):
1502 return set
1503 elif isa(specimen, dict):
1504 return dict
1505
1506 if hasattr(specimen, "append"):
1507 return list
1508 elif hasattr(specimen, "add"):
1509 return set
1510 elif hasattr(specimen, "set"):
1511 return dict
1512 else:
1513 return default
1514
1515
1516def assert_arg_type(
1517 arg: Any, argtype: Union[Tuple[Type[Any], ...], Type[Any]], name: str
1518) -> Any:
1519 if isinstance(arg, argtype):
1520 return arg
1521 else:
1522 if isinstance(argtype, tuple):
1523 raise exc.ArgumentError(
1524 "Argument '%s' is expected to be one of type %s, got '%s'"
1525 % (name, " or ".join("'%s'" % a for a in argtype), type(arg))
1526 )
1527 else:
1528 raise exc.ArgumentError(
1529 "Argument '%s' is expected to be of type '%s', got '%s'"
1530 % (name, argtype, type(arg))
1531 )
1532
1533
1534def dictlike_iteritems(dictlike):
1535 """Return a (key, value) iterator for almost any dict-like object."""
1536
1537 if hasattr(dictlike, "items"):
1538 return list(dictlike.items())
1539
1540 getter = getattr(dictlike, "__getitem__", getattr(dictlike, "get", None))
1541 if getter is None:
1542 raise TypeError("Object '%r' is not dict-like" % dictlike)
1543
1544 if hasattr(dictlike, "iterkeys"):
1545
1546 def iterator():
1547 for key in dictlike.iterkeys():
1548 assert getter is not None
1549 yield key, getter(key)
1550
1551 return iterator()
1552 elif hasattr(dictlike, "keys"):
1553 return iter((key, getter(key)) for key in dictlike.keys())
1554 else:
1555 raise TypeError("Object '%r' is not dict-like" % dictlike)
1556
1557
1558class classproperty(property):
1559 """A decorator that behaves like @property except that operates
1560 on classes rather than instances.
1561
1562 The decorator is currently special when using the declarative
1563 module, but note that the
1564 :class:`~.sqlalchemy.ext.declarative.declared_attr`
1565 decorator should be used for this purpose with declarative.
1566
1567 """
1568
1569 fget: Callable[[Any], Any]
1570
1571 def __init__(self, fget: Callable[[Any], Any], *arg: Any, **kw: Any):
1572 super().__init__(fget, *arg, **kw)
1573 self.__doc__ = fget.__doc__
1574
1575 def __get__(self, obj: Any, cls: Optional[type] = None) -> Any:
1576 return self.fget(cls)
1577
1578
1579class hybridproperty(Generic[_T]):
1580 def __init__(self, func: Callable[..., _T]):
1581 self.func = func
1582 self.clslevel = func
1583
1584 def __get__(self, instance: Any, owner: Any) -> _T:
1585 if instance is None:
1586 clsval = self.clslevel(owner)
1587 return clsval
1588 else:
1589 return self.func(instance)
1590
1591 def classlevel(self, func: Callable[..., Any]) -> hybridproperty[_T]:
1592 self.clslevel = func
1593 return self
1594
1595
1596class rw_hybridproperty(Generic[_T]):
1597 def __init__(self, func: Callable[..., _T]):
1598 self.func = func
1599 self.clslevel = func
1600 self.setfn: Optional[Callable[..., Any]] = None
1601
1602 def __get__(self, instance: Any, owner: Any) -> _T:
1603 if instance is None:
1604 clsval = self.clslevel(owner)
1605 return clsval
1606 else:
1607 return self.func(instance)
1608
1609 def __set__(self, instance: Any, value: Any) -> None:
1610 assert self.setfn is not None
1611 self.setfn(instance, value)
1612
1613 def setter(self, func: Callable[..., Any]) -> rw_hybridproperty[_T]:
1614 self.setfn = func
1615 return self
1616
1617 def classlevel(self, func: Callable[..., Any]) -> rw_hybridproperty[_T]:
1618 self.clslevel = func
1619 return self
1620
1621
1622class hybridmethod(Generic[_T]):
1623 """Decorate a function as cls- or instance- level."""
1624
1625 def __init__(self, func: Callable[..., _T]):
1626 self.func = self.__func__ = func
1627 self.clslevel = func
1628
1629 def __get__(self, instance: Any, owner: Any) -> Callable[..., _T]:
1630 if instance is None:
1631 return self.clslevel.__get__(owner, owner.__class__) # type:ignore
1632 else:
1633 return self.func.__get__(instance, owner) # type:ignore
1634
1635 def classlevel(self, func: Callable[..., Any]) -> hybridmethod[_T]:
1636 self.clslevel = func
1637 return self
1638
1639
1640class symbol(int):
1641 """A constant symbol.
1642
1643 >>> symbol("foo") is symbol("foo")
1644 True
1645 >>> symbol("foo")
1646 <symbol 'foo>
1647
1648 A slight refinement of the MAGICCOOKIE=object() pattern. The primary
1649 advantage of symbol() is its repr(). They are also singletons.
1650
1651 Repeated calls of symbol('name') will all return the same instance.
1652
1653 """
1654
1655 name: str
1656
1657 symbols: Dict[str, symbol] = {}
1658 _lock = threading.Lock()
1659
1660 def __new__(
1661 cls,
1662 name: str,
1663 doc: Optional[str] = None,
1664 canonical: Optional[int] = None,
1665 ) -> symbol:
1666 with cls._lock:
1667 sym = cls.symbols.get(name)
1668 if sym is None:
1669 assert isinstance(name, str)
1670 if canonical is None:
1671 canonical = hash(name)
1672 sym = int.__new__(symbol, canonical)
1673 sym.name = name
1674 if doc:
1675 sym.__doc__ = doc
1676
1677 # NOTE: we should ultimately get rid of this global thing,
1678 # however, currently it is to support pickling. The best
1679 # change would be when we are on py3.11 at a minimum, we
1680 # switch to stdlib enum.IntFlag.
1681 cls.symbols[name] = sym
1682 else:
1683 if canonical and canonical != sym:
1684 raise TypeError(
1685 f"Can't replace canonical symbol for {name!r} "
1686 f"with new int value {canonical}"
1687 )
1688 return sym
1689
1690 def __reduce__(self):
1691 return symbol, (self.name, "x", int(self))
1692
1693 def __str__(self):
1694 return repr(self)
1695
1696 def __repr__(self):
1697 return f"symbol({self.name!r})"
1698
1699
1700class _IntFlagMeta(type):
1701 def __init__(
1702 cls,
1703 classname: str,
1704 bases: Tuple[Type[Any], ...],
1705 dict_: Dict[str, Any],
1706 **kw: Any,
1707 ) -> None:
1708 items: List[symbol]
1709 cls._items = items = []
1710 for k, v in dict_.items():
1711 if re.match(r"^__.*__$", k):
1712 continue
1713 if isinstance(v, int):
1714 sym = symbol(k, canonical=v)
1715 elif not k.startswith("_"):
1716 raise TypeError("Expected integer values for IntFlag")
1717 else:
1718 continue
1719 setattr(cls, k, sym)
1720 items.append(sym)
1721
1722 cls.__members__ = _collections.immutabledict(
1723 {sym.name: sym for sym in items}
1724 )
1725
1726 def __iter__(self) -> Iterator[symbol]:
1727 raise NotImplementedError(
1728 "iter not implemented to ensure compatibility with "
1729 "Python 3.11 IntFlag. Please use __members__. See "
1730 "https://github.com/python/cpython/issues/99304"
1731 )
1732
1733
1734class _FastIntFlag(metaclass=_IntFlagMeta):
1735 """An 'IntFlag' copycat that isn't slow when performing bitwise
1736 operations.
1737
1738 the ``FastIntFlag`` class will return ``enum.IntFlag`` under TYPE_CHECKING
1739 and ``_FastIntFlag`` otherwise.
1740
1741 """
1742
1743
1744if TYPE_CHECKING:
1745 from enum import IntFlag
1746
1747 FastIntFlag = IntFlag
1748else:
1749 FastIntFlag = _FastIntFlag
1750
1751
1752_E = TypeVar("_E", bound=enum.Enum)
1753
1754
1755def parse_user_argument_for_enum(
1756 arg: Any,
1757 choices: Dict[_E, List[Any]],
1758 name: str,
1759 resolve_symbol_names: bool = False,
1760) -> Optional[_E]:
1761 """Given a user parameter, parse the parameter into a chosen value
1762 from a list of choice objects, typically Enum values.
1763
1764 The user argument can be a string name that matches the name of a
1765 symbol, or the symbol object itself, or any number of alternate choices
1766 such as True/False/ None etc.
1767
1768 :param arg: the user argument.
1769 :param choices: dictionary of enum values to lists of possible
1770 entries for each.
1771 :param name: name of the argument. Used in an :class:`.ArgumentError`
1772 that is raised if the parameter doesn't match any available argument.
1773
1774 """
1775 for enum_value, choice in choices.items():
1776 if arg is enum_value:
1777 return enum_value
1778 elif resolve_symbol_names and arg == enum_value.name:
1779 return enum_value
1780 elif arg in choice:
1781 return enum_value
1782
1783 if arg is None:
1784 return None
1785
1786 raise exc.ArgumentError(f"Invalid value for '{name}': {arg!r}")
1787
1788
1789_creation_order = 1
1790
1791
1792def set_creation_order(instance: Any) -> None:
1793 """Assign a '_creation_order' sequence to the given instance.
1794
1795 This allows multiple instances to be sorted in order of creation
1796 (typically within a single thread; the counter is not particularly
1797 threadsafe).
1798
1799 """
1800 global _creation_order
1801 instance._creation_order = _creation_order
1802 _creation_order += 1
1803
1804
1805def warn_exception(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
1806 """executes the given function, catches all exceptions and converts to
1807 a warning.
1808
1809 """
1810 try:
1811 return func(*args, **kwargs)
1812 except Exception:
1813 warn("%s('%s') ignored" % sys.exc_info()[0:2])
1814
1815
1816def ellipses_string(value, len_=25):
1817 try:
1818 if len(value) > len_:
1819 return "%s..." % value[0:len_]
1820 else:
1821 return value
1822 except TypeError:
1823 return value
1824
1825
1826class _hash_limit_string(str):
1827 """A string subclass that can only be hashed on a maximum amount
1828 of unique values.
1829
1830 This is used for warnings so that we can send out parameterized warnings
1831 without the __warningregistry__ of the module, or the non-overridable
1832 "once" registry within warnings.py, overloading memory,
1833
1834
1835 """
1836
1837 _hash: int
1838
1839 def __new__(
1840 cls, value: str, num: int, args: Sequence[Any]
1841 ) -> _hash_limit_string:
1842 interpolated = (value % args) + (
1843 " (this warning may be suppressed after %d occurrences)" % num
1844 )
1845 self = super().__new__(cls, interpolated)
1846 self._hash = hash("%s_%d" % (value, hash(interpolated) % num))
1847 return self
1848
1849 def __hash__(self) -> int:
1850 return self._hash
1851
1852 def __eq__(self, other: Any) -> bool:
1853 return hash(self) == hash(other)
1854
1855
1856def warn(msg: str, code: Optional[str] = None) -> None:
1857 """Issue a warning.
1858
1859 If msg is a string, :class:`.exc.SAWarning` is used as
1860 the category.
1861
1862 """
1863 if code:
1864 _warnings_warn(exc.SAWarning(msg, code=code))
1865 else:
1866 _warnings_warn(msg, exc.SAWarning)
1867
1868
1869def warn_limited(msg: str, args: Sequence[Any]) -> None:
1870 """Issue a warning with a parameterized string, limiting the number
1871 of registrations.
1872
1873 """
1874 if args:
1875 msg = _hash_limit_string(msg, 10, args)
1876 _warnings_warn(msg, exc.SAWarning)
1877
1878
1879_warning_tags: Dict[CodeType, Tuple[str, Type[Warning]]] = {}
1880
1881
1882def tag_method_for_warnings(
1883 message: str, category: Type[Warning]
1884) -> Callable[[_F], _F]:
1885 def go(fn):
1886 _warning_tags[fn.__code__] = (message, category)
1887 return fn
1888
1889 return go
1890
1891
1892_not_sa_pattern = re.compile(r"^(?:sqlalchemy\.(?!testing)|alembic\.)")
1893
1894
1895def _warnings_warn(
1896 message: Union[str, Warning],
1897 category: Optional[Type[Warning]] = None,
1898 stacklevel: int = 2,
1899) -> None:
1900
1901 if category is None and isinstance(message, Warning):
1902 category = type(message)
1903
1904 # adjust the given stacklevel to be outside of SQLAlchemy
1905 try:
1906 frame = sys._getframe(stacklevel)
1907 except ValueError:
1908 # being called from less than 3 (or given) stacklevels, weird,
1909 # but don't crash
1910 stacklevel = 0
1911 except:
1912 # _getframe() doesn't work, weird interpreter issue, weird,
1913 # ok, but don't crash
1914 stacklevel = 0
1915 else:
1916 stacklevel_found = warning_tag_found = False
1917 while frame is not None:
1918 # using __name__ here requires that we have __name__ in the
1919 # __globals__ of the decorated string functions we make also.
1920 # we generate this using {"__name__": fn.__module__}
1921 if not stacklevel_found and not re.match(
1922 _not_sa_pattern, frame.f_globals.get("__name__", "")
1923 ):
1924 # stop incrementing stack level if an out-of-SQLA line
1925 # were found.
1926 stacklevel_found = True
1927
1928 # however, for the warning tag thing, we have to keep
1929 # scanning up the whole traceback
1930
1931 if frame.f_code in _warning_tags:
1932 warning_tag_found = True
1933 (_suffix, _category) = _warning_tags[frame.f_code]
1934 category = category or _category
1935 message = f"{message} ({_suffix})"
1936
1937 frame = frame.f_back # type: ignore[assignment]
1938
1939 if not stacklevel_found:
1940 stacklevel += 1
1941 elif stacklevel_found and warning_tag_found:
1942 break
1943
1944 if category is not None:
1945 warnings.warn(message, category, stacklevel=stacklevel + 1)
1946 else:
1947 warnings.warn(message, stacklevel=stacklevel + 1)
1948
1949
1950def only_once(
1951 fn: Callable[..., _T], retry_on_exception: bool
1952) -> Callable[..., Optional[_T]]:
1953 """Decorate the given function to be a no-op after it is called exactly
1954 once."""
1955
1956 once = [fn]
1957
1958 def go(*arg: Any, **kw: Any) -> Optional[_T]:
1959 # strong reference fn so that it isn't garbage collected,
1960 # which interferes with the event system's expectations
1961 strong_fn = fn # noqa
1962 if once:
1963 once_fn = once.pop()
1964 try:
1965 return once_fn(*arg, **kw)
1966 except:
1967 if retry_on_exception:
1968 once.insert(0, once_fn)
1969 raise
1970
1971 return None
1972
1973 return go
1974
1975
1976_SQLA_RE = re.compile(r"sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py")
1977_UNITTEST_RE = re.compile(r"unit(?:2|test2?/)")
1978
1979
1980def chop_traceback(
1981 tb: List[str],
1982 exclude_prefix: re.Pattern[str] = _UNITTEST_RE,
1983 exclude_suffix: re.Pattern[str] = _SQLA_RE,
1984) -> List[str]:
1985 """Chop extraneous lines off beginning and end of a traceback.
1986
1987 :param tb:
1988 a list of traceback lines as returned by ``traceback.format_stack()``
1989
1990 :param exclude_prefix:
1991 a regular expression object matching lines to skip at beginning of
1992 ``tb``
1993
1994 :param exclude_suffix:
1995 a regular expression object matching lines to skip at end of ``tb``
1996 """
1997 start = 0
1998 end = len(tb) - 1
1999 while start <= end and exclude_prefix.search(tb[start]):
2000 start += 1
2001 while start <= end and exclude_suffix.search(tb[end]):
2002 end -= 1
2003 return tb[start : end + 1]
2004
2005
2006NoneType = type(None)
2007
2008
2009def attrsetter(attrname):
2010 code = "def set(obj, value): obj.%s = value" % attrname
2011 env = locals().copy()
2012 exec(code, env)
2013 return env["set"]
2014
2015
2016_dunders = re.compile("^__.+__$")
2017
2018
2019class TypingOnly:
2020 """A mixin class that marks a class as 'typing only', meaning it has
2021 absolutely no methods, attributes, or runtime functionality whatsoever.
2022
2023 """
2024
2025 __slots__ = ()
2026
2027 def __init_subclass__(cls) -> None:
2028 if TypingOnly in cls.__bases__:
2029 remaining = {
2030 name for name in cls.__dict__ if not _dunders.match(name)
2031 }
2032 if remaining:
2033 raise AssertionError(
2034 f"Class {cls} directly inherits TypingOnly but has "
2035 f"additional attributes {remaining}."
2036 )
2037 super().__init_subclass__()
2038
2039
2040class EnsureKWArg:
2041 r"""Apply translation of functions to accept \**kw arguments if they
2042 don't already.
2043
2044 Used to ensure cross-compatibility with third party legacy code, for things
2045 like compiler visit methods that need to accept ``**kw`` arguments,
2046 but may have been copied from old code that didn't accept them.
2047
2048 """
2049
2050 ensure_kwarg: str
2051 """a regular expression that indicates method names for which the method
2052 should accept ``**kw`` arguments.
2053
2054 The class will scan for methods matching the name template and decorate
2055 them if necessary to ensure ``**kw`` parameters are accepted.
2056
2057 """
2058
2059 def __init_subclass__(cls) -> None:
2060 fn_reg = cls.ensure_kwarg
2061 clsdict = cls.__dict__
2062 if fn_reg:
2063 for key in clsdict:
2064 m = re.match(fn_reg, key)
2065 if m:
2066 fn = clsdict[key]
2067 spec = compat.inspect_getfullargspec(fn)
2068 if not spec.varkw:
2069 wrapped = cls._wrap_w_kw(fn)
2070 setattr(cls, key, wrapped)
2071 super().__init_subclass__()
2072
2073 @classmethod
2074 def _wrap_w_kw(cls, fn: Callable[..., Any]) -> Callable[..., Any]:
2075 def wrap(*arg: Any, **kw: Any) -> Any:
2076 return fn(*arg)
2077
2078 return update_wrapper(wrap, fn)
2079
2080
2081def wrap_callable(wrapper, fn):
2082 """Augment functools.update_wrapper() to work with objects with
2083 a ``__call__()`` method.
2084
2085 :param fn:
2086 object with __call__ method
2087
2088 """
2089 if hasattr(fn, "__name__"):
2090 return update_wrapper(wrapper, fn)
2091 else:
2092 _f = wrapper
2093 _f.__name__ = fn.__class__.__name__
2094 if hasattr(fn, "__module__"):
2095 _f.__module__ = fn.__module__
2096
2097 if hasattr(fn.__call__, "__doc__") and fn.__call__.__doc__:
2098 _f.__doc__ = fn.__call__.__doc__
2099 elif fn.__doc__:
2100 _f.__doc__ = fn.__doc__
2101
2102 return _f
2103
2104
2105def quoted_token_parser(value):
2106 """Parse a dotted identifier with accommodation for quoted names.
2107
2108 Includes support for SQL-style double quotes as a literal character.
2109
2110 E.g.::
2111
2112 >>> quoted_token_parser("name")
2113 ["name"]
2114 >>> quoted_token_parser("schema.name")
2115 ["schema", "name"]
2116 >>> quoted_token_parser('"Schema"."Name"')
2117 ['Schema', 'Name']
2118 >>> quoted_token_parser('"Schema"."Name""Foo"')
2119 ['Schema', 'Name""Foo']
2120
2121 """
2122
2123 if '"' not in value:
2124 return value.split(".")
2125
2126 # 0 = outside of quotes
2127 # 1 = inside of quotes
2128 state = 0
2129 result: List[List[str]] = [[]]
2130 idx = 0
2131 lv = len(value)
2132 while idx < lv:
2133 char = value[idx]
2134 if char == '"':
2135 if state == 1 and idx < lv - 1 and value[idx + 1] == '"':
2136 result[-1].append('"')
2137 idx += 1
2138 else:
2139 state ^= 1
2140 elif char == "." and state == 0:
2141 result.append([])
2142 else:
2143 result[-1].append(char)
2144 idx += 1
2145
2146 return ["".join(token) for token in result]
2147
2148
2149def add_parameter_text(params: Any, text: str) -> Callable[[_F], _F]:
2150 params = _collections.to_list(params)
2151
2152 def decorate(fn):
2153 doc = fn.__doc__ is not None and fn.__doc__ or ""
2154 if doc:
2155 doc = inject_param_text(doc, {param: text for param in params})
2156 fn.__doc__ = doc
2157 return fn
2158
2159 return decorate
2160
2161
2162def _dedent_docstring(text: str) -> str:
2163 split_text = text.split("\n", 1)
2164 if len(split_text) == 1:
2165 return text
2166 else:
2167 firstline, remaining = split_text
2168 if not firstline.startswith(" "):
2169 return firstline + "\n" + textwrap.dedent(remaining)
2170 else:
2171 return textwrap.dedent(text)
2172
2173
2174def inject_docstring_text(
2175 given_doctext: Optional[str], injecttext: str, pos: int
2176) -> str:
2177 doctext: str = _dedent_docstring(given_doctext or "")
2178 lines = doctext.split("\n")
2179 if len(lines) == 1:
2180 lines.append("")
2181 injectlines = textwrap.dedent(injecttext).split("\n")
2182 if injectlines[0]:
2183 injectlines.insert(0, "")
2184
2185 blanks = [num for num, line in enumerate(lines) if not line.strip()]
2186 blanks.insert(0, 0)
2187
2188 inject_pos = blanks[min(pos, len(blanks) - 1)]
2189
2190 lines = lines[0:inject_pos] + injectlines + lines[inject_pos:]
2191 return "\n".join(lines)
2192
2193
2194_param_reg = re.compile(r"(\s+):param (.+?):")
2195
2196
2197def inject_param_text(doctext: str, inject_params: Dict[str, str]) -> str:
2198 doclines = collections.deque(doctext.splitlines())
2199 lines = []
2200
2201 # TODO: this is not working for params like ":param case_sensitive=True:"
2202
2203 to_inject = None
2204 while doclines:
2205 line = doclines.popleft()
2206
2207 m = _param_reg.match(line)
2208
2209 if to_inject is None:
2210 if m:
2211 param = m.group(2).lstrip("*")
2212 if param in inject_params:
2213 # default indent to that of :param: plus one
2214 indent = " " * len(m.group(1)) + " "
2215
2216 # but if the next line has text, use that line's
2217 # indentation
2218 if doclines:
2219 m2 = re.match(r"(\s+)\S", doclines[0])
2220 if m2:
2221 indent = " " * len(m2.group(1))
2222
2223 to_inject = indent + inject_params[param]
2224 elif m:
2225 lines.extend(["\n", to_inject, "\n"])
2226 to_inject = None
2227 elif not line.rstrip():
2228 lines.extend([line, to_inject, "\n"])
2229 to_inject = None
2230 elif line.endswith("::"):
2231 # TODO: this still won't cover if the code example itself has
2232 # blank lines in it, need to detect those via indentation.
2233 lines.extend([line, doclines.popleft()])
2234 continue
2235 lines.append(line)
2236
2237 return "\n".join(lines)
2238
2239
2240def repr_tuple_names(names: List[str]) -> Optional[str]:
2241 """Trims a list of strings from the middle and return a string of up to
2242 four elements. Strings greater than 11 characters will be truncated"""
2243 if len(names) == 0:
2244 return None
2245 flag = len(names) <= 4
2246 names = names[0:4] if flag else names[0:3] + names[-1:]
2247 res = ["%s.." % name[:11] if len(name) > 11 else name for name in names]
2248 if flag:
2249 return ", ".join(res)
2250 else:
2251 return "%s, ..., %s" % (", ".join(res[0:3]), res[-1])
2252
2253
2254def has_compiled_ext(raise_=False):
2255 from ._has_cython import HAS_CYEXTENSION
2256
2257 if HAS_CYEXTENSION:
2258 return True
2259 elif raise_:
2260 raise ImportError(
2261 "cython extensions were expected to be installed, "
2262 "but are not present"
2263 )
2264 else:
2265 return False
2266
2267
2268def load_uncompiled_module(module: _M) -> _M:
2269 """Load the non-compied version of a module that is also
2270 compiled with cython.
2271 """
2272 full_name = module.__name__
2273 assert module.__spec__
2274 parent_name = module.__spec__.parent
2275 assert parent_name
2276 parent_module = sys.modules[parent_name]
2277 assert parent_module.__spec__
2278 package_path = parent_module.__spec__.origin
2279 assert package_path and package_path.endswith("__init__.py")
2280
2281 name = full_name.split(".")[-1]
2282 module_path = package_path.replace("__init__.py", f"{name}.py")
2283
2284 py_spec = importlib.util.spec_from_file_location(full_name, module_path)
2285 assert py_spec
2286 py_module = importlib.util.module_from_spec(py_spec)
2287 assert py_spec.loader
2288 py_spec.loader.exec_module(py_module)
2289 return cast(_M, py_module)
2290
2291
2292class _Missing(enum.Enum):
2293 Missing = enum.auto()
2294
2295
2296Missing = _Missing.Missing
2297MissingOr = Union[_T, Literal[_Missing.Missing]]