1# util/langhelpers.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Routines to help with the creation, loading and introspection of
10modules, classes, hierarchies, attributes, functions, and methods.
11
12"""
13from __future__ import annotations
14
15import collections
16import enum
17from functools import update_wrapper
18import inspect
19import itertools
20import operator
21import re
22import sys
23import textwrap
24import threading
25import types
26from types import CodeType
27from typing import Any
28from typing import Callable
29from typing import cast
30from typing import Dict
31from typing import FrozenSet
32from typing import Generic
33from typing import Iterator
34from typing import List
35from typing import NoReturn
36from typing import Optional
37from typing import overload
38from typing import Sequence
39from typing import Set
40from typing import Tuple
41from typing import Type
42from typing import TYPE_CHECKING
43from typing import TypeVar
44from typing import Union
45import warnings
46
47from . import _collections
48from . import compat
49from ._has_cy import HAS_CYEXTENSION
50from .typing import Literal
51from .. import exc
52
53_T = TypeVar("_T")
54_T_co = TypeVar("_T_co", covariant=True)
55_F = TypeVar("_F", bound=Callable[..., Any])
56_MP = TypeVar("_MP", bound="memoized_property[Any]")
57_MA = TypeVar("_MA", bound="HasMemoized.memoized_attribute[Any]")
58_HP = TypeVar("_HP", bound="hybridproperty[Any]")
59_HM = TypeVar("_HM", bound="hybridmethod[Any]")
60
61
62def md5_hex(x: Any) -> str:
63 x = x.encode("utf-8")
64 m = compat.md5_not_for_security()
65 m.update(x)
66 return cast(str, m.hexdigest())
67
68
69class safe_reraise:
70 """Reraise an exception after invoking some
71 handler code.
72
73 Stores the existing exception info before
74 invoking so that it is maintained across a potential
75 coroutine context switch.
76
77 e.g.::
78
79 try:
80 sess.commit()
81 except:
82 with safe_reraise():
83 sess.rollback()
84
85 TODO: we should at some point evaluate current behaviors in this regard
86 based on current greenlet, gevent/eventlet implementations in Python 3, and
87 also see the degree to which our own asyncio (based on greenlet also) is
88 impacted by this. .rollback() will cause IO / context switch to occur in
89 all these scenarios; what happens to the exception context from an
90 "except:" block if we don't explicitly store it? Original issue was #2703.
91
92 """
93
94 __slots__ = ("_exc_info",)
95
96 _exc_info: Union[
97 None,
98 Tuple[
99 Type[BaseException],
100 BaseException,
101 types.TracebackType,
102 ],
103 Tuple[None, None, None],
104 ]
105
106 def __enter__(self) -> None:
107 self._exc_info = sys.exc_info()
108
109 def __exit__(
110 self,
111 type_: Optional[Type[BaseException]],
112 value: Optional[BaseException],
113 traceback: Optional[types.TracebackType],
114 ) -> NoReturn:
115 assert self._exc_info is not None
116 # see #2703 for notes
117 if type_ is None:
118 exc_type, exc_value, exc_tb = self._exc_info
119 assert exc_value is not None
120 self._exc_info = None # remove potential circular references
121 raise exc_value.with_traceback(exc_tb)
122 else:
123 self._exc_info = None # remove potential circular references
124 assert value is not None
125 raise value.with_traceback(traceback)
126
127
128def walk_subclasses(cls: Type[_T]) -> Iterator[Type[_T]]:
129 seen: Set[Any] = set()
130
131 stack = [cls]
132 while stack:
133 cls = stack.pop()
134 if cls in seen:
135 continue
136 else:
137 seen.add(cls)
138 stack.extend(cls.__subclasses__())
139 yield cls
140
141
142def string_or_unprintable(element: Any) -> str:
143 if isinstance(element, str):
144 return element
145 else:
146 try:
147 return str(element)
148 except Exception:
149 return "unprintable element %r" % element
150
151
152def clsname_as_plain_name(
153 cls: Type[Any], use_name: Optional[str] = None
154) -> str:
155 name = use_name or cls.__name__
156 return " ".join(n.lower() for n in re.findall(r"([A-Z][a-z]+|SQL)", name))
157
158
159def method_is_overridden(
160 instance_or_cls: Union[Type[Any], object],
161 against_method: Callable[..., Any],
162) -> bool:
163 """Return True if the two class methods don't match."""
164
165 if not isinstance(instance_or_cls, type):
166 current_cls = instance_or_cls.__class__
167 else:
168 current_cls = instance_or_cls
169
170 method_name = against_method.__name__
171
172 current_method: types.MethodType = getattr(current_cls, method_name)
173
174 return current_method != against_method
175
176
177def decode_slice(slc: slice) -> Tuple[Any, ...]:
178 """decode a slice object as sent to __getitem__.
179
180 takes into account the 2.5 __index__() method, basically.
181
182 """
183 ret: List[Any] = []
184 for x in slc.start, slc.stop, slc.step:
185 if hasattr(x, "__index__"):
186 x = x.__index__()
187 ret.append(x)
188 return tuple(ret)
189
190
191def _unique_symbols(used: Sequence[str], *bases: str) -> Iterator[str]:
192 used_set = set(used)
193 for base in bases:
194 pool = itertools.chain(
195 (base,),
196 map(lambda i: base + str(i), range(1000)),
197 )
198 for sym in pool:
199 if sym not in used_set:
200 used_set.add(sym)
201 yield sym
202 break
203 else:
204 raise NameError("exhausted namespace for symbol base %s" % base)
205
206
207def map_bits(fn: Callable[[int], Any], n: int) -> Iterator[Any]:
208 """Call the given function given each nonzero bit from n."""
209
210 while n:
211 b = n & (~n + 1)
212 yield fn(b)
213 n ^= b
214
215
216_Fn = TypeVar("_Fn", bound="Callable[..., Any]")
217
218# this seems to be in flux in recent mypy versions
219
220
221def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]:
222 """A signature-matching decorator factory."""
223
224 def decorate(fn: _Fn) -> _Fn:
225 if not inspect.isfunction(fn) and not inspect.ismethod(fn):
226 raise Exception("not a decoratable function")
227
228 # Python 3.14 defer creating __annotations__ until its used.
229 # We do not want to create __annotations__ now.
230 annofunc = getattr(fn, "__annotate__", None)
231 if annofunc is not None:
232 fn.__annotate__ = None # type: ignore[union-attr]
233 try:
234 spec = compat.inspect_getfullargspec(fn)
235 finally:
236 fn.__annotate__ = annofunc # type: ignore[union-attr]
237 else:
238 spec = compat.inspect_getfullargspec(fn)
239
240 # Do not generate code for annotations.
241 # update_wrapper() copies the annotation from fn to decorated.
242 # We use dummy defaults for code generation to avoid having
243 # copy of large globals for compiling.
244 # We copy __defaults__ and __kwdefaults__ from fn to decorated.
245 empty_defaults = (None,) * len(spec.defaults or ())
246 empty_kwdefaults = dict.fromkeys(spec.kwonlydefaults or ())
247 spec = spec._replace(
248 annotations={},
249 defaults=empty_defaults,
250 kwonlydefaults=empty_kwdefaults,
251 )
252
253 names = (
254 tuple(cast("Tuple[str, ...]", spec[0]))
255 + cast("Tuple[str, ...]", spec[1:3])
256 + (fn.__name__,)
257 )
258 targ_name, fn_name = _unique_symbols(names, "target", "fn")
259
260 metadata: Dict[str, Optional[str]] = dict(target=targ_name, fn=fn_name)
261 metadata.update(format_argspec_plus(spec, grouped=False))
262 metadata["name"] = fn.__name__
263
264 if inspect.iscoroutinefunction(fn):
265 metadata["prefix"] = "async "
266 metadata["target_prefix"] = "await "
267 else:
268 metadata["prefix"] = ""
269 metadata["target_prefix"] = ""
270
271 # look for __ positional arguments. This is a convention in
272 # SQLAlchemy that arguments should be passed positionally
273 # rather than as keyword
274 # arguments. note that apply_pos doesn't currently work in all cases
275 # such as when a kw-only indicator "*" is present, which is why
276 # we limit the use of this to just that case we can detect. As we add
277 # more kinds of methods that use @decorator, things may have to
278 # be further improved in this area
279 if "__" in repr(spec[0]):
280 code = (
281 """\
282%(prefix)sdef %(name)s%(grouped_args)s:
283 return %(target_prefix)s%(target)s(%(fn)s, %(apply_pos)s)
284"""
285 % metadata
286 )
287 else:
288 code = (
289 """\
290%(prefix)sdef %(name)s%(grouped_args)s:
291 return %(target_prefix)s%(target)s(%(fn)s, %(apply_kw)s)
292"""
293 % metadata
294 )
295
296 env: Dict[str, Any] = {
297 targ_name: target,
298 fn_name: fn,
299 "__name__": fn.__module__,
300 }
301
302 decorated = cast(
303 types.FunctionType,
304 _exec_code_in_env(code, env, fn.__name__),
305 )
306 decorated.__defaults__ = fn.__defaults__
307 decorated.__kwdefaults__ = fn.__kwdefaults__ # type: ignore
308 return update_wrapper(decorated, fn) # type: ignore[return-value]
309
310 return update_wrapper(decorate, target) # type: ignore[return-value]
311
312
313def _exec_code_in_env(
314 code: Union[str, types.CodeType], env: Dict[str, Any], fn_name: str
315) -> Callable[..., Any]:
316 exec(code, env)
317 return env[fn_name] # type: ignore[no-any-return]
318
319
320_PF = TypeVar("_PF")
321_TE = TypeVar("_TE")
322
323
324class PluginLoader:
325 def __init__(
326 self, group: str, auto_fn: Optional[Callable[..., Any]] = None
327 ):
328 self.group = group
329 self.impls: Dict[str, Any] = {}
330 self.auto_fn = auto_fn
331
332 def clear(self):
333 self.impls.clear()
334
335 def load(self, name: str) -> Any:
336 if name in self.impls:
337 return self.impls[name]()
338
339 if self.auto_fn:
340 loader = self.auto_fn(name)
341 if loader:
342 self.impls[name] = loader
343 return loader()
344
345 for impl in compat.importlib_metadata_get(self.group):
346 if impl.name == name:
347 self.impls[name] = impl.load
348 return impl.load()
349
350 raise exc.NoSuchModuleError(
351 "Can't load plugin: %s:%s" % (self.group, name)
352 )
353
354 def register(self, name: str, modulepath: str, objname: str) -> None:
355 def load():
356 mod = __import__(modulepath)
357 for token in modulepath.split(".")[1:]:
358 mod = getattr(mod, token)
359 return getattr(mod, objname)
360
361 self.impls[name] = load
362
363 def deregister(self, name: str) -> None:
364 del self.impls[name]
365
366
367def _inspect_func_args(fn):
368 try:
369 co_varkeywords = inspect.CO_VARKEYWORDS
370 except AttributeError:
371 # https://docs.python.org/3/library/inspect.html
372 # The flags are specific to CPython, and may not be defined in other
373 # Python implementations. Furthermore, the flags are an implementation
374 # detail, and can be removed or deprecated in future Python releases.
375 spec = compat.inspect_getfullargspec(fn)
376 return spec[0], bool(spec[2])
377 else:
378 # use fn.__code__ plus flags to reduce method call overhead
379 co = fn.__code__
380 nargs = co.co_argcount
381 return (
382 list(co.co_varnames[:nargs]),
383 bool(co.co_flags & co_varkeywords),
384 )
385
386
387@overload
388def get_cls_kwargs(
389 cls: type,
390 *,
391 _set: Optional[Set[str]] = None,
392 raiseerr: Literal[True] = ...,
393) -> Set[str]: ...
394
395
396@overload
397def get_cls_kwargs(
398 cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
399) -> Optional[Set[str]]: ...
400
401
402def get_cls_kwargs(
403 cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
404) -> Optional[Set[str]]:
405 r"""Return the full set of inherited kwargs for the given `cls`.
406
407 Probes a class's __init__ method, collecting all named arguments. If the
408 __init__ defines a \**kwargs catch-all, then the constructor is presumed
409 to pass along unrecognized keywords to its base classes, and the
410 collection process is repeated recursively on each of the bases.
411
412 Uses a subset of inspect.getfullargspec() to cut down on method overhead,
413 as this is used within the Core typing system to create copies of type
414 objects which is a performance-sensitive operation.
415
416 No anonymous tuple arguments please !
417
418 """
419 toplevel = _set is None
420 if toplevel:
421 _set = set()
422 assert _set is not None
423
424 ctr = cls.__dict__.get("__init__", False)
425
426 has_init = (
427 ctr
428 and isinstance(ctr, types.FunctionType)
429 and isinstance(ctr.__code__, types.CodeType)
430 )
431
432 if has_init:
433 names, has_kw = _inspect_func_args(ctr)
434 _set.update(names)
435
436 if not has_kw and not toplevel:
437 if raiseerr:
438 raise TypeError(
439 f"given cls {cls} doesn't have an __init__ method"
440 )
441 else:
442 return None
443 else:
444 has_kw = False
445
446 if not has_init or has_kw:
447 for c in cls.__bases__:
448 if get_cls_kwargs(c, _set=_set) is None:
449 break
450
451 _set.discard("self")
452 return _set
453
454
455def get_func_kwargs(func: Callable[..., Any]) -> List[str]:
456 """Return the set of legal kwargs for the given `func`.
457
458 Uses getargspec so is safe to call for methods, functions,
459 etc.
460
461 """
462
463 return compat.inspect_getfullargspec(func)[0]
464
465
466def get_callable_argspec(
467 fn: Callable[..., Any], no_self: bool = False, _is_init: bool = False
468) -> compat.FullArgSpec:
469 """Return the argument signature for any callable.
470
471 All pure-Python callables are accepted, including
472 functions, methods, classes, objects with __call__;
473 builtins and other edge cases like functools.partial() objects
474 raise a TypeError.
475
476 """
477 if inspect.isbuiltin(fn):
478 raise TypeError("Can't inspect builtin: %s" % fn)
479 elif inspect.isfunction(fn):
480 if _is_init and no_self:
481 spec = compat.inspect_getfullargspec(fn)
482 return compat.FullArgSpec(
483 spec.args[1:],
484 spec.varargs,
485 spec.varkw,
486 spec.defaults,
487 spec.kwonlyargs,
488 spec.kwonlydefaults,
489 spec.annotations,
490 )
491 else:
492 return compat.inspect_getfullargspec(fn)
493 elif inspect.ismethod(fn):
494 if no_self and (_is_init or fn.__self__):
495 spec = compat.inspect_getfullargspec(fn.__func__)
496 return compat.FullArgSpec(
497 spec.args[1:],
498 spec.varargs,
499 spec.varkw,
500 spec.defaults,
501 spec.kwonlyargs,
502 spec.kwonlydefaults,
503 spec.annotations,
504 )
505 else:
506 return compat.inspect_getfullargspec(fn.__func__)
507 elif inspect.isclass(fn):
508 return get_callable_argspec(
509 fn.__init__, no_self=no_self, _is_init=True
510 )
511 elif hasattr(fn, "__func__"):
512 return compat.inspect_getfullargspec(fn.__func__)
513 elif hasattr(fn, "__call__"):
514 if inspect.ismethod(fn.__call__):
515 return get_callable_argspec(fn.__call__, no_self=no_self)
516 else:
517 raise TypeError("Can't inspect callable: %s" % fn)
518 else:
519 raise TypeError("Can't inspect callable: %s" % fn)
520
521
522def format_argspec_plus(
523 fn: Union[Callable[..., Any], compat.FullArgSpec], grouped: bool = True
524) -> Dict[str, Optional[str]]:
525 """Returns a dictionary of formatted, introspected function arguments.
526
527 A enhanced variant of inspect.formatargspec to support code generation.
528
529 fn
530 An inspectable callable or tuple of inspect getargspec() results.
531 grouped
532 Defaults to True; include (parens, around, argument) lists
533
534 Returns:
535
536 args
537 Full inspect.formatargspec for fn
538 self_arg
539 The name of the first positional argument, varargs[0], or None
540 if the function defines no positional arguments.
541 apply_pos
542 args, re-written in calling rather than receiving syntax. Arguments are
543 passed positionally.
544 apply_kw
545 Like apply_pos, except keyword-ish args are passed as keywords.
546 apply_pos_proxied
547 Like apply_pos but omits the self/cls argument
548
549 Example::
550
551 >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123)
552 {'grouped_args': '(self, a, b, c=3, **d)',
553 'self_arg': 'self',
554 'apply_kw': '(self, a, b, c=c, **d)',
555 'apply_pos': '(self, a, b, c, **d)'}
556
557 """
558 if callable(fn):
559 spec = compat.inspect_getfullargspec(fn)
560 else:
561 spec = fn
562
563 args = compat.inspect_formatargspec(*spec)
564
565 apply_pos = compat.inspect_formatargspec(
566 spec[0], spec[1], spec[2], None, spec[4]
567 )
568
569 if spec[0]:
570 self_arg = spec[0][0]
571
572 apply_pos_proxied = compat.inspect_formatargspec(
573 spec[0][1:], spec[1], spec[2], None, spec[4]
574 )
575
576 elif spec[1]:
577 # I'm not sure what this is
578 self_arg = "%s[0]" % spec[1]
579
580 apply_pos_proxied = apply_pos
581 else:
582 self_arg = None
583 apply_pos_proxied = apply_pos
584
585 num_defaults = 0
586 if spec[3]:
587 num_defaults += len(cast(Tuple[Any], spec[3]))
588 if spec[4]:
589 num_defaults += len(spec[4])
590
591 name_args = spec[0] + spec[4]
592
593 defaulted_vals: Union[List[str], Tuple[()]]
594
595 if num_defaults:
596 defaulted_vals = name_args[0 - num_defaults :]
597 else:
598 defaulted_vals = ()
599
600 apply_kw = compat.inspect_formatargspec(
601 name_args,
602 spec[1],
603 spec[2],
604 defaulted_vals,
605 formatvalue=lambda x: "=" + str(x),
606 )
607
608 if spec[0]:
609 apply_kw_proxied = compat.inspect_formatargspec(
610 name_args[1:],
611 spec[1],
612 spec[2],
613 defaulted_vals,
614 formatvalue=lambda x: "=" + str(x),
615 )
616 else:
617 apply_kw_proxied = apply_kw
618
619 if grouped:
620 return dict(
621 grouped_args=args,
622 self_arg=self_arg,
623 apply_pos=apply_pos,
624 apply_kw=apply_kw,
625 apply_pos_proxied=apply_pos_proxied,
626 apply_kw_proxied=apply_kw_proxied,
627 )
628 else:
629 return dict(
630 grouped_args=args,
631 self_arg=self_arg,
632 apply_pos=apply_pos[1:-1],
633 apply_kw=apply_kw[1:-1],
634 apply_pos_proxied=apply_pos_proxied[1:-1],
635 apply_kw_proxied=apply_kw_proxied[1:-1],
636 )
637
638
639def format_argspec_init(method, grouped=True):
640 """format_argspec_plus with considerations for typical __init__ methods
641
642 Wraps format_argspec_plus with error handling strategies for typical
643 __init__ cases:
644
645 .. sourcecode:: text
646
647 object.__init__ -> (self)
648 other unreflectable (usually C) -> (self, *args, **kwargs)
649
650 """
651 if method is object.__init__:
652 grouped_args = "(self)"
653 args = "(self)" if grouped else "self"
654 proxied = "()" if grouped else ""
655 else:
656 try:
657 return format_argspec_plus(method, grouped=grouped)
658 except TypeError:
659 grouped_args = "(self, *args, **kwargs)"
660 args = grouped_args if grouped else "self, *args, **kwargs"
661 proxied = "(*args, **kwargs)" if grouped else "*args, **kwargs"
662 return dict(
663 self_arg="self",
664 grouped_args=grouped_args,
665 apply_pos=args,
666 apply_kw=args,
667 apply_pos_proxied=proxied,
668 apply_kw_proxied=proxied,
669 )
670
671
672def create_proxy_methods(
673 target_cls: Type[Any],
674 target_cls_sphinx_name: str,
675 proxy_cls_sphinx_name: str,
676 classmethods: Sequence[str] = (),
677 methods: Sequence[str] = (),
678 attributes: Sequence[str] = (),
679 use_intermediate_variable: Sequence[str] = (),
680) -> Callable[[_T], _T]:
681 """A class decorator indicating attributes should refer to a proxy
682 class.
683
684 This decorator is now a "marker" that does nothing at runtime. Instead,
685 it is consumed by the tools/generate_proxy_methods.py script to
686 statically generate proxy methods and attributes that are fully
687 recognized by typing tools such as mypy.
688
689 """
690
691 def decorate(cls):
692 return cls
693
694 return decorate
695
696
697def getargspec_init(method):
698 """inspect.getargspec with considerations for typical __init__ methods
699
700 Wraps inspect.getargspec with error handling for typical __init__ cases:
701
702 .. sourcecode:: text
703
704 object.__init__ -> (self)
705 other unreflectable (usually C) -> (self, *args, **kwargs)
706
707 """
708 try:
709 return compat.inspect_getfullargspec(method)
710 except TypeError:
711 if method is object.__init__:
712 return (["self"], None, None, None)
713 else:
714 return (["self"], "args", "kwargs", None)
715
716
717def unbound_method_to_callable(func_or_cls):
718 """Adjust the incoming callable such that a 'self' argument is not
719 required.
720
721 """
722
723 if isinstance(func_or_cls, types.MethodType) and not func_or_cls.__self__:
724 return func_or_cls.__func__
725 else:
726 return func_or_cls
727
728
729def generic_repr(
730 obj: Any,
731 additional_kw: Sequence[Tuple[str, Any]] = (),
732 to_inspect: Optional[Union[object, List[object]]] = None,
733 omit_kwarg: Sequence[str] = (),
734) -> str:
735 """Produce a __repr__() based on direct association of the __init__()
736 specification vs. same-named attributes present.
737
738 """
739 if to_inspect is None:
740 to_inspect = [obj]
741 else:
742 to_inspect = _collections.to_list(to_inspect)
743
744 missing = object()
745
746 pos_args = []
747 kw_args: _collections.OrderedDict[str, Any] = _collections.OrderedDict()
748 vargs = None
749 for i, insp in enumerate(to_inspect):
750 try:
751 spec = compat.inspect_getfullargspec(insp.__init__)
752 except TypeError:
753 continue
754 else:
755 default_len = len(spec.defaults) if spec.defaults else 0
756 if i == 0:
757 if spec.varargs:
758 vargs = spec.varargs
759 if default_len:
760 pos_args.extend(spec.args[1:-default_len])
761 else:
762 pos_args.extend(spec.args[1:])
763 else:
764 kw_args.update(
765 [(arg, missing) for arg in spec.args[1:-default_len]]
766 )
767
768 if default_len:
769 assert spec.defaults
770 kw_args.update(
771 [
772 (arg, default)
773 for arg, default in zip(
774 spec.args[-default_len:], spec.defaults
775 )
776 ]
777 )
778 output: List[str] = []
779
780 output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
781
782 if vargs is not None and hasattr(obj, vargs):
783 output.extend([repr(val) for val in getattr(obj, vargs)])
784
785 for arg, defval in kw_args.items():
786 if arg in omit_kwarg:
787 continue
788 try:
789 val = getattr(obj, arg, missing)
790 if val is not missing and val != defval:
791 output.append("%s=%r" % (arg, val))
792 except Exception:
793 pass
794
795 if additional_kw:
796 for arg, defval in additional_kw:
797 try:
798 val = getattr(obj, arg, missing)
799 if val is not missing and val != defval:
800 output.append("%s=%r" % (arg, val))
801 except Exception:
802 pass
803
804 return "%s(%s)" % (obj.__class__.__name__, ", ".join(output))
805
806
807class portable_instancemethod:
808 """Turn an instancemethod into a (parent, name) pair
809 to produce a serializable callable.
810
811 """
812
813 __slots__ = "target", "name", "kwargs", "__weakref__"
814
815 def __getstate__(self):
816 return {
817 "target": self.target,
818 "name": self.name,
819 "kwargs": self.kwargs,
820 }
821
822 def __setstate__(self, state):
823 self.target = state["target"]
824 self.name = state["name"]
825 self.kwargs = state.get("kwargs", ())
826
827 def __init__(self, meth, kwargs=()):
828 self.target = meth.__self__
829 self.name = meth.__name__
830 self.kwargs = kwargs
831
832 def __call__(self, *arg, **kw):
833 kw.update(self.kwargs)
834 return getattr(self.target, self.name)(*arg, **kw)
835
836
837def class_hierarchy(cls):
838 """Return an unordered sequence of all classes related to cls.
839
840 Traverses diamond hierarchies.
841
842 Fibs slightly: subclasses of builtin types are not returned. Thus
843 class_hierarchy(class A(object)) returns (A, object), not A plus every
844 class systemwide that derives from object.
845
846 """
847
848 hier = {cls}
849 process = list(cls.__mro__)
850 while process:
851 c = process.pop()
852 bases = (_ for _ in c.__bases__ if _ not in hier)
853
854 for b in bases:
855 process.append(b)
856 hier.add(b)
857
858 if c.__module__ == "builtins" or not hasattr(c, "__subclasses__"):
859 continue
860
861 for s in [
862 _
863 for _ in (
864 c.__subclasses__()
865 if not issubclass(c, type)
866 else c.__subclasses__(c)
867 )
868 if _ not in hier
869 ]:
870 process.append(s)
871 hier.add(s)
872 return list(hier)
873
874
875def iterate_attributes(cls):
876 """iterate all the keys and attributes associated
877 with a class, without using getattr().
878
879 Does not use getattr() so that class-sensitive
880 descriptors (i.e. property.__get__()) are not called.
881
882 """
883 keys = dir(cls)
884 for key in keys:
885 for c in cls.__mro__:
886 if key in c.__dict__:
887 yield (key, c.__dict__[key])
888 break
889
890
891def monkeypatch_proxied_specials(
892 into_cls,
893 from_cls,
894 skip=None,
895 only=None,
896 name="self.proxy",
897 from_instance=None,
898):
899 """Automates delegation of __specials__ for a proxying type."""
900
901 if only:
902 dunders = only
903 else:
904 if skip is None:
905 skip = (
906 "__slots__",
907 "__del__",
908 "__getattribute__",
909 "__metaclass__",
910 "__getstate__",
911 "__setstate__",
912 )
913 dunders = [
914 m
915 for m in dir(from_cls)
916 if (
917 m.startswith("__")
918 and m.endswith("__")
919 and not hasattr(into_cls, m)
920 and m not in skip
921 )
922 ]
923
924 for method in dunders:
925 try:
926 maybe_fn = getattr(from_cls, method)
927 if not hasattr(maybe_fn, "__call__"):
928 continue
929 maybe_fn = getattr(maybe_fn, "__func__", maybe_fn)
930 fn = cast(types.FunctionType, maybe_fn)
931
932 except AttributeError:
933 continue
934 try:
935 spec = compat.inspect_getfullargspec(fn)
936 fn_args = compat.inspect_formatargspec(spec[0])
937 d_args = compat.inspect_formatargspec(spec[0][1:])
938 except TypeError:
939 fn_args = "(self, *args, **kw)"
940 d_args = "(*args, **kw)"
941
942 py = (
943 "def %(method)s%(fn_args)s: "
944 "return %(name)s.%(method)s%(d_args)s" % locals()
945 )
946
947 env: Dict[str, types.FunctionType] = (
948 from_instance is not None and {name: from_instance} or {}
949 )
950 exec(py, env)
951 try:
952 env[method].__defaults__ = fn.__defaults__
953 except AttributeError:
954 pass
955 setattr(into_cls, method, env[method])
956
957
958def methods_equivalent(meth1, meth2):
959 """Return True if the two methods are the same implementation."""
960
961 return getattr(meth1, "__func__", meth1) is getattr(
962 meth2, "__func__", meth2
963 )
964
965
966def as_interface(obj, cls=None, methods=None, required=None):
967 """Ensure basic interface compliance for an instance or dict of callables.
968
969 Checks that ``obj`` implements public methods of ``cls`` or has members
970 listed in ``methods``. If ``required`` is not supplied, implementing at
971 least one interface method is sufficient. Methods present on ``obj`` that
972 are not in the interface are ignored.
973
974 If ``obj`` is a dict and ``dict`` does not meet the interface
975 requirements, the keys of the dictionary are inspected. Keys present in
976 ``obj`` that are not in the interface will raise TypeErrors.
977
978 Raises TypeError if ``obj`` does not meet the interface criteria.
979
980 In all passing cases, an object with callable members is returned. In the
981 simple case, ``obj`` is returned as-is; if dict processing kicks in then
982 an anonymous class is returned.
983
984 obj
985 A type, instance, or dictionary of callables.
986 cls
987 Optional, a type. All public methods of cls are considered the
988 interface. An ``obj`` instance of cls will always pass, ignoring
989 ``required``..
990 methods
991 Optional, a sequence of method names to consider as the interface.
992 required
993 Optional, a sequence of mandatory implementations. If omitted, an
994 ``obj`` that provides at least one interface method is considered
995 sufficient. As a convenience, required may be a type, in which case
996 all public methods of the type are required.
997
998 """
999 if not cls and not methods:
1000 raise TypeError("a class or collection of method names are required")
1001
1002 if isinstance(cls, type) and isinstance(obj, cls):
1003 return obj
1004
1005 interface = set(methods or [m for m in dir(cls) if not m.startswith("_")])
1006 implemented = set(dir(obj))
1007
1008 complies = operator.ge
1009 if isinstance(required, type):
1010 required = interface
1011 elif not required:
1012 required = set()
1013 complies = operator.gt
1014 else:
1015 required = set(required)
1016
1017 if complies(implemented.intersection(interface), required):
1018 return obj
1019
1020 # No dict duck typing here.
1021 if not isinstance(obj, dict):
1022 qualifier = complies is operator.gt and "any of" or "all of"
1023 raise TypeError(
1024 "%r does not implement %s: %s"
1025 % (obj, qualifier, ", ".join(interface))
1026 )
1027
1028 class AnonymousInterface:
1029 """A callable-holding shell."""
1030
1031 if cls:
1032 AnonymousInterface.__name__ = "Anonymous" + cls.__name__
1033 found = set()
1034
1035 for method, impl in dictlike_iteritems(obj):
1036 if method not in interface:
1037 raise TypeError("%r: unknown in this interface" % method)
1038 if not callable(impl):
1039 raise TypeError("%r=%r is not callable" % (method, impl))
1040 setattr(AnonymousInterface, method, staticmethod(impl))
1041 found.add(method)
1042
1043 if complies(found, required):
1044 return AnonymousInterface
1045
1046 raise TypeError(
1047 "dictionary does not contain required keys %s"
1048 % ", ".join(required - found)
1049 )
1050
1051
1052_GFD = TypeVar("_GFD", bound="generic_fn_descriptor[Any]")
1053
1054
1055class generic_fn_descriptor(Generic[_T_co]):
1056 """Descriptor which proxies a function when the attribute is not
1057 present in dict
1058
1059 This superclass is organized in a particular way with "memoized" and
1060 "non-memoized" implementation classes that are hidden from type checkers,
1061 as Mypy seems to not be able to handle seeing multiple kinds of descriptor
1062 classes used for the same attribute.
1063
1064 """
1065
1066 fget: Callable[..., _T_co]
1067 __doc__: Optional[str]
1068 __name__: str
1069
1070 def __init__(self, fget: Callable[..., _T_co], doc: Optional[str] = None):
1071 self.fget = fget
1072 self.__doc__ = doc or fget.__doc__
1073 self.__name__ = fget.__name__
1074
1075 @overload
1076 def __get__(self: _GFD, obj: None, cls: Any) -> _GFD: ...
1077
1078 @overload
1079 def __get__(self, obj: object, cls: Any) -> _T_co: ...
1080
1081 def __get__(self: _GFD, obj: Any, cls: Any) -> Union[_GFD, _T_co]:
1082 raise NotImplementedError()
1083
1084 if TYPE_CHECKING:
1085
1086 def __set__(self, instance: Any, value: Any) -> None: ...
1087
1088 def __delete__(self, instance: Any) -> None: ...
1089
1090 def _reset(self, obj: Any) -> None:
1091 raise NotImplementedError()
1092
1093 @classmethod
1094 def reset(cls, obj: Any, name: str) -> None:
1095 raise NotImplementedError()
1096
1097
1098class _non_memoized_property(generic_fn_descriptor[_T_co]):
1099 """a plain descriptor that proxies a function.
1100
1101 primary rationale is to provide a plain attribute that's
1102 compatible with memoized_property which is also recognized as equivalent
1103 by mypy.
1104
1105 """
1106
1107 if not TYPE_CHECKING:
1108
1109 def __get__(self, obj, cls):
1110 if obj is None:
1111 return self
1112 return self.fget(obj)
1113
1114
1115class _memoized_property(generic_fn_descriptor[_T_co]):
1116 """A read-only @property that is only evaluated once."""
1117
1118 if not TYPE_CHECKING:
1119
1120 def __get__(self, obj, cls):
1121 if obj is None:
1122 return self
1123 obj.__dict__[self.__name__] = result = self.fget(obj)
1124 return result
1125
1126 def _reset(self, obj):
1127 _memoized_property.reset(obj, self.__name__)
1128
1129 @classmethod
1130 def reset(cls, obj, name):
1131 obj.__dict__.pop(name, None)
1132
1133
1134# despite many attempts to get Mypy to recognize an overridden descriptor
1135# where one is memoized and the other isn't, there seems to be no reliable
1136# way other than completely deceiving the type checker into thinking there
1137# is just one single descriptor type everywhere. Otherwise, if a superclass
1138# has non-memoized and subclass has memoized, that requires
1139# "class memoized(non_memoized)". but then if a superclass has memoized and
1140# superclass has non-memoized, the class hierarchy of the descriptors
1141# would need to be reversed; "class non_memoized(memoized)". so there's no
1142# way to achieve this.
1143# additional issues, RO properties:
1144# https://github.com/python/mypy/issues/12440
1145if TYPE_CHECKING:
1146 # allow memoized and non-memoized to be freely mixed by having them
1147 # be the same class
1148 memoized_property = generic_fn_descriptor
1149 non_memoized_property = generic_fn_descriptor
1150
1151 # for read only situations, mypy only sees @property as read only.
1152 # read only is needed when a subtype specializes the return type
1153 # of a property, meaning assignment needs to be disallowed
1154 ro_memoized_property = property
1155 ro_non_memoized_property = property
1156
1157else:
1158 memoized_property = ro_memoized_property = _memoized_property
1159 non_memoized_property = ro_non_memoized_property = _non_memoized_property
1160
1161
1162def memoized_instancemethod(fn: _F) -> _F:
1163 """Decorate a method memoize its return value.
1164
1165 Best applied to no-arg methods: memoization is not sensitive to
1166 argument values, and will always return the same value even when
1167 called with different arguments.
1168
1169 """
1170
1171 def oneshot(self, *args, **kw):
1172 result = fn(self, *args, **kw)
1173
1174 def memo(*a, **kw):
1175 return result
1176
1177 memo.__name__ = fn.__name__
1178 memo.__doc__ = fn.__doc__
1179 self.__dict__[fn.__name__] = memo
1180 return result
1181
1182 return update_wrapper(oneshot, fn) # type: ignore
1183
1184
1185class HasMemoized:
1186 """A mixin class that maintains the names of memoized elements in a
1187 collection for easy cache clearing, generative, etc.
1188
1189 """
1190
1191 if not TYPE_CHECKING:
1192 # support classes that want to have __slots__ with an explicit
1193 # slot for __dict__. not sure if that requires base __slots__ here.
1194 __slots__ = ()
1195
1196 _memoized_keys: FrozenSet[str] = frozenset()
1197
1198 def _reset_memoizations(self) -> None:
1199 for elem in self._memoized_keys:
1200 self.__dict__.pop(elem, None)
1201
1202 def _assert_no_memoizations(self) -> None:
1203 for elem in self._memoized_keys:
1204 assert elem not in self.__dict__
1205
1206 def _set_memoized_attribute(self, key: str, value: Any) -> None:
1207 self.__dict__[key] = value
1208 self._memoized_keys |= {key}
1209
1210 class memoized_attribute(memoized_property[_T]):
1211 """A read-only @property that is only evaluated once.
1212
1213 :meta private:
1214
1215 """
1216
1217 fget: Callable[..., _T]
1218 __doc__: Optional[str]
1219 __name__: str
1220
1221 def __init__(self, fget: Callable[..., _T], doc: Optional[str] = None):
1222 self.fget = fget
1223 self.__doc__ = doc or fget.__doc__
1224 self.__name__ = fget.__name__
1225
1226 @overload
1227 def __get__(self: _MA, obj: None, cls: Any) -> _MA: ...
1228
1229 @overload
1230 def __get__(self, obj: Any, cls: Any) -> _T: ...
1231
1232 def __get__(self, obj, cls):
1233 if obj is None:
1234 return self
1235 obj.__dict__[self.__name__] = result = self.fget(obj)
1236 obj._memoized_keys |= {self.__name__}
1237 return result
1238
1239 @classmethod
1240 def memoized_instancemethod(cls, fn: _F) -> _F:
1241 """Decorate a method memoize its return value.
1242
1243 :meta private:
1244
1245 """
1246
1247 def oneshot(self: Any, *args: Any, **kw: Any) -> Any:
1248 result = fn(self, *args, **kw)
1249
1250 def memo(*a, **kw):
1251 return result
1252
1253 memo.__name__ = fn.__name__
1254 memo.__doc__ = fn.__doc__
1255 self.__dict__[fn.__name__] = memo
1256 self._memoized_keys |= {fn.__name__}
1257 return result
1258
1259 return update_wrapper(oneshot, fn) # type: ignore
1260
1261
1262if TYPE_CHECKING:
1263 HasMemoized_ro_memoized_attribute = property
1264else:
1265 HasMemoized_ro_memoized_attribute = HasMemoized.memoized_attribute
1266
1267
1268class MemoizedSlots:
1269 """Apply memoized items to an object using a __getattr__ scheme.
1270
1271 This allows the functionality of memoized_property and
1272 memoized_instancemethod to be available to a class using __slots__.
1273
1274 The memoized get is not threadsafe under freethreading and the
1275 creator method may in extremely rare cases be called more than once.
1276
1277 """
1278
1279 __slots__ = ()
1280
1281 def _fallback_getattr(self, key):
1282 raise AttributeError(key)
1283
1284 def __getattr__(self, key: str) -> Any:
1285 if key.startswith("_memoized_attr_") or key.startswith(
1286 "_memoized_method_"
1287 ):
1288 raise AttributeError(key)
1289 # to avoid recursion errors when interacting with other __getattr__
1290 # schemes that refer to this one, when testing for memoized method
1291 # look at __class__ only rather than going into __getattr__ again.
1292 elif hasattr(self.__class__, f"_memoized_attr_{key}"):
1293 value = getattr(self, f"_memoized_attr_{key}")()
1294 setattr(self, key, value)
1295 return value
1296 elif hasattr(self.__class__, f"_memoized_method_{key}"):
1297 meth = getattr(self, f"_memoized_method_{key}")
1298
1299 def oneshot(*args, **kw):
1300 result = meth(*args, **kw)
1301
1302 def memo(*a, **kw):
1303 return result
1304
1305 memo.__name__ = meth.__name__
1306 memo.__doc__ = meth.__doc__
1307 setattr(self, key, memo)
1308 return result
1309
1310 oneshot.__doc__ = meth.__doc__
1311 return oneshot
1312 else:
1313 return self._fallback_getattr(key)
1314
1315
1316# from paste.deploy.converters
1317def asbool(obj: Any) -> bool:
1318 if isinstance(obj, str):
1319 obj = obj.strip().lower()
1320 if obj in ["true", "yes", "on", "y", "t", "1"]:
1321 return True
1322 elif obj in ["false", "no", "off", "n", "f", "0"]:
1323 return False
1324 else:
1325 raise ValueError("String is not true/false: %r" % obj)
1326 return bool(obj)
1327
1328
1329def bool_or_str(*text: str) -> Callable[[str], Union[str, bool]]:
1330 """Return a callable that will evaluate a string as
1331 boolean, or one of a set of "alternate" string values.
1332
1333 """
1334
1335 def bool_or_value(obj: str) -> Union[str, bool]:
1336 if obj in text:
1337 return obj
1338 else:
1339 return asbool(obj)
1340
1341 return bool_or_value
1342
1343
1344def asint(value: Any) -> Optional[int]:
1345 """Coerce to integer."""
1346
1347 if value is None:
1348 return value
1349 return int(value)
1350
1351
1352def coerce_kw_type(
1353 kw: Dict[str, Any],
1354 key: str,
1355 type_: Type[Any],
1356 flexi_bool: bool = True,
1357 dest: Optional[Dict[str, Any]] = None,
1358) -> None:
1359 r"""If 'key' is present in dict 'kw', coerce its value to type 'type\_' if
1360 necessary. If 'flexi_bool' is True, the string '0' is considered false
1361 when coercing to boolean.
1362 """
1363
1364 if dest is None:
1365 dest = kw
1366
1367 if (
1368 key in kw
1369 and (not isinstance(type_, type) or not isinstance(kw[key], type_))
1370 and kw[key] is not None
1371 ):
1372 if type_ is bool and flexi_bool:
1373 dest[key] = asbool(kw[key])
1374 else:
1375 dest[key] = type_(kw[key])
1376
1377
1378def constructor_key(obj: Any, cls: Type[Any]) -> Tuple[Any, ...]:
1379 """Produce a tuple structure that is cacheable using the __dict__ of
1380 obj to retrieve values
1381
1382 """
1383 names = get_cls_kwargs(cls)
1384 return (cls,) + tuple(
1385 (k, obj.__dict__[k]) for k in names if k in obj.__dict__
1386 )
1387
1388
1389def constructor_copy(obj: _T, cls: Type[_T], *args: Any, **kw: Any) -> _T:
1390 """Instantiate cls using the __dict__ of obj as constructor arguments.
1391
1392 Uses inspect to match the named arguments of ``cls``.
1393
1394 """
1395
1396 names = get_cls_kwargs(cls)
1397 kw.update(
1398 (k, obj.__dict__[k]) for k in names.difference(kw) if k in obj.__dict__
1399 )
1400 return cls(*args, **kw)
1401
1402
1403def counter() -> Callable[[], int]:
1404 """Return a threadsafe counter function."""
1405
1406 lock = threading.Lock()
1407 counter = itertools.count(1)
1408
1409 # avoid the 2to3 "next" transformation...
1410 def _next():
1411 with lock:
1412 return next(counter)
1413
1414 return _next
1415
1416
1417def duck_type_collection(
1418 specimen: Any, default: Optional[Type[Any]] = None
1419) -> Optional[Type[Any]]:
1420 """Given an instance or class, guess if it is or is acting as one of
1421 the basic collection types: list, set and dict. If the __emulates__
1422 property is present, return that preferentially.
1423 """
1424
1425 if hasattr(specimen, "__emulates__"):
1426 # canonicalize set vs sets.Set to a standard: the builtin set
1427 if specimen.__emulates__ is not None and issubclass(
1428 specimen.__emulates__, set
1429 ):
1430 return set
1431 else:
1432 return specimen.__emulates__ # type: ignore
1433
1434 isa = issubclass if isinstance(specimen, type) else isinstance
1435 if isa(specimen, list):
1436 return list
1437 elif isa(specimen, set):
1438 return set
1439 elif isa(specimen, dict):
1440 return dict
1441
1442 if hasattr(specimen, "append"):
1443 return list
1444 elif hasattr(specimen, "add"):
1445 return set
1446 elif hasattr(specimen, "set"):
1447 return dict
1448 else:
1449 return default
1450
1451
1452def assert_arg_type(
1453 arg: Any, argtype: Union[Tuple[Type[Any], ...], Type[Any]], name: str
1454) -> Any:
1455 if isinstance(arg, argtype):
1456 return arg
1457 else:
1458 if isinstance(argtype, tuple):
1459 raise exc.ArgumentError(
1460 "Argument '%s' is expected to be one of type %s, got '%s'"
1461 % (name, " or ".join("'%s'" % a for a in argtype), type(arg))
1462 )
1463 else:
1464 raise exc.ArgumentError(
1465 "Argument '%s' is expected to be of type '%s', got '%s'"
1466 % (name, argtype, type(arg))
1467 )
1468
1469
1470def dictlike_iteritems(dictlike):
1471 """Return a (key, value) iterator for almost any dict-like object."""
1472
1473 if hasattr(dictlike, "items"):
1474 return list(dictlike.items())
1475
1476 getter = getattr(dictlike, "__getitem__", getattr(dictlike, "get", None))
1477 if getter is None:
1478 raise TypeError("Object '%r' is not dict-like" % dictlike)
1479
1480 if hasattr(dictlike, "iterkeys"):
1481
1482 def iterator():
1483 for key in dictlike.iterkeys():
1484 assert getter is not None
1485 yield key, getter(key)
1486
1487 return iterator()
1488 elif hasattr(dictlike, "keys"):
1489 return iter((key, getter(key)) for key in dictlike.keys())
1490 else:
1491 raise TypeError("Object '%r' is not dict-like" % dictlike)
1492
1493
1494class classproperty(property):
1495 """A decorator that behaves like @property except that operates
1496 on classes rather than instances.
1497
1498 The decorator is currently special when using the declarative
1499 module, but note that the
1500 :class:`~.sqlalchemy.ext.declarative.declared_attr`
1501 decorator should be used for this purpose with declarative.
1502
1503 """
1504
1505 fget: Callable[[Any], Any]
1506
1507 def __init__(self, fget: Callable[[Any], Any], *arg: Any, **kw: Any):
1508 super().__init__(fget, *arg, **kw)
1509 self.__doc__ = fget.__doc__
1510
1511 def __get__(self, obj: Any, cls: Optional[type] = None) -> Any:
1512 return self.fget(cls)
1513
1514
1515class hybridproperty(Generic[_T]):
1516 def __init__(self, func: Callable[..., _T]):
1517 self.func = func
1518 self.clslevel = func
1519
1520 def __get__(self, instance: Any, owner: Any) -> _T:
1521 if instance is None:
1522 clsval = self.clslevel(owner)
1523 return clsval
1524 else:
1525 return self.func(instance)
1526
1527 def classlevel(self, func: Callable[..., Any]) -> hybridproperty[_T]:
1528 self.clslevel = func
1529 return self
1530
1531
1532class rw_hybridproperty(Generic[_T]):
1533 def __init__(self, func: Callable[..., _T]):
1534 self.func = func
1535 self.clslevel = func
1536 self.setfn: Optional[Callable[..., Any]] = None
1537
1538 def __get__(self, instance: Any, owner: Any) -> _T:
1539 if instance is None:
1540 clsval = self.clslevel(owner)
1541 return clsval
1542 else:
1543 return self.func(instance)
1544
1545 def __set__(self, instance: Any, value: Any) -> None:
1546 assert self.setfn is not None
1547 self.setfn(instance, value)
1548
1549 def setter(self, func: Callable[..., Any]) -> rw_hybridproperty[_T]:
1550 self.setfn = func
1551 return self
1552
1553 def classlevel(self, func: Callable[..., Any]) -> rw_hybridproperty[_T]:
1554 self.clslevel = func
1555 return self
1556
1557
1558class hybridmethod(Generic[_T]):
1559 """Decorate a function as cls- or instance- level."""
1560
1561 def __init__(self, func: Callable[..., _T]):
1562 self.func = self.__func__ = func
1563 self.clslevel = func
1564
1565 def __get__(self, instance: Any, owner: Any) -> Callable[..., _T]:
1566 if instance is None:
1567 return self.clslevel.__get__(owner, owner.__class__) # type:ignore
1568 else:
1569 return self.func.__get__(instance, owner) # type:ignore
1570
1571 def classlevel(self, func: Callable[..., Any]) -> hybridmethod[_T]:
1572 self.clslevel = func
1573 return self
1574
1575
1576class symbol(int):
1577 """A constant symbol.
1578
1579 >>> symbol("foo") is symbol("foo")
1580 True
1581 >>> symbol("foo")
1582 <symbol 'foo>
1583
1584 A slight refinement of the MAGICCOOKIE=object() pattern. The primary
1585 advantage of symbol() is its repr(). They are also singletons.
1586
1587 Repeated calls of symbol('name') will all return the same instance.
1588
1589 """
1590
1591 name: str
1592
1593 symbols: Dict[str, symbol] = {}
1594 _lock = threading.Lock()
1595
1596 def __new__(
1597 cls,
1598 name: str,
1599 doc: Optional[str] = None,
1600 canonical: Optional[int] = None,
1601 ) -> symbol:
1602 with cls._lock:
1603 sym = cls.symbols.get(name)
1604 if sym is None:
1605 assert isinstance(name, str)
1606 if canonical is None:
1607 canonical = hash(name)
1608 sym = int.__new__(symbol, canonical)
1609 sym.name = name
1610 if doc:
1611 sym.__doc__ = doc
1612
1613 # NOTE: we should ultimately get rid of this global thing,
1614 # however, currently it is to support pickling. The best
1615 # change would be when we are on py3.11 at a minimum, we
1616 # switch to stdlib enum.IntFlag.
1617 cls.symbols[name] = sym
1618 else:
1619 if canonical and canonical != sym:
1620 raise TypeError(
1621 f"Can't replace canonical symbol for {name!r} "
1622 f"with new int value {canonical}"
1623 )
1624 return sym
1625
1626 def __reduce__(self):
1627 return symbol, (self.name, "x", int(self))
1628
1629 def __str__(self):
1630 return repr(self)
1631
1632 def __repr__(self):
1633 return f"symbol({self.name!r})"
1634
1635
1636class _IntFlagMeta(type):
1637 def __init__(
1638 cls,
1639 classname: str,
1640 bases: Tuple[Type[Any], ...],
1641 dict_: Dict[str, Any],
1642 **kw: Any,
1643 ) -> None:
1644 items: List[symbol]
1645 cls._items = items = []
1646 for k, v in dict_.items():
1647 if re.match(r"^__.*__$", k):
1648 continue
1649 if isinstance(v, int):
1650 sym = symbol(k, canonical=v)
1651 elif not k.startswith("_"):
1652 raise TypeError("Expected integer values for IntFlag")
1653 else:
1654 continue
1655 setattr(cls, k, sym)
1656 items.append(sym)
1657
1658 cls.__members__ = _collections.immutabledict(
1659 {sym.name: sym for sym in items}
1660 )
1661
1662 def __iter__(self) -> Iterator[symbol]:
1663 raise NotImplementedError(
1664 "iter not implemented to ensure compatibility with "
1665 "Python 3.11 IntFlag. Please use __members__. See "
1666 "https://github.com/python/cpython/issues/99304"
1667 )
1668
1669
1670class _FastIntFlag(metaclass=_IntFlagMeta):
1671 """An 'IntFlag' copycat that isn't slow when performing bitwise
1672 operations.
1673
1674 the ``FastIntFlag`` class will return ``enum.IntFlag`` under TYPE_CHECKING
1675 and ``_FastIntFlag`` otherwise.
1676
1677 """
1678
1679
1680if TYPE_CHECKING:
1681 from enum import IntFlag
1682
1683 FastIntFlag = IntFlag
1684else:
1685 FastIntFlag = _FastIntFlag
1686
1687
1688_E = TypeVar("_E", bound=enum.Enum)
1689
1690
1691def parse_user_argument_for_enum(
1692 arg: Any,
1693 choices: Dict[_E, List[Any]],
1694 name: str,
1695 resolve_symbol_names: bool = False,
1696) -> Optional[_E]:
1697 """Given a user parameter, parse the parameter into a chosen value
1698 from a list of choice objects, typically Enum values.
1699
1700 The user argument can be a string name that matches the name of a
1701 symbol, or the symbol object itself, or any number of alternate choices
1702 such as True/False/ None etc.
1703
1704 :param arg: the user argument.
1705 :param choices: dictionary of enum values to lists of possible
1706 entries for each.
1707 :param name: name of the argument. Used in an :class:`.ArgumentError`
1708 that is raised if the parameter doesn't match any available argument.
1709
1710 """
1711 for enum_value, choice in choices.items():
1712 if arg is enum_value:
1713 return enum_value
1714 elif resolve_symbol_names and arg == enum_value.name:
1715 return enum_value
1716 elif arg in choice:
1717 return enum_value
1718
1719 if arg is None:
1720 return None
1721
1722 raise exc.ArgumentError(f"Invalid value for '{name}': {arg!r}")
1723
1724
1725_creation_order = 1
1726
1727
1728def set_creation_order(instance: Any) -> None:
1729 """Assign a '_creation_order' sequence to the given instance.
1730
1731 This allows multiple instances to be sorted in order of creation
1732 (typically within a single thread; the counter is not particularly
1733 threadsafe).
1734
1735 """
1736 global _creation_order
1737 instance._creation_order = _creation_order
1738 _creation_order += 1
1739
1740
1741def warn_exception(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
1742 """executes the given function, catches all exceptions and converts to
1743 a warning.
1744
1745 """
1746 try:
1747 return func(*args, **kwargs)
1748 except Exception:
1749 warn("%s('%s') ignored" % sys.exc_info()[0:2])
1750
1751
1752def ellipses_string(value, len_=25):
1753 try:
1754 if len(value) > len_:
1755 return "%s..." % value[0:len_]
1756 else:
1757 return value
1758 except TypeError:
1759 return value
1760
1761
1762class _hash_limit_string(str):
1763 """A string subclass that can only be hashed on a maximum amount
1764 of unique values.
1765
1766 This is used for warnings so that we can send out parameterized warnings
1767 without the __warningregistry__ of the module, or the non-overridable
1768 "once" registry within warnings.py, overloading memory,
1769
1770
1771 """
1772
1773 _hash: int
1774
1775 def __new__(
1776 cls, value: str, num: int, args: Sequence[Any]
1777 ) -> _hash_limit_string:
1778 interpolated = (value % args) + (
1779 " (this warning may be suppressed after %d occurrences)" % num
1780 )
1781 self = super().__new__(cls, interpolated)
1782 self._hash = hash("%s_%d" % (value, hash(interpolated) % num))
1783 return self
1784
1785 def __hash__(self) -> int:
1786 return self._hash
1787
1788 def __eq__(self, other: Any) -> bool:
1789 return hash(self) == hash(other)
1790
1791
1792def warn(msg: str, code: Optional[str] = None) -> None:
1793 """Issue a warning.
1794
1795 If msg is a string, :class:`.exc.SAWarning` is used as
1796 the category.
1797
1798 """
1799 if code:
1800 _warnings_warn(exc.SAWarning(msg, code=code))
1801 else:
1802 _warnings_warn(msg, exc.SAWarning)
1803
1804
1805def warn_limited(msg: str, args: Sequence[Any]) -> None:
1806 """Issue a warning with a parameterized string, limiting the number
1807 of registrations.
1808
1809 """
1810 if args:
1811 msg = _hash_limit_string(msg, 10, args)
1812 _warnings_warn(msg, exc.SAWarning)
1813
1814
1815_warning_tags: Dict[CodeType, Tuple[str, Type[Warning]]] = {}
1816
1817
1818def tag_method_for_warnings(
1819 message: str, category: Type[Warning]
1820) -> Callable[[_F], _F]:
1821 def go(fn):
1822 _warning_tags[fn.__code__] = (message, category)
1823 return fn
1824
1825 return go
1826
1827
1828_not_sa_pattern = re.compile(r"^(?:sqlalchemy\.(?!testing)|alembic\.)")
1829
1830
1831def _warnings_warn(
1832 message: Union[str, Warning],
1833 category: Optional[Type[Warning]] = None,
1834 stacklevel: int = 2,
1835) -> None:
1836 # adjust the given stacklevel to be outside of SQLAlchemy
1837 try:
1838 frame = sys._getframe(stacklevel)
1839 except ValueError:
1840 # being called from less than 3 (or given) stacklevels, weird,
1841 # but don't crash
1842 stacklevel = 0
1843 except:
1844 # _getframe() doesn't work, weird interpreter issue, weird,
1845 # ok, but don't crash
1846 stacklevel = 0
1847 else:
1848 stacklevel_found = warning_tag_found = False
1849 while frame is not None:
1850 # using __name__ here requires that we have __name__ in the
1851 # __globals__ of the decorated string functions we make also.
1852 # we generate this using {"__name__": fn.__module__}
1853 if not stacklevel_found and not re.match(
1854 _not_sa_pattern, frame.f_globals.get("__name__", "")
1855 ):
1856 # stop incrementing stack level if an out-of-SQLA line
1857 # were found.
1858 stacklevel_found = True
1859
1860 # however, for the warning tag thing, we have to keep
1861 # scanning up the whole traceback
1862
1863 if frame.f_code in _warning_tags:
1864 warning_tag_found = True
1865 (_suffix, _category) = _warning_tags[frame.f_code]
1866 category = category or _category
1867 message = f"{message} ({_suffix})"
1868
1869 frame = frame.f_back # type: ignore[assignment]
1870
1871 if not stacklevel_found:
1872 stacklevel += 1
1873 elif stacklevel_found and warning_tag_found:
1874 break
1875
1876 if category is not None:
1877 warnings.warn(message, category, stacklevel=stacklevel + 1)
1878 else:
1879 warnings.warn(message, stacklevel=stacklevel + 1)
1880
1881
1882def only_once(
1883 fn: Callable[..., _T], retry_on_exception: bool
1884) -> Callable[..., Optional[_T]]:
1885 """Decorate the given function to be a no-op after it is called exactly
1886 once."""
1887
1888 once = [fn]
1889
1890 def go(*arg: Any, **kw: Any) -> Optional[_T]:
1891 # strong reference fn so that it isn't garbage collected,
1892 # which interferes with the event system's expectations
1893 strong_fn = fn # noqa
1894 if once:
1895 once_fn = once.pop()
1896 try:
1897 return once_fn(*arg, **kw)
1898 except:
1899 if retry_on_exception:
1900 once.insert(0, once_fn)
1901 raise
1902
1903 return None
1904
1905 return go
1906
1907
1908_SQLA_RE = re.compile(r"sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py")
1909_UNITTEST_RE = re.compile(r"unit(?:2|test2?/)")
1910
1911
1912def chop_traceback(
1913 tb: List[str],
1914 exclude_prefix: re.Pattern[str] = _UNITTEST_RE,
1915 exclude_suffix: re.Pattern[str] = _SQLA_RE,
1916) -> List[str]:
1917 """Chop extraneous lines off beginning and end of a traceback.
1918
1919 :param tb:
1920 a list of traceback lines as returned by ``traceback.format_stack()``
1921
1922 :param exclude_prefix:
1923 a regular expression object matching lines to skip at beginning of
1924 ``tb``
1925
1926 :param exclude_suffix:
1927 a regular expression object matching lines to skip at end of ``tb``
1928 """
1929 start = 0
1930 end = len(tb) - 1
1931 while start <= end and exclude_prefix.search(tb[start]):
1932 start += 1
1933 while start <= end and exclude_suffix.search(tb[end]):
1934 end -= 1
1935 return tb[start : end + 1]
1936
1937
1938NoneType = type(None)
1939
1940
1941def attrsetter(attrname):
1942 code = "def set(obj, value): obj.%s = value" % attrname
1943 env = locals().copy()
1944 exec(code, env)
1945 return env["set"]
1946
1947
1948_dunders = re.compile("^__.+__$")
1949
1950
1951class TypingOnly:
1952 """A mixin class that marks a class as 'typing only', meaning it has
1953 absolutely no methods, attributes, or runtime functionality whatsoever.
1954
1955 """
1956
1957 __slots__ = ()
1958
1959 def __init_subclass__(cls) -> None:
1960 if TypingOnly in cls.__bases__:
1961 remaining = {
1962 name for name in cls.__dict__ if not _dunders.match(name)
1963 }
1964 if remaining:
1965 raise AssertionError(
1966 f"Class {cls} directly inherits TypingOnly but has "
1967 f"additional attributes {remaining}."
1968 )
1969 super().__init_subclass__()
1970
1971
1972class EnsureKWArg:
1973 r"""Apply translation of functions to accept \**kw arguments if they
1974 don't already.
1975
1976 Used to ensure cross-compatibility with third party legacy code, for things
1977 like compiler visit methods that need to accept ``**kw`` arguments,
1978 but may have been copied from old code that didn't accept them.
1979
1980 """
1981
1982 ensure_kwarg: str
1983 """a regular expression that indicates method names for which the method
1984 should accept ``**kw`` arguments.
1985
1986 The class will scan for methods matching the name template and decorate
1987 them if necessary to ensure ``**kw`` parameters are accepted.
1988
1989 """
1990
1991 def __init_subclass__(cls) -> None:
1992 fn_reg = cls.ensure_kwarg
1993 clsdict = cls.__dict__
1994 if fn_reg:
1995 for key in clsdict:
1996 m = re.match(fn_reg, key)
1997 if m:
1998 fn = clsdict[key]
1999 spec = compat.inspect_getfullargspec(fn)
2000 if not spec.varkw:
2001 wrapped = cls._wrap_w_kw(fn)
2002 setattr(cls, key, wrapped)
2003 super().__init_subclass__()
2004
2005 @classmethod
2006 def _wrap_w_kw(cls, fn: Callable[..., Any]) -> Callable[..., Any]:
2007 def wrap(*arg: Any, **kw: Any) -> Any:
2008 return fn(*arg)
2009
2010 return update_wrapper(wrap, fn)
2011
2012
2013def wrap_callable(wrapper, fn):
2014 """Augment functools.update_wrapper() to work with objects with
2015 a ``__call__()`` method.
2016
2017 :param fn:
2018 object with __call__ method
2019
2020 """
2021 if hasattr(fn, "__name__"):
2022 return update_wrapper(wrapper, fn)
2023 else:
2024 _f = wrapper
2025 _f.__name__ = fn.__class__.__name__
2026 if hasattr(fn, "__module__"):
2027 _f.__module__ = fn.__module__
2028
2029 if hasattr(fn.__call__, "__doc__") and fn.__call__.__doc__:
2030 _f.__doc__ = fn.__call__.__doc__
2031 elif fn.__doc__:
2032 _f.__doc__ = fn.__doc__
2033
2034 return _f
2035
2036
2037def quoted_token_parser(value):
2038 """Parse a dotted identifier with accommodation for quoted names.
2039
2040 Includes support for SQL-style double quotes as a literal character.
2041
2042 E.g.::
2043
2044 >>> quoted_token_parser("name")
2045 ["name"]
2046 >>> quoted_token_parser("schema.name")
2047 ["schema", "name"]
2048 >>> quoted_token_parser('"Schema"."Name"')
2049 ['Schema', 'Name']
2050 >>> quoted_token_parser('"Schema"."Name""Foo"')
2051 ['Schema', 'Name""Foo']
2052
2053 """
2054
2055 if '"' not in value:
2056 return value.split(".")
2057
2058 # 0 = outside of quotes
2059 # 1 = inside of quotes
2060 state = 0
2061 result: List[List[str]] = [[]]
2062 idx = 0
2063 lv = len(value)
2064 while idx < lv:
2065 char = value[idx]
2066 if char == '"':
2067 if state == 1 and idx < lv - 1 and value[idx + 1] == '"':
2068 result[-1].append('"')
2069 idx += 1
2070 else:
2071 state ^= 1
2072 elif char == "." and state == 0:
2073 result.append([])
2074 else:
2075 result[-1].append(char)
2076 idx += 1
2077
2078 return ["".join(token) for token in result]
2079
2080
2081def add_parameter_text(params: Any, text: str) -> Callable[[_F], _F]:
2082 params = _collections.to_list(params)
2083
2084 def decorate(fn):
2085 doc = fn.__doc__ is not None and fn.__doc__ or ""
2086 if doc:
2087 doc = inject_param_text(doc, {param: text for param in params})
2088 fn.__doc__ = doc
2089 return fn
2090
2091 return decorate
2092
2093
2094def _dedent_docstring(text: str) -> str:
2095 split_text = text.split("\n", 1)
2096 if len(split_text) == 1:
2097 return text
2098 else:
2099 firstline, remaining = split_text
2100 if not firstline.startswith(" "):
2101 return firstline + "\n" + textwrap.dedent(remaining)
2102 else:
2103 return textwrap.dedent(text)
2104
2105
2106def inject_docstring_text(
2107 given_doctext: Optional[str], injecttext: str, pos: int
2108) -> str:
2109 doctext: str = _dedent_docstring(given_doctext or "")
2110 lines = doctext.split("\n")
2111 if len(lines) == 1:
2112 lines.append("")
2113 injectlines = textwrap.dedent(injecttext).split("\n")
2114 if injectlines[0]:
2115 injectlines.insert(0, "")
2116
2117 blanks = [num for num, line in enumerate(lines) if not line.strip()]
2118 blanks.insert(0, 0)
2119
2120 inject_pos = blanks[min(pos, len(blanks) - 1)]
2121
2122 lines = lines[0:inject_pos] + injectlines + lines[inject_pos:]
2123 return "\n".join(lines)
2124
2125
2126_param_reg = re.compile(r"(\s+):param (.+?):")
2127
2128
2129def inject_param_text(doctext: str, inject_params: Dict[str, str]) -> str:
2130 doclines = collections.deque(doctext.splitlines())
2131 lines = []
2132
2133 # TODO: this is not working for params like ":param case_sensitive=True:"
2134
2135 to_inject = None
2136 while doclines:
2137 line = doclines.popleft()
2138
2139 m = _param_reg.match(line)
2140
2141 if to_inject is None:
2142 if m:
2143 param = m.group(2).lstrip("*")
2144 if param in inject_params:
2145 # default indent to that of :param: plus one
2146 indent = " " * len(m.group(1)) + " "
2147
2148 # but if the next line has text, use that line's
2149 # indentation
2150 if doclines:
2151 m2 = re.match(r"(\s+)\S", doclines[0])
2152 if m2:
2153 indent = " " * len(m2.group(1))
2154
2155 to_inject = indent + inject_params[param]
2156 elif m:
2157 lines.extend(["\n", to_inject, "\n"])
2158 to_inject = None
2159 elif not line.rstrip():
2160 lines.extend([line, to_inject, "\n"])
2161 to_inject = None
2162 elif line.endswith("::"):
2163 # TODO: this still won't cover if the code example itself has
2164 # blank lines in it, need to detect those via indentation.
2165 lines.extend([line, doclines.popleft()])
2166 continue
2167 lines.append(line)
2168
2169 return "\n".join(lines)
2170
2171
2172def repr_tuple_names(names: List[str]) -> Optional[str]:
2173 """Trims a list of strings from the middle and return a string of up to
2174 four elements. Strings greater than 11 characters will be truncated"""
2175 if len(names) == 0:
2176 return None
2177 flag = len(names) <= 4
2178 names = names[0:4] if flag else names[0:3] + names[-1:]
2179 res = ["%s.." % name[:11] if len(name) > 11 else name for name in names]
2180 if flag:
2181 return ", ".join(res)
2182 else:
2183 return "%s, ..., %s" % (", ".join(res[0:3]), res[-1])
2184
2185
2186def has_compiled_ext(raise_=False):
2187 if HAS_CYEXTENSION:
2188 return True
2189 elif raise_:
2190 raise ImportError(
2191 "cython extensions were expected to be installed, "
2192 "but are not present"
2193 )
2194 else:
2195 return False
2196
2197
2198class _Missing(enum.Enum):
2199 Missing = enum.auto()
2200
2201
2202Missing = _Missing.Missing
2203MissingOr = Union[_T, Literal[_Missing.Missing]]