1# util/langhelpers.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Routines to help with the creation, loading and introspection of
10modules, classes, hierarchies, attributes, functions, and methods.
11
12"""
13from __future__ import annotations
14
15import collections
16import enum
17from functools import update_wrapper
18import importlib.util
19import inspect
20import itertools
21import operator
22import re
23import sys
24import textwrap
25import threading
26import types
27from types import CodeType
28from types import ModuleType
29from typing import Any
30from typing import Callable
31from typing import cast
32from typing import Dict
33from typing import FrozenSet
34from typing import Generic
35from typing import Iterator
36from typing import List
37from typing import Literal
38from typing import Mapping
39from typing import NoReturn
40from typing import Optional
41from typing import overload
42from typing import Sequence
43from typing import Set
44from typing import Tuple
45from typing import Type
46from typing import TYPE_CHECKING
47from typing import TypeVar
48from typing import Union
49import warnings
50
51from . import _collections
52from . import compat
53from .. import exc
54
55_T = TypeVar("_T")
56_T_co = TypeVar("_T_co", covariant=True)
57_F = TypeVar("_F", bound=Callable[..., Any])
58_MA = TypeVar("_MA", bound="HasMemoized.memoized_attribute[Any]")
59_M = TypeVar("_M", bound=ModuleType)
60
61if compat.py314:
62 # vendor a minimal form of get_annotations per
63 # https://github.com/python/cpython/issues/133684#issuecomment-2863841891
64
65 from annotationlib import call_annotate_function # type: ignore
66 from annotationlib import Format
67
68 def _get_and_call_annotate(obj, format): # noqa: A002
69 annotate = getattr(obj, "__annotate__", None)
70 if annotate is not None:
71 ann = call_annotate_function(annotate, format, owner=obj)
72 if not isinstance(ann, dict):
73 raise ValueError(f"{obj!r}.__annotate__ returned a non-dict")
74 return ann
75 return None
76
77 # this is ported from py3.13.0a7
78 _BASE_GET_ANNOTATIONS = type.__dict__["__annotations__"].__get__ # type: ignore # noqa: E501
79
80 def _get_dunder_annotations(obj):
81 if isinstance(obj, type):
82 try:
83 ann = _BASE_GET_ANNOTATIONS(obj)
84 except AttributeError:
85 # For static types, the descriptor raises AttributeError.
86 return {}
87 else:
88 ann = getattr(obj, "__annotations__", None)
89 if ann is None:
90 return {}
91
92 if not isinstance(ann, dict):
93 raise ValueError(
94 f"{obj!r}.__annotations__ is neither a dict nor None"
95 )
96 return dict(ann)
97
98 def _vendored_get_annotations(
99 obj: Any, *, format: Format # noqa: A002
100 ) -> Mapping[str, Any]:
101 """A sparse implementation of annotationlib.get_annotations()"""
102
103 try:
104 ann = _get_dunder_annotations(obj)
105 except Exception:
106 pass
107 else:
108 if ann is not None:
109 return dict(ann)
110
111 # But if __annotations__ threw a NameError, we try calling __annotate__
112 ann = _get_and_call_annotate(obj, format)
113 if ann is None:
114 # If that didn't work either, we have a very weird object:
115 # evaluating
116 # __annotations__ threw NameError and there is no __annotate__.
117 # In that case,
118 # we fall back to trying __annotations__ again.
119 ann = _get_dunder_annotations(obj)
120
121 if ann is None:
122 if isinstance(obj, type) or callable(obj):
123 return {}
124 raise TypeError(f"{obj!r} does not have annotations")
125
126 if not ann:
127 return {}
128
129 return dict(ann)
130
131 def get_annotations(obj: Any) -> Mapping[str, Any]:
132 # FORWARDREF has the effect of giving us ForwardRefs and not
133 # actually trying to evaluate the annotations. We need this so
134 # that the annotations act as much like
135 # "from __future__ import annotations" as possible, which is going
136 # away in future python as a separate mode
137 return _vendored_get_annotations(obj, format=Format.FORWARDREF)
138
139else:
140
141 def get_annotations(obj: Any) -> Mapping[str, Any]:
142 return inspect.get_annotations(obj)
143
144
145def md5_hex(x: Any) -> str:
146 x = x.encode("utf-8")
147 m = compat.md5_not_for_security()
148 m.update(x)
149 return cast(str, m.hexdigest())
150
151
152class safe_reraise:
153 """Reraise an exception after invoking some
154 handler code.
155
156 Stores the existing exception info before
157 invoking so that it is maintained across a potential
158 coroutine context switch.
159
160 e.g.::
161
162 try:
163 sess.commit()
164 except:
165 with safe_reraise():
166 sess.rollback()
167
168 TODO: we should at some point evaluate current behaviors in this regard
169 based on current greenlet, gevent/eventlet implementations in Python 3, and
170 also see the degree to which our own asyncio (based on greenlet also) is
171 impacted by this. .rollback() will cause IO / context switch to occur in
172 all these scenarios; what happens to the exception context from an
173 "except:" block if we don't explicitly store it? Original issue was #2703.
174
175 """
176
177 __slots__ = ("_exc_info",)
178
179 _exc_info: Union[
180 None,
181 Tuple[
182 Type[BaseException],
183 BaseException,
184 types.TracebackType,
185 ],
186 Tuple[None, None, None],
187 ]
188
189 def __enter__(self) -> None:
190 self._exc_info = sys.exc_info()
191
192 def __exit__(
193 self,
194 type_: Optional[Type[BaseException]],
195 value: Optional[BaseException],
196 traceback: Optional[types.TracebackType],
197 ) -> NoReturn:
198 assert self._exc_info is not None
199 # see #2703 for notes
200 if type_ is None:
201 exc_type, exc_value, exc_tb = self._exc_info
202 assert exc_value is not None
203 self._exc_info = None # remove potential circular references
204 raise exc_value.with_traceback(exc_tb)
205 else:
206 self._exc_info = None # remove potential circular references
207 assert value is not None
208 raise value.with_traceback(traceback)
209
210
211def walk_subclasses(cls: Type[_T]) -> Iterator[Type[_T]]:
212 seen: Set[Any] = set()
213
214 stack = [cls]
215 while stack:
216 cls = stack.pop()
217 if cls in seen:
218 continue
219 else:
220 seen.add(cls)
221 stack.extend(cls.__subclasses__())
222 yield cls
223
224
225def string_or_unprintable(element: Any) -> str:
226 if isinstance(element, str):
227 return element
228 else:
229 try:
230 return str(element)
231 except Exception:
232 return "unprintable element %r" % element
233
234
235def clsname_as_plain_name(
236 cls: Type[Any], use_name: Optional[str] = None
237) -> str:
238 name = use_name or cls.__name__
239 return " ".join(n.lower() for n in re.findall(r"([A-Z][a-z]+|SQL)", name))
240
241
242def method_is_overridden(
243 instance_or_cls: Union[Type[Any], object],
244 against_method: Callable[..., Any],
245) -> bool:
246 """Return True if the two class methods don't match."""
247
248 if not isinstance(instance_or_cls, type):
249 current_cls = instance_or_cls.__class__
250 else:
251 current_cls = instance_or_cls
252
253 method_name = against_method.__name__
254
255 current_method: types.MethodType = getattr(current_cls, method_name)
256
257 return current_method != against_method
258
259
260def decode_slice(slc: slice) -> Tuple[Any, ...]:
261 """decode a slice object as sent to __getitem__.
262
263 takes into account the 2.5 __index__() method, basically.
264
265 """
266 ret: List[Any] = []
267 for x in slc.start, slc.stop, slc.step:
268 if hasattr(x, "__index__"):
269 x = x.__index__()
270 ret.append(x)
271 return tuple(ret)
272
273
274def _unique_symbols(used: Sequence[str], *bases: str) -> Iterator[str]:
275 used_set = set(used)
276 for base in bases:
277 pool = itertools.chain(
278 (base,),
279 map(lambda i: base + str(i), range(1000)),
280 )
281 for sym in pool:
282 if sym not in used_set:
283 used_set.add(sym)
284 yield sym
285 break
286 else:
287 raise NameError("exhausted namespace for symbol base %s" % base)
288
289
290def map_bits(fn: Callable[[int], Any], n: int) -> Iterator[Any]:
291 """Call the given function given each nonzero bit from n."""
292
293 while n:
294 b = n & (~n + 1)
295 yield fn(b)
296 n ^= b
297
298
299_Fn = TypeVar("_Fn", bound="Callable[..., Any]")
300
301# this seems to be in flux in recent mypy versions
302
303
304def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]:
305 """A signature-matching decorator factory."""
306
307 def decorate(fn: _Fn) -> _Fn:
308 if not inspect.isfunction(fn) and not inspect.ismethod(fn):
309 raise Exception("not a decoratable function")
310
311 # Python 3.14 defer creating __annotations__ until its used.
312 # We do not want to create __annotations__ now.
313 annofunc = getattr(fn, "__annotate__", None)
314 if annofunc is not None:
315 fn.__annotate__ = None # type: ignore[union-attr]
316 try:
317 spec = compat.inspect_getfullargspec(fn)
318 finally:
319 fn.__annotate__ = annofunc # type: ignore[union-attr]
320 else:
321 spec = compat.inspect_getfullargspec(fn)
322
323 # Do not generate code for annotations.
324 # update_wrapper() copies the annotation from fn to decorated.
325 # We use dummy defaults for code generation to avoid having
326 # copy of large globals for compiling.
327 # We copy __defaults__ and __kwdefaults__ from fn to decorated.
328 empty_defaults = (None,) * len(spec.defaults or ())
329 empty_kwdefaults = dict.fromkeys(spec.kwonlydefaults or ())
330 spec = spec._replace(
331 annotations={},
332 defaults=empty_defaults,
333 kwonlydefaults=empty_kwdefaults,
334 )
335
336 names = (
337 tuple(cast("Tuple[str, ...]", spec[0]))
338 + cast("Tuple[str, ...]", spec[1:3])
339 + (fn.__name__,)
340 )
341 targ_name, fn_name = _unique_symbols(names, "target", "fn")
342
343 metadata: Dict[str, Optional[str]] = dict(target=targ_name, fn=fn_name)
344 metadata.update(format_argspec_plus(spec, grouped=False))
345 metadata["name"] = fn.__name__
346
347 if inspect.iscoroutinefunction(fn):
348 metadata["prefix"] = "async "
349 metadata["target_prefix"] = "await "
350 else:
351 metadata["prefix"] = ""
352 metadata["target_prefix"] = ""
353
354 # look for __ positional arguments. This is a convention in
355 # SQLAlchemy that arguments should be passed positionally
356 # rather than as keyword
357 # arguments. note that apply_pos doesn't currently work in all cases
358 # such as when a kw-only indicator "*" is present, which is why
359 # we limit the use of this to just that case we can detect. As we add
360 # more kinds of methods that use @decorator, things may have to
361 # be further improved in this area
362 if "__" in repr(spec[0]):
363 code = (
364 """\
365%(prefix)sdef %(name)s%(grouped_args)s:
366 return %(target_prefix)s%(target)s(%(fn)s, %(apply_pos)s)
367"""
368 % metadata
369 )
370 else:
371 code = (
372 """\
373%(prefix)sdef %(name)s%(grouped_args)s:
374 return %(target_prefix)s%(target)s(%(fn)s, %(apply_kw)s)
375"""
376 % metadata
377 )
378
379 env: Dict[str, Any] = {
380 targ_name: target,
381 fn_name: fn,
382 "__name__": fn.__module__,
383 }
384
385 decorated = cast(
386 types.FunctionType,
387 _exec_code_in_env(code, env, fn.__name__),
388 )
389 decorated.__defaults__ = fn.__defaults__
390 decorated.__kwdefaults__ = fn.__kwdefaults__ # type: ignore
391 return update_wrapper(decorated, fn) # type: ignore[return-value]
392
393 return update_wrapper(decorate, target) # type: ignore[return-value]
394
395
396def _exec_code_in_env(
397 code: Union[str, types.CodeType], env: Dict[str, Any], fn_name: str
398) -> Callable[..., Any]:
399 exec(code, env)
400 return env[fn_name] # type: ignore[no-any-return]
401
402
403_PF = TypeVar("_PF")
404_TE = TypeVar("_TE")
405
406
407class PluginLoader:
408 def __init__(
409 self, group: str, auto_fn: Optional[Callable[..., Any]] = None
410 ):
411 self.group = group
412 self.impls: Dict[str, Any] = {}
413 self.auto_fn = auto_fn
414
415 def clear(self):
416 self.impls.clear()
417
418 def load(self, name: str) -> Any:
419 if name in self.impls:
420 return self.impls[name]()
421
422 if self.auto_fn:
423 loader = self.auto_fn(name)
424 if loader:
425 self.impls[name] = loader
426 return loader()
427
428 for impl in compat.importlib_metadata_get(self.group):
429 if impl.name == name:
430 self.impls[name] = impl.load
431 return impl.load()
432
433 raise exc.NoSuchModuleError(
434 "Can't load plugin: %s:%s" % (self.group, name)
435 )
436
437 def register(self, name: str, modulepath: str, objname: str) -> None:
438 def load():
439 mod = __import__(modulepath)
440 for token in modulepath.split(".")[1:]:
441 mod = getattr(mod, token)
442 return getattr(mod, objname)
443
444 self.impls[name] = load
445
446 def deregister(self, name: str) -> None:
447 del self.impls[name]
448
449
450def _inspect_func_args(fn):
451 try:
452 co_varkeywords = inspect.CO_VARKEYWORDS
453 except AttributeError:
454 # https://docs.python.org/3/library/inspect.html
455 # The flags are specific to CPython, and may not be defined in other
456 # Python implementations. Furthermore, the flags are an implementation
457 # detail, and can be removed or deprecated in future Python releases.
458 spec = compat.inspect_getfullargspec(fn)
459 return spec[0], bool(spec[2])
460 else:
461 # use fn.__code__ plus flags to reduce method call overhead
462 co = fn.__code__
463 nargs = co.co_argcount
464 return (
465 list(co.co_varnames[:nargs]),
466 bool(co.co_flags & co_varkeywords),
467 )
468
469
470@overload
471def get_cls_kwargs(
472 cls: type,
473 *,
474 _set: Optional[Set[str]] = None,
475 raiseerr: Literal[True] = ...,
476) -> Set[str]: ...
477
478
479@overload
480def get_cls_kwargs(
481 cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
482) -> Optional[Set[str]]: ...
483
484
485def get_cls_kwargs(
486 cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
487) -> Optional[Set[str]]:
488 r"""Return the full set of inherited kwargs for the given `cls`.
489
490 Probes a class's __init__ method, collecting all named arguments. If the
491 __init__ defines a \**kwargs catch-all, then the constructor is presumed
492 to pass along unrecognized keywords to its base classes, and the
493 collection process is repeated recursively on each of the bases.
494
495 Uses a subset of inspect.getfullargspec() to cut down on method overhead,
496 as this is used within the Core typing system to create copies of type
497 objects which is a performance-sensitive operation.
498
499 No anonymous tuple arguments please !
500
501 """
502 toplevel = _set is None
503 if toplevel:
504 _set = set()
505 assert _set is not None
506
507 ctr = cls.__dict__.get("__init__", False)
508
509 has_init = (
510 ctr
511 and isinstance(ctr, types.FunctionType)
512 and isinstance(ctr.__code__, types.CodeType)
513 )
514
515 if has_init:
516 names, has_kw = _inspect_func_args(ctr)
517 _set.update(names)
518
519 if not has_kw and not toplevel:
520 if raiseerr:
521 raise TypeError(
522 f"given cls {cls} doesn't have an __init__ method"
523 )
524 else:
525 return None
526 else:
527 has_kw = False
528
529 if not has_init or has_kw:
530 for c in cls.__bases__:
531 if get_cls_kwargs(c, _set=_set) is None:
532 break
533
534 _set.discard("self")
535 return _set
536
537
538def get_func_kwargs(func: Callable[..., Any]) -> List[str]:
539 """Return the set of legal kwargs for the given `func`.
540
541 Uses getargspec so is safe to call for methods, functions,
542 etc.
543
544 """
545
546 return compat.inspect_getfullargspec(func)[0]
547
548
549def get_callable_argspec(
550 fn: Callable[..., Any], no_self: bool = False, _is_init: bool = False
551) -> compat.FullArgSpec:
552 """Return the argument signature for any callable.
553
554 All pure-Python callables are accepted, including
555 functions, methods, classes, objects with __call__;
556 builtins and other edge cases like functools.partial() objects
557 raise a TypeError.
558
559 """
560 if inspect.isbuiltin(fn):
561 raise TypeError("Can't inspect builtin: %s" % fn)
562 elif inspect.isfunction(fn):
563 if _is_init and no_self:
564 spec = compat.inspect_getfullargspec(fn)
565 return compat.FullArgSpec(
566 spec.args[1:],
567 spec.varargs,
568 spec.varkw,
569 spec.defaults,
570 spec.kwonlyargs,
571 spec.kwonlydefaults,
572 spec.annotations,
573 )
574 else:
575 return compat.inspect_getfullargspec(fn)
576 elif inspect.ismethod(fn):
577 if no_self and (_is_init or fn.__self__):
578 spec = compat.inspect_getfullargspec(fn.__func__)
579 return compat.FullArgSpec(
580 spec.args[1:],
581 spec.varargs,
582 spec.varkw,
583 spec.defaults,
584 spec.kwonlyargs,
585 spec.kwonlydefaults,
586 spec.annotations,
587 )
588 else:
589 return compat.inspect_getfullargspec(fn.__func__)
590 elif inspect.isclass(fn):
591 return get_callable_argspec(
592 fn.__init__, no_self=no_self, _is_init=True
593 )
594 elif hasattr(fn, "__func__"):
595 return compat.inspect_getfullargspec(fn.__func__)
596 elif hasattr(fn, "__call__"):
597 if inspect.ismethod(fn.__call__):
598 return get_callable_argspec(fn.__call__, no_self=no_self)
599 else:
600 raise TypeError("Can't inspect callable: %s" % fn)
601 else:
602 raise TypeError("Can't inspect callable: %s" % fn)
603
604
605def format_argspec_plus(
606 fn: Union[Callable[..., Any], compat.FullArgSpec], grouped: bool = True
607) -> Dict[str, Optional[str]]:
608 """Returns a dictionary of formatted, introspected function arguments.
609
610 A enhanced variant of inspect.formatargspec to support code generation.
611
612 fn
613 An inspectable callable or tuple of inspect getargspec() results.
614 grouped
615 Defaults to True; include (parens, around, argument) lists
616
617 Returns:
618
619 args
620 Full inspect.formatargspec for fn
621 self_arg
622 The name of the first positional argument, varargs[0], or None
623 if the function defines no positional arguments.
624 apply_pos
625 args, re-written in calling rather than receiving syntax. Arguments are
626 passed positionally.
627 apply_kw
628 Like apply_pos, except keyword-ish args are passed as keywords.
629 apply_pos_proxied
630 Like apply_pos but omits the self/cls argument
631
632 Example::
633
634 >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123)
635 {'grouped_args': '(self, a, b, c=3, **d)',
636 'self_arg': 'self',
637 'apply_kw': '(self, a, b, c=c, **d)',
638 'apply_pos': '(self, a, b, c, **d)'}
639
640 """
641 if callable(fn):
642 spec = compat.inspect_getfullargspec(fn)
643 else:
644 spec = fn
645
646 args = compat.inspect_formatargspec(*spec)
647
648 apply_pos = compat.inspect_formatargspec(
649 spec[0], spec[1], spec[2], None, spec[4]
650 )
651
652 if spec[0]:
653 self_arg = spec[0][0]
654
655 apply_pos_proxied = compat.inspect_formatargspec(
656 spec[0][1:], spec[1], spec[2], None, spec[4]
657 )
658
659 elif spec[1]:
660 # I'm not sure what this is
661 self_arg = "%s[0]" % spec[1]
662
663 apply_pos_proxied = apply_pos
664 else:
665 self_arg = None
666 apply_pos_proxied = apply_pos
667
668 num_defaults = 0
669 if spec[3]:
670 num_defaults += len(cast(Tuple[Any], spec[3]))
671 if spec[4]:
672 num_defaults += len(spec[4])
673
674 name_args = spec[0] + spec[4]
675
676 defaulted_vals: Union[List[str], Tuple[()]]
677
678 if num_defaults:
679 defaulted_vals = name_args[0 - num_defaults :]
680 else:
681 defaulted_vals = ()
682
683 apply_kw = compat.inspect_formatargspec(
684 name_args,
685 spec[1],
686 spec[2],
687 defaulted_vals,
688 formatvalue=lambda x: "=" + str(x),
689 )
690
691 if spec[0]:
692 apply_kw_proxied = compat.inspect_formatargspec(
693 name_args[1:],
694 spec[1],
695 spec[2],
696 defaulted_vals,
697 formatvalue=lambda x: "=" + str(x),
698 )
699 else:
700 apply_kw_proxied = apply_kw
701
702 if grouped:
703 return dict(
704 grouped_args=args,
705 self_arg=self_arg,
706 apply_pos=apply_pos,
707 apply_kw=apply_kw,
708 apply_pos_proxied=apply_pos_proxied,
709 apply_kw_proxied=apply_kw_proxied,
710 )
711 else:
712 return dict(
713 grouped_args=args,
714 self_arg=self_arg,
715 apply_pos=apply_pos[1:-1],
716 apply_kw=apply_kw[1:-1],
717 apply_pos_proxied=apply_pos_proxied[1:-1],
718 apply_kw_proxied=apply_kw_proxied[1:-1],
719 )
720
721
722def format_argspec_init(method, grouped=True):
723 """format_argspec_plus with considerations for typical __init__ methods
724
725 Wraps format_argspec_plus with error handling strategies for typical
726 __init__ cases:
727
728 .. sourcecode:: text
729
730 object.__init__ -> (self)
731 other unreflectable (usually C) -> (self, *args, **kwargs)
732
733 """
734 if method is object.__init__:
735 grouped_args = "(self)"
736 args = "(self)" if grouped else "self"
737 proxied = "()" if grouped else ""
738 else:
739 try:
740 return format_argspec_plus(method, grouped=grouped)
741 except TypeError:
742 grouped_args = "(self, *args, **kwargs)"
743 args = grouped_args if grouped else "self, *args, **kwargs"
744 proxied = "(*args, **kwargs)" if grouped else "*args, **kwargs"
745 return dict(
746 self_arg="self",
747 grouped_args=grouped_args,
748 apply_pos=args,
749 apply_kw=args,
750 apply_pos_proxied=proxied,
751 apply_kw_proxied=proxied,
752 )
753
754
755def create_proxy_methods(
756 target_cls: Type[Any],
757 target_cls_sphinx_name: str,
758 proxy_cls_sphinx_name: str,
759 classmethods: Sequence[str] = (),
760 methods: Sequence[str] = (),
761 attributes: Sequence[str] = (),
762 use_intermediate_variable: Sequence[str] = (),
763) -> Callable[[_T], _T]:
764 """A class decorator indicating attributes should refer to a proxy
765 class.
766
767 This decorator is now a "marker" that does nothing at runtime. Instead,
768 it is consumed by the tools/generate_proxy_methods.py script to
769 statically generate proxy methods and attributes that are fully
770 recognized by typing tools such as mypy.
771
772 """
773
774 def decorate(cls):
775 return cls
776
777 return decorate
778
779
780def getargspec_init(method):
781 """inspect.getargspec with considerations for typical __init__ methods
782
783 Wraps inspect.getargspec with error handling for typical __init__ cases:
784
785 .. sourcecode:: text
786
787 object.__init__ -> (self)
788 other unreflectable (usually C) -> (self, *args, **kwargs)
789
790 """
791 try:
792 return compat.inspect_getfullargspec(method)
793 except TypeError:
794 if method is object.__init__:
795 return (["self"], None, None, None)
796 else:
797 return (["self"], "args", "kwargs", None)
798
799
800def unbound_method_to_callable(func_or_cls):
801 """Adjust the incoming callable such that a 'self' argument is not
802 required.
803
804 """
805
806 if isinstance(func_or_cls, types.MethodType) and not func_or_cls.__self__:
807 return func_or_cls.__func__
808 else:
809 return func_or_cls
810
811
812def generic_repr(
813 obj: Any,
814 additional_kw: Sequence[Tuple[str, Any]] = (),
815 to_inspect: Optional[Union[object, List[object]]] = None,
816 omit_kwarg: Sequence[str] = (),
817) -> str:
818 """Produce a __repr__() based on direct association of the __init__()
819 specification vs. same-named attributes present.
820
821 """
822 if to_inspect is None:
823 to_inspect = [obj]
824 else:
825 to_inspect = _collections.to_list(to_inspect)
826
827 missing = object()
828
829 pos_args = []
830 kw_args: _collections.OrderedDict[str, Any] = _collections.OrderedDict()
831 vargs = None
832 for i, insp in enumerate(to_inspect):
833 try:
834 spec = compat.inspect_getfullargspec(insp.__init__)
835 except TypeError:
836 continue
837 else:
838 default_len = len(spec.defaults) if spec.defaults else 0
839 if i == 0:
840 if spec.varargs:
841 vargs = spec.varargs
842 if default_len:
843 pos_args.extend(spec.args[1:-default_len])
844 else:
845 pos_args.extend(spec.args[1:])
846 else:
847 kw_args.update(
848 [(arg, missing) for arg in spec.args[1:-default_len]]
849 )
850
851 if default_len:
852 assert spec.defaults
853 kw_args.update(
854 [
855 (arg, default)
856 for arg, default in zip(
857 spec.args[-default_len:], spec.defaults
858 )
859 ]
860 )
861 output: List[str] = []
862
863 output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
864
865 if vargs is not None and hasattr(obj, vargs):
866 output.extend([repr(val) for val in getattr(obj, vargs)])
867
868 for arg, defval in kw_args.items():
869 if arg in omit_kwarg:
870 continue
871 try:
872 val = getattr(obj, arg, missing)
873 if val is not missing and val != defval:
874 output.append("%s=%r" % (arg, val))
875 except Exception:
876 pass
877
878 if additional_kw:
879 for arg, defval in additional_kw:
880 try:
881 val = getattr(obj, arg, missing)
882 if val is not missing and val != defval:
883 output.append("%s=%r" % (arg, val))
884 except Exception:
885 pass
886
887 return "%s(%s)" % (obj.__class__.__name__, ", ".join(output))
888
889
890def class_hierarchy(cls):
891 """Return an unordered sequence of all classes related to cls.
892
893 Traverses diamond hierarchies.
894
895 Fibs slightly: subclasses of builtin types are not returned. Thus
896 class_hierarchy(class A(object)) returns (A, object), not A plus every
897 class systemwide that derives from object.
898
899 """
900
901 hier = {cls}
902 process = list(cls.__mro__)
903 while process:
904 c = process.pop()
905 bases = (_ for _ in c.__bases__ if _ not in hier)
906
907 for b in bases:
908 process.append(b)
909 hier.add(b)
910
911 if c.__module__ == "builtins" or not hasattr(c, "__subclasses__"):
912 continue
913
914 for s in [
915 _
916 for _ in (
917 c.__subclasses__()
918 if not issubclass(c, type)
919 else c.__subclasses__(c)
920 )
921 if _ not in hier
922 ]:
923 process.append(s)
924 hier.add(s)
925 return list(hier)
926
927
928def iterate_attributes(cls):
929 """iterate all the keys and attributes associated
930 with a class, without using getattr().
931
932 Does not use getattr() so that class-sensitive
933 descriptors (i.e. property.__get__()) are not called.
934
935 """
936 keys = dir(cls)
937 for key in keys:
938 for c in cls.__mro__:
939 if key in c.__dict__:
940 yield (key, c.__dict__[key])
941 break
942
943
944def monkeypatch_proxied_specials(
945 into_cls,
946 from_cls,
947 skip=None,
948 only=None,
949 name="self.proxy",
950 from_instance=None,
951):
952 """Automates delegation of __specials__ for a proxying type."""
953
954 if only:
955 dunders = only
956 else:
957 if skip is None:
958 skip = (
959 "__slots__",
960 "__del__",
961 "__getattribute__",
962 "__metaclass__",
963 "__getstate__",
964 "__setstate__",
965 )
966 dunders = [
967 m
968 for m in dir(from_cls)
969 if (
970 m.startswith("__")
971 and m.endswith("__")
972 and not hasattr(into_cls, m)
973 and m not in skip
974 )
975 ]
976
977 for method in dunders:
978 try:
979 maybe_fn = getattr(from_cls, method)
980 if not hasattr(maybe_fn, "__call__"):
981 continue
982 maybe_fn = getattr(maybe_fn, "__func__", maybe_fn)
983 fn = cast(types.FunctionType, maybe_fn)
984
985 except AttributeError:
986 continue
987 try:
988 spec = compat.inspect_getfullargspec(fn)
989 fn_args = compat.inspect_formatargspec(spec[0])
990 d_args = compat.inspect_formatargspec(spec[0][1:])
991 except TypeError:
992 fn_args = "(self, *args, **kw)"
993 d_args = "(*args, **kw)"
994
995 py = (
996 "def %(method)s%(fn_args)s: "
997 "return %(name)s.%(method)s%(d_args)s" % locals()
998 )
999
1000 env: Dict[str, types.FunctionType] = (
1001 from_instance is not None and {name: from_instance} or {}
1002 )
1003 exec(py, env)
1004 try:
1005 env[method].__defaults__ = fn.__defaults__
1006 except AttributeError:
1007 pass
1008 setattr(into_cls, method, env[method])
1009
1010
1011def methods_equivalent(meth1, meth2):
1012 """Return True if the two methods are the same implementation."""
1013
1014 return getattr(meth1, "__func__", meth1) is getattr(
1015 meth2, "__func__", meth2
1016 )
1017
1018
1019def as_interface(obj, cls=None, methods=None, required=None):
1020 """Ensure basic interface compliance for an instance or dict of callables.
1021
1022 Checks that ``obj`` implements public methods of ``cls`` or has members
1023 listed in ``methods``. If ``required`` is not supplied, implementing at
1024 least one interface method is sufficient. Methods present on ``obj`` that
1025 are not in the interface are ignored.
1026
1027 If ``obj`` is a dict and ``dict`` does not meet the interface
1028 requirements, the keys of the dictionary are inspected. Keys present in
1029 ``obj`` that are not in the interface will raise TypeErrors.
1030
1031 Raises TypeError if ``obj`` does not meet the interface criteria.
1032
1033 In all passing cases, an object with callable members is returned. In the
1034 simple case, ``obj`` is returned as-is; if dict processing kicks in then
1035 an anonymous class is returned.
1036
1037 obj
1038 A type, instance, or dictionary of callables.
1039 cls
1040 Optional, a type. All public methods of cls are considered the
1041 interface. An ``obj`` instance of cls will always pass, ignoring
1042 ``required``..
1043 methods
1044 Optional, a sequence of method names to consider as the interface.
1045 required
1046 Optional, a sequence of mandatory implementations. If omitted, an
1047 ``obj`` that provides at least one interface method is considered
1048 sufficient. As a convenience, required may be a type, in which case
1049 all public methods of the type are required.
1050
1051 """
1052 if not cls and not methods:
1053 raise TypeError("a class or collection of method names are required")
1054
1055 if isinstance(cls, type) and isinstance(obj, cls):
1056 return obj
1057
1058 interface = set(methods or [m for m in dir(cls) if not m.startswith("_")])
1059 implemented = set(dir(obj))
1060
1061 complies = operator.ge
1062 if isinstance(required, type):
1063 required = interface
1064 elif not required:
1065 required = set()
1066 complies = operator.gt
1067 else:
1068 required = set(required)
1069
1070 if complies(implemented.intersection(interface), required):
1071 return obj
1072
1073 # No dict duck typing here.
1074 if not isinstance(obj, dict):
1075 qualifier = complies is operator.gt and "any of" or "all of"
1076 raise TypeError(
1077 "%r does not implement %s: %s"
1078 % (obj, qualifier, ", ".join(interface))
1079 )
1080
1081 class AnonymousInterface:
1082 """A callable-holding shell."""
1083
1084 if cls:
1085 AnonymousInterface.__name__ = "Anonymous" + cls.__name__
1086 found = set()
1087
1088 for method, impl in dictlike_iteritems(obj):
1089 if method not in interface:
1090 raise TypeError("%r: unknown in this interface" % method)
1091 if not callable(impl):
1092 raise TypeError("%r=%r is not callable" % (method, impl))
1093 setattr(AnonymousInterface, method, staticmethod(impl))
1094 found.add(method)
1095
1096 if complies(found, required):
1097 return AnonymousInterface
1098
1099 raise TypeError(
1100 "dictionary does not contain required keys %s"
1101 % ", ".join(required - found)
1102 )
1103
1104
1105_GFD = TypeVar("_GFD", bound="generic_fn_descriptor[Any]")
1106
1107
1108class generic_fn_descriptor(Generic[_T_co]):
1109 """Descriptor which proxies a function when the attribute is not
1110 present in dict
1111
1112 This superclass is organized in a particular way with "memoized" and
1113 "non-memoized" implementation classes that are hidden from type checkers,
1114 as Mypy seems to not be able to handle seeing multiple kinds of descriptor
1115 classes used for the same attribute.
1116
1117 """
1118
1119 fget: Callable[..., _T_co]
1120 __doc__: Optional[str]
1121 __name__: str
1122
1123 def __init__(self, fget: Callable[..., _T_co], doc: Optional[str] = None):
1124 self.fget = fget
1125 self.__doc__ = doc or fget.__doc__
1126 self.__name__ = fget.__name__
1127
1128 @overload
1129 def __get__(self: _GFD, obj: None, cls: Any) -> _GFD: ...
1130
1131 @overload
1132 def __get__(self, obj: object, cls: Any) -> _T_co: ...
1133
1134 def __get__(self: _GFD, obj: Any, cls: Any) -> Union[_GFD, _T_co]:
1135 raise NotImplementedError()
1136
1137 if TYPE_CHECKING:
1138
1139 def __set__(self, instance: Any, value: Any) -> None: ...
1140
1141 def __delete__(self, instance: Any) -> None: ...
1142
1143 def _reset(self, obj: Any) -> None:
1144 raise NotImplementedError()
1145
1146 @classmethod
1147 def reset(cls, obj: Any, name: str) -> None:
1148 raise NotImplementedError()
1149
1150
1151class _non_memoized_property(generic_fn_descriptor[_T_co]):
1152 """a plain descriptor that proxies a function.
1153
1154 primary rationale is to provide a plain attribute that's
1155 compatible with memoized_property which is also recognized as equivalent
1156 by mypy.
1157
1158 """
1159
1160 if not TYPE_CHECKING:
1161
1162 def __get__(self, obj, cls):
1163 if obj is None:
1164 return self
1165 return self.fget(obj)
1166
1167
1168class _memoized_property(generic_fn_descriptor[_T_co]):
1169 """A read-only @property that is only evaluated once."""
1170
1171 if not TYPE_CHECKING:
1172
1173 def __get__(self, obj, cls):
1174 if obj is None:
1175 return self
1176 obj.__dict__[self.__name__] = result = self.fget(obj)
1177 return result
1178
1179 def _reset(self, obj):
1180 _memoized_property.reset(obj, self.__name__)
1181
1182 @classmethod
1183 def reset(cls, obj, name):
1184 obj.__dict__.pop(name, None)
1185
1186
1187# despite many attempts to get Mypy to recognize an overridden descriptor
1188# where one is memoized and the other isn't, there seems to be no reliable
1189# way other than completely deceiving the type checker into thinking there
1190# is just one single descriptor type everywhere. Otherwise, if a superclass
1191# has non-memoized and subclass has memoized, that requires
1192# "class memoized(non_memoized)". but then if a superclass has memoized and
1193# superclass has non-memoized, the class hierarchy of the descriptors
1194# would need to be reversed; "class non_memoized(memoized)". so there's no
1195# way to achieve this.
1196# additional issues, RO properties:
1197# https://github.com/python/mypy/issues/12440
1198if TYPE_CHECKING:
1199 # allow memoized and non-memoized to be freely mixed by having them
1200 # be the same class
1201 memoized_property = generic_fn_descriptor
1202 non_memoized_property = generic_fn_descriptor
1203
1204 # for read only situations, mypy only sees @property as read only.
1205 # read only is needed when a subtype specializes the return type
1206 # of a property, meaning assignment needs to be disallowed
1207 ro_memoized_property = property
1208 ro_non_memoized_property = property
1209
1210else:
1211 memoized_property = ro_memoized_property = _memoized_property
1212 non_memoized_property = ro_non_memoized_property = _non_memoized_property
1213
1214
1215def memoized_instancemethod(fn: _F) -> _F:
1216 """Decorate a method memoize its return value.
1217
1218 Best applied to no-arg methods: memoization is not sensitive to
1219 argument values, and will always return the same value even when
1220 called with different arguments.
1221
1222 """
1223
1224 def oneshot(self, *args, **kw):
1225 result = fn(self, *args, **kw)
1226
1227 def memo(*a, **kw):
1228 return result
1229
1230 memo.__name__ = fn.__name__
1231 memo.__doc__ = fn.__doc__
1232 self.__dict__[fn.__name__] = memo
1233 return result
1234
1235 return update_wrapper(oneshot, fn) # type: ignore
1236
1237
1238class HasMemoized:
1239 """A mixin class that maintains the names of memoized elements in a
1240 collection for easy cache clearing, generative, etc.
1241
1242 """
1243
1244 if not TYPE_CHECKING:
1245 # support classes that want to have __slots__ with an explicit
1246 # slot for __dict__. not sure if that requires base __slots__ here.
1247 __slots__ = ()
1248
1249 _memoized_keys: FrozenSet[str] = frozenset()
1250
1251 def _reset_memoizations(self) -> None:
1252 for elem in self._memoized_keys:
1253 self.__dict__.pop(elem, None)
1254
1255 def _assert_no_memoizations(self) -> None:
1256 for elem in self._memoized_keys:
1257 assert elem not in self.__dict__
1258
1259 def _set_memoized_attribute(self, key: str, value: Any) -> None:
1260 self.__dict__[key] = value
1261 self._memoized_keys |= {key}
1262
1263 class memoized_attribute(memoized_property[_T]):
1264 """A read-only @property that is only evaluated once.
1265
1266 :meta private:
1267
1268 """
1269
1270 fget: Callable[..., _T]
1271 __doc__: Optional[str]
1272 __name__: str
1273
1274 def __init__(self, fget: Callable[..., _T], doc: Optional[str] = None):
1275 self.fget = fget
1276 self.__doc__ = doc or fget.__doc__
1277 self.__name__ = fget.__name__
1278
1279 @overload
1280 def __get__(self: _MA, obj: None, cls: Any) -> _MA: ...
1281
1282 @overload
1283 def __get__(self, obj: Any, cls: Any) -> _T: ...
1284
1285 def __get__(self, obj, cls):
1286 if obj is None:
1287 return self
1288 obj.__dict__[self.__name__] = result = self.fget(obj)
1289 obj._memoized_keys |= {self.__name__}
1290 return result
1291
1292 @classmethod
1293 def memoized_instancemethod(cls, fn: _F) -> _F:
1294 """Decorate a method memoize its return value.
1295
1296 :meta private:
1297
1298 """
1299
1300 def oneshot(self: Any, *args: Any, **kw: Any) -> Any:
1301 result = fn(self, *args, **kw)
1302
1303 def memo(*a, **kw):
1304 return result
1305
1306 memo.__name__ = fn.__name__
1307 memo.__doc__ = fn.__doc__
1308 self.__dict__[fn.__name__] = memo
1309 self._memoized_keys |= {fn.__name__}
1310 return result
1311
1312 return update_wrapper(oneshot, fn) # type: ignore
1313
1314
1315if TYPE_CHECKING:
1316 HasMemoized_ro_memoized_attribute = property
1317else:
1318 HasMemoized_ro_memoized_attribute = HasMemoized.memoized_attribute
1319
1320
1321class MemoizedSlots:
1322 """Apply memoized items to an object using a __getattr__ scheme.
1323
1324 This allows the functionality of memoized_property and
1325 memoized_instancemethod to be available to a class using __slots__.
1326
1327 """
1328
1329 __slots__ = ()
1330
1331 def _fallback_getattr(self, key):
1332 raise AttributeError(key)
1333
1334 def __getattr__(self, key: str) -> Any:
1335 if key.startswith("_memoized_attr_") or key.startswith(
1336 "_memoized_method_"
1337 ):
1338 raise AttributeError(key)
1339 # to avoid recursion errors when interacting with other __getattr__
1340 # schemes that refer to this one, when testing for memoized method
1341 # look at __class__ only rather than going into __getattr__ again.
1342 elif hasattr(self.__class__, f"_memoized_attr_{key}"):
1343 value = getattr(self, f"_memoized_attr_{key}")()
1344 setattr(self, key, value)
1345 return value
1346 elif hasattr(self.__class__, f"_memoized_method_{key}"):
1347 fn = getattr(self, f"_memoized_method_{key}")
1348
1349 def oneshot(*args, **kw):
1350 result = fn(*args, **kw)
1351
1352 def memo(*a, **kw):
1353 return result
1354
1355 memo.__name__ = fn.__name__
1356 memo.__doc__ = fn.__doc__
1357 setattr(self, key, memo)
1358 return result
1359
1360 oneshot.__doc__ = fn.__doc__
1361 return oneshot
1362 else:
1363 return self._fallback_getattr(key)
1364
1365
1366# from paste.deploy.converters
1367def asbool(obj: Any) -> bool:
1368 if isinstance(obj, str):
1369 obj = obj.strip().lower()
1370 if obj in ["true", "yes", "on", "y", "t", "1"]:
1371 return True
1372 elif obj in ["false", "no", "off", "n", "f", "0"]:
1373 return False
1374 else:
1375 raise ValueError("String is not true/false: %r" % obj)
1376 return bool(obj)
1377
1378
1379def bool_or_str(*text: str) -> Callable[[str], Union[str, bool]]:
1380 """Return a callable that will evaluate a string as
1381 boolean, or one of a set of "alternate" string values.
1382
1383 """
1384
1385 def bool_or_value(obj: str) -> Union[str, bool]:
1386 if obj in text:
1387 return obj
1388 else:
1389 return asbool(obj)
1390
1391 return bool_or_value
1392
1393
1394def asint(value: Any) -> Optional[int]:
1395 """Coerce to integer."""
1396
1397 if value is None:
1398 return value
1399 return int(value)
1400
1401
1402def coerce_kw_type(
1403 kw: Dict[str, Any],
1404 key: str,
1405 type_: Type[Any],
1406 flexi_bool: bool = True,
1407 dest: Optional[Dict[str, Any]] = None,
1408) -> None:
1409 r"""If 'key' is present in dict 'kw', coerce its value to type 'type\_' if
1410 necessary. If 'flexi_bool' is True, the string '0' is considered false
1411 when coercing to boolean.
1412 """
1413
1414 if dest is None:
1415 dest = kw
1416
1417 if (
1418 key in kw
1419 and (not isinstance(type_, type) or not isinstance(kw[key], type_))
1420 and kw[key] is not None
1421 ):
1422 if type_ is bool and flexi_bool:
1423 dest[key] = asbool(kw[key])
1424 else:
1425 dest[key] = type_(kw[key])
1426
1427
1428def constructor_key(obj: Any, cls: Type[Any]) -> Tuple[Any, ...]:
1429 """Produce a tuple structure that is cacheable using the __dict__ of
1430 obj to retrieve values
1431
1432 """
1433 names = get_cls_kwargs(cls)
1434 return (cls,) + tuple(
1435 (k, obj.__dict__[k]) for k in names if k in obj.__dict__
1436 )
1437
1438
1439def constructor_copy(obj: _T, cls: Type[_T], *args: Any, **kw: Any) -> _T:
1440 """Instantiate cls using the __dict__ of obj as constructor arguments.
1441
1442 Uses inspect to match the named arguments of ``cls``.
1443
1444 """
1445
1446 names = get_cls_kwargs(cls)
1447 kw.update(
1448 (k, obj.__dict__[k]) for k in names.difference(kw) if k in obj.__dict__
1449 )
1450 return cls(*args, **kw)
1451
1452
1453def counter() -> Callable[[], int]:
1454 """Return a threadsafe counter function."""
1455
1456 lock = threading.Lock()
1457 counter = itertools.count(1)
1458
1459 # avoid the 2to3 "next" transformation...
1460 def _next():
1461 with lock:
1462 return next(counter)
1463
1464 return _next
1465
1466
1467def duck_type_collection(
1468 specimen: Any, default: Optional[Type[Any]] = None
1469) -> Optional[Type[Any]]:
1470 """Given an instance or class, guess if it is or is acting as one of
1471 the basic collection types: list, set and dict. If the __emulates__
1472 property is present, return that preferentially.
1473 """
1474
1475 if hasattr(specimen, "__emulates__"):
1476 # canonicalize set vs sets.Set to a standard: the builtin set
1477 if specimen.__emulates__ is not None and issubclass(
1478 specimen.__emulates__, set
1479 ):
1480 return set
1481 else:
1482 return specimen.__emulates__ # type: ignore
1483
1484 isa = issubclass if isinstance(specimen, type) else isinstance
1485 if isa(specimen, list):
1486 return list
1487 elif isa(specimen, set):
1488 return set
1489 elif isa(specimen, dict):
1490 return dict
1491
1492 if hasattr(specimen, "append"):
1493 return list
1494 elif hasattr(specimen, "add"):
1495 return set
1496 elif hasattr(specimen, "set"):
1497 return dict
1498 else:
1499 return default
1500
1501
1502def assert_arg_type(
1503 arg: Any, argtype: Union[Tuple[Type[Any], ...], Type[Any]], name: str
1504) -> Any:
1505 if isinstance(arg, argtype):
1506 return arg
1507 else:
1508 if isinstance(argtype, tuple):
1509 raise exc.ArgumentError(
1510 "Argument '%s' is expected to be one of type %s, got '%s'"
1511 % (name, " or ".join("'%s'" % a for a in argtype), type(arg))
1512 )
1513 else:
1514 raise exc.ArgumentError(
1515 "Argument '%s' is expected to be of type '%s', got '%s'"
1516 % (name, argtype, type(arg))
1517 )
1518
1519
1520def dictlike_iteritems(dictlike):
1521 """Return a (key, value) iterator for almost any dict-like object."""
1522
1523 if hasattr(dictlike, "items"):
1524 return list(dictlike.items())
1525
1526 getter = getattr(dictlike, "__getitem__", getattr(dictlike, "get", None))
1527 if getter is None:
1528 raise TypeError("Object '%r' is not dict-like" % dictlike)
1529
1530 if hasattr(dictlike, "iterkeys"):
1531
1532 def iterator():
1533 for key in dictlike.iterkeys():
1534 assert getter is not None
1535 yield key, getter(key)
1536
1537 return iterator()
1538 elif hasattr(dictlike, "keys"):
1539 return iter((key, getter(key)) for key in dictlike.keys())
1540 else:
1541 raise TypeError("Object '%r' is not dict-like" % dictlike)
1542
1543
1544class classproperty(property):
1545 """A decorator that behaves like @property except that operates
1546 on classes rather than instances.
1547
1548 The decorator is currently special when using the declarative
1549 module, but note that the
1550 :class:`~.sqlalchemy.ext.declarative.declared_attr`
1551 decorator should be used for this purpose with declarative.
1552
1553 """
1554
1555 fget: Callable[[Any], Any]
1556
1557 def __init__(self, fget: Callable[[Any], Any], *arg: Any, **kw: Any):
1558 super().__init__(fget, *arg, **kw)
1559 self.__doc__ = fget.__doc__
1560
1561 def __get__(self, obj: Any, cls: Optional[type] = None) -> Any:
1562 return self.fget(cls)
1563
1564
1565class hybridproperty(Generic[_T]):
1566 def __init__(self, func: Callable[..., _T]):
1567 self.func = func
1568 self.clslevel = func
1569
1570 def __get__(self, instance: Any, owner: Any) -> _T:
1571 if instance is None:
1572 clsval = self.clslevel(owner)
1573 return clsval
1574 else:
1575 return self.func(instance)
1576
1577 def classlevel(self, func: Callable[..., Any]) -> hybridproperty[_T]:
1578 self.clslevel = func
1579 return self
1580
1581
1582class rw_hybridproperty(Generic[_T]):
1583 def __init__(self, func: Callable[..., _T]):
1584 self.func = func
1585 self.clslevel = func
1586 self.setfn: Optional[Callable[..., Any]] = None
1587
1588 def __get__(self, instance: Any, owner: Any) -> _T:
1589 if instance is None:
1590 clsval = self.clslevel(owner)
1591 return clsval
1592 else:
1593 return self.func(instance)
1594
1595 def __set__(self, instance: Any, value: Any) -> None:
1596 assert self.setfn is not None
1597 self.setfn(instance, value)
1598
1599 def setter(self, func: Callable[..., Any]) -> rw_hybridproperty[_T]:
1600 self.setfn = func
1601 return self
1602
1603 def classlevel(self, func: Callable[..., Any]) -> rw_hybridproperty[_T]:
1604 self.clslevel = func
1605 return self
1606
1607
1608class hybridmethod(Generic[_T]):
1609 """Decorate a function as cls- or instance- level."""
1610
1611 def __init__(self, func: Callable[..., _T]):
1612 self.func = self.__func__ = func
1613 self.clslevel = func
1614
1615 def __get__(self, instance: Any, owner: Any) -> Callable[..., _T]:
1616 if instance is None:
1617 return self.clslevel.__get__(owner, owner.__class__) # type:ignore
1618 else:
1619 return self.func.__get__(instance, owner) # type:ignore
1620
1621 def classlevel(self, func: Callable[..., Any]) -> hybridmethod[_T]:
1622 self.clslevel = func
1623 return self
1624
1625
1626class symbol(int):
1627 """A constant symbol.
1628
1629 >>> symbol("foo") is symbol("foo")
1630 True
1631 >>> symbol("foo")
1632 <symbol 'foo>
1633
1634 A slight refinement of the MAGICCOOKIE=object() pattern. The primary
1635 advantage of symbol() is its repr(). They are also singletons.
1636
1637 Repeated calls of symbol('name') will all return the same instance.
1638
1639 """
1640
1641 name: str
1642
1643 symbols: Dict[str, symbol] = {}
1644 _lock = threading.Lock()
1645
1646 def __new__(
1647 cls,
1648 name: str,
1649 doc: Optional[str] = None,
1650 canonical: Optional[int] = None,
1651 ) -> symbol:
1652 with cls._lock:
1653 sym = cls.symbols.get(name)
1654 if sym is None:
1655 assert isinstance(name, str)
1656 if canonical is None:
1657 canonical = hash(name)
1658 sym = int.__new__(symbol, canonical)
1659 sym.name = name
1660 if doc:
1661 sym.__doc__ = doc
1662
1663 # NOTE: we should ultimately get rid of this global thing,
1664 # however, currently it is to support pickling. The best
1665 # change would be when we are on py3.11 at a minimum, we
1666 # switch to stdlib enum.IntFlag.
1667 cls.symbols[name] = sym
1668 else:
1669 if canonical and canonical != sym:
1670 raise TypeError(
1671 f"Can't replace canonical symbol for {name!r} "
1672 f"with new int value {canonical}"
1673 )
1674 return sym
1675
1676 def __reduce__(self):
1677 return symbol, (self.name, "x", int(self))
1678
1679 def __str__(self):
1680 return repr(self)
1681
1682 def __repr__(self):
1683 return f"symbol({self.name!r})"
1684
1685
1686class _IntFlagMeta(type):
1687 def __init__(
1688 cls,
1689 classname: str,
1690 bases: Tuple[Type[Any], ...],
1691 dict_: Dict[str, Any],
1692 **kw: Any,
1693 ) -> None:
1694 items: List[symbol]
1695 cls._items = items = []
1696 for k, v in dict_.items():
1697 if re.match(r"^__.*__$", k):
1698 continue
1699 if isinstance(v, int):
1700 sym = symbol(k, canonical=v)
1701 elif not k.startswith("_"):
1702 raise TypeError("Expected integer values for IntFlag")
1703 else:
1704 continue
1705 setattr(cls, k, sym)
1706 items.append(sym)
1707
1708 cls.__members__ = _collections.immutabledict(
1709 {sym.name: sym for sym in items}
1710 )
1711
1712 def __iter__(self) -> Iterator[symbol]:
1713 raise NotImplementedError(
1714 "iter not implemented to ensure compatibility with "
1715 "Python 3.11 IntFlag. Please use __members__. See "
1716 "https://github.com/python/cpython/issues/99304"
1717 )
1718
1719
1720class _FastIntFlag(metaclass=_IntFlagMeta):
1721 """An 'IntFlag' copycat that isn't slow when performing bitwise
1722 operations.
1723
1724 the ``FastIntFlag`` class will return ``enum.IntFlag`` under TYPE_CHECKING
1725 and ``_FastIntFlag`` otherwise.
1726
1727 """
1728
1729
1730if TYPE_CHECKING:
1731 from enum import IntFlag
1732
1733 FastIntFlag = IntFlag
1734else:
1735 FastIntFlag = _FastIntFlag
1736
1737
1738_E = TypeVar("_E", bound=enum.Enum)
1739
1740
1741def parse_user_argument_for_enum(
1742 arg: Any,
1743 choices: Dict[_E, List[Any]],
1744 name: str,
1745 resolve_symbol_names: bool = False,
1746) -> Optional[_E]:
1747 """Given a user parameter, parse the parameter into a chosen value
1748 from a list of choice objects, typically Enum values.
1749
1750 The user argument can be a string name that matches the name of a
1751 symbol, or the symbol object itself, or any number of alternate choices
1752 such as True/False/ None etc.
1753
1754 :param arg: the user argument.
1755 :param choices: dictionary of enum values to lists of possible
1756 entries for each.
1757 :param name: name of the argument. Used in an :class:`.ArgumentError`
1758 that is raised if the parameter doesn't match any available argument.
1759
1760 """
1761 for enum_value, choice in choices.items():
1762 if arg is enum_value:
1763 return enum_value
1764 elif resolve_symbol_names and arg == enum_value.name:
1765 return enum_value
1766 elif arg in choice:
1767 return enum_value
1768
1769 if arg is None:
1770 return None
1771
1772 raise exc.ArgumentError(f"Invalid value for '{name}': {arg!r}")
1773
1774
1775_creation_order = 1
1776
1777
1778def set_creation_order(instance: Any) -> None:
1779 """Assign a '_creation_order' sequence to the given instance.
1780
1781 This allows multiple instances to be sorted in order of creation
1782 (typically within a single thread; the counter is not particularly
1783 threadsafe).
1784
1785 """
1786 global _creation_order
1787 instance._creation_order = _creation_order
1788 _creation_order += 1
1789
1790
1791def warn_exception(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
1792 """executes the given function, catches all exceptions and converts to
1793 a warning.
1794
1795 """
1796 try:
1797 return func(*args, **kwargs)
1798 except Exception:
1799 warn("%s('%s') ignored" % sys.exc_info()[0:2])
1800
1801
1802def ellipses_string(value, len_=25):
1803 try:
1804 if len(value) > len_:
1805 return "%s..." % value[0:len_]
1806 else:
1807 return value
1808 except TypeError:
1809 return value
1810
1811
1812class _hash_limit_string(str):
1813 """A string subclass that can only be hashed on a maximum amount
1814 of unique values.
1815
1816 This is used for warnings so that we can send out parameterized warnings
1817 without the __warningregistry__ of the module, or the non-overridable
1818 "once" registry within warnings.py, overloading memory,
1819
1820
1821 """
1822
1823 _hash: int
1824
1825 def __new__(
1826 cls, value: str, num: int, args: Sequence[Any]
1827 ) -> _hash_limit_string:
1828 interpolated = (value % args) + (
1829 " (this warning may be suppressed after %d occurrences)" % num
1830 )
1831 self = super().__new__(cls, interpolated)
1832 self._hash = hash("%s_%d" % (value, hash(interpolated) % num))
1833 return self
1834
1835 def __hash__(self) -> int:
1836 return self._hash
1837
1838 def __eq__(self, other: Any) -> bool:
1839 return hash(self) == hash(other)
1840
1841
1842def warn(msg: str, code: Optional[str] = None) -> None:
1843 """Issue a warning.
1844
1845 If msg is a string, :class:`.exc.SAWarning` is used as
1846 the category.
1847
1848 """
1849 if code:
1850 _warnings_warn(exc.SAWarning(msg, code=code))
1851 else:
1852 _warnings_warn(msg, exc.SAWarning)
1853
1854
1855def warn_limited(msg: str, args: Sequence[Any]) -> None:
1856 """Issue a warning with a parameterized string, limiting the number
1857 of registrations.
1858
1859 """
1860 if args:
1861 msg = _hash_limit_string(msg, 10, args)
1862 _warnings_warn(msg, exc.SAWarning)
1863
1864
1865_warning_tags: Dict[CodeType, Tuple[str, Type[Warning]]] = {}
1866
1867
1868def tag_method_for_warnings(
1869 message: str, category: Type[Warning]
1870) -> Callable[[_F], _F]:
1871 def go(fn):
1872 _warning_tags[fn.__code__] = (message, category)
1873 return fn
1874
1875 return go
1876
1877
1878_not_sa_pattern = re.compile(r"^(?:sqlalchemy\.(?!testing)|alembic\.)")
1879
1880
1881def _warnings_warn(
1882 message: Union[str, Warning],
1883 category: Optional[Type[Warning]] = None,
1884 stacklevel: int = 2,
1885) -> None:
1886
1887 if category is None and isinstance(message, Warning):
1888 category = type(message)
1889
1890 # adjust the given stacklevel to be outside of SQLAlchemy
1891 try:
1892 frame = sys._getframe(stacklevel)
1893 except ValueError:
1894 # being called from less than 3 (or given) stacklevels, weird,
1895 # but don't crash
1896 stacklevel = 0
1897 except:
1898 # _getframe() doesn't work, weird interpreter issue, weird,
1899 # ok, but don't crash
1900 stacklevel = 0
1901 else:
1902 stacklevel_found = warning_tag_found = False
1903 while frame is not None:
1904 # using __name__ here requires that we have __name__ in the
1905 # __globals__ of the decorated string functions we make also.
1906 # we generate this using {"__name__": fn.__module__}
1907 if not stacklevel_found and not re.match(
1908 _not_sa_pattern, frame.f_globals.get("__name__", "")
1909 ):
1910 # stop incrementing stack level if an out-of-SQLA line
1911 # were found.
1912 stacklevel_found = True
1913
1914 # however, for the warning tag thing, we have to keep
1915 # scanning up the whole traceback
1916
1917 if frame.f_code in _warning_tags:
1918 warning_tag_found = True
1919 (_suffix, _category) = _warning_tags[frame.f_code]
1920 category = category or _category
1921 message = f"{message} ({_suffix})"
1922
1923 frame = frame.f_back # type: ignore[assignment]
1924
1925 if not stacklevel_found:
1926 stacklevel += 1
1927 elif stacklevel_found and warning_tag_found:
1928 break
1929
1930 if category is not None:
1931 warnings.warn(message, category, stacklevel=stacklevel + 1)
1932 else:
1933 warnings.warn(message, stacklevel=stacklevel + 1)
1934
1935
1936def only_once(
1937 fn: Callable[..., _T], retry_on_exception: bool
1938) -> Callable[..., Optional[_T]]:
1939 """Decorate the given function to be a no-op after it is called exactly
1940 once."""
1941
1942 once = [fn]
1943
1944 def go(*arg: Any, **kw: Any) -> Optional[_T]:
1945 # strong reference fn so that it isn't garbage collected,
1946 # which interferes with the event system's expectations
1947 strong_fn = fn # noqa
1948 if once:
1949 once_fn = once.pop()
1950 try:
1951 return once_fn(*arg, **kw)
1952 except:
1953 if retry_on_exception:
1954 once.insert(0, once_fn)
1955 raise
1956
1957 return None
1958
1959 return go
1960
1961
1962_SQLA_RE = re.compile(r"sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py")
1963_UNITTEST_RE = re.compile(r"unit(?:2|test2?/)")
1964
1965
1966def chop_traceback(
1967 tb: List[str],
1968 exclude_prefix: re.Pattern[str] = _UNITTEST_RE,
1969 exclude_suffix: re.Pattern[str] = _SQLA_RE,
1970) -> List[str]:
1971 """Chop extraneous lines off beginning and end of a traceback.
1972
1973 :param tb:
1974 a list of traceback lines as returned by ``traceback.format_stack()``
1975
1976 :param exclude_prefix:
1977 a regular expression object matching lines to skip at beginning of
1978 ``tb``
1979
1980 :param exclude_suffix:
1981 a regular expression object matching lines to skip at end of ``tb``
1982 """
1983 start = 0
1984 end = len(tb) - 1
1985 while start <= end and exclude_prefix.search(tb[start]):
1986 start += 1
1987 while start <= end and exclude_suffix.search(tb[end]):
1988 end -= 1
1989 return tb[start : end + 1]
1990
1991
1992def attrsetter(attrname):
1993 code = "def set(obj, value): obj.%s = value" % attrname
1994 env = locals().copy()
1995 exec(code, env)
1996 return env["set"]
1997
1998
1999_dunders = re.compile("^__.+__$")
2000
2001
2002class TypingOnly:
2003 """A mixin class that marks a class as 'typing only', meaning it has
2004 absolutely no methods, attributes, or runtime functionality whatsoever.
2005
2006 """
2007
2008 __slots__ = ()
2009
2010 def __init_subclass__(cls) -> None:
2011 if TypingOnly in cls.__bases__:
2012 remaining = {
2013 name for name in cls.__dict__ if not _dunders.match(name)
2014 }
2015 if remaining:
2016 raise AssertionError(
2017 f"Class {cls} directly inherits TypingOnly but has "
2018 f"additional attributes {remaining}."
2019 )
2020 super().__init_subclass__()
2021
2022
2023class EnsureKWArg:
2024 r"""Apply translation of functions to accept \**kw arguments if they
2025 don't already.
2026
2027 Used to ensure cross-compatibility with third party legacy code, for things
2028 like compiler visit methods that need to accept ``**kw`` arguments,
2029 but may have been copied from old code that didn't accept them.
2030
2031 """
2032
2033 ensure_kwarg: str
2034 """a regular expression that indicates method names for which the method
2035 should accept ``**kw`` arguments.
2036
2037 The class will scan for methods matching the name template and decorate
2038 them if necessary to ensure ``**kw`` parameters are accepted.
2039
2040 """
2041
2042 def __init_subclass__(cls) -> None:
2043 fn_reg = cls.ensure_kwarg
2044 clsdict = cls.__dict__
2045 if fn_reg:
2046 for key in clsdict:
2047 m = re.match(fn_reg, key)
2048 if m:
2049 fn = clsdict[key]
2050 spec = compat.inspect_getfullargspec(fn)
2051 if not spec.varkw:
2052 wrapped = cls._wrap_w_kw(fn)
2053 setattr(cls, key, wrapped)
2054 super().__init_subclass__()
2055
2056 @classmethod
2057 def _wrap_w_kw(cls, fn: Callable[..., Any]) -> Callable[..., Any]:
2058 def wrap(*arg: Any, **kw: Any) -> Any:
2059 return fn(*arg)
2060
2061 return update_wrapper(wrap, fn)
2062
2063
2064def wrap_callable(wrapper, fn):
2065 """Augment functools.update_wrapper() to work with objects with
2066 a ``__call__()`` method.
2067
2068 :param fn:
2069 object with __call__ method
2070
2071 """
2072 if hasattr(fn, "__name__"):
2073 return update_wrapper(wrapper, fn)
2074 else:
2075 _f = wrapper
2076 _f.__name__ = fn.__class__.__name__
2077 if hasattr(fn, "__module__"):
2078 _f.__module__ = fn.__module__
2079
2080 if hasattr(fn.__call__, "__doc__") and fn.__call__.__doc__:
2081 _f.__doc__ = fn.__call__.__doc__
2082 elif fn.__doc__:
2083 _f.__doc__ = fn.__doc__
2084
2085 return _f
2086
2087
2088def quoted_token_parser(value):
2089 """Parse a dotted identifier with accommodation for quoted names.
2090
2091 Includes support for SQL-style double quotes as a literal character.
2092
2093 E.g.::
2094
2095 >>> quoted_token_parser("name")
2096 ["name"]
2097 >>> quoted_token_parser("schema.name")
2098 ["schema", "name"]
2099 >>> quoted_token_parser('"Schema"."Name"')
2100 ['Schema', 'Name']
2101 >>> quoted_token_parser('"Schema"."Name""Foo"')
2102 ['Schema', 'Name""Foo']
2103
2104 """
2105
2106 if '"' not in value:
2107 return value.split(".")
2108
2109 # 0 = outside of quotes
2110 # 1 = inside of quotes
2111 state = 0
2112 result: List[List[str]] = [[]]
2113 idx = 0
2114 lv = len(value)
2115 while idx < lv:
2116 char = value[idx]
2117 if char == '"':
2118 if state == 1 and idx < lv - 1 and value[idx + 1] == '"':
2119 result[-1].append('"')
2120 idx += 1
2121 else:
2122 state ^= 1
2123 elif char == "." and state == 0:
2124 result.append([])
2125 else:
2126 result[-1].append(char)
2127 idx += 1
2128
2129 return ["".join(token) for token in result]
2130
2131
2132def add_parameter_text(params: Any, text: str) -> Callable[[_F], _F]:
2133 params = _collections.to_list(params)
2134
2135 def decorate(fn):
2136 doc = fn.__doc__ is not None and fn.__doc__ or ""
2137 if doc:
2138 doc = inject_param_text(doc, {param: text for param in params})
2139 fn.__doc__ = doc
2140 return fn
2141
2142 return decorate
2143
2144
2145def _dedent_docstring(text: str) -> str:
2146 split_text = text.split("\n", 1)
2147 if len(split_text) == 1:
2148 return text
2149 else:
2150 firstline, remaining = split_text
2151 if not firstline.startswith(" "):
2152 return firstline + "\n" + textwrap.dedent(remaining)
2153 else:
2154 return textwrap.dedent(text)
2155
2156
2157def inject_docstring_text(
2158 given_doctext: Optional[str], injecttext: str, pos: int
2159) -> str:
2160 doctext: str = _dedent_docstring(given_doctext or "")
2161 lines = doctext.split("\n")
2162 if len(lines) == 1:
2163 lines.append("")
2164 injectlines = textwrap.dedent(injecttext).split("\n")
2165 if injectlines[0]:
2166 injectlines.insert(0, "")
2167
2168 blanks = [num for num, line in enumerate(lines) if not line.strip()]
2169 blanks.insert(0, 0)
2170
2171 inject_pos = blanks[min(pos, len(blanks) - 1)]
2172
2173 lines = lines[0:inject_pos] + injectlines + lines[inject_pos:]
2174 return "\n".join(lines)
2175
2176
2177_param_reg = re.compile(r"(\s+):param (.+?):")
2178
2179
2180def inject_param_text(doctext: str, inject_params: Dict[str, str]) -> str:
2181 doclines = collections.deque(doctext.splitlines())
2182 lines = []
2183
2184 # TODO: this is not working for params like ":param case_sensitive=True:"
2185
2186 to_inject = None
2187 while doclines:
2188 line = doclines.popleft()
2189
2190 m = _param_reg.match(line)
2191
2192 if to_inject is None:
2193 if m:
2194 param = m.group(2).lstrip("*")
2195 if param in inject_params:
2196 # default indent to that of :param: plus one
2197 indent = " " * len(m.group(1)) + " "
2198
2199 # but if the next line has text, use that line's
2200 # indentation
2201 if doclines:
2202 m2 = re.match(r"(\s+)\S", doclines[0])
2203 if m2:
2204 indent = " " * len(m2.group(1))
2205
2206 to_inject = indent + inject_params[param]
2207 elif m:
2208 lines.extend(["\n", to_inject, "\n"])
2209 to_inject = None
2210 elif not line.rstrip():
2211 lines.extend([line, to_inject, "\n"])
2212 to_inject = None
2213 elif line.endswith("::"):
2214 # TODO: this still won't cover if the code example itself has
2215 # blank lines in it, need to detect those via indentation.
2216 lines.extend([line, doclines.popleft()])
2217 continue
2218 lines.append(line)
2219
2220 return "\n".join(lines)
2221
2222
2223def repr_tuple_names(names: List[str]) -> Optional[str]:
2224 """Trims a list of strings from the middle and return a string of up to
2225 four elements. Strings greater than 11 characters will be truncated"""
2226 if len(names) == 0:
2227 return None
2228 flag = len(names) <= 4
2229 names = names[0:4] if flag else names[0:3] + names[-1:]
2230 res = ["%s.." % name[:11] if len(name) > 11 else name for name in names]
2231 if flag:
2232 return ", ".join(res)
2233 else:
2234 return "%s, ..., %s" % (", ".join(res[0:3]), res[-1])
2235
2236
2237def has_compiled_ext(raise_=False):
2238 from ._has_cython import HAS_CYEXTENSION
2239
2240 if HAS_CYEXTENSION:
2241 return True
2242 elif raise_:
2243 raise ImportError(
2244 "cython extensions were expected to be installed, "
2245 "but are not present"
2246 )
2247 else:
2248 return False
2249
2250
2251def load_uncompiled_module(module: _M) -> _M:
2252 """Load the non-compied version of a module that is also
2253 compiled with cython.
2254 """
2255 full_name = module.__name__
2256 assert module.__spec__
2257 parent_name = module.__spec__.parent
2258 assert parent_name
2259 parent_module = sys.modules[parent_name]
2260 assert parent_module.__spec__
2261 package_path = parent_module.__spec__.origin
2262 assert package_path and package_path.endswith("__init__.py")
2263
2264 name = full_name.split(".")[-1]
2265 module_path = package_path.replace("__init__.py", f"{name}.py")
2266
2267 py_spec = importlib.util.spec_from_file_location(full_name, module_path)
2268 assert py_spec
2269 py_module = importlib.util.module_from_spec(py_spec)
2270 assert py_spec.loader
2271 py_spec.loader.exec_module(py_module)
2272 return cast(_M, py_module)
2273
2274
2275class _Missing(enum.Enum):
2276 Missing = enum.auto()
2277
2278
2279Missing = _Missing.Missing
2280MissingOr = Union[_T, Literal[_Missing.Missing]]