Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/aiohttp/payload.py: 59%
221 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:56 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:56 +0000
1import asyncio
2import enum
3import io
4import json
5import mimetypes
6import os
7import warnings
8from abc import ABC, abstractmethod
9from itertools import chain
10from typing import (
11 IO,
12 TYPE_CHECKING,
13 Any,
14 ByteString,
15 Dict,
16 Iterable,
17 Optional,
18 TextIO,
19 Tuple,
20 Type,
21 Union,
22)
24from multidict import CIMultiDict
26from . import hdrs
27from .abc import AbstractStreamWriter
28from .helpers import (
29 PY_36,
30 content_disposition_header,
31 guess_filename,
32 parse_mimetype,
33 sentinel,
34)
35from .streams import StreamReader
36from .typedefs import Final, JSONEncoder, _CIMultiDict
38__all__ = (
39 "PAYLOAD_REGISTRY",
40 "get_payload",
41 "payload_type",
42 "Payload",
43 "BytesPayload",
44 "StringPayload",
45 "IOBasePayload",
46 "BytesIOPayload",
47 "BufferedReaderPayload",
48 "TextIOPayload",
49 "StringIOPayload",
50 "JsonPayload",
51 "AsyncIterablePayload",
52)
54TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB
56if TYPE_CHECKING: # pragma: no cover
57 from typing import List
60class LookupError(Exception):
61 pass
64class Order(str, enum.Enum):
65 normal = "normal"
66 try_first = "try_first"
67 try_last = "try_last"
70def get_payload(data: Any, *args: Any, **kwargs: Any) -> "Payload":
71 return PAYLOAD_REGISTRY.get(data, *args, **kwargs)
74def register_payload(
75 factory: Type["Payload"], type: Any, *, order: Order = Order.normal
76) -> None:
77 PAYLOAD_REGISTRY.register(factory, type, order=order)
80class payload_type:
81 def __init__(self, type: Any, *, order: Order = Order.normal) -> None:
82 self.type = type
83 self.order = order
85 def __call__(self, factory: Type["Payload"]) -> Type["Payload"]:
86 register_payload(factory, self.type, order=self.order)
87 return factory
90PayloadType = Type["Payload"]
91_PayloadRegistryItem = Tuple[PayloadType, Any]
94class PayloadRegistry:
95 """Payload registry.
97 note: we need zope.interface for more efficient adapter search
98 """
100 def __init__(self) -> None:
101 self._first: List[_PayloadRegistryItem] = []
102 self._normal: List[_PayloadRegistryItem] = []
103 self._last: List[_PayloadRegistryItem] = []
105 def get(
106 self,
107 data: Any,
108 *args: Any,
109 _CHAIN: "Type[chain[_PayloadRegistryItem]]" = chain,
110 **kwargs: Any,
111 ) -> "Payload":
112 if isinstance(data, Payload):
113 return data
114 for factory, type in _CHAIN(self._first, self._normal, self._last):
115 if isinstance(data, type):
116 return factory(data, *args, **kwargs)
118 raise LookupError()
120 def register(
121 self, factory: PayloadType, type: Any, *, order: Order = Order.normal
122 ) -> None:
123 if order is Order.try_first:
124 self._first.append((factory, type))
125 elif order is Order.normal:
126 self._normal.append((factory, type))
127 elif order is Order.try_last:
128 self._last.append((factory, type))
129 else:
130 raise ValueError(f"Unsupported order {order!r}")
133class Payload(ABC):
135 _default_content_type: str = "application/octet-stream"
136 _size: Optional[int] = None
138 def __init__(
139 self,
140 value: Any,
141 headers: Optional[
142 Union[_CIMultiDict, Dict[str, str], Iterable[Tuple[str, str]]]
143 ] = None,
144 content_type: Optional[str] = sentinel,
145 filename: Optional[str] = None,
146 encoding: Optional[str] = None,
147 **kwargs: Any,
148 ) -> None:
149 self._encoding = encoding
150 self._filename = filename
151 self._headers: _CIMultiDict = CIMultiDict()
152 self._value = value
153 if content_type is not sentinel and content_type is not None:
154 self._headers[hdrs.CONTENT_TYPE] = content_type
155 elif self._filename is not None:
156 content_type = mimetypes.guess_type(self._filename)[0]
157 if content_type is None:
158 content_type = self._default_content_type
159 self._headers[hdrs.CONTENT_TYPE] = content_type
160 else:
161 self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
162 self._headers.update(headers or {})
164 @property
165 def size(self) -> Optional[int]:
166 """Size of the payload."""
167 return self._size
169 @property
170 def filename(self) -> Optional[str]:
171 """Filename of the payload."""
172 return self._filename
174 @property
175 def headers(self) -> _CIMultiDict:
176 """Custom item headers"""
177 return self._headers
179 @property
180 def _binary_headers(self) -> bytes:
181 return (
182 "".join([k + ": " + v + "\r\n" for k, v in self.headers.items()]).encode(
183 "utf-8"
184 )
185 + b"\r\n"
186 )
188 @property
189 def encoding(self) -> Optional[str]:
190 """Payload encoding"""
191 return self._encoding
193 @property
194 def content_type(self) -> str:
195 """Content type"""
196 return self._headers[hdrs.CONTENT_TYPE]
198 def set_content_disposition(
199 self,
200 disptype: str,
201 quote_fields: bool = True,
202 _charset: str = "utf-8",
203 **params: Any,
204 ) -> None:
205 """Sets ``Content-Disposition`` header."""
206 self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header(
207 disptype, quote_fields=quote_fields, _charset=_charset, **params
208 )
210 @abstractmethod
211 async def write(self, writer: AbstractStreamWriter) -> None:
212 """Write payload.
214 writer is an AbstractStreamWriter instance:
215 """
218class BytesPayload(Payload):
219 def __init__(self, value: ByteString, *args: Any, **kwargs: Any) -> None:
220 if not isinstance(value, (bytes, bytearray, memoryview)):
221 raise TypeError(f"value argument must be byte-ish, not {type(value)!r}")
223 if "content_type" not in kwargs:
224 kwargs["content_type"] = "application/octet-stream"
226 super().__init__(value, *args, **kwargs)
228 if isinstance(value, memoryview):
229 self._size = value.nbytes
230 else:
231 self._size = len(value)
233 if self._size > TOO_LARGE_BYTES_BODY:
234 if PY_36:
235 kwargs = {"source": self}
236 else:
237 kwargs = {}
238 warnings.warn(
239 "Sending a large body directly with raw bytes might"
240 " lock the event loop. You should probably pass an "
241 "io.BytesIO object instead",
242 ResourceWarning,
243 **kwargs,
244 )
246 async def write(self, writer: AbstractStreamWriter) -> None:
247 await writer.write(self._value)
250class StringPayload(BytesPayload):
251 def __init__(
252 self,
253 value: str,
254 *args: Any,
255 encoding: Optional[str] = None,
256 content_type: Optional[str] = None,
257 **kwargs: Any,
258 ) -> None:
260 if encoding is None:
261 if content_type is None:
262 real_encoding = "utf-8"
263 content_type = "text/plain; charset=utf-8"
264 else:
265 mimetype = parse_mimetype(content_type)
266 real_encoding = mimetype.parameters.get("charset", "utf-8")
267 else:
268 if content_type is None:
269 content_type = "text/plain; charset=%s" % encoding
270 real_encoding = encoding
272 super().__init__(
273 value.encode(real_encoding),
274 encoding=real_encoding,
275 content_type=content_type,
276 *args,
277 **kwargs,
278 )
281class StringIOPayload(StringPayload):
282 def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None:
283 super().__init__(value.read(), *args, **kwargs)
286class IOBasePayload(Payload):
287 _value: IO[Any]
289 def __init__(
290 self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any
291 ) -> None:
292 if "filename" not in kwargs:
293 kwargs["filename"] = guess_filename(value)
295 super().__init__(value, *args, **kwargs)
297 if self._filename is not None and disposition is not None:
298 if hdrs.CONTENT_DISPOSITION not in self.headers:
299 self.set_content_disposition(disposition, filename=self._filename)
301 async def write(self, writer: AbstractStreamWriter) -> None:
302 loop = asyncio.get_event_loop()
303 try:
304 chunk = await loop.run_in_executor(None, self._value.read, 2**16)
305 while chunk:
306 await writer.write(chunk)
307 chunk = await loop.run_in_executor(None, self._value.read, 2**16)
308 finally:
309 await loop.run_in_executor(None, self._value.close)
312class TextIOPayload(IOBasePayload):
313 _value: TextIO
315 def __init__(
316 self,
317 value: TextIO,
318 *args: Any,
319 encoding: Optional[str] = None,
320 content_type: Optional[str] = None,
321 **kwargs: Any,
322 ) -> None:
324 if encoding is None:
325 if content_type is None:
326 encoding = "utf-8"
327 content_type = "text/plain; charset=utf-8"
328 else:
329 mimetype = parse_mimetype(content_type)
330 encoding = mimetype.parameters.get("charset", "utf-8")
331 else:
332 if content_type is None:
333 content_type = "text/plain; charset=%s" % encoding
335 super().__init__(
336 value,
337 content_type=content_type,
338 encoding=encoding,
339 *args,
340 **kwargs,
341 )
343 @property
344 def size(self) -> Optional[int]:
345 try:
346 return os.fstat(self._value.fileno()).st_size - self._value.tell()
347 except OSError:
348 return None
350 async def write(self, writer: AbstractStreamWriter) -> None:
351 loop = asyncio.get_event_loop()
352 try:
353 chunk = await loop.run_in_executor(None, self._value.read, 2**16)
354 while chunk:
355 data = (
356 chunk.encode(encoding=self._encoding)
357 if self._encoding
358 else chunk.encode()
359 )
360 await writer.write(data)
361 chunk = await loop.run_in_executor(None, self._value.read, 2**16)
362 finally:
363 await loop.run_in_executor(None, self._value.close)
366class BytesIOPayload(IOBasePayload):
367 @property
368 def size(self) -> int:
369 position = self._value.tell()
370 end = self._value.seek(0, os.SEEK_END)
371 self._value.seek(position)
372 return end - position
375class BufferedReaderPayload(IOBasePayload):
376 @property
377 def size(self) -> Optional[int]:
378 try:
379 return os.fstat(self._value.fileno()).st_size - self._value.tell()
380 except OSError:
381 # data.fileno() is not supported, e.g.
382 # io.BufferedReader(io.BytesIO(b'data'))
383 return None
386class JsonPayload(BytesPayload):
387 def __init__(
388 self,
389 value: Any,
390 encoding: str = "utf-8",
391 content_type: str = "application/json",
392 dumps: JSONEncoder = json.dumps,
393 *args: Any,
394 **kwargs: Any,
395 ) -> None:
397 super().__init__(
398 dumps(value).encode(encoding),
399 content_type=content_type,
400 encoding=encoding,
401 *args,
402 **kwargs,
403 )
406if TYPE_CHECKING: # pragma: no cover
407 from typing import AsyncIterable, AsyncIterator
409 _AsyncIterator = AsyncIterator[bytes]
410 _AsyncIterable = AsyncIterable[bytes]
411else:
412 from collections.abc import AsyncIterable, AsyncIterator
414 _AsyncIterator = AsyncIterator
415 _AsyncIterable = AsyncIterable
418class AsyncIterablePayload(Payload):
420 _iter: Optional[_AsyncIterator] = None
422 def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None:
423 if not isinstance(value, AsyncIterable):
424 raise TypeError(
425 "value argument must support "
426 "collections.abc.AsyncIterablebe interface, "
427 "got {!r}".format(type(value))
428 )
430 if "content_type" not in kwargs:
431 kwargs["content_type"] = "application/octet-stream"
433 super().__init__(value, *args, **kwargs)
435 self._iter = value.__aiter__()
437 async def write(self, writer: AbstractStreamWriter) -> None:
438 if self._iter:
439 try:
440 # iter is not None check prevents rare cases
441 # when the case iterable is used twice
442 while True:
443 chunk = await self._iter.__anext__()
444 await writer.write(chunk)
445 except StopAsyncIteration:
446 self._iter = None
449class StreamReaderPayload(AsyncIterablePayload):
450 def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None:
451 super().__init__(value.iter_any(), *args, **kwargs)
454PAYLOAD_REGISTRY = PayloadRegistry()
455PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview))
456PAYLOAD_REGISTRY.register(StringPayload, str)
457PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO)
458PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase)
459PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO)
460PAYLOAD_REGISTRY.register(BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom))
461PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase)
462PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader)
463# try_last for giving a chance to more specialized async interables like
464# multidict.BodyPartReaderPayload override the default
465PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, order=Order.try_last)