Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/http_writer.py: 34%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Http related parsers and protocol."""
3import asyncio
4import sys
5from typing import ( # noqa
6 Any,
7 Awaitable,
8 Callable,
9 Iterable,
10 List,
11 NamedTuple,
12 Optional,
13 Union,
14)
16from multidict import CIMultiDict
18from .abc import AbstractStreamWriter
19from .base_protocol import BaseProtocol
20from .client_exceptions import ClientConnectionResetError
21from .compression_utils import ZLibCompressor
22from .helpers import NO_EXTENSIONS
24__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")
27MIN_PAYLOAD_FOR_WRITELINES = 2048
28IS_PY313_BEFORE_313_2 = (3, 13, 0) <= sys.version_info < (3, 13, 2)
29IS_PY_BEFORE_312_9 = sys.version_info < (3, 12, 9)
30SKIP_WRITELINES = IS_PY313_BEFORE_313_2 or IS_PY_BEFORE_312_9
31# writelines is not safe for use
32# on Python 3.12+ until 3.12.9
33# on Python 3.13+ until 3.13.2
34# and on older versions it not any faster than write
35# CVE-2024-12254: https://github.com/python/cpython/pull/127656
38class HttpVersion(NamedTuple):
39 major: int
40 minor: int
43HttpVersion10 = HttpVersion(1, 0)
44HttpVersion11 = HttpVersion(1, 1)
47_T_OnChunkSent = Optional[
48 Callable[
49 [Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]],
50 Awaitable[None],
51 ]
52]
53_T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]]
56class StreamWriter(AbstractStreamWriter):
58 length: Optional[int] = None
59 chunked: bool = False
60 _eof: bool = False
61 _compress: Optional[ZLibCompressor] = None
63 def __init__(
64 self,
65 protocol: BaseProtocol,
66 loop: asyncio.AbstractEventLoop,
67 on_chunk_sent: _T_OnChunkSent = None,
68 on_headers_sent: _T_OnHeadersSent = None,
69 ) -> None:
70 self._protocol = protocol
71 self.loop = loop
72 self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent
73 self._on_headers_sent: _T_OnHeadersSent = on_headers_sent
75 @property
76 def transport(self) -> Optional[asyncio.Transport]:
77 return self._protocol.transport
79 @property
80 def protocol(self) -> BaseProtocol:
81 return self._protocol
83 def enable_chunking(self) -> None:
84 self.chunked = True
86 def enable_compression(
87 self, encoding: str = "deflate", strategy: Optional[int] = None
88 ) -> None:
89 self._compress = ZLibCompressor(encoding=encoding, strategy=strategy)
91 def _write(
92 self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
93 ) -> None:
94 size = len(chunk)
95 self.buffer_size += size
96 self.output_size += size
97 transport = self._protocol.transport
98 if transport is None or transport.is_closing():
99 raise ClientConnectionResetError("Cannot write to closing transport")
100 transport.write(chunk) # type: ignore[arg-type]
102 def _writelines(
103 self,
104 chunks: Iterable[
105 Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
106 ],
107 ) -> None:
108 size = 0
109 for chunk in chunks:
110 size += len(chunk)
111 self.buffer_size += size
112 self.output_size += size
113 transport = self._protocol.transport
114 if transport is None or transport.is_closing():
115 raise ClientConnectionResetError("Cannot write to closing transport")
116 if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES:
117 transport.write(b"".join(chunks))
118 else:
119 transport.writelines(chunks) # type: ignore[arg-type]
121 async def write(
122 self,
123 chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"],
124 *,
125 drain: bool = True,
126 LIMIT: int = 0x10000,
127 ) -> None:
128 """Writes chunk of data to a stream.
130 write_eof() indicates end of stream.
131 writer can't be used after write_eof() method being called.
132 write() return drain future.
133 """
134 if self._on_chunk_sent is not None:
135 await self._on_chunk_sent(chunk)
137 if isinstance(chunk, memoryview):
138 if chunk.nbytes != len(chunk):
139 # just reshape it
140 chunk = chunk.cast("c")
142 if self._compress is not None:
143 chunk = await self._compress.compress(chunk)
144 if not chunk:
145 return
147 if self.length is not None:
148 chunk_len = len(chunk)
149 if self.length >= chunk_len:
150 self.length = self.length - chunk_len
151 else:
152 chunk = chunk[: self.length]
153 self.length = 0
154 if not chunk:
155 return
157 if chunk:
158 if self.chunked:
159 self._writelines(
160 (f"{len(chunk):x}\r\n".encode("ascii"), chunk, b"\r\n")
161 )
162 else:
163 self._write(chunk)
165 if self.buffer_size > LIMIT and drain:
166 self.buffer_size = 0
167 await self.drain()
169 async def write_headers(
170 self, status_line: str, headers: "CIMultiDict[str]"
171 ) -> None:
172 """Write request/response status and headers."""
173 if self._on_headers_sent is not None:
174 await self._on_headers_sent(headers)
176 # status + headers
177 buf = _serialize_headers(status_line, headers)
178 self._write(buf)
180 def set_eof(self) -> None:
181 """Indicate that the message is complete."""
182 self._eof = True
184 async def write_eof(self, chunk: bytes = b"") -> None:
185 if self._eof:
186 return
188 if chunk and self._on_chunk_sent is not None:
189 await self._on_chunk_sent(chunk)
191 if self._compress:
192 chunks: List[bytes] = []
193 chunks_len = 0
194 if chunk and (compressed_chunk := await self._compress.compress(chunk)):
195 chunks_len = len(compressed_chunk)
196 chunks.append(compressed_chunk)
198 flush_chunk = self._compress.flush()
199 chunks_len += len(flush_chunk)
200 chunks.append(flush_chunk)
201 assert chunks_len
203 if self.chunked:
204 chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii")
205 self._writelines((chunk_len_pre, *chunks, b"\r\n0\r\n\r\n"))
206 elif len(chunks) > 1:
207 self._writelines(chunks)
208 else:
209 self._write(chunks[0])
210 elif self.chunked:
211 if chunk:
212 chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii")
213 self._writelines((chunk_len_pre, chunk, b"\r\n0\r\n\r\n"))
214 else:
215 self._write(b"0\r\n\r\n")
216 elif chunk:
217 self._write(chunk)
219 await self.drain()
221 self._eof = True
223 async def drain(self) -> None:
224 """Flush the write buffer.
226 The intended use is to write
228 await w.write(data)
229 await w.drain()
230 """
231 protocol = self._protocol
232 if protocol.transport is not None and protocol._paused:
233 await protocol._drain_helper()
236def _safe_header(string: str) -> str:
237 if "\r" in string or "\n" in string:
238 raise ValueError(
239 "Newline or carriage return detected in headers. "
240 "Potential header injection attack."
241 )
242 return string
245def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes:
246 headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items())
247 line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n"
248 return line.encode("utf-8")
251_serialize_headers = _py_serialize_headers
253try:
254 import aiohttp._http_writer as _http_writer # type: ignore[import-not-found]
256 _c_serialize_headers = _http_writer._serialize_headers
257 if not NO_EXTENSIONS:
258 _serialize_headers = _c_serialize_headers
259except ImportError:
260 pass