1"""Http related parsers and protocol."""
2
3import asyncio
4import sys
5from typing import ( # noqa
6 TYPE_CHECKING,
7 Any,
8 Awaitable,
9 Callable,
10 Iterable,
11 List,
12 NamedTuple,
13 Optional,
14 Union,
15)
16
17from multidict import CIMultiDict
18
19from .abc import AbstractStreamWriter
20from .base_protocol import BaseProtocol
21from .client_exceptions import ClientConnectionResetError
22from .compression_utils import ZLibCompressor
23from .helpers import NO_EXTENSIONS
24
25__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")
26
27
28MIN_PAYLOAD_FOR_WRITELINES = 2048
29IS_PY313_BEFORE_313_2 = (3, 13, 0) <= sys.version_info < (3, 13, 2)
30IS_PY_BEFORE_312_9 = sys.version_info < (3, 12, 9)
31SKIP_WRITELINES = IS_PY313_BEFORE_313_2 or IS_PY_BEFORE_312_9
32# writelines is not safe for use
33# on Python 3.12+ until 3.12.9
34# on Python 3.13+ until 3.13.2
35# and on older versions it not any faster than write
36# CVE-2024-12254: https://github.com/python/cpython/pull/127656
37
38
39class HttpVersion(NamedTuple):
40 major: int
41 minor: int
42
43
44HttpVersion10 = HttpVersion(1, 0)
45HttpVersion11 = HttpVersion(1, 1)
46
47
48_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
49_T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]]
50
51
52class StreamWriter(AbstractStreamWriter):
53
54 length: Optional[int] = None
55 chunked: bool = False
56 _eof: bool = False
57 _compress: Optional[ZLibCompressor] = None
58
59 def __init__(
60 self,
61 protocol: BaseProtocol,
62 loop: asyncio.AbstractEventLoop,
63 on_chunk_sent: _T_OnChunkSent = None,
64 on_headers_sent: _T_OnHeadersSent = None,
65 ) -> None:
66 self._protocol = protocol
67 self.loop = loop
68 self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent
69 self._on_headers_sent: _T_OnHeadersSent = on_headers_sent
70 self._headers_buf: Optional[bytes] = None
71 self._headers_written: bool = False
72
73 @property
74 def transport(self) -> Optional[asyncio.Transport]:
75 return self._protocol.transport
76
77 @property
78 def protocol(self) -> BaseProtocol:
79 return self._protocol
80
81 def enable_chunking(self) -> None:
82 self.chunked = True
83
84 def enable_compression(
85 self, encoding: str = "deflate", strategy: Optional[int] = None
86 ) -> None:
87 self._compress = ZLibCompressor(encoding=encoding, strategy=strategy)
88
89 def _write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
90 size = len(chunk)
91 self.buffer_size += size
92 self.output_size += size
93 transport = self._protocol.transport
94 if transport is None or transport.is_closing():
95 raise ClientConnectionResetError("Cannot write to closing transport")
96 transport.write(chunk)
97
98 def _writelines(self, chunks: Iterable[bytes]) -> None:
99 size = 0
100 for chunk in chunks:
101 size += len(chunk)
102 self.buffer_size += size
103 self.output_size += size
104 transport = self._protocol.transport
105 if transport is None or transport.is_closing():
106 raise ClientConnectionResetError("Cannot write to closing transport")
107 if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES:
108 transport.write(b"".join(chunks))
109 else:
110 transport.writelines(chunks)
111
112 def _write_chunked_payload(
113 self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
114 ) -> None:
115 """Write a chunk with proper chunked encoding."""
116 chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii")
117 self._writelines((chunk_len_pre, chunk, b"\r\n"))
118
119 def _send_headers_with_payload(
120 self,
121 chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"],
122 is_eof: bool,
123 ) -> None:
124 """Send buffered headers with payload, coalescing into single write."""
125 # Mark headers as written
126 self._headers_written = True
127 headers_buf = self._headers_buf
128 self._headers_buf = None
129
130 if TYPE_CHECKING:
131 # Safe because callers (write() and write_eof()) only invoke this method
132 # after checking that self._headers_buf is truthy
133 assert headers_buf is not None
134
135 if not self.chunked:
136 # Non-chunked: coalesce headers with body
137 if chunk:
138 self._writelines((headers_buf, chunk))
139 else:
140 self._write(headers_buf)
141 return
142
143 # Coalesce headers with chunked data
144 if chunk:
145 chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii")
146 if is_eof:
147 self._writelines((headers_buf, chunk_len_pre, chunk, b"\r\n0\r\n\r\n"))
148 else:
149 self._writelines((headers_buf, chunk_len_pre, chunk, b"\r\n"))
150 elif is_eof:
151 self._writelines((headers_buf, b"0\r\n\r\n"))
152 else:
153 self._write(headers_buf)
154
155 async def write(
156 self,
157 chunk: Union[bytes, bytearray, memoryview],
158 *,
159 drain: bool = True,
160 LIMIT: int = 0x10000,
161 ) -> None:
162 """
163 Writes chunk of data to a stream.
164
165 write_eof() indicates end of stream.
166 writer can't be used after write_eof() method being called.
167 write() return drain future.
168 """
169 if self._on_chunk_sent is not None:
170 await self._on_chunk_sent(chunk)
171
172 if isinstance(chunk, memoryview):
173 if chunk.nbytes != len(chunk):
174 # just reshape it
175 chunk = chunk.cast("c")
176
177 if self._compress is not None:
178 chunk = await self._compress.compress(chunk)
179 if not chunk:
180 return
181
182 if self.length is not None:
183 chunk_len = len(chunk)
184 if self.length >= chunk_len:
185 self.length = self.length - chunk_len
186 else:
187 chunk = chunk[: self.length]
188 self.length = 0
189 if not chunk:
190 return
191
192 # Handle buffered headers for small payload optimization
193 if self._headers_buf and not self._headers_written:
194 self._send_headers_with_payload(chunk, False)
195 if drain and self.buffer_size > LIMIT:
196 self.buffer_size = 0
197 await self.drain()
198 return
199
200 if chunk:
201 if self.chunked:
202 self._write_chunked_payload(chunk)
203 else:
204 self._write(chunk)
205
206 if drain and self.buffer_size > LIMIT:
207 self.buffer_size = 0
208 await self.drain()
209
210 async def write_headers(
211 self, status_line: str, headers: "CIMultiDict[str]"
212 ) -> None:
213 """Write headers to the stream."""
214 if self._on_headers_sent is not None:
215 await self._on_headers_sent(headers)
216 # status + headers
217 buf = _serialize_headers(status_line, headers)
218 self._headers_written = False
219 self._headers_buf = buf
220
221 def send_headers(self) -> None:
222 """Force sending buffered headers if not already sent."""
223 if not self._headers_buf or self._headers_written:
224 return
225
226 self._headers_written = True
227 headers_buf = self._headers_buf
228 self._headers_buf = None
229
230 if TYPE_CHECKING:
231 # Safe because we only enter this block when self._headers_buf is truthy
232 assert headers_buf is not None
233
234 self._write(headers_buf)
235
236 def set_eof(self) -> None:
237 """Indicate that the message is complete."""
238 if self._eof:
239 return
240
241 # If headers haven't been sent yet, send them now
242 # This handles the case where there's no body at all
243 if self._headers_buf and not self._headers_written:
244 self._headers_written = True
245 headers_buf = self._headers_buf
246 self._headers_buf = None
247
248 if TYPE_CHECKING:
249 # Safe because we only enter this block when self._headers_buf is truthy
250 assert headers_buf is not None
251
252 # Combine headers and chunked EOF marker in a single write
253 if self.chunked:
254 self._writelines((headers_buf, b"0\r\n\r\n"))
255 else:
256 self._write(headers_buf)
257 elif self.chunked and self._headers_written:
258 # Headers already sent, just send the final chunk marker
259 self._write(b"0\r\n\r\n")
260
261 self._eof = True
262
263 async def write_eof(self, chunk: bytes = b"") -> None:
264 if self._eof:
265 return
266
267 if chunk and self._on_chunk_sent is not None:
268 await self._on_chunk_sent(chunk)
269
270 # Handle body/compression
271 if self._compress:
272 chunks: List[bytes] = []
273 chunks_len = 0
274 if chunk and (compressed_chunk := await self._compress.compress(chunk)):
275 chunks_len = len(compressed_chunk)
276 chunks.append(compressed_chunk)
277
278 flush_chunk = self._compress.flush()
279 chunks_len += len(flush_chunk)
280 chunks.append(flush_chunk)
281 assert chunks_len
282
283 # Send buffered headers with compressed data if not yet sent
284 if self._headers_buf and not self._headers_written:
285 self._headers_written = True
286 headers_buf = self._headers_buf
287 self._headers_buf = None
288
289 if self.chunked:
290 # Coalesce headers with compressed chunked data
291 chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii")
292 self._writelines(
293 (headers_buf, chunk_len_pre, *chunks, b"\r\n0\r\n\r\n")
294 )
295 else:
296 # Coalesce headers with compressed data
297 self._writelines((headers_buf, *chunks))
298 await self.drain()
299 self._eof = True
300 return
301
302 # Headers already sent, just write compressed data
303 if self.chunked:
304 chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii")
305 self._writelines((chunk_len_pre, *chunks, b"\r\n0\r\n\r\n"))
306 elif len(chunks) > 1:
307 self._writelines(chunks)
308 else:
309 self._write(chunks[0])
310 await self.drain()
311 self._eof = True
312 return
313
314 # No compression - send buffered headers if not yet sent
315 if self._headers_buf and not self._headers_written:
316 # Use helper to send headers with payload
317 self._send_headers_with_payload(chunk, True)
318 await self.drain()
319 self._eof = True
320 return
321
322 # Handle remaining body
323 if self.chunked:
324 if chunk:
325 # Write final chunk with EOF marker
326 self._writelines(
327 (f"{len(chunk):x}\r\n".encode("ascii"), chunk, b"\r\n0\r\n\r\n")
328 )
329 else:
330 self._write(b"0\r\n\r\n")
331 await self.drain()
332 self._eof = True
333 return
334
335 if chunk:
336 self._write(chunk)
337 await self.drain()
338
339 self._eof = True
340
341 async def drain(self) -> None:
342 """Flush the write buffer.
343
344 The intended use is to write
345
346 await w.write(data)
347 await w.drain()
348 """
349 protocol = self._protocol
350 if protocol.transport is not None and protocol._paused:
351 await protocol._drain_helper()
352
353
354def _safe_header(string: str) -> str:
355 if "\r" in string or "\n" in string:
356 raise ValueError(
357 "Newline or carriage return detected in headers. "
358 "Potential header injection attack."
359 )
360 return string
361
362
363def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes:
364 headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items())
365 line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n"
366 return line.encode("utf-8")
367
368
369_serialize_headers = _py_serialize_headers
370
371try:
372 import aiohttp._http_writer as _http_writer # type: ignore[import-not-found]
373
374 _c_serialize_headers = _http_writer._serialize_headers
375 if not NO_EXTENSIONS:
376 _serialize_headers = _c_serialize_headers
377except ImportError:
378 pass