Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/redis/_parsers/resp3.py: 14%

162 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-23 06:16 +0000

1from logging import getLogger 

2from typing import Any, Union 

3 

4from ..exceptions import ConnectionError, InvalidResponse, ResponseError 

5from ..typing import EncodableT 

6from .base import _AsyncRESPBase, _RESPBase 

7from .socket import SERVER_CLOSED_CONNECTION_ERROR 

8 

9_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"] 

10 

11 

12class _RESP3Parser(_RESPBase): 

13 """RESP3 protocol implementation""" 

14 

15 def __init__(self, socket_read_size): 

16 super().__init__(socket_read_size) 

17 self.pubsub_push_handler_func = self.handle_pubsub_push_response 

18 self.invalidations_push_handler_func = None 

19 

20 def handle_pubsub_push_response(self, response): 

21 logger = getLogger("push_response") 

22 logger.info("Push response: " + str(response)) 

23 return response 

24 

25 def read_response(self, disable_decoding=False, push_request=False): 

26 pos = self._buffer.get_pos() if self._buffer else None 

27 try: 

28 result = self._read_response( 

29 disable_decoding=disable_decoding, push_request=push_request 

30 ) 

31 except BaseException: 

32 if self._buffer: 

33 self._buffer.rewind(pos) 

34 raise 

35 else: 

36 self._buffer.purge() 

37 return result 

38 

39 def _read_response(self, disable_decoding=False, push_request=False): 

40 raw = self._buffer.readline() 

41 if not raw: 

42 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) 

43 

44 byte, response = raw[:1], raw[1:] 

45 

46 # server returned an error 

47 if byte in (b"-", b"!"): 

48 if byte == b"!": 

49 response = self._buffer.read(int(response)) 

50 response = response.decode("utf-8", errors="replace") 

51 error = self.parse_error(response) 

52 # if the error is a ConnectionError, raise immediately so the user 

53 # is notified 

54 if isinstance(error, ConnectionError): 

55 raise error 

56 # otherwise, we're dealing with a ResponseError that might belong 

57 # inside a pipeline response. the connection's read_response() 

58 # and/or the pipeline's execute() will raise this error if 

59 # necessary, so just return the exception instance here. 

60 return error 

61 # single value 

62 elif byte == b"+": 

63 pass 

64 # null value 

65 elif byte == b"_": 

66 return None 

67 # int and big int values 

68 elif byte in (b":", b"("): 

69 return int(response) 

70 # double value 

71 elif byte == b",": 

72 return float(response) 

73 # bool value 

74 elif byte == b"#": 

75 return response == b"t" 

76 # bulk response 

77 elif byte == b"$": 

78 response = self._buffer.read(int(response)) 

79 # verbatim string response 

80 elif byte == b"=": 

81 response = self._buffer.read(int(response))[4:] 

82 # array response 

83 elif byte == b"*": 

84 response = [ 

85 self._read_response(disable_decoding=disable_decoding) 

86 for _ in range(int(response)) 

87 ] 

88 # set response 

89 elif byte == b"~": 

90 # redis can return unhashable types (like dict) in a set, 

91 # so we need to first convert to a list, and then try to convert it to a set 

92 response = [ 

93 self._read_response(disable_decoding=disable_decoding) 

94 for _ in range(int(response)) 

95 ] 

96 try: 

97 response = set(response) 

98 except TypeError: 

99 pass 

100 # map response 

101 elif byte == b"%": 

102 # We cannot use a dict-comprehension to parse stream. 

103 # Evaluation order of key:val expression in dict comprehension only 

104 # became defined to be left-right in version 3.8 

105 resp_dict = {} 

106 for _ in range(int(response)): 

107 key = self._read_response(disable_decoding=disable_decoding) 

108 resp_dict[key] = self._read_response( 

109 disable_decoding=disable_decoding, push_request=push_request 

110 ) 

111 response = resp_dict 

112 # push response 

113 elif byte == b">": 

114 response = [ 

115 self._read_response( 

116 disable_decoding=disable_decoding, push_request=push_request 

117 ) 

118 for _ in range(int(response)) 

119 ] 

120 response = self.handle_push_response( 

121 response, disable_decoding, push_request 

122 ) 

123 else: 

124 raise InvalidResponse(f"Protocol Error: {raw!r}") 

125 

126 if isinstance(response, bytes) and disable_decoding is False: 

127 response = self.encoder.decode(response) 

128 return response 

129 

130 def handle_push_response(self, response, disable_decoding, push_request): 

131 if response[0] in _INVALIDATION_MESSAGE: 

132 res = self.invalidation_push_handler_func(response) 

133 else: 

134 res = self.pubsub_push_handler_func(response) 

135 if not push_request: 

136 return self._read_response( 

137 disable_decoding=disable_decoding, push_request=push_request 

138 ) 

139 else: 

140 return res 

141 

142 def set_pubsub_push_handler(self, pubsub_push_handler_func): 

143 self.pubsub_push_handler_func = pubsub_push_handler_func 

144 

145 def set_invalidation_push_handler(self, invalidations_push_handler_func): 

146 self.invalidation_push_handler_func = invalidations_push_handler_func 

147 

148 

149class _AsyncRESP3Parser(_AsyncRESPBase): 

150 def __init__(self, socket_read_size): 

151 super().__init__(socket_read_size) 

152 self.pubsub_push_handler_func = self.handle_pubsub_push_response 

153 self.invalidations_push_handler_func = None 

154 

155 def handle_pubsub_push_response(self, response): 

156 logger = getLogger("push_response") 

157 logger.info("Push response: " + str(response)) 

158 return response 

159 

160 async def read_response( 

161 self, disable_decoding: bool = False, push_request: bool = False 

162 ): 

163 if self._chunks: 

164 # augment parsing buffer with previously read data 

165 self._buffer += b"".join(self._chunks) 

166 self._chunks.clear() 

167 self._pos = 0 

168 response = await self._read_response( 

169 disable_decoding=disable_decoding, push_request=push_request 

170 ) 

171 # Successfully parsing a response allows us to clear our parsing buffer 

172 self._clear() 

173 return response 

174 

175 async def _read_response( 

176 self, disable_decoding: bool = False, push_request: bool = False 

177 ) -> Union[EncodableT, ResponseError, None]: 

178 if not self._stream or not self.encoder: 

179 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) 

180 raw = await self._readline() 

181 response: Any 

182 byte, response = raw[:1], raw[1:] 

183 

184 # if byte not in (b"-", b"+", b":", b"$", b"*"): 

185 # raise InvalidResponse(f"Protocol Error: {raw!r}") 

186 

187 # server returned an error 

188 if byte in (b"-", b"!"): 

189 if byte == b"!": 

190 response = await self._read(int(response)) 

191 response = response.decode("utf-8", errors="replace") 

192 error = self.parse_error(response) 

193 # if the error is a ConnectionError, raise immediately so the user 

194 # is notified 

195 if isinstance(error, ConnectionError): 

196 self._clear() # Successful parse 

197 raise error 

198 # otherwise, we're dealing with a ResponseError that might belong 

199 # inside a pipeline response. the connection's read_response() 

200 # and/or the pipeline's execute() will raise this error if 

201 # necessary, so just return the exception instance here. 

202 return error 

203 # single value 

204 elif byte == b"+": 

205 pass 

206 # null value 

207 elif byte == b"_": 

208 return None 

209 # int and big int values 

210 elif byte in (b":", b"("): 

211 return int(response) 

212 # double value 

213 elif byte == b",": 

214 return float(response) 

215 # bool value 

216 elif byte == b"#": 

217 return response == b"t" 

218 # bulk response 

219 elif byte == b"$": 

220 response = await self._read(int(response)) 

221 # verbatim string response 

222 elif byte == b"=": 

223 response = (await self._read(int(response)))[4:] 

224 # array response 

225 elif byte == b"*": 

226 response = [ 

227 (await self._read_response(disable_decoding=disable_decoding)) 

228 for _ in range(int(response)) 

229 ] 

230 # set response 

231 elif byte == b"~": 

232 # redis can return unhashable types (like dict) in a set, 

233 # so we need to first convert to a list, and then try to convert it to a set 

234 response = [ 

235 (await self._read_response(disable_decoding=disable_decoding)) 

236 for _ in range(int(response)) 

237 ] 

238 try: 

239 response = set(response) 

240 except TypeError: 

241 pass 

242 # map response 

243 elif byte == b"%": 

244 # We cannot use a dict-comprehension to parse stream. 

245 # Evaluation order of key:val expression in dict comprehension only 

246 # became defined to be left-right in version 3.8 

247 resp_dict = {} 

248 for _ in range(int(response)): 

249 key = await self._read_response(disable_decoding=disable_decoding) 

250 resp_dict[key] = await self._read_response( 

251 disable_decoding=disable_decoding, push_request=push_request 

252 ) 

253 response = resp_dict 

254 # push response 

255 elif byte == b">": 

256 response = [ 

257 ( 

258 await self._read_response( 

259 disable_decoding=disable_decoding, push_request=push_request 

260 ) 

261 ) 

262 for _ in range(int(response)) 

263 ] 

264 response = await self.handle_push_response( 

265 response, disable_decoding, push_request 

266 ) 

267 else: 

268 raise InvalidResponse(f"Protocol Error: {raw!r}") 

269 

270 if isinstance(response, bytes) and disable_decoding is False: 

271 response = self.encoder.decode(response) 

272 return response 

273 

274 async def handle_push_response(self, response, disable_decoding, push_request): 

275 if response[0] in _INVALIDATION_MESSAGE: 

276 res = self.invalidation_push_handler_func(response) 

277 else: 

278 res = self.pubsub_push_handler_func(response) 

279 if not push_request: 

280 return await self._read_response( 

281 disable_decoding=disable_decoding, push_request=push_request 

282 ) 

283 else: 

284 return res 

285 

286 def set_pubsub_push_handler(self, pubsub_push_handler_func): 

287 self.pubsub_push_handler_func = pubsub_push_handler_func 

288 

289 def set_invalidation_push_handler(self, invalidations_push_handler_func): 

290 self.invalidation_push_handler_func = invalidations_push_handler_func