1import asyncio
2from contextlib import suppress
3from typing import Any, Optional, Tuple, Union
4
5from .base_protocol import BaseProtocol
6from .client_exceptions import (
7 ClientConnectionError,
8 ClientOSError,
9 ClientPayloadError,
10 ServerDisconnectedError,
11 SocketTimeoutError,
12)
13from .helpers import (
14 _EXC_SENTINEL,
15 EMPTY_BODY_STATUS_CODES,
16 BaseTimerContext,
17 set_exception,
18 set_result,
19)
20from .http import HttpResponseParser, RawResponseMessage
21from .http_exceptions import HttpProcessingError
22from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader
23
24
25class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamReader]]):
26 """Helper class to adapt between Protocol and StreamReader."""
27
28 def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
29 BaseProtocol.__init__(self, loop=loop)
30 DataQueue.__init__(self, loop)
31
32 self._should_close = False
33
34 self._payload: Optional[StreamReader] = None
35 self._skip_payload = False
36 self._payload_parser = None
37
38 self._timer = None
39
40 self._tail = b""
41 self._upgraded = False
42 self._parser: Optional[HttpResponseParser] = None
43
44 self._read_timeout: Optional[float] = None
45 self._read_timeout_handle: Optional[asyncio.TimerHandle] = None
46
47 self._timeout_ceil_threshold: Optional[float] = 5
48
49 self._closed: Union[None, asyncio.Future[None]] = None
50 self._connection_lost_called = False
51
52 @property
53 def closed(self) -> Union[None, asyncio.Future[None]]:
54 """Future that is set when the connection is closed.
55
56 This property returns a Future that will be completed when the connection
57 is closed. The Future is created lazily on first access to avoid creating
58 futures that will never be awaited.
59
60 Returns:
61 - A Future[None] if the connection is still open or was closed after
62 this property was accessed
63 - None if connection_lost() was already called before this property
64 was ever accessed (indicating no one is waiting for the closure)
65 """
66 if self._closed is None and not self._connection_lost_called:
67 self._closed = self._loop.create_future()
68 return self._closed
69
70 @property
71 def upgraded(self) -> bool:
72 return self._upgraded
73
74 @property
75 def should_close(self) -> bool:
76 return bool(
77 self._should_close
78 or (self._payload is not None and not self._payload.is_eof())
79 or self._upgraded
80 or self._exception is not None
81 or self._payload_parser is not None
82 or self._buffer
83 or self._tail
84 )
85
86 def force_close(self) -> None:
87 self._should_close = True
88
89 def close(self) -> None:
90 self._exception = None # Break cyclic references
91 transport = self.transport
92 if transport is not None:
93 transport.close()
94 self.transport = None
95 self._payload = None
96 self._drop_timeout()
97
98 def abort(self) -> None:
99 self._exception = None # Break cyclic references
100 transport = self.transport
101 if transport is not None:
102 transport.abort()
103 self.transport = None
104 self._payload = None
105 self._drop_timeout()
106
107 def is_connected(self) -> bool:
108 return self.transport is not None and not self.transport.is_closing()
109
110 def connection_lost(self, exc: Optional[BaseException]) -> None:
111 self._connection_lost_called = True
112 self._drop_timeout()
113
114 original_connection_error = exc
115 reraised_exc = original_connection_error
116
117 connection_closed_cleanly = original_connection_error is None
118
119 if self._closed is not None:
120 # If someone is waiting for the closed future,
121 # we should set it to None or an exception. If
122 # self._closed is None, it means that
123 # connection_lost() was called already
124 # or nobody is waiting for it.
125 if connection_closed_cleanly:
126 set_result(self._closed, None)
127 else:
128 assert original_connection_error is not None
129 set_exception(
130 self._closed,
131 ClientConnectionError(
132 f"Connection lost: {original_connection_error !s}",
133 ),
134 original_connection_error,
135 )
136
137 if self._payload_parser is not None:
138 with suppress(Exception): # FIXME: log this somehow?
139 self._payload_parser.feed_eof()
140
141 uncompleted = None
142 if self._parser is not None:
143 try:
144 uncompleted = self._parser.feed_eof()
145 except Exception as underlying_exc:
146 if self._payload is not None:
147 client_payload_exc_msg = (
148 f"Response payload is not completed: {underlying_exc !r}"
149 )
150 if not connection_closed_cleanly:
151 client_payload_exc_msg = (
152 f"{client_payload_exc_msg !s}. "
153 f"{original_connection_error !r}"
154 )
155 set_exception(
156 self._payload,
157 ClientPayloadError(client_payload_exc_msg),
158 underlying_exc,
159 )
160
161 if not self.is_eof():
162 if isinstance(original_connection_error, OSError):
163 reraised_exc = ClientOSError(*original_connection_error.args)
164 if connection_closed_cleanly:
165 reraised_exc = ServerDisconnectedError(uncompleted)
166 # assigns self._should_close to True as side effect,
167 # we do it anyway below
168 underlying_non_eof_exc = (
169 _EXC_SENTINEL
170 if connection_closed_cleanly
171 else original_connection_error
172 )
173 assert underlying_non_eof_exc is not None
174 assert reraised_exc is not None
175 self.set_exception(reraised_exc, underlying_non_eof_exc)
176
177 self._should_close = True
178 self._parser = None
179 self._payload = None
180 self._payload_parser = None
181 self._reading_paused = False
182
183 super().connection_lost(reraised_exc)
184
185 def eof_received(self) -> None:
186 # should call parser.feed_eof() most likely
187 self._drop_timeout()
188
189 def pause_reading(self) -> None:
190 super().pause_reading()
191 self._drop_timeout()
192
193 def resume_reading(self) -> None:
194 super().resume_reading()
195 self._reschedule_timeout()
196
197 def set_exception(
198 self,
199 exc: BaseException,
200 exc_cause: BaseException = _EXC_SENTINEL,
201 ) -> None:
202 self._should_close = True
203 self._drop_timeout()
204 super().set_exception(exc, exc_cause)
205
206 def set_parser(self, parser: Any, payload: Any) -> None:
207 # TODO: actual types are:
208 # parser: WebSocketReader
209 # payload: WebSocketDataQueue
210 # but they are not generi enough
211 # Need an ABC for both types
212 self._payload = payload
213 self._payload_parser = parser
214
215 self._drop_timeout()
216
217 if self._tail:
218 data, self._tail = self._tail, b""
219 self.data_received(data)
220
221 def set_response_params(
222 self,
223 *,
224 timer: Optional[BaseTimerContext] = None,
225 skip_payload: bool = False,
226 read_until_eof: bool = False,
227 auto_decompress: bool = True,
228 read_timeout: Optional[float] = None,
229 read_bufsize: int = 2**16,
230 timeout_ceil_threshold: float = 5,
231 max_line_size: int = 8190,
232 max_field_size: int = 8190,
233 ) -> None:
234 self._skip_payload = skip_payload
235
236 self._read_timeout = read_timeout
237
238 self._timeout_ceil_threshold = timeout_ceil_threshold
239
240 self._parser = HttpResponseParser(
241 self,
242 self._loop,
243 read_bufsize,
244 timer=timer,
245 payload_exception=ClientPayloadError,
246 response_with_body=not skip_payload,
247 read_until_eof=read_until_eof,
248 auto_decompress=auto_decompress,
249 max_line_size=max_line_size,
250 max_field_size=max_field_size,
251 )
252
253 if self._tail:
254 data, self._tail = self._tail, b""
255 self.data_received(data)
256
257 def _drop_timeout(self) -> None:
258 if self._read_timeout_handle is not None:
259 self._read_timeout_handle.cancel()
260 self._read_timeout_handle = None
261
262 def _reschedule_timeout(self) -> None:
263 timeout = self._read_timeout
264 if self._read_timeout_handle is not None:
265 self._read_timeout_handle.cancel()
266
267 if timeout:
268 self._read_timeout_handle = self._loop.call_later(
269 timeout, self._on_read_timeout
270 )
271 else:
272 self._read_timeout_handle = None
273
274 def start_timeout(self) -> None:
275 self._reschedule_timeout()
276
277 @property
278 def read_timeout(self) -> Optional[float]:
279 return self._read_timeout
280
281 @read_timeout.setter
282 def read_timeout(self, read_timeout: Optional[float]) -> None:
283 self._read_timeout = read_timeout
284
285 def _on_read_timeout(self) -> None:
286 exc = SocketTimeoutError("Timeout on reading data from socket")
287 self.set_exception(exc)
288 if self._payload is not None:
289 set_exception(self._payload, exc)
290
291 def data_received(self, data: bytes) -> None:
292 self._reschedule_timeout()
293
294 if not data:
295 return
296
297 # custom payload parser - currently always WebSocketReader
298 if self._payload_parser is not None:
299 eof, tail = self._payload_parser.feed_data(data)
300 if eof:
301 self._payload = None
302 self._payload_parser = None
303
304 if tail:
305 self.data_received(tail)
306 return
307
308 if self._upgraded or self._parser is None:
309 # i.e. websocket connection, websocket parser is not set yet
310 self._tail += data
311 return
312
313 # parse http messages
314 try:
315 messages, upgraded, tail = self._parser.feed_data(data)
316 except BaseException as underlying_exc:
317 if self.transport is not None:
318 # connection.release() could be called BEFORE
319 # data_received(), the transport is already
320 # closed in this case
321 self.transport.close()
322 # should_close is True after the call
323 if isinstance(underlying_exc, HttpProcessingError):
324 exc = HttpProcessingError(
325 code=underlying_exc.code,
326 message=underlying_exc.message,
327 headers=underlying_exc.headers,
328 )
329 else:
330 exc = HttpProcessingError()
331 self.set_exception(exc, underlying_exc)
332 return
333
334 self._upgraded = upgraded
335
336 payload: Optional[StreamReader] = None
337 for message, payload in messages:
338 if message.should_close:
339 self._should_close = True
340
341 self._payload = payload
342
343 if self._skip_payload or message.code in EMPTY_BODY_STATUS_CODES:
344 self.feed_data((message, EMPTY_PAYLOAD), 0)
345 else:
346 self.feed_data((message, payload), 0)
347
348 if payload is not None:
349 # new message(s) was processed
350 # register timeout handler unsubscribing
351 # either on end-of-stream or immediately for
352 # EMPTY_PAYLOAD
353 if payload is not EMPTY_PAYLOAD:
354 payload.on_eof(self._drop_timeout)
355 else:
356 self._drop_timeout()
357
358 if upgraded and tail:
359 self.data_received(tail)