1# util/compat.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Handle Python version/platform incompatibilities."""
10
11from __future__ import annotations
12
13import base64
14import dataclasses
15import hashlib
16from importlib import metadata as importlib_metadata
17import inspect
18import operator
19import platform
20import sys
21import typing
22from typing import Any
23from typing import Callable
24from typing import Dict
25from typing import Iterable
26from typing import List
27from typing import Mapping
28from typing import Optional
29from typing import Sequence
30from typing import Set
31from typing import Tuple
32from typing import Type
33
34py312 = sys.version_info >= (3, 12)
35py311 = sys.version_info >= (3, 11)
36py310 = sys.version_info >= (3, 10)
37py39 = sys.version_info >= (3, 9)
38pypy = platform.python_implementation() == "PyPy"
39cpython = platform.python_implementation() == "CPython"
40
41win32 = sys.platform.startswith("win")
42osx = sys.platform.startswith("darwin")
43arm = "aarch" in platform.machine().lower()
44is64bit = sys.maxsize > 2**32
45
46has_refcount_gc = bool(cpython)
47
48dottedgetter = operator.attrgetter
49
50
51class FullArgSpec(typing.NamedTuple):
52 args: List[str]
53 varargs: Optional[str]
54 varkw: Optional[str]
55 defaults: Optional[Tuple[Any, ...]]
56 kwonlyargs: List[str]
57 kwonlydefaults: Optional[Dict[str, Any]]
58 annotations: Dict[str, Any]
59
60
61def inspect_getfullargspec(func: Callable[..., Any]) -> FullArgSpec:
62 """Fully vendored version of getfullargspec from Python 3.3."""
63
64 if inspect.ismethod(func):
65 func = func.__func__
66 if not inspect.isfunction(func):
67 raise TypeError(f"{func!r} is not a Python function")
68
69 co = func.__code__
70 if not inspect.iscode(co):
71 raise TypeError(f"{co!r} is not a code object")
72
73 nargs = co.co_argcount
74 names = co.co_varnames
75 nkwargs = co.co_kwonlyargcount
76 args = list(names[:nargs])
77 kwonlyargs = list(names[nargs : nargs + nkwargs])
78
79 nargs += nkwargs
80 varargs = None
81 if co.co_flags & inspect.CO_VARARGS:
82 varargs = co.co_varnames[nargs]
83 nargs = nargs + 1
84 varkw = None
85 if co.co_flags & inspect.CO_VARKEYWORDS:
86 varkw = co.co_varnames[nargs]
87
88 return FullArgSpec(
89 args,
90 varargs,
91 varkw,
92 func.__defaults__,
93 kwonlyargs,
94 func.__kwdefaults__,
95 func.__annotations__,
96 )
97
98
99if py39:
100 # python stubs don't have a public type for this. not worth
101 # making a protocol
102 def md5_not_for_security() -> Any:
103 return hashlib.md5(usedforsecurity=False)
104
105else:
106
107 def md5_not_for_security() -> Any:
108 return hashlib.md5()
109
110
111if typing.TYPE_CHECKING or py39:
112 # pep 584 dict union
113 dict_union = operator.or_ # noqa
114else:
115
116 def dict_union(a: dict, b: dict) -> dict:
117 a = a.copy()
118 a.update(b)
119 return a
120
121
122if py310:
123 anext_ = anext
124else:
125 _NOT_PROVIDED = object()
126 from collections.abc import AsyncIterator
127
128 async def anext_(async_iterator, default=_NOT_PROVIDED):
129 """vendored from https://github.com/python/cpython/pull/8895"""
130
131 if not isinstance(async_iterator, AsyncIterator):
132 raise TypeError(
133 f"anext expected an AsyncIterator, got {type(async_iterator)}"
134 )
135 anxt = type(async_iterator).__anext__
136 try:
137 return await anxt(async_iterator)
138 except StopAsyncIteration:
139 if default is _NOT_PROVIDED:
140 raise
141 return default
142
143
144def importlib_metadata_get(group):
145 ep = importlib_metadata.entry_points()
146 if typing.TYPE_CHECKING or hasattr(ep, "select"):
147 return ep.select(group=group)
148 else:
149 return ep.get(group, ())
150
151
152def b(s):
153 return s.encode("latin-1")
154
155
156def b64decode(x: str) -> bytes:
157 return base64.b64decode(x.encode("ascii"))
158
159
160def b64encode(x: bytes) -> str:
161 return base64.b64encode(x).decode("ascii")
162
163
164def decode_backslashreplace(text: bytes, encoding: str) -> str:
165 return text.decode(encoding, errors="backslashreplace")
166
167
168def cmp(a, b):
169 return (a > b) - (a < b)
170
171
172def _formatannotation(annotation, base_module=None):
173 """vendored from python 3.7"""
174
175 if isinstance(annotation, str):
176 return annotation
177
178 if getattr(annotation, "__module__", None) == "typing":
179 return repr(annotation).replace("typing.", "").replace("~", "")
180 if isinstance(annotation, type):
181 if annotation.__module__ in ("builtins", base_module):
182 return repr(annotation.__qualname__)
183 return annotation.__module__ + "." + annotation.__qualname__
184 elif isinstance(annotation, typing.TypeVar):
185 return repr(annotation).replace("~", "")
186 return repr(annotation).replace("~", "")
187
188
189def inspect_formatargspec(
190 args: List[str],
191 varargs: Optional[str] = None,
192 varkw: Optional[str] = None,
193 defaults: Optional[Sequence[Any]] = None,
194 kwonlyargs: Optional[Sequence[str]] = (),
195 kwonlydefaults: Optional[Mapping[str, Any]] = {},
196 annotations: Mapping[str, Any] = {},
197 formatarg: Callable[[str], str] = str,
198 formatvarargs: Callable[[str], str] = lambda name: "*" + name,
199 formatvarkw: Callable[[str], str] = lambda name: "**" + name,
200 formatvalue: Callable[[Any], str] = lambda value: "=" + repr(value),
201 formatreturns: Callable[[Any], str] = lambda text: " -> " + str(text),
202 formatannotation: Callable[[Any], str] = _formatannotation,
203) -> str:
204 """Copy formatargspec from python 3.7 standard library.
205
206 Python 3 has deprecated formatargspec and requested that Signature
207 be used instead, however this requires a full reimplementation
208 of formatargspec() in terms of creating Parameter objects and such.
209 Instead of introducing all the object-creation overhead and having
210 to reinvent from scratch, just copy their compatibility routine.
211
212 Ultimately we would need to rewrite our "decorator" routine completely
213 which is not really worth it right now, until all Python 2.x support
214 is dropped.
215
216 """
217
218 kwonlydefaults = kwonlydefaults or {}
219 annotations = annotations or {}
220
221 def formatargandannotation(arg):
222 result = formatarg(arg)
223 if arg in annotations:
224 result += ": " + formatannotation(annotations[arg])
225 return result
226
227 specs = []
228 if defaults:
229 firstdefault = len(args) - len(defaults)
230 else:
231 firstdefault = -1
232
233 for i, arg in enumerate(args):
234 spec = formatargandannotation(arg)
235 if defaults and i >= firstdefault:
236 spec = spec + formatvalue(defaults[i - firstdefault])
237 specs.append(spec)
238
239 if varargs is not None:
240 specs.append(formatvarargs(formatargandannotation(varargs)))
241 else:
242 if kwonlyargs:
243 specs.append("*")
244
245 if kwonlyargs:
246 for kwonlyarg in kwonlyargs:
247 spec = formatargandannotation(kwonlyarg)
248 if kwonlydefaults and kwonlyarg in kwonlydefaults:
249 spec += formatvalue(kwonlydefaults[kwonlyarg])
250 specs.append(spec)
251
252 if varkw is not None:
253 specs.append(formatvarkw(formatargandannotation(varkw)))
254
255 result = "(" + ", ".join(specs) + ")"
256 if "return" in annotations:
257 result += formatreturns(formatannotation(annotations["return"]))
258 return result
259
260
261def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
262 """Return a sequence of all dataclasses.Field objects associated
263 with a class as an already processed dataclass.
264
265 The class must **already be a dataclass** for Field objects to be returned.
266
267 """
268
269 if dataclasses.is_dataclass(cls):
270 return dataclasses.fields(cls)
271 else:
272 return []
273
274
275def local_dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
276 """Return a sequence of all dataclasses.Field objects associated with
277 an already processed dataclass, excluding those that originate from a
278 superclass.
279
280 The class must **already be a dataclass** for Field objects to be returned.
281
282 """
283
284 if dataclasses.is_dataclass(cls):
285 super_fields: Set[dataclasses.Field[Any]] = set()
286 for sup in cls.__bases__:
287 super_fields.update(dataclass_fields(sup))
288 return [f for f in dataclasses.fields(cls) if f not in super_fields]
289 else:
290 return []