Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/utils.py: 39%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

180 statements  

1import datetime 

2import logging 

3import textwrap 

4import warnings 

5from collections.abc import Callable 

6from contextlib import contextmanager 

7from functools import wraps 

8from typing import Any, Dict, List, Mapping, Optional, TypeVar, Union 

9 

10from redis.exceptions import DataError 

11from redis.typing import AbsExpiryT, EncodableT, ExpiryT 

12 

13try: 

14 import hiredis # noqa 

15 

16 # Only support Hiredis >= 3.0: 

17 hiredis_version = hiredis.__version__.split(".") 

18 HIREDIS_AVAILABLE = int(hiredis_version[0]) > 3 or ( 

19 int(hiredis_version[0]) == 3 and int(hiredis_version[1]) >= 2 

20 ) 

21 if not HIREDIS_AVAILABLE: 

22 raise ImportError("hiredis package should be >= 3.2.0") 

23except ImportError: 

24 HIREDIS_AVAILABLE = False 

25 

26try: 

27 import ssl # noqa 

28 

29 SSL_AVAILABLE = True 

30except ImportError: 

31 SSL_AVAILABLE = False 

32 

33try: 

34 import cryptography # noqa 

35 

36 CRYPTOGRAPHY_AVAILABLE = True 

37except ImportError: 

38 CRYPTOGRAPHY_AVAILABLE = False 

39 

40from importlib import metadata 

41 

42 

43def from_url(url, **kwargs): 

44 """ 

45 Returns an active Redis client generated from the given database URL. 

46 

47 Will attempt to extract the database id from the path url fragment, if 

48 none is provided. 

49 """ 

50 from redis.client import Redis 

51 

52 return Redis.from_url(url, **kwargs) 

53 

54 

55@contextmanager 

56def pipeline(redis_obj): 

57 p = redis_obj.pipeline() 

58 yield p 

59 p.execute() 

60 

61 

62def str_if_bytes(value: Union[str, bytes]) -> str: 

63 return ( 

64 value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value 

65 ) 

66 

67 

68def safe_str(value): 

69 return str(str_if_bytes(value)) 

70 

71 

72def dict_merge(*dicts: Mapping[str, Any]) -> Dict[str, Any]: 

73 """ 

74 Merge all provided dicts into 1 dict. 

75 *dicts : `dict` 

76 dictionaries to merge 

77 """ 

78 merged = {} 

79 

80 for d in dicts: 

81 merged.update(d) 

82 

83 return merged 

84 

85 

86def list_keys_to_dict(key_list, callback): 

87 return dict.fromkeys(key_list, callback) 

88 

89 

90def merge_result(command, res): 

91 """ 

92 Merge all items in `res` into a list. 

93 

94 This command is used when sending a command to multiple nodes 

95 and the result from each node should be merged into a single list. 

96 

97 res : 'dict' 

98 """ 

99 result = set() 

100 

101 for v in res.values(): 

102 for value in v: 

103 result.add(value) 

104 

105 return list(result) 

106 

107 

108def warn_deprecated(name, reason="", version="", stacklevel=2): 

109 import warnings 

110 

111 msg = f"Call to deprecated {name}." 

112 if reason: 

113 msg += f" ({reason})" 

114 if version: 

115 msg += f" -- Deprecated since version {version}." 

116 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) 

117 

118 

119def deprecated_function(reason="", version="", name=None): 

120 """ 

121 Decorator to mark a function as deprecated. 

122 """ 

123 

124 def decorator(func): 

125 @wraps(func) 

126 def wrapper(*args, **kwargs): 

127 warn_deprecated(name or func.__name__, reason, version, stacklevel=3) 

128 return func(*args, **kwargs) 

129 

130 return wrapper 

131 

132 return decorator 

133 

134 

135def warn_deprecated_arg_usage( 

136 arg_name: Union[list, str], 

137 function_name: str, 

138 reason: str = "", 

139 version: str = "", 

140 stacklevel: int = 2, 

141): 

142 import warnings 

143 

144 msg = ( 

145 f"Call to '{function_name}' function with deprecated" 

146 f" usage of input argument/s '{arg_name}'." 

147 ) 

148 if reason: 

149 msg += f" ({reason})" 

150 if version: 

151 msg += f" -- Deprecated since version {version}." 

152 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) 

153 

154 

155C = TypeVar("C", bound=Callable) 

156 

157 

158def deprecated_args( 

159 args_to_warn: list = ["*"], 

160 allowed_args: list = [], 

161 reason: str = "", 

162 version: str = "", 

163) -> Callable[[C], C]: 

164 """ 

165 Decorator to mark specified args of a function as deprecated. 

166 If '*' is in args_to_warn, all arguments will be marked as deprecated. 

167 """ 

168 

169 def decorator(func: C) -> C: 

170 @wraps(func) 

171 def wrapper(*args, **kwargs): 

172 # Get function argument names 

173 arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] 

174 

175 provided_args = dict(zip(arg_names, args)) 

176 provided_args.update(kwargs) 

177 

178 provided_args.pop("self", None) 

179 for allowed_arg in allowed_args: 

180 provided_args.pop(allowed_arg, None) 

181 

182 for arg in args_to_warn: 

183 if arg == "*" and len(provided_args) > 0: 

184 warn_deprecated_arg_usage( 

185 list(provided_args.keys()), 

186 func.__name__, 

187 reason, 

188 version, 

189 stacklevel=3, 

190 ) 

191 elif arg in provided_args: 

192 warn_deprecated_arg_usage( 

193 arg, func.__name__, reason, version, stacklevel=3 

194 ) 

195 

196 return func(*args, **kwargs) 

197 

198 return wrapper 

199 

200 return decorator 

201 

202 

203def _set_info_logger(): 

204 """ 

205 Set up a logger that log info logs to stdout. 

206 (This is used by the default push response handler) 

207 """ 

208 if "push_response" not in logging.root.manager.loggerDict.keys(): 

209 logger = logging.getLogger("push_response") 

210 logger.setLevel(logging.INFO) 

211 handler = logging.StreamHandler() 

212 handler.setLevel(logging.INFO) 

213 logger.addHandler(handler) 

214 

215 

216def get_lib_version(): 

217 try: 

218 libver = metadata.version("redis") 

219 except metadata.PackageNotFoundError: 

220 libver = "99.99.99" 

221 return libver 

222 

223 

224def format_error_message(host_error: str, exception: BaseException) -> str: 

225 if not exception.args: 

226 return f"Error connecting to {host_error}." 

227 elif len(exception.args) == 1: 

228 return f"Error {exception.args[0]} connecting to {host_error}." 

229 else: 

230 return ( 

231 f"Error {exception.args[0]} connecting to {host_error}. " 

232 f"{exception.args[1]}." 

233 ) 

234 

235 

236def compare_versions(version1: str, version2: str) -> int: 

237 """ 

238 Compare two versions. 

239 

240 :return: -1 if version1 > version2 

241 0 if both versions are equal 

242 1 if version1 < version2 

243 """ 

244 

245 num_versions1 = list(map(int, version1.split("."))) 

246 num_versions2 = list(map(int, version2.split("."))) 

247 

248 if len(num_versions1) > len(num_versions2): 

249 diff = len(num_versions1) - len(num_versions2) 

250 for _ in range(diff): 

251 num_versions2.append(0) 

252 elif len(num_versions1) < len(num_versions2): 

253 diff = len(num_versions2) - len(num_versions1) 

254 for _ in range(diff): 

255 num_versions1.append(0) 

256 

257 for i, ver in enumerate(num_versions1): 

258 if num_versions1[i] > num_versions2[i]: 

259 return -1 

260 elif num_versions1[i] < num_versions2[i]: 

261 return 1 

262 

263 return 0 

264 

265 

266def ensure_string(key): 

267 if isinstance(key, bytes): 

268 return key.decode("utf-8") 

269 elif isinstance(key, str): 

270 return key 

271 else: 

272 raise TypeError("Key must be either a string or bytes") 

273 

274 

275def extract_expire_flags( 

276 ex: Optional[ExpiryT] = None, 

277 px: Optional[ExpiryT] = None, 

278 exat: Optional[AbsExpiryT] = None, 

279 pxat: Optional[AbsExpiryT] = None, 

280) -> List[EncodableT]: 

281 exp_options: list[EncodableT] = [] 

282 if ex is not None: 

283 exp_options.append("EX") 

284 if isinstance(ex, datetime.timedelta): 

285 exp_options.append(int(ex.total_seconds())) 

286 elif isinstance(ex, int): 

287 exp_options.append(ex) 

288 elif isinstance(ex, str) and ex.isdigit(): 

289 exp_options.append(int(ex)) 

290 else: 

291 raise DataError("ex must be datetime.timedelta or int") 

292 elif px is not None: 

293 exp_options.append("PX") 

294 if isinstance(px, datetime.timedelta): 

295 exp_options.append(int(px.total_seconds() * 1000)) 

296 elif isinstance(px, int): 

297 exp_options.append(px) 

298 else: 

299 raise DataError("px must be datetime.timedelta or int") 

300 elif exat is not None: 

301 if isinstance(exat, datetime.datetime): 

302 exat = int(exat.timestamp()) 

303 exp_options.extend(["EXAT", exat]) 

304 elif pxat is not None: 

305 if isinstance(pxat, datetime.datetime): 

306 pxat = int(pxat.timestamp() * 1000) 

307 exp_options.extend(["PXAT", pxat]) 

308 

309 return exp_options 

310 

311 

312def truncate_text(txt, max_length=100): 

313 return textwrap.shorten( 

314 text=txt, width=max_length, placeholder="...", break_long_words=True 

315 ) 

316 

317 

318def dummy_fail(): 

319 """ 

320 Fake function for a Retry object if you don't need to handle each failure. 

321 """ 

322 pass 

323 

324 

325async def dummy_fail_async(): 

326 """ 

327 Async fake function for a Retry object if you don't need to handle each failure. 

328 """ 

329 pass 

330 

331 

332def experimental(cls): 

333 """ 

334 Decorator to mark a class as experimental. 

335 """ 

336 original_init = cls.__init__ 

337 

338 @wraps(original_init) 

339 def new_init(self, *args, **kwargs): 

340 warnings.warn( 

341 f"{cls.__name__} is an experimental and may change or be removed in future versions.", 

342 category=UserWarning, 

343 stacklevel=2, 

344 ) 

345 original_init(self, *args, **kwargs) 

346 

347 cls.__init__ = new_init 

348 return cls