1# util/langhelpers.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Routines to help with the creation, loading and introspection of
10modules, classes, hierarchies, attributes, functions, and methods.
11
12"""
13from __future__ import annotations
14
15import collections
16import enum
17from functools import update_wrapper
18import importlib.util
19import inspect
20import itertools
21import operator
22import re
23import sys
24import textwrap
25import threading
26import types
27from types import CodeType
28from types import ModuleType
29from typing import Any
30from typing import Callable
31from typing import cast
32from typing import Dict
33from typing import FrozenSet
34from typing import Generic
35from typing import Iterator
36from typing import List
37from typing import Mapping
38from typing import NoReturn
39from typing import Optional
40from typing import overload
41from typing import Sequence
42from typing import Set
43from typing import Tuple
44from typing import Type
45from typing import TYPE_CHECKING
46from typing import TypeVar
47from typing import Union
48import warnings
49
50from . import _collections
51from . import compat
52from .typing import Literal
53from .. import exc
54
55_T = TypeVar("_T")
56_T_co = TypeVar("_T_co", covariant=True)
57_F = TypeVar("_F", bound=Callable[..., Any])
58_MA = TypeVar("_MA", bound="HasMemoized.memoized_attribute[Any]")
59_M = TypeVar("_M", bound=ModuleType)
60
61if compat.py310:
62
63 def get_annotations(obj: Any) -> Mapping[str, Any]:
64 return inspect.get_annotations(obj)
65
66else:
67
68 def get_annotations(obj: Any) -> Mapping[str, Any]:
69 # it's been observed that cls.__annotations__ can be non present.
70 # it's not clear what causes this, running under tox py38 it
71 # happens, running straight pytest it doesnt
72
73 # https://docs.python.org/3/howto/annotations.html#annotations-howto
74 if isinstance(obj, type):
75 ann = obj.__dict__.get("__annotations__", None)
76 else:
77 ann = getattr(obj, "__annotations__", None)
78
79 if ann is None:
80 return _collections.EMPTY_DICT
81 else:
82 return cast("Mapping[str, Any]", ann)
83
84
85def md5_hex(x: Any) -> str:
86 x = x.encode("utf-8")
87 m = compat.md5_not_for_security()
88 m.update(x)
89 return cast(str, m.hexdigest())
90
91
92class safe_reraise:
93 """Reraise an exception after invoking some
94 handler code.
95
96 Stores the existing exception info before
97 invoking so that it is maintained across a potential
98 coroutine context switch.
99
100 e.g.::
101
102 try:
103 sess.commit()
104 except:
105 with safe_reraise():
106 sess.rollback()
107
108 TODO: we should at some point evaluate current behaviors in this regard
109 based on current greenlet, gevent/eventlet implementations in Python 3, and
110 also see the degree to which our own asyncio (based on greenlet also) is
111 impacted by this. .rollback() will cause IO / context switch to occur in
112 all these scenarios; what happens to the exception context from an
113 "except:" block if we don't explicitly store it? Original issue was #2703.
114
115 """
116
117 __slots__ = ("_exc_info",)
118
119 _exc_info: Union[
120 None,
121 Tuple[
122 Type[BaseException],
123 BaseException,
124 types.TracebackType,
125 ],
126 Tuple[None, None, None],
127 ]
128
129 def __enter__(self) -> None:
130 self._exc_info = sys.exc_info()
131
132 def __exit__(
133 self,
134 type_: Optional[Type[BaseException]],
135 value: Optional[BaseException],
136 traceback: Optional[types.TracebackType],
137 ) -> NoReturn:
138 assert self._exc_info is not None
139 # see #2703 for notes
140 if type_ is None:
141 exc_type, exc_value, exc_tb = self._exc_info
142 assert exc_value is not None
143 self._exc_info = None # remove potential circular references
144 raise exc_value.with_traceback(exc_tb)
145 else:
146 self._exc_info = None # remove potential circular references
147 assert value is not None
148 raise value.with_traceback(traceback)
149
150
151def walk_subclasses(cls: Type[_T]) -> Iterator[Type[_T]]:
152 seen: Set[Any] = set()
153
154 stack = [cls]
155 while stack:
156 cls = stack.pop()
157 if cls in seen:
158 continue
159 else:
160 seen.add(cls)
161 stack.extend(cls.__subclasses__())
162 yield cls
163
164
165def string_or_unprintable(element: Any) -> str:
166 if isinstance(element, str):
167 return element
168 else:
169 try:
170 return str(element)
171 except Exception:
172 return "unprintable element %r" % element
173
174
175def clsname_as_plain_name(
176 cls: Type[Any], use_name: Optional[str] = None
177) -> str:
178 name = use_name or cls.__name__
179 return " ".join(n.lower() for n in re.findall(r"([A-Z][a-z]+|SQL)", name))
180
181
182def method_is_overridden(
183 instance_or_cls: Union[Type[Any], object],
184 against_method: Callable[..., Any],
185) -> bool:
186 """Return True if the two class methods don't match."""
187
188 if not isinstance(instance_or_cls, type):
189 current_cls = instance_or_cls.__class__
190 else:
191 current_cls = instance_or_cls
192
193 method_name = against_method.__name__
194
195 current_method: types.MethodType = getattr(current_cls, method_name)
196
197 return current_method != against_method
198
199
200def decode_slice(slc: slice) -> Tuple[Any, ...]:
201 """decode a slice object as sent to __getitem__.
202
203 takes into account the 2.5 __index__() method, basically.
204
205 """
206 ret: List[Any] = []
207 for x in slc.start, slc.stop, slc.step:
208 if hasattr(x, "__index__"):
209 x = x.__index__()
210 ret.append(x)
211 return tuple(ret)
212
213
214def _unique_symbols(used: Sequence[str], *bases: str) -> Iterator[str]:
215 used_set = set(used)
216 for base in bases:
217 pool = itertools.chain(
218 (base,),
219 map(lambda i: base + str(i), range(1000)),
220 )
221 for sym in pool:
222 if sym not in used_set:
223 used_set.add(sym)
224 yield sym
225 break
226 else:
227 raise NameError("exhausted namespace for symbol base %s" % base)
228
229
230def map_bits(fn: Callable[[int], Any], n: int) -> Iterator[Any]:
231 """Call the given function given each nonzero bit from n."""
232
233 while n:
234 b = n & (~n + 1)
235 yield fn(b)
236 n ^= b
237
238
239_Fn = TypeVar("_Fn", bound="Callable[..., Any]")
240
241# this seems to be in flux in recent mypy versions
242
243
244def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]:
245 """A signature-matching decorator factory."""
246
247 def decorate(fn: _Fn) -> _Fn:
248 if not inspect.isfunction(fn) and not inspect.ismethod(fn):
249 raise Exception("not a decoratable function")
250
251 spec = compat.inspect_getfullargspec(fn)
252 env: Dict[str, Any] = {}
253
254 spec = _update_argspec_defaults_into_env(spec, env)
255
256 names = (
257 tuple(cast("Tuple[str, ...]", spec[0]))
258 + cast("Tuple[str, ...]", spec[1:3])
259 + (fn.__name__,)
260 )
261 targ_name, fn_name = _unique_symbols(names, "target", "fn")
262
263 metadata: Dict[str, Optional[str]] = dict(target=targ_name, fn=fn_name)
264 metadata.update(format_argspec_plus(spec, grouped=False))
265 metadata["name"] = fn.__name__
266
267 if inspect.iscoroutinefunction(fn):
268 metadata["prefix"] = "async "
269 metadata["target_prefix"] = "await "
270 else:
271 metadata["prefix"] = ""
272 metadata["target_prefix"] = ""
273
274 # look for __ positional arguments. This is a convention in
275 # SQLAlchemy that arguments should be passed positionally
276 # rather than as keyword
277 # arguments. note that apply_pos doesn't currently work in all cases
278 # such as when a kw-only indicator "*" is present, which is why
279 # we limit the use of this to just that case we can detect. As we add
280 # more kinds of methods that use @decorator, things may have to
281 # be further improved in this area
282 if "__" in repr(spec[0]):
283 code = (
284 """\
285%(prefix)sdef %(name)s%(grouped_args)s:
286 return %(target_prefix)s%(target)s(%(fn)s, %(apply_pos)s)
287"""
288 % metadata
289 )
290 else:
291 code = (
292 """\
293%(prefix)sdef %(name)s%(grouped_args)s:
294 return %(target_prefix)s%(target)s(%(fn)s, %(apply_kw)s)
295"""
296 % metadata
297 )
298
299 mod = sys.modules[fn.__module__]
300 env.update(vars(mod))
301 env.update({targ_name: target, fn_name: fn, "__name__": fn.__module__})
302
303 decorated = cast(
304 types.FunctionType,
305 _exec_code_in_env(code, env, fn.__name__),
306 )
307 decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
308
309 decorated.__wrapped__ = fn # type: ignore[attr-defined]
310 return update_wrapper(decorated, fn) # type: ignore[return-value]
311
312 return update_wrapper(decorate, target) # type: ignore[return-value]
313
314
315def _update_argspec_defaults_into_env(spec, env):
316 """given a FullArgSpec, convert defaults to be symbol names in an env."""
317
318 if spec.defaults:
319 new_defaults = []
320 i = 0
321 for arg in spec.defaults:
322 if type(arg).__module__ not in ("builtins", "__builtin__"):
323 name = "x%d" % i
324 env[name] = arg
325 new_defaults.append(name)
326 i += 1
327 else:
328 new_defaults.append(arg)
329 elem = list(spec)
330 elem[3] = tuple(new_defaults)
331 return compat.FullArgSpec(*elem)
332 else:
333 return spec
334
335
336def _exec_code_in_env(
337 code: Union[str, types.CodeType], env: Dict[str, Any], fn_name: str
338) -> Callable[..., Any]:
339 exec(code, env)
340 return env[fn_name] # type: ignore[no-any-return]
341
342
343_PF = TypeVar("_PF")
344_TE = TypeVar("_TE")
345
346
347class PluginLoader:
348 def __init__(
349 self, group: str, auto_fn: Optional[Callable[..., Any]] = None
350 ):
351 self.group = group
352 self.impls: Dict[str, Any] = {}
353 self.auto_fn = auto_fn
354
355 def clear(self):
356 self.impls.clear()
357
358 def load(self, name: str) -> Any:
359 if name in self.impls:
360 return self.impls[name]()
361
362 if self.auto_fn:
363 loader = self.auto_fn(name)
364 if loader:
365 self.impls[name] = loader
366 return loader()
367
368 for impl in compat.importlib_metadata_get(self.group):
369 if impl.name == name:
370 self.impls[name] = impl.load
371 return impl.load()
372
373 raise exc.NoSuchModuleError(
374 "Can't load plugin: %s:%s" % (self.group, name)
375 )
376
377 def register(self, name: str, modulepath: str, objname: str) -> None:
378 def load():
379 mod = __import__(modulepath)
380 for token in modulepath.split(".")[1:]:
381 mod = getattr(mod, token)
382 return getattr(mod, objname)
383
384 self.impls[name] = load
385
386
387def _inspect_func_args(fn):
388 try:
389 co_varkeywords = inspect.CO_VARKEYWORDS
390 except AttributeError:
391 # https://docs.python.org/3/library/inspect.html
392 # The flags are specific to CPython, and may not be defined in other
393 # Python implementations. Furthermore, the flags are an implementation
394 # detail, and can be removed or deprecated in future Python releases.
395 spec = compat.inspect_getfullargspec(fn)
396 return spec[0], bool(spec[2])
397 else:
398 # use fn.__code__ plus flags to reduce method call overhead
399 co = fn.__code__
400 nargs = co.co_argcount
401 return (
402 list(co.co_varnames[:nargs]),
403 bool(co.co_flags & co_varkeywords),
404 )
405
406
407@overload
408def get_cls_kwargs(
409 cls: type,
410 *,
411 _set: Optional[Set[str]] = None,
412 raiseerr: Literal[True] = ...,
413) -> Set[str]: ...
414
415
416@overload
417def get_cls_kwargs(
418 cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
419) -> Optional[Set[str]]: ...
420
421
422def get_cls_kwargs(
423 cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
424) -> Optional[Set[str]]:
425 r"""Return the full set of inherited kwargs for the given `cls`.
426
427 Probes a class's __init__ method, collecting all named arguments. If the
428 __init__ defines a \**kwargs catch-all, then the constructor is presumed
429 to pass along unrecognized keywords to its base classes, and the
430 collection process is repeated recursively on each of the bases.
431
432 Uses a subset of inspect.getfullargspec() to cut down on method overhead,
433 as this is used within the Core typing system to create copies of type
434 objects which is a performance-sensitive operation.
435
436 No anonymous tuple arguments please !
437
438 """
439 toplevel = _set is None
440 if toplevel:
441 _set = set()
442 assert _set is not None
443
444 ctr = cls.__dict__.get("__init__", False)
445
446 has_init = (
447 ctr
448 and isinstance(ctr, types.FunctionType)
449 and isinstance(ctr.__code__, types.CodeType)
450 )
451
452 if has_init:
453 names, has_kw = _inspect_func_args(ctr)
454 _set.update(names)
455
456 if not has_kw and not toplevel:
457 if raiseerr:
458 raise TypeError(
459 f"given cls {cls} doesn't have an __init__ method"
460 )
461 else:
462 return None
463 else:
464 has_kw = False
465
466 if not has_init or has_kw:
467 for c in cls.__bases__:
468 if get_cls_kwargs(c, _set=_set) is None:
469 break
470
471 _set.discard("self")
472 return _set
473
474
475def get_func_kwargs(func: Callable[..., Any]) -> List[str]:
476 """Return the set of legal kwargs for the given `func`.
477
478 Uses getargspec so is safe to call for methods, functions,
479 etc.
480
481 """
482
483 return compat.inspect_getfullargspec(func)[0]
484
485
486def get_callable_argspec(
487 fn: Callable[..., Any], no_self: bool = False, _is_init: bool = False
488) -> compat.FullArgSpec:
489 """Return the argument signature for any callable.
490
491 All pure-Python callables are accepted, including
492 functions, methods, classes, objects with __call__;
493 builtins and other edge cases like functools.partial() objects
494 raise a TypeError.
495
496 """
497 if inspect.isbuiltin(fn):
498 raise TypeError("Can't inspect builtin: %s" % fn)
499 elif inspect.isfunction(fn):
500 if _is_init and no_self:
501 spec = compat.inspect_getfullargspec(fn)
502 return compat.FullArgSpec(
503 spec.args[1:],
504 spec.varargs,
505 spec.varkw,
506 spec.defaults,
507 spec.kwonlyargs,
508 spec.kwonlydefaults,
509 spec.annotations,
510 )
511 else:
512 return compat.inspect_getfullargspec(fn)
513 elif inspect.ismethod(fn):
514 if no_self and (_is_init or fn.__self__):
515 spec = compat.inspect_getfullargspec(fn.__func__)
516 return compat.FullArgSpec(
517 spec.args[1:],
518 spec.varargs,
519 spec.varkw,
520 spec.defaults,
521 spec.kwonlyargs,
522 spec.kwonlydefaults,
523 spec.annotations,
524 )
525 else:
526 return compat.inspect_getfullargspec(fn.__func__)
527 elif inspect.isclass(fn):
528 return get_callable_argspec(
529 fn.__init__, no_self=no_self, _is_init=True
530 )
531 elif hasattr(fn, "__func__"):
532 return compat.inspect_getfullargspec(fn.__func__)
533 elif hasattr(fn, "__call__"):
534 if inspect.ismethod(fn.__call__):
535 return get_callable_argspec(fn.__call__, no_self=no_self)
536 else:
537 raise TypeError("Can't inspect callable: %s" % fn)
538 else:
539 raise TypeError("Can't inspect callable: %s" % fn)
540
541
542def format_argspec_plus(
543 fn: Union[Callable[..., Any], compat.FullArgSpec], grouped: bool = True
544) -> Dict[str, Optional[str]]:
545 """Returns a dictionary of formatted, introspected function arguments.
546
547 A enhanced variant of inspect.formatargspec to support code generation.
548
549 fn
550 An inspectable callable or tuple of inspect getargspec() results.
551 grouped
552 Defaults to True; include (parens, around, argument) lists
553
554 Returns:
555
556 args
557 Full inspect.formatargspec for fn
558 self_arg
559 The name of the first positional argument, varargs[0], or None
560 if the function defines no positional arguments.
561 apply_pos
562 args, re-written in calling rather than receiving syntax. Arguments are
563 passed positionally.
564 apply_kw
565 Like apply_pos, except keyword-ish args are passed as keywords.
566 apply_pos_proxied
567 Like apply_pos but omits the self/cls argument
568
569 Example::
570
571 >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123)
572 {'grouped_args': '(self, a, b, c=3, **d)',
573 'self_arg': 'self',
574 'apply_kw': '(self, a, b, c=c, **d)',
575 'apply_pos': '(self, a, b, c, **d)'}
576
577 """
578 if callable(fn):
579 spec = compat.inspect_getfullargspec(fn)
580 else:
581 spec = fn
582
583 args = compat.inspect_formatargspec(*spec)
584
585 apply_pos = compat.inspect_formatargspec(
586 spec[0], spec[1], spec[2], None, spec[4]
587 )
588
589 if spec[0]:
590 self_arg = spec[0][0]
591
592 apply_pos_proxied = compat.inspect_formatargspec(
593 spec[0][1:], spec[1], spec[2], None, spec[4]
594 )
595
596 elif spec[1]:
597 # I'm not sure what this is
598 self_arg = "%s[0]" % spec[1]
599
600 apply_pos_proxied = apply_pos
601 else:
602 self_arg = None
603 apply_pos_proxied = apply_pos
604
605 num_defaults = 0
606 if spec[3]:
607 num_defaults += len(cast(Tuple[Any], spec[3]))
608 if spec[4]:
609 num_defaults += len(spec[4])
610
611 name_args = spec[0] + spec[4]
612
613 defaulted_vals: Union[List[str], Tuple[()]]
614
615 if num_defaults:
616 defaulted_vals = name_args[0 - num_defaults :]
617 else:
618 defaulted_vals = ()
619
620 apply_kw = compat.inspect_formatargspec(
621 name_args,
622 spec[1],
623 spec[2],
624 defaulted_vals,
625 formatvalue=lambda x: "=" + str(x),
626 )
627
628 if spec[0]:
629 apply_kw_proxied = compat.inspect_formatargspec(
630 name_args[1:],
631 spec[1],
632 spec[2],
633 defaulted_vals,
634 formatvalue=lambda x: "=" + str(x),
635 )
636 else:
637 apply_kw_proxied = apply_kw
638
639 if grouped:
640 return dict(
641 grouped_args=args,
642 self_arg=self_arg,
643 apply_pos=apply_pos,
644 apply_kw=apply_kw,
645 apply_pos_proxied=apply_pos_proxied,
646 apply_kw_proxied=apply_kw_proxied,
647 )
648 else:
649 return dict(
650 grouped_args=args,
651 self_arg=self_arg,
652 apply_pos=apply_pos[1:-1],
653 apply_kw=apply_kw[1:-1],
654 apply_pos_proxied=apply_pos_proxied[1:-1],
655 apply_kw_proxied=apply_kw_proxied[1:-1],
656 )
657
658
659def format_argspec_init(method, grouped=True):
660 """format_argspec_plus with considerations for typical __init__ methods
661
662 Wraps format_argspec_plus with error handling strategies for typical
663 __init__ cases::
664
665 object.__init__ -> (self)
666 other unreflectable (usually C) -> (self, *args, **kwargs)
667
668 """
669 if method is object.__init__:
670 grouped_args = "(self)"
671 args = "(self)" if grouped else "self"
672 proxied = "()" if grouped else ""
673 else:
674 try:
675 return format_argspec_plus(method, grouped=grouped)
676 except TypeError:
677 grouped_args = "(self, *args, **kwargs)"
678 args = grouped_args if grouped else "self, *args, **kwargs"
679 proxied = "(*args, **kwargs)" if grouped else "*args, **kwargs"
680 return dict(
681 self_arg="self",
682 grouped_args=grouped_args,
683 apply_pos=args,
684 apply_kw=args,
685 apply_pos_proxied=proxied,
686 apply_kw_proxied=proxied,
687 )
688
689
690def create_proxy_methods(
691 target_cls: Type[Any],
692 target_cls_sphinx_name: str,
693 proxy_cls_sphinx_name: str,
694 classmethods: Sequence[str] = (),
695 methods: Sequence[str] = (),
696 attributes: Sequence[str] = (),
697 use_intermediate_variable: Sequence[str] = (),
698) -> Callable[[_T], _T]:
699 """A class decorator indicating attributes should refer to a proxy
700 class.
701
702 This decorator is now a "marker" that does nothing at runtime. Instead,
703 it is consumed by the tools/generate_proxy_methods.py script to
704 statically generate proxy methods and attributes that are fully
705 recognized by typing tools such as mypy.
706
707 """
708
709 def decorate(cls):
710 return cls
711
712 return decorate
713
714
715def getargspec_init(method):
716 """inspect.getargspec with considerations for typical __init__ methods
717
718 Wraps inspect.getargspec with error handling for typical __init__ cases::
719
720 object.__init__ -> (self)
721 other unreflectable (usually C) -> (self, *args, **kwargs)
722
723 """
724 try:
725 return compat.inspect_getfullargspec(method)
726 except TypeError:
727 if method is object.__init__:
728 return (["self"], None, None, None)
729 else:
730 return (["self"], "args", "kwargs", None)
731
732
733def unbound_method_to_callable(func_or_cls):
734 """Adjust the incoming callable such that a 'self' argument is not
735 required.
736
737 """
738
739 if isinstance(func_or_cls, types.MethodType) and not func_or_cls.__self__:
740 return func_or_cls.__func__
741 else:
742 return func_or_cls
743
744
745def generic_repr(
746 obj: Any,
747 additional_kw: Sequence[Tuple[str, Any]] = (),
748 to_inspect: Optional[Union[object, List[object]]] = None,
749 omit_kwarg: Sequence[str] = (),
750) -> str:
751 """Produce a __repr__() based on direct association of the __init__()
752 specification vs. same-named attributes present.
753
754 """
755 if to_inspect is None:
756 to_inspect = [obj]
757 else:
758 to_inspect = _collections.to_list(to_inspect)
759
760 missing = object()
761
762 pos_args = []
763 kw_args: _collections.OrderedDict[str, Any] = _collections.OrderedDict()
764 vargs = None
765 for i, insp in enumerate(to_inspect):
766 try:
767 spec = compat.inspect_getfullargspec(insp.__init__)
768 except TypeError:
769 continue
770 else:
771 default_len = len(spec.defaults) if spec.defaults else 0
772 if i == 0:
773 if spec.varargs:
774 vargs = spec.varargs
775 if default_len:
776 pos_args.extend(spec.args[1:-default_len])
777 else:
778 pos_args.extend(spec.args[1:])
779 else:
780 kw_args.update(
781 [(arg, missing) for arg in spec.args[1:-default_len]]
782 )
783
784 if default_len:
785 assert spec.defaults
786 kw_args.update(
787 [
788 (arg, default)
789 for arg, default in zip(
790 spec.args[-default_len:], spec.defaults
791 )
792 ]
793 )
794 output: List[str] = []
795
796 output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
797
798 if vargs is not None and hasattr(obj, vargs):
799 output.extend([repr(val) for val in getattr(obj, vargs)])
800
801 for arg, defval in kw_args.items():
802 if arg in omit_kwarg:
803 continue
804 try:
805 val = getattr(obj, arg, missing)
806 if val is not missing and val != defval:
807 output.append("%s=%r" % (arg, val))
808 except Exception:
809 pass
810
811 if additional_kw:
812 for arg, defval in additional_kw:
813 try:
814 val = getattr(obj, arg, missing)
815 if val is not missing and val != defval:
816 output.append("%s=%r" % (arg, val))
817 except Exception:
818 pass
819
820 return "%s(%s)" % (obj.__class__.__name__, ", ".join(output))
821
822
823class portable_instancemethod:
824 """Turn an instancemethod into a (parent, name) pair
825 to produce a serializable callable.
826
827 """
828
829 __slots__ = "target", "name", "kwargs", "__weakref__"
830
831 def __getstate__(self):
832 return {
833 "target": self.target,
834 "name": self.name,
835 "kwargs": self.kwargs,
836 }
837
838 def __setstate__(self, state):
839 self.target = state["target"]
840 self.name = state["name"]
841 self.kwargs = state.get("kwargs", ())
842
843 def __init__(self, meth, kwargs=()):
844 self.target = meth.__self__
845 self.name = meth.__name__
846 self.kwargs = kwargs
847
848 def __call__(self, *arg, **kw):
849 kw.update(self.kwargs)
850 return getattr(self.target, self.name)(*arg, **kw)
851
852
853def class_hierarchy(cls):
854 """Return an unordered sequence of all classes related to cls.
855
856 Traverses diamond hierarchies.
857
858 Fibs slightly: subclasses of builtin types are not returned. Thus
859 class_hierarchy(class A(object)) returns (A, object), not A plus every
860 class systemwide that derives from object.
861
862 """
863
864 hier = {cls}
865 process = list(cls.__mro__)
866 while process:
867 c = process.pop()
868 bases = (_ for _ in c.__bases__ if _ not in hier)
869
870 for b in bases:
871 process.append(b)
872 hier.add(b)
873
874 if c.__module__ == "builtins" or not hasattr(c, "__subclasses__"):
875 continue
876
877 for s in [
878 _
879 for _ in (
880 c.__subclasses__()
881 if not issubclass(c, type)
882 else c.__subclasses__(c)
883 )
884 if _ not in hier
885 ]:
886 process.append(s)
887 hier.add(s)
888 return list(hier)
889
890
891def iterate_attributes(cls):
892 """iterate all the keys and attributes associated
893 with a class, without using getattr().
894
895 Does not use getattr() so that class-sensitive
896 descriptors (i.e. property.__get__()) are not called.
897
898 """
899 keys = dir(cls)
900 for key in keys:
901 for c in cls.__mro__:
902 if key in c.__dict__:
903 yield (key, c.__dict__[key])
904 break
905
906
907def monkeypatch_proxied_specials(
908 into_cls,
909 from_cls,
910 skip=None,
911 only=None,
912 name="self.proxy",
913 from_instance=None,
914):
915 """Automates delegation of __specials__ for a proxying type."""
916
917 if only:
918 dunders = only
919 else:
920 if skip is None:
921 skip = (
922 "__slots__",
923 "__del__",
924 "__getattribute__",
925 "__metaclass__",
926 "__getstate__",
927 "__setstate__",
928 )
929 dunders = [
930 m
931 for m in dir(from_cls)
932 if (
933 m.startswith("__")
934 and m.endswith("__")
935 and not hasattr(into_cls, m)
936 and m not in skip
937 )
938 ]
939
940 for method in dunders:
941 try:
942 maybe_fn = getattr(from_cls, method)
943 if not hasattr(maybe_fn, "__call__"):
944 continue
945 maybe_fn = getattr(maybe_fn, "__func__", maybe_fn)
946 fn = cast(types.FunctionType, maybe_fn)
947
948 except AttributeError:
949 continue
950 try:
951 spec = compat.inspect_getfullargspec(fn)
952 fn_args = compat.inspect_formatargspec(spec[0])
953 d_args = compat.inspect_formatargspec(spec[0][1:])
954 except TypeError:
955 fn_args = "(self, *args, **kw)"
956 d_args = "(*args, **kw)"
957
958 py = (
959 "def %(method)s%(fn_args)s: "
960 "return %(name)s.%(method)s%(d_args)s" % locals()
961 )
962
963 env: Dict[str, types.FunctionType] = (
964 from_instance is not None and {name: from_instance} or {}
965 )
966 exec(py, env)
967 try:
968 env[method].__defaults__ = fn.__defaults__
969 except AttributeError:
970 pass
971 setattr(into_cls, method, env[method])
972
973
974def methods_equivalent(meth1, meth2):
975 """Return True if the two methods are the same implementation."""
976
977 return getattr(meth1, "__func__", meth1) is getattr(
978 meth2, "__func__", meth2
979 )
980
981
982def as_interface(obj, cls=None, methods=None, required=None):
983 """Ensure basic interface compliance for an instance or dict of callables.
984
985 Checks that ``obj`` implements public methods of ``cls`` or has members
986 listed in ``methods``. If ``required`` is not supplied, implementing at
987 least one interface method is sufficient. Methods present on ``obj`` that
988 are not in the interface are ignored.
989
990 If ``obj`` is a dict and ``dict`` does not meet the interface
991 requirements, the keys of the dictionary are inspected. Keys present in
992 ``obj`` that are not in the interface will raise TypeErrors.
993
994 Raises TypeError if ``obj`` does not meet the interface criteria.
995
996 In all passing cases, an object with callable members is returned. In the
997 simple case, ``obj`` is returned as-is; if dict processing kicks in then
998 an anonymous class is returned.
999
1000 obj
1001 A type, instance, or dictionary of callables.
1002 cls
1003 Optional, a type. All public methods of cls are considered the
1004 interface. An ``obj`` instance of cls will always pass, ignoring
1005 ``required``..
1006 methods
1007 Optional, a sequence of method names to consider as the interface.
1008 required
1009 Optional, a sequence of mandatory implementations. If omitted, an
1010 ``obj`` that provides at least one interface method is considered
1011 sufficient. As a convenience, required may be a type, in which case
1012 all public methods of the type are required.
1013
1014 """
1015 if not cls and not methods:
1016 raise TypeError("a class or collection of method names are required")
1017
1018 if isinstance(cls, type) and isinstance(obj, cls):
1019 return obj
1020
1021 interface = set(methods or [m for m in dir(cls) if not m.startswith("_")])
1022 implemented = set(dir(obj))
1023
1024 complies = operator.ge
1025 if isinstance(required, type):
1026 required = interface
1027 elif not required:
1028 required = set()
1029 complies = operator.gt
1030 else:
1031 required = set(required)
1032
1033 if complies(implemented.intersection(interface), required):
1034 return obj
1035
1036 # No dict duck typing here.
1037 if not isinstance(obj, dict):
1038 qualifier = complies is operator.gt and "any of" or "all of"
1039 raise TypeError(
1040 "%r does not implement %s: %s"
1041 % (obj, qualifier, ", ".join(interface))
1042 )
1043
1044 class AnonymousInterface:
1045 """A callable-holding shell."""
1046
1047 if cls:
1048 AnonymousInterface.__name__ = "Anonymous" + cls.__name__
1049 found = set()
1050
1051 for method, impl in dictlike_iteritems(obj):
1052 if method not in interface:
1053 raise TypeError("%r: unknown in this interface" % method)
1054 if not callable(impl):
1055 raise TypeError("%r=%r is not callable" % (method, impl))
1056 setattr(AnonymousInterface, method, staticmethod(impl))
1057 found.add(method)
1058
1059 if complies(found, required):
1060 return AnonymousInterface
1061
1062 raise TypeError(
1063 "dictionary does not contain required keys %s"
1064 % ", ".join(required - found)
1065 )
1066
1067
1068_GFD = TypeVar("_GFD", bound="generic_fn_descriptor[Any]")
1069
1070
1071class generic_fn_descriptor(Generic[_T_co]):
1072 """Descriptor which proxies a function when the attribute is not
1073 present in dict
1074
1075 This superclass is organized in a particular way with "memoized" and
1076 "non-memoized" implementation classes that are hidden from type checkers,
1077 as Mypy seems to not be able to handle seeing multiple kinds of descriptor
1078 classes used for the same attribute.
1079
1080 """
1081
1082 fget: Callable[..., _T_co]
1083 __doc__: Optional[str]
1084 __name__: str
1085
1086 def __init__(self, fget: Callable[..., _T_co], doc: Optional[str] = None):
1087 self.fget = fget
1088 self.__doc__ = doc or fget.__doc__
1089 self.__name__ = fget.__name__
1090
1091 @overload
1092 def __get__(self: _GFD, obj: None, cls: Any) -> _GFD: ...
1093
1094 @overload
1095 def __get__(self, obj: object, cls: Any) -> _T_co: ...
1096
1097 def __get__(self: _GFD, obj: Any, cls: Any) -> Union[_GFD, _T_co]:
1098 raise NotImplementedError()
1099
1100 if TYPE_CHECKING:
1101
1102 def __set__(self, instance: Any, value: Any) -> None: ...
1103
1104 def __delete__(self, instance: Any) -> None: ...
1105
1106 def _reset(self, obj: Any) -> None:
1107 raise NotImplementedError()
1108
1109 @classmethod
1110 def reset(cls, obj: Any, name: str) -> None:
1111 raise NotImplementedError()
1112
1113
1114class _non_memoized_property(generic_fn_descriptor[_T_co]):
1115 """a plain descriptor that proxies a function.
1116
1117 primary rationale is to provide a plain attribute that's
1118 compatible with memoized_property which is also recognized as equivalent
1119 by mypy.
1120
1121 """
1122
1123 if not TYPE_CHECKING:
1124
1125 def __get__(self, obj, cls):
1126 if obj is None:
1127 return self
1128 return self.fget(obj)
1129
1130
1131class _memoized_property(generic_fn_descriptor[_T_co]):
1132 """A read-only @property that is only evaluated once."""
1133
1134 if not TYPE_CHECKING:
1135
1136 def __get__(self, obj, cls):
1137 if obj is None:
1138 return self
1139 obj.__dict__[self.__name__] = result = self.fget(obj)
1140 return result
1141
1142 def _reset(self, obj):
1143 _memoized_property.reset(obj, self.__name__)
1144
1145 @classmethod
1146 def reset(cls, obj, name):
1147 obj.__dict__.pop(name, None)
1148
1149
1150# despite many attempts to get Mypy to recognize an overridden descriptor
1151# where one is memoized and the other isn't, there seems to be no reliable
1152# way other than completely deceiving the type checker into thinking there
1153# is just one single descriptor type everywhere. Otherwise, if a superclass
1154# has non-memoized and subclass has memoized, that requires
1155# "class memoized(non_memoized)". but then if a superclass has memoized and
1156# superclass has non-memoized, the class hierarchy of the descriptors
1157# would need to be reversed; "class non_memoized(memoized)". so there's no
1158# way to achieve this.
1159# additional issues, RO properties:
1160# https://github.com/python/mypy/issues/12440
1161if TYPE_CHECKING:
1162 # allow memoized and non-memoized to be freely mixed by having them
1163 # be the same class
1164 memoized_property = generic_fn_descriptor
1165 non_memoized_property = generic_fn_descriptor
1166
1167 # for read only situations, mypy only sees @property as read only.
1168 # read only is needed when a subtype specializes the return type
1169 # of a property, meaning assignment needs to be disallowed
1170 ro_memoized_property = property
1171 ro_non_memoized_property = property
1172
1173else:
1174 memoized_property = ro_memoized_property = _memoized_property
1175 non_memoized_property = ro_non_memoized_property = _non_memoized_property
1176
1177
1178def memoized_instancemethod(fn: _F) -> _F:
1179 """Decorate a method memoize its return value.
1180
1181 Best applied to no-arg methods: memoization is not sensitive to
1182 argument values, and will always return the same value even when
1183 called with different arguments.
1184
1185 """
1186
1187 def oneshot(self, *args, **kw):
1188 result = fn(self, *args, **kw)
1189
1190 def memo(*a, **kw):
1191 return result
1192
1193 memo.__name__ = fn.__name__
1194 memo.__doc__ = fn.__doc__
1195 self.__dict__[fn.__name__] = memo
1196 return result
1197
1198 return update_wrapper(oneshot, fn) # type: ignore
1199
1200
1201class HasMemoized:
1202 """A mixin class that maintains the names of memoized elements in a
1203 collection for easy cache clearing, generative, etc.
1204
1205 """
1206
1207 if not TYPE_CHECKING:
1208 # support classes that want to have __slots__ with an explicit
1209 # slot for __dict__. not sure if that requires base __slots__ here.
1210 __slots__ = ()
1211
1212 _memoized_keys: FrozenSet[str] = frozenset()
1213
1214 def _reset_memoizations(self) -> None:
1215 for elem in self._memoized_keys:
1216 self.__dict__.pop(elem, None)
1217
1218 def _assert_no_memoizations(self) -> None:
1219 for elem in self._memoized_keys:
1220 assert elem not in self.__dict__
1221
1222 def _set_memoized_attribute(self, key: str, value: Any) -> None:
1223 self.__dict__[key] = value
1224 self._memoized_keys |= {key}
1225
1226 class memoized_attribute(memoized_property[_T]):
1227 """A read-only @property that is only evaluated once.
1228
1229 :meta private:
1230
1231 """
1232
1233 fget: Callable[..., _T]
1234 __doc__: Optional[str]
1235 __name__: str
1236
1237 def __init__(self, fget: Callable[..., _T], doc: Optional[str] = None):
1238 self.fget = fget
1239 self.__doc__ = doc or fget.__doc__
1240 self.__name__ = fget.__name__
1241
1242 @overload
1243 def __get__(self: _MA, obj: None, cls: Any) -> _MA: ...
1244
1245 @overload
1246 def __get__(self, obj: Any, cls: Any) -> _T: ...
1247
1248 def __get__(self, obj, cls):
1249 if obj is None:
1250 return self
1251 obj.__dict__[self.__name__] = result = self.fget(obj)
1252 obj._memoized_keys |= {self.__name__}
1253 return result
1254
1255 @classmethod
1256 def memoized_instancemethod(cls, fn: _F) -> _F:
1257 """Decorate a method memoize its return value.
1258
1259 :meta private:
1260
1261 """
1262
1263 def oneshot(self: Any, *args: Any, **kw: Any) -> Any:
1264 result = fn(self, *args, **kw)
1265
1266 def memo(*a, **kw):
1267 return result
1268
1269 memo.__name__ = fn.__name__
1270 memo.__doc__ = fn.__doc__
1271 self.__dict__[fn.__name__] = memo
1272 self._memoized_keys |= {fn.__name__}
1273 return result
1274
1275 return update_wrapper(oneshot, fn) # type: ignore
1276
1277
1278if TYPE_CHECKING:
1279 HasMemoized_ro_memoized_attribute = property
1280else:
1281 HasMemoized_ro_memoized_attribute = HasMemoized.memoized_attribute
1282
1283
1284class MemoizedSlots:
1285 """Apply memoized items to an object using a __getattr__ scheme.
1286
1287 This allows the functionality of memoized_property and
1288 memoized_instancemethod to be available to a class using __slots__.
1289
1290 """
1291
1292 __slots__ = ()
1293
1294 def _fallback_getattr(self, key):
1295 raise AttributeError(key)
1296
1297 def __getattr__(self, key: str) -> Any:
1298 if key.startswith("_memoized_attr_") or key.startswith(
1299 "_memoized_method_"
1300 ):
1301 raise AttributeError(key)
1302 # to avoid recursion errors when interacting with other __getattr__
1303 # schemes that refer to this one, when testing for memoized method
1304 # look at __class__ only rather than going into __getattr__ again.
1305 elif hasattr(self.__class__, f"_memoized_attr_{key}"):
1306 value = getattr(self, f"_memoized_attr_{key}")()
1307 setattr(self, key, value)
1308 return value
1309 elif hasattr(self.__class__, f"_memoized_method_{key}"):
1310 fn = getattr(self, f"_memoized_method_{key}")
1311
1312 def oneshot(*args, **kw):
1313 result = fn(*args, **kw)
1314
1315 def memo(*a, **kw):
1316 return result
1317
1318 memo.__name__ = fn.__name__
1319 memo.__doc__ = fn.__doc__
1320 setattr(self, key, memo)
1321 return result
1322
1323 oneshot.__doc__ = fn.__doc__
1324 return oneshot
1325 else:
1326 return self._fallback_getattr(key)
1327
1328
1329# from paste.deploy.converters
1330def asbool(obj: Any) -> bool:
1331 if isinstance(obj, str):
1332 obj = obj.strip().lower()
1333 if obj in ["true", "yes", "on", "y", "t", "1"]:
1334 return True
1335 elif obj in ["false", "no", "off", "n", "f", "0"]:
1336 return False
1337 else:
1338 raise ValueError("String is not true/false: %r" % obj)
1339 return bool(obj)
1340
1341
1342def bool_or_str(*text: str) -> Callable[[str], Union[str, bool]]:
1343 """Return a callable that will evaluate a string as
1344 boolean, or one of a set of "alternate" string values.
1345
1346 """
1347
1348 def bool_or_value(obj: str) -> Union[str, bool]:
1349 if obj in text:
1350 return obj
1351 else:
1352 return asbool(obj)
1353
1354 return bool_or_value
1355
1356
1357def asint(value: Any) -> Optional[int]:
1358 """Coerce to integer."""
1359
1360 if value is None:
1361 return value
1362 return int(value)
1363
1364
1365def coerce_kw_type(
1366 kw: Dict[str, Any],
1367 key: str,
1368 type_: Type[Any],
1369 flexi_bool: bool = True,
1370 dest: Optional[Dict[str, Any]] = None,
1371) -> None:
1372 r"""If 'key' is present in dict 'kw', coerce its value to type 'type\_' if
1373 necessary. If 'flexi_bool' is True, the string '0' is considered false
1374 when coercing to boolean.
1375 """
1376
1377 if dest is None:
1378 dest = kw
1379
1380 if (
1381 key in kw
1382 and (not isinstance(type_, type) or not isinstance(kw[key], type_))
1383 and kw[key] is not None
1384 ):
1385 if type_ is bool and flexi_bool:
1386 dest[key] = asbool(kw[key])
1387 else:
1388 dest[key] = type_(kw[key])
1389
1390
1391def constructor_key(obj: Any, cls: Type[Any]) -> Tuple[Any, ...]:
1392 """Produce a tuple structure that is cacheable using the __dict__ of
1393 obj to retrieve values
1394
1395 """
1396 names = get_cls_kwargs(cls)
1397 return (cls,) + tuple(
1398 (k, obj.__dict__[k]) for k in names if k in obj.__dict__
1399 )
1400
1401
1402def constructor_copy(obj: _T, cls: Type[_T], *args: Any, **kw: Any) -> _T:
1403 """Instantiate cls using the __dict__ of obj as constructor arguments.
1404
1405 Uses inspect to match the named arguments of ``cls``.
1406
1407 """
1408
1409 names = get_cls_kwargs(cls)
1410 kw.update(
1411 (k, obj.__dict__[k]) for k in names.difference(kw) if k in obj.__dict__
1412 )
1413 return cls(*args, **kw)
1414
1415
1416def counter() -> Callable[[], int]:
1417 """Return a threadsafe counter function."""
1418
1419 lock = threading.Lock()
1420 counter = itertools.count(1)
1421
1422 # avoid the 2to3 "next" transformation...
1423 def _next():
1424 with lock:
1425 return next(counter)
1426
1427 return _next
1428
1429
1430def duck_type_collection(
1431 specimen: Any, default: Optional[Type[Any]] = None
1432) -> Optional[Type[Any]]:
1433 """Given an instance or class, guess if it is or is acting as one of
1434 the basic collection types: list, set and dict. If the __emulates__
1435 property is present, return that preferentially.
1436 """
1437
1438 if hasattr(specimen, "__emulates__"):
1439 # canonicalize set vs sets.Set to a standard: the builtin set
1440 if specimen.__emulates__ is not None and issubclass(
1441 specimen.__emulates__, set
1442 ):
1443 return set
1444 else:
1445 return specimen.__emulates__ # type: ignore
1446
1447 isa = issubclass if isinstance(specimen, type) else isinstance
1448 if isa(specimen, list):
1449 return list
1450 elif isa(specimen, set):
1451 return set
1452 elif isa(specimen, dict):
1453 return dict
1454
1455 if hasattr(specimen, "append"):
1456 return list
1457 elif hasattr(specimen, "add"):
1458 return set
1459 elif hasattr(specimen, "set"):
1460 return dict
1461 else:
1462 return default
1463
1464
1465def assert_arg_type(
1466 arg: Any, argtype: Union[Tuple[Type[Any], ...], Type[Any]], name: str
1467) -> Any:
1468 if isinstance(arg, argtype):
1469 return arg
1470 else:
1471 if isinstance(argtype, tuple):
1472 raise exc.ArgumentError(
1473 "Argument '%s' is expected to be one of type %s, got '%s'"
1474 % (name, " or ".join("'%s'" % a for a in argtype), type(arg))
1475 )
1476 else:
1477 raise exc.ArgumentError(
1478 "Argument '%s' is expected to be of type '%s', got '%s'"
1479 % (name, argtype, type(arg))
1480 )
1481
1482
1483def dictlike_iteritems(dictlike):
1484 """Return a (key, value) iterator for almost any dict-like object."""
1485
1486 if hasattr(dictlike, "items"):
1487 return list(dictlike.items())
1488
1489 getter = getattr(dictlike, "__getitem__", getattr(dictlike, "get", None))
1490 if getter is None:
1491 raise TypeError("Object '%r' is not dict-like" % dictlike)
1492
1493 if hasattr(dictlike, "iterkeys"):
1494
1495 def iterator():
1496 for key in dictlike.iterkeys():
1497 assert getter is not None
1498 yield key, getter(key)
1499
1500 return iterator()
1501 elif hasattr(dictlike, "keys"):
1502 return iter((key, getter(key)) for key in dictlike.keys())
1503 else:
1504 raise TypeError("Object '%r' is not dict-like" % dictlike)
1505
1506
1507class classproperty(property):
1508 """A decorator that behaves like @property except that operates
1509 on classes rather than instances.
1510
1511 The decorator is currently special when using the declarative
1512 module, but note that the
1513 :class:`~.sqlalchemy.ext.declarative.declared_attr`
1514 decorator should be used for this purpose with declarative.
1515
1516 """
1517
1518 fget: Callable[[Any], Any]
1519
1520 def __init__(self, fget: Callable[[Any], Any], *arg: Any, **kw: Any):
1521 super().__init__(fget, *arg, **kw)
1522 self.__doc__ = fget.__doc__
1523
1524 def __get__(self, obj: Any, cls: Optional[type] = None) -> Any:
1525 return self.fget(cls)
1526
1527
1528class hybridproperty(Generic[_T]):
1529 def __init__(self, func: Callable[..., _T]):
1530 self.func = func
1531 self.clslevel = func
1532
1533 def __get__(self, instance: Any, owner: Any) -> _T:
1534 if instance is None:
1535 clsval = self.clslevel(owner)
1536 return clsval
1537 else:
1538 return self.func(instance)
1539
1540 def classlevel(self, func: Callable[..., Any]) -> hybridproperty[_T]:
1541 self.clslevel = func
1542 return self
1543
1544
1545class rw_hybridproperty(Generic[_T]):
1546 def __init__(self, func: Callable[..., _T]):
1547 self.func = func
1548 self.clslevel = func
1549 self.setfn: Optional[Callable[..., Any]] = None
1550
1551 def __get__(self, instance: Any, owner: Any) -> _T:
1552 if instance is None:
1553 clsval = self.clslevel(owner)
1554 return clsval
1555 else:
1556 return self.func(instance)
1557
1558 def __set__(self, instance: Any, value: Any) -> None:
1559 assert self.setfn is not None
1560 self.setfn(instance, value)
1561
1562 def setter(self, func: Callable[..., Any]) -> rw_hybridproperty[_T]:
1563 self.setfn = func
1564 return self
1565
1566 def classlevel(self, func: Callable[..., Any]) -> rw_hybridproperty[_T]:
1567 self.clslevel = func
1568 return self
1569
1570
1571class hybridmethod(Generic[_T]):
1572 """Decorate a function as cls- or instance- level."""
1573
1574 def __init__(self, func: Callable[..., _T]):
1575 self.func = self.__func__ = func
1576 self.clslevel = func
1577
1578 def __get__(self, instance: Any, owner: Any) -> Callable[..., _T]:
1579 if instance is None:
1580 return self.clslevel.__get__(owner, owner.__class__) # type:ignore
1581 else:
1582 return self.func.__get__(instance, owner) # type:ignore
1583
1584 def classlevel(self, func: Callable[..., Any]) -> hybridmethod[_T]:
1585 self.clslevel = func
1586 return self
1587
1588
1589class symbol(int):
1590 """A constant symbol.
1591
1592 >>> symbol('foo') is symbol('foo')
1593 True
1594 >>> symbol('foo')
1595 <symbol 'foo>
1596
1597 A slight refinement of the MAGICCOOKIE=object() pattern. The primary
1598 advantage of symbol() is its repr(). They are also singletons.
1599
1600 Repeated calls of symbol('name') will all return the same instance.
1601
1602 """
1603
1604 name: str
1605
1606 symbols: Dict[str, symbol] = {}
1607 _lock = threading.Lock()
1608
1609 def __new__(
1610 cls,
1611 name: str,
1612 doc: Optional[str] = None,
1613 canonical: Optional[int] = None,
1614 ) -> symbol:
1615 with cls._lock:
1616 sym = cls.symbols.get(name)
1617 if sym is None:
1618 assert isinstance(name, str)
1619 if canonical is None:
1620 canonical = hash(name)
1621 sym = int.__new__(symbol, canonical)
1622 sym.name = name
1623 if doc:
1624 sym.__doc__ = doc
1625
1626 # NOTE: we should ultimately get rid of this global thing,
1627 # however, currently it is to support pickling. The best
1628 # change would be when we are on py3.11 at a minimum, we
1629 # switch to stdlib enum.IntFlag.
1630 cls.symbols[name] = sym
1631 else:
1632 if canonical and canonical != sym:
1633 raise TypeError(
1634 f"Can't replace canonical symbol for {name!r} "
1635 f"with new int value {canonical}"
1636 )
1637 return sym
1638
1639 def __reduce__(self):
1640 return symbol, (self.name, "x", int(self))
1641
1642 def __str__(self):
1643 return repr(self)
1644
1645 def __repr__(self):
1646 return f"symbol({self.name!r})"
1647
1648
1649class _IntFlagMeta(type):
1650 def __init__(
1651 cls,
1652 classname: str,
1653 bases: Tuple[Type[Any], ...],
1654 dict_: Dict[str, Any],
1655 **kw: Any,
1656 ) -> None:
1657 items: List[symbol]
1658 cls._items = items = []
1659 for k, v in dict_.items():
1660 if re.match(r"^__.*__$", k):
1661 continue
1662 if isinstance(v, int):
1663 sym = symbol(k, canonical=v)
1664 elif not k.startswith("_"):
1665 raise TypeError("Expected integer values for IntFlag")
1666 else:
1667 continue
1668 setattr(cls, k, sym)
1669 items.append(sym)
1670
1671 cls.__members__ = _collections.immutabledict(
1672 {sym.name: sym for sym in items}
1673 )
1674
1675 def __iter__(self) -> Iterator[symbol]:
1676 raise NotImplementedError(
1677 "iter not implemented to ensure compatibility with "
1678 "Python 3.11 IntFlag. Please use __members__. See "
1679 "https://github.com/python/cpython/issues/99304"
1680 )
1681
1682
1683class _FastIntFlag(metaclass=_IntFlagMeta):
1684 """An 'IntFlag' copycat that isn't slow when performing bitwise
1685 operations.
1686
1687 the ``FastIntFlag`` class will return ``enum.IntFlag`` under TYPE_CHECKING
1688 and ``_FastIntFlag`` otherwise.
1689
1690 """
1691
1692
1693if TYPE_CHECKING:
1694 from enum import IntFlag
1695
1696 FastIntFlag = IntFlag
1697else:
1698 FastIntFlag = _FastIntFlag
1699
1700
1701_E = TypeVar("_E", bound=enum.Enum)
1702
1703
1704def parse_user_argument_for_enum(
1705 arg: Any,
1706 choices: Dict[_E, List[Any]],
1707 name: str,
1708 resolve_symbol_names: bool = False,
1709) -> Optional[_E]:
1710 """Given a user parameter, parse the parameter into a chosen value
1711 from a list of choice objects, typically Enum values.
1712
1713 The user argument can be a string name that matches the name of a
1714 symbol, or the symbol object itself, or any number of alternate choices
1715 such as True/False/ None etc.
1716
1717 :param arg: the user argument.
1718 :param choices: dictionary of enum values to lists of possible
1719 entries for each.
1720 :param name: name of the argument. Used in an :class:`.ArgumentError`
1721 that is raised if the parameter doesn't match any available argument.
1722
1723 """
1724 for enum_value, choice in choices.items():
1725 if arg is enum_value:
1726 return enum_value
1727 elif resolve_symbol_names and arg == enum_value.name:
1728 return enum_value
1729 elif arg in choice:
1730 return enum_value
1731
1732 if arg is None:
1733 return None
1734
1735 raise exc.ArgumentError(f"Invalid value for '{name}': {arg!r}")
1736
1737
1738_creation_order = 1
1739
1740
1741def set_creation_order(instance: Any) -> None:
1742 """Assign a '_creation_order' sequence to the given instance.
1743
1744 This allows multiple instances to be sorted in order of creation
1745 (typically within a single thread; the counter is not particularly
1746 threadsafe).
1747
1748 """
1749 global _creation_order
1750 instance._creation_order = _creation_order
1751 _creation_order += 1
1752
1753
1754def warn_exception(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
1755 """executes the given function, catches all exceptions and converts to
1756 a warning.
1757
1758 """
1759 try:
1760 return func(*args, **kwargs)
1761 except Exception:
1762 warn("%s('%s') ignored" % sys.exc_info()[0:2])
1763
1764
1765def ellipses_string(value, len_=25):
1766 try:
1767 if len(value) > len_:
1768 return "%s..." % value[0:len_]
1769 else:
1770 return value
1771 except TypeError:
1772 return value
1773
1774
1775class _hash_limit_string(str):
1776 """A string subclass that can only be hashed on a maximum amount
1777 of unique values.
1778
1779 This is used for warnings so that we can send out parameterized warnings
1780 without the __warningregistry__ of the module, or the non-overridable
1781 "once" registry within warnings.py, overloading memory,
1782
1783
1784 """
1785
1786 _hash: int
1787
1788 def __new__(
1789 cls, value: str, num: int, args: Sequence[Any]
1790 ) -> _hash_limit_string:
1791 interpolated = (value % args) + (
1792 " (this warning may be suppressed after %d occurrences)" % num
1793 )
1794 self = super().__new__(cls, interpolated)
1795 self._hash = hash("%s_%d" % (value, hash(interpolated) % num))
1796 return self
1797
1798 def __hash__(self) -> int:
1799 return self._hash
1800
1801 def __eq__(self, other: Any) -> bool:
1802 return hash(self) == hash(other)
1803
1804
1805def warn(msg: str, code: Optional[str] = None) -> None:
1806 """Issue a warning.
1807
1808 If msg is a string, :class:`.exc.SAWarning` is used as
1809 the category.
1810
1811 """
1812 if code:
1813 _warnings_warn(exc.SAWarning(msg, code=code))
1814 else:
1815 _warnings_warn(msg, exc.SAWarning)
1816
1817
1818def warn_limited(msg: str, args: Sequence[Any]) -> None:
1819 """Issue a warning with a parameterized string, limiting the number
1820 of registrations.
1821
1822 """
1823 if args:
1824 msg = _hash_limit_string(msg, 10, args)
1825 _warnings_warn(msg, exc.SAWarning)
1826
1827
1828_warning_tags: Dict[CodeType, Tuple[str, Type[Warning]]] = {}
1829
1830
1831def tag_method_for_warnings(
1832 message: str, category: Type[Warning]
1833) -> Callable[[_F], _F]:
1834 def go(fn):
1835 _warning_tags[fn.__code__] = (message, category)
1836 return fn
1837
1838 return go
1839
1840
1841_not_sa_pattern = re.compile(r"^(?:sqlalchemy\.(?!testing)|alembic\.)")
1842
1843
1844def _warnings_warn(
1845 message: Union[str, Warning],
1846 category: Optional[Type[Warning]] = None,
1847 stacklevel: int = 2,
1848) -> None:
1849 # adjust the given stacklevel to be outside of SQLAlchemy
1850 try:
1851 frame = sys._getframe(stacklevel)
1852 except ValueError:
1853 # being called from less than 3 (or given) stacklevels, weird,
1854 # but don't crash
1855 stacklevel = 0
1856 except:
1857 # _getframe() doesn't work, weird interpreter issue, weird,
1858 # ok, but don't crash
1859 stacklevel = 0
1860 else:
1861 stacklevel_found = warning_tag_found = False
1862 while frame is not None:
1863 # using __name__ here requires that we have __name__ in the
1864 # __globals__ of the decorated string functions we make also.
1865 # we generate this using {"__name__": fn.__module__}
1866 if not stacklevel_found and not re.match(
1867 _not_sa_pattern, frame.f_globals.get("__name__", "")
1868 ):
1869 # stop incrementing stack level if an out-of-SQLA line
1870 # were found.
1871 stacklevel_found = True
1872
1873 # however, for the warning tag thing, we have to keep
1874 # scanning up the whole traceback
1875
1876 if frame.f_code in _warning_tags:
1877 warning_tag_found = True
1878 (_suffix, _category) = _warning_tags[frame.f_code]
1879 category = category or _category
1880 message = f"{message} ({_suffix})"
1881
1882 frame = frame.f_back # type: ignore[assignment]
1883
1884 if not stacklevel_found:
1885 stacklevel += 1
1886 elif stacklevel_found and warning_tag_found:
1887 break
1888
1889 if category is not None:
1890 warnings.warn(message, category, stacklevel=stacklevel + 1)
1891 else:
1892 warnings.warn(message, stacklevel=stacklevel + 1)
1893
1894
1895def only_once(
1896 fn: Callable[..., _T], retry_on_exception: bool
1897) -> Callable[..., Optional[_T]]:
1898 """Decorate the given function to be a no-op after it is called exactly
1899 once."""
1900
1901 once = [fn]
1902
1903 def go(*arg: Any, **kw: Any) -> Optional[_T]:
1904 # strong reference fn so that it isn't garbage collected,
1905 # which interferes with the event system's expectations
1906 strong_fn = fn # noqa
1907 if once:
1908 once_fn = once.pop()
1909 try:
1910 return once_fn(*arg, **kw)
1911 except:
1912 if retry_on_exception:
1913 once.insert(0, once_fn)
1914 raise
1915
1916 return None
1917
1918 return go
1919
1920
1921_SQLA_RE = re.compile(r"sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py")
1922_UNITTEST_RE = re.compile(r"unit(?:2|test2?/)")
1923
1924
1925def chop_traceback(
1926 tb: List[str],
1927 exclude_prefix: re.Pattern[str] = _UNITTEST_RE,
1928 exclude_suffix: re.Pattern[str] = _SQLA_RE,
1929) -> List[str]:
1930 """Chop extraneous lines off beginning and end of a traceback.
1931
1932 :param tb:
1933 a list of traceback lines as returned by ``traceback.format_stack()``
1934
1935 :param exclude_prefix:
1936 a regular expression object matching lines to skip at beginning of
1937 ``tb``
1938
1939 :param exclude_suffix:
1940 a regular expression object matching lines to skip at end of ``tb``
1941 """
1942 start = 0
1943 end = len(tb) - 1
1944 while start <= end and exclude_prefix.search(tb[start]):
1945 start += 1
1946 while start <= end and exclude_suffix.search(tb[end]):
1947 end -= 1
1948 return tb[start : end + 1]
1949
1950
1951NoneType = type(None)
1952
1953
1954def attrsetter(attrname):
1955 code = "def set(obj, value): obj.%s = value" % attrname
1956 env = locals().copy()
1957 exec(code, env)
1958 return env["set"]
1959
1960
1961_dunders = re.compile("^__.+__$")
1962
1963
1964class TypingOnly:
1965 """A mixin class that marks a class as 'typing only', meaning it has
1966 absolutely no methods, attributes, or runtime functionality whatsoever.
1967
1968 """
1969
1970 __slots__ = ()
1971
1972 def __init_subclass__(cls) -> None:
1973 if TypingOnly in cls.__bases__:
1974 remaining = {
1975 name for name in cls.__dict__ if not _dunders.match(name)
1976 }
1977 if remaining:
1978 raise AssertionError(
1979 f"Class {cls} directly inherits TypingOnly but has "
1980 f"additional attributes {remaining}."
1981 )
1982 super().__init_subclass__()
1983
1984
1985class EnsureKWArg:
1986 r"""Apply translation of functions to accept \**kw arguments if they
1987 don't already.
1988
1989 Used to ensure cross-compatibility with third party legacy code, for things
1990 like compiler visit methods that need to accept ``**kw`` arguments,
1991 but may have been copied from old code that didn't accept them.
1992
1993 """
1994
1995 ensure_kwarg: str
1996 """a regular expression that indicates method names for which the method
1997 should accept ``**kw`` arguments.
1998
1999 The class will scan for methods matching the name template and decorate
2000 them if necessary to ensure ``**kw`` parameters are accepted.
2001
2002 """
2003
2004 def __init_subclass__(cls) -> None:
2005 fn_reg = cls.ensure_kwarg
2006 clsdict = cls.__dict__
2007 if fn_reg:
2008 for key in clsdict:
2009 m = re.match(fn_reg, key)
2010 if m:
2011 fn = clsdict[key]
2012 spec = compat.inspect_getfullargspec(fn)
2013 if not spec.varkw:
2014 wrapped = cls._wrap_w_kw(fn)
2015 setattr(cls, key, wrapped)
2016 super().__init_subclass__()
2017
2018 @classmethod
2019 def _wrap_w_kw(cls, fn: Callable[..., Any]) -> Callable[..., Any]:
2020 def wrap(*arg: Any, **kw: Any) -> Any:
2021 return fn(*arg)
2022
2023 return update_wrapper(wrap, fn)
2024
2025
2026def wrap_callable(wrapper, fn):
2027 """Augment functools.update_wrapper() to work with objects with
2028 a ``__call__()`` method.
2029
2030 :param fn:
2031 object with __call__ method
2032
2033 """
2034 if hasattr(fn, "__name__"):
2035 return update_wrapper(wrapper, fn)
2036 else:
2037 _f = wrapper
2038 _f.__name__ = fn.__class__.__name__
2039 if hasattr(fn, "__module__"):
2040 _f.__module__ = fn.__module__
2041
2042 if hasattr(fn.__call__, "__doc__") and fn.__call__.__doc__:
2043 _f.__doc__ = fn.__call__.__doc__
2044 elif fn.__doc__:
2045 _f.__doc__ = fn.__doc__
2046
2047 return _f
2048
2049
2050def quoted_token_parser(value):
2051 """Parse a dotted identifier with accommodation for quoted names.
2052
2053 Includes support for SQL-style double quotes as a literal character.
2054
2055 E.g.::
2056
2057 >>> quoted_token_parser("name")
2058 ["name"]
2059 >>> quoted_token_parser("schema.name")
2060 ["schema", "name"]
2061 >>> quoted_token_parser('"Schema"."Name"')
2062 ['Schema', 'Name']
2063 >>> quoted_token_parser('"Schema"."Name""Foo"')
2064 ['Schema', 'Name""Foo']
2065
2066 """
2067
2068 if '"' not in value:
2069 return value.split(".")
2070
2071 # 0 = outside of quotes
2072 # 1 = inside of quotes
2073 state = 0
2074 result: List[List[str]] = [[]]
2075 idx = 0
2076 lv = len(value)
2077 while idx < lv:
2078 char = value[idx]
2079 if char == '"':
2080 if state == 1 and idx < lv - 1 and value[idx + 1] == '"':
2081 result[-1].append('"')
2082 idx += 1
2083 else:
2084 state ^= 1
2085 elif char == "." and state == 0:
2086 result.append([])
2087 else:
2088 result[-1].append(char)
2089 idx += 1
2090
2091 return ["".join(token) for token in result]
2092
2093
2094def add_parameter_text(params: Any, text: str) -> Callable[[_F], _F]:
2095 params = _collections.to_list(params)
2096
2097 def decorate(fn):
2098 doc = fn.__doc__ is not None and fn.__doc__ or ""
2099 if doc:
2100 doc = inject_param_text(doc, {param: text for param in params})
2101 fn.__doc__ = doc
2102 return fn
2103
2104 return decorate
2105
2106
2107def _dedent_docstring(text: str) -> str:
2108 split_text = text.split("\n", 1)
2109 if len(split_text) == 1:
2110 return text
2111 else:
2112 firstline, remaining = split_text
2113 if not firstline.startswith(" "):
2114 return firstline + "\n" + textwrap.dedent(remaining)
2115 else:
2116 return textwrap.dedent(text)
2117
2118
2119def inject_docstring_text(
2120 given_doctext: Optional[str], injecttext: str, pos: int
2121) -> str:
2122 doctext: str = _dedent_docstring(given_doctext or "")
2123 lines = doctext.split("\n")
2124 if len(lines) == 1:
2125 lines.append("")
2126 injectlines = textwrap.dedent(injecttext).split("\n")
2127 if injectlines[0]:
2128 injectlines.insert(0, "")
2129
2130 blanks = [num for num, line in enumerate(lines) if not line.strip()]
2131 blanks.insert(0, 0)
2132
2133 inject_pos = blanks[min(pos, len(blanks) - 1)]
2134
2135 lines = lines[0:inject_pos] + injectlines + lines[inject_pos:]
2136 return "\n".join(lines)
2137
2138
2139_param_reg = re.compile(r"(\s+):param (.+?):")
2140
2141
2142def inject_param_text(doctext: str, inject_params: Dict[str, str]) -> str:
2143 doclines = collections.deque(doctext.splitlines())
2144 lines = []
2145
2146 # TODO: this is not working for params like ":param case_sensitive=True:"
2147
2148 to_inject = None
2149 while doclines:
2150 line = doclines.popleft()
2151
2152 m = _param_reg.match(line)
2153
2154 if to_inject is None:
2155 if m:
2156 param = m.group(2).lstrip("*")
2157 if param in inject_params:
2158 # default indent to that of :param: plus one
2159 indent = " " * len(m.group(1)) + " "
2160
2161 # but if the next line has text, use that line's
2162 # indentation
2163 if doclines:
2164 m2 = re.match(r"(\s+)\S", doclines[0])
2165 if m2:
2166 indent = " " * len(m2.group(1))
2167
2168 to_inject = indent + inject_params[param]
2169 elif m:
2170 lines.extend(["\n", to_inject, "\n"])
2171 to_inject = None
2172 elif not line.rstrip():
2173 lines.extend([line, to_inject, "\n"])
2174 to_inject = None
2175 elif line.endswith("::"):
2176 # TODO: this still won't cover if the code example itself has
2177 # blank lines in it, need to detect those via indentation.
2178 lines.extend([line, doclines.popleft()])
2179 continue
2180 lines.append(line)
2181
2182 return "\n".join(lines)
2183
2184
2185def repr_tuple_names(names: List[str]) -> Optional[str]:
2186 """Trims a list of strings from the middle and return a string of up to
2187 four elements. Strings greater than 11 characters will be truncated"""
2188 if len(names) == 0:
2189 return None
2190 flag = len(names) <= 4
2191 names = names[0:4] if flag else names[0:3] + names[-1:]
2192 res = ["%s.." % name[:11] if len(name) > 11 else name for name in names]
2193 if flag:
2194 return ", ".join(res)
2195 else:
2196 return "%s, ..., %s" % (", ".join(res[0:3]), res[-1])
2197
2198
2199def has_compiled_ext(raise_=False):
2200 from ._has_cython import HAS_CYEXTENSION
2201
2202 if HAS_CYEXTENSION:
2203 return True
2204 elif raise_:
2205 raise ImportError(
2206 "cython extensions were expected to be installed, "
2207 "but are not present"
2208 )
2209 else:
2210 return False
2211
2212
2213def load_uncompiled_module(module: _M) -> _M:
2214 """Load the non-compied version of a module that is also
2215 compiled with cython.
2216 """
2217 full_name = module.__name__
2218 assert module.__spec__
2219 parent_name = module.__spec__.parent
2220 assert parent_name
2221 parent_module = sys.modules[parent_name]
2222 assert parent_module.__spec__
2223 package_path = parent_module.__spec__.origin
2224 assert package_path and package_path.endswith("__init__.py")
2225
2226 name = full_name.split(".")[-1]
2227 module_path = package_path.replace("__init__.py", f"{name}.py")
2228
2229 py_spec = importlib.util.spec_from_file_location(full_name, module_path)
2230 assert py_spec
2231 py_module = importlib.util.module_from_spec(py_spec)
2232 assert py_spec.loader
2233 py_spec.loader.exec_module(py_module)
2234 return cast(_M, py_module)
2235
2236
2237class _Missing(enum.Enum):
2238 Missing = enum.auto()
2239
2240
2241Missing = _Missing.Missing
2242MissingOr = Union[_T, Literal[_Missing.Missing]]