Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/utils.py: 40%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

263 statements  

1import datetime 

2import inspect 

3import logging 

4import textwrap 

5import warnings 

6from collections.abc import Callable 

7from contextlib import contextmanager 

8from functools import wraps 

9from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, TypeVar, Union 

10 

11from redis.exceptions import DataError 

12from redis.typing import AbsExpiryT, EncodableT, ExpiryT 

13 

14if TYPE_CHECKING: 

15 from redis.client import Redis 

16 

17try: 

18 import hiredis # noqa 

19 

20 # Only support Hiredis >= 3.0: 

21 hiredis_version = hiredis.__version__.split(".") 

22 HIREDIS_AVAILABLE = int(hiredis_version[0]) > 3 or ( 

23 int(hiredis_version[0]) == 3 and int(hiredis_version[1]) >= 2 

24 ) 

25 if not HIREDIS_AVAILABLE: 

26 raise ImportError("hiredis package should be >= 3.2.0") 

27except ImportError: 

28 HIREDIS_AVAILABLE = False 

29 

30try: 

31 import ssl # noqa 

32 

33 SSL_AVAILABLE = True 

34except ImportError: 

35 SSL_AVAILABLE = False 

36 

37try: 

38 import cryptography # noqa 

39 

40 CRYPTOGRAPHY_AVAILABLE = True 

41except ImportError: 

42 CRYPTOGRAPHY_AVAILABLE = False 

43 

44from importlib import metadata 

45 

46 

47def from_url(url: str, **kwargs: Any) -> "Redis": 

48 """ 

49 Returns an active Redis client generated from the given database URL. 

50 

51 Will attempt to extract the database id from the path url fragment, if 

52 none is provided. 

53 """ 

54 from redis.client import Redis 

55 

56 return Redis.from_url(url, **kwargs) 

57 

58 

59@contextmanager 

60def pipeline(redis_obj): 

61 p = redis_obj.pipeline() 

62 yield p 

63 p.execute() 

64 

65 

66def str_if_bytes(value: Union[str, bytes]) -> str: 

67 return ( 

68 value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value 

69 ) 

70 

71 

72def safe_str(value): 

73 return str(str_if_bytes(value)) 

74 

75 

76def dict_merge(*dicts: Mapping[str, Any]) -> Dict[str, Any]: 

77 """ 

78 Merge all provided dicts into 1 dict. 

79 *dicts : `dict` 

80 dictionaries to merge 

81 """ 

82 merged = {} 

83 

84 for d in dicts: 

85 merged.update(d) 

86 

87 return merged 

88 

89 

90def list_keys_to_dict(key_list, callback): 

91 return dict.fromkeys(key_list, callback) 

92 

93 

94def merge_result(command, res): 

95 """ 

96 Merge all items in `res` into a list. 

97 

98 This command is used when sending a command to multiple nodes 

99 and the result from each node should be merged into a single list. 

100 

101 res : 'dict' 

102 """ 

103 result = set() 

104 

105 for v in res.values(): 

106 for value in v: 

107 result.add(value) 

108 

109 return list(result) 

110 

111 

112def warn_deprecated(name, reason="", version="", stacklevel=2): 

113 import warnings 

114 

115 msg = f"Call to deprecated {name}." 

116 if reason: 

117 msg += f" ({reason})" 

118 if version: 

119 msg += f" -- Deprecated since version {version}." 

120 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) 

121 

122 

123def deprecated_function(reason="", version="", name=None): 

124 """ 

125 Decorator to mark a function as deprecated. 

126 """ 

127 

128 def decorator(func): 

129 if inspect.iscoroutinefunction(func): 

130 # Create async wrapper for async functions 

131 @wraps(func) 

132 async def async_wrapper(*args, **kwargs): 

133 warn_deprecated(name or func.__name__, reason, version, stacklevel=3) 

134 return await func(*args, **kwargs) 

135 

136 return async_wrapper 

137 else: 

138 # Create regular wrapper for sync functions 

139 @wraps(func) 

140 def wrapper(*args, **kwargs): 

141 warn_deprecated(name or func.__name__, reason, version, stacklevel=3) 

142 return func(*args, **kwargs) 

143 

144 return wrapper 

145 

146 return decorator 

147 

148 

149def warn_deprecated_arg_usage( 

150 arg_name: Union[list, str], 

151 function_name: str, 

152 reason: str = "", 

153 version: str = "", 

154 stacklevel: int = 2, 

155): 

156 import warnings 

157 

158 msg = ( 

159 f"Call to '{function_name}' function with deprecated" 

160 f" usage of input argument/s '{arg_name}'." 

161 ) 

162 if reason: 

163 msg += f" ({reason})" 

164 if version: 

165 msg += f" -- Deprecated since version {version}." 

166 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) 

167 

168 

169C = TypeVar("C", bound=Callable) 

170 

171 

172def _get_filterable_args( 

173 func: Callable, args: tuple, kwargs: dict, allowed_args: Optional[List[str]] = None 

174) -> dict: 

175 """ 

176 Extract arguments from function call that should be checked for deprecation/experimental warnings. 

177 Excludes 'self' and any explicitly allowed args. 

178 """ 

179 arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] 

180 filterable_args = dict(zip(arg_names, args)) 

181 filterable_args.update(kwargs) 

182 filterable_args.pop("self", None) 

183 if allowed_args: 

184 for allowed_arg in allowed_args: 

185 filterable_args.pop(allowed_arg, None) 

186 return filterable_args 

187 

188 

189def deprecated_args( 

190 args_to_warn: Optional[List[str]] = None, 

191 allowed_args: Optional[List[str]] = None, 

192 reason: str = "", 

193 version: str = "", 

194) -> Callable[[C], C]: 

195 """ 

196 Decorator to mark specified args of a function as deprecated. 

197 If '*' is in args_to_warn, all arguments will be marked as deprecated. 

198 """ 

199 if args_to_warn is None: 

200 args_to_warn = ["*"] 

201 if allowed_args is None: 

202 allowed_args = [] 

203 

204 def _check_deprecated_args(func, filterable_args): 

205 """Check and warn about deprecated arguments.""" 

206 for arg in args_to_warn: 

207 if arg == "*" and len(filterable_args) > 0: 

208 warn_deprecated_arg_usage( 

209 list(filterable_args.keys()), 

210 func.__name__, 

211 reason, 

212 version, 

213 stacklevel=5, 

214 ) 

215 elif arg in filterable_args: 

216 warn_deprecated_arg_usage( 

217 arg, func.__name__, reason, version, stacklevel=5 

218 ) 

219 

220 def decorator(func: C) -> C: 

221 if inspect.iscoroutinefunction(func): 

222 

223 @wraps(func) 

224 async def async_wrapper(*args, **kwargs): 

225 filterable_args = _get_filterable_args(func, args, kwargs, allowed_args) 

226 _check_deprecated_args(func, filterable_args) 

227 return await func(*args, **kwargs) 

228 

229 return async_wrapper 

230 else: 

231 

232 @wraps(func) 

233 def wrapper(*args, **kwargs): 

234 filterable_args = _get_filterable_args(func, args, kwargs, allowed_args) 

235 _check_deprecated_args(func, filterable_args) 

236 return func(*args, **kwargs) 

237 

238 return wrapper 

239 

240 return decorator 

241 

242 

243def _set_info_logger(): 

244 """ 

245 Set up a logger that log info logs to stdout. 

246 (This is used by the default push response handler) 

247 """ 

248 if "push_response" not in logging.root.manager.loggerDict.keys(): 

249 logger = logging.getLogger("push_response") 

250 logger.setLevel(logging.INFO) 

251 handler = logging.StreamHandler() 

252 handler.setLevel(logging.INFO) 

253 logger.addHandler(handler) 

254 

255 

256def check_protocol_version( 

257 protocol: Optional[Union[str, int]], expected_version: int = 3 

258) -> bool: 

259 if protocol is None: 

260 return False 

261 if isinstance(protocol, str): 

262 try: 

263 protocol = int(protocol) 

264 except ValueError: 

265 return False 

266 return protocol == expected_version 

267 

268 

269def get_lib_version(): 

270 try: 

271 libver = metadata.version("redis") 

272 except metadata.PackageNotFoundError: 

273 libver = "99.99.99" 

274 return libver 

275 

276 

277def format_error_message(host_error: str, exception: BaseException) -> str: 

278 if not exception.args: 

279 return f"Error connecting to {host_error}." 

280 elif len(exception.args) == 1: 

281 return f"Error {exception.args[0]} connecting to {host_error}." 

282 else: 

283 return ( 

284 f"Error {exception.args[0]} connecting to {host_error}. " 

285 f"{exception.args[1]}." 

286 ) 

287 

288 

289def compare_versions(version1: str, version2: str) -> int: 

290 """ 

291 Compare two versions. 

292 

293 :return: -1 if version1 > version2 

294 0 if both versions are equal 

295 1 if version1 < version2 

296 """ 

297 

298 num_versions1 = list(map(int, version1.split("."))) 

299 num_versions2 = list(map(int, version2.split("."))) 

300 

301 if len(num_versions1) > len(num_versions2): 

302 diff = len(num_versions1) - len(num_versions2) 

303 for _ in range(diff): 

304 num_versions2.append(0) 

305 elif len(num_versions1) < len(num_versions2): 

306 diff = len(num_versions2) - len(num_versions1) 

307 for _ in range(diff): 

308 num_versions1.append(0) 

309 

310 for i, ver in enumerate(num_versions1): 

311 if num_versions1[i] > num_versions2[i]: 

312 return -1 

313 elif num_versions1[i] < num_versions2[i]: 

314 return 1 

315 

316 return 0 

317 

318 

319def ensure_string(key): 

320 if isinstance(key, bytes): 

321 return key.decode("utf-8") 

322 elif isinstance(key, str): 

323 return key 

324 else: 

325 raise TypeError("Key must be either a string or bytes") 

326 

327 

328def extract_expire_flags( 

329 ex: Optional[ExpiryT] = None, 

330 px: Optional[ExpiryT] = None, 

331 exat: Optional[AbsExpiryT] = None, 

332 pxat: Optional[AbsExpiryT] = None, 

333) -> List[EncodableT]: 

334 exp_options: list[EncodableT] = [] 

335 if ex is not None: 

336 exp_options.append("EX") 

337 if isinstance(ex, datetime.timedelta): 

338 exp_options.append(int(ex.total_seconds())) 

339 elif isinstance(ex, int): 

340 exp_options.append(ex) 

341 elif isinstance(ex, str) and ex.isdigit(): 

342 exp_options.append(int(ex)) 

343 else: 

344 raise DataError("ex must be datetime.timedelta or int") 

345 elif px is not None: 

346 exp_options.append("PX") 

347 if isinstance(px, datetime.timedelta): 

348 exp_options.append(int(px.total_seconds() * 1000)) 

349 elif isinstance(px, int): 

350 exp_options.append(px) 

351 else: 

352 raise DataError("px must be datetime.timedelta or int") 

353 elif exat is not None: 

354 if isinstance(exat, datetime.datetime): 

355 exat = int(exat.timestamp()) 

356 exp_options.extend(["EXAT", exat]) 

357 elif pxat is not None: 

358 if isinstance(pxat, datetime.datetime): 

359 pxat = int(pxat.timestamp() * 1000) 

360 exp_options.extend(["PXAT", pxat]) 

361 

362 return exp_options 

363 

364 

365def truncate_text(txt, max_length=100): 

366 return textwrap.shorten( 

367 text=txt, width=max_length, placeholder="...", break_long_words=True 

368 ) 

369 

370 

371def dummy_fail(): 

372 """ 

373 Fake function for a Retry object if you don't need to handle each failure. 

374 """ 

375 pass 

376 

377 

378async def dummy_fail_async(): 

379 """ 

380 Async fake function for a Retry object if you don't need to handle each failure. 

381 """ 

382 pass 

383 

384 

385def experimental(cls): 

386 """ 

387 Decorator to mark a class as experimental. 

388 """ 

389 original_init = cls.__init__ 

390 

391 @wraps(original_init) 

392 def new_init(self, *args, **kwargs): 

393 warnings.warn( 

394 f"{cls.__name__} is an experimental and may change or be removed in future versions.", 

395 category=UserWarning, 

396 stacklevel=2, 

397 ) 

398 original_init(self, *args, **kwargs) 

399 

400 cls.__init__ = new_init 

401 return cls 

402 

403 

404def warn_experimental(name, stacklevel=2): 

405 import warnings 

406 

407 msg = ( 

408 f"Call to experimental method {name}. " 

409 "Be aware that the function arguments can " 

410 "change or be removed in future versions." 

411 ) 

412 warnings.warn(msg, category=UserWarning, stacklevel=stacklevel) 

413 

414 

415def experimental_method() -> Callable[[C], C]: 

416 """ 

417 Decorator to mark a function as experimental. 

418 """ 

419 

420 def decorator(func: C) -> C: 

421 if inspect.iscoroutinefunction(func): 

422 # Create async wrapper for async functions 

423 @wraps(func) 

424 async def async_wrapper(*args, **kwargs): 

425 warn_experimental(func.__name__, stacklevel=2) 

426 return await func(*args, **kwargs) 

427 

428 return async_wrapper 

429 else: 

430 # Create regular wrapper for sync functions 

431 @wraps(func) 

432 def wrapper(*args, **kwargs): 

433 warn_experimental(func.__name__, stacklevel=2) 

434 return func(*args, **kwargs) 

435 

436 return wrapper 

437 

438 return decorator 

439 

440 

441def warn_experimental_arg_usage( 

442 arg_name: Union[list, str], 

443 function_name: str, 

444 stacklevel: int = 2, 

445): 

446 import warnings 

447 

448 msg = ( 

449 f"Call to '{function_name}' method with experimental" 

450 f" usage of input argument/s '{arg_name}'." 

451 ) 

452 warnings.warn(msg, category=UserWarning, stacklevel=stacklevel) 

453 

454 

455def experimental_args( 

456 args_to_warn: Optional[List[str]] = None, 

457) -> Callable[[C], C]: 

458 """ 

459 Decorator to mark specified args of a function as experimental. 

460 If '*' is in args_to_warn, all arguments will be marked as experimental. 

461 """ 

462 if args_to_warn is None: 

463 args_to_warn = ["*"] 

464 

465 def _check_experimental_args(func, filterable_args): 

466 """Check and warn about experimental arguments.""" 

467 for arg in args_to_warn: 

468 if arg == "*" and len(filterable_args) > 0: 

469 warn_experimental_arg_usage( 

470 list(filterable_args.keys()), func.__name__, stacklevel=4 

471 ) 

472 elif arg in filterable_args: 

473 warn_experimental_arg_usage(arg, func.__name__, stacklevel=4) 

474 

475 def decorator(func: C) -> C: 

476 if inspect.iscoroutinefunction(func): 

477 

478 @wraps(func) 

479 async def async_wrapper(*args, **kwargs): 

480 filterable_args = _get_filterable_args(func, args, kwargs) 

481 if len(filterable_args) > 0: 

482 _check_experimental_args(func, filterable_args) 

483 return await func(*args, **kwargs) 

484 

485 return async_wrapper 

486 else: 

487 

488 @wraps(func) 

489 def wrapper(*args, **kwargs): 

490 filterable_args = _get_filterable_args(func, args, kwargs) 

491 if len(filterable_args) > 0: 

492 _check_experimental_args(func, filterable_args) 

493 return func(*args, **kwargs) 

494 

495 return wrapper 

496 

497 return decorator