Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/utils.py: 38%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import datetime
2import inspect
3import logging
4import textwrap
5import warnings
6from collections.abc import Callable
7from contextlib import contextmanager
8from functools import wraps
9from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, TypeVar, Union
11from redis.exceptions import DataError
12from redis.typing import AbsExpiryT, EncodableT, ExpiryT
14if TYPE_CHECKING:
15 from redis.client import Redis
17try:
18 import hiredis # noqa
20 # Only support Hiredis >= 3.0:
21 hiredis_version = hiredis.__version__.split(".")
22 HIREDIS_AVAILABLE = int(hiredis_version[0]) > 3 or (
23 int(hiredis_version[0]) == 3 and int(hiredis_version[1]) >= 2
24 )
25 if not HIREDIS_AVAILABLE:
26 raise ImportError("hiredis package should be >= 3.2.0")
27except ImportError:
28 HIREDIS_AVAILABLE = False
30try:
31 import ssl # noqa
33 SSL_AVAILABLE = True
34except ImportError:
35 SSL_AVAILABLE = False
37try:
38 import cryptography # noqa
40 CRYPTOGRAPHY_AVAILABLE = True
41except ImportError:
42 CRYPTOGRAPHY_AVAILABLE = False
44from importlib import metadata
46# Shared marker for omitted arguments, especially where None is a valid
47# explicit value. Import this object from redis.utils instead of creating local
48# sentinels, and compare it by identity only (`is` / `is not`).
49SENTINEL = object()
52def from_url(url: str, **kwargs: Any) -> "Redis":
53 """
54 Returns an active Redis client generated from the given database URL.
56 Will attempt to extract the database id from the path url fragment, if
57 none is provided.
58 """
59 from redis.client import Redis
61 return Redis.from_url(url, **kwargs)
64@contextmanager
65def pipeline(redis_obj):
66 p = redis_obj.pipeline()
67 yield p
68 p.execute()
71def str_if_bytes(value: Union[str, bytes]) -> str:
72 return (
73 value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value
74 )
77def safe_str(value):
78 return str(str_if_bytes(value))
81def decode_field_value(value, key=None, field_encodings=None):
82 """Decode a field value respecting optional per-field encoding overrides.
84 - If *field_encodings* is provided and *key* is in it, the corresponding
85 encoding is used (``None`` means keep raw bytes).
86 - Otherwise falls back to :func:`str_if_bytes`.
87 """
88 if not isinstance(value, bytes):
89 return value
90 if field_encodings and key is not None and key in field_encodings:
91 encoding = field_encodings[key]
92 if encoding is None:
93 return value
94 return value.decode(encoding, "replace")
95 return str_if_bytes(value)
98def dict_merge(*dicts: Mapping[str, Any]) -> Dict[str, Any]:
99 """
100 Merge all provided dicts into 1 dict.
101 *dicts : `dict`
102 dictionaries to merge
103 """
104 merged = {}
106 for d in dicts:
107 merged.update(d)
109 return merged
112def list_keys_to_dict(key_list, callback):
113 return dict.fromkeys(key_list, callback)
116def merge_result(command, res):
117 """
118 Merge all items in `res` into a list.
120 This command is used when sending a command to multiple nodes
121 and the result from each node should be merged into a single list.
123 res : 'dict'
124 """
125 result = set()
127 for v in res.values():
128 for value in v:
129 result.add(value)
131 return list(result)
134def warn_deprecated(name, reason="", version="", stacklevel=2):
135 import warnings
137 msg = f"Call to deprecated {name}."
138 if reason:
139 msg += f" ({reason})"
140 if version:
141 msg += f" -- Deprecated since version {version}."
142 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel)
145def deprecated_function(reason="", version="", name=None):
146 """
147 Decorator to mark a function as deprecated.
148 """
150 def decorator(func):
151 if inspect.iscoroutinefunction(func):
152 # Create async wrapper for async functions
153 @wraps(func)
154 async def async_wrapper(*args, **kwargs):
155 warn_deprecated(name or func.__name__, reason, version, stacklevel=3)
156 return await func(*args, **kwargs)
158 return async_wrapper
159 else:
160 # Create regular wrapper for sync functions
161 @wraps(func)
162 def wrapper(*args, **kwargs):
163 warn_deprecated(name or func.__name__, reason, version, stacklevel=3)
164 return func(*args, **kwargs)
166 return wrapper
168 return decorator
171def warn_deprecated_arg_usage(
172 arg_name: Union[list, str],
173 function_name: str,
174 reason: str = "",
175 version: str = "",
176 stacklevel: int = 2,
177):
178 import warnings
180 msg = (
181 f"Call to '{function_name}' function with deprecated"
182 f" usage of input argument/s '{arg_name}'."
183 )
184 if reason:
185 msg += f" ({reason})"
186 if version:
187 msg += f" -- Deprecated since version {version}."
188 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel)
191C = TypeVar("C", bound=Callable)
194def _get_filterable_args(
195 func: Callable, args: tuple, kwargs: dict, allowed_args: Optional[List[str]] = None
196) -> dict:
197 """
198 Extract arguments from function call that should be checked for deprecation/experimental warnings.
199 Excludes 'self' and any explicitly allowed args.
200 """
201 arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
202 filterable_args = dict(zip(arg_names, args))
203 filterable_args.update(kwargs)
204 filterable_args.pop("self", None)
205 if allowed_args:
206 for allowed_arg in allowed_args:
207 filterable_args.pop(allowed_arg, None)
208 return filterable_args
211def deprecated_args(
212 args_to_warn: Optional[List[str]] = None,
213 allowed_args: Optional[List[str]] = None,
214 reason: str = "",
215 version: str = "",
216) -> Callable[[C], C]:
217 """
218 Decorator to mark specified args of a function as deprecated.
219 If '*' is in args_to_warn, all arguments will be marked as deprecated.
220 """
221 if args_to_warn is None:
222 args_to_warn = ["*"]
223 if allowed_args is None:
224 allowed_args = []
226 def _check_deprecated_args(func, filterable_args):
227 """Check and warn about deprecated arguments."""
228 for arg in args_to_warn:
229 if arg == "*" and len(filterable_args) > 0:
230 warn_deprecated_arg_usage(
231 list(filterable_args.keys()),
232 func.__name__,
233 reason,
234 version,
235 stacklevel=5,
236 )
237 elif arg in filterable_args:
238 warn_deprecated_arg_usage(
239 arg, func.__name__, reason, version, stacklevel=5
240 )
242 def decorator(func: C) -> C:
243 if inspect.iscoroutinefunction(func):
245 @wraps(func)
246 async def async_wrapper(*args, **kwargs):
247 filterable_args = _get_filterable_args(func, args, kwargs, allowed_args)
248 _check_deprecated_args(func, filterable_args)
249 return await func(*args, **kwargs)
251 return async_wrapper
252 else:
254 @wraps(func)
255 def wrapper(*args, **kwargs):
256 filterable_args = _get_filterable_args(func, args, kwargs, allowed_args)
257 _check_deprecated_args(func, filterable_args)
258 return func(*args, **kwargs)
260 return wrapper
262 return decorator
265def _set_info_logger():
266 """
267 Set up a logger that log info logs to stdout.
268 (This is used by the default push response handler)
269 """
270 if "push_response" not in logging.root.manager.loggerDict.keys():
271 logger = logging.getLogger("push_response")
272 logger.setLevel(logging.INFO)
273 handler = logging.StreamHandler()
274 handler.setLevel(logging.INFO)
275 logger.addHandler(handler)
278#: Default RESP protocol version used on the wire when the user does not
279#: supply an explicit ``protocol`` to the client / connection / pool. Lives
280#: in ``redis.utils`` so both ``redis.connection`` (for the HELLO handshake)
281#: and ``check_protocol_version`` (for protocol-gated features) can read it
282#: without a circular import.
283DEFAULT_RESP_VERSION = 3
286def check_protocol_version(
287 protocol: Optional[Union[str, int]], expected_version: int = 3
288) -> bool:
289 if protocol is None:
290 protocol = DEFAULT_RESP_VERSION
291 if isinstance(protocol, str):
292 try:
293 protocol = int(protocol)
294 except ValueError:
295 return False
296 return protocol == expected_version
299def get_lib_version():
300 try:
301 libver = metadata.version("redis")
302 except metadata.PackageNotFoundError:
303 libver = "99.99.99"
304 return libver
307def format_error_message(host_error: str, exception: BaseException) -> str:
308 if not exception.args:
309 return f"Error connecting to {host_error}."
310 elif len(exception.args) == 1:
311 return f"Error {exception.args[0]} connecting to {host_error}."
312 else:
313 return (
314 f"Error {exception.args[0]} connecting to {host_error}. "
315 f"{exception.args[1]}."
316 )
319def compare_versions(version1: str, version2: str) -> int:
320 """
321 Compare two versions.
323 :return: -1 if version1 > version2
324 0 if both versions are equal
325 1 if version1 < version2
326 """
328 num_versions1 = list(map(int, version1.split(".")))
329 num_versions2 = list(map(int, version2.split(".")))
331 if len(num_versions1) > len(num_versions2):
332 diff = len(num_versions1) - len(num_versions2)
333 for _ in range(diff):
334 num_versions2.append(0)
335 elif len(num_versions1) < len(num_versions2):
336 diff = len(num_versions2) - len(num_versions1)
337 for _ in range(diff):
338 num_versions1.append(0)
340 for i, ver in enumerate(num_versions1):
341 if num_versions1[i] > num_versions2[i]:
342 return -1
343 elif num_versions1[i] < num_versions2[i]:
344 return 1
346 return 0
349def ensure_string(key):
350 if isinstance(key, bytes):
351 return key.decode("utf-8")
352 elif isinstance(key, str):
353 return key
354 else:
355 raise TypeError("Key must be either a string or bytes")
358def extract_expire_flags(
359 ex: Optional[ExpiryT] = None,
360 px: Optional[ExpiryT] = None,
361 exat: Optional[AbsExpiryT] = None,
362 pxat: Optional[AbsExpiryT] = None,
363) -> List[EncodableT]:
364 exp_options: list[EncodableT] = []
365 if ex is not None:
366 exp_options.append("EX")
367 if isinstance(ex, datetime.timedelta):
368 exp_options.append(int(ex.total_seconds()))
369 elif isinstance(ex, int):
370 exp_options.append(ex)
371 elif isinstance(ex, str) and ex.isdigit():
372 exp_options.append(int(ex))
373 else:
374 raise DataError("ex must be datetime.timedelta or int")
375 elif px is not None:
376 exp_options.append("PX")
377 if isinstance(px, datetime.timedelta):
378 exp_options.append(int(px.total_seconds() * 1000))
379 elif isinstance(px, int):
380 exp_options.append(px)
381 else:
382 raise DataError("px must be datetime.timedelta or int")
383 elif exat is not None:
384 if isinstance(exat, datetime.datetime):
385 exat = int(exat.timestamp())
386 exp_options.extend(["EXAT", exat])
387 elif pxat is not None:
388 if isinstance(pxat, datetime.datetime):
389 pxat = int(pxat.timestamp() * 1000)
390 exp_options.extend(["PXAT", pxat])
392 return exp_options
395def truncate_text(txt, max_length=100):
396 return textwrap.shorten(
397 text=txt, width=max_length, placeholder="...", break_long_words=True
398 )
401def dummy_fail():
402 """
403 Fake function for a Retry object if you don't need to handle each failure.
404 """
405 pass
408async def dummy_fail_async():
409 """
410 Async fake function for a Retry object if you don't need to handle each failure.
411 """
412 pass
415def experimental(cls):
416 """
417 Decorator to mark a class as experimental.
418 """
419 original_init = cls.__init__
421 @wraps(original_init)
422 def new_init(self, *args, **kwargs):
423 warnings.warn(
424 f"{cls.__name__} is an experimental and may change or be removed in future versions.",
425 category=UserWarning,
426 stacklevel=2,
427 )
428 original_init(self, *args, **kwargs)
430 cls.__init__ = new_init
431 return cls
434def warn_experimental(name, stacklevel=2):
435 import warnings
437 msg = (
438 f"Call to experimental method {name}. "
439 "Be aware that the function arguments can "
440 "change or be removed in future versions."
441 )
442 warnings.warn(msg, category=UserWarning, stacklevel=stacklevel)
445def experimental_method() -> Callable[[C], C]:
446 """
447 Decorator to mark a function as experimental.
448 """
450 def decorator(func: C) -> C:
451 if inspect.iscoroutinefunction(func):
452 # Create async wrapper for async functions
453 @wraps(func)
454 async def async_wrapper(*args, **kwargs):
455 warn_experimental(func.__name__, stacklevel=2)
456 return await func(*args, **kwargs)
458 return async_wrapper
459 else:
460 # Create regular wrapper for sync functions
461 @wraps(func)
462 def wrapper(*args, **kwargs):
463 warn_experimental(func.__name__, stacklevel=2)
464 return func(*args, **kwargs)
466 return wrapper
468 return decorator
471def warn_experimental_arg_usage(
472 arg_name: Union[list, str],
473 function_name: str,
474 stacklevel: int = 2,
475):
476 import warnings
478 msg = (
479 f"Call to '{function_name}' method with experimental"
480 f" usage of input argument/s '{arg_name}'."
481 )
482 warnings.warn(msg, category=UserWarning, stacklevel=stacklevel)
485def experimental_args(
486 args_to_warn: Optional[List[str]] = None,
487) -> Callable[[C], C]:
488 """
489 Decorator to mark specified args of a function as experimental.
490 If '*' is in args_to_warn, all arguments will be marked as experimental.
491 """
492 if args_to_warn is None:
493 args_to_warn = ["*"]
495 def _check_experimental_args(func, filterable_args):
496 """Check and warn about experimental arguments."""
497 for arg in args_to_warn:
498 if arg == "*" and len(filterable_args) > 0:
499 warn_experimental_arg_usage(
500 list(filterable_args.keys()), func.__name__, stacklevel=4
501 )
502 elif arg in filterable_args:
503 warn_experimental_arg_usage(arg, func.__name__, stacklevel=4)
505 def decorator(func: C) -> C:
506 if inspect.iscoroutinefunction(func):
508 @wraps(func)
509 async def async_wrapper(*args, **kwargs):
510 filterable_args = _get_filterable_args(func, args, kwargs)
511 if len(filterable_args) > 0:
512 _check_experimental_args(func, filterable_args)
513 return await func(*args, **kwargs)
515 return async_wrapper
516 else:
518 @wraps(func)
519 def wrapper(*args, **kwargs):
520 filterable_args = _get_filterable_args(func, args, kwargs)
521 if len(filterable_args) > 0:
522 _check_experimental_args(func, filterable_args)
523 return func(*args, **kwargs)
525 return wrapper
527 return decorator