1# util/langhelpers.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Routines to help with the creation, loading and introspection of
10modules, classes, hierarchies, attributes, functions, and methods.
11
12"""
13from __future__ import annotations
14
15import collections
16import enum
17from functools import update_wrapper
18import importlib.util
19import inspect
20import itertools
21import operator
22import re
23import sys
24import textwrap
25import threading
26import types
27from types import CodeType
28from types import ModuleType
29from typing import Any
30from typing import Callable
31from typing import cast
32from typing import Dict
33from typing import FrozenSet
34from typing import Generic
35from typing import Iterator
36from typing import List
37from typing import Literal
38from typing import NoReturn
39from typing import Optional
40from typing import overload
41from typing import Sequence
42from typing import Set
43from typing import Tuple
44from typing import Type
45from typing import TYPE_CHECKING
46from typing import TypeVar
47from typing import Union
48import warnings
49
50from . import _collections
51from . import compat
52from .. import exc
53
54_T = TypeVar("_T")
55_T_co = TypeVar("_T_co", covariant=True)
56_F = TypeVar("_F", bound=Callable[..., Any])
57_MA = TypeVar("_MA", bound="HasMemoized.memoized_attribute[Any]")
58_M = TypeVar("_M", bound=ModuleType)
59
60
61def restore_annotations(
62 cls: type, new_annotations: dict[str, Any]
63) -> Callable[[], None]:
64 """apply alternate annotations to a class, with a callable to restore
65 the pristine state of the former.
66 This is used strictly to provide dataclasses on a mapped class, where
67 in some cases where are making dataclass fields based on an attribute
68 that is actually a python descriptor on a superclass which we called
69 to get a value.
70 if dataclasses were to give us a way to achieve this without swapping
71 __annotations__, that would be much better.
72 """
73 delattr_ = object()
74
75 # pep-649 means classes have "__annotate__", and it's a callable. if it's
76 # there and is None, we're in "legacy future mode", where it's python 3.14
77 # or higher and "from __future__ import annotations" is set. in "legacy
78 # future mode" we have to do the same steps we do for older pythons,
79 # __annotate__ can be ignored
80 is_pep649 = hasattr(cls, "__annotate__") and cls.__annotate__ is not None
81
82 if is_pep649:
83 memoized = {
84 "__annotate__": getattr(cls, "__annotate__", delattr_),
85 }
86 else:
87 memoized = {
88 "__annotations__": getattr(cls, "__annotations__", delattr_)
89 }
90
91 cls.__annotations__ = new_annotations
92
93 def restore():
94 for k, v in memoized.items():
95 if v is delattr_:
96 delattr(cls, k)
97 else:
98 setattr(cls, k, v)
99
100 return restore
101
102
103def md5_hex(x: Any) -> str:
104 x = x.encode("utf-8")
105 m = compat.md5_not_for_security()
106 m.update(x)
107 return cast(str, m.hexdigest())
108
109
110class safe_reraise:
111 """Reraise an exception after invoking some
112 handler code.
113
114 Stores the existing exception info before
115 invoking so that it is maintained across a potential
116 coroutine context switch.
117
118 e.g.::
119
120 try:
121 sess.commit()
122 except:
123 with safe_reraise():
124 sess.rollback()
125
126 TODO: we should at some point evaluate current behaviors in this regard
127 based on current greenlet, gevent/eventlet implementations in Python 3, and
128 also see the degree to which our own asyncio (based on greenlet also) is
129 impacted by this. .rollback() will cause IO / context switch to occur in
130 all these scenarios; what happens to the exception context from an
131 "except:" block if we don't explicitly store it? Original issue was #2703.
132
133 """
134
135 __slots__ = ("_exc_info",)
136
137 _exc_info: Union[
138 None,
139 Tuple[
140 Type[BaseException],
141 BaseException,
142 types.TracebackType,
143 ],
144 Tuple[None, None, None],
145 ]
146
147 def __enter__(self) -> None:
148 self._exc_info = sys.exc_info()
149
150 def __exit__(
151 self,
152 type_: Optional[Type[BaseException]],
153 value: Optional[BaseException],
154 traceback: Optional[types.TracebackType],
155 ) -> NoReturn:
156 assert self._exc_info is not None
157 # see #2703 for notes
158 if type_ is None:
159 exc_type, exc_value, exc_tb = self._exc_info
160 assert exc_value is not None
161 self._exc_info = None # remove potential circular references
162 raise exc_value.with_traceback(exc_tb)
163 else:
164 self._exc_info = None # remove potential circular references
165 assert value is not None
166 raise value.with_traceback(traceback)
167
168
169def walk_subclasses(cls: Type[_T]) -> Iterator[Type[_T]]:
170 seen: Set[Any] = set()
171
172 stack = [cls]
173 while stack:
174 cls = stack.pop()
175 if cls in seen:
176 continue
177 else:
178 seen.add(cls)
179 stack.extend(cls.__subclasses__())
180 yield cls
181
182
183def string_or_unprintable(element: Any) -> str:
184 if isinstance(element, str):
185 return element
186 else:
187 try:
188 return str(element)
189 except Exception:
190 return "unprintable element %r" % element
191
192
193def clsname_as_plain_name(
194 cls: Type[Any], use_name: Optional[str] = None
195) -> str:
196 name = use_name or cls.__name__
197 return " ".join(n.lower() for n in re.findall(r"([A-Z][a-z]+|SQL)", name))
198
199
200def method_is_overridden(
201 instance_or_cls: Union[Type[Any], object],
202 against_method: Callable[..., Any],
203) -> bool:
204 """Return True if the two class methods don't match."""
205
206 if not isinstance(instance_or_cls, type):
207 current_cls = instance_or_cls.__class__
208 else:
209 current_cls = instance_or_cls
210
211 method_name = against_method.__name__
212
213 current_method: types.MethodType = getattr(current_cls, method_name)
214
215 return current_method != against_method
216
217
218def decode_slice(slc: slice) -> Tuple[Any, ...]:
219 """decode a slice object as sent to __getitem__.
220
221 takes into account the 2.5 __index__() method, basically.
222
223 """
224 ret: List[Any] = []
225 for x in slc.start, slc.stop, slc.step:
226 if hasattr(x, "__index__"):
227 x = x.__index__()
228 ret.append(x)
229 return tuple(ret)
230
231
232def _unique_symbols(used: Sequence[str], *bases: str) -> Iterator[str]:
233 used_set = set(used)
234 for base in bases:
235 pool = itertools.chain(
236 (base,),
237 map(lambda i: base + str(i), range(1000)),
238 )
239 for sym in pool:
240 if sym not in used_set:
241 used_set.add(sym)
242 yield sym
243 break
244 else:
245 raise NameError("exhausted namespace for symbol base %s" % base)
246
247
248def map_bits(fn: Callable[[int], Any], n: int) -> Iterator[Any]:
249 """Call the given function given each nonzero bit from n."""
250
251 while n:
252 b = n & (~n + 1)
253 yield fn(b)
254 n ^= b
255
256
257_Fn = TypeVar("_Fn", bound="Callable[..., Any]")
258
259# this seems to be in flux in recent mypy versions
260
261
262def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]:
263 """A signature-matching decorator factory."""
264
265 def decorate(fn: _Fn) -> _Fn:
266 if not inspect.isfunction(fn) and not inspect.ismethod(fn):
267 raise Exception("not a decoratable function")
268
269 # Python 3.14 defer creating __annotations__ until its used.
270 # We do not want to create __annotations__ now.
271 annofunc = getattr(fn, "__annotate__", None)
272 if annofunc is not None:
273 fn.__annotate__ = None # type: ignore[union-attr]
274 try:
275 spec = compat.inspect_getfullargspec(fn)
276 finally:
277 fn.__annotate__ = annofunc # type: ignore[union-attr]
278 else:
279 spec = compat.inspect_getfullargspec(fn)
280
281 # Do not generate code for annotations.
282 # update_wrapper() copies the annotation from fn to decorated.
283 # We use dummy defaults for code generation to avoid having
284 # copy of large globals for compiling.
285 # We copy __defaults__ and __kwdefaults__ from fn to decorated.
286 empty_defaults = (None,) * len(spec.defaults or ())
287 empty_kwdefaults = dict.fromkeys(spec.kwonlydefaults or ())
288 spec = spec._replace(
289 annotations={},
290 defaults=empty_defaults,
291 kwonlydefaults=empty_kwdefaults,
292 )
293
294 names = (
295 tuple(cast("Tuple[str, ...]", spec[0]))
296 + cast("Tuple[str, ...]", spec[1:3])
297 + (fn.__name__,)
298 )
299 targ_name, fn_name = _unique_symbols(names, "target", "fn")
300
301 metadata: Dict[str, Optional[str]] = dict(target=targ_name, fn=fn_name)
302 metadata.update(format_argspec_plus(spec, grouped=False))
303 metadata["name"] = fn.__name__
304
305 if inspect.iscoroutinefunction(fn):
306 metadata["prefix"] = "async "
307 metadata["target_prefix"] = "await "
308 else:
309 metadata["prefix"] = ""
310 metadata["target_prefix"] = ""
311
312 # look for __ positional arguments. This is a convention in
313 # SQLAlchemy that arguments should be passed positionally
314 # rather than as keyword
315 # arguments. note that apply_pos doesn't currently work in all cases
316 # such as when a kw-only indicator "*" is present, which is why
317 # we limit the use of this to just that case we can detect. As we add
318 # more kinds of methods that use @decorator, things may have to
319 # be further improved in this area
320 if "__" in repr(spec[0]):
321 code = (
322 """\
323%(prefix)sdef %(name)s%(grouped_args)s:
324 return %(target_prefix)s%(target)s(%(fn)s, %(apply_pos)s)
325"""
326 % metadata
327 )
328 else:
329 code = (
330 """\
331%(prefix)sdef %(name)s%(grouped_args)s:
332 return %(target_prefix)s%(target)s(%(fn)s, %(apply_kw)s)
333"""
334 % metadata
335 )
336
337 env: Dict[str, Any] = {
338 targ_name: target,
339 fn_name: fn,
340 "__name__": fn.__module__,
341 }
342
343 decorated = cast(
344 types.FunctionType,
345 _exec_code_in_env(code, env, fn.__name__),
346 )
347 decorated.__defaults__ = fn.__defaults__
348 decorated.__kwdefaults__ = fn.__kwdefaults__ # type: ignore
349 return update_wrapper(decorated, fn) # type: ignore[return-value]
350
351 return update_wrapper(decorate, target) # type: ignore[return-value]
352
353
354def _exec_code_in_env(
355 code: Union[str, types.CodeType], env: Dict[str, Any], fn_name: str
356) -> Callable[..., Any]:
357 exec(code, env)
358 return env[fn_name] # type: ignore[no-any-return]
359
360
361_PF = TypeVar("_PF")
362_TE = TypeVar("_TE")
363
364
365class PluginLoader:
366 def __init__(
367 self, group: str, auto_fn: Optional[Callable[..., Any]] = None
368 ):
369 self.group = group
370 self.impls: Dict[str, Any] = {}
371 self.auto_fn = auto_fn
372
373 def clear(self):
374 self.impls.clear()
375
376 def load(self, name: str) -> Any:
377 if name in self.impls:
378 return self.impls[name]()
379
380 if self.auto_fn:
381 loader = self.auto_fn(name)
382 if loader:
383 self.impls[name] = loader
384 return loader()
385
386 for impl in compat.importlib_metadata_get(self.group):
387 if impl.name == name:
388 self.impls[name] = impl.load
389 return impl.load()
390
391 raise exc.NoSuchModuleError(
392 "Can't load plugin: %s:%s" % (self.group, name)
393 )
394
395 def register(self, name: str, modulepath: str, objname: str) -> None:
396 def load():
397 mod = __import__(modulepath)
398 for token in modulepath.split(".")[1:]:
399 mod = getattr(mod, token)
400 return getattr(mod, objname)
401
402 self.impls[name] = load
403
404 def deregister(self, name: str) -> None:
405 del self.impls[name]
406
407
408def _inspect_func_args(fn):
409 try:
410 co_varkeywords = inspect.CO_VARKEYWORDS
411 except AttributeError:
412 # https://docs.python.org/3/library/inspect.html
413 # The flags are specific to CPython, and may not be defined in other
414 # Python implementations. Furthermore, the flags are an implementation
415 # detail, and can be removed or deprecated in future Python releases.
416 spec = compat.inspect_getfullargspec(fn)
417 return spec[0], bool(spec[2])
418 else:
419 # use fn.__code__ plus flags to reduce method call overhead
420 co = fn.__code__
421 nargs = co.co_argcount
422 return (
423 list(co.co_varnames[:nargs]),
424 bool(co.co_flags & co_varkeywords),
425 )
426
427
428@overload
429def get_cls_kwargs(
430 cls: type,
431 *,
432 _set: Optional[Set[str]] = None,
433 raiseerr: Literal[True] = ...,
434) -> Set[str]: ...
435
436
437@overload
438def get_cls_kwargs(
439 cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
440) -> Optional[Set[str]]: ...
441
442
443def get_cls_kwargs(
444 cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
445) -> Optional[Set[str]]:
446 r"""Return the full set of inherited kwargs for the given `cls`.
447
448 Probes a class's __init__ method, collecting all named arguments. If the
449 __init__ defines a \**kwargs catch-all, then the constructor is presumed
450 to pass along unrecognized keywords to its base classes, and the
451 collection process is repeated recursively on each of the bases.
452
453 Uses a subset of inspect.getfullargspec() to cut down on method overhead,
454 as this is used within the Core typing system to create copies of type
455 objects which is a performance-sensitive operation.
456
457 No anonymous tuple arguments please !
458
459 """
460 toplevel = _set is None
461 if toplevel:
462 _set = set()
463 assert _set is not None
464
465 ctr = cls.__dict__.get("__init__", False)
466
467 has_init = (
468 ctr
469 and isinstance(ctr, types.FunctionType)
470 and isinstance(ctr.__code__, types.CodeType)
471 )
472
473 if has_init:
474 names, has_kw = _inspect_func_args(ctr)
475 _set.update(names)
476
477 if not has_kw and not toplevel:
478 if raiseerr:
479 raise TypeError(
480 f"given cls {cls} doesn't have an __init__ method"
481 )
482 else:
483 return None
484 else:
485 has_kw = False
486
487 if not has_init or has_kw:
488 for c in cls.__bases__:
489 if get_cls_kwargs(c, _set=_set) is None:
490 break
491
492 _set.discard("self")
493 return _set
494
495
496def get_func_kwargs(func: Callable[..., Any]) -> List[str]:
497 """Return the set of legal kwargs for the given `func`.
498
499 Uses getargspec so is safe to call for methods, functions,
500 etc.
501
502 """
503
504 return compat.inspect_getfullargspec(func)[0]
505
506
507def get_callable_argspec(
508 fn: Callable[..., Any], no_self: bool = False, _is_init: bool = False
509) -> compat.FullArgSpec:
510 """Return the argument signature for any callable.
511
512 All pure-Python callables are accepted, including
513 functions, methods, classes, objects with __call__;
514 builtins and other edge cases like functools.partial() objects
515 raise a TypeError.
516
517 """
518 if inspect.isbuiltin(fn):
519 raise TypeError("Can't inspect builtin: %s" % fn)
520 elif inspect.isfunction(fn):
521 if _is_init and no_self:
522 spec = compat.inspect_getfullargspec(fn)
523 return compat.FullArgSpec(
524 spec.args[1:],
525 spec.varargs,
526 spec.varkw,
527 spec.defaults,
528 spec.kwonlyargs,
529 spec.kwonlydefaults,
530 spec.annotations,
531 )
532 else:
533 return compat.inspect_getfullargspec(fn)
534 elif inspect.ismethod(fn):
535 if no_self and (_is_init or fn.__self__):
536 spec = compat.inspect_getfullargspec(fn.__func__)
537 return compat.FullArgSpec(
538 spec.args[1:],
539 spec.varargs,
540 spec.varkw,
541 spec.defaults,
542 spec.kwonlyargs,
543 spec.kwonlydefaults,
544 spec.annotations,
545 )
546 else:
547 return compat.inspect_getfullargspec(fn.__func__)
548 elif inspect.isclass(fn):
549 return get_callable_argspec(
550 fn.__init__, no_self=no_self, _is_init=True
551 )
552 elif hasattr(fn, "__func__"):
553 return compat.inspect_getfullargspec(fn.__func__)
554 elif hasattr(fn, "__call__"):
555 if inspect.ismethod(fn.__call__):
556 return get_callable_argspec(fn.__call__, no_self=no_self)
557 else:
558 raise TypeError("Can't inspect callable: %s" % fn)
559 else:
560 raise TypeError("Can't inspect callable: %s" % fn)
561
562
563def format_argspec_plus(
564 fn: Union[Callable[..., Any], compat.FullArgSpec], grouped: bool = True
565) -> Dict[str, Optional[str]]:
566 """Returns a dictionary of formatted, introspected function arguments.
567
568 A enhanced variant of inspect.formatargspec to support code generation.
569
570 fn
571 An inspectable callable or tuple of inspect getargspec() results.
572 grouped
573 Defaults to True; include (parens, around, argument) lists
574
575 Returns:
576
577 args
578 Full inspect.formatargspec for fn
579 self_arg
580 The name of the first positional argument, varargs[0], or None
581 if the function defines no positional arguments.
582 apply_pos
583 args, re-written in calling rather than receiving syntax. Arguments are
584 passed positionally.
585 apply_kw
586 Like apply_pos, except keyword-ish args are passed as keywords.
587 apply_pos_proxied
588 Like apply_pos but omits the self/cls argument
589
590 Example::
591
592 >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123)
593 {'grouped_args': '(self, a, b, c=3, **d)',
594 'self_arg': 'self',
595 'apply_kw': '(self, a, b, c=c, **d)',
596 'apply_pos': '(self, a, b, c, **d)'}
597
598 """
599 if callable(fn):
600 spec = compat.inspect_getfullargspec(fn)
601 else:
602 spec = fn
603
604 args = compat.inspect_formatargspec(*spec)
605
606 apply_pos = compat.inspect_formatargspec(
607 spec[0], spec[1], spec[2], None, spec[4]
608 )
609
610 if spec[0]:
611 self_arg = spec[0][0]
612
613 apply_pos_proxied = compat.inspect_formatargspec(
614 spec[0][1:], spec[1], spec[2], None, spec[4]
615 )
616
617 elif spec[1]:
618 # I'm not sure what this is
619 self_arg = "%s[0]" % spec[1]
620
621 apply_pos_proxied = apply_pos
622 else:
623 self_arg = None
624 apply_pos_proxied = apply_pos
625
626 num_defaults = 0
627 if spec[3]:
628 num_defaults += len(cast(Tuple[Any], spec[3]))
629 if spec[4]:
630 num_defaults += len(spec[4])
631
632 name_args = spec[0] + spec[4]
633
634 defaulted_vals: Union[List[str], Tuple[()]]
635
636 if num_defaults:
637 defaulted_vals = name_args[0 - num_defaults :]
638 else:
639 defaulted_vals = ()
640
641 apply_kw = compat.inspect_formatargspec(
642 name_args,
643 spec[1],
644 spec[2],
645 defaulted_vals,
646 formatvalue=lambda x: "=" + str(x),
647 )
648
649 if spec[0]:
650 apply_kw_proxied = compat.inspect_formatargspec(
651 name_args[1:],
652 spec[1],
653 spec[2],
654 defaulted_vals,
655 formatvalue=lambda x: "=" + str(x),
656 )
657 else:
658 apply_kw_proxied = apply_kw
659
660 if grouped:
661 return dict(
662 grouped_args=args,
663 self_arg=self_arg,
664 apply_pos=apply_pos,
665 apply_kw=apply_kw,
666 apply_pos_proxied=apply_pos_proxied,
667 apply_kw_proxied=apply_kw_proxied,
668 )
669 else:
670 return dict(
671 grouped_args=args,
672 self_arg=self_arg,
673 apply_pos=apply_pos[1:-1],
674 apply_kw=apply_kw[1:-1],
675 apply_pos_proxied=apply_pos_proxied[1:-1],
676 apply_kw_proxied=apply_kw_proxied[1:-1],
677 )
678
679
680def format_argspec_init(method, grouped=True):
681 """format_argspec_plus with considerations for typical __init__ methods
682
683 Wraps format_argspec_plus with error handling strategies for typical
684 __init__ cases:
685
686 .. sourcecode:: text
687
688 object.__init__ -> (self)
689 other unreflectable (usually C) -> (self, *args, **kwargs)
690
691 """
692 if method is object.__init__:
693 grouped_args = "(self)"
694 args = "(self)" if grouped else "self"
695 proxied = "()" if grouped else ""
696 else:
697 try:
698 return format_argspec_plus(method, grouped=grouped)
699 except TypeError:
700 grouped_args = "(self, *args, **kwargs)"
701 args = grouped_args if grouped else "self, *args, **kwargs"
702 proxied = "(*args, **kwargs)" if grouped else "*args, **kwargs"
703 return dict(
704 self_arg="self",
705 grouped_args=grouped_args,
706 apply_pos=args,
707 apply_kw=args,
708 apply_pos_proxied=proxied,
709 apply_kw_proxied=proxied,
710 )
711
712
713def create_proxy_methods(
714 target_cls: Type[Any],
715 target_cls_sphinx_name: str,
716 proxy_cls_sphinx_name: str,
717 classmethods: Sequence[str] = (),
718 methods: Sequence[str] = (),
719 attributes: Sequence[str] = (),
720 use_intermediate_variable: Sequence[str] = (),
721) -> Callable[[_T], _T]:
722 """A class decorator indicating attributes should refer to a proxy
723 class.
724
725 This decorator is now a "marker" that does nothing at runtime. Instead,
726 it is consumed by the tools/generate_proxy_methods.py script to
727 statically generate proxy methods and attributes that are fully
728 recognized by typing tools such as mypy.
729
730 """
731
732 def decorate(cls):
733 return cls
734
735 return decorate
736
737
738def getargspec_init(method):
739 """inspect.getargspec with considerations for typical __init__ methods
740
741 Wraps inspect.getargspec with error handling for typical __init__ cases:
742
743 .. sourcecode:: text
744
745 object.__init__ -> (self)
746 other unreflectable (usually C) -> (self, *args, **kwargs)
747
748 """
749 try:
750 return compat.inspect_getfullargspec(method)
751 except TypeError:
752 if method is object.__init__:
753 return (["self"], None, None, None)
754 else:
755 return (["self"], "args", "kwargs", None)
756
757
758def unbound_method_to_callable(func_or_cls):
759 """Adjust the incoming callable such that a 'self' argument is not
760 required.
761
762 """
763
764 if isinstance(func_or_cls, types.MethodType) and not func_or_cls.__self__:
765 return func_or_cls.__func__
766 else:
767 return func_or_cls
768
769
770def generic_repr(
771 obj: Any,
772 additional_kw: Sequence[Tuple[str, Any]] = (),
773 to_inspect: Optional[Union[object, List[object]]] = None,
774 omit_kwarg: Sequence[str] = (),
775) -> str:
776 """Produce a __repr__() based on direct association of the __init__()
777 specification vs. same-named attributes present.
778
779 """
780 if to_inspect is None:
781 to_inspect = [obj]
782 else:
783 to_inspect = _collections.to_list(to_inspect)
784
785 missing = object()
786
787 pos_args = []
788 kw_args: _collections.OrderedDict[str, Any] = _collections.OrderedDict()
789 vargs = None
790 for i, insp in enumerate(to_inspect):
791 try:
792 spec = compat.inspect_getfullargspec(insp.__init__)
793 except TypeError:
794 continue
795 else:
796 default_len = len(spec.defaults) if spec.defaults else 0
797 if i == 0:
798 if spec.varargs:
799 vargs = spec.varargs
800 if default_len:
801 pos_args.extend(spec.args[1:-default_len])
802 else:
803 pos_args.extend(spec.args[1:])
804 else:
805 kw_args.update(
806 [(arg, missing) for arg in spec.args[1:-default_len]]
807 )
808
809 if default_len:
810 assert spec.defaults
811 kw_args.update(
812 [
813 (arg, default)
814 for arg, default in zip(
815 spec.args[-default_len:], spec.defaults
816 )
817 ]
818 )
819 output: List[str] = []
820
821 output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
822
823 if vargs is not None and hasattr(obj, vargs):
824 output.extend([repr(val) for val in getattr(obj, vargs)])
825
826 for arg, defval in kw_args.items():
827 if arg in omit_kwarg:
828 continue
829 try:
830 val = getattr(obj, arg, missing)
831 if val is not missing and val != defval:
832 output.append("%s=%r" % (arg, val))
833 except Exception:
834 pass
835
836 if additional_kw:
837 for arg, defval in additional_kw:
838 try:
839 val = getattr(obj, arg, missing)
840 if val is not missing and val != defval:
841 output.append("%s=%r" % (arg, val))
842 except Exception:
843 pass
844
845 return "%s(%s)" % (obj.__class__.__name__, ", ".join(output))
846
847
848def class_hierarchy(cls):
849 """Return an unordered sequence of all classes related to cls.
850
851 Traverses diamond hierarchies.
852
853 Fibs slightly: subclasses of builtin types are not returned. Thus
854 class_hierarchy(class A(object)) returns (A, object), not A plus every
855 class systemwide that derives from object.
856
857 """
858
859 hier = {cls}
860 process = list(cls.__mro__)
861 while process:
862 c = process.pop()
863 bases = (_ for _ in c.__bases__ if _ not in hier)
864
865 for b in bases:
866 process.append(b)
867 hier.add(b)
868
869 if c.__module__ == "builtins" or not hasattr(c, "__subclasses__"):
870 continue
871
872 for s in [
873 _
874 for _ in (
875 c.__subclasses__()
876 if not issubclass(c, type)
877 else c.__subclasses__(c)
878 )
879 if _ not in hier
880 ]:
881 process.append(s)
882 hier.add(s)
883 return list(hier)
884
885
886def iterate_attributes(cls):
887 """iterate all the keys and attributes associated
888 with a class, without using getattr().
889
890 Does not use getattr() so that class-sensitive
891 descriptors (i.e. property.__get__()) are not called.
892
893 """
894 keys = dir(cls)
895 for key in keys:
896 for c in cls.__mro__:
897 if key in c.__dict__:
898 yield (key, c.__dict__[key])
899 break
900
901
902def monkeypatch_proxied_specials(
903 into_cls,
904 from_cls,
905 skip=None,
906 only=None,
907 name="self.proxy",
908 from_instance=None,
909):
910 """Automates delegation of __specials__ for a proxying type."""
911
912 if only:
913 dunders = only
914 else:
915 if skip is None:
916 skip = (
917 "__slots__",
918 "__del__",
919 "__getattribute__",
920 "__metaclass__",
921 "__getstate__",
922 "__setstate__",
923 )
924 dunders = [
925 m
926 for m in dir(from_cls)
927 if (
928 m.startswith("__")
929 and m.endswith("__")
930 and not hasattr(into_cls, m)
931 and m not in skip
932 )
933 ]
934
935 for method in dunders:
936 try:
937 maybe_fn = getattr(from_cls, method)
938 if not hasattr(maybe_fn, "__call__"):
939 continue
940 maybe_fn = getattr(maybe_fn, "__func__", maybe_fn)
941 fn = cast(types.FunctionType, maybe_fn)
942
943 except AttributeError:
944 continue
945 try:
946 spec = compat.inspect_getfullargspec(fn)
947 fn_args = compat.inspect_formatargspec(spec[0])
948 d_args = compat.inspect_formatargspec(spec[0][1:])
949 except TypeError:
950 fn_args = "(self, *args, **kw)"
951 d_args = "(*args, **kw)"
952
953 py = (
954 "def %(method)s%(fn_args)s: "
955 "return %(name)s.%(method)s%(d_args)s" % locals()
956 )
957
958 env: Dict[str, types.FunctionType] = (
959 from_instance is not None and {name: from_instance} or {}
960 )
961 exec(py, env)
962 try:
963 env[method].__defaults__ = fn.__defaults__
964 except AttributeError:
965 pass
966 setattr(into_cls, method, env[method])
967
968
969def methods_equivalent(meth1, meth2):
970 """Return True if the two methods are the same implementation."""
971
972 return getattr(meth1, "__func__", meth1) is getattr(
973 meth2, "__func__", meth2
974 )
975
976
977def as_interface(obj, cls=None, methods=None, required=None):
978 """Ensure basic interface compliance for an instance or dict of callables.
979
980 Checks that ``obj`` implements public methods of ``cls`` or has members
981 listed in ``methods``. If ``required`` is not supplied, implementing at
982 least one interface method is sufficient. Methods present on ``obj`` that
983 are not in the interface are ignored.
984
985 If ``obj`` is a dict and ``dict`` does not meet the interface
986 requirements, the keys of the dictionary are inspected. Keys present in
987 ``obj`` that are not in the interface will raise TypeErrors.
988
989 Raises TypeError if ``obj`` does not meet the interface criteria.
990
991 In all passing cases, an object with callable members is returned. In the
992 simple case, ``obj`` is returned as-is; if dict processing kicks in then
993 an anonymous class is returned.
994
995 obj
996 A type, instance, or dictionary of callables.
997 cls
998 Optional, a type. All public methods of cls are considered the
999 interface. An ``obj`` instance of cls will always pass, ignoring
1000 ``required``..
1001 methods
1002 Optional, a sequence of method names to consider as the interface.
1003 required
1004 Optional, a sequence of mandatory implementations. If omitted, an
1005 ``obj`` that provides at least one interface method is considered
1006 sufficient. As a convenience, required may be a type, in which case
1007 all public methods of the type are required.
1008
1009 """
1010 if not cls and not methods:
1011 raise TypeError("a class or collection of method names are required")
1012
1013 if isinstance(cls, type) and isinstance(obj, cls):
1014 return obj
1015
1016 interface = set(methods or [m for m in dir(cls) if not m.startswith("_")])
1017 implemented = set(dir(obj))
1018
1019 complies = operator.ge
1020 if isinstance(required, type):
1021 required = interface
1022 elif not required:
1023 required = set()
1024 complies = operator.gt
1025 else:
1026 required = set(required)
1027
1028 if complies(implemented.intersection(interface), required):
1029 return obj
1030
1031 # No dict duck typing here.
1032 if not isinstance(obj, dict):
1033 qualifier = complies is operator.gt and "any of" or "all of"
1034 raise TypeError(
1035 "%r does not implement %s: %s"
1036 % (obj, qualifier, ", ".join(interface))
1037 )
1038
1039 class AnonymousInterface:
1040 """A callable-holding shell."""
1041
1042 if cls:
1043 AnonymousInterface.__name__ = "Anonymous" + cls.__name__
1044 found = set()
1045
1046 for method, impl in dictlike_iteritems(obj):
1047 if method not in interface:
1048 raise TypeError("%r: unknown in this interface" % method)
1049 if not callable(impl):
1050 raise TypeError("%r=%r is not callable" % (method, impl))
1051 setattr(AnonymousInterface, method, staticmethod(impl))
1052 found.add(method)
1053
1054 if complies(found, required):
1055 return AnonymousInterface
1056
1057 raise TypeError(
1058 "dictionary does not contain required keys %s"
1059 % ", ".join(required - found)
1060 )
1061
1062
1063_GFD = TypeVar("_GFD", bound="generic_fn_descriptor[Any]")
1064
1065
1066class generic_fn_descriptor(Generic[_T_co]):
1067 """Descriptor which proxies a function when the attribute is not
1068 present in dict
1069
1070 This superclass is organized in a particular way with "memoized" and
1071 "non-memoized" implementation classes that are hidden from type checkers,
1072 as Mypy seems to not be able to handle seeing multiple kinds of descriptor
1073 classes used for the same attribute.
1074
1075 """
1076
1077 fget: Callable[..., _T_co]
1078 __doc__: Optional[str]
1079 __name__: str
1080
1081 def __init__(self, fget: Callable[..., _T_co], doc: Optional[str] = None):
1082 self.fget = fget
1083 self.__doc__ = doc or fget.__doc__
1084 self.__name__ = fget.__name__
1085
1086 @overload
1087 def __get__(self: _GFD, obj: None, cls: Any) -> _GFD: ...
1088
1089 @overload
1090 def __get__(self, obj: object, cls: Any) -> _T_co: ...
1091
1092 def __get__(self: _GFD, obj: Any, cls: Any) -> Union[_GFD, _T_co]:
1093 raise NotImplementedError()
1094
1095 if TYPE_CHECKING:
1096
1097 def __set__(self, instance: Any, value: Any) -> None: ...
1098
1099 def __delete__(self, instance: Any) -> None: ...
1100
1101 def _reset(self, obj: Any) -> None:
1102 raise NotImplementedError()
1103
1104 @classmethod
1105 def reset(cls, obj: Any, name: str) -> None:
1106 raise NotImplementedError()
1107
1108
1109class _non_memoized_property(generic_fn_descriptor[_T_co]):
1110 """a plain descriptor that proxies a function.
1111
1112 primary rationale is to provide a plain attribute that's
1113 compatible with memoized_property which is also recognized as equivalent
1114 by mypy.
1115
1116 """
1117
1118 if not TYPE_CHECKING:
1119
1120 def __get__(self, obj, cls):
1121 if obj is None:
1122 return self
1123 return self.fget(obj)
1124
1125
1126class _memoized_property(generic_fn_descriptor[_T_co]):
1127 """A read-only @property that is only evaluated once."""
1128
1129 if not TYPE_CHECKING:
1130
1131 def __get__(self, obj, cls):
1132 if obj is None:
1133 return self
1134 obj.__dict__[self.__name__] = result = self.fget(obj)
1135 return result
1136
1137 def _reset(self, obj):
1138 _memoized_property.reset(obj, self.__name__)
1139
1140 @classmethod
1141 def reset(cls, obj, name):
1142 obj.__dict__.pop(name, None)
1143
1144
1145# despite many attempts to get Mypy to recognize an overridden descriptor
1146# where one is memoized and the other isn't, there seems to be no reliable
1147# way other than completely deceiving the type checker into thinking there
1148# is just one single descriptor type everywhere. Otherwise, if a superclass
1149# has non-memoized and subclass has memoized, that requires
1150# "class memoized(non_memoized)". but then if a superclass has memoized and
1151# superclass has non-memoized, the class hierarchy of the descriptors
1152# would need to be reversed; "class non_memoized(memoized)". so there's no
1153# way to achieve this.
1154# additional issues, RO properties:
1155# https://github.com/python/mypy/issues/12440
1156if TYPE_CHECKING:
1157 # allow memoized and non-memoized to be freely mixed by having them
1158 # be the same class
1159 memoized_property = generic_fn_descriptor
1160 non_memoized_property = generic_fn_descriptor
1161
1162 # for read only situations, mypy only sees @property as read only.
1163 # read only is needed when a subtype specializes the return type
1164 # of a property, meaning assignment needs to be disallowed
1165 ro_memoized_property = property
1166 ro_non_memoized_property = property
1167
1168else:
1169 memoized_property = ro_memoized_property = _memoized_property
1170 non_memoized_property = ro_non_memoized_property = _non_memoized_property
1171
1172
1173def memoized_instancemethod(fn: _F) -> _F:
1174 """Decorate a method memoize its return value.
1175
1176 Best applied to no-arg methods: memoization is not sensitive to
1177 argument values, and will always return the same value even when
1178 called with different arguments.
1179
1180 """
1181
1182 def oneshot(self, *args, **kw):
1183 result = fn(self, *args, **kw)
1184
1185 def memo(*a, **kw):
1186 return result
1187
1188 memo.__name__ = fn.__name__
1189 memo.__doc__ = fn.__doc__
1190 self.__dict__[fn.__name__] = memo
1191 return result
1192
1193 return update_wrapper(oneshot, fn) # type: ignore
1194
1195
1196class HasMemoized:
1197 """A mixin class that maintains the names of memoized elements in a
1198 collection for easy cache clearing, generative, etc.
1199
1200 """
1201
1202 if not TYPE_CHECKING:
1203 # support classes that want to have __slots__ with an explicit
1204 # slot for __dict__. not sure if that requires base __slots__ here.
1205 __slots__ = ()
1206
1207 _memoized_keys: FrozenSet[str] = frozenset()
1208
1209 def _reset_memoizations(self) -> None:
1210 for elem in self._memoized_keys:
1211 self.__dict__.pop(elem, None)
1212
1213 def _assert_no_memoizations(self) -> None:
1214 for elem in self._memoized_keys:
1215 assert elem not in self.__dict__
1216
1217 def _set_memoized_attribute(self, key: str, value: Any) -> None:
1218 self.__dict__[key] = value
1219 self._memoized_keys |= {key}
1220
1221 class memoized_attribute(memoized_property[_T]):
1222 """A read-only @property that is only evaluated once.
1223
1224 :meta private:
1225
1226 """
1227
1228 fget: Callable[..., _T]
1229 __doc__: Optional[str]
1230 __name__: str
1231
1232 def __init__(self, fget: Callable[..., _T], doc: Optional[str] = None):
1233 self.fget = fget
1234 self.__doc__ = doc or fget.__doc__
1235 self.__name__ = fget.__name__
1236
1237 @overload
1238 def __get__(self: _MA, obj: None, cls: Any) -> _MA: ...
1239
1240 @overload
1241 def __get__(self, obj: Any, cls: Any) -> _T: ...
1242
1243 def __get__(self, obj, cls):
1244 if obj is None:
1245 return self
1246 obj.__dict__[self.__name__] = result = self.fget(obj)
1247 obj._memoized_keys |= {self.__name__}
1248 return result
1249
1250 @classmethod
1251 def memoized_instancemethod(cls, fn: _F) -> _F:
1252 """Decorate a method memoize its return value.
1253
1254 :meta private:
1255
1256 """
1257
1258 def oneshot(self: Any, *args: Any, **kw: Any) -> Any:
1259 result = fn(self, *args, **kw)
1260
1261 def memo(*a, **kw):
1262 return result
1263
1264 memo.__name__ = fn.__name__
1265 memo.__doc__ = fn.__doc__
1266 self.__dict__[fn.__name__] = memo
1267 self._memoized_keys |= {fn.__name__}
1268 return result
1269
1270 return update_wrapper(oneshot, fn) # type: ignore
1271
1272
1273if TYPE_CHECKING:
1274 HasMemoized_ro_memoized_attribute = property
1275else:
1276 HasMemoized_ro_memoized_attribute = HasMemoized.memoized_attribute
1277
1278
1279class MemoizedSlots:
1280 """Apply memoized items to an object using a __getattr__ scheme.
1281
1282 This allows the functionality of memoized_property and
1283 memoized_instancemethod to be available to a class using __slots__.
1284
1285 The memoized get is not threadsafe under freethreading and the
1286 creator method may in extremely rare cases be called more than once.
1287
1288 """
1289
1290 __slots__ = ()
1291
1292 def _fallback_getattr(self, key):
1293 raise AttributeError(key)
1294
1295 def __getattr__(self, key: str) -> Any:
1296 if key.startswith("_memoized_attr_") or key.startswith(
1297 "_memoized_method_"
1298 ):
1299 raise AttributeError(key)
1300 # to avoid recursion errors when interacting with other __getattr__
1301 # schemes that refer to this one, when testing for memoized method
1302 # look at __class__ only rather than going into __getattr__ again.
1303 elif hasattr(self.__class__, f"_memoized_attr_{key}"):
1304 value = getattr(self, f"_memoized_attr_{key}")()
1305 setattr(self, key, value)
1306 return value
1307 elif hasattr(self.__class__, f"_memoized_method_{key}"):
1308 meth = getattr(self, f"_memoized_method_{key}")
1309
1310 def oneshot(*args, **kw):
1311 result = meth(*args, **kw)
1312
1313 def memo(*a, **kw):
1314 return result
1315
1316 memo.__name__ = meth.__name__
1317 memo.__doc__ = meth.__doc__
1318 setattr(self, key, memo)
1319 return result
1320
1321 oneshot.__doc__ = meth.__doc__
1322 return oneshot
1323 else:
1324 return self._fallback_getattr(key)
1325
1326
1327# from paste.deploy.converters
1328def asbool(obj: Any) -> bool:
1329 if isinstance(obj, str):
1330 obj = obj.strip().lower()
1331 if obj in ["true", "yes", "on", "y", "t", "1"]:
1332 return True
1333 elif obj in ["false", "no", "off", "n", "f", "0"]:
1334 return False
1335 else:
1336 raise ValueError("String is not true/false: %r" % obj)
1337 return bool(obj)
1338
1339
1340def bool_or_str(*text: str) -> Callable[[str], Union[str, bool]]:
1341 """Return a callable that will evaluate a string as
1342 boolean, or one of a set of "alternate" string values.
1343
1344 """
1345
1346 def bool_or_value(obj: str) -> Union[str, bool]:
1347 if obj in text:
1348 return obj
1349 else:
1350 return asbool(obj)
1351
1352 return bool_or_value
1353
1354
1355def asint(value: Any) -> Optional[int]:
1356 """Coerce to integer."""
1357
1358 if value is None:
1359 return value
1360 return int(value)
1361
1362
1363def coerce_kw_type(
1364 kw: Dict[str, Any],
1365 key: str,
1366 type_: Type[Any],
1367 flexi_bool: bool = True,
1368 dest: Optional[Dict[str, Any]] = None,
1369) -> None:
1370 r"""If 'key' is present in dict 'kw', coerce its value to type 'type\_' if
1371 necessary. If 'flexi_bool' is True, the string '0' is considered false
1372 when coercing to boolean.
1373 """
1374
1375 if dest is None:
1376 dest = kw
1377
1378 if (
1379 key in kw
1380 and (not isinstance(type_, type) or not isinstance(kw[key], type_))
1381 and kw[key] is not None
1382 ):
1383 if type_ is bool and flexi_bool:
1384 dest[key] = asbool(kw[key])
1385 else:
1386 dest[key] = type_(kw[key])
1387
1388
1389def constructor_key(obj: Any, cls: Type[Any]) -> Tuple[Any, ...]:
1390 """Produce a tuple structure that is cacheable using the __dict__ of
1391 obj to retrieve values
1392
1393 """
1394 names = get_cls_kwargs(cls)
1395 return (cls,) + tuple(
1396 (k, obj.__dict__[k]) for k in names if k in obj.__dict__
1397 )
1398
1399
1400def constructor_copy(obj: _T, cls: Type[_T], *args: Any, **kw: Any) -> _T:
1401 """Instantiate cls using the __dict__ of obj as constructor arguments.
1402
1403 Uses inspect to match the named arguments of ``cls``.
1404
1405 """
1406
1407 names = get_cls_kwargs(cls)
1408 kw.update(
1409 (k, obj.__dict__[k]) for k in names.difference(kw) if k in obj.__dict__
1410 )
1411 return cls(*args, **kw)
1412
1413
1414def counter() -> Callable[[], int]:
1415 """Return a threadsafe counter function."""
1416
1417 lock = threading.Lock()
1418 counter = itertools.count(1)
1419
1420 # avoid the 2to3 "next" transformation...
1421 def _next():
1422 with lock:
1423 return next(counter)
1424
1425 return _next
1426
1427
1428def duck_type_collection(
1429 specimen: Any, default: Optional[Type[Any]] = None
1430) -> Optional[Type[Any]]:
1431 """Given an instance or class, guess if it is or is acting as one of
1432 the basic collection types: list, set and dict. If the __emulates__
1433 property is present, return that preferentially.
1434 """
1435
1436 if hasattr(specimen, "__emulates__"):
1437 # canonicalize set vs sets.Set to a standard: the builtin set
1438 if specimen.__emulates__ is not None and issubclass(
1439 specimen.__emulates__, set
1440 ):
1441 return set
1442 else:
1443 return specimen.__emulates__ # type: ignore
1444
1445 isa = issubclass if isinstance(specimen, type) else isinstance
1446 if isa(specimen, list):
1447 return list
1448 elif isa(specimen, set):
1449 return set
1450 elif isa(specimen, dict):
1451 return dict
1452
1453 if hasattr(specimen, "append"):
1454 return list
1455 elif hasattr(specimen, "add"):
1456 return set
1457 elif hasattr(specimen, "set"):
1458 return dict
1459 else:
1460 return default
1461
1462
1463def assert_arg_type(
1464 arg: Any, argtype: Union[Tuple[Type[Any], ...], Type[Any]], name: str
1465) -> Any:
1466 if isinstance(arg, argtype):
1467 return arg
1468 else:
1469 if isinstance(argtype, tuple):
1470 raise exc.ArgumentError(
1471 "Argument '%s' is expected to be one of type %s, got '%s'"
1472 % (name, " or ".join("'%s'" % a for a in argtype), type(arg))
1473 )
1474 else:
1475 raise exc.ArgumentError(
1476 "Argument '%s' is expected to be of type '%s', got '%s'"
1477 % (name, argtype, type(arg))
1478 )
1479
1480
1481def dictlike_iteritems(dictlike):
1482 """Return a (key, value) iterator for almost any dict-like object."""
1483
1484 if hasattr(dictlike, "items"):
1485 return list(dictlike.items())
1486
1487 getter = getattr(dictlike, "__getitem__", getattr(dictlike, "get", None))
1488 if getter is None:
1489 raise TypeError("Object '%r' is not dict-like" % dictlike)
1490
1491 if hasattr(dictlike, "iterkeys"):
1492
1493 def iterator():
1494 for key in dictlike.iterkeys():
1495 assert getter is not None
1496 yield key, getter(key)
1497
1498 return iterator()
1499 elif hasattr(dictlike, "keys"):
1500 return iter((key, getter(key)) for key in dictlike.keys())
1501 else:
1502 raise TypeError("Object '%r' is not dict-like" % dictlike)
1503
1504
1505class classproperty(property):
1506 """A decorator that behaves like @property except that operates
1507 on classes rather than instances.
1508
1509 The decorator is currently special when using the declarative
1510 module, but note that the
1511 :class:`~.sqlalchemy.ext.declarative.declared_attr`
1512 decorator should be used for this purpose with declarative.
1513
1514 """
1515
1516 fget: Callable[[Any], Any]
1517
1518 def __init__(self, fget: Callable[[Any], Any], *arg: Any, **kw: Any):
1519 super().__init__(fget, *arg, **kw)
1520 self.__doc__ = fget.__doc__
1521
1522 def __get__(self, obj: Any, cls: Optional[type] = None) -> Any:
1523 return self.fget(cls)
1524
1525
1526class hybridproperty(Generic[_T]):
1527 def __init__(self, func: Callable[..., _T]):
1528 self.func = func
1529 self.clslevel = func
1530
1531 def __get__(self, instance: Any, owner: Any) -> _T:
1532 if instance is None:
1533 clsval = self.clslevel(owner)
1534 return clsval
1535 else:
1536 return self.func(instance)
1537
1538 def classlevel(self, func: Callable[..., Any]) -> hybridproperty[_T]:
1539 self.clslevel = func
1540 return self
1541
1542
1543class rw_hybridproperty(Generic[_T]):
1544 def __init__(self, func: Callable[..., _T]):
1545 self.func = func
1546 self.clslevel = func
1547 self.setfn: Optional[Callable[..., Any]] = None
1548
1549 def __get__(self, instance: Any, owner: Any) -> _T:
1550 if instance is None:
1551 clsval = self.clslevel(owner)
1552 return clsval
1553 else:
1554 return self.func(instance)
1555
1556 def __set__(self, instance: Any, value: Any) -> None:
1557 assert self.setfn is not None
1558 self.setfn(instance, value)
1559
1560 def setter(self, func: Callable[..., Any]) -> rw_hybridproperty[_T]:
1561 self.setfn = func
1562 return self
1563
1564 def classlevel(self, func: Callable[..., Any]) -> rw_hybridproperty[_T]:
1565 self.clslevel = func
1566 return self
1567
1568
1569class hybridmethod(Generic[_T]):
1570 """Decorate a function as cls- or instance- level."""
1571
1572 def __init__(self, func: Callable[..., _T]):
1573 self.func = self.__func__ = func
1574 self.clslevel = func
1575
1576 def __get__(self, instance: Any, owner: Any) -> Callable[..., _T]:
1577 if instance is None:
1578 return self.clslevel.__get__(owner, owner.__class__) # type:ignore
1579 else:
1580 return self.func.__get__(instance, owner) # type:ignore
1581
1582 def classlevel(self, func: Callable[..., Any]) -> hybridmethod[_T]:
1583 self.clslevel = func
1584 return self
1585
1586
1587class symbol(int):
1588 """A constant symbol.
1589
1590 >>> symbol("foo") is symbol("foo")
1591 True
1592 >>> symbol("foo")
1593 <symbol 'foo>
1594
1595 A slight refinement of the MAGICCOOKIE=object() pattern. The primary
1596 advantage of symbol() is its repr(). They are also singletons.
1597
1598 Repeated calls of symbol('name') will all return the same instance.
1599
1600 """
1601
1602 name: str
1603
1604 symbols: Dict[str, symbol] = {}
1605 _lock = threading.Lock()
1606
1607 def __new__(
1608 cls,
1609 name: str,
1610 doc: Optional[str] = None,
1611 canonical: Optional[int] = None,
1612 ) -> symbol:
1613 with cls._lock:
1614 sym = cls.symbols.get(name)
1615 if sym is None:
1616 assert isinstance(name, str)
1617 if canonical is None:
1618 canonical = hash(name)
1619 sym = int.__new__(symbol, canonical)
1620 sym.name = name
1621 if doc:
1622 sym.__doc__ = doc
1623
1624 # NOTE: we should ultimately get rid of this global thing,
1625 # however, currently it is to support pickling. The best
1626 # change would be when we are on py3.11 at a minimum, we
1627 # switch to stdlib enum.IntFlag.
1628 cls.symbols[name] = sym
1629 else:
1630 if canonical and canonical != sym:
1631 raise TypeError(
1632 f"Can't replace canonical symbol for {name!r} "
1633 f"with new int value {canonical}"
1634 )
1635 return sym
1636
1637 def __reduce__(self):
1638 return symbol, (self.name, "x", int(self))
1639
1640 def __str__(self):
1641 return repr(self)
1642
1643 def __repr__(self):
1644 return f"symbol({self.name!r})"
1645
1646
1647class _IntFlagMeta(type):
1648 def __init__(
1649 cls,
1650 classname: str,
1651 bases: Tuple[Type[Any], ...],
1652 dict_: Dict[str, Any],
1653 **kw: Any,
1654 ) -> None:
1655 items: List[symbol]
1656 cls._items = items = []
1657 for k, v in dict_.items():
1658 if re.match(r"^__.*__$", k):
1659 continue
1660 if isinstance(v, int):
1661 sym = symbol(k, canonical=v)
1662 elif not k.startswith("_"):
1663 raise TypeError("Expected integer values for IntFlag")
1664 else:
1665 continue
1666 setattr(cls, k, sym)
1667 items.append(sym)
1668
1669 cls.__members__ = _collections.immutabledict(
1670 {sym.name: sym for sym in items}
1671 )
1672
1673 def __iter__(self) -> Iterator[symbol]:
1674 raise NotImplementedError(
1675 "iter not implemented to ensure compatibility with "
1676 "Python 3.11 IntFlag. Please use __members__. See "
1677 "https://github.com/python/cpython/issues/99304"
1678 )
1679
1680
1681class _FastIntFlag(metaclass=_IntFlagMeta):
1682 """An 'IntFlag' copycat that isn't slow when performing bitwise
1683 operations.
1684
1685 the ``FastIntFlag`` class will return ``enum.IntFlag`` under TYPE_CHECKING
1686 and ``_FastIntFlag`` otherwise.
1687
1688 """
1689
1690
1691if TYPE_CHECKING:
1692 from enum import IntFlag
1693
1694 FastIntFlag = IntFlag
1695else:
1696 FastIntFlag = _FastIntFlag
1697
1698
1699_E = TypeVar("_E", bound=enum.Enum)
1700
1701
1702def parse_user_argument_for_enum(
1703 arg: Any,
1704 choices: Dict[_E, List[Any]],
1705 name: str,
1706 resolve_symbol_names: bool = False,
1707) -> Optional[_E]:
1708 """Given a user parameter, parse the parameter into a chosen value
1709 from a list of choice objects, typically Enum values.
1710
1711 The user argument can be a string name that matches the name of a
1712 symbol, or the symbol object itself, or any number of alternate choices
1713 such as True/False/ None etc.
1714
1715 :param arg: the user argument.
1716 :param choices: dictionary of enum values to lists of possible
1717 entries for each.
1718 :param name: name of the argument. Used in an :class:`.ArgumentError`
1719 that is raised if the parameter doesn't match any available argument.
1720
1721 """
1722 for enum_value, choice in choices.items():
1723 if arg is enum_value:
1724 return enum_value
1725 elif resolve_symbol_names and arg == enum_value.name:
1726 return enum_value
1727 elif arg in choice:
1728 return enum_value
1729
1730 if arg is None:
1731 return None
1732
1733 raise exc.ArgumentError(f"Invalid value for '{name}': {arg!r}")
1734
1735
1736_creation_order = 1
1737
1738
1739def set_creation_order(instance: Any) -> None:
1740 """Assign a '_creation_order' sequence to the given instance.
1741
1742 This allows multiple instances to be sorted in order of creation
1743 (typically within a single thread; the counter is not particularly
1744 threadsafe).
1745
1746 """
1747 global _creation_order
1748 instance._creation_order = _creation_order
1749 _creation_order += 1
1750
1751
1752def warn_exception(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
1753 """executes the given function, catches all exceptions and converts to
1754 a warning.
1755
1756 """
1757 try:
1758 return func(*args, **kwargs)
1759 except Exception:
1760 warn("%s('%s') ignored" % sys.exc_info()[0:2])
1761
1762
1763def ellipses_string(value, len_=25):
1764 try:
1765 if len(value) > len_:
1766 return "%s..." % value[0:len_]
1767 else:
1768 return value
1769 except TypeError:
1770 return value
1771
1772
1773class _hash_limit_string(str):
1774 """A string subclass that can only be hashed on a maximum amount
1775 of unique values.
1776
1777 This is used for warnings so that we can send out parameterized warnings
1778 without the __warningregistry__ of the module, or the non-overridable
1779 "once" registry within warnings.py, overloading memory,
1780
1781
1782 """
1783
1784 _hash: int
1785
1786 def __new__(
1787 cls, value: str, num: int, args: Sequence[Any]
1788 ) -> _hash_limit_string:
1789 interpolated = (value % args) + (
1790 " (this warning may be suppressed after %d occurrences)" % num
1791 )
1792 self = super().__new__(cls, interpolated)
1793 self._hash = hash("%s_%d" % (value, hash(interpolated) % num))
1794 return self
1795
1796 def __hash__(self) -> int:
1797 return self._hash
1798
1799 def __eq__(self, other: Any) -> bool:
1800 return hash(self) == hash(other)
1801
1802
1803def warn(msg: str, code: Optional[str] = None) -> None:
1804 """Issue a warning.
1805
1806 If msg is a string, :class:`.exc.SAWarning` is used as
1807 the category.
1808
1809 """
1810 if code:
1811 _warnings_warn(exc.SAWarning(msg, code=code))
1812 else:
1813 _warnings_warn(msg, exc.SAWarning)
1814
1815
1816def warn_limited(msg: str, args: Sequence[Any]) -> None:
1817 """Issue a warning with a parameterized string, limiting the number
1818 of registrations.
1819
1820 """
1821 if args:
1822 msg = _hash_limit_string(msg, 10, args)
1823 _warnings_warn(msg, exc.SAWarning)
1824
1825
1826_warning_tags: Dict[CodeType, Tuple[str, Type[Warning]]] = {}
1827
1828
1829def tag_method_for_warnings(
1830 message: str, category: Type[Warning]
1831) -> Callable[[_F], _F]:
1832 def go(fn):
1833 _warning_tags[fn.__code__] = (message, category)
1834 return fn
1835
1836 return go
1837
1838
1839_not_sa_pattern = re.compile(r"^(?:sqlalchemy\.(?!testing)|alembic\.)")
1840
1841
1842def _warnings_warn(
1843 message: Union[str, Warning],
1844 category: Optional[Type[Warning]] = None,
1845 stacklevel: int = 2,
1846) -> None:
1847
1848 if category is None and isinstance(message, Warning):
1849 category = type(message)
1850
1851 # adjust the given stacklevel to be outside of SQLAlchemy
1852 try:
1853 frame = sys._getframe(stacklevel)
1854 except ValueError:
1855 # being called from less than 3 (or given) stacklevels, weird,
1856 # but don't crash
1857 stacklevel = 0
1858 except:
1859 # _getframe() doesn't work, weird interpreter issue, weird,
1860 # ok, but don't crash
1861 stacklevel = 0
1862 else:
1863 stacklevel_found = warning_tag_found = False
1864 while frame is not None:
1865 # using __name__ here requires that we have __name__ in the
1866 # __globals__ of the decorated string functions we make also.
1867 # we generate this using {"__name__": fn.__module__}
1868 if not stacklevel_found and not re.match(
1869 _not_sa_pattern, frame.f_globals.get("__name__", "")
1870 ):
1871 # stop incrementing stack level if an out-of-SQLA line
1872 # were found.
1873 stacklevel_found = True
1874
1875 # however, for the warning tag thing, we have to keep
1876 # scanning up the whole traceback
1877
1878 if frame.f_code in _warning_tags:
1879 warning_tag_found = True
1880 (_suffix, _category) = _warning_tags[frame.f_code]
1881 category = category or _category
1882 message = f"{message} ({_suffix})"
1883
1884 frame = frame.f_back # type: ignore[assignment]
1885
1886 if not stacklevel_found:
1887 stacklevel += 1
1888 elif stacklevel_found and warning_tag_found:
1889 break
1890
1891 if category is not None:
1892 warnings.warn(message, category, stacklevel=stacklevel + 1)
1893 else:
1894 warnings.warn(message, stacklevel=stacklevel + 1)
1895
1896
1897def only_once(
1898 fn: Callable[..., _T], retry_on_exception: bool
1899) -> Callable[..., Optional[_T]]:
1900 """Decorate the given function to be a no-op after it is called exactly
1901 once."""
1902
1903 once = [fn]
1904
1905 def go(*arg: Any, **kw: Any) -> Optional[_T]:
1906 # strong reference fn so that it isn't garbage collected,
1907 # which interferes with the event system's expectations
1908 strong_fn = fn # noqa
1909 if once:
1910 once_fn = once.pop()
1911 try:
1912 return once_fn(*arg, **kw)
1913 except:
1914 if retry_on_exception:
1915 once.insert(0, once_fn)
1916 raise
1917
1918 return None
1919
1920 return go
1921
1922
1923_SQLA_RE = re.compile(r"sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py")
1924_UNITTEST_RE = re.compile(r"unit(?:2|test2?/)")
1925
1926
1927def chop_traceback(
1928 tb: List[str],
1929 exclude_prefix: re.Pattern[str] = _UNITTEST_RE,
1930 exclude_suffix: re.Pattern[str] = _SQLA_RE,
1931) -> List[str]:
1932 """Chop extraneous lines off beginning and end of a traceback.
1933
1934 :param tb:
1935 a list of traceback lines as returned by ``traceback.format_stack()``
1936
1937 :param exclude_prefix:
1938 a regular expression object matching lines to skip at beginning of
1939 ``tb``
1940
1941 :param exclude_suffix:
1942 a regular expression object matching lines to skip at end of ``tb``
1943 """
1944 start = 0
1945 end = len(tb) - 1
1946 while start <= end and exclude_prefix.search(tb[start]):
1947 start += 1
1948 while start <= end and exclude_suffix.search(tb[end]):
1949 end -= 1
1950 return tb[start : end + 1]
1951
1952
1953def attrsetter(attrname):
1954 code = "def set(obj, value): obj.%s = value" % attrname
1955 env = locals().copy()
1956 exec(code, env)
1957 return env["set"]
1958
1959
1960_dunders = re.compile("^__.+__$")
1961
1962
1963class TypingOnly:
1964 """A mixin class that marks a class as 'typing only', meaning it has
1965 absolutely no methods, attributes, or runtime functionality whatsoever.
1966
1967 """
1968
1969 __slots__ = ()
1970
1971 def __init_subclass__(cls) -> None:
1972 if TypingOnly in cls.__bases__:
1973 remaining = {
1974 name for name in cls.__dict__ if not _dunders.match(name)
1975 }
1976 if remaining:
1977 raise AssertionError(
1978 f"Class {cls} directly inherits TypingOnly but has "
1979 f"additional attributes {remaining}."
1980 )
1981 super().__init_subclass__()
1982
1983
1984class EnsureKWArg:
1985 r"""Apply translation of functions to accept \**kw arguments if they
1986 don't already.
1987
1988 Used to ensure cross-compatibility with third party legacy code, for things
1989 like compiler visit methods that need to accept ``**kw`` arguments,
1990 but may have been copied from old code that didn't accept them.
1991
1992 """
1993
1994 ensure_kwarg: str
1995 """a regular expression that indicates method names for which the method
1996 should accept ``**kw`` arguments.
1997
1998 The class will scan for methods matching the name template and decorate
1999 them if necessary to ensure ``**kw`` parameters are accepted.
2000
2001 """
2002
2003 def __init_subclass__(cls) -> None:
2004 fn_reg = cls.ensure_kwarg
2005 clsdict = cls.__dict__
2006 if fn_reg:
2007 for key in clsdict:
2008 m = re.match(fn_reg, key)
2009 if m:
2010 fn = clsdict[key]
2011 spec = compat.inspect_getfullargspec(fn)
2012 if not spec.varkw:
2013 wrapped = cls._wrap_w_kw(fn)
2014 setattr(cls, key, wrapped)
2015 super().__init_subclass__()
2016
2017 @classmethod
2018 def _wrap_w_kw(cls, fn: Callable[..., Any]) -> Callable[..., Any]:
2019 def wrap(*arg: Any, **kw: Any) -> Any:
2020 return fn(*arg)
2021
2022 return update_wrapper(wrap, fn)
2023
2024
2025def wrap_callable(wrapper, fn):
2026 """Augment functools.update_wrapper() to work with objects with
2027 a ``__call__()`` method.
2028
2029 :param fn:
2030 object with __call__ method
2031
2032 """
2033 if hasattr(fn, "__name__"):
2034 return update_wrapper(wrapper, fn)
2035 else:
2036 _f = wrapper
2037 _f.__name__ = fn.__class__.__name__
2038 if hasattr(fn, "__module__"):
2039 _f.__module__ = fn.__module__
2040
2041 if hasattr(fn.__call__, "__doc__") and fn.__call__.__doc__:
2042 _f.__doc__ = fn.__call__.__doc__
2043 elif fn.__doc__:
2044 _f.__doc__ = fn.__doc__
2045
2046 return _f
2047
2048
2049def quoted_token_parser(value):
2050 """Parse a dotted identifier with accommodation for quoted names.
2051
2052 Includes support for SQL-style double quotes as a literal character.
2053
2054 E.g.::
2055
2056 >>> quoted_token_parser("name")
2057 ["name"]
2058 >>> quoted_token_parser("schema.name")
2059 ["schema", "name"]
2060 >>> quoted_token_parser('"Schema"."Name"')
2061 ['Schema', 'Name']
2062 >>> quoted_token_parser('"Schema"."Name""Foo"')
2063 ['Schema', 'Name""Foo']
2064
2065 """
2066
2067 if '"' not in value:
2068 return value.split(".")
2069
2070 # 0 = outside of quotes
2071 # 1 = inside of quotes
2072 state = 0
2073 result: List[List[str]] = [[]]
2074 idx = 0
2075 lv = len(value)
2076 while idx < lv:
2077 char = value[idx]
2078 if char == '"':
2079 if state == 1 and idx < lv - 1 and value[idx + 1] == '"':
2080 result[-1].append('"')
2081 idx += 1
2082 else:
2083 state ^= 1
2084 elif char == "." and state == 0:
2085 result.append([])
2086 else:
2087 result[-1].append(char)
2088 idx += 1
2089
2090 return ["".join(token) for token in result]
2091
2092
2093def add_parameter_text(params: Any, text: str) -> Callable[[_F], _F]:
2094 params = _collections.to_list(params)
2095
2096 def decorate(fn):
2097 doc = fn.__doc__ is not None and fn.__doc__ or ""
2098 if doc:
2099 doc = inject_param_text(doc, {param: text for param in params})
2100 fn.__doc__ = doc
2101 return fn
2102
2103 return decorate
2104
2105
2106def _dedent_docstring(text: str) -> str:
2107 split_text = text.split("\n", 1)
2108 if len(split_text) == 1:
2109 return text
2110 else:
2111 firstline, remaining = split_text
2112 if not firstline.startswith(" "):
2113 return firstline + "\n" + textwrap.dedent(remaining)
2114 else:
2115 return textwrap.dedent(text)
2116
2117
2118def inject_docstring_text(
2119 given_doctext: Optional[str], injecttext: str, pos: int
2120) -> str:
2121 doctext: str = _dedent_docstring(given_doctext or "")
2122 lines = doctext.split("\n")
2123 if len(lines) == 1:
2124 lines.append("")
2125 injectlines = textwrap.dedent(injecttext).split("\n")
2126 if injectlines[0]:
2127 injectlines.insert(0, "")
2128
2129 blanks = [num for num, line in enumerate(lines) if not line.strip()]
2130 blanks.insert(0, 0)
2131
2132 inject_pos = blanks[min(pos, len(blanks) - 1)]
2133
2134 lines = lines[0:inject_pos] + injectlines + lines[inject_pos:]
2135 return "\n".join(lines)
2136
2137
2138_param_reg = re.compile(r"(\s+):param (.+?):")
2139
2140
2141def inject_param_text(doctext: str, inject_params: Dict[str, str]) -> str:
2142 doclines = collections.deque(doctext.splitlines())
2143 lines = []
2144
2145 # TODO: this is not working for params like ":param case_sensitive=True:"
2146
2147 to_inject = None
2148 while doclines:
2149 line = doclines.popleft()
2150
2151 m = _param_reg.match(line)
2152
2153 if to_inject is None:
2154 if m:
2155 param = m.group(2).lstrip("*")
2156 if param in inject_params:
2157 # default indent to that of :param: plus one
2158 indent = " " * len(m.group(1)) + " "
2159
2160 # but if the next line has text, use that line's
2161 # indentation
2162 if doclines:
2163 m2 = re.match(r"(\s+)\S", doclines[0])
2164 if m2:
2165 indent = " " * len(m2.group(1))
2166
2167 to_inject = indent + inject_params[param]
2168 elif m:
2169 lines.extend(["\n", to_inject, "\n"])
2170 to_inject = None
2171 elif not line.rstrip():
2172 lines.extend([line, to_inject, "\n"])
2173 to_inject = None
2174 elif line.endswith("::"):
2175 # TODO: this still won't cover if the code example itself has
2176 # blank lines in it, need to detect those via indentation.
2177 lines.extend([line, doclines.popleft()])
2178 continue
2179 lines.append(line)
2180
2181 return "\n".join(lines)
2182
2183
2184def repr_tuple_names(names: List[str]) -> Optional[str]:
2185 """Trims a list of strings from the middle and return a string of up to
2186 four elements. Strings greater than 11 characters will be truncated"""
2187 if len(names) == 0:
2188 return None
2189 flag = len(names) <= 4
2190 names = names[0:4] if flag else names[0:3] + names[-1:]
2191 res = ["%s.." % name[:11] if len(name) > 11 else name for name in names]
2192 if flag:
2193 return ", ".join(res)
2194 else:
2195 return "%s, ..., %s" % (", ".join(res[0:3]), res[-1])
2196
2197
2198def has_compiled_ext(raise_=False):
2199 from ._has_cython import HAS_CYEXTENSION
2200
2201 if HAS_CYEXTENSION:
2202 return True
2203 elif raise_:
2204 raise ImportError(
2205 "cython extensions were expected to be installed, "
2206 "but are not present"
2207 )
2208 else:
2209 return False
2210
2211
2212def load_uncompiled_module(module: _M) -> _M:
2213 """Load the non-compied version of a module that is also
2214 compiled with cython.
2215 """
2216 full_name = module.__name__
2217 assert module.__spec__
2218 parent_name = module.__spec__.parent
2219 assert parent_name
2220 parent_module = sys.modules[parent_name]
2221 assert parent_module.__spec__
2222 package_path = parent_module.__spec__.origin
2223 assert package_path and package_path.endswith("__init__.py")
2224
2225 name = full_name.split(".")[-1]
2226 module_path = package_path.replace("__init__.py", f"{name}.py")
2227
2228 py_spec = importlib.util.spec_from_file_location(full_name, module_path)
2229 assert py_spec
2230 py_module = importlib.util.module_from_spec(py_spec)
2231 assert py_spec.loader
2232 py_spec.loader.exec_module(py_module)
2233 return cast(_M, py_module)
2234
2235
2236class _Missing(enum.Enum):
2237 Missing = enum.auto()
2238
2239
2240Missing = _Missing.Missing
2241MissingOr = Union[_T, Literal[_Missing.Missing]]