Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/utils.py: 38%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

274 statements  

1import datetime 

2import inspect 

3import logging 

4import textwrap 

5import warnings 

6from collections.abc import Callable 

7from contextlib import contextmanager 

8from functools import wraps 

9from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, TypeVar, Union 

10 

11from redis.exceptions import DataError 

12from redis.typing import AbsExpiryT, EncodableT, ExpiryT 

13 

14if TYPE_CHECKING: 

15 from redis.client import Redis 

16 

17try: 

18 import hiredis # noqa 

19 

20 # Only support Hiredis >= 3.0: 

21 hiredis_version = hiredis.__version__.split(".") 

22 HIREDIS_AVAILABLE = int(hiredis_version[0]) > 3 or ( 

23 int(hiredis_version[0]) == 3 and int(hiredis_version[1]) >= 2 

24 ) 

25 if not HIREDIS_AVAILABLE: 

26 raise ImportError("hiredis package should be >= 3.2.0") 

27except ImportError: 

28 HIREDIS_AVAILABLE = False 

29 

30try: 

31 import ssl # noqa 

32 

33 SSL_AVAILABLE = True 

34except ImportError: 

35 SSL_AVAILABLE = False 

36 

37try: 

38 import cryptography # noqa 

39 

40 CRYPTOGRAPHY_AVAILABLE = True 

41except ImportError: 

42 CRYPTOGRAPHY_AVAILABLE = False 

43 

44from importlib import metadata 

45 

46# Shared marker for omitted arguments, especially where None is a valid 

47# explicit value. Import this object from redis.utils instead of creating local 

48# sentinels, and compare it by identity only (`is` / `is not`). 

49SENTINEL = object() 

50 

51 

52def from_url(url: str, **kwargs: Any) -> "Redis": 

53 """ 

54 Returns an active Redis client generated from the given database URL. 

55 

56 Will attempt to extract the database id from the path url fragment, if 

57 none is provided. 

58 """ 

59 from redis.client import Redis 

60 

61 return Redis.from_url(url, **kwargs) 

62 

63 

64@contextmanager 

65def pipeline(redis_obj): 

66 p = redis_obj.pipeline() 

67 yield p 

68 p.execute() 

69 

70 

71def str_if_bytes(value: Union[str, bytes]) -> str: 

72 return ( 

73 value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value 

74 ) 

75 

76 

77def safe_str(value): 

78 return str(str_if_bytes(value)) 

79 

80 

81def decode_field_value(value, key=None, field_encodings=None): 

82 """Decode a field value respecting optional per-field encoding overrides. 

83 

84 - If *field_encodings* is provided and *key* is in it, the corresponding 

85 encoding is used (``None`` means keep raw bytes). 

86 - Otherwise falls back to :func:`str_if_bytes`. 

87 """ 

88 if not isinstance(value, bytes): 

89 return value 

90 if field_encodings and key is not None and key in field_encodings: 

91 encoding = field_encodings[key] 

92 if encoding is None: 

93 return value 

94 return value.decode(encoding, "replace") 

95 return str_if_bytes(value) 

96 

97 

98def dict_merge(*dicts: Mapping[str, Any]) -> Dict[str, Any]: 

99 """ 

100 Merge all provided dicts into 1 dict. 

101 *dicts : `dict` 

102 dictionaries to merge 

103 """ 

104 merged = {} 

105 

106 for d in dicts: 

107 merged.update(d) 

108 

109 return merged 

110 

111 

112def list_keys_to_dict(key_list, callback): 

113 return dict.fromkeys(key_list, callback) 

114 

115 

116def merge_result(command, res): 

117 """ 

118 Merge all items in `res` into a list. 

119 

120 This command is used when sending a command to multiple nodes 

121 and the result from each node should be merged into a single list. 

122 

123 res : 'dict' 

124 """ 

125 result = set() 

126 

127 for v in res.values(): 

128 for value in v: 

129 result.add(value) 

130 

131 return list(result) 

132 

133 

134def warn_deprecated(name, reason="", version="", stacklevel=2): 

135 import warnings 

136 

137 msg = f"Call to deprecated {name}." 

138 if reason: 

139 msg += f" ({reason})" 

140 if version: 

141 msg += f" -- Deprecated since version {version}." 

142 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) 

143 

144 

145def deprecated_function(reason="", version="", name=None): 

146 """ 

147 Decorator to mark a function as deprecated. 

148 """ 

149 

150 def decorator(func): 

151 if inspect.iscoroutinefunction(func): 

152 # Create async wrapper for async functions 

153 @wraps(func) 

154 async def async_wrapper(*args, **kwargs): 

155 warn_deprecated(name or func.__name__, reason, version, stacklevel=3) 

156 return await func(*args, **kwargs) 

157 

158 return async_wrapper 

159 else: 

160 # Create regular wrapper for sync functions 

161 @wraps(func) 

162 def wrapper(*args, **kwargs): 

163 warn_deprecated(name or func.__name__, reason, version, stacklevel=3) 

164 return func(*args, **kwargs) 

165 

166 return wrapper 

167 

168 return decorator 

169 

170 

171def warn_deprecated_arg_usage( 

172 arg_name: Union[list, str], 

173 function_name: str, 

174 reason: str = "", 

175 version: str = "", 

176 stacklevel: int = 2, 

177): 

178 import warnings 

179 

180 msg = ( 

181 f"Call to '{function_name}' function with deprecated" 

182 f" usage of input argument/s '{arg_name}'." 

183 ) 

184 if reason: 

185 msg += f" ({reason})" 

186 if version: 

187 msg += f" -- Deprecated since version {version}." 

188 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) 

189 

190 

191C = TypeVar("C", bound=Callable) 

192 

193 

194def _get_filterable_args( 

195 func: Callable, args: tuple, kwargs: dict, allowed_args: Optional[List[str]] = None 

196) -> dict: 

197 """ 

198 Extract arguments from function call that should be checked for deprecation/experimental warnings. 

199 Excludes 'self' and any explicitly allowed args. 

200 """ 

201 arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] 

202 filterable_args = dict(zip(arg_names, args)) 

203 filterable_args.update(kwargs) 

204 filterable_args.pop("self", None) 

205 if allowed_args: 

206 for allowed_arg in allowed_args: 

207 filterable_args.pop(allowed_arg, None) 

208 return filterable_args 

209 

210 

211def deprecated_args( 

212 args_to_warn: Optional[List[str]] = None, 

213 allowed_args: Optional[List[str]] = None, 

214 reason: str = "", 

215 version: str = "", 

216) -> Callable[[C], C]: 

217 """ 

218 Decorator to mark specified args of a function as deprecated. 

219 If '*' is in args_to_warn, all arguments will be marked as deprecated. 

220 """ 

221 if args_to_warn is None: 

222 args_to_warn = ["*"] 

223 if allowed_args is None: 

224 allowed_args = [] 

225 

226 def _check_deprecated_args(func, filterable_args): 

227 """Check and warn about deprecated arguments.""" 

228 for arg in args_to_warn: 

229 if arg == "*" and len(filterable_args) > 0: 

230 warn_deprecated_arg_usage( 

231 list(filterable_args.keys()), 

232 func.__name__, 

233 reason, 

234 version, 

235 stacklevel=5, 

236 ) 

237 elif arg in filterable_args: 

238 warn_deprecated_arg_usage( 

239 arg, func.__name__, reason, version, stacklevel=5 

240 ) 

241 

242 def decorator(func: C) -> C: 

243 if inspect.iscoroutinefunction(func): 

244 

245 @wraps(func) 

246 async def async_wrapper(*args, **kwargs): 

247 filterable_args = _get_filterable_args(func, args, kwargs, allowed_args) 

248 _check_deprecated_args(func, filterable_args) 

249 return await func(*args, **kwargs) 

250 

251 return async_wrapper 

252 else: 

253 

254 @wraps(func) 

255 def wrapper(*args, **kwargs): 

256 filterable_args = _get_filterable_args(func, args, kwargs, allowed_args) 

257 _check_deprecated_args(func, filterable_args) 

258 return func(*args, **kwargs) 

259 

260 return wrapper 

261 

262 return decorator 

263 

264 

265def _set_info_logger(): 

266 """ 

267 Set up a logger that log info logs to stdout. 

268 (This is used by the default push response handler) 

269 """ 

270 if "push_response" not in logging.root.manager.loggerDict.keys(): 

271 logger = logging.getLogger("push_response") 

272 logger.setLevel(logging.INFO) 

273 handler = logging.StreamHandler() 

274 handler.setLevel(logging.INFO) 

275 logger.addHandler(handler) 

276 

277 

278#: Default RESP protocol version used on the wire when the user does not 

279#: supply an explicit ``protocol`` to the client / connection / pool. Lives 

280#: in ``redis.utils`` so both ``redis.connection`` (for the HELLO handshake) 

281#: and ``check_protocol_version`` (for protocol-gated features) can read it 

282#: without a circular import. 

283DEFAULT_RESP_VERSION = 3 

284 

285 

286def check_protocol_version( 

287 protocol: Optional[Union[str, int]], expected_version: int = 3 

288) -> bool: 

289 if protocol is None: 

290 protocol = DEFAULT_RESP_VERSION 

291 if isinstance(protocol, str): 

292 try: 

293 protocol = int(protocol) 

294 except ValueError: 

295 return False 

296 return protocol == expected_version 

297 

298 

299def get_lib_version(): 

300 try: 

301 libver = metadata.version("redis") 

302 except metadata.PackageNotFoundError: 

303 libver = "99.99.99" 

304 return libver 

305 

306 

307def format_error_message(host_error: str, exception: BaseException) -> str: 

308 if not exception.args: 

309 return f"Error connecting to {host_error}." 

310 elif len(exception.args) == 1: 

311 return f"Error {exception.args[0]} connecting to {host_error}." 

312 else: 

313 return ( 

314 f"Error {exception.args[0]} connecting to {host_error}. " 

315 f"{exception.args[1]}." 

316 ) 

317 

318 

319def compare_versions(version1: str, version2: str) -> int: 

320 """ 

321 Compare two versions. 

322 

323 :return: -1 if version1 > version2 

324 0 if both versions are equal 

325 1 if version1 < version2 

326 """ 

327 

328 num_versions1 = list(map(int, version1.split("."))) 

329 num_versions2 = list(map(int, version2.split("."))) 

330 

331 if len(num_versions1) > len(num_versions2): 

332 diff = len(num_versions1) - len(num_versions2) 

333 for _ in range(diff): 

334 num_versions2.append(0) 

335 elif len(num_versions1) < len(num_versions2): 

336 diff = len(num_versions2) - len(num_versions1) 

337 for _ in range(diff): 

338 num_versions1.append(0) 

339 

340 for i, ver in enumerate(num_versions1): 

341 if num_versions1[i] > num_versions2[i]: 

342 return -1 

343 elif num_versions1[i] < num_versions2[i]: 

344 return 1 

345 

346 return 0 

347 

348 

349def ensure_string(key): 

350 if isinstance(key, bytes): 

351 return key.decode("utf-8") 

352 elif isinstance(key, str): 

353 return key 

354 else: 

355 raise TypeError("Key must be either a string or bytes") 

356 

357 

358def extract_expire_flags( 

359 ex: Optional[ExpiryT] = None, 

360 px: Optional[ExpiryT] = None, 

361 exat: Optional[AbsExpiryT] = None, 

362 pxat: Optional[AbsExpiryT] = None, 

363) -> List[EncodableT]: 

364 exp_options: list[EncodableT] = [] 

365 if ex is not None: 

366 exp_options.append("EX") 

367 if isinstance(ex, datetime.timedelta): 

368 exp_options.append(int(ex.total_seconds())) 

369 elif isinstance(ex, int): 

370 exp_options.append(ex) 

371 elif isinstance(ex, str) and ex.isdigit(): 

372 exp_options.append(int(ex)) 

373 else: 

374 raise DataError("ex must be datetime.timedelta or int") 

375 elif px is not None: 

376 exp_options.append("PX") 

377 if isinstance(px, datetime.timedelta): 

378 exp_options.append(int(px.total_seconds() * 1000)) 

379 elif isinstance(px, int): 

380 exp_options.append(px) 

381 else: 

382 raise DataError("px must be datetime.timedelta or int") 

383 elif exat is not None: 

384 if isinstance(exat, datetime.datetime): 

385 exat = int(exat.timestamp()) 

386 exp_options.extend(["EXAT", exat]) 

387 elif pxat is not None: 

388 if isinstance(pxat, datetime.datetime): 

389 pxat = int(pxat.timestamp() * 1000) 

390 exp_options.extend(["PXAT", pxat]) 

391 

392 return exp_options 

393 

394 

395def truncate_text(txt, max_length=100): 

396 return textwrap.shorten( 

397 text=txt, width=max_length, placeholder="...", break_long_words=True 

398 ) 

399 

400 

401def dummy_fail(): 

402 """ 

403 Fake function for a Retry object if you don't need to handle each failure. 

404 """ 

405 pass 

406 

407 

408async def dummy_fail_async(): 

409 """ 

410 Async fake function for a Retry object if you don't need to handle each failure. 

411 """ 

412 pass 

413 

414 

415def experimental(cls): 

416 """ 

417 Decorator to mark a class as experimental. 

418 """ 

419 original_init = cls.__init__ 

420 

421 @wraps(original_init) 

422 def new_init(self, *args, **kwargs): 

423 warnings.warn( 

424 f"{cls.__name__} is an experimental and may change or be removed in future versions.", 

425 category=UserWarning, 

426 stacklevel=2, 

427 ) 

428 original_init(self, *args, **kwargs) 

429 

430 cls.__init__ = new_init 

431 return cls 

432 

433 

434def warn_experimental(name, stacklevel=2): 

435 import warnings 

436 

437 msg = ( 

438 f"Call to experimental method {name}. " 

439 "Be aware that the function arguments can " 

440 "change or be removed in future versions." 

441 ) 

442 warnings.warn(msg, category=UserWarning, stacklevel=stacklevel) 

443 

444 

445def experimental_method() -> Callable[[C], C]: 

446 """ 

447 Decorator to mark a function as experimental. 

448 """ 

449 

450 def decorator(func: C) -> C: 

451 if inspect.iscoroutinefunction(func): 

452 # Create async wrapper for async functions 

453 @wraps(func) 

454 async def async_wrapper(*args, **kwargs): 

455 warn_experimental(func.__name__, stacklevel=2) 

456 return await func(*args, **kwargs) 

457 

458 return async_wrapper 

459 else: 

460 # Create regular wrapper for sync functions 

461 @wraps(func) 

462 def wrapper(*args, **kwargs): 

463 warn_experimental(func.__name__, stacklevel=2) 

464 return func(*args, **kwargs) 

465 

466 return wrapper 

467 

468 return decorator 

469 

470 

471def warn_experimental_arg_usage( 

472 arg_name: Union[list, str], 

473 function_name: str, 

474 stacklevel: int = 2, 

475): 

476 import warnings 

477 

478 msg = ( 

479 f"Call to '{function_name}' method with experimental" 

480 f" usage of input argument/s '{arg_name}'." 

481 ) 

482 warnings.warn(msg, category=UserWarning, stacklevel=stacklevel) 

483 

484 

485def experimental_args( 

486 args_to_warn: Optional[List[str]] = None, 

487) -> Callable[[C], C]: 

488 """ 

489 Decorator to mark specified args of a function as experimental. 

490 If '*' is in args_to_warn, all arguments will be marked as experimental. 

491 """ 

492 if args_to_warn is None: 

493 args_to_warn = ["*"] 

494 

495 def _check_experimental_args(func, filterable_args): 

496 """Check and warn about experimental arguments.""" 

497 for arg in args_to_warn: 

498 if arg == "*" and len(filterable_args) > 0: 

499 warn_experimental_arg_usage( 

500 list(filterable_args.keys()), func.__name__, stacklevel=4 

501 ) 

502 elif arg in filterable_args: 

503 warn_experimental_arg_usage(arg, func.__name__, stacklevel=4) 

504 

505 def decorator(func: C) -> C: 

506 if inspect.iscoroutinefunction(func): 

507 

508 @wraps(func) 

509 async def async_wrapper(*args, **kwargs): 

510 filterable_args = _get_filterable_args(func, args, kwargs) 

511 if len(filterable_args) > 0: 

512 _check_experimental_args(func, filterable_args) 

513 return await func(*args, **kwargs) 

514 

515 return async_wrapper 

516 else: 

517 

518 @wraps(func) 

519 def wrapper(*args, **kwargs): 

520 filterable_args = _get_filterable_args(func, args, kwargs) 

521 if len(filterable_args) > 0: 

522 _check_experimental_args(func, filterable_args) 

523 return func(*args, **kwargs) 

524 

525 return wrapper 

526 

527 return decorator