Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/utils.py: 38%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import datetime
2import inspect
3import logging
4import textwrap
5import warnings
6from collections.abc import Callable
7from contextlib import contextmanager
8from functools import wraps
9from importlib import metadata
10from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, TypeVar, Union
12from redis.exceptions import DataError
13from redis.typing import AbsExpiryT, EncodableT, ExpiryT
15DEFAULT_RESP_VERSION = 3
17if TYPE_CHECKING:
18 from redis.client import Redis
20try:
21 import hiredis # noqa
23 # Only support Hiredis >= 3.0:
24 hiredis_version = hiredis.__version__.split(".")
25 HIREDIS_AVAILABLE = int(hiredis_version[0]) > 3 or (
26 int(hiredis_version[0]) == 3 and int(hiredis_version[1]) >= 2
27 )
28 if not HIREDIS_AVAILABLE:
29 raise ImportError("hiredis package should be >= 3.2.0")
30except ImportError:
31 HIREDIS_AVAILABLE = False
33try:
34 import ssl # noqa
36 SSL_AVAILABLE = True
37except ImportError:
38 SSL_AVAILABLE = False
40try:
41 import cryptography # noqa
43 CRYPTOGRAPHY_AVAILABLE = True
44except ImportError:
45 CRYPTOGRAPHY_AVAILABLE = False
48def from_url(url: str, **kwargs: Any) -> "Redis":
49 """
50 Returns an active Redis client generated from the given database URL.
52 Will attempt to extract the database id from the path url fragment, if
53 none is provided.
54 """
55 from redis.client import Redis
57 return Redis.from_url(url, **kwargs)
60@contextmanager
61def pipeline(redis_obj):
62 p = redis_obj.pipeline()
63 yield p
64 p.execute()
67def str_if_bytes(value: Union[str, bytes]) -> str:
68 return (
69 value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value
70 )
73def safe_str(value):
74 return str(str_if_bytes(value))
77def decode_field_value(value, key=None, field_encodings=None):
78 """Decode a field value respecting optional per-field encoding overrides.
80 - If *field_encodings* is provided and *key* is in it, the corresponding
81 encoding is used (``None`` means keep raw bytes).
82 - Otherwise falls back to :func:`str_if_bytes`.
83 """
84 if not isinstance(value, bytes):
85 return value
86 if field_encodings and key is not None and key in field_encodings:
87 encoding = field_encodings[key]
88 if encoding is None:
89 return value
90 return value.decode(encoding, "replace")
91 return str_if_bytes(value)
94def dict_merge(*dicts: Mapping[str, Any]) -> Dict[str, Any]:
95 """
96 Merge all provided dicts into 1 dict.
97 *dicts : `dict`
98 dictionaries to merge
99 """
100 merged = {}
102 for d in dicts:
103 merged.update(d)
105 return merged
108def list_keys_to_dict(key_list, callback):
109 return dict.fromkeys(key_list, callback)
112def merge_result(command, res):
113 """
114 Merge all items in `res` into a list.
116 This command is used when sending a command to multiple nodes
117 and the result from each node should be merged into a single list.
119 res : 'dict'
120 """
121 result = set()
123 for v in res.values():
124 for value in v:
125 result.add(value)
127 return list(result)
130def warn_deprecated(name, reason="", version="", stacklevel=2):
131 import warnings
133 msg = f"Call to deprecated {name}."
134 if reason:
135 msg += f" ({reason})"
136 if version:
137 msg += f" -- Deprecated since version {version}."
138 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel)
141def deprecated_function(reason="", version="", name=None):
142 """
143 Decorator to mark a function as deprecated.
144 """
146 def decorator(func):
147 if inspect.iscoroutinefunction(func):
148 # Create async wrapper for async functions
149 @wraps(func)
150 async def async_wrapper(*args, **kwargs):
151 warn_deprecated(name or func.__name__, reason, version, stacklevel=3)
152 return await func(*args, **kwargs)
154 return async_wrapper
155 else:
156 # Create regular wrapper for sync functions
157 @wraps(func)
158 def wrapper(*args, **kwargs):
159 warn_deprecated(name or func.__name__, reason, version, stacklevel=3)
160 return func(*args, **kwargs)
162 return wrapper
164 return decorator
167def warn_deprecated_arg_usage(
168 arg_name: Union[list, str],
169 function_name: str,
170 reason: str = "",
171 version: str = "",
172 stacklevel: int = 2,
173):
174 import warnings
176 msg = (
177 f"Call to '{function_name}' function with deprecated"
178 f" usage of input argument/s '{arg_name}'."
179 )
180 if reason:
181 msg += f" ({reason})"
182 if version:
183 msg += f" -- Deprecated since version {version}."
184 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel)
187C = TypeVar("C", bound=Callable)
190def _get_filterable_args(
191 func: Callable, args: tuple, kwargs: dict, allowed_args: Optional[List[str]] = None
192) -> dict:
193 """
194 Extract arguments from function call that should be checked for deprecation/experimental warnings.
195 Excludes 'self' and any explicitly allowed args.
196 """
197 arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
198 filterable_args = dict(zip(arg_names, args))
199 filterable_args.update(kwargs)
200 filterable_args.pop("self", None)
201 if allowed_args:
202 for allowed_arg in allowed_args:
203 filterable_args.pop(allowed_arg, None)
204 return filterable_args
207def deprecated_args(
208 args_to_warn: Optional[List[str]] = None,
209 allowed_args: Optional[List[str]] = None,
210 reason: str = "",
211 version: str = "",
212) -> Callable[[C], C]:
213 """
214 Decorator to mark specified args of a function as deprecated.
215 If '*' is in args_to_warn, all arguments will be marked as deprecated.
216 """
217 if args_to_warn is None:
218 args_to_warn = ["*"]
219 if allowed_args is None:
220 allowed_args = []
222 def _check_deprecated_args(func, filterable_args):
223 """Check and warn about deprecated arguments."""
224 for arg in args_to_warn:
225 if arg == "*" and len(filterable_args) > 0:
226 warn_deprecated_arg_usage(
227 list(filterable_args.keys()),
228 func.__name__,
229 reason,
230 version,
231 stacklevel=5,
232 )
233 elif arg in filterable_args:
234 warn_deprecated_arg_usage(
235 arg, func.__name__, reason, version, stacklevel=5
236 )
238 def decorator(func: C) -> C:
239 if inspect.iscoroutinefunction(func):
241 @wraps(func)
242 async def async_wrapper(*args, **kwargs):
243 filterable_args = _get_filterable_args(func, args, kwargs, allowed_args)
244 _check_deprecated_args(func, filterable_args)
245 return await func(*args, **kwargs)
247 return async_wrapper
248 else:
250 @wraps(func)
251 def wrapper(*args, **kwargs):
252 filterable_args = _get_filterable_args(func, args, kwargs, allowed_args)
253 _check_deprecated_args(func, filterable_args)
254 return func(*args, **kwargs)
256 return wrapper
258 return decorator
261def _set_info_logger():
262 """
263 Set up a logger that log info logs to stdout.
264 (This is used by the default push response handler)
265 """
266 if "push_response" not in logging.root.manager.loggerDict.keys():
267 logger = logging.getLogger("push_response")
268 logger.setLevel(logging.INFO)
269 handler = logging.StreamHandler()
270 handler.setLevel(logging.INFO)
271 logger.addHandler(handler)
274def check_protocol_version(
275 protocol: Optional[Union[str, int]], expected_version: int = 3
276) -> bool:
277 if protocol is None:
278 return False
279 if isinstance(protocol, str):
280 try:
281 protocol = int(protocol)
282 except ValueError:
283 return False
284 return protocol == expected_version
287def get_lib_version():
288 try:
289 libver = metadata.version("redis")
290 except metadata.PackageNotFoundError:
291 libver = "99.99.99"
292 return libver
295def format_error_message(host_error: str, exception: BaseException) -> str:
296 if not exception.args:
297 return f"Error connecting to {host_error}."
298 elif len(exception.args) == 1:
299 return f"Error {exception.args[0]} connecting to {host_error}."
300 else:
301 return (
302 f"Error {exception.args[0]} connecting to {host_error}. "
303 f"{exception.args[1]}."
304 )
307def compare_versions(version1: str, version2: str) -> int:
308 """
309 Compare two versions.
311 :return: -1 if version1 > version2
312 0 if both versions are equal
313 1 if version1 < version2
314 """
316 num_versions1 = list(map(int, version1.split(".")))
317 num_versions2 = list(map(int, version2.split(".")))
319 if len(num_versions1) > len(num_versions2):
320 diff = len(num_versions1) - len(num_versions2)
321 for _ in range(diff):
322 num_versions2.append(0)
323 elif len(num_versions1) < len(num_versions2):
324 diff = len(num_versions2) - len(num_versions1)
325 for _ in range(diff):
326 num_versions1.append(0)
328 for i, ver in enumerate(num_versions1):
329 if num_versions1[i] > num_versions2[i]:
330 return -1
331 elif num_versions1[i] < num_versions2[i]:
332 return 1
334 return 0
337def ensure_string(key):
338 if isinstance(key, bytes):
339 return key.decode("utf-8")
340 elif isinstance(key, str):
341 return key
342 else:
343 raise TypeError("Key must be either a string or bytes")
346def extract_expire_flags(
347 ex: Optional[ExpiryT] = None,
348 px: Optional[ExpiryT] = None,
349 exat: Optional[AbsExpiryT] = None,
350 pxat: Optional[AbsExpiryT] = None,
351) -> List[EncodableT]:
352 exp_options: list[EncodableT] = []
353 if ex is not None:
354 exp_options.append("EX")
355 if isinstance(ex, datetime.timedelta):
356 exp_options.append(int(ex.total_seconds()))
357 elif isinstance(ex, int):
358 exp_options.append(ex)
359 elif isinstance(ex, str) and ex.isdigit():
360 exp_options.append(int(ex))
361 else:
362 raise DataError("ex must be datetime.timedelta or int")
363 elif px is not None:
364 exp_options.append("PX")
365 if isinstance(px, datetime.timedelta):
366 exp_options.append(int(px.total_seconds() * 1000))
367 elif isinstance(px, int):
368 exp_options.append(px)
369 else:
370 raise DataError("px must be datetime.timedelta or int")
371 elif exat is not None:
372 if isinstance(exat, datetime.datetime):
373 exat = int(exat.timestamp())
374 exp_options.extend(["EXAT", exat])
375 elif pxat is not None:
376 if isinstance(pxat, datetime.datetime):
377 pxat = int(pxat.timestamp() * 1000)
378 exp_options.extend(["PXAT", pxat])
380 return exp_options
383def truncate_text(txt, max_length=100):
384 return textwrap.shorten(
385 text=txt, width=max_length, placeholder="...", break_long_words=True
386 )
389def dummy_fail():
390 """
391 Fake function for a Retry object if you don't need to handle each failure.
392 """
393 pass
396async def dummy_fail_async():
397 """
398 Async fake function for a Retry object if you don't need to handle each failure.
399 """
400 pass
403def experimental(cls):
404 """
405 Decorator to mark a class as experimental.
406 """
407 original_init = cls.__init__
409 @wraps(original_init)
410 def new_init(self, *args, **kwargs):
411 warnings.warn(
412 f"{cls.__name__} is an experimental and may change or be removed in future versions.",
413 category=UserWarning,
414 stacklevel=2,
415 )
416 original_init(self, *args, **kwargs)
418 cls.__init__ = new_init
419 return cls
422def warn_experimental(name, stacklevel=2):
423 import warnings
425 msg = (
426 f"Call to experimental method {name}. "
427 "Be aware that the function arguments can "
428 "change or be removed in future versions."
429 )
430 warnings.warn(msg, category=UserWarning, stacklevel=stacklevel)
433def experimental_method() -> Callable[[C], C]:
434 """
435 Decorator to mark a function as experimental.
436 """
438 def decorator(func: C) -> C:
439 if inspect.iscoroutinefunction(func):
440 # Create async wrapper for async functions
441 @wraps(func)
442 async def async_wrapper(*args, **kwargs):
443 warn_experimental(func.__name__, stacklevel=2)
444 return await func(*args, **kwargs)
446 return async_wrapper
447 else:
448 # Create regular wrapper for sync functions
449 @wraps(func)
450 def wrapper(*args, **kwargs):
451 warn_experimental(func.__name__, stacklevel=2)
452 return func(*args, **kwargs)
454 return wrapper
456 return decorator
459def warn_experimental_arg_usage(
460 arg_name: Union[list, str],
461 function_name: str,
462 stacklevel: int = 2,
463):
464 import warnings
466 msg = (
467 f"Call to '{function_name}' method with experimental"
468 f" usage of input argument/s '{arg_name}'."
469 )
470 warnings.warn(msg, category=UserWarning, stacklevel=stacklevel)
473def experimental_args(
474 args_to_warn: Optional[List[str]] = None,
475) -> Callable[[C], C]:
476 """
477 Decorator to mark specified args of a function as experimental.
478 If '*' is in args_to_warn, all arguments will be marked as experimental.
479 """
480 if args_to_warn is None:
481 args_to_warn = ["*"]
483 def _check_experimental_args(func, filterable_args):
484 """Check and warn about experimental arguments."""
485 for arg in args_to_warn:
486 if arg == "*" and len(filterable_args) > 0:
487 warn_experimental_arg_usage(
488 list(filterable_args.keys()), func.__name__, stacklevel=4
489 )
490 elif arg in filterable_args:
491 warn_experimental_arg_usage(arg, func.__name__, stacklevel=4)
493 def decorator(func: C) -> C:
494 if inspect.iscoroutinefunction(func):
496 @wraps(func)
497 async def async_wrapper(*args, **kwargs):
498 filterable_args = _get_filterable_args(func, args, kwargs)
499 if len(filterable_args) > 0:
500 _check_experimental_args(func, filterable_args)
501 return await func(*args, **kwargs)
503 return async_wrapper
504 else:
506 @wraps(func)
507 def wrapper(*args, **kwargs):
508 filterable_args = _get_filterable_args(func, args, kwargs)
509 if len(filterable_args) > 0:
510 _check_experimental_args(func, filterable_args)
511 return func(*args, **kwargs)
513 return wrapper
515 return decorator