Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/utils.py: 40%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import datetime
2import logging
3import textwrap
4from collections.abc import Callable
5from contextlib import contextmanager
6from functools import wraps
7from typing import Any, Dict, List, Mapping, Optional, TypeVar, Union
9from redis.exceptions import DataError
10from redis.typing import AbsExpiryT, EncodableT, ExpiryT
12try:
13 import hiredis # noqa
15 # Only support Hiredis >= 3.0:
16 hiredis_version = hiredis.__version__.split(".")
17 HIREDIS_AVAILABLE = int(hiredis_version[0]) > 3 or (
18 int(hiredis_version[0]) == 3 and int(hiredis_version[1]) >= 2
19 )
20 if not HIREDIS_AVAILABLE:
21 raise ImportError("hiredis package should be >= 3.2.0")
22except ImportError:
23 HIREDIS_AVAILABLE = False
25try:
26 import ssl # noqa
28 SSL_AVAILABLE = True
29except ImportError:
30 SSL_AVAILABLE = False
32try:
33 import cryptography # noqa
35 CRYPTOGRAPHY_AVAILABLE = True
36except ImportError:
37 CRYPTOGRAPHY_AVAILABLE = False
39from importlib import metadata
42def from_url(url, **kwargs):
43 """
44 Returns an active Redis client generated from the given database URL.
46 Will attempt to extract the database id from the path url fragment, if
47 none is provided.
48 """
49 from redis.client import Redis
51 return Redis.from_url(url, **kwargs)
54@contextmanager
55def pipeline(redis_obj):
56 p = redis_obj.pipeline()
57 yield p
58 p.execute()
61def str_if_bytes(value: Union[str, bytes]) -> str:
62 return (
63 value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value
64 )
67def safe_str(value):
68 return str(str_if_bytes(value))
71def dict_merge(*dicts: Mapping[str, Any]) -> Dict[str, Any]:
72 """
73 Merge all provided dicts into 1 dict.
74 *dicts : `dict`
75 dictionaries to merge
76 """
77 merged = {}
79 for d in dicts:
80 merged.update(d)
82 return merged
85def list_keys_to_dict(key_list, callback):
86 return dict.fromkeys(key_list, callback)
89def merge_result(command, res):
90 """
91 Merge all items in `res` into a list.
93 This command is used when sending a command to multiple nodes
94 and the result from each node should be merged into a single list.
96 res : 'dict'
97 """
98 result = set()
100 for v in res.values():
101 for value in v:
102 result.add(value)
104 return list(result)
107def warn_deprecated(name, reason="", version="", stacklevel=2):
108 import warnings
110 msg = f"Call to deprecated {name}."
111 if reason:
112 msg += f" ({reason})"
113 if version:
114 msg += f" -- Deprecated since version {version}."
115 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel)
118def deprecated_function(reason="", version="", name=None):
119 """
120 Decorator to mark a function as deprecated.
121 """
123 def decorator(func):
124 @wraps(func)
125 def wrapper(*args, **kwargs):
126 warn_deprecated(name or func.__name__, reason, version, stacklevel=3)
127 return func(*args, **kwargs)
129 return wrapper
131 return decorator
134def warn_deprecated_arg_usage(
135 arg_name: Union[list, str],
136 function_name: str,
137 reason: str = "",
138 version: str = "",
139 stacklevel: int = 2,
140):
141 import warnings
143 msg = (
144 f"Call to '{function_name}' function with deprecated"
145 f" usage of input argument/s '{arg_name}'."
146 )
147 if reason:
148 msg += f" ({reason})"
149 if version:
150 msg += f" -- Deprecated since version {version}."
151 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel)
154C = TypeVar("C", bound=Callable)
157def deprecated_args(
158 args_to_warn: list = ["*"],
159 allowed_args: list = [],
160 reason: str = "",
161 version: str = "",
162) -> Callable[[C], C]:
163 """
164 Decorator to mark specified args of a function as deprecated.
165 If '*' is in args_to_warn, all arguments will be marked as deprecated.
166 """
168 def decorator(func: C) -> C:
169 @wraps(func)
170 def wrapper(*args, **kwargs):
171 # Get function argument names
172 arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
174 provided_args = dict(zip(arg_names, args))
175 provided_args.update(kwargs)
177 provided_args.pop("self", None)
178 for allowed_arg in allowed_args:
179 provided_args.pop(allowed_arg, None)
181 for arg in args_to_warn:
182 if arg == "*" and len(provided_args) > 0:
183 warn_deprecated_arg_usage(
184 list(provided_args.keys()),
185 func.__name__,
186 reason,
187 version,
188 stacklevel=3,
189 )
190 elif arg in provided_args:
191 warn_deprecated_arg_usage(
192 arg, func.__name__, reason, version, stacklevel=3
193 )
195 return func(*args, **kwargs)
197 return wrapper
199 return decorator
202def _set_info_logger():
203 """
204 Set up a logger that log info logs to stdout.
205 (This is used by the default push response handler)
206 """
207 if "push_response" not in logging.root.manager.loggerDict.keys():
208 logger = logging.getLogger("push_response")
209 logger.setLevel(logging.INFO)
210 handler = logging.StreamHandler()
211 handler.setLevel(logging.INFO)
212 logger.addHandler(handler)
215def get_lib_version():
216 try:
217 libver = metadata.version("redis")
218 except metadata.PackageNotFoundError:
219 libver = "99.99.99"
220 return libver
223def format_error_message(host_error: str, exception: BaseException) -> str:
224 if not exception.args:
225 return f"Error connecting to {host_error}."
226 elif len(exception.args) == 1:
227 return f"Error {exception.args[0]} connecting to {host_error}."
228 else:
229 return (
230 f"Error {exception.args[0]} connecting to {host_error}. "
231 f"{exception.args[1]}."
232 )
235def compare_versions(version1: str, version2: str) -> int:
236 """
237 Compare two versions.
239 :return: -1 if version1 > version2
240 0 if both versions are equal
241 1 if version1 < version2
242 """
244 num_versions1 = list(map(int, version1.split(".")))
245 num_versions2 = list(map(int, version2.split(".")))
247 if len(num_versions1) > len(num_versions2):
248 diff = len(num_versions1) - len(num_versions2)
249 for _ in range(diff):
250 num_versions2.append(0)
251 elif len(num_versions1) < len(num_versions2):
252 diff = len(num_versions2) - len(num_versions1)
253 for _ in range(diff):
254 num_versions1.append(0)
256 for i, ver in enumerate(num_versions1):
257 if num_versions1[i] > num_versions2[i]:
258 return -1
259 elif num_versions1[i] < num_versions2[i]:
260 return 1
262 return 0
265def ensure_string(key):
266 if isinstance(key, bytes):
267 return key.decode("utf-8")
268 elif isinstance(key, str):
269 return key
270 else:
271 raise TypeError("Key must be either a string or bytes")
274def extract_expire_flags(
275 ex: Optional[ExpiryT] = None,
276 px: Optional[ExpiryT] = None,
277 exat: Optional[AbsExpiryT] = None,
278 pxat: Optional[AbsExpiryT] = None,
279) -> List[EncodableT]:
280 exp_options: list[EncodableT] = []
281 if ex is not None:
282 exp_options.append("EX")
283 if isinstance(ex, datetime.timedelta):
284 exp_options.append(int(ex.total_seconds()))
285 elif isinstance(ex, int):
286 exp_options.append(ex)
287 elif isinstance(ex, str) and ex.isdigit():
288 exp_options.append(int(ex))
289 else:
290 raise DataError("ex must be datetime.timedelta or int")
291 elif px is not None:
292 exp_options.append("PX")
293 if isinstance(px, datetime.timedelta):
294 exp_options.append(int(px.total_seconds() * 1000))
295 elif isinstance(px, int):
296 exp_options.append(px)
297 else:
298 raise DataError("px must be datetime.timedelta or int")
299 elif exat is not None:
300 if isinstance(exat, datetime.datetime):
301 exat = int(exat.timestamp())
302 exp_options.extend(["EXAT", exat])
303 elif pxat is not None:
304 if isinstance(pxat, datetime.datetime):
305 pxat = int(pxat.timestamp() * 1000)
306 exp_options.extend(["PXAT", pxat])
308 return exp_options
311def truncate_text(txt, max_length=100):
312 return textwrap.shorten(
313 text=txt, width=max_length, placeholder="...", break_long_words=True
314 )