Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/utils.py: 40%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

214 statements  

1import datetime 

2import logging 

3import textwrap 

4import warnings 

5from collections.abc import Callable 

6from contextlib import contextmanager 

7from functools import wraps 

8from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, TypeVar, Union 

9 

10from redis.exceptions import DataError 

11from redis.typing import AbsExpiryT, EncodableT, ExpiryT 

12 

13if TYPE_CHECKING: 

14 from redis.client import Redis 

15 

16try: 

17 import hiredis # noqa 

18 

19 # Only support Hiredis >= 3.0: 

20 hiredis_version = hiredis.__version__.split(".") 

21 HIREDIS_AVAILABLE = int(hiredis_version[0]) > 3 or ( 

22 int(hiredis_version[0]) == 3 and int(hiredis_version[1]) >= 2 

23 ) 

24 if not HIREDIS_AVAILABLE: 

25 raise ImportError("hiredis package should be >= 3.2.0") 

26except ImportError: 

27 HIREDIS_AVAILABLE = False 

28 

29try: 

30 import ssl # noqa 

31 

32 SSL_AVAILABLE = True 

33except ImportError: 

34 SSL_AVAILABLE = False 

35 

36try: 

37 import cryptography # noqa 

38 

39 CRYPTOGRAPHY_AVAILABLE = True 

40except ImportError: 

41 CRYPTOGRAPHY_AVAILABLE = False 

42 

43from importlib import metadata 

44 

45 

46def from_url(url: str, **kwargs: Any) -> "Redis": 

47 """ 

48 Returns an active Redis client generated from the given database URL. 

49 

50 Will attempt to extract the database id from the path url fragment, if 

51 none is provided. 

52 """ 

53 from redis.client import Redis 

54 

55 return Redis.from_url(url, **kwargs) 

56 

57 

58@contextmanager 

59def pipeline(redis_obj): 

60 p = redis_obj.pipeline() 

61 yield p 

62 p.execute() 

63 

64 

65def str_if_bytes(value: Union[str, bytes]) -> str: 

66 return ( 

67 value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value 

68 ) 

69 

70 

71def safe_str(value): 

72 return str(str_if_bytes(value)) 

73 

74 

75def dict_merge(*dicts: Mapping[str, Any]) -> Dict[str, Any]: 

76 """ 

77 Merge all provided dicts into 1 dict. 

78 *dicts : `dict` 

79 dictionaries to merge 

80 """ 

81 merged = {} 

82 

83 for d in dicts: 

84 merged.update(d) 

85 

86 return merged 

87 

88 

89def list_keys_to_dict(key_list, callback): 

90 return dict.fromkeys(key_list, callback) 

91 

92 

93def merge_result(command, res): 

94 """ 

95 Merge all items in `res` into a list. 

96 

97 This command is used when sending a command to multiple nodes 

98 and the result from each node should be merged into a single list. 

99 

100 res : 'dict' 

101 """ 

102 result = set() 

103 

104 for v in res.values(): 

105 for value in v: 

106 result.add(value) 

107 

108 return list(result) 

109 

110 

111def warn_deprecated(name, reason="", version="", stacklevel=2): 

112 import warnings 

113 

114 msg = f"Call to deprecated {name}." 

115 if reason: 

116 msg += f" ({reason})" 

117 if version: 

118 msg += f" -- Deprecated since version {version}." 

119 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) 

120 

121 

122def deprecated_function(reason="", version="", name=None): 

123 """ 

124 Decorator to mark a function as deprecated. 

125 """ 

126 

127 def decorator(func): 

128 @wraps(func) 

129 def wrapper(*args, **kwargs): 

130 warn_deprecated(name or func.__name__, reason, version, stacklevel=3) 

131 return func(*args, **kwargs) 

132 

133 return wrapper 

134 

135 return decorator 

136 

137 

138def warn_deprecated_arg_usage( 

139 arg_name: Union[list, str], 

140 function_name: str, 

141 reason: str = "", 

142 version: str = "", 

143 stacklevel: int = 2, 

144): 

145 import warnings 

146 

147 msg = ( 

148 f"Call to '{function_name}' function with deprecated" 

149 f" usage of input argument/s '{arg_name}'." 

150 ) 

151 if reason: 

152 msg += f" ({reason})" 

153 if version: 

154 msg += f" -- Deprecated since version {version}." 

155 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) 

156 

157 

158C = TypeVar("C", bound=Callable) 

159 

160 

161def deprecated_args( 

162 args_to_warn: list = ["*"], 

163 allowed_args: list = [], 

164 reason: str = "", 

165 version: str = "", 

166) -> Callable[[C], C]: 

167 """ 

168 Decorator to mark specified args of a function as deprecated. 

169 If '*' is in args_to_warn, all arguments will be marked as deprecated. 

170 """ 

171 

172 def decorator(func: C) -> C: 

173 @wraps(func) 

174 def wrapper(*args, **kwargs): 

175 # Get function argument names 

176 arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] 

177 

178 provided_args = dict(zip(arg_names, args)) 

179 provided_args.update(kwargs) 

180 

181 provided_args.pop("self", None) 

182 for allowed_arg in allowed_args: 

183 provided_args.pop(allowed_arg, None) 

184 

185 for arg in args_to_warn: 

186 if arg == "*" and len(provided_args) > 0: 

187 warn_deprecated_arg_usage( 

188 list(provided_args.keys()), 

189 func.__name__, 

190 reason, 

191 version, 

192 stacklevel=3, 

193 ) 

194 elif arg in provided_args: 

195 warn_deprecated_arg_usage( 

196 arg, func.__name__, reason, version, stacklevel=3 

197 ) 

198 

199 return func(*args, **kwargs) 

200 

201 return wrapper 

202 

203 return decorator 

204 

205 

206def _set_info_logger(): 

207 """ 

208 Set up a logger that log info logs to stdout. 

209 (This is used by the default push response handler) 

210 """ 

211 if "push_response" not in logging.root.manager.loggerDict.keys(): 

212 logger = logging.getLogger("push_response") 

213 logger.setLevel(logging.INFO) 

214 handler = logging.StreamHandler() 

215 handler.setLevel(logging.INFO) 

216 logger.addHandler(handler) 

217 

218 

219def get_lib_version(): 

220 try: 

221 libver = metadata.version("redis") 

222 except metadata.PackageNotFoundError: 

223 libver = "99.99.99" 

224 return libver 

225 

226 

227def format_error_message(host_error: str, exception: BaseException) -> str: 

228 if not exception.args: 

229 return f"Error connecting to {host_error}." 

230 elif len(exception.args) == 1: 

231 return f"Error {exception.args[0]} connecting to {host_error}." 

232 else: 

233 return ( 

234 f"Error {exception.args[0]} connecting to {host_error}. " 

235 f"{exception.args[1]}." 

236 ) 

237 

238 

239def compare_versions(version1: str, version2: str) -> int: 

240 """ 

241 Compare two versions. 

242 

243 :return: -1 if version1 > version2 

244 0 if both versions are equal 

245 1 if version1 < version2 

246 """ 

247 

248 num_versions1 = list(map(int, version1.split("."))) 

249 num_versions2 = list(map(int, version2.split("."))) 

250 

251 if len(num_versions1) > len(num_versions2): 

252 diff = len(num_versions1) - len(num_versions2) 

253 for _ in range(diff): 

254 num_versions2.append(0) 

255 elif len(num_versions1) < len(num_versions2): 

256 diff = len(num_versions2) - len(num_versions1) 

257 for _ in range(diff): 

258 num_versions1.append(0) 

259 

260 for i, ver in enumerate(num_versions1): 

261 if num_versions1[i] > num_versions2[i]: 

262 return -1 

263 elif num_versions1[i] < num_versions2[i]: 

264 return 1 

265 

266 return 0 

267 

268 

269def ensure_string(key): 

270 if isinstance(key, bytes): 

271 return key.decode("utf-8") 

272 elif isinstance(key, str): 

273 return key 

274 else: 

275 raise TypeError("Key must be either a string or bytes") 

276 

277 

278def extract_expire_flags( 

279 ex: Optional[ExpiryT] = None, 

280 px: Optional[ExpiryT] = None, 

281 exat: Optional[AbsExpiryT] = None, 

282 pxat: Optional[AbsExpiryT] = None, 

283) -> List[EncodableT]: 

284 exp_options: list[EncodableT] = [] 

285 if ex is not None: 

286 exp_options.append("EX") 

287 if isinstance(ex, datetime.timedelta): 

288 exp_options.append(int(ex.total_seconds())) 

289 elif isinstance(ex, int): 

290 exp_options.append(ex) 

291 elif isinstance(ex, str) and ex.isdigit(): 

292 exp_options.append(int(ex)) 

293 else: 

294 raise DataError("ex must be datetime.timedelta or int") 

295 elif px is not None: 

296 exp_options.append("PX") 

297 if isinstance(px, datetime.timedelta): 

298 exp_options.append(int(px.total_seconds() * 1000)) 

299 elif isinstance(px, int): 

300 exp_options.append(px) 

301 else: 

302 raise DataError("px must be datetime.timedelta or int") 

303 elif exat is not None: 

304 if isinstance(exat, datetime.datetime): 

305 exat = int(exat.timestamp()) 

306 exp_options.extend(["EXAT", exat]) 

307 elif pxat is not None: 

308 if isinstance(pxat, datetime.datetime): 

309 pxat = int(pxat.timestamp() * 1000) 

310 exp_options.extend(["PXAT", pxat]) 

311 

312 return exp_options 

313 

314 

315def truncate_text(txt, max_length=100): 

316 return textwrap.shorten( 

317 text=txt, width=max_length, placeholder="...", break_long_words=True 

318 ) 

319 

320 

321def dummy_fail(): 

322 """ 

323 Fake function for a Retry object if you don't need to handle each failure. 

324 """ 

325 pass 

326 

327 

328async def dummy_fail_async(): 

329 """ 

330 Async fake function for a Retry object if you don't need to handle each failure. 

331 """ 

332 pass 

333 

334 

335def experimental(cls): 

336 """ 

337 Decorator to mark a class as experimental. 

338 """ 

339 original_init = cls.__init__ 

340 

341 @wraps(original_init) 

342 def new_init(self, *args, **kwargs): 

343 warnings.warn( 

344 f"{cls.__name__} is an experimental and may change or be removed in future versions.", 

345 category=UserWarning, 

346 stacklevel=2, 

347 ) 

348 original_init(self, *args, **kwargs) 

349 

350 cls.__init__ = new_init 

351 return cls 

352 

353 

354def warn_experimental(name, stacklevel=2): 

355 import warnings 

356 

357 msg = ( 

358 f"Call to experimental method {name}. " 

359 "Be aware that the function arguments can " 

360 "change or be removed in future versions." 

361 ) 

362 warnings.warn(msg, category=UserWarning, stacklevel=stacklevel) 

363 

364 

365def experimental_method() -> Callable[[C], C]: 

366 """ 

367 Decorator to mark a function as experimental. 

368 """ 

369 

370 def decorator(func: C) -> C: 

371 @wraps(func) 

372 def wrapper(*args, **kwargs): 

373 warn_experimental(func.__name__, stacklevel=2) 

374 return func(*args, **kwargs) 

375 

376 return wrapper 

377 

378 return decorator 

379 

380 

381def warn_experimental_arg_usage( 

382 arg_name: Union[list, str], 

383 function_name: str, 

384 stacklevel: int = 2, 

385): 

386 import warnings 

387 

388 msg = ( 

389 f"Call to '{function_name}' method with experimental" 

390 f" usage of input argument/s '{arg_name}'." 

391 ) 

392 warnings.warn(msg, category=UserWarning, stacklevel=stacklevel) 

393 

394 

395def experimental_args( 

396 args_to_warn: list = ["*"], 

397) -> Callable[[C], C]: 

398 """ 

399 Decorator to mark specified args of a function as experimental. 

400 """ 

401 

402 def decorator(func: C) -> C: 

403 @wraps(func) 

404 def wrapper(*args, **kwargs): 

405 # Get function argument names 

406 arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] 

407 

408 provided_args = dict(zip(arg_names, args)) 

409 provided_args.update(kwargs) 

410 

411 provided_args.pop("self", None) 

412 

413 if len(provided_args) == 0: 

414 return func(*args, **kwargs) 

415 

416 for arg in args_to_warn: 

417 if arg in provided_args: 

418 warn_experimental_arg_usage(arg, func.__name__, stacklevel=3) 

419 

420 return func(*args, **kwargs) 

421 

422 return wrapper 

423 

424 return decorator