Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/utils.py: 38%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

273 statements  

1import datetime 

2import inspect 

3import logging 

4import textwrap 

5import warnings 

6from collections.abc import Callable 

7from contextlib import contextmanager 

8from functools import wraps 

9from importlib import metadata 

10from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, TypeVar, Union 

11 

12from redis.exceptions import DataError 

13from redis.typing import AbsExpiryT, EncodableT, ExpiryT 

14 

15DEFAULT_RESP_VERSION = 3 

16 

17if TYPE_CHECKING: 

18 from redis.client import Redis 

19 

20try: 

21 import hiredis # noqa 

22 

23 # Only support Hiredis >= 3.0: 

24 hiredis_version = hiredis.__version__.split(".") 

25 HIREDIS_AVAILABLE = int(hiredis_version[0]) > 3 or ( 

26 int(hiredis_version[0]) == 3 and int(hiredis_version[1]) >= 2 

27 ) 

28 if not HIREDIS_AVAILABLE: 

29 raise ImportError("hiredis package should be >= 3.2.0") 

30except ImportError: 

31 HIREDIS_AVAILABLE = False 

32 

33try: 

34 import ssl # noqa 

35 

36 SSL_AVAILABLE = True 

37except ImportError: 

38 SSL_AVAILABLE = False 

39 

40try: 

41 import cryptography # noqa 

42 

43 CRYPTOGRAPHY_AVAILABLE = True 

44except ImportError: 

45 CRYPTOGRAPHY_AVAILABLE = False 

46 

47 

48def from_url(url: str, **kwargs: Any) -> "Redis": 

49 """ 

50 Returns an active Redis client generated from the given database URL. 

51 

52 Will attempt to extract the database id from the path url fragment, if 

53 none is provided. 

54 """ 

55 from redis.client import Redis 

56 

57 return Redis.from_url(url, **kwargs) 

58 

59 

60@contextmanager 

61def pipeline(redis_obj): 

62 p = redis_obj.pipeline() 

63 yield p 

64 p.execute() 

65 

66 

67def str_if_bytes(value: Union[str, bytes]) -> str: 

68 return ( 

69 value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value 

70 ) 

71 

72 

73def safe_str(value): 

74 return str(str_if_bytes(value)) 

75 

76 

77def decode_field_value(value, key=None, field_encodings=None): 

78 """Decode a field value respecting optional per-field encoding overrides. 

79 

80 - If *field_encodings* is provided and *key* is in it, the corresponding 

81 encoding is used (``None`` means keep raw bytes). 

82 - Otherwise falls back to :func:`str_if_bytes`. 

83 """ 

84 if not isinstance(value, bytes): 

85 return value 

86 if field_encodings and key is not None and key in field_encodings: 

87 encoding = field_encodings[key] 

88 if encoding is None: 

89 return value 

90 return value.decode(encoding, "replace") 

91 return str_if_bytes(value) 

92 

93 

94def dict_merge(*dicts: Mapping[str, Any]) -> Dict[str, Any]: 

95 """ 

96 Merge all provided dicts into 1 dict. 

97 *dicts : `dict` 

98 dictionaries to merge 

99 """ 

100 merged = {} 

101 

102 for d in dicts: 

103 merged.update(d) 

104 

105 return merged 

106 

107 

108def list_keys_to_dict(key_list, callback): 

109 return dict.fromkeys(key_list, callback) 

110 

111 

112def merge_result(command, res): 

113 """ 

114 Merge all items in `res` into a list. 

115 

116 This command is used when sending a command to multiple nodes 

117 and the result from each node should be merged into a single list. 

118 

119 res : 'dict' 

120 """ 

121 result = set() 

122 

123 for v in res.values(): 

124 for value in v: 

125 result.add(value) 

126 

127 return list(result) 

128 

129 

130def warn_deprecated(name, reason="", version="", stacklevel=2): 

131 import warnings 

132 

133 msg = f"Call to deprecated {name}." 

134 if reason: 

135 msg += f" ({reason})" 

136 if version: 

137 msg += f" -- Deprecated since version {version}." 

138 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) 

139 

140 

141def deprecated_function(reason="", version="", name=None): 

142 """ 

143 Decorator to mark a function as deprecated. 

144 """ 

145 

146 def decorator(func): 

147 if inspect.iscoroutinefunction(func): 

148 # Create async wrapper for async functions 

149 @wraps(func) 

150 async def async_wrapper(*args, **kwargs): 

151 warn_deprecated(name or func.__name__, reason, version, stacklevel=3) 

152 return await func(*args, **kwargs) 

153 

154 return async_wrapper 

155 else: 

156 # Create regular wrapper for sync functions 

157 @wraps(func) 

158 def wrapper(*args, **kwargs): 

159 warn_deprecated(name or func.__name__, reason, version, stacklevel=3) 

160 return func(*args, **kwargs) 

161 

162 return wrapper 

163 

164 return decorator 

165 

166 

167def warn_deprecated_arg_usage( 

168 arg_name: Union[list, str], 

169 function_name: str, 

170 reason: str = "", 

171 version: str = "", 

172 stacklevel: int = 2, 

173): 

174 import warnings 

175 

176 msg = ( 

177 f"Call to '{function_name}' function with deprecated" 

178 f" usage of input argument/s '{arg_name}'." 

179 ) 

180 if reason: 

181 msg += f" ({reason})" 

182 if version: 

183 msg += f" -- Deprecated since version {version}." 

184 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) 

185 

186 

187C = TypeVar("C", bound=Callable) 

188 

189 

190def _get_filterable_args( 

191 func: Callable, args: tuple, kwargs: dict, allowed_args: Optional[List[str]] = None 

192) -> dict: 

193 """ 

194 Extract arguments from function call that should be checked for deprecation/experimental warnings. 

195 Excludes 'self' and any explicitly allowed args. 

196 """ 

197 arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] 

198 filterable_args = dict(zip(arg_names, args)) 

199 filterable_args.update(kwargs) 

200 filterable_args.pop("self", None) 

201 if allowed_args: 

202 for allowed_arg in allowed_args: 

203 filterable_args.pop(allowed_arg, None) 

204 return filterable_args 

205 

206 

207def deprecated_args( 

208 args_to_warn: Optional[List[str]] = None, 

209 allowed_args: Optional[List[str]] = None, 

210 reason: str = "", 

211 version: str = "", 

212) -> Callable[[C], C]: 

213 """ 

214 Decorator to mark specified args of a function as deprecated. 

215 If '*' is in args_to_warn, all arguments will be marked as deprecated. 

216 """ 

217 if args_to_warn is None: 

218 args_to_warn = ["*"] 

219 if allowed_args is None: 

220 allowed_args = [] 

221 

222 def _check_deprecated_args(func, filterable_args): 

223 """Check and warn about deprecated arguments.""" 

224 for arg in args_to_warn: 

225 if arg == "*" and len(filterable_args) > 0: 

226 warn_deprecated_arg_usage( 

227 list(filterable_args.keys()), 

228 func.__name__, 

229 reason, 

230 version, 

231 stacklevel=5, 

232 ) 

233 elif arg in filterable_args: 

234 warn_deprecated_arg_usage( 

235 arg, func.__name__, reason, version, stacklevel=5 

236 ) 

237 

238 def decorator(func: C) -> C: 

239 if inspect.iscoroutinefunction(func): 

240 

241 @wraps(func) 

242 async def async_wrapper(*args, **kwargs): 

243 filterable_args = _get_filterable_args(func, args, kwargs, allowed_args) 

244 _check_deprecated_args(func, filterable_args) 

245 return await func(*args, **kwargs) 

246 

247 return async_wrapper 

248 else: 

249 

250 @wraps(func) 

251 def wrapper(*args, **kwargs): 

252 filterable_args = _get_filterable_args(func, args, kwargs, allowed_args) 

253 _check_deprecated_args(func, filterable_args) 

254 return func(*args, **kwargs) 

255 

256 return wrapper 

257 

258 return decorator 

259 

260 

261def _set_info_logger(): 

262 """ 

263 Set up a logger that log info logs to stdout. 

264 (This is used by the default push response handler) 

265 """ 

266 if "push_response" not in logging.root.manager.loggerDict.keys(): 

267 logger = logging.getLogger("push_response") 

268 logger.setLevel(logging.INFO) 

269 handler = logging.StreamHandler() 

270 handler.setLevel(logging.INFO) 

271 logger.addHandler(handler) 

272 

273 

274def check_protocol_version( 

275 protocol: Optional[Union[str, int]], expected_version: int = 3 

276) -> bool: 

277 if protocol is None: 

278 return False 

279 if isinstance(protocol, str): 

280 try: 

281 protocol = int(protocol) 

282 except ValueError: 

283 return False 

284 return protocol == expected_version 

285 

286 

287def get_lib_version(): 

288 try: 

289 libver = metadata.version("redis") 

290 except metadata.PackageNotFoundError: 

291 libver = "99.99.99" 

292 return libver 

293 

294 

295def format_error_message(host_error: str, exception: BaseException) -> str: 

296 if not exception.args: 

297 return f"Error connecting to {host_error}." 

298 elif len(exception.args) == 1: 

299 return f"Error {exception.args[0]} connecting to {host_error}." 

300 else: 

301 return ( 

302 f"Error {exception.args[0]} connecting to {host_error}. " 

303 f"{exception.args[1]}." 

304 ) 

305 

306 

307def compare_versions(version1: str, version2: str) -> int: 

308 """ 

309 Compare two versions. 

310 

311 :return: -1 if version1 > version2 

312 0 if both versions are equal 

313 1 if version1 < version2 

314 """ 

315 

316 num_versions1 = list(map(int, version1.split("."))) 

317 num_versions2 = list(map(int, version2.split("."))) 

318 

319 if len(num_versions1) > len(num_versions2): 

320 diff = len(num_versions1) - len(num_versions2) 

321 for _ in range(diff): 

322 num_versions2.append(0) 

323 elif len(num_versions1) < len(num_versions2): 

324 diff = len(num_versions2) - len(num_versions1) 

325 for _ in range(diff): 

326 num_versions1.append(0) 

327 

328 for i, ver in enumerate(num_versions1): 

329 if num_versions1[i] > num_versions2[i]: 

330 return -1 

331 elif num_versions1[i] < num_versions2[i]: 

332 return 1 

333 

334 return 0 

335 

336 

337def ensure_string(key): 

338 if isinstance(key, bytes): 

339 return key.decode("utf-8") 

340 elif isinstance(key, str): 

341 return key 

342 else: 

343 raise TypeError("Key must be either a string or bytes") 

344 

345 

346def extract_expire_flags( 

347 ex: Optional[ExpiryT] = None, 

348 px: Optional[ExpiryT] = None, 

349 exat: Optional[AbsExpiryT] = None, 

350 pxat: Optional[AbsExpiryT] = None, 

351) -> List[EncodableT]: 

352 exp_options: list[EncodableT] = [] 

353 if ex is not None: 

354 exp_options.append("EX") 

355 if isinstance(ex, datetime.timedelta): 

356 exp_options.append(int(ex.total_seconds())) 

357 elif isinstance(ex, int): 

358 exp_options.append(ex) 

359 elif isinstance(ex, str) and ex.isdigit(): 

360 exp_options.append(int(ex)) 

361 else: 

362 raise DataError("ex must be datetime.timedelta or int") 

363 elif px is not None: 

364 exp_options.append("PX") 

365 if isinstance(px, datetime.timedelta): 

366 exp_options.append(int(px.total_seconds() * 1000)) 

367 elif isinstance(px, int): 

368 exp_options.append(px) 

369 else: 

370 raise DataError("px must be datetime.timedelta or int") 

371 elif exat is not None: 

372 if isinstance(exat, datetime.datetime): 

373 exat = int(exat.timestamp()) 

374 exp_options.extend(["EXAT", exat]) 

375 elif pxat is not None: 

376 if isinstance(pxat, datetime.datetime): 

377 pxat = int(pxat.timestamp() * 1000) 

378 exp_options.extend(["PXAT", pxat]) 

379 

380 return exp_options 

381 

382 

383def truncate_text(txt, max_length=100): 

384 return textwrap.shorten( 

385 text=txt, width=max_length, placeholder="...", break_long_words=True 

386 ) 

387 

388 

389def dummy_fail(): 

390 """ 

391 Fake function for a Retry object if you don't need to handle each failure. 

392 """ 

393 pass 

394 

395 

396async def dummy_fail_async(): 

397 """ 

398 Async fake function for a Retry object if you don't need to handle each failure. 

399 """ 

400 pass 

401 

402 

403def experimental(cls): 

404 """ 

405 Decorator to mark a class as experimental. 

406 """ 

407 original_init = cls.__init__ 

408 

409 @wraps(original_init) 

410 def new_init(self, *args, **kwargs): 

411 warnings.warn( 

412 f"{cls.__name__} is an experimental and may change or be removed in future versions.", 

413 category=UserWarning, 

414 stacklevel=2, 

415 ) 

416 original_init(self, *args, **kwargs) 

417 

418 cls.__init__ = new_init 

419 return cls 

420 

421 

422def warn_experimental(name, stacklevel=2): 

423 import warnings 

424 

425 msg = ( 

426 f"Call to experimental method {name}. " 

427 "Be aware that the function arguments can " 

428 "change or be removed in future versions." 

429 ) 

430 warnings.warn(msg, category=UserWarning, stacklevel=stacklevel) 

431 

432 

433def experimental_method() -> Callable[[C], C]: 

434 """ 

435 Decorator to mark a function as experimental. 

436 """ 

437 

438 def decorator(func: C) -> C: 

439 if inspect.iscoroutinefunction(func): 

440 # Create async wrapper for async functions 

441 @wraps(func) 

442 async def async_wrapper(*args, **kwargs): 

443 warn_experimental(func.__name__, stacklevel=2) 

444 return await func(*args, **kwargs) 

445 

446 return async_wrapper 

447 else: 

448 # Create regular wrapper for sync functions 

449 @wraps(func) 

450 def wrapper(*args, **kwargs): 

451 warn_experimental(func.__name__, stacklevel=2) 

452 return func(*args, **kwargs) 

453 

454 return wrapper 

455 

456 return decorator 

457 

458 

459def warn_experimental_arg_usage( 

460 arg_name: Union[list, str], 

461 function_name: str, 

462 stacklevel: int = 2, 

463): 

464 import warnings 

465 

466 msg = ( 

467 f"Call to '{function_name}' method with experimental" 

468 f" usage of input argument/s '{arg_name}'." 

469 ) 

470 warnings.warn(msg, category=UserWarning, stacklevel=stacklevel) 

471 

472 

473def experimental_args( 

474 args_to_warn: Optional[List[str]] = None, 

475) -> Callable[[C], C]: 

476 """ 

477 Decorator to mark specified args of a function as experimental. 

478 If '*' is in args_to_warn, all arguments will be marked as experimental. 

479 """ 

480 if args_to_warn is None: 

481 args_to_warn = ["*"] 

482 

483 def _check_experimental_args(func, filterable_args): 

484 """Check and warn about experimental arguments.""" 

485 for arg in args_to_warn: 

486 if arg == "*" and len(filterable_args) > 0: 

487 warn_experimental_arg_usage( 

488 list(filterable_args.keys()), func.__name__, stacklevel=4 

489 ) 

490 elif arg in filterable_args: 

491 warn_experimental_arg_usage(arg, func.__name__, stacklevel=4) 

492 

493 def decorator(func: C) -> C: 

494 if inspect.iscoroutinefunction(func): 

495 

496 @wraps(func) 

497 async def async_wrapper(*args, **kwargs): 

498 filterable_args = _get_filterable_args(func, args, kwargs) 

499 if len(filterable_args) > 0: 

500 _check_experimental_args(func, filterable_args) 

501 return await func(*args, **kwargs) 

502 

503 return async_wrapper 

504 else: 

505 

506 @wraps(func) 

507 def wrapper(*args, **kwargs): 

508 filterable_args = _get_filterable_args(func, args, kwargs) 

509 if len(filterable_args) > 0: 

510 _check_experimental_args(func, filterable_args) 

511 return func(*args, **kwargs) 

512 

513 return wrapper 

514 

515 return decorator