Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/aiohttp/http_writer.py: 31%

118 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 06:40 +0000

1"""Http related parsers and protocol.""" 

2 

3import asyncio 

4import zlib 

5from typing import Any, Awaitable, Callable, NamedTuple, Optional, Union # noqa 

6 

7from multidict import CIMultiDict 

8 

9from .abc import AbstractStreamWriter 

10from .base_protocol import BaseProtocol 

11from .compression_utils import ZLibCompressor 

12from .helpers import NO_EXTENSIONS 

13 

14__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11") 

15 

16 

17class HttpVersion(NamedTuple): 

18 major: int 

19 minor: int 

20 

21 

22HttpVersion10 = HttpVersion(1, 0) 

23HttpVersion11 = HttpVersion(1, 1) 

24 

25 

26_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]] 

27_T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]] 

28 

29 

30class StreamWriter(AbstractStreamWriter): 

31 def __init__( 

32 self, 

33 protocol: BaseProtocol, 

34 loop: asyncio.AbstractEventLoop, 

35 on_chunk_sent: _T_OnChunkSent = None, 

36 on_headers_sent: _T_OnHeadersSent = None, 

37 ) -> None: 

38 self._protocol = protocol 

39 

40 self.loop = loop 

41 self.length = None 

42 self.chunked = False 

43 self.buffer_size = 0 

44 self.output_size = 0 

45 

46 self._eof = False 

47 self._compress: Optional[ZLibCompressor] = None 

48 self._drain_waiter = None 

49 

50 self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent 

51 self._on_headers_sent: _T_OnHeadersSent = on_headers_sent 

52 

53 @property 

54 def transport(self) -> Optional[asyncio.Transport]: 

55 return self._protocol.transport 

56 

57 @property 

58 def protocol(self) -> BaseProtocol: 

59 return self._protocol 

60 

61 def enable_chunking(self) -> None: 

62 self.chunked = True 

63 

64 def enable_compression( 

65 self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY 

66 ) -> None: 

67 self._compress = ZLibCompressor(encoding=encoding, strategy=strategy) 

68 

69 def _write(self, chunk: bytes) -> None: 

70 size = len(chunk) 

71 self.buffer_size += size 

72 self.output_size += size 

73 transport = self.transport 

74 if not self._protocol.connected or transport is None or transport.is_closing(): 

75 raise ConnectionResetError("Cannot write to closing transport") 

76 transport.write(chunk) 

77 

78 async def write( 

79 self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000 

80 ) -> None: 

81 """Writes chunk of data to a stream. 

82 

83 write_eof() indicates end of stream. 

84 writer can't be used after write_eof() method being called. 

85 write() return drain future. 

86 """ 

87 if self._on_chunk_sent is not None: 

88 await self._on_chunk_sent(chunk) 

89 

90 if isinstance(chunk, memoryview): 

91 if chunk.nbytes != len(chunk): 

92 # just reshape it 

93 chunk = chunk.cast("c") 

94 

95 if self._compress is not None: 

96 chunk = await self._compress.compress(chunk) 

97 if not chunk: 

98 return 

99 

100 if self.length is not None: 

101 chunk_len = len(chunk) 

102 if self.length >= chunk_len: 

103 self.length = self.length - chunk_len 

104 else: 

105 chunk = chunk[: self.length] 

106 self.length = 0 

107 if not chunk: 

108 return 

109 

110 if chunk: 

111 if self.chunked: 

112 chunk_len_pre = ("%x\r\n" % len(chunk)).encode("ascii") 

113 chunk = chunk_len_pre + chunk + b"\r\n" 

114 

115 self._write(chunk) 

116 

117 if self.buffer_size > LIMIT and drain: 

118 self.buffer_size = 0 

119 await self.drain() 

120 

121 async def write_headers( 

122 self, status_line: str, headers: "CIMultiDict[str]" 

123 ) -> None: 

124 """Write request/response status and headers.""" 

125 if self._on_headers_sent is not None: 

126 await self._on_headers_sent(headers) 

127 

128 # status + headers 

129 buf = _serialize_headers(status_line, headers) 

130 self._write(buf) 

131 

132 async def write_eof(self, chunk: bytes = b"") -> None: 

133 if self._eof: 

134 return 

135 

136 if chunk and self._on_chunk_sent is not None: 

137 await self._on_chunk_sent(chunk) 

138 

139 if self._compress: 

140 if chunk: 

141 chunk = await self._compress.compress(chunk) 

142 

143 chunk += self._compress.flush() 

144 if chunk and self.chunked: 

145 chunk_len = ("%x\r\n" % len(chunk)).encode("ascii") 

146 chunk = chunk_len + chunk + b"\r\n0\r\n\r\n" 

147 else: 

148 if self.chunked: 

149 if chunk: 

150 chunk_len = ("%x\r\n" % len(chunk)).encode("ascii") 

151 chunk = chunk_len + chunk + b"\r\n0\r\n\r\n" 

152 else: 

153 chunk = b"0\r\n\r\n" 

154 

155 if chunk: 

156 self._write(chunk) 

157 

158 await self.drain() 

159 

160 self._eof = True 

161 

162 async def drain(self) -> None: 

163 """Flush the write buffer. 

164 

165 The intended use is to write 

166 

167 await w.write(data) 

168 await w.drain() 

169 """ 

170 if self._protocol.transport is not None: 

171 await self._protocol._drain_helper() 

172 

173 

174def _safe_header(string: str) -> str: 

175 if "\r" in string or "\n" in string: 

176 raise ValueError( 

177 "Newline or carriage return detected in headers. " 

178 "Potential header injection attack." 

179 ) 

180 return string 

181 

182 

183def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes: 

184 headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items()) 

185 line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n" 

186 return line.encode("utf-8") 

187 

188 

189_serialize_headers = _py_serialize_headers 

190 

191try: 

192 import aiohttp._http_writer as _http_writer # type: ignore[import-not-found] 

193 

194 _c_serialize_headers = _http_writer._serialize_headers 

195 if not NO_EXTENSIONS: 

196 _serialize_headers = _c_serialize_headers 

197except ImportError: 

198 pass