1from typing import Any, Union
2
3from ..exceptions import ConnectionError, InvalidResponse, ResponseError
4from ..typing import EncodableT
5from .base import _AsyncRESPBase, _RESPBase
6from .socket import SERVER_CLOSED_CONNECTION_ERROR
7
8
9class _RESP2Parser(_RESPBase):
10 """RESP2 protocol implementation"""
11
12 def read_response(self, disable_decoding=False):
13 pos = self._buffer.get_pos() if self._buffer else None
14 try:
15 result = self._read_response(disable_decoding=disable_decoding)
16 except BaseException:
17 if self._buffer:
18 self._buffer.rewind(pos)
19 raise
20 else:
21 self._buffer.purge()
22 return result
23
24 def _read_response(self, disable_decoding=False):
25 raw = self._buffer.readline()
26 if not raw:
27 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
28
29 byte, response = raw[:1], raw[1:]
30
31 # server returned an error
32 if byte == b"-":
33 response = response.decode("utf-8", errors="replace")
34 error = self.parse_error(response)
35 # if the error is a ConnectionError, raise immediately so the user
36 # is notified
37 if isinstance(error, ConnectionError):
38 raise error
39 # otherwise, we're dealing with a ResponseError that might belong
40 # inside a pipeline response. the connection's read_response()
41 # and/or the pipeline's execute() will raise this error if
42 # necessary, so just return the exception instance here.
43 return error
44 # single value
45 elif byte == b"+":
46 pass
47 # int value
48 elif byte == b":":
49 return int(response)
50 # bulk response
51 elif byte == b"$" and response == b"-1":
52 return None
53 elif byte == b"$":
54 response = self._buffer.read(int(response))
55 # multi-bulk response
56 elif byte == b"*" and response == b"-1":
57 return None
58 elif byte == b"*":
59 response = [
60 self._read_response(disable_decoding=disable_decoding)
61 for i in range(int(response))
62 ]
63 else:
64 raise InvalidResponse(f"Protocol Error: {raw!r}")
65
66 if disable_decoding is False:
67 response = self.encoder.decode(response)
68 return response
69
70
71class _AsyncRESP2Parser(_AsyncRESPBase):
72 """Async class for the RESP2 protocol"""
73
74 async def read_response(self, disable_decoding: bool = False):
75 if not self._connected:
76 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
77 if self._chunks:
78 # augment parsing buffer with previously read data
79 self._buffer += b"".join(self._chunks)
80 self._chunks.clear()
81 self._pos = 0
82 response = await self._read_response(disable_decoding=disable_decoding)
83 # Successfully parsing a response allows us to clear our parsing buffer
84 self._clear()
85 return response
86
87 async def _read_response(
88 self, disable_decoding: bool = False
89 ) -> Union[EncodableT, ResponseError, None]:
90 raw = await self._readline()
91 response: Any
92 byte, response = raw[:1], raw[1:]
93
94 # server returned an error
95 if byte == b"-":
96 response = response.decode("utf-8", errors="replace")
97 error = self.parse_error(response)
98 # if the error is a ConnectionError, raise immediately so the user
99 # is notified
100 if isinstance(error, ConnectionError):
101 self._clear() # Successful parse
102 raise error
103 # otherwise, we're dealing with a ResponseError that might belong
104 # inside a pipeline response. the connection's read_response()
105 # and/or the pipeline's execute() will raise this error if
106 # necessary, so just return the exception instance here.
107 return error
108 # single value
109 elif byte == b"+":
110 pass
111 # int value
112 elif byte == b":":
113 return int(response)
114 # bulk response
115 elif byte == b"$" and response == b"-1":
116 return None
117 elif byte == b"$":
118 response = await self._read(int(response))
119 # multi-bulk response
120 elif byte == b"*" and response == b"-1":
121 return None
122 elif byte == b"*":
123 response = [
124 (await self._read_response(disable_decoding))
125 for _ in range(int(response)) # noqa
126 ]
127 else:
128 raise InvalidResponse(f"Protocol Error: {raw!r}")
129
130 if disable_decoding is False:
131 response = self.encoder.decode(response)
132 return response