Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/http_writer.py: 31%
118 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:52 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:52 +0000
1"""Http related parsers and protocol."""
3import asyncio
4import zlib
5from typing import Any, Awaitable, Callable, NamedTuple, Optional, Union # noqa
7from multidict import CIMultiDict
9from .abc import AbstractStreamWriter
10from .base_protocol import BaseProtocol
11from .compression_utils import ZLibCompressor
12from .helpers import NO_EXTENSIONS
14__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")
17class HttpVersion(NamedTuple):
18 major: int
19 minor: int
22HttpVersion10 = HttpVersion(1, 0)
23HttpVersion11 = HttpVersion(1, 1)
26_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
27_T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]]
30class StreamWriter(AbstractStreamWriter):
31 def __init__(
32 self,
33 protocol: BaseProtocol,
34 loop: asyncio.AbstractEventLoop,
35 on_chunk_sent: _T_OnChunkSent = None,
36 on_headers_sent: _T_OnHeadersSent = None,
37 ) -> None:
38 self._protocol = protocol
40 self.loop = loop
41 self.length = None
42 self.chunked = False
43 self.buffer_size = 0
44 self.output_size = 0
46 self._eof = False
47 self._compress: Optional[ZLibCompressor] = None
48 self._drain_waiter = None
50 self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent
51 self._on_headers_sent: _T_OnHeadersSent = on_headers_sent
53 @property
54 def transport(self) -> Optional[asyncio.Transport]:
55 return self._protocol.transport
57 @property
58 def protocol(self) -> BaseProtocol:
59 return self._protocol
61 def enable_chunking(self) -> None:
62 self.chunked = True
64 def enable_compression(
65 self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
66 ) -> None:
67 self._compress = ZLibCompressor(encoding=encoding, strategy=strategy)
69 def _write(self, chunk: bytes) -> None:
70 size = len(chunk)
71 self.buffer_size += size
72 self.output_size += size
73 transport = self.transport
74 if not self._protocol.connected or transport is None or transport.is_closing():
75 raise ConnectionResetError("Cannot write to closing transport")
76 transport.write(chunk)
78 async def write(
79 self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000
80 ) -> None:
81 """Writes chunk of data to a stream.
83 write_eof() indicates end of stream.
84 writer can't be used after write_eof() method being called.
85 write() return drain future.
86 """
87 if self._on_chunk_sent is not None:
88 await self._on_chunk_sent(chunk)
90 if isinstance(chunk, memoryview):
91 if chunk.nbytes != len(chunk):
92 # just reshape it
93 chunk = chunk.cast("c")
95 if self._compress is not None:
96 chunk = await self._compress.compress(chunk)
97 if not chunk:
98 return
100 if self.length is not None:
101 chunk_len = len(chunk)
102 if self.length >= chunk_len:
103 self.length = self.length - chunk_len
104 else:
105 chunk = chunk[: self.length]
106 self.length = 0
107 if not chunk:
108 return
110 if chunk:
111 if self.chunked:
112 chunk_len_pre = ("%x\r\n" % len(chunk)).encode("ascii")
113 chunk = chunk_len_pre + chunk + b"\r\n"
115 self._write(chunk)
117 if self.buffer_size > LIMIT and drain:
118 self.buffer_size = 0
119 await self.drain()
121 async def write_headers(
122 self, status_line: str, headers: "CIMultiDict[str]"
123 ) -> None:
124 """Write request/response status and headers."""
125 if self._on_headers_sent is not None:
126 await self._on_headers_sent(headers)
128 # status + headers
129 buf = _serialize_headers(status_line, headers)
130 self._write(buf)
132 async def write_eof(self, chunk: bytes = b"") -> None:
133 if self._eof:
134 return
136 if chunk and self._on_chunk_sent is not None:
137 await self._on_chunk_sent(chunk)
139 if self._compress:
140 if chunk:
141 chunk = await self._compress.compress(chunk)
143 chunk += self._compress.flush()
144 if chunk and self.chunked:
145 chunk_len = ("%x\r\n" % len(chunk)).encode("ascii")
146 chunk = chunk_len + chunk + b"\r\n0\r\n\r\n"
147 else:
148 if self.chunked:
149 if chunk:
150 chunk_len = ("%x\r\n" % len(chunk)).encode("ascii")
151 chunk = chunk_len + chunk + b"\r\n0\r\n\r\n"
152 else:
153 chunk = b"0\r\n\r\n"
155 if chunk:
156 self._write(chunk)
158 await self.drain()
160 self._eof = True
162 async def drain(self) -> None:
163 """Flush the write buffer.
165 The intended use is to write
167 await w.write(data)
168 await w.drain()
169 """
170 if self._protocol.transport is not None:
171 await self._protocol._drain_helper()
174def _safe_header(string: str) -> str:
175 if "\r" in string or "\n" in string:
176 raise ValueError(
177 "Newline or carriage return detected in headers. "
178 "Potential header injection attack."
179 )
180 return string
183def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes:
184 headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items())
185 line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n"
186 return line.encode("utf-8")
189_serialize_headers = _py_serialize_headers
191try:
192 import aiohttp._http_writer as _http_writer # type: ignore[import]
194 _c_serialize_headers = _http_writer._serialize_headers
195 if not NO_EXTENSIONS:
196 _serialize_headers = _c_serialize_headers
197except ImportError:
198 pass