Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/http_writer.py: 34%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

143 statements  

1"""Http related parsers and protocol.""" 

2 

3import asyncio 

4import sys 

5from typing import ( # noqa 

6 Any, 

7 Awaitable, 

8 Callable, 

9 Iterable, 

10 List, 

11 NamedTuple, 

12 Optional, 

13 Union, 

14) 

15 

16from multidict import CIMultiDict 

17 

18from .abc import AbstractStreamWriter 

19from .base_protocol import BaseProtocol 

20from .client_exceptions import ClientConnectionResetError 

21from .compression_utils import ZLibCompressor 

22from .helpers import NO_EXTENSIONS 

23 

24__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11") 

25 

26 

27MIN_PAYLOAD_FOR_WRITELINES = 2048 

28IS_PY313_BEFORE_313_2 = (3, 13, 0) <= sys.version_info < (3, 13, 2) 

29IS_PY_BEFORE_312_9 = sys.version_info < (3, 12, 9) 

30SKIP_WRITELINES = IS_PY313_BEFORE_313_2 or IS_PY_BEFORE_312_9 

31# writelines is not safe for use 

32# on Python 3.12+ until 3.12.9 

33# on Python 3.13+ until 3.13.2 

34# and on older versions it not any faster than write 

35# CVE-2024-12254: https://github.com/python/cpython/pull/127656 

36 

37 

38class HttpVersion(NamedTuple): 

39 major: int 

40 minor: int 

41 

42 

43HttpVersion10 = HttpVersion(1, 0) 

44HttpVersion11 = HttpVersion(1, 1) 

45 

46 

47_T_OnChunkSent = Optional[ 

48 Callable[ 

49 [Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]], 

50 Awaitable[None], 

51 ] 

52] 

53_T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]] 

54 

55 

56class StreamWriter(AbstractStreamWriter): 

57 

58 length: Optional[int] = None 

59 chunked: bool = False 

60 _eof: bool = False 

61 _compress: Optional[ZLibCompressor] = None 

62 

63 def __init__( 

64 self, 

65 protocol: BaseProtocol, 

66 loop: asyncio.AbstractEventLoop, 

67 on_chunk_sent: _T_OnChunkSent = None, 

68 on_headers_sent: _T_OnHeadersSent = None, 

69 ) -> None: 

70 self._protocol = protocol 

71 self.loop = loop 

72 self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent 

73 self._on_headers_sent: _T_OnHeadersSent = on_headers_sent 

74 

75 @property 

76 def transport(self) -> Optional[asyncio.Transport]: 

77 return self._protocol.transport 

78 

79 @property 

80 def protocol(self) -> BaseProtocol: 

81 return self._protocol 

82 

83 def enable_chunking(self) -> None: 

84 self.chunked = True 

85 

86 def enable_compression( 

87 self, encoding: str = "deflate", strategy: Optional[int] = None 

88 ) -> None: 

89 self._compress = ZLibCompressor(encoding=encoding, strategy=strategy) 

90 

91 def _write( 

92 self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] 

93 ) -> None: 

94 size = len(chunk) 

95 self.buffer_size += size 

96 self.output_size += size 

97 transport = self._protocol.transport 

98 if transport is None or transport.is_closing(): 

99 raise ClientConnectionResetError("Cannot write to closing transport") 

100 transport.write(chunk) # type: ignore[arg-type] 

101 

102 def _writelines( 

103 self, 

104 chunks: Iterable[ 

105 Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] 

106 ], 

107 ) -> None: 

108 size = 0 

109 for chunk in chunks: 

110 size += len(chunk) 

111 self.buffer_size += size 

112 self.output_size += size 

113 transport = self._protocol.transport 

114 if transport is None or transport.is_closing(): 

115 raise ClientConnectionResetError("Cannot write to closing transport") 

116 if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES: 

117 transport.write(b"".join(chunks)) 

118 else: 

119 transport.writelines(chunks) # type: ignore[arg-type] 

120 

121 async def write( 

122 self, 

123 chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"], 

124 *, 

125 drain: bool = True, 

126 LIMIT: int = 0x10000, 

127 ) -> None: 

128 """Writes chunk of data to a stream. 

129 

130 write_eof() indicates end of stream. 

131 writer can't be used after write_eof() method being called. 

132 write() return drain future. 

133 """ 

134 if self._on_chunk_sent is not None: 

135 await self._on_chunk_sent(chunk) 

136 

137 if isinstance(chunk, memoryview): 

138 if chunk.nbytes != len(chunk): 

139 # just reshape it 

140 chunk = chunk.cast("c") 

141 

142 if self._compress is not None: 

143 chunk = await self._compress.compress(chunk) 

144 if not chunk: 

145 return 

146 

147 if self.length is not None: 

148 chunk_len = len(chunk) 

149 if self.length >= chunk_len: 

150 self.length = self.length - chunk_len 

151 else: 

152 chunk = chunk[: self.length] 

153 self.length = 0 

154 if not chunk: 

155 return 

156 

157 if chunk: 

158 if self.chunked: 

159 self._writelines( 

160 (f"{len(chunk):x}\r\n".encode("ascii"), chunk, b"\r\n") 

161 ) 

162 else: 

163 self._write(chunk) 

164 

165 if self.buffer_size > LIMIT and drain: 

166 self.buffer_size = 0 

167 await self.drain() 

168 

169 async def write_headers( 

170 self, status_line: str, headers: "CIMultiDict[str]" 

171 ) -> None: 

172 """Write request/response status and headers.""" 

173 if self._on_headers_sent is not None: 

174 await self._on_headers_sent(headers) 

175 

176 # status + headers 

177 buf = _serialize_headers(status_line, headers) 

178 self._write(buf) 

179 

180 def set_eof(self) -> None: 

181 """Indicate that the message is complete.""" 

182 self._eof = True 

183 

184 async def write_eof(self, chunk: bytes = b"") -> None: 

185 if self._eof: 

186 return 

187 

188 if chunk and self._on_chunk_sent is not None: 

189 await self._on_chunk_sent(chunk) 

190 

191 if self._compress: 

192 chunks: List[bytes] = [] 

193 chunks_len = 0 

194 if chunk and (compressed_chunk := await self._compress.compress(chunk)): 

195 chunks_len = len(compressed_chunk) 

196 chunks.append(compressed_chunk) 

197 

198 flush_chunk = self._compress.flush() 

199 chunks_len += len(flush_chunk) 

200 chunks.append(flush_chunk) 

201 assert chunks_len 

202 

203 if self.chunked: 

204 chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii") 

205 self._writelines((chunk_len_pre, *chunks, b"\r\n0\r\n\r\n")) 

206 elif len(chunks) > 1: 

207 self._writelines(chunks) 

208 else: 

209 self._write(chunks[0]) 

210 elif self.chunked: 

211 if chunk: 

212 chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii") 

213 self._writelines((chunk_len_pre, chunk, b"\r\n0\r\n\r\n")) 

214 else: 

215 self._write(b"0\r\n\r\n") 

216 elif chunk: 

217 self._write(chunk) 

218 

219 await self.drain() 

220 

221 self._eof = True 

222 

223 async def drain(self) -> None: 

224 """Flush the write buffer. 

225 

226 The intended use is to write 

227 

228 await w.write(data) 

229 await w.drain() 

230 """ 

231 protocol = self._protocol 

232 if protocol.transport is not None and protocol._paused: 

233 await protocol._drain_helper() 

234 

235 

236def _safe_header(string: str) -> str: 

237 if "\r" in string or "\n" in string: 

238 raise ValueError( 

239 "Newline or carriage return detected in headers. " 

240 "Potential header injection attack." 

241 ) 

242 return string 

243 

244 

245def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes: 

246 headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items()) 

247 line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n" 

248 return line.encode("utf-8") 

249 

250 

251_serialize_headers = _py_serialize_headers 

252 

253try: 

254 import aiohttp._http_writer as _http_writer # type: ignore[import-not-found] 

255 

256 _c_serialize_headers = _http_writer._serialize_headers 

257 if not NO_EXTENSIONS: 

258 _serialize_headers = _c_serialize_headers 

259except ImportError: 

260 pass