Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/utils.py: 40%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

167 statements  

1import datetime 

2import logging 

3import textwrap 

4from collections.abc import Callable 

5from contextlib import contextmanager 

6from functools import wraps 

7from typing import Any, Dict, List, Mapping, Optional, TypeVar, Union 

8 

9from redis.exceptions import DataError 

10from redis.typing import AbsExpiryT, EncodableT, ExpiryT 

11 

12try: 

13 import hiredis # noqa 

14 

15 # Only support Hiredis >= 3.0: 

16 hiredis_version = hiredis.__version__.split(".") 

17 HIREDIS_AVAILABLE = int(hiredis_version[0]) > 3 or ( 

18 int(hiredis_version[0]) == 3 and int(hiredis_version[1]) >= 2 

19 ) 

20 if not HIREDIS_AVAILABLE: 

21 raise ImportError("hiredis package should be >= 3.2.0") 

22except ImportError: 

23 HIREDIS_AVAILABLE = False 

24 

25try: 

26 import ssl # noqa 

27 

28 SSL_AVAILABLE = True 

29except ImportError: 

30 SSL_AVAILABLE = False 

31 

32try: 

33 import cryptography # noqa 

34 

35 CRYPTOGRAPHY_AVAILABLE = True 

36except ImportError: 

37 CRYPTOGRAPHY_AVAILABLE = False 

38 

39from importlib import metadata 

40 

41 

42def from_url(url, **kwargs): 

43 """ 

44 Returns an active Redis client generated from the given database URL. 

45 

46 Will attempt to extract the database id from the path url fragment, if 

47 none is provided. 

48 """ 

49 from redis.client import Redis 

50 

51 return Redis.from_url(url, **kwargs) 

52 

53 

54@contextmanager 

55def pipeline(redis_obj): 

56 p = redis_obj.pipeline() 

57 yield p 

58 p.execute() 

59 

60 

61def str_if_bytes(value: Union[str, bytes]) -> str: 

62 return ( 

63 value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value 

64 ) 

65 

66 

67def safe_str(value): 

68 return str(str_if_bytes(value)) 

69 

70 

71def dict_merge(*dicts: Mapping[str, Any]) -> Dict[str, Any]: 

72 """ 

73 Merge all provided dicts into 1 dict. 

74 *dicts : `dict` 

75 dictionaries to merge 

76 """ 

77 merged = {} 

78 

79 for d in dicts: 

80 merged.update(d) 

81 

82 return merged 

83 

84 

85def list_keys_to_dict(key_list, callback): 

86 return dict.fromkeys(key_list, callback) 

87 

88 

89def merge_result(command, res): 

90 """ 

91 Merge all items in `res` into a list. 

92 

93 This command is used when sending a command to multiple nodes 

94 and the result from each node should be merged into a single list. 

95 

96 res : 'dict' 

97 """ 

98 result = set() 

99 

100 for v in res.values(): 

101 for value in v: 

102 result.add(value) 

103 

104 return list(result) 

105 

106 

107def warn_deprecated(name, reason="", version="", stacklevel=2): 

108 import warnings 

109 

110 msg = f"Call to deprecated {name}." 

111 if reason: 

112 msg += f" ({reason})" 

113 if version: 

114 msg += f" -- Deprecated since version {version}." 

115 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) 

116 

117 

118def deprecated_function(reason="", version="", name=None): 

119 """ 

120 Decorator to mark a function as deprecated. 

121 """ 

122 

123 def decorator(func): 

124 @wraps(func) 

125 def wrapper(*args, **kwargs): 

126 warn_deprecated(name or func.__name__, reason, version, stacklevel=3) 

127 return func(*args, **kwargs) 

128 

129 return wrapper 

130 

131 return decorator 

132 

133 

134def warn_deprecated_arg_usage( 

135 arg_name: Union[list, str], 

136 function_name: str, 

137 reason: str = "", 

138 version: str = "", 

139 stacklevel: int = 2, 

140): 

141 import warnings 

142 

143 msg = ( 

144 f"Call to '{function_name}' function with deprecated" 

145 f" usage of input argument/s '{arg_name}'." 

146 ) 

147 if reason: 

148 msg += f" ({reason})" 

149 if version: 

150 msg += f" -- Deprecated since version {version}." 

151 warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) 

152 

153 

154C = TypeVar("C", bound=Callable) 

155 

156 

157def deprecated_args( 

158 args_to_warn: list = ["*"], 

159 allowed_args: list = [], 

160 reason: str = "", 

161 version: str = "", 

162) -> Callable[[C], C]: 

163 """ 

164 Decorator to mark specified args of a function as deprecated. 

165 If '*' is in args_to_warn, all arguments will be marked as deprecated. 

166 """ 

167 

168 def decorator(func: C) -> C: 

169 @wraps(func) 

170 def wrapper(*args, **kwargs): 

171 # Get function argument names 

172 arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] 

173 

174 provided_args = dict(zip(arg_names, args)) 

175 provided_args.update(kwargs) 

176 

177 provided_args.pop("self", None) 

178 for allowed_arg in allowed_args: 

179 provided_args.pop(allowed_arg, None) 

180 

181 for arg in args_to_warn: 

182 if arg == "*" and len(provided_args) > 0: 

183 warn_deprecated_arg_usage( 

184 list(provided_args.keys()), 

185 func.__name__, 

186 reason, 

187 version, 

188 stacklevel=3, 

189 ) 

190 elif arg in provided_args: 

191 warn_deprecated_arg_usage( 

192 arg, func.__name__, reason, version, stacklevel=3 

193 ) 

194 

195 return func(*args, **kwargs) 

196 

197 return wrapper 

198 

199 return decorator 

200 

201 

202def _set_info_logger(): 

203 """ 

204 Set up a logger that log info logs to stdout. 

205 (This is used by the default push response handler) 

206 """ 

207 if "push_response" not in logging.root.manager.loggerDict.keys(): 

208 logger = logging.getLogger("push_response") 

209 logger.setLevel(logging.INFO) 

210 handler = logging.StreamHandler() 

211 handler.setLevel(logging.INFO) 

212 logger.addHandler(handler) 

213 

214 

215def get_lib_version(): 

216 try: 

217 libver = metadata.version("redis") 

218 except metadata.PackageNotFoundError: 

219 libver = "99.99.99" 

220 return libver 

221 

222 

223def format_error_message(host_error: str, exception: BaseException) -> str: 

224 if not exception.args: 

225 return f"Error connecting to {host_error}." 

226 elif len(exception.args) == 1: 

227 return f"Error {exception.args[0]} connecting to {host_error}." 

228 else: 

229 return ( 

230 f"Error {exception.args[0]} connecting to {host_error}. " 

231 f"{exception.args[1]}." 

232 ) 

233 

234 

235def compare_versions(version1: str, version2: str) -> int: 

236 """ 

237 Compare two versions. 

238 

239 :return: -1 if version1 > version2 

240 0 if both versions are equal 

241 1 if version1 < version2 

242 """ 

243 

244 num_versions1 = list(map(int, version1.split("."))) 

245 num_versions2 = list(map(int, version2.split("."))) 

246 

247 if len(num_versions1) > len(num_versions2): 

248 diff = len(num_versions1) - len(num_versions2) 

249 for _ in range(diff): 

250 num_versions2.append(0) 

251 elif len(num_versions1) < len(num_versions2): 

252 diff = len(num_versions2) - len(num_versions1) 

253 for _ in range(diff): 

254 num_versions1.append(0) 

255 

256 for i, ver in enumerate(num_versions1): 

257 if num_versions1[i] > num_versions2[i]: 

258 return -1 

259 elif num_versions1[i] < num_versions2[i]: 

260 return 1 

261 

262 return 0 

263 

264 

265def ensure_string(key): 

266 if isinstance(key, bytes): 

267 return key.decode("utf-8") 

268 elif isinstance(key, str): 

269 return key 

270 else: 

271 raise TypeError("Key must be either a string or bytes") 

272 

273 

274def extract_expire_flags( 

275 ex: Optional[ExpiryT] = None, 

276 px: Optional[ExpiryT] = None, 

277 exat: Optional[AbsExpiryT] = None, 

278 pxat: Optional[AbsExpiryT] = None, 

279) -> List[EncodableT]: 

280 exp_options: list[EncodableT] = [] 

281 if ex is not None: 

282 exp_options.append("EX") 

283 if isinstance(ex, datetime.timedelta): 

284 exp_options.append(int(ex.total_seconds())) 

285 elif isinstance(ex, int): 

286 exp_options.append(ex) 

287 elif isinstance(ex, str) and ex.isdigit(): 

288 exp_options.append(int(ex)) 

289 else: 

290 raise DataError("ex must be datetime.timedelta or int") 

291 elif px is not None: 

292 exp_options.append("PX") 

293 if isinstance(px, datetime.timedelta): 

294 exp_options.append(int(px.total_seconds() * 1000)) 

295 elif isinstance(px, int): 

296 exp_options.append(px) 

297 else: 

298 raise DataError("px must be datetime.timedelta or int") 

299 elif exat is not None: 

300 if isinstance(exat, datetime.datetime): 

301 exat = int(exat.timestamp()) 

302 exp_options.extend(["EXAT", exat]) 

303 elif pxat is not None: 

304 if isinstance(pxat, datetime.datetime): 

305 pxat = int(pxat.timestamp() * 1000) 

306 exp_options.extend(["PXAT", pxat]) 

307 

308 return exp_options 

309 

310 

311def truncate_text(txt, max_length=100): 

312 return textwrap.shorten( 

313 text=txt, width=max_length, placeholder="...", break_long_words=True 

314 )