Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/aiohttp/http_writer.py: 31%
118 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:56 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:56 +0000
1"""Http related parsers and protocol."""
3import asyncio
4import zlib
5from typing import Any, Awaitable, Callable, NamedTuple, Optional, Union # noqa
7from multidict import CIMultiDict
9from .abc import AbstractStreamWriter
10from .base_protocol import BaseProtocol
11from .helpers import NO_EXTENSIONS
13__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")
16class HttpVersion(NamedTuple):
17 major: int
18 minor: int
21HttpVersion10 = HttpVersion(1, 0)
22HttpVersion11 = HttpVersion(1, 1)
25_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
26_T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]]
29class StreamWriter(AbstractStreamWriter):
30 def __init__(
31 self,
32 protocol: BaseProtocol,
33 loop: asyncio.AbstractEventLoop,
34 on_chunk_sent: _T_OnChunkSent = None,
35 on_headers_sent: _T_OnHeadersSent = None,
36 ) -> None:
37 self._protocol = protocol
39 self.loop = loop
40 self.length = None
41 self.chunked = False
42 self.buffer_size = 0
43 self.output_size = 0
45 self._eof = False
46 self._compress: Any = None
47 self._drain_waiter = None
49 self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent
50 self._on_headers_sent: _T_OnHeadersSent = on_headers_sent
52 @property
53 def transport(self) -> Optional[asyncio.Transport]:
54 return self._protocol.transport
56 @property
57 def protocol(self) -> BaseProtocol:
58 return self._protocol
60 def enable_chunking(self) -> None:
61 self.chunked = True
63 def enable_compression(
64 self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
65 ) -> None:
66 zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else zlib.MAX_WBITS
67 self._compress = zlib.compressobj(wbits=zlib_mode, strategy=strategy)
69 def _write(self, chunk: bytes) -> None:
70 size = len(chunk)
71 self.buffer_size += size
72 self.output_size += size
73 transport = self.transport
74 if not self._protocol.connected or transport is None or transport.is_closing():
75 raise ConnectionResetError("Cannot write to closing transport")
76 transport.write(chunk)
78 async def write(
79 self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000
80 ) -> None:
81 """Writes chunk of data to a stream.
83 write_eof() indicates end of stream.
84 writer can't be used after write_eof() method being called.
85 write() return drain future.
86 """
87 if self._on_chunk_sent is not None:
88 await self._on_chunk_sent(chunk)
90 if isinstance(chunk, memoryview):
91 if chunk.nbytes != len(chunk):
92 # just reshape it
93 chunk = chunk.cast("c")
95 if self._compress is not None:
96 chunk = self._compress.compress(chunk)
97 if not chunk:
98 return
100 if self.length is not None:
101 chunk_len = len(chunk)
102 if self.length >= chunk_len:
103 self.length = self.length - chunk_len
104 else:
105 chunk = chunk[: self.length]
106 self.length = 0
107 if not chunk:
108 return
110 if chunk:
111 if self.chunked:
112 chunk_len_pre = ("%x\r\n" % len(chunk)).encode("ascii")
113 chunk = chunk_len_pre + chunk + b"\r\n"
115 self._write(chunk)
117 if self.buffer_size > LIMIT and drain:
118 self.buffer_size = 0
119 await self.drain()
121 async def write_headers(
122 self, status_line: str, headers: "CIMultiDict[str]"
123 ) -> None:
124 """Write request/response status and headers."""
125 if self._on_headers_sent is not None:
126 await self._on_headers_sent(headers)
128 # status + headers
129 buf = _serialize_headers(status_line, headers)
130 self._write(buf)
132 async def write_eof(self, chunk: bytes = b"") -> None:
133 if self._eof:
134 return
136 if chunk and self._on_chunk_sent is not None:
137 await self._on_chunk_sent(chunk)
139 if self._compress:
140 if chunk:
141 chunk = self._compress.compress(chunk)
143 chunk = chunk + self._compress.flush()
144 if chunk and self.chunked:
145 chunk_len = ("%x\r\n" % len(chunk)).encode("ascii")
146 chunk = chunk_len + chunk + b"\r\n0\r\n\r\n"
147 else:
148 if self.chunked:
149 if chunk:
150 chunk_len = ("%x\r\n" % len(chunk)).encode("ascii")
151 chunk = chunk_len + chunk + b"\r\n0\r\n\r\n"
152 else:
153 chunk = b"0\r\n\r\n"
155 if chunk:
156 self._write(chunk)
158 await self.drain()
160 self._eof = True
162 async def drain(self) -> None:
163 """Flush the write buffer.
165 The intended use is to write
167 await w.write(data)
168 await w.drain()
169 """
170 if self._protocol.transport is not None:
171 await self._protocol._drain_helper()
174def _safe_header(string: str) -> str:
175 if "\r" in string or "\n" in string:
176 raise ValueError(
177 "Newline or carriage return detected in headers. "
178 "Potential header injection attack."
179 )
180 return string
183def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes:
184 headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items())
185 line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n"
186 return line.encode("utf-8")
189_serialize_headers = _py_serialize_headers
191try:
192 import aiohttp._http_writer as _http_writer # type: ignore[import]
194 _c_serialize_headers = _http_writer._serialize_headers
195 if not NO_EXTENSIONS:
196 _serialize_headers = _c_serialize_headers
197except ImportError:
198 pass