Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/utils.py: 39%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

165 statements  

1import datetime 

2import logging 

3import textwrap 

4from contextlib import contextmanager 

5from functools import wraps 

6from typing import Any, Dict, List, Mapping, Optional, Union 

7 

8from redis.exceptions import DataError 

9from redis.typing import AbsExpiryT, EncodableT, ExpiryT 

10 

11try: 

12 import hiredis # noqa 

13 

14 # Only support Hiredis >= 3.0: 

15 hiredis_version = hiredis.__version__.split(".") 

16 HIREDIS_AVAILABLE = int(hiredis_version[0]) > 3 or ( 

17 int(hiredis_version[0]) == 3 and int(hiredis_version[1]) >= 2 

18 ) 

19 if not HIREDIS_AVAILABLE: 

20 raise ImportError("hiredis package should be >= 3.2.0") 

21except ImportError: 

22 HIREDIS_AVAILABLE = False 

23 

24try: 

25 import ssl # noqa 

26 

27 SSL_AVAILABLE = True 

28except ImportError: 

29 SSL_AVAILABLE = False 

30 

31try: 

32 import cryptography # noqa 

33 

34 CRYPTOGRAPHY_AVAILABLE = True 

35except ImportError: 

36 CRYPTOGRAPHY_AVAILABLE = False 

37 

38from importlib import metadata 

39 

40 

41def from_url(url, **kwargs): 

42 """ 

43 Returns an active Redis client generated from the given database URL. 

44 

45 Will attempt to extract the database id from the path url fragment, if 

46 none is provided. 

47 """ 

48 from redis.client import Redis 

49 

50 return Redis.from_url(url, **kwargs) 

51 

52 

53@contextmanager 

54def pipeline(redis_obj): 

55 p = redis_obj.pipeline() 

56 yield p 

57 p.execute() 

58 

59 

60def str_if_bytes(value: Union[str, bytes]) -> str: 

61 return ( 

62 value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value 

63 ) 

64 

65 

66def safe_str(value): 

67 return str(str_if_bytes(value)) 

68 

69 

70def dict_merge(*dicts: Mapping[str, Any]) -> Dict[str, Any]: 

71 """ 

72 Merge all provided dicts into 1 dict. 

73 *dicts : `dict` 

74 dictionaries to merge 

75 """ 

76 merged = {} 

77 

78 for d in dicts: 

79 merged.update(d) 

80 

81 return merged 

82 

83 

84def list_keys_to_dict(key_list, callback): 

85 return dict.fromkeys(key_list, callback) 

86 

87 

88def merge_result(command, res): 

89 """ 

90 Merge all items in `res` into a list. 

91 

92 This command is used when sending a command to multiple nodes 

93 and the result from each node should be merged into a single list. 

94 

95 res : 'dict' 

96 """ 

97 result = set() 

98 

99 for v in res.values(): 

100 for value in v: 

101 result.add(value) 

102 

103 return list(result) 

104 

105 

106def warn_deprecated(name, reason="", version="", stacklevel=2): 

107 import warnings 

108 

109 msg = f"Call to deprecated {name}." 

110 if reason: 

111 msg += f" ({reason})" 

112 if version: 

113 msg += f" -- Deprecated since version {version}." 

114 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) 

115 

116 

117def deprecated_function(reason="", version="", name=None): 

118 """ 

119 Decorator to mark a function as deprecated. 

120 """ 

121 

122 def decorator(func): 

123 @wraps(func) 

124 def wrapper(*args, **kwargs): 

125 warn_deprecated(name or func.__name__, reason, version, stacklevel=3) 

126 return func(*args, **kwargs) 

127 

128 return wrapper 

129 

130 return decorator 

131 

132 

133def warn_deprecated_arg_usage( 

134 arg_name: Union[list, str], 

135 function_name: str, 

136 reason: str = "", 

137 version: str = "", 

138 stacklevel: int = 2, 

139): 

140 import warnings 

141 

142 msg = ( 

143 f"Call to '{function_name}' function with deprecated" 

144 f" usage of input argument/s '{arg_name}'." 

145 ) 

146 if reason: 

147 msg += f" ({reason})" 

148 if version: 

149 msg += f" -- Deprecated since version {version}." 

150 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) 

151 

152 

153def deprecated_args( 

154 args_to_warn: list = ["*"], 

155 allowed_args: list = [], 

156 reason: str = "", 

157 version: str = "", 

158): 

159 """ 

160 Decorator to mark specified args of a function as deprecated. 

161 If '*' is in args_to_warn, all arguments will be marked as deprecated. 

162 """ 

163 

164 def decorator(func): 

165 @wraps(func) 

166 def wrapper(*args, **kwargs): 

167 # Get function argument names 

168 arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] 

169 

170 provided_args = dict(zip(arg_names, args)) 

171 provided_args.update(kwargs) 

172 

173 provided_args.pop("self", None) 

174 for allowed_arg in allowed_args: 

175 provided_args.pop(allowed_arg, None) 

176 

177 for arg in args_to_warn: 

178 if arg == "*" and len(provided_args) > 0: 

179 warn_deprecated_arg_usage( 

180 list(provided_args.keys()), 

181 func.__name__, 

182 reason, 

183 version, 

184 stacklevel=3, 

185 ) 

186 elif arg in provided_args: 

187 warn_deprecated_arg_usage( 

188 arg, func.__name__, reason, version, stacklevel=3 

189 ) 

190 

191 return func(*args, **kwargs) 

192 

193 return wrapper 

194 

195 return decorator 

196 

197 

198def _set_info_logger(): 

199 """ 

200 Set up a logger that log info logs to stdout. 

201 (This is used by the default push response handler) 

202 """ 

203 if "push_response" not in logging.root.manager.loggerDict.keys(): 

204 logger = logging.getLogger("push_response") 

205 logger.setLevel(logging.INFO) 

206 handler = logging.StreamHandler() 

207 handler.setLevel(logging.INFO) 

208 logger.addHandler(handler) 

209 

210 

211def get_lib_version(): 

212 try: 

213 libver = metadata.version("redis") 

214 except metadata.PackageNotFoundError: 

215 libver = "99.99.99" 

216 return libver 

217 

218 

219def format_error_message(host_error: str, exception: BaseException) -> str: 

220 if not exception.args: 

221 return f"Error connecting to {host_error}." 

222 elif len(exception.args) == 1: 

223 return f"Error {exception.args[0]} connecting to {host_error}." 

224 else: 

225 return ( 

226 f"Error {exception.args[0]} connecting to {host_error}. " 

227 f"{exception.args[1]}." 

228 ) 

229 

230 

231def compare_versions(version1: str, version2: str) -> int: 

232 """ 

233 Compare two versions. 

234 

235 :return: -1 if version1 > version2 

236 0 if both versions are equal 

237 1 if version1 < version2 

238 """ 

239 

240 num_versions1 = list(map(int, version1.split("."))) 

241 num_versions2 = list(map(int, version2.split("."))) 

242 

243 if len(num_versions1) > len(num_versions2): 

244 diff = len(num_versions1) - len(num_versions2) 

245 for _ in range(diff): 

246 num_versions2.append(0) 

247 elif len(num_versions1) < len(num_versions2): 

248 diff = len(num_versions2) - len(num_versions1) 

249 for _ in range(diff): 

250 num_versions1.append(0) 

251 

252 for i, ver in enumerate(num_versions1): 

253 if num_versions1[i] > num_versions2[i]: 

254 return -1 

255 elif num_versions1[i] < num_versions2[i]: 

256 return 1 

257 

258 return 0 

259 

260 

261def ensure_string(key): 

262 if isinstance(key, bytes): 

263 return key.decode("utf-8") 

264 elif isinstance(key, str): 

265 return key 

266 else: 

267 raise TypeError("Key must be either a string or bytes") 

268 

269 

270def extract_expire_flags( 

271 ex: Optional[ExpiryT] = None, 

272 px: Optional[ExpiryT] = None, 

273 exat: Optional[AbsExpiryT] = None, 

274 pxat: Optional[AbsExpiryT] = None, 

275) -> List[EncodableT]: 

276 exp_options: list[EncodableT] = [] 

277 if ex is not None: 

278 exp_options.append("EX") 

279 if isinstance(ex, datetime.timedelta): 

280 exp_options.append(int(ex.total_seconds())) 

281 elif isinstance(ex, int): 

282 exp_options.append(ex) 

283 elif isinstance(ex, str) and ex.isdigit(): 

284 exp_options.append(int(ex)) 

285 else: 

286 raise DataError("ex must be datetime.timedelta or int") 

287 elif px is not None: 

288 exp_options.append("PX") 

289 if isinstance(px, datetime.timedelta): 

290 exp_options.append(int(px.total_seconds() * 1000)) 

291 elif isinstance(px, int): 

292 exp_options.append(px) 

293 else: 

294 raise DataError("px must be datetime.timedelta or int") 

295 elif exat is not None: 

296 if isinstance(exat, datetime.datetime): 

297 exat = int(exat.timestamp()) 

298 exp_options.extend(["EXAT", exat]) 

299 elif pxat is not None: 

300 if isinstance(pxat, datetime.datetime): 

301 pxat = int(pxat.timestamp() * 1000) 

302 exp_options.extend(["PXAT", pxat]) 

303 

304 return exp_options 

305 

306 

307def truncate_text(txt, max_length=100): 

308 return textwrap.shorten( 

309 text=txt, width=max_length, placeholder="...", break_long_words=True 

310 )