Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/aiohttp/payload.py: 59%

221 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:56 +0000

1import asyncio 

2import enum 

3import io 

4import json 

5import mimetypes 

6import os 

7import warnings 

8from abc import ABC, abstractmethod 

9from itertools import chain 

10from typing import ( 

11 IO, 

12 TYPE_CHECKING, 

13 Any, 

14 ByteString, 

15 Dict, 

16 Iterable, 

17 Optional, 

18 TextIO, 

19 Tuple, 

20 Type, 

21 Union, 

22) 

23 

24from multidict import CIMultiDict 

25 

26from . import hdrs 

27from .abc import AbstractStreamWriter 

28from .helpers import ( 

29 PY_36, 

30 content_disposition_header, 

31 guess_filename, 

32 parse_mimetype, 

33 sentinel, 

34) 

35from .streams import StreamReader 

36from .typedefs import Final, JSONEncoder, _CIMultiDict 

37 

38__all__ = ( 

39 "PAYLOAD_REGISTRY", 

40 "get_payload", 

41 "payload_type", 

42 "Payload", 

43 "BytesPayload", 

44 "StringPayload", 

45 "IOBasePayload", 

46 "BytesIOPayload", 

47 "BufferedReaderPayload", 

48 "TextIOPayload", 

49 "StringIOPayload", 

50 "JsonPayload", 

51 "AsyncIterablePayload", 

52) 

53 

54TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB 

55 

56if TYPE_CHECKING: # pragma: no cover 

57 from typing import List 

58 

59 

60class LookupError(Exception): 

61 pass 

62 

63 

64class Order(str, enum.Enum): 

65 normal = "normal" 

66 try_first = "try_first" 

67 try_last = "try_last" 

68 

69 

70def get_payload(data: Any, *args: Any, **kwargs: Any) -> "Payload": 

71 return PAYLOAD_REGISTRY.get(data, *args, **kwargs) 

72 

73 

74def register_payload( 

75 factory: Type["Payload"], type: Any, *, order: Order = Order.normal 

76) -> None: 

77 PAYLOAD_REGISTRY.register(factory, type, order=order) 

78 

79 

80class payload_type: 

81 def __init__(self, type: Any, *, order: Order = Order.normal) -> None: 

82 self.type = type 

83 self.order = order 

84 

85 def __call__(self, factory: Type["Payload"]) -> Type["Payload"]: 

86 register_payload(factory, self.type, order=self.order) 

87 return factory 

88 

89 

90PayloadType = Type["Payload"] 

91_PayloadRegistryItem = Tuple[PayloadType, Any] 

92 

93 

94class PayloadRegistry: 

95 """Payload registry. 

96 

97 note: we need zope.interface for more efficient adapter search 

98 """ 

99 

100 def __init__(self) -> None: 

101 self._first: List[_PayloadRegistryItem] = [] 

102 self._normal: List[_PayloadRegistryItem] = [] 

103 self._last: List[_PayloadRegistryItem] = [] 

104 

105 def get( 

106 self, 

107 data: Any, 

108 *args: Any, 

109 _CHAIN: "Type[chain[_PayloadRegistryItem]]" = chain, 

110 **kwargs: Any, 

111 ) -> "Payload": 

112 if isinstance(data, Payload): 

113 return data 

114 for factory, type in _CHAIN(self._first, self._normal, self._last): 

115 if isinstance(data, type): 

116 return factory(data, *args, **kwargs) 

117 

118 raise LookupError() 

119 

120 def register( 

121 self, factory: PayloadType, type: Any, *, order: Order = Order.normal 

122 ) -> None: 

123 if order is Order.try_first: 

124 self._first.append((factory, type)) 

125 elif order is Order.normal: 

126 self._normal.append((factory, type)) 

127 elif order is Order.try_last: 

128 self._last.append((factory, type)) 

129 else: 

130 raise ValueError(f"Unsupported order {order!r}") 

131 

132 

133class Payload(ABC): 

134 

135 _default_content_type: str = "application/octet-stream" 

136 _size: Optional[int] = None 

137 

138 def __init__( 

139 self, 

140 value: Any, 

141 headers: Optional[ 

142 Union[_CIMultiDict, Dict[str, str], Iterable[Tuple[str, str]]] 

143 ] = None, 

144 content_type: Optional[str] = sentinel, 

145 filename: Optional[str] = None, 

146 encoding: Optional[str] = None, 

147 **kwargs: Any, 

148 ) -> None: 

149 self._encoding = encoding 

150 self._filename = filename 

151 self._headers: _CIMultiDict = CIMultiDict() 

152 self._value = value 

153 if content_type is not sentinel and content_type is not None: 

154 self._headers[hdrs.CONTENT_TYPE] = content_type 

155 elif self._filename is not None: 

156 content_type = mimetypes.guess_type(self._filename)[0] 

157 if content_type is None: 

158 content_type = self._default_content_type 

159 self._headers[hdrs.CONTENT_TYPE] = content_type 

160 else: 

161 self._headers[hdrs.CONTENT_TYPE] = self._default_content_type 

162 self._headers.update(headers or {}) 

163 

164 @property 

165 def size(self) -> Optional[int]: 

166 """Size of the payload.""" 

167 return self._size 

168 

169 @property 

170 def filename(self) -> Optional[str]: 

171 """Filename of the payload.""" 

172 return self._filename 

173 

174 @property 

175 def headers(self) -> _CIMultiDict: 

176 """Custom item headers""" 

177 return self._headers 

178 

179 @property 

180 def _binary_headers(self) -> bytes: 

181 return ( 

182 "".join([k + ": " + v + "\r\n" for k, v in self.headers.items()]).encode( 

183 "utf-8" 

184 ) 

185 + b"\r\n" 

186 ) 

187 

188 @property 

189 def encoding(self) -> Optional[str]: 

190 """Payload encoding""" 

191 return self._encoding 

192 

193 @property 

194 def content_type(self) -> str: 

195 """Content type""" 

196 return self._headers[hdrs.CONTENT_TYPE] 

197 

198 def set_content_disposition( 

199 self, 

200 disptype: str, 

201 quote_fields: bool = True, 

202 _charset: str = "utf-8", 

203 **params: Any, 

204 ) -> None: 

205 """Sets ``Content-Disposition`` header.""" 

206 self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header( 

207 disptype, quote_fields=quote_fields, _charset=_charset, **params 

208 ) 

209 

210 @abstractmethod 

211 async def write(self, writer: AbstractStreamWriter) -> None: 

212 """Write payload. 

213 

214 writer is an AbstractStreamWriter instance: 

215 """ 

216 

217 

218class BytesPayload(Payload): 

219 def __init__(self, value: ByteString, *args: Any, **kwargs: Any) -> None: 

220 if not isinstance(value, (bytes, bytearray, memoryview)): 

221 raise TypeError(f"value argument must be byte-ish, not {type(value)!r}") 

222 

223 if "content_type" not in kwargs: 

224 kwargs["content_type"] = "application/octet-stream" 

225 

226 super().__init__(value, *args, **kwargs) 

227 

228 if isinstance(value, memoryview): 

229 self._size = value.nbytes 

230 else: 

231 self._size = len(value) 

232 

233 if self._size > TOO_LARGE_BYTES_BODY: 

234 if PY_36: 

235 kwargs = {"source": self} 

236 else: 

237 kwargs = {} 

238 warnings.warn( 

239 "Sending a large body directly with raw bytes might" 

240 " lock the event loop. You should probably pass an " 

241 "io.BytesIO object instead", 

242 ResourceWarning, 

243 **kwargs, 

244 ) 

245 

246 async def write(self, writer: AbstractStreamWriter) -> None: 

247 await writer.write(self._value) 

248 

249 

250class StringPayload(BytesPayload): 

251 def __init__( 

252 self, 

253 value: str, 

254 *args: Any, 

255 encoding: Optional[str] = None, 

256 content_type: Optional[str] = None, 

257 **kwargs: Any, 

258 ) -> None: 

259 

260 if encoding is None: 

261 if content_type is None: 

262 real_encoding = "utf-8" 

263 content_type = "text/plain; charset=utf-8" 

264 else: 

265 mimetype = parse_mimetype(content_type) 

266 real_encoding = mimetype.parameters.get("charset", "utf-8") 

267 else: 

268 if content_type is None: 

269 content_type = "text/plain; charset=%s" % encoding 

270 real_encoding = encoding 

271 

272 super().__init__( 

273 value.encode(real_encoding), 

274 encoding=real_encoding, 

275 content_type=content_type, 

276 *args, 

277 **kwargs, 

278 ) 

279 

280 

281class StringIOPayload(StringPayload): 

282 def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None: 

283 super().__init__(value.read(), *args, **kwargs) 

284 

285 

286class IOBasePayload(Payload): 

287 _value: IO[Any] 

288 

289 def __init__( 

290 self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any 

291 ) -> None: 

292 if "filename" not in kwargs: 

293 kwargs["filename"] = guess_filename(value) 

294 

295 super().__init__(value, *args, **kwargs) 

296 

297 if self._filename is not None and disposition is not None: 

298 if hdrs.CONTENT_DISPOSITION not in self.headers: 

299 self.set_content_disposition(disposition, filename=self._filename) 

300 

301 async def write(self, writer: AbstractStreamWriter) -> None: 

302 loop = asyncio.get_event_loop() 

303 try: 

304 chunk = await loop.run_in_executor(None, self._value.read, 2**16) 

305 while chunk: 

306 await writer.write(chunk) 

307 chunk = await loop.run_in_executor(None, self._value.read, 2**16) 

308 finally: 

309 await loop.run_in_executor(None, self._value.close) 

310 

311 

312class TextIOPayload(IOBasePayload): 

313 _value: TextIO 

314 

315 def __init__( 

316 self, 

317 value: TextIO, 

318 *args: Any, 

319 encoding: Optional[str] = None, 

320 content_type: Optional[str] = None, 

321 **kwargs: Any, 

322 ) -> None: 

323 

324 if encoding is None: 

325 if content_type is None: 

326 encoding = "utf-8" 

327 content_type = "text/plain; charset=utf-8" 

328 else: 

329 mimetype = parse_mimetype(content_type) 

330 encoding = mimetype.parameters.get("charset", "utf-8") 

331 else: 

332 if content_type is None: 

333 content_type = "text/plain; charset=%s" % encoding 

334 

335 super().__init__( 

336 value, 

337 content_type=content_type, 

338 encoding=encoding, 

339 *args, 

340 **kwargs, 

341 ) 

342 

343 @property 

344 def size(self) -> Optional[int]: 

345 try: 

346 return os.fstat(self._value.fileno()).st_size - self._value.tell() 

347 except OSError: 

348 return None 

349 

350 async def write(self, writer: AbstractStreamWriter) -> None: 

351 loop = asyncio.get_event_loop() 

352 try: 

353 chunk = await loop.run_in_executor(None, self._value.read, 2**16) 

354 while chunk: 

355 data = ( 

356 chunk.encode(encoding=self._encoding) 

357 if self._encoding 

358 else chunk.encode() 

359 ) 

360 await writer.write(data) 

361 chunk = await loop.run_in_executor(None, self._value.read, 2**16) 

362 finally: 

363 await loop.run_in_executor(None, self._value.close) 

364 

365 

366class BytesIOPayload(IOBasePayload): 

367 @property 

368 def size(self) -> int: 

369 position = self._value.tell() 

370 end = self._value.seek(0, os.SEEK_END) 

371 self._value.seek(position) 

372 return end - position 

373 

374 

375class BufferedReaderPayload(IOBasePayload): 

376 @property 

377 def size(self) -> Optional[int]: 

378 try: 

379 return os.fstat(self._value.fileno()).st_size - self._value.tell() 

380 except OSError: 

381 # data.fileno() is not supported, e.g. 

382 # io.BufferedReader(io.BytesIO(b'data')) 

383 return None 

384 

385 

386class JsonPayload(BytesPayload): 

387 def __init__( 

388 self, 

389 value: Any, 

390 encoding: str = "utf-8", 

391 content_type: str = "application/json", 

392 dumps: JSONEncoder = json.dumps, 

393 *args: Any, 

394 **kwargs: Any, 

395 ) -> None: 

396 

397 super().__init__( 

398 dumps(value).encode(encoding), 

399 content_type=content_type, 

400 encoding=encoding, 

401 *args, 

402 **kwargs, 

403 ) 

404 

405 

406if TYPE_CHECKING: # pragma: no cover 

407 from typing import AsyncIterable, AsyncIterator 

408 

409 _AsyncIterator = AsyncIterator[bytes] 

410 _AsyncIterable = AsyncIterable[bytes] 

411else: 

412 from collections.abc import AsyncIterable, AsyncIterator 

413 

414 _AsyncIterator = AsyncIterator 

415 _AsyncIterable = AsyncIterable 

416 

417 

418class AsyncIterablePayload(Payload): 

419 

420 _iter: Optional[_AsyncIterator] = None 

421 

422 def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None: 

423 if not isinstance(value, AsyncIterable): 

424 raise TypeError( 

425 "value argument must support " 

426 "collections.abc.AsyncIterablebe interface, " 

427 "got {!r}".format(type(value)) 

428 ) 

429 

430 if "content_type" not in kwargs: 

431 kwargs["content_type"] = "application/octet-stream" 

432 

433 super().__init__(value, *args, **kwargs) 

434 

435 self._iter = value.__aiter__() 

436 

437 async def write(self, writer: AbstractStreamWriter) -> None: 

438 if self._iter: 

439 try: 

440 # iter is not None check prevents rare cases 

441 # when the case iterable is used twice 

442 while True: 

443 chunk = await self._iter.__anext__() 

444 await writer.write(chunk) 

445 except StopAsyncIteration: 

446 self._iter = None 

447 

448 

449class StreamReaderPayload(AsyncIterablePayload): 

450 def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None: 

451 super().__init__(value.iter_any(), *args, **kwargs) 

452 

453 

454PAYLOAD_REGISTRY = PayloadRegistry() 

455PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview)) 

456PAYLOAD_REGISTRY.register(StringPayload, str) 

457PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO) 

458PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase) 

459PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO) 

460PAYLOAD_REGISTRY.register(BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom)) 

461PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase) 

462PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader) 

463# try_last for giving a chance to more specialized async interables like 

464# multidict.BodyPartReaderPayload override the default 

465PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, order=Order.try_last)