Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/utils.py: 40%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import datetime
2import inspect
3import logging
4import textwrap
5import warnings
6from collections.abc import Callable
7from contextlib import contextmanager
8from functools import wraps
9from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, TypeVar, Union
11from redis.exceptions import DataError
12from redis.typing import AbsExpiryT, EncodableT, ExpiryT
14if TYPE_CHECKING:
15 from redis.client import Redis
17try:
18 import hiredis # noqa
20 # Only support Hiredis >= 3.0:
21 hiredis_version = hiredis.__version__.split(".")
22 HIREDIS_AVAILABLE = int(hiredis_version[0]) > 3 or (
23 int(hiredis_version[0]) == 3 and int(hiredis_version[1]) >= 2
24 )
25 if not HIREDIS_AVAILABLE:
26 raise ImportError("hiredis package should be >= 3.2.0")
27except ImportError:
28 HIREDIS_AVAILABLE = False
30try:
31 import ssl # noqa
33 SSL_AVAILABLE = True
34except ImportError:
35 SSL_AVAILABLE = False
37try:
38 import cryptography # noqa
40 CRYPTOGRAPHY_AVAILABLE = True
41except ImportError:
42 CRYPTOGRAPHY_AVAILABLE = False
44from importlib import metadata
47def from_url(url: str, **kwargs: Any) -> "Redis":
48 """
49 Returns an active Redis client generated from the given database URL.
51 Will attempt to extract the database id from the path url fragment, if
52 none is provided.
53 """
54 from redis.client import Redis
56 return Redis.from_url(url, **kwargs)
59@contextmanager
60def pipeline(redis_obj):
61 p = redis_obj.pipeline()
62 yield p
63 p.execute()
66def str_if_bytes(value: Union[str, bytes]) -> str:
67 return (
68 value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value
69 )
72def safe_str(value):
73 return str(str_if_bytes(value))
76def dict_merge(*dicts: Mapping[str, Any]) -> Dict[str, Any]:
77 """
78 Merge all provided dicts into 1 dict.
79 *dicts : `dict`
80 dictionaries to merge
81 """
82 merged = {}
84 for d in dicts:
85 merged.update(d)
87 return merged
90def list_keys_to_dict(key_list, callback):
91 return dict.fromkeys(key_list, callback)
94def merge_result(command, res):
95 """
96 Merge all items in `res` into a list.
98 This command is used when sending a command to multiple nodes
99 and the result from each node should be merged into a single list.
101 res : 'dict'
102 """
103 result = set()
105 for v in res.values():
106 for value in v:
107 result.add(value)
109 return list(result)
112def warn_deprecated(name, reason="", version="", stacklevel=2):
113 import warnings
115 msg = f"Call to deprecated {name}."
116 if reason:
117 msg += f" ({reason})"
118 if version:
119 msg += f" -- Deprecated since version {version}."
120 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel)
123def deprecated_function(reason="", version="", name=None):
124 """
125 Decorator to mark a function as deprecated.
126 """
128 def decorator(func):
129 if inspect.iscoroutinefunction(func):
130 # Create async wrapper for async functions
131 @wraps(func)
132 async def async_wrapper(*args, **kwargs):
133 warn_deprecated(name or func.__name__, reason, version, stacklevel=3)
134 return await func(*args, **kwargs)
136 return async_wrapper
137 else:
138 # Create regular wrapper for sync functions
139 @wraps(func)
140 def wrapper(*args, **kwargs):
141 warn_deprecated(name or func.__name__, reason, version, stacklevel=3)
142 return func(*args, **kwargs)
144 return wrapper
146 return decorator
149def warn_deprecated_arg_usage(
150 arg_name: Union[list, str],
151 function_name: str,
152 reason: str = "",
153 version: str = "",
154 stacklevel: int = 2,
155):
156 import warnings
158 msg = (
159 f"Call to '{function_name}' function with deprecated"
160 f" usage of input argument/s '{arg_name}'."
161 )
162 if reason:
163 msg += f" ({reason})"
164 if version:
165 msg += f" -- Deprecated since version {version}."
166 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel)
169C = TypeVar("C", bound=Callable)
172def _get_filterable_args(
173 func: Callable, args: tuple, kwargs: dict, allowed_args: Optional[List[str]] = None
174) -> dict:
175 """
176 Extract arguments from function call that should be checked for deprecation/experimental warnings.
177 Excludes 'self' and any explicitly allowed args.
178 """
179 arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
180 filterable_args = dict(zip(arg_names, args))
181 filterable_args.update(kwargs)
182 filterable_args.pop("self", None)
183 if allowed_args:
184 for allowed_arg in allowed_args:
185 filterable_args.pop(allowed_arg, None)
186 return filterable_args
189def deprecated_args(
190 args_to_warn: Optional[List[str]] = None,
191 allowed_args: Optional[List[str]] = None,
192 reason: str = "",
193 version: str = "",
194) -> Callable[[C], C]:
195 """
196 Decorator to mark specified args of a function as deprecated.
197 If '*' is in args_to_warn, all arguments will be marked as deprecated.
198 """
199 if args_to_warn is None:
200 args_to_warn = ["*"]
201 if allowed_args is None:
202 allowed_args = []
204 def _check_deprecated_args(func, filterable_args):
205 """Check and warn about deprecated arguments."""
206 for arg in args_to_warn:
207 if arg == "*" and len(filterable_args) > 0:
208 warn_deprecated_arg_usage(
209 list(filterable_args.keys()),
210 func.__name__,
211 reason,
212 version,
213 stacklevel=5,
214 )
215 elif arg in filterable_args:
216 warn_deprecated_arg_usage(
217 arg, func.__name__, reason, version, stacklevel=5
218 )
220 def decorator(func: C) -> C:
221 if inspect.iscoroutinefunction(func):
223 @wraps(func)
224 async def async_wrapper(*args, **kwargs):
225 filterable_args = _get_filterable_args(func, args, kwargs, allowed_args)
226 _check_deprecated_args(func, filterable_args)
227 return await func(*args, **kwargs)
229 return async_wrapper
230 else:
232 @wraps(func)
233 def wrapper(*args, **kwargs):
234 filterable_args = _get_filterable_args(func, args, kwargs, allowed_args)
235 _check_deprecated_args(func, filterable_args)
236 return func(*args, **kwargs)
238 return wrapper
240 return decorator
243def _set_info_logger():
244 """
245 Set up a logger that log info logs to stdout.
246 (This is used by the default push response handler)
247 """
248 if "push_response" not in logging.root.manager.loggerDict.keys():
249 logger = logging.getLogger("push_response")
250 logger.setLevel(logging.INFO)
251 handler = logging.StreamHandler()
252 handler.setLevel(logging.INFO)
253 logger.addHandler(handler)
256def check_protocol_version(
257 protocol: Optional[Union[str, int]], expected_version: int = 3
258) -> bool:
259 if protocol is None:
260 return False
261 if isinstance(protocol, str):
262 try:
263 protocol = int(protocol)
264 except ValueError:
265 return False
266 return protocol == expected_version
269def get_lib_version():
270 try:
271 libver = metadata.version("redis")
272 except metadata.PackageNotFoundError:
273 libver = "99.99.99"
274 return libver
277def format_error_message(host_error: str, exception: BaseException) -> str:
278 if not exception.args:
279 return f"Error connecting to {host_error}."
280 elif len(exception.args) == 1:
281 return f"Error {exception.args[0]} connecting to {host_error}."
282 else:
283 return (
284 f"Error {exception.args[0]} connecting to {host_error}. "
285 f"{exception.args[1]}."
286 )
289def compare_versions(version1: str, version2: str) -> int:
290 """
291 Compare two versions.
293 :return: -1 if version1 > version2
294 0 if both versions are equal
295 1 if version1 < version2
296 """
298 num_versions1 = list(map(int, version1.split(".")))
299 num_versions2 = list(map(int, version2.split(".")))
301 if len(num_versions1) > len(num_versions2):
302 diff = len(num_versions1) - len(num_versions2)
303 for _ in range(diff):
304 num_versions2.append(0)
305 elif len(num_versions1) < len(num_versions2):
306 diff = len(num_versions2) - len(num_versions1)
307 for _ in range(diff):
308 num_versions1.append(0)
310 for i, ver in enumerate(num_versions1):
311 if num_versions1[i] > num_versions2[i]:
312 return -1
313 elif num_versions1[i] < num_versions2[i]:
314 return 1
316 return 0
319def ensure_string(key):
320 if isinstance(key, bytes):
321 return key.decode("utf-8")
322 elif isinstance(key, str):
323 return key
324 else:
325 raise TypeError("Key must be either a string or bytes")
328def extract_expire_flags(
329 ex: Optional[ExpiryT] = None,
330 px: Optional[ExpiryT] = None,
331 exat: Optional[AbsExpiryT] = None,
332 pxat: Optional[AbsExpiryT] = None,
333) -> List[EncodableT]:
334 exp_options: list[EncodableT] = []
335 if ex is not None:
336 exp_options.append("EX")
337 if isinstance(ex, datetime.timedelta):
338 exp_options.append(int(ex.total_seconds()))
339 elif isinstance(ex, int):
340 exp_options.append(ex)
341 elif isinstance(ex, str) and ex.isdigit():
342 exp_options.append(int(ex))
343 else:
344 raise DataError("ex must be datetime.timedelta or int")
345 elif px is not None:
346 exp_options.append("PX")
347 if isinstance(px, datetime.timedelta):
348 exp_options.append(int(px.total_seconds() * 1000))
349 elif isinstance(px, int):
350 exp_options.append(px)
351 else:
352 raise DataError("px must be datetime.timedelta or int")
353 elif exat is not None:
354 if isinstance(exat, datetime.datetime):
355 exat = int(exat.timestamp())
356 exp_options.extend(["EXAT", exat])
357 elif pxat is not None:
358 if isinstance(pxat, datetime.datetime):
359 pxat = int(pxat.timestamp() * 1000)
360 exp_options.extend(["PXAT", pxat])
362 return exp_options
365def truncate_text(txt, max_length=100):
366 return textwrap.shorten(
367 text=txt, width=max_length, placeholder="...", break_long_words=True
368 )
371def dummy_fail():
372 """
373 Fake function for a Retry object if you don't need to handle each failure.
374 """
375 pass
378async def dummy_fail_async():
379 """
380 Async fake function for a Retry object if you don't need to handle each failure.
381 """
382 pass
385def experimental(cls):
386 """
387 Decorator to mark a class as experimental.
388 """
389 original_init = cls.__init__
391 @wraps(original_init)
392 def new_init(self, *args, **kwargs):
393 warnings.warn(
394 f"{cls.__name__} is an experimental and may change or be removed in future versions.",
395 category=UserWarning,
396 stacklevel=2,
397 )
398 original_init(self, *args, **kwargs)
400 cls.__init__ = new_init
401 return cls
404def warn_experimental(name, stacklevel=2):
405 import warnings
407 msg = (
408 f"Call to experimental method {name}. "
409 "Be aware that the function arguments can "
410 "change or be removed in future versions."
411 )
412 warnings.warn(msg, category=UserWarning, stacklevel=stacklevel)
415def experimental_method() -> Callable[[C], C]:
416 """
417 Decorator to mark a function as experimental.
418 """
420 def decorator(func: C) -> C:
421 if inspect.iscoroutinefunction(func):
422 # Create async wrapper for async functions
423 @wraps(func)
424 async def async_wrapper(*args, **kwargs):
425 warn_experimental(func.__name__, stacklevel=2)
426 return await func(*args, **kwargs)
428 return async_wrapper
429 else:
430 # Create regular wrapper for sync functions
431 @wraps(func)
432 def wrapper(*args, **kwargs):
433 warn_experimental(func.__name__, stacklevel=2)
434 return func(*args, **kwargs)
436 return wrapper
438 return decorator
441def warn_experimental_arg_usage(
442 arg_name: Union[list, str],
443 function_name: str,
444 stacklevel: int = 2,
445):
446 import warnings
448 msg = (
449 f"Call to '{function_name}' method with experimental"
450 f" usage of input argument/s '{arg_name}'."
451 )
452 warnings.warn(msg, category=UserWarning, stacklevel=stacklevel)
455def experimental_args(
456 args_to_warn: Optional[List[str]] = None,
457) -> Callable[[C], C]:
458 """
459 Decorator to mark specified args of a function as experimental.
460 If '*' is in args_to_warn, all arguments will be marked as experimental.
461 """
462 if args_to_warn is None:
463 args_to_warn = ["*"]
465 def _check_experimental_args(func, filterable_args):
466 """Check and warn about experimental arguments."""
467 for arg in args_to_warn:
468 if arg == "*" and len(filterable_args) > 0:
469 warn_experimental_arg_usage(
470 list(filterable_args.keys()), func.__name__, stacklevel=4
471 )
472 elif arg in filterable_args:
473 warn_experimental_arg_usage(arg, func.__name__, stacklevel=4)
475 def decorator(func: C) -> C:
476 if inspect.iscoroutinefunction(func):
478 @wraps(func)
479 async def async_wrapper(*args, **kwargs):
480 filterable_args = _get_filterable_args(func, args, kwargs)
481 if len(filterable_args) > 0:
482 _check_experimental_args(func, filterable_args)
483 return await func(*args, **kwargs)
485 return async_wrapper
486 else:
488 @wraps(func)
489 def wrapper(*args, **kwargs):
490 filterable_args = _get_filterable_args(func, args, kwargs)
491 if len(filterable_args) > 0:
492 _check_experimental_args(func, filterable_args)
493 return func(*args, **kwargs)
495 return wrapper
497 return decorator