Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/payload.py: 60%

220 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:52 +0000

1import asyncio 

2import enum 

3import io 

4import json 

5import mimetypes 

6import os 

7import warnings 

8from abc import ABC, abstractmethod 

9from itertools import chain 

10from typing import ( 

11 IO, 

12 TYPE_CHECKING, 

13 Any, 

14 ByteString, 

15 Dict, 

16 Iterable, 

17 Optional, 

18 TextIO, 

19 Tuple, 

20 Type, 

21 Union, 

22) 

23 

24from multidict import CIMultiDict 

25from typing_extensions import Final 

26 

27from . import hdrs 

28from .abc import AbstractStreamWriter 

29from .helpers import ( 

30 _SENTINEL, 

31 content_disposition_header, 

32 guess_filename, 

33 parse_mimetype, 

34 sentinel, 

35) 

36from .streams import StreamReader 

37from .typedefs import JSONEncoder, _CIMultiDict 

38 

39__all__ = ( 

40 "PAYLOAD_REGISTRY", 

41 "get_payload", 

42 "payload_type", 

43 "Payload", 

44 "BytesPayload", 

45 "StringPayload", 

46 "IOBasePayload", 

47 "BytesIOPayload", 

48 "BufferedReaderPayload", 

49 "TextIOPayload", 

50 "StringIOPayload", 

51 "JsonPayload", 

52 "AsyncIterablePayload", 

53) 

54 

55TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB 

56 

57if TYPE_CHECKING: # pragma: no cover 

58 from typing import List 

59 

60 

61class LookupError(Exception): 

62 pass 

63 

64 

65class Order(str, enum.Enum): 

66 normal = "normal" 

67 try_first = "try_first" 

68 try_last = "try_last" 

69 

70 

71def get_payload(data: Any, *args: Any, **kwargs: Any) -> "Payload": 

72 return PAYLOAD_REGISTRY.get(data, *args, **kwargs) 

73 

74 

75def register_payload( 

76 factory: Type["Payload"], type: Any, *, order: Order = Order.normal 

77) -> None: 

78 PAYLOAD_REGISTRY.register(factory, type, order=order) 

79 

80 

81class payload_type: 

82 def __init__(self, type: Any, *, order: Order = Order.normal) -> None: 

83 self.type = type 

84 self.order = order 

85 

86 def __call__(self, factory: Type["Payload"]) -> Type["Payload"]: 

87 register_payload(factory, self.type, order=self.order) 

88 return factory 

89 

90 

91PayloadType = Type["Payload"] 

92_PayloadRegistryItem = Tuple[PayloadType, Any] 

93 

94 

95class PayloadRegistry: 

96 """Payload registry. 

97 

98 note: we need zope.interface for more efficient adapter search 

99 """ 

100 

101 def __init__(self) -> None: 

102 self._first: List[_PayloadRegistryItem] = [] 

103 self._normal: List[_PayloadRegistryItem] = [] 

104 self._last: List[_PayloadRegistryItem] = [] 

105 

106 def get( 

107 self, 

108 data: Any, 

109 *args: Any, 

110 _CHAIN: "Type[chain[_PayloadRegistryItem]]" = chain, 

111 **kwargs: Any, 

112 ) -> "Payload": 

113 if isinstance(data, Payload): 

114 return data 

115 for factory, type in _CHAIN(self._first, self._normal, self._last): 

116 if isinstance(data, type): 

117 return factory(data, *args, **kwargs) 

118 

119 raise LookupError() 

120 

121 def register( 

122 self, factory: PayloadType, type: Any, *, order: Order = Order.normal 

123 ) -> None: 

124 if order is Order.try_first: 

125 self._first.append((factory, type)) 

126 elif order is Order.normal: 

127 self._normal.append((factory, type)) 

128 elif order is Order.try_last: 

129 self._last.append((factory, type)) 

130 else: 

131 raise ValueError(f"Unsupported order {order!r}") 

132 

133 

134class Payload(ABC): 

135 _default_content_type: str = "application/octet-stream" 

136 _size: Optional[int] = None 

137 

138 def __init__( 

139 self, 

140 value: Any, 

141 headers: Optional[ 

142 Union[_CIMultiDict, Dict[str, str], Iterable[Tuple[str, str]]] 

143 ] = None, 

144 content_type: Union[None, str, _SENTINEL] = sentinel, 

145 filename: Optional[str] = None, 

146 encoding: Optional[str] = None, 

147 **kwargs: Any, 

148 ) -> None: 

149 self._encoding = encoding 

150 self._filename = filename 

151 self._headers: _CIMultiDict = CIMultiDict() 

152 self._value = value 

153 if content_type is not sentinel and content_type is not None: 

154 assert isinstance(content_type, str) 

155 self._headers[hdrs.CONTENT_TYPE] = content_type 

156 elif self._filename is not None: 

157 content_type = mimetypes.guess_type(self._filename)[0] 

158 if content_type is None: 

159 content_type = self._default_content_type 

160 self._headers[hdrs.CONTENT_TYPE] = content_type 

161 else: 

162 self._headers[hdrs.CONTENT_TYPE] = self._default_content_type 

163 self._headers.update(headers or {}) 

164 

165 @property 

166 def size(self) -> Optional[int]: 

167 """Size of the payload.""" 

168 return self._size 

169 

170 @property 

171 def filename(self) -> Optional[str]: 

172 """Filename of the payload.""" 

173 return self._filename 

174 

175 @property 

176 def headers(self) -> _CIMultiDict: 

177 """Custom item headers""" 

178 return self._headers 

179 

180 @property 

181 def _binary_headers(self) -> bytes: 

182 return ( 

183 "".join([k + ": " + v + "\r\n" for k, v in self.headers.items()]).encode( 

184 "utf-8" 

185 ) 

186 + b"\r\n" 

187 ) 

188 

189 @property 

190 def encoding(self) -> Optional[str]: 

191 """Payload encoding""" 

192 return self._encoding 

193 

194 @property 

195 def content_type(self) -> str: 

196 """Content type""" 

197 return self._headers[hdrs.CONTENT_TYPE] 

198 

199 def set_content_disposition( 

200 self, 

201 disptype: str, 

202 quote_fields: bool = True, 

203 _charset: str = "utf-8", 

204 **params: Any, 

205 ) -> None: 

206 """Sets ``Content-Disposition`` header.""" 

207 self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header( 

208 disptype, quote_fields=quote_fields, _charset=_charset, **params 

209 ) 

210 

211 @abstractmethod 

212 async def write(self, writer: AbstractStreamWriter) -> None: 

213 """Write payload. 

214 

215 writer is an AbstractStreamWriter instance: 

216 """ 

217 

218 

219class BytesPayload(Payload): 

220 def __init__(self, value: ByteString, *args: Any, **kwargs: Any) -> None: 

221 if not isinstance(value, (bytes, bytearray, memoryview)): 

222 raise TypeError(f"value argument must be byte-ish, not {type(value)!r}") 

223 

224 if "content_type" not in kwargs: 

225 kwargs["content_type"] = "application/octet-stream" 

226 

227 super().__init__(value, *args, **kwargs) 

228 

229 if isinstance(value, memoryview): 

230 self._size = value.nbytes 

231 else: 

232 self._size = len(value) 

233 

234 if self._size > TOO_LARGE_BYTES_BODY: 

235 warnings.warn( 

236 "Sending a large body directly with raw bytes might" 

237 " lock the event loop. You should probably pass an " 

238 "io.BytesIO object instead", 

239 ResourceWarning, 

240 source=self, 

241 ) 

242 

243 async def write(self, writer: AbstractStreamWriter) -> None: 

244 await writer.write(self._value) 

245 

246 

247class StringPayload(BytesPayload): 

248 def __init__( 

249 self, 

250 value: str, 

251 *args: Any, 

252 encoding: Optional[str] = None, 

253 content_type: Optional[str] = None, 

254 **kwargs: Any, 

255 ) -> None: 

256 if encoding is None: 

257 if content_type is None: 

258 real_encoding = "utf-8" 

259 content_type = "text/plain; charset=utf-8" 

260 else: 

261 mimetype = parse_mimetype(content_type) 

262 real_encoding = mimetype.parameters.get("charset", "utf-8") 

263 else: 

264 if content_type is None: 

265 content_type = "text/plain; charset=%s" % encoding 

266 real_encoding = encoding 

267 

268 super().__init__( 

269 value.encode(real_encoding), 

270 encoding=real_encoding, 

271 content_type=content_type, 

272 *args, 

273 **kwargs, 

274 ) 

275 

276 

277class StringIOPayload(StringPayload): 

278 def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None: 

279 super().__init__(value.read(), *args, **kwargs) 

280 

281 

282class IOBasePayload(Payload): 

283 _value: IO[Any] 

284 

285 def __init__( 

286 self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any 

287 ) -> None: 

288 if "filename" not in kwargs: 

289 kwargs["filename"] = guess_filename(value) 

290 

291 super().__init__(value, *args, **kwargs) 

292 

293 if self._filename is not None and disposition is not None: 

294 if hdrs.CONTENT_DISPOSITION not in self.headers: 

295 self.set_content_disposition(disposition, filename=self._filename) 

296 

297 async def write(self, writer: AbstractStreamWriter) -> None: 

298 loop = asyncio.get_event_loop() 

299 try: 

300 chunk = await loop.run_in_executor(None, self._value.read, 2**16) 

301 while chunk: 

302 await writer.write(chunk) 

303 chunk = await loop.run_in_executor(None, self._value.read, 2**16) 

304 finally: 

305 await loop.run_in_executor(None, self._value.close) 

306 

307 

308class TextIOPayload(IOBasePayload): 

309 _value: TextIO 

310 

311 def __init__( 

312 self, 

313 value: TextIO, 

314 *args: Any, 

315 encoding: Optional[str] = None, 

316 content_type: Optional[str] = None, 

317 **kwargs: Any, 

318 ) -> None: 

319 if encoding is None: 

320 if content_type is None: 

321 encoding = "utf-8" 

322 content_type = "text/plain; charset=utf-8" 

323 else: 

324 mimetype = parse_mimetype(content_type) 

325 encoding = mimetype.parameters.get("charset", "utf-8") 

326 else: 

327 if content_type is None: 

328 content_type = "text/plain; charset=%s" % encoding 

329 

330 super().__init__( 

331 value, 

332 content_type=content_type, 

333 encoding=encoding, 

334 *args, 

335 **kwargs, 

336 ) 

337 

338 @property 

339 def size(self) -> Optional[int]: 

340 try: 

341 return os.fstat(self._value.fileno()).st_size - self._value.tell() 

342 except OSError: 

343 return None 

344 

345 async def write(self, writer: AbstractStreamWriter) -> None: 

346 loop = asyncio.get_event_loop() 

347 try: 

348 chunk = await loop.run_in_executor(None, self._value.read, 2**16) 

349 while chunk: 

350 data = ( 

351 chunk.encode(encoding=self._encoding) 

352 if self._encoding 

353 else chunk.encode() 

354 ) 

355 await writer.write(data) 

356 chunk = await loop.run_in_executor(None, self._value.read, 2**16) 

357 finally: 

358 await loop.run_in_executor(None, self._value.close) 

359 

360 

361class BytesIOPayload(IOBasePayload): 

362 @property 

363 def size(self) -> int: 

364 position = self._value.tell() 

365 end = self._value.seek(0, os.SEEK_END) 

366 self._value.seek(position) 

367 return end - position 

368 

369 

370class BufferedReaderPayload(IOBasePayload): 

371 @property 

372 def size(self) -> Optional[int]: 

373 try: 

374 return os.fstat(self._value.fileno()).st_size - self._value.tell() 

375 except OSError: 

376 # data.fileno() is not supported, e.g. 

377 # io.BufferedReader(io.BytesIO(b'data')) 

378 return None 

379 

380 

381class JsonPayload(BytesPayload): 

382 def __init__( 

383 self, 

384 value: Any, 

385 encoding: str = "utf-8", 

386 content_type: str = "application/json", 

387 dumps: JSONEncoder = json.dumps, 

388 *args: Any, 

389 **kwargs: Any, 

390 ) -> None: 

391 super().__init__( 

392 dumps(value).encode(encoding), 

393 content_type=content_type, 

394 encoding=encoding, 

395 *args, 

396 **kwargs, 

397 ) 

398 

399 

400if TYPE_CHECKING: # pragma: no cover 

401 from typing import AsyncIterable, AsyncIterator 

402 

403 _AsyncIterator = AsyncIterator[bytes] 

404 _AsyncIterable = AsyncIterable[bytes] 

405else: 

406 from collections.abc import AsyncIterable, AsyncIterator 

407 

408 _AsyncIterator = AsyncIterator 

409 _AsyncIterable = AsyncIterable 

410 

411 

412class AsyncIterablePayload(Payload): 

413 _iter: Optional[_AsyncIterator] = None 

414 

415 def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None: 

416 if not isinstance(value, AsyncIterable): 

417 raise TypeError( 

418 "value argument must support " 

419 "collections.abc.AsyncIterablebe interface, " 

420 "got {!r}".format(type(value)) 

421 ) 

422 

423 if "content_type" not in kwargs: 

424 kwargs["content_type"] = "application/octet-stream" 

425 

426 super().__init__(value, *args, **kwargs) 

427 

428 self._iter = value.__aiter__() 

429 

430 async def write(self, writer: AbstractStreamWriter) -> None: 

431 if self._iter: 

432 try: 

433 # iter is not None check prevents rare cases 

434 # when the case iterable is used twice 

435 while True: 

436 chunk = await self._iter.__anext__() 

437 await writer.write(chunk) 

438 except StopAsyncIteration: 

439 self._iter = None 

440 

441 

442class StreamReaderPayload(AsyncIterablePayload): 

443 def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None: 

444 super().__init__(value.iter_any(), *args, **kwargs) 

445 

446 

447PAYLOAD_REGISTRY = PayloadRegistry() 

448PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview)) 

449PAYLOAD_REGISTRY.register(StringPayload, str) 

450PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO) 

451PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase) 

452PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO) 

453PAYLOAD_REGISTRY.register(BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom)) 

454PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase) 

455PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader) 

456# try_last for giving a chance to more specialized async interables like 

457# multidict.BodyPartReaderPayload override the default 

458PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, order=Order.try_last)