1# util/langhelpers.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Routines to help with the creation, loading and introspection of
10modules, classes, hierarchies, attributes, functions, and methods.
11
12"""
13
14from __future__ import annotations
15
16import collections
17import enum
18from functools import update_wrapper
19import inspect
20import itertools
21import operator
22import re
23import sys
24import textwrap
25import threading
26import types
27from types import CodeType
28from typing import Any
29from typing import Callable
30from typing import cast
31from typing import Dict
32from typing import FrozenSet
33from typing import Generic
34from typing import Iterator
35from typing import List
36from typing import NoReturn
37from typing import Optional
38from typing import overload
39from typing import Sequence
40from typing import Set
41from typing import Tuple
42from typing import Type
43from typing import TYPE_CHECKING
44from typing import TypeVar
45from typing import Union
46import warnings
47
48from . import _collections
49from . import compat
50from ._has_cy import HAS_CYEXTENSION
51from .typing import Literal
52from .. import exc
53
54_T = TypeVar("_T")
55_T_co = TypeVar("_T_co", covariant=True)
56_F = TypeVar("_F", bound=Callable[..., Any])
57_MP = TypeVar("_MP", bound="memoized_property[Any]")
58_MA = TypeVar("_MA", bound="HasMemoized.memoized_attribute[Any]")
59_HP = TypeVar("_HP", bound="hybridproperty[Any]")
60_HM = TypeVar("_HM", bound="hybridmethod[Any]")
61
62
63def md5_hex(x: Any) -> str:
64 x = x.encode("utf-8")
65 m = compat.md5_not_for_security()
66 m.update(x)
67 return cast(str, m.hexdigest())
68
69
70class safe_reraise:
71 """Reraise an exception after invoking some
72 handler code.
73
74 Stores the existing exception info before
75 invoking so that it is maintained across a potential
76 coroutine context switch.
77
78 e.g.::
79
80 try:
81 sess.commit()
82 except:
83 with safe_reraise():
84 sess.rollback()
85
86 TODO: we should at some point evaluate current behaviors in this regard
87 based on current greenlet, gevent/eventlet implementations in Python 3, and
88 also see the degree to which our own asyncio (based on greenlet also) is
89 impacted by this. .rollback() will cause IO / context switch to occur in
90 all these scenarios; what happens to the exception context from an
91 "except:" block if we don't explicitly store it? Original issue was #2703.
92
93 """
94
95 __slots__ = ("_exc_info",)
96
97 _exc_info: Union[
98 None,
99 Tuple[
100 Type[BaseException],
101 BaseException,
102 types.TracebackType,
103 ],
104 Tuple[None, None, None],
105 ]
106
107 def __enter__(self) -> None:
108 self._exc_info = sys.exc_info()
109
110 def __exit__(
111 self,
112 type_: Optional[Type[BaseException]],
113 value: Optional[BaseException],
114 traceback: Optional[types.TracebackType],
115 ) -> NoReturn:
116 assert self._exc_info is not None
117 # see #2703 for notes
118 if type_ is None:
119 exc_type, exc_value, exc_tb = self._exc_info
120 assert exc_value is not None
121 self._exc_info = None # remove potential circular references
122 raise exc_value.with_traceback(exc_tb)
123 else:
124 self._exc_info = None # remove potential circular references
125 assert value is not None
126 raise value.with_traceback(traceback)
127
128
129def walk_subclasses(cls: Type[_T]) -> Iterator[Type[_T]]:
130 seen: Set[Any] = set()
131
132 stack = [cls]
133 while stack:
134 cls = stack.pop()
135 if cls in seen:
136 continue
137 else:
138 seen.add(cls)
139 stack.extend(cls.__subclasses__())
140 yield cls
141
142
143def string_or_unprintable(element: Any) -> str:
144 if isinstance(element, str):
145 return element
146 else:
147 try:
148 return str(element)
149 except Exception:
150 return "unprintable element %r" % element
151
152
153def clsname_as_plain_name(
154 cls: Type[Any], use_name: Optional[str] = None
155) -> str:
156 name = use_name or cls.__name__
157 return " ".join(n.lower() for n in re.findall(r"([A-Z][a-z]+|SQL)", name))
158
159
160def method_is_overridden(
161 instance_or_cls: Union[Type[Any], object],
162 against_method: Callable[..., Any],
163) -> bool:
164 """Return True if the two class methods don't match."""
165
166 if not isinstance(instance_or_cls, type):
167 current_cls = instance_or_cls.__class__
168 else:
169 current_cls = instance_or_cls
170
171 method_name = against_method.__name__
172
173 current_method: types.MethodType = getattr(current_cls, method_name)
174
175 return current_method != against_method
176
177
178def decode_slice(slc: slice) -> Tuple[Any, ...]:
179 """decode a slice object as sent to __getitem__.
180
181 takes into account the 2.5 __index__() method, basically.
182
183 """
184 ret: List[Any] = []
185 for x in slc.start, slc.stop, slc.step:
186 if hasattr(x, "__index__"):
187 x = x.__index__()
188 ret.append(x)
189 return tuple(ret)
190
191
192def _unique_symbols(used: Sequence[str], *bases: str) -> Iterator[str]:
193 used_set = set(used)
194 for base in bases:
195 pool = itertools.chain(
196 (base,),
197 map(lambda i: base + str(i), range(1000)),
198 )
199 for sym in pool:
200 if sym not in used_set:
201 used_set.add(sym)
202 yield sym
203 break
204 else:
205 raise NameError("exhausted namespace for symbol base %s" % base)
206
207
208def map_bits(fn: Callable[[int], Any], n: int) -> Iterator[Any]:
209 """Call the given function given each nonzero bit from n."""
210
211 while n:
212 b = n & (~n + 1)
213 yield fn(b)
214 n ^= b
215
216
217_Fn = TypeVar("_Fn", bound="Callable[..., Any]")
218
219# this seems to be in flux in recent mypy versions
220
221
222def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]:
223 """A signature-matching decorator factory."""
224
225 def decorate(fn: _Fn) -> _Fn:
226 if not inspect.isfunction(fn) and not inspect.ismethod(fn):
227 raise Exception("not a decoratable function")
228
229 # Python 3.14 defer creating __annotations__ until its used.
230 # We do not want to create __annotations__ now.
231 annofunc = getattr(fn, "__annotate__", None)
232 if annofunc is not None:
233 fn.__annotate__ = None # type: ignore[union-attr]
234 try:
235 spec = compat.inspect_getfullargspec(fn)
236 finally:
237 fn.__annotate__ = annofunc # type: ignore[union-attr]
238 else:
239 spec = compat.inspect_getfullargspec(fn)
240
241 # Do not generate code for annotations.
242 # update_wrapper() copies the annotation from fn to decorated.
243 # We use dummy defaults for code generation to avoid having
244 # copy of large globals for compiling.
245 # We copy __defaults__ and __kwdefaults__ from fn to decorated.
246 empty_defaults = (None,) * len(spec.defaults or ())
247 empty_kwdefaults = dict.fromkeys(spec.kwonlydefaults or ())
248 spec = spec._replace(
249 annotations={},
250 defaults=empty_defaults,
251 kwonlydefaults=empty_kwdefaults,
252 )
253
254 names = (
255 tuple(cast("Tuple[str, ...]", spec[0]))
256 + cast("Tuple[str, ...]", spec[1:3])
257 + (fn.__name__,)
258 )
259 targ_name, fn_name = _unique_symbols(names, "target", "fn")
260
261 metadata: Dict[str, Optional[str]] = dict(target=targ_name, fn=fn_name)
262 metadata.update(format_argspec_plus(spec, grouped=False))
263 metadata["name"] = fn.__name__
264
265 if inspect.iscoroutinefunction(fn):
266 metadata["prefix"] = "async "
267 metadata["target_prefix"] = "await "
268 else:
269 metadata["prefix"] = ""
270 metadata["target_prefix"] = ""
271
272 # look for __ positional arguments. This is a convention in
273 # SQLAlchemy that arguments should be passed positionally
274 # rather than as keyword
275 # arguments. note that apply_pos doesn't currently work in all cases
276 # such as when a kw-only indicator "*" is present, which is why
277 # we limit the use of this to just that case we can detect. As we add
278 # more kinds of methods that use @decorator, things may have to
279 # be further improved in this area
280 if "__" in repr(spec[0]):
281 code = """\
282%(prefix)sdef %(name)s%(grouped_args)s:
283 return %(target_prefix)s%(target)s(%(fn)s, %(apply_pos)s)
284""" % metadata
285 else:
286 code = """\
287%(prefix)sdef %(name)s%(grouped_args)s:
288 return %(target_prefix)s%(target)s(%(fn)s, %(apply_kw)s)
289""" % metadata
290
291 env: Dict[str, Any] = {
292 targ_name: target,
293 fn_name: fn,
294 "__name__": fn.__module__,
295 }
296
297 decorated = cast(
298 types.FunctionType,
299 _exec_code_in_env(code, env, fn.__name__),
300 )
301 decorated.__defaults__ = fn.__defaults__
302 decorated.__kwdefaults__ = fn.__kwdefaults__ # type: ignore
303 return update_wrapper(decorated, fn) # type: ignore[return-value]
304
305 return update_wrapper(decorate, target) # type: ignore[return-value]
306
307
308def _exec_code_in_env(
309 code: Union[str, types.CodeType], env: Dict[str, Any], fn_name: str
310) -> Callable[..., Any]:
311 exec(code, env)
312 return env[fn_name] # type: ignore[no-any-return]
313
314
315_PF = TypeVar("_PF")
316_TE = TypeVar("_TE")
317
318
319class PluginLoader:
320 def __init__(
321 self, group: str, auto_fn: Optional[Callable[..., Any]] = None
322 ):
323 self.group = group
324 self.impls: Dict[str, Any] = {}
325 self.auto_fn = auto_fn
326
327 def clear(self):
328 self.impls.clear()
329
330 def load(self, name: str) -> Any:
331 if name in self.impls:
332 return self.impls[name]()
333
334 if self.auto_fn:
335 loader = self.auto_fn(name)
336 if loader:
337 self.impls[name] = loader
338 return loader()
339
340 for impl in compat.importlib_metadata_get(self.group):
341 if impl.name == name:
342 self.impls[name] = impl.load
343 return impl.load()
344
345 raise exc.NoSuchModuleError(
346 "Can't load plugin: %s:%s" % (self.group, name)
347 )
348
349 def register(self, name: str, modulepath: str, objname: str) -> None:
350 def load():
351 mod = __import__(modulepath)
352 for token in modulepath.split(".")[1:]:
353 mod = getattr(mod, token)
354 return getattr(mod, objname)
355
356 self.impls[name] = load
357
358 def deregister(self, name: str) -> None:
359 del self.impls[name]
360
361
362def _inspect_func_args(fn):
363 try:
364 co_varkeywords = inspect.CO_VARKEYWORDS
365 except AttributeError:
366 # https://docs.python.org/3/library/inspect.html
367 # The flags are specific to CPython, and may not be defined in other
368 # Python implementations. Furthermore, the flags are an implementation
369 # detail, and can be removed or deprecated in future Python releases.
370 spec = compat.inspect_getfullargspec(fn)
371 return spec[0], bool(spec[2])
372 else:
373 # use fn.__code__ plus flags to reduce method call overhead
374 co = fn.__code__
375 nargs = co.co_argcount
376 return (
377 list(co.co_varnames[:nargs]),
378 bool(co.co_flags & co_varkeywords),
379 )
380
381
382@overload
383def get_cls_kwargs(
384 cls: type,
385 *,
386 _set: Optional[Set[str]] = None,
387 raiseerr: Literal[True] = ...,
388) -> Set[str]: ...
389
390
391@overload
392def get_cls_kwargs(
393 cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
394) -> Optional[Set[str]]: ...
395
396
397def get_cls_kwargs(
398 cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
399) -> Optional[Set[str]]:
400 r"""Return the full set of inherited kwargs for the given `cls`.
401
402 Probes a class's __init__ method, collecting all named arguments. If the
403 __init__ defines a \**kwargs catch-all, then the constructor is presumed
404 to pass along unrecognized keywords to its base classes, and the
405 collection process is repeated recursively on each of the bases.
406
407 Uses a subset of inspect.getfullargspec() to cut down on method overhead,
408 as this is used within the Core typing system to create copies of type
409 objects which is a performance-sensitive operation.
410
411 No anonymous tuple arguments please !
412
413 """
414 toplevel = _set is None
415 if toplevel:
416 _set = set()
417 assert _set is not None
418
419 ctr = cls.__dict__.get("__init__", False)
420
421 has_init = (
422 ctr
423 and isinstance(ctr, types.FunctionType)
424 and isinstance(ctr.__code__, types.CodeType)
425 )
426
427 if has_init:
428 names, has_kw = _inspect_func_args(ctr)
429 _set.update(names)
430
431 if not has_kw and not toplevel:
432 if raiseerr:
433 raise TypeError(
434 f"given cls {cls} doesn't have an __init__ method"
435 )
436 else:
437 return None
438 else:
439 has_kw = False
440
441 if not has_init or has_kw:
442 for c in cls.__bases__:
443 if get_cls_kwargs(c, _set=_set) is None:
444 break
445
446 _set.discard("self")
447 return _set
448
449
450def get_func_kwargs(func: Callable[..., Any]) -> List[str]:
451 """Return the set of legal kwargs for the given `func`.
452
453 Uses getargspec so is safe to call for methods, functions,
454 etc.
455
456 """
457
458 return compat.inspect_getfullargspec(func)[0]
459
460
461def get_callable_argspec(
462 fn: Callable[..., Any], no_self: bool = False, _is_init: bool = False
463) -> compat.FullArgSpec:
464 """Return the argument signature for any callable.
465
466 All pure-Python callables are accepted, including
467 functions, methods, classes, objects with __call__;
468 builtins and other edge cases like functools.partial() objects
469 raise a TypeError.
470
471 """
472 if inspect.isbuiltin(fn):
473 raise TypeError("Can't inspect builtin: %s" % fn)
474 elif inspect.isfunction(fn) or (
475 hasattr(fn, "__code__")
476 and not inspect.isclass(fn)
477 and not inspect.ismethod(fn)
478 ):
479 if _is_init and no_self:
480 spec = compat.inspect_getfullargspec(fn)
481 return compat.FullArgSpec(
482 spec.args[1:],
483 spec.varargs,
484 spec.varkw,
485 spec.defaults,
486 spec.kwonlyargs,
487 spec.kwonlydefaults,
488 spec.annotations,
489 )
490 else:
491 return compat.inspect_getfullargspec(fn)
492 elif inspect.ismethod(fn):
493 if no_self and (_is_init or fn.__self__):
494 spec = compat.inspect_getfullargspec(fn.__func__)
495 return compat.FullArgSpec(
496 spec.args[1:],
497 spec.varargs,
498 spec.varkw,
499 spec.defaults,
500 spec.kwonlyargs,
501 spec.kwonlydefaults,
502 spec.annotations,
503 )
504 else:
505 return compat.inspect_getfullargspec(fn.__func__)
506 elif inspect.isclass(fn):
507 return get_callable_argspec(
508 fn.__init__, no_self=no_self, _is_init=True
509 )
510 elif hasattr(fn, "__func__"):
511 return compat.inspect_getfullargspec(fn.__func__)
512 elif hasattr(fn, "__call__"):
513 if inspect.ismethod(fn.__call__):
514 return get_callable_argspec(fn.__call__, no_self=no_self)
515 else:
516 raise TypeError("Can't inspect callable: %s" % fn)
517 else:
518 raise TypeError("Can't inspect callable: %s" % fn)
519
520
521def format_argspec_plus(
522 fn: Union[Callable[..., Any], compat.FullArgSpec], grouped: bool = True
523) -> Dict[str, Optional[str]]:
524 """Returns a dictionary of formatted, introspected function arguments.
525
526 A enhanced variant of inspect.formatargspec to support code generation.
527
528 fn
529 An inspectable callable or tuple of inspect getargspec() results.
530 grouped
531 Defaults to True; include (parens, around, argument) lists
532
533 Returns:
534
535 args
536 Full inspect.formatargspec for fn
537 self_arg
538 The name of the first positional argument, varargs[0], or None
539 if the function defines no positional arguments.
540 apply_pos
541 args, re-written in calling rather than receiving syntax. Arguments are
542 passed positionally.
543 apply_kw
544 Like apply_pos, except keyword-ish args are passed as keywords.
545 apply_pos_proxied
546 Like apply_pos but omits the self/cls argument
547
548 Example::
549
550 >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123)
551 {'grouped_args': '(self, a, b, c=3, **d)',
552 'self_arg': 'self',
553 'apply_kw': '(self, a, b, c=c, **d)',
554 'apply_pos': '(self, a, b, c, **d)'}
555
556 """
557 if callable(fn):
558 spec = compat.inspect_getfullargspec(fn)
559 else:
560 spec = fn
561
562 args = compat.inspect_formatargspec(*spec)
563
564 apply_pos = compat.inspect_formatargspec(
565 spec[0], spec[1], spec[2], None, spec[4]
566 )
567
568 if spec[0]:
569 self_arg = spec[0][0]
570
571 apply_pos_proxied = compat.inspect_formatargspec(
572 spec[0][1:], spec[1], spec[2], None, spec[4]
573 )
574
575 elif spec[1]:
576 # I'm not sure what this is
577 self_arg = "%s[0]" % spec[1]
578
579 apply_pos_proxied = apply_pos
580 else:
581 self_arg = None
582 apply_pos_proxied = apply_pos
583
584 num_defaults = 0
585 if spec[3]:
586 num_defaults += len(cast(Tuple[Any], spec[3]))
587 if spec[4]:
588 num_defaults += len(spec[4])
589
590 name_args = spec[0] + spec[4]
591
592 defaulted_vals: Union[List[str], Tuple[()]]
593
594 if num_defaults:
595 defaulted_vals = name_args[0 - num_defaults :]
596 else:
597 defaulted_vals = ()
598
599 apply_kw = compat.inspect_formatargspec(
600 name_args,
601 spec[1],
602 spec[2],
603 defaulted_vals,
604 formatvalue=lambda x: "=" + str(x),
605 )
606
607 if spec[0]:
608 apply_kw_proxied = compat.inspect_formatargspec(
609 name_args[1:],
610 spec[1],
611 spec[2],
612 defaulted_vals,
613 formatvalue=lambda x: "=" + str(x),
614 )
615 else:
616 apply_kw_proxied = apply_kw
617
618 if grouped:
619 return dict(
620 grouped_args=args,
621 self_arg=self_arg,
622 apply_pos=apply_pos,
623 apply_kw=apply_kw,
624 apply_pos_proxied=apply_pos_proxied,
625 apply_kw_proxied=apply_kw_proxied,
626 )
627 else:
628 return dict(
629 grouped_args=args,
630 self_arg=self_arg,
631 apply_pos=apply_pos[1:-1],
632 apply_kw=apply_kw[1:-1],
633 apply_pos_proxied=apply_pos_proxied[1:-1],
634 apply_kw_proxied=apply_kw_proxied[1:-1],
635 )
636
637
638def format_argspec_init(method, grouped=True):
639 """format_argspec_plus with considerations for typical __init__ methods
640
641 Wraps format_argspec_plus with error handling strategies for typical
642 __init__ cases:
643
644 .. sourcecode:: text
645
646 object.__init__ -> (self)
647 other unreflectable (usually C) -> (self, *args, **kwargs)
648
649 """
650 if method is object.__init__:
651 grouped_args = "(self)"
652 args = "(self)" if grouped else "self"
653 proxied = "()" if grouped else ""
654 else:
655 try:
656 return format_argspec_plus(method, grouped=grouped)
657 except TypeError:
658 grouped_args = "(self, *args, **kwargs)"
659 args = grouped_args if grouped else "self, *args, **kwargs"
660 proxied = "(*args, **kwargs)" if grouped else "*args, **kwargs"
661 return dict(
662 self_arg="self",
663 grouped_args=grouped_args,
664 apply_pos=args,
665 apply_kw=args,
666 apply_pos_proxied=proxied,
667 apply_kw_proxied=proxied,
668 )
669
670
671def create_proxy_methods(
672 target_cls: Type[Any],
673 target_cls_sphinx_name: str,
674 proxy_cls_sphinx_name: str,
675 classmethods: Sequence[str] = (),
676 methods: Sequence[str] = (),
677 attributes: Sequence[str] = (),
678 use_intermediate_variable: Sequence[str] = (),
679) -> Callable[[_T], _T]:
680 """A class decorator indicating attributes should refer to a proxy
681 class.
682
683 This decorator is now a "marker" that does nothing at runtime. Instead,
684 it is consumed by the tools/generate_proxy_methods.py script to
685 statically generate proxy methods and attributes that are fully
686 recognized by typing tools such as mypy.
687
688 """
689
690 def decorate(cls):
691 return cls
692
693 return decorate
694
695
696def getargspec_init(method):
697 """inspect.getargspec with considerations for typical __init__ methods
698
699 Wraps inspect.getargspec with error handling for typical __init__ cases:
700
701 .. sourcecode:: text
702
703 object.__init__ -> (self)
704 other unreflectable (usually C) -> (self, *args, **kwargs)
705
706 """
707 try:
708 return compat.inspect_getfullargspec(method)
709 except TypeError:
710 if method is object.__init__:
711 return (["self"], None, None, None)
712 else:
713 return (["self"], "args", "kwargs", None)
714
715
716def unbound_method_to_callable(func_or_cls):
717 """Adjust the incoming callable such that a 'self' argument is not
718 required.
719
720 """
721
722 if isinstance(func_or_cls, types.MethodType) and not func_or_cls.__self__:
723 return func_or_cls.__func__
724 else:
725 return func_or_cls
726
727
728def generic_repr(
729 obj: Any,
730 additional_kw: Sequence[Tuple[str, Any]] = (),
731 to_inspect: Optional[Union[object, List[object]]] = None,
732 omit_kwarg: Sequence[str] = (),
733) -> str:
734 """Produce a __repr__() based on direct association of the __init__()
735 specification vs. same-named attributes present.
736
737 """
738 if to_inspect is None:
739 to_inspect = [obj]
740 else:
741 to_inspect = _collections.to_list(to_inspect)
742
743 missing = object()
744
745 pos_args = []
746 kw_args: _collections.OrderedDict[str, Any] = _collections.OrderedDict()
747 vargs = None
748 for i, insp in enumerate(to_inspect):
749 try:
750 spec = compat.inspect_getfullargspec(insp.__init__)
751 except TypeError:
752 continue
753 else:
754 default_len = len(spec.defaults) if spec.defaults else 0
755 if i == 0:
756 if spec.varargs:
757 vargs = spec.varargs
758 if default_len:
759 pos_args.extend(spec.args[1:-default_len])
760 else:
761 pos_args.extend(spec.args[1:])
762 else:
763 kw_args.update(
764 [(arg, missing) for arg in spec.args[1:-default_len]]
765 )
766
767 if default_len:
768 assert spec.defaults
769 kw_args.update(
770 [
771 (arg, default)
772 for arg, default in zip(
773 spec.args[-default_len:], spec.defaults
774 )
775 ]
776 )
777 output: List[str] = []
778
779 output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
780
781 if vargs is not None and hasattr(obj, vargs):
782 output.extend([repr(val) for val in getattr(obj, vargs)])
783
784 for arg, defval in kw_args.items():
785 if arg in omit_kwarg:
786 continue
787 try:
788 val = getattr(obj, arg, missing)
789 if val is not missing and val != defval:
790 output.append("%s=%r" % (arg, val))
791 except Exception:
792 pass
793
794 if additional_kw:
795 for arg, defval in additional_kw:
796 try:
797 val = getattr(obj, arg, missing)
798 if val is not missing and val != defval:
799 output.append("%s=%r" % (arg, val))
800 except Exception:
801 pass
802
803 return "%s(%s)" % (obj.__class__.__name__, ", ".join(output))
804
805
806class portable_instancemethod:
807 """Turn an instancemethod into a (parent, name) pair
808 to produce a serializable callable.
809
810 """
811
812 __slots__ = "target", "name", "kwargs", "__weakref__"
813
814 def __getstate__(self):
815 return {
816 "target": self.target,
817 "name": self.name,
818 "kwargs": self.kwargs,
819 }
820
821 def __setstate__(self, state):
822 self.target = state["target"]
823 self.name = state["name"]
824 self.kwargs = state.get("kwargs", ())
825
826 def __init__(self, meth, kwargs=()):
827 self.target = meth.__self__
828 self.name = meth.__name__
829 self.kwargs = kwargs
830
831 def __call__(self, *arg, **kw):
832 kw.update(self.kwargs)
833 return getattr(self.target, self.name)(*arg, **kw)
834
835
836def class_hierarchy(cls):
837 """Return an unordered sequence of all classes related to cls.
838
839 Traverses diamond hierarchies.
840
841 Fibs slightly: subclasses of builtin types are not returned. Thus
842 class_hierarchy(class A(object)) returns (A, object), not A plus every
843 class systemwide that derives from object.
844
845 """
846
847 hier = {cls}
848 process = list(cls.__mro__)
849 while process:
850 c = process.pop()
851 bases = (_ for _ in c.__bases__ if _ not in hier)
852
853 for b in bases:
854 process.append(b)
855 hier.add(b)
856
857 if c.__module__ == "builtins" or not hasattr(c, "__subclasses__"):
858 continue
859
860 for s in [
861 _
862 for _ in (
863 c.__subclasses__()
864 if not issubclass(c, type)
865 else c.__subclasses__(c)
866 )
867 if _ not in hier
868 ]:
869 process.append(s)
870 hier.add(s)
871 return list(hier)
872
873
874def iterate_attributes(cls):
875 """iterate all the keys and attributes associated
876 with a class, without using getattr().
877
878 Does not use getattr() so that class-sensitive
879 descriptors (i.e. property.__get__()) are not called.
880
881 """
882 keys = dir(cls)
883 for key in keys:
884 for c in cls.__mro__:
885 if key in c.__dict__:
886 yield (key, c.__dict__[key])
887 break
888
889
890def monkeypatch_proxied_specials(
891 into_cls,
892 from_cls,
893 skip=None,
894 only=None,
895 name="self.proxy",
896 from_instance=None,
897):
898 """Automates delegation of __specials__ for a proxying type."""
899
900 if only:
901 dunders = only
902 else:
903 if skip is None:
904 skip = (
905 "__slots__",
906 "__del__",
907 "__getattribute__",
908 "__metaclass__",
909 "__getstate__",
910 "__setstate__",
911 )
912 dunders = [
913 m
914 for m in dir(from_cls)
915 if (
916 m.startswith("__")
917 and m.endswith("__")
918 and not hasattr(into_cls, m)
919 and m not in skip
920 )
921 ]
922
923 for method in dunders:
924 try:
925 maybe_fn = getattr(from_cls, method)
926 if not hasattr(maybe_fn, "__call__"):
927 continue
928 maybe_fn = getattr(maybe_fn, "__func__", maybe_fn)
929 fn = cast(types.FunctionType, maybe_fn)
930
931 except AttributeError:
932 continue
933 try:
934 spec = compat.inspect_getfullargspec(fn)
935 fn_args = compat.inspect_formatargspec(spec[0])
936 d_args = compat.inspect_formatargspec(spec[0][1:])
937 except TypeError:
938 fn_args = "(self, *args, **kw)"
939 d_args = "(*args, **kw)"
940
941 py = (
942 "def %(method)s%(fn_args)s: "
943 "return %(name)s.%(method)s%(d_args)s" % locals()
944 )
945
946 env: Dict[str, types.FunctionType] = (
947 from_instance is not None and {name: from_instance} or {}
948 )
949 exec(py, env)
950 try:
951 env[method].__defaults__ = fn.__defaults__
952 except AttributeError:
953 pass
954 setattr(into_cls, method, env[method])
955
956
957def methods_equivalent(meth1, meth2):
958 """Return True if the two methods are the same implementation."""
959
960 return getattr(meth1, "__func__", meth1) is getattr(
961 meth2, "__func__", meth2
962 )
963
964
965def as_interface(obj, cls=None, methods=None, required=None):
966 """Ensure basic interface compliance for an instance or dict of callables.
967
968 Checks that ``obj`` implements public methods of ``cls`` or has members
969 listed in ``methods``. If ``required`` is not supplied, implementing at
970 least one interface method is sufficient. Methods present on ``obj`` that
971 are not in the interface are ignored.
972
973 If ``obj`` is a dict and ``dict`` does not meet the interface
974 requirements, the keys of the dictionary are inspected. Keys present in
975 ``obj`` that are not in the interface will raise TypeErrors.
976
977 Raises TypeError if ``obj`` does not meet the interface criteria.
978
979 In all passing cases, an object with callable members is returned. In the
980 simple case, ``obj`` is returned as-is; if dict processing kicks in then
981 an anonymous class is returned.
982
983 obj
984 A type, instance, or dictionary of callables.
985 cls
986 Optional, a type. All public methods of cls are considered the
987 interface. An ``obj`` instance of cls will always pass, ignoring
988 ``required``..
989 methods
990 Optional, a sequence of method names to consider as the interface.
991 required
992 Optional, a sequence of mandatory implementations. If omitted, an
993 ``obj`` that provides at least one interface method is considered
994 sufficient. As a convenience, required may be a type, in which case
995 all public methods of the type are required.
996
997 """
998 if not cls and not methods:
999 raise TypeError("a class or collection of method names are required")
1000
1001 if isinstance(cls, type) and isinstance(obj, cls):
1002 return obj
1003
1004 interface = set(methods or [m for m in dir(cls) if not m.startswith("_")])
1005 implemented = set(dir(obj))
1006
1007 complies = operator.ge
1008 if isinstance(required, type):
1009 required = interface
1010 elif not required:
1011 required = set()
1012 complies = operator.gt
1013 else:
1014 required = set(required)
1015
1016 if complies(implemented.intersection(interface), required):
1017 return obj
1018
1019 # No dict duck typing here.
1020 if not isinstance(obj, dict):
1021 qualifier = complies is operator.gt and "any of" or "all of"
1022 raise TypeError(
1023 "%r does not implement %s: %s"
1024 % (obj, qualifier, ", ".join(interface))
1025 )
1026
1027 class AnonymousInterface:
1028 """A callable-holding shell."""
1029
1030 if cls:
1031 AnonymousInterface.__name__ = "Anonymous" + cls.__name__
1032 found = set()
1033
1034 for method, impl in dictlike_iteritems(obj):
1035 if method not in interface:
1036 raise TypeError("%r: unknown in this interface" % method)
1037 if not callable(impl):
1038 raise TypeError("%r=%r is not callable" % (method, impl))
1039 setattr(AnonymousInterface, method, staticmethod(impl))
1040 found.add(method)
1041
1042 if complies(found, required):
1043 return AnonymousInterface
1044
1045 raise TypeError(
1046 "dictionary does not contain required keys %s"
1047 % ", ".join(required - found)
1048 )
1049
1050
1051_GFD = TypeVar("_GFD", bound="generic_fn_descriptor[Any]")
1052
1053
1054class generic_fn_descriptor(Generic[_T_co]):
1055 """Descriptor which proxies a function when the attribute is not
1056 present in dict
1057
1058 This superclass is organized in a particular way with "memoized" and
1059 "non-memoized" implementation classes that are hidden from type checkers,
1060 as Mypy seems to not be able to handle seeing multiple kinds of descriptor
1061 classes used for the same attribute.
1062
1063 """
1064
1065 fget: Callable[..., _T_co]
1066 __doc__: Optional[str]
1067 __name__: str
1068
1069 def __init__(self, fget: Callable[..., _T_co], doc: Optional[str] = None):
1070 self.fget = fget
1071 self.__doc__ = doc or fget.__doc__
1072 self.__name__ = fget.__name__
1073
1074 @overload
1075 def __get__(self: _GFD, obj: None, cls: Any) -> _GFD: ...
1076
1077 @overload
1078 def __get__(self, obj: object, cls: Any) -> _T_co: ...
1079
1080 def __get__(self: _GFD, obj: Any, cls: Any) -> Union[_GFD, _T_co]:
1081 raise NotImplementedError()
1082
1083 if TYPE_CHECKING:
1084
1085 def __set__(self, instance: Any, value: Any) -> None: ...
1086
1087 def __delete__(self, instance: Any) -> None: ...
1088
1089 def _reset(self, obj: Any) -> None:
1090 raise NotImplementedError()
1091
1092 @classmethod
1093 def reset(cls, obj: Any, name: str) -> None:
1094 raise NotImplementedError()
1095
1096
1097class _non_memoized_property(generic_fn_descriptor[_T_co]):
1098 """a plain descriptor that proxies a function.
1099
1100 primary rationale is to provide a plain attribute that's
1101 compatible with memoized_property which is also recognized as equivalent
1102 by mypy.
1103
1104 """
1105
1106 if not TYPE_CHECKING:
1107
1108 def __get__(self, obj, cls):
1109 if obj is None:
1110 return self
1111 return self.fget(obj)
1112
1113
1114class _memoized_property(generic_fn_descriptor[_T_co]):
1115 """A read-only @property that is only evaluated once."""
1116
1117 if not TYPE_CHECKING:
1118
1119 def __get__(self, obj, cls):
1120 if obj is None:
1121 return self
1122 obj.__dict__[self.__name__] = result = self.fget(obj)
1123 return result
1124
1125 def _reset(self, obj):
1126 _memoized_property.reset(obj, self.__name__)
1127
1128 @classmethod
1129 def reset(cls, obj, name):
1130 obj.__dict__.pop(name, None)
1131
1132
1133# despite many attempts to get Mypy to recognize an overridden descriptor
1134# where one is memoized and the other isn't, there seems to be no reliable
1135# way other than completely deceiving the type checker into thinking there
1136# is just one single descriptor type everywhere. Otherwise, if a superclass
1137# has non-memoized and subclass has memoized, that requires
1138# "class memoized(non_memoized)". but then if a superclass has memoized and
1139# superclass has non-memoized, the class hierarchy of the descriptors
1140# would need to be reversed; "class non_memoized(memoized)". so there's no
1141# way to achieve this.
1142# additional issues, RO properties:
1143# https://github.com/python/mypy/issues/12440
1144if TYPE_CHECKING:
1145 # allow memoized and non-memoized to be freely mixed by having them
1146 # be the same class
1147 memoized_property = generic_fn_descriptor
1148 non_memoized_property = generic_fn_descriptor
1149
1150 # for read only situations, mypy only sees @property as read only.
1151 # read only is needed when a subtype specializes the return type
1152 # of a property, meaning assignment needs to be disallowed
1153 ro_memoized_property = property
1154 ro_non_memoized_property = property
1155
1156else:
1157 memoized_property = ro_memoized_property = _memoized_property
1158 non_memoized_property = ro_non_memoized_property = _non_memoized_property
1159
1160
1161def memoized_instancemethod(fn: _F) -> _F:
1162 """Decorate a method memoize its return value.
1163
1164 Best applied to no-arg methods: memoization is not sensitive to
1165 argument values, and will always return the same value even when
1166 called with different arguments.
1167
1168 """
1169
1170 def oneshot(self, *args, **kw):
1171 result = fn(self, *args, **kw)
1172
1173 def memo(*a, **kw):
1174 return result
1175
1176 memo.__name__ = fn.__name__
1177 memo.__doc__ = fn.__doc__
1178 self.__dict__[fn.__name__] = memo
1179 return result
1180
1181 return update_wrapper(oneshot, fn) # type: ignore
1182
1183
1184class HasMemoized:
1185 """A mixin class that maintains the names of memoized elements in a
1186 collection for easy cache clearing, generative, etc.
1187
1188 """
1189
1190 if not TYPE_CHECKING:
1191 # support classes that want to have __slots__ with an explicit
1192 # slot for __dict__. not sure if that requires base __slots__ here.
1193 __slots__ = ()
1194
1195 _memoized_keys: FrozenSet[str] = frozenset()
1196
1197 def _reset_memoizations(self) -> None:
1198 for elem in self._memoized_keys:
1199 self.__dict__.pop(elem, None)
1200
1201 def _assert_no_memoizations(self) -> None:
1202 for elem in self._memoized_keys:
1203 assert elem not in self.__dict__
1204
1205 def _set_memoized_attribute(self, key: str, value: Any) -> None:
1206 self.__dict__[key] = value
1207 self._memoized_keys |= {key}
1208
1209 class memoized_attribute(memoized_property[_T]):
1210 """A read-only @property that is only evaluated once.
1211
1212 :meta private:
1213
1214 """
1215
1216 fget: Callable[..., _T]
1217 __doc__: Optional[str]
1218 __name__: str
1219
1220 def __init__(self, fget: Callable[..., _T], doc: Optional[str] = None):
1221 self.fget = fget
1222 self.__doc__ = doc or fget.__doc__
1223 self.__name__ = fget.__name__
1224
1225 @overload
1226 def __get__(self: _MA, obj: None, cls: Any) -> _MA: ...
1227
1228 @overload
1229 def __get__(self, obj: Any, cls: Any) -> _T: ...
1230
1231 def __get__(self, obj, cls):
1232 if obj is None:
1233 return self
1234 obj.__dict__[self.__name__] = result = self.fget(obj)
1235 obj._memoized_keys |= {self.__name__}
1236 return result
1237
1238 @classmethod
1239 def memoized_instancemethod(cls, fn: _F) -> _F:
1240 """Decorate a method memoize its return value.
1241
1242 :meta private:
1243
1244 """
1245
1246 def oneshot(self: Any, *args: Any, **kw: Any) -> Any:
1247 result = fn(self, *args, **kw)
1248
1249 def memo(*a, **kw):
1250 return result
1251
1252 memo.__name__ = fn.__name__
1253 memo.__doc__ = fn.__doc__
1254 self.__dict__[fn.__name__] = memo
1255 self._memoized_keys |= {fn.__name__}
1256 return result
1257
1258 return update_wrapper(oneshot, fn) # type: ignore
1259
1260
1261if TYPE_CHECKING:
1262 HasMemoized_ro_memoized_attribute = property
1263else:
1264 HasMemoized_ro_memoized_attribute = HasMemoized.memoized_attribute
1265
1266
1267class MemoizedSlots:
1268 """Apply memoized items to an object using a __getattr__ scheme.
1269
1270 This allows the functionality of memoized_property and
1271 memoized_instancemethod to be available to a class using __slots__.
1272
1273 The memoized get is not threadsafe under freethreading and the
1274 creator method may in extremely rare cases be called more than once.
1275
1276 """
1277
1278 __slots__ = ()
1279
1280 def _fallback_getattr(self, key):
1281 raise AttributeError(key)
1282
1283 def __getattr__(self, key: str) -> Any:
1284 if key.startswith("_memoized_attr_") or key.startswith(
1285 "_memoized_method_"
1286 ):
1287 raise AttributeError(key)
1288 # to avoid recursion errors when interacting with other __getattr__
1289 # schemes that refer to this one, when testing for memoized method
1290 # look at __class__ only rather than going into __getattr__ again.
1291 elif hasattr(self.__class__, f"_memoized_attr_{key}"):
1292 value = getattr(self, f"_memoized_attr_{key}")()
1293 setattr(self, key, value)
1294 return value
1295 elif hasattr(self.__class__, f"_memoized_method_{key}"):
1296 meth = getattr(self, f"_memoized_method_{key}")
1297
1298 def oneshot(*args, **kw):
1299 result = meth(*args, **kw)
1300
1301 def memo(*a, **kw):
1302 return result
1303
1304 memo.__name__ = meth.__name__
1305 memo.__doc__ = meth.__doc__
1306 setattr(self, key, memo)
1307 return result
1308
1309 oneshot.__doc__ = meth.__doc__
1310 return oneshot
1311 else:
1312 return self._fallback_getattr(key)
1313
1314
1315# from paste.deploy.converters
1316def asbool(obj: Any) -> bool:
1317 if isinstance(obj, str):
1318 obj = obj.strip().lower()
1319 if obj in ["true", "yes", "on", "y", "t", "1"]:
1320 return True
1321 elif obj in ["false", "no", "off", "n", "f", "0"]:
1322 return False
1323 else:
1324 raise ValueError("String is not true/false: %r" % obj)
1325 return bool(obj)
1326
1327
1328def bool_or_str(*text: str) -> Callable[[str], Union[str, bool]]:
1329 """Return a callable that will evaluate a string as
1330 boolean, or one of a set of "alternate" string values.
1331
1332 """
1333
1334 def bool_or_value(obj: str) -> Union[str, bool]:
1335 if obj in text:
1336 return obj
1337 else:
1338 return asbool(obj)
1339
1340 return bool_or_value
1341
1342
1343def asint(value: Any) -> Optional[int]:
1344 """Coerce to integer."""
1345
1346 if value is None:
1347 return value
1348 return int(value)
1349
1350
1351def coerce_kw_type(
1352 kw: Dict[str, Any],
1353 key: str,
1354 type_: Type[Any],
1355 flexi_bool: bool = True,
1356 dest: Optional[Dict[str, Any]] = None,
1357) -> None:
1358 r"""If 'key' is present in dict 'kw', coerce its value to type 'type\_' if
1359 necessary. If 'flexi_bool' is True, the string '0' is considered false
1360 when coercing to boolean.
1361 """
1362
1363 if dest is None:
1364 dest = kw
1365
1366 if (
1367 key in kw
1368 and (not isinstance(type_, type) or not isinstance(kw[key], type_))
1369 and kw[key] is not None
1370 ):
1371 if type_ is bool and flexi_bool:
1372 dest[key] = asbool(kw[key])
1373 else:
1374 dest[key] = type_(kw[key])
1375
1376
1377def constructor_key(obj: Any, cls: Type[Any]) -> Tuple[Any, ...]:
1378 """Produce a tuple structure that is cacheable using the __dict__ of
1379 obj to retrieve values
1380
1381 """
1382 names = get_cls_kwargs(cls)
1383 return (cls,) + tuple(
1384 (k, obj.__dict__[k]) for k in names if k in obj.__dict__
1385 )
1386
1387
1388def constructor_copy(obj: _T, cls: Type[_T], *args: Any, **kw: Any) -> _T:
1389 """Instantiate cls using the __dict__ of obj as constructor arguments.
1390
1391 Uses inspect to match the named arguments of ``cls``.
1392
1393 """
1394
1395 names = get_cls_kwargs(cls)
1396 kw.update(
1397 (k, obj.__dict__[k]) for k in names.difference(kw) if k in obj.__dict__
1398 )
1399 return cls(*args, **kw)
1400
1401
1402def counter() -> Callable[[], int]:
1403 """Return a threadsafe counter function."""
1404
1405 lock = threading.Lock()
1406 counter = itertools.count(1)
1407
1408 # avoid the 2to3 "next" transformation...
1409 def _next():
1410 with lock:
1411 return next(counter)
1412
1413 return _next
1414
1415
1416def duck_type_collection(
1417 specimen: Any, default: Optional[Type[Any]] = None
1418) -> Optional[Type[Any]]:
1419 """Given an instance or class, guess if it is or is acting as one of
1420 the basic collection types: list, set and dict. If the __emulates__
1421 property is present, return that preferentially.
1422 """
1423
1424 if hasattr(specimen, "__emulates__"):
1425 # canonicalize set vs sets.Set to a standard: the builtin set
1426 if specimen.__emulates__ is not None and issubclass(
1427 specimen.__emulates__, set
1428 ):
1429 return set
1430 else:
1431 return specimen.__emulates__ # type: ignore
1432
1433 isa = issubclass if isinstance(specimen, type) else isinstance
1434 if isa(specimen, list):
1435 return list
1436 elif isa(specimen, set):
1437 return set
1438 elif isa(specimen, dict):
1439 return dict
1440
1441 if hasattr(specimen, "append"):
1442 return list
1443 elif hasattr(specimen, "add"):
1444 return set
1445 elif hasattr(specimen, "set"):
1446 return dict
1447 else:
1448 return default
1449
1450
1451def assert_arg_type(
1452 arg: Any, argtype: Union[Tuple[Type[Any], ...], Type[Any]], name: str
1453) -> Any:
1454 if isinstance(arg, argtype):
1455 return arg
1456 else:
1457 if isinstance(argtype, tuple):
1458 raise exc.ArgumentError(
1459 "Argument '%s' is expected to be one of type %s, got '%s'"
1460 % (name, " or ".join("'%s'" % a for a in argtype), type(arg))
1461 )
1462 else:
1463 raise exc.ArgumentError(
1464 "Argument '%s' is expected to be of type '%s', got '%s'"
1465 % (name, argtype, type(arg))
1466 )
1467
1468
1469def dictlike_iteritems(dictlike):
1470 """Return a (key, value) iterator for almost any dict-like object."""
1471
1472 if hasattr(dictlike, "items"):
1473 return list(dictlike.items())
1474
1475 getter = getattr(dictlike, "__getitem__", getattr(dictlike, "get", None))
1476 if getter is None:
1477 raise TypeError("Object '%r' is not dict-like" % dictlike)
1478
1479 if hasattr(dictlike, "iterkeys"):
1480
1481 def iterator():
1482 for key in dictlike.iterkeys():
1483 assert getter is not None
1484 yield key, getter(key)
1485
1486 return iterator()
1487 elif hasattr(dictlike, "keys"):
1488 return iter((key, getter(key)) for key in dictlike.keys())
1489 else:
1490 raise TypeError("Object '%r' is not dict-like" % dictlike)
1491
1492
1493class classproperty(property):
1494 """A decorator that behaves like @property except that operates
1495 on classes rather than instances.
1496
1497 The decorator is currently special when using the declarative
1498 module, but note that the
1499 :class:`~.sqlalchemy.ext.declarative.declared_attr`
1500 decorator should be used for this purpose with declarative.
1501
1502 """
1503
1504 fget: Callable[[Any], Any]
1505
1506 def __init__(self, fget: Callable[[Any], Any], *arg: Any, **kw: Any):
1507 super().__init__(fget, *arg, **kw)
1508 self.__doc__ = fget.__doc__
1509
1510 def __get__(self, obj: Any, cls: Optional[type] = None) -> Any:
1511 return self.fget(cls)
1512
1513
1514class hybridproperty(Generic[_T]):
1515 def __init__(self, func: Callable[..., _T]):
1516 self.func = func
1517 self.clslevel = func
1518
1519 def __get__(self, instance: Any, owner: Any) -> _T:
1520 if instance is None:
1521 clsval = self.clslevel(owner)
1522 return clsval
1523 else:
1524 return self.func(instance)
1525
1526 def classlevel(self, func: Callable[..., Any]) -> hybridproperty[_T]:
1527 self.clslevel = func
1528 return self
1529
1530
1531class rw_hybridproperty(Generic[_T]):
1532 def __init__(self, func: Callable[..., _T]):
1533 self.func = func
1534 self.clslevel = func
1535 self.setfn: Optional[Callable[..., Any]] = None
1536
1537 def __get__(self, instance: Any, owner: Any) -> _T:
1538 if instance is None:
1539 clsval = self.clslevel(owner)
1540 return clsval
1541 else:
1542 return self.func(instance)
1543
1544 def __set__(self, instance: Any, value: Any) -> None:
1545 assert self.setfn is not None
1546 self.setfn(instance, value)
1547
1548 def setter(self, func: Callable[..., Any]) -> rw_hybridproperty[_T]:
1549 self.setfn = func
1550 return self
1551
1552 def classlevel(self, func: Callable[..., Any]) -> rw_hybridproperty[_T]:
1553 self.clslevel = func
1554 return self
1555
1556
1557class hybridmethod(Generic[_T]):
1558 """Decorate a function as cls- or instance- level."""
1559
1560 def __init__(self, func: Callable[..., _T]):
1561 self.func = self.__func__ = func
1562 self.clslevel = func
1563
1564 def __get__(self, instance: Any, owner: Any) -> Callable[..., _T]:
1565 if instance is None:
1566 return self.clslevel.__get__( # type: ignore[no-any-return]
1567 owner, owner.__class__
1568 )
1569 else:
1570 return self.func.__get__( # type: ignore[no-any-return]
1571 instance, owner
1572 )
1573
1574 def classlevel(self, func: Callable[..., Any]) -> hybridmethod[_T]:
1575 self.clslevel = func
1576 return self
1577
1578
1579class symbol(int):
1580 """A constant symbol.
1581
1582 >>> symbol("foo") is symbol("foo")
1583 True
1584 >>> symbol("foo")
1585 <symbol 'foo>
1586
1587 A slight refinement of the MAGICCOOKIE=object() pattern. The primary
1588 advantage of symbol() is its repr(). They are also singletons.
1589
1590 Repeated calls of symbol('name') will all return the same instance.
1591
1592 """
1593
1594 name: str
1595
1596 symbols: Dict[str, symbol] = {}
1597 _lock = threading.Lock()
1598
1599 def __new__(
1600 cls,
1601 name: str,
1602 doc: Optional[str] = None,
1603 canonical: Optional[int] = None,
1604 ) -> symbol:
1605 with cls._lock:
1606 sym = cls.symbols.get(name)
1607 if sym is None:
1608 assert isinstance(name, str)
1609 if canonical is None:
1610 canonical = hash(name)
1611 sym = int.__new__(symbol, canonical)
1612 sym.name = name
1613 if doc:
1614 sym.__doc__ = doc
1615
1616 # NOTE: we should ultimately get rid of this global thing,
1617 # however, currently it is to support pickling. The best
1618 # change would be when we are on py3.11 at a minimum, we
1619 # switch to stdlib enum.IntFlag.
1620 cls.symbols[name] = sym
1621 else:
1622 if canonical and canonical != sym:
1623 raise TypeError(
1624 f"Can't replace canonical symbol for {name!r} "
1625 f"with new int value {canonical}"
1626 )
1627 return sym
1628
1629 def __reduce__(self):
1630 return symbol, (self.name, "x", int(self))
1631
1632 def __str__(self):
1633 return repr(self)
1634
1635 def __repr__(self):
1636 return f"symbol({self.name!r})"
1637
1638
1639class _IntFlagMeta(type):
1640 def __init__(
1641 cls,
1642 classname: str,
1643 bases: Tuple[Type[Any], ...],
1644 dict_: Dict[str, Any],
1645 **kw: Any,
1646 ) -> None:
1647 items: List[symbol]
1648 cls._items = items = []
1649 for k, v in dict_.items():
1650 if re.match(r"^__.*__$", k):
1651 continue
1652 if isinstance(v, int):
1653 sym = symbol(k, canonical=v)
1654 elif not k.startswith("_"):
1655 raise TypeError("Expected integer values for IntFlag")
1656 else:
1657 continue
1658 setattr(cls, k, sym)
1659 items.append(sym)
1660
1661 cls.__members__ = _collections.immutabledict(
1662 {sym.name: sym for sym in items}
1663 )
1664
1665 def __iter__(self) -> Iterator[symbol]:
1666 raise NotImplementedError(
1667 "iter not implemented to ensure compatibility with "
1668 "Python 3.11 IntFlag. Please use __members__. See "
1669 "https://github.com/python/cpython/issues/99304"
1670 )
1671
1672
1673class _FastIntFlag(metaclass=_IntFlagMeta):
1674 """An 'IntFlag' copycat that isn't slow when performing bitwise
1675 operations.
1676
1677 the ``FastIntFlag`` class will return ``enum.IntFlag`` under TYPE_CHECKING
1678 and ``_FastIntFlag`` otherwise.
1679
1680 """
1681
1682
1683if TYPE_CHECKING:
1684 from enum import IntFlag
1685
1686 FastIntFlag = IntFlag
1687else:
1688 FastIntFlag = _FastIntFlag
1689
1690
1691_E = TypeVar("_E", bound=enum.Enum)
1692
1693
1694def parse_user_argument_for_enum(
1695 arg: Any,
1696 choices: Dict[_E, List[Any]],
1697 name: str,
1698 resolve_symbol_names: bool = False,
1699) -> Optional[_E]:
1700 """Given a user parameter, parse the parameter into a chosen value
1701 from a list of choice objects, typically Enum values.
1702
1703 The user argument can be a string name that matches the name of a
1704 symbol, or the symbol object itself, or any number of alternate choices
1705 such as True/False/ None etc.
1706
1707 :param arg: the user argument.
1708 :param choices: dictionary of enum values to lists of possible
1709 entries for each.
1710 :param name: name of the argument. Used in an :class:`.ArgumentError`
1711 that is raised if the parameter doesn't match any available argument.
1712
1713 """
1714 for enum_value, choice in choices.items():
1715 if arg is enum_value:
1716 return enum_value
1717 elif resolve_symbol_names and arg == enum_value.name:
1718 return enum_value
1719 elif arg in choice:
1720 return enum_value
1721
1722 if arg is None:
1723 return None
1724
1725 raise exc.ArgumentError(f"Invalid value for '{name}': {arg!r}")
1726
1727
1728_creation_order = 1
1729
1730
1731def set_creation_order(instance: Any) -> None:
1732 """Assign a '_creation_order' sequence to the given instance.
1733
1734 This allows multiple instances to be sorted in order of creation
1735 (typically within a single thread; the counter is not particularly
1736 threadsafe).
1737
1738 """
1739 global _creation_order
1740 instance._creation_order = _creation_order
1741 _creation_order += 1
1742
1743
1744def warn_exception(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
1745 """executes the given function, catches all exceptions and converts to
1746 a warning.
1747
1748 """
1749 try:
1750 return func(*args, **kwargs)
1751 except Exception:
1752 warn("%s('%s') ignored" % sys.exc_info()[0:2])
1753
1754
1755def ellipses_string(value, len_=25):
1756 try:
1757 if len(value) > len_:
1758 return "%s..." % value[0:len_]
1759 else:
1760 return value
1761 except TypeError:
1762 return value
1763
1764
1765class _hash_limit_string(str):
1766 """A string subclass that can only be hashed on a maximum amount
1767 of unique values.
1768
1769 This is used for warnings so that we can send out parameterized warnings
1770 without the __warningregistry__ of the module, or the non-overridable
1771 "once" registry within warnings.py, overloading memory,
1772
1773
1774 """
1775
1776 _hash: int
1777
1778 def __new__(
1779 cls, value: str, num: int, args: Sequence[Any]
1780 ) -> _hash_limit_string:
1781 interpolated = (value % args) + (
1782 " (this warning may be suppressed after %d occurrences)" % num
1783 )
1784 self = super().__new__(cls, interpolated)
1785 self._hash = hash("%s_%d" % (value, hash(interpolated) % num))
1786 return self
1787
1788 def __hash__(self) -> int:
1789 return self._hash
1790
1791 def __eq__(self, other: Any) -> bool:
1792 return hash(self) == hash(other)
1793
1794
1795def warn(msg: str, code: Optional[str] = None) -> None:
1796 """Issue a warning.
1797
1798 If msg is a string, :class:`.exc.SAWarning` is used as
1799 the category.
1800
1801 """
1802 if code:
1803 _warnings_warn(exc.SAWarning(msg, code=code))
1804 else:
1805 _warnings_warn(msg, exc.SAWarning)
1806
1807
1808def warn_limited(msg: str, args: Sequence[Any]) -> None:
1809 """Issue a warning with a parameterized string, limiting the number
1810 of registrations.
1811
1812 """
1813 if args:
1814 msg = _hash_limit_string(msg, 10, args)
1815 _warnings_warn(msg, exc.SAWarning)
1816
1817
1818_warning_tags: Dict[CodeType, Tuple[str, Type[Warning]]] = {}
1819
1820
1821def tag_method_for_warnings(
1822 message: str, category: Type[Warning]
1823) -> Callable[[_F], _F]:
1824 def go(fn):
1825 _warning_tags[fn.__code__] = (message, category)
1826 return fn
1827
1828 return go
1829
1830
1831_not_sa_pattern = re.compile(r"^(?:sqlalchemy\.(?!testing)|alembic\.)")
1832
1833
1834def _warnings_warn(
1835 message: Union[str, Warning],
1836 category: Optional[Type[Warning]] = None,
1837 stacklevel: int = 2,
1838) -> None:
1839 # adjust the given stacklevel to be outside of SQLAlchemy
1840 try:
1841 frame = sys._getframe(stacklevel)
1842 except ValueError:
1843 # being called from less than 3 (or given) stacklevels, weird,
1844 # but don't crash
1845 stacklevel = 0
1846 except:
1847 # _getframe() doesn't work, weird interpreter issue, weird,
1848 # ok, but don't crash
1849 stacklevel = 0
1850 else:
1851 stacklevel_found = warning_tag_found = False
1852 while frame is not None:
1853 # using __name__ here requires that we have __name__ in the
1854 # __globals__ of the decorated string functions we make also.
1855 # we generate this using {"__name__": fn.__module__}
1856 if not stacklevel_found and not re.match(
1857 _not_sa_pattern, frame.f_globals.get("__name__", "")
1858 ):
1859 # stop incrementing stack level if an out-of-SQLA line
1860 # were found.
1861 stacklevel_found = True
1862
1863 # however, for the warning tag thing, we have to keep
1864 # scanning up the whole traceback
1865
1866 if frame.f_code in _warning_tags:
1867 warning_tag_found = True
1868 _suffix, _category = _warning_tags[frame.f_code]
1869 category = category or _category
1870 message = f"{message} ({_suffix})"
1871
1872 frame = frame.f_back # type: ignore[assignment]
1873
1874 if not stacklevel_found:
1875 stacklevel += 1
1876 elif stacklevel_found and warning_tag_found:
1877 break
1878
1879 if category is not None:
1880 warnings.warn(message, category, stacklevel=stacklevel + 1)
1881 else:
1882 warnings.warn(message, stacklevel=stacklevel + 1)
1883
1884
1885def only_once(
1886 fn: Callable[..., _T], retry_on_exception: bool
1887) -> Callable[..., Optional[_T]]:
1888 """Decorate the given function to be a no-op after it is called exactly
1889 once."""
1890
1891 once = [fn]
1892
1893 def go(*arg: Any, **kw: Any) -> Optional[_T]:
1894 # strong reference fn so that it isn't garbage collected,
1895 # which interferes with the event system's expectations
1896 strong_fn = fn # noqa
1897 if once:
1898 once_fn = once.pop()
1899 try:
1900 return once_fn(*arg, **kw)
1901 except:
1902 if retry_on_exception:
1903 once.insert(0, once_fn)
1904 raise
1905
1906 return None
1907
1908 return go
1909
1910
1911_SQLA_RE = re.compile(r"sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py")
1912_UNITTEST_RE = re.compile(r"unit(?:2|test2?/)")
1913
1914
1915def chop_traceback(
1916 tb: List[str],
1917 exclude_prefix: re.Pattern[str] = _UNITTEST_RE,
1918 exclude_suffix: re.Pattern[str] = _SQLA_RE,
1919) -> List[str]:
1920 """Chop extraneous lines off beginning and end of a traceback.
1921
1922 :param tb:
1923 a list of traceback lines as returned by ``traceback.format_stack()``
1924
1925 :param exclude_prefix:
1926 a regular expression object matching lines to skip at beginning of
1927 ``tb``
1928
1929 :param exclude_suffix:
1930 a regular expression object matching lines to skip at end of ``tb``
1931 """
1932 start = 0
1933 end = len(tb) - 1
1934 while start <= end and exclude_prefix.search(tb[start]):
1935 start += 1
1936 while start <= end and exclude_suffix.search(tb[end]):
1937 end -= 1
1938 return tb[start : end + 1]
1939
1940
1941NoneType = type(None)
1942
1943
1944def attrsetter(attrname):
1945 code = "def set(obj, value): obj.%s = value" % attrname
1946 env = locals().copy()
1947 exec(code, env)
1948 return env["set"]
1949
1950
1951_dunders = re.compile("^__.+__$")
1952
1953
1954class TypingOnly:
1955 """A mixin class that marks a class as 'typing only', meaning it has
1956 absolutely no methods, attributes, or runtime functionality whatsoever.
1957
1958 """
1959
1960 __slots__ = ()
1961
1962 def __init_subclass__(cls) -> None:
1963 if TypingOnly in cls.__bases__:
1964 remaining = {
1965 name for name in cls.__dict__ if not _dunders.match(name)
1966 }
1967 if remaining:
1968 raise AssertionError(
1969 f"Class {cls} directly inherits TypingOnly but has "
1970 f"additional attributes {remaining}."
1971 )
1972 super().__init_subclass__()
1973
1974
1975class EnsureKWArg:
1976 r"""Apply translation of functions to accept \**kw arguments if they
1977 don't already.
1978
1979 Used to ensure cross-compatibility with third party legacy code, for things
1980 like compiler visit methods that need to accept ``**kw`` arguments,
1981 but may have been copied from old code that didn't accept them.
1982
1983 """
1984
1985 ensure_kwarg: str
1986 """a regular expression that indicates method names for which the method
1987 should accept ``**kw`` arguments.
1988
1989 The class will scan for methods matching the name template and decorate
1990 them if necessary to ensure ``**kw`` parameters are accepted.
1991
1992 """
1993
1994 def __init_subclass__(cls) -> None:
1995 fn_reg = cls.ensure_kwarg
1996 clsdict = cls.__dict__
1997 if fn_reg:
1998 for key in clsdict:
1999 m = re.match(fn_reg, key)
2000 if m:
2001 fn = clsdict[key]
2002 spec = compat.inspect_getfullargspec(fn)
2003 if not spec.varkw:
2004 wrapped = cls._wrap_w_kw(fn)
2005 setattr(cls, key, wrapped)
2006 super().__init_subclass__()
2007
2008 @classmethod
2009 def _wrap_w_kw(cls, fn: Callable[..., Any]) -> Callable[..., Any]:
2010 def wrap(*arg: Any, **kw: Any) -> Any:
2011 return fn(*arg)
2012
2013 return update_wrapper(wrap, fn)
2014
2015
2016def wrap_callable(wrapper, fn):
2017 """Augment functools.update_wrapper() to work with objects with
2018 a ``__call__()`` method.
2019
2020 :param fn:
2021 object with __call__ method
2022
2023 """
2024 if hasattr(fn, "__name__"):
2025 return update_wrapper(wrapper, fn)
2026 else:
2027 _f = wrapper
2028 _f.__name__ = fn.__class__.__name__
2029 if hasattr(fn, "__module__"):
2030 _f.__module__ = fn.__module__
2031
2032 if hasattr(fn.__call__, "__doc__") and fn.__call__.__doc__:
2033 _f.__doc__ = fn.__call__.__doc__
2034 elif fn.__doc__:
2035 _f.__doc__ = fn.__doc__
2036
2037 return _f
2038
2039
2040def quoted_token_parser(value):
2041 """Parse a dotted identifier with accommodation for quoted names.
2042
2043 Includes support for SQL-style double quotes as a literal character.
2044
2045 E.g.::
2046
2047 >>> quoted_token_parser("name")
2048 ["name"]
2049 >>> quoted_token_parser("schema.name")
2050 ["schema", "name"]
2051 >>> quoted_token_parser('"Schema"."Name"')
2052 ['Schema', 'Name']
2053 >>> quoted_token_parser('"Schema"."Name""Foo"')
2054 ['Schema', 'Name""Foo']
2055
2056 """
2057
2058 if '"' not in value:
2059 return value.split(".")
2060
2061 # 0 = outside of quotes
2062 # 1 = inside of quotes
2063 state = 0
2064 result: List[List[str]] = [[]]
2065 idx = 0
2066 lv = len(value)
2067 while idx < lv:
2068 char = value[idx]
2069 if char == '"':
2070 if state == 1 and idx < lv - 1 and value[idx + 1] == '"':
2071 result[-1].append('"')
2072 idx += 1
2073 else:
2074 state ^= 1
2075 elif char == "." and state == 0:
2076 result.append([])
2077 else:
2078 result[-1].append(char)
2079 idx += 1
2080
2081 return ["".join(token) for token in result]
2082
2083
2084def add_parameter_text(params: Any, text: str) -> Callable[[_F], _F]:
2085 params = _collections.to_list(params)
2086
2087 def decorate(fn):
2088 doc = fn.__doc__ is not None and fn.__doc__ or ""
2089 if doc:
2090 doc = inject_param_text(doc, {param: text for param in params})
2091 fn.__doc__ = doc
2092 return fn
2093
2094 return decorate
2095
2096
2097def _dedent_docstring(text: str) -> str:
2098 split_text = text.split("\n", 1)
2099 if len(split_text) == 1:
2100 return text
2101 else:
2102 firstline, remaining = split_text
2103 if not firstline.startswith(" "):
2104 return firstline + "\n" + textwrap.dedent(remaining)
2105 else:
2106 return textwrap.dedent(text)
2107
2108
2109def inject_docstring_text(
2110 given_doctext: Optional[str], injecttext: str, pos: int
2111) -> str:
2112 doctext: str = _dedent_docstring(given_doctext or "")
2113 lines = doctext.split("\n")
2114 if len(lines) == 1:
2115 lines.append("")
2116 injectlines = textwrap.dedent(injecttext).split("\n")
2117 if injectlines[0]:
2118 injectlines.insert(0, "")
2119
2120 blanks = [num for num, line in enumerate(lines) if not line.strip()]
2121 blanks.insert(0, 0)
2122
2123 inject_pos = blanks[min(pos, len(blanks) - 1)]
2124
2125 lines = lines[0:inject_pos] + injectlines + lines[inject_pos:]
2126 return "\n".join(lines)
2127
2128
2129_param_reg = re.compile(r"(\s+):param (.+?):")
2130
2131
2132def inject_param_text(doctext: str, inject_params: Dict[str, str]) -> str:
2133 doclines = collections.deque(doctext.splitlines())
2134 lines = []
2135
2136 # TODO: this is not working for params like ":param case_sensitive=True:"
2137
2138 to_inject = None
2139 while doclines:
2140 line = doclines.popleft()
2141
2142 m = _param_reg.match(line)
2143
2144 if to_inject is None:
2145 if m:
2146 param = m.group(2).lstrip("*")
2147 if param in inject_params:
2148 # default indent to that of :param: plus one
2149 indent = " " * len(m.group(1)) + " "
2150
2151 # but if the next line has text, use that line's
2152 # indentation
2153 if doclines:
2154 m2 = re.match(r"(\s+)\S", doclines[0])
2155 if m2:
2156 indent = " " * len(m2.group(1))
2157
2158 to_inject = indent + inject_params[param]
2159 elif m:
2160 lines.extend(["\n", to_inject, "\n"])
2161 to_inject = None
2162 elif not line.rstrip():
2163 lines.extend([line, to_inject, "\n"])
2164 to_inject = None
2165 elif line.endswith("::"):
2166 # TODO: this still won't cover if the code example itself has
2167 # blank lines in it, need to detect those via indentation.
2168 lines.extend([line, doclines.popleft()])
2169 continue
2170 lines.append(line)
2171
2172 return "\n".join(lines)
2173
2174
2175def repr_tuple_names(names: List[str]) -> Optional[str]:
2176 """Trims a list of strings from the middle and return a string of up to
2177 four elements. Strings greater than 11 characters will be truncated"""
2178 if len(names) == 0:
2179 return None
2180 flag = len(names) <= 4
2181 names = names[0:4] if flag else names[0:3] + names[-1:]
2182 res = ["%s.." % name[:11] if len(name) > 11 else name for name in names]
2183 if flag:
2184 return ", ".join(res)
2185 else:
2186 return "%s, ..., %s" % (", ".join(res[0:3]), res[-1])
2187
2188
2189def has_compiled_ext(raise_=False):
2190 if HAS_CYEXTENSION:
2191 return True
2192 elif raise_:
2193 raise ImportError(
2194 "cython extensions were expected to be installed, "
2195 "but are not present"
2196 )
2197 else:
2198 return False
2199
2200
2201class _Missing(enum.Enum):
2202 Missing = enum.auto()
2203
2204
2205Missing = _Missing.Missing
2206MissingOr = Union[_T, Literal[_Missing.Missing]]