Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/aiohttp/payload.py: 59%
219 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:40 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:40 +0000
1import asyncio
2import enum
3import io
4import json
5import mimetypes
6import os
7import warnings
8from abc import ABC, abstractmethod
9from itertools import chain
10from typing import (
11 IO,
12 TYPE_CHECKING,
13 Any,
14 ByteString,
15 Dict,
16 Final,
17 Iterable,
18 Optional,
19 TextIO,
20 Tuple,
21 Type,
22 Union,
23)
25from multidict import CIMultiDict
27from . import hdrs
28from .abc import AbstractStreamWriter
29from .helpers import (
30 _SENTINEL,
31 content_disposition_header,
32 guess_filename,
33 parse_mimetype,
34 sentinel,
35)
36from .streams import StreamReader
37from .typedefs import JSONEncoder, _CIMultiDict
39__all__ = (
40 "PAYLOAD_REGISTRY",
41 "get_payload",
42 "payload_type",
43 "Payload",
44 "BytesPayload",
45 "StringPayload",
46 "IOBasePayload",
47 "BytesIOPayload",
48 "BufferedReaderPayload",
49 "TextIOPayload",
50 "StringIOPayload",
51 "JsonPayload",
52 "AsyncIterablePayload",
53)
55TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB
57if TYPE_CHECKING: # pragma: no cover
58 from typing import List
61class LookupError(Exception):
62 pass
65class Order(str, enum.Enum):
66 normal = "normal"
67 try_first = "try_first"
68 try_last = "try_last"
71def get_payload(data: Any, *args: Any, **kwargs: Any) -> "Payload":
72 return PAYLOAD_REGISTRY.get(data, *args, **kwargs)
75def register_payload(
76 factory: Type["Payload"], type: Any, *, order: Order = Order.normal
77) -> None:
78 PAYLOAD_REGISTRY.register(factory, type, order=order)
81class payload_type:
82 def __init__(self, type: Any, *, order: Order = Order.normal) -> None:
83 self.type = type
84 self.order = order
86 def __call__(self, factory: Type["Payload"]) -> Type["Payload"]:
87 register_payload(factory, self.type, order=self.order)
88 return factory
91PayloadType = Type["Payload"]
92_PayloadRegistryItem = Tuple[PayloadType, Any]
95class PayloadRegistry:
96 """Payload registry.
98 note: we need zope.interface for more efficient adapter search
99 """
101 def __init__(self) -> None:
102 self._first: List[_PayloadRegistryItem] = []
103 self._normal: List[_PayloadRegistryItem] = []
104 self._last: List[_PayloadRegistryItem] = []
106 def get(
107 self,
108 data: Any,
109 *args: Any,
110 _CHAIN: "Type[chain[_PayloadRegistryItem]]" = chain,
111 **kwargs: Any,
112 ) -> "Payload":
113 if isinstance(data, Payload):
114 return data
115 for factory, type in _CHAIN(self._first, self._normal, self._last):
116 if isinstance(data, type):
117 return factory(data, *args, **kwargs)
119 raise LookupError()
121 def register(
122 self, factory: PayloadType, type: Any, *, order: Order = Order.normal
123 ) -> None:
124 if order is Order.try_first:
125 self._first.append((factory, type))
126 elif order is Order.normal:
127 self._normal.append((factory, type))
128 elif order is Order.try_last:
129 self._last.append((factory, type))
130 else:
131 raise ValueError(f"Unsupported order {order!r}")
134class Payload(ABC):
136 _default_content_type: str = "application/octet-stream"
137 _size: Optional[int] = None
139 def __init__(
140 self,
141 value: Any,
142 headers: Optional[
143 Union[_CIMultiDict, Dict[str, str], Iterable[Tuple[str, str]]]
144 ] = None,
145 content_type: Union[str, None, _SENTINEL] = sentinel,
146 filename: Optional[str] = None,
147 encoding: Optional[str] = None,
148 **kwargs: Any,
149 ) -> None:
150 self._encoding = encoding
151 self._filename = filename
152 self._headers: _CIMultiDict = CIMultiDict()
153 self._value = value
154 if content_type is not sentinel and content_type is not None:
155 self._headers[hdrs.CONTENT_TYPE] = content_type
156 elif self._filename is not None:
157 content_type = mimetypes.guess_type(self._filename)[0]
158 if content_type is None:
159 content_type = self._default_content_type
160 self._headers[hdrs.CONTENT_TYPE] = content_type
161 else:
162 self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
163 self._headers.update(headers or {})
165 @property
166 def size(self) -> Optional[int]:
167 """Size of the payload."""
168 return self._size
170 @property
171 def filename(self) -> Optional[str]:
172 """Filename of the payload."""
173 return self._filename
175 @property
176 def headers(self) -> _CIMultiDict:
177 """Custom item headers"""
178 return self._headers
180 @property
181 def _binary_headers(self) -> bytes:
182 return (
183 "".join([k + ": " + v + "\r\n" for k, v in self.headers.items()]).encode(
184 "utf-8"
185 )
186 + b"\r\n"
187 )
189 @property
190 def encoding(self) -> Optional[str]:
191 """Payload encoding"""
192 return self._encoding
194 @property
195 def content_type(self) -> str:
196 """Content type"""
197 return self._headers[hdrs.CONTENT_TYPE]
199 def set_content_disposition(
200 self,
201 disptype: str,
202 quote_fields: bool = True,
203 _charset: str = "utf-8",
204 **params: Any,
205 ) -> None:
206 """Sets ``Content-Disposition`` header."""
207 self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header(
208 disptype, quote_fields=quote_fields, _charset=_charset, **params
209 )
211 @abstractmethod
212 async def write(self, writer: AbstractStreamWriter) -> None:
213 """Write payload.
215 writer is an AbstractStreamWriter instance:
216 """
219class BytesPayload(Payload):
220 def __init__(self, value: ByteString, *args: Any, **kwargs: Any) -> None:
221 if not isinstance(value, (bytes, bytearray, memoryview)):
222 raise TypeError(f"value argument must be byte-ish, not {type(value)!r}")
224 if "content_type" not in kwargs:
225 kwargs["content_type"] = "application/octet-stream"
227 super().__init__(value, *args, **kwargs)
229 if isinstance(value, memoryview):
230 self._size = value.nbytes
231 else:
232 self._size = len(value)
234 if self._size > TOO_LARGE_BYTES_BODY:
235 kwargs = {"source": self}
236 warnings.warn(
237 "Sending a large body directly with raw bytes might"
238 " lock the event loop. You should probably pass an "
239 "io.BytesIO object instead",
240 ResourceWarning,
241 **kwargs,
242 )
244 async def write(self, writer: AbstractStreamWriter) -> None:
245 await writer.write(self._value)
248class StringPayload(BytesPayload):
249 def __init__(
250 self,
251 value: str,
252 *args: Any,
253 encoding: Optional[str] = None,
254 content_type: Optional[str] = None,
255 **kwargs: Any,
256 ) -> None:
258 if encoding is None:
259 if content_type is None:
260 real_encoding = "utf-8"
261 content_type = "text/plain; charset=utf-8"
262 else:
263 mimetype = parse_mimetype(content_type)
264 real_encoding = mimetype.parameters.get("charset", "utf-8")
265 else:
266 if content_type is None:
267 content_type = "text/plain; charset=%s" % encoding
268 real_encoding = encoding
270 super().__init__(
271 value.encode(real_encoding),
272 encoding=real_encoding,
273 content_type=content_type,
274 *args,
275 **kwargs,
276 )
279class StringIOPayload(StringPayload):
280 def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None:
281 super().__init__(value.read(), *args, **kwargs)
284class IOBasePayload(Payload):
285 _value: IO[Any]
287 def __init__(
288 self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any
289 ) -> None:
290 if "filename" not in kwargs:
291 kwargs["filename"] = guess_filename(value)
293 super().__init__(value, *args, **kwargs)
295 if self._filename is not None and disposition is not None:
296 if hdrs.CONTENT_DISPOSITION not in self.headers:
297 self.set_content_disposition(disposition, filename=self._filename)
299 async def write(self, writer: AbstractStreamWriter) -> None:
300 loop = asyncio.get_event_loop()
301 try:
302 chunk = await loop.run_in_executor(None, self._value.read, 2**16)
303 while chunk:
304 await writer.write(chunk)
305 chunk = await loop.run_in_executor(None, self._value.read, 2**16)
306 finally:
307 await loop.run_in_executor(None, self._value.close)
310class TextIOPayload(IOBasePayload):
311 _value: TextIO
313 def __init__(
314 self,
315 value: TextIO,
316 *args: Any,
317 encoding: Optional[str] = None,
318 content_type: Optional[str] = None,
319 **kwargs: Any,
320 ) -> None:
322 if encoding is None:
323 if content_type is None:
324 encoding = "utf-8"
325 content_type = "text/plain; charset=utf-8"
326 else:
327 mimetype = parse_mimetype(content_type)
328 encoding = mimetype.parameters.get("charset", "utf-8")
329 else:
330 if content_type is None:
331 content_type = "text/plain; charset=%s" % encoding
333 super().__init__(
334 value,
335 content_type=content_type,
336 encoding=encoding,
337 *args,
338 **kwargs,
339 )
341 @property
342 def size(self) -> Optional[int]:
343 try:
344 return os.fstat(self._value.fileno()).st_size - self._value.tell()
345 except OSError:
346 return None
348 async def write(self, writer: AbstractStreamWriter) -> None:
349 loop = asyncio.get_event_loop()
350 try:
351 chunk = await loop.run_in_executor(None, self._value.read, 2**16)
352 while chunk:
353 data = (
354 chunk.encode(encoding=self._encoding)
355 if self._encoding
356 else chunk.encode()
357 )
358 await writer.write(data)
359 chunk = await loop.run_in_executor(None, self._value.read, 2**16)
360 finally:
361 await loop.run_in_executor(None, self._value.close)
364class BytesIOPayload(IOBasePayload):
365 @property
366 def size(self) -> int:
367 position = self._value.tell()
368 end = self._value.seek(0, os.SEEK_END)
369 self._value.seek(position)
370 return end - position
373class BufferedReaderPayload(IOBasePayload):
374 @property
375 def size(self) -> Optional[int]:
376 try:
377 return os.fstat(self._value.fileno()).st_size - self._value.tell()
378 except OSError:
379 # data.fileno() is not supported, e.g.
380 # io.BufferedReader(io.BytesIO(b'data'))
381 return None
384class JsonPayload(BytesPayload):
385 def __init__(
386 self,
387 value: Any,
388 encoding: str = "utf-8",
389 content_type: str = "application/json",
390 dumps: JSONEncoder = json.dumps,
391 *args: Any,
392 **kwargs: Any,
393 ) -> None:
395 super().__init__(
396 dumps(value).encode(encoding),
397 content_type=content_type,
398 encoding=encoding,
399 *args,
400 **kwargs,
401 )
404if TYPE_CHECKING: # pragma: no cover
405 from typing import AsyncIterable, AsyncIterator
407 _AsyncIterator = AsyncIterator[bytes]
408 _AsyncIterable = AsyncIterable[bytes]
409else:
410 from collections.abc import AsyncIterable, AsyncIterator
412 _AsyncIterator = AsyncIterator
413 _AsyncIterable = AsyncIterable
416class AsyncIterablePayload(Payload):
418 _iter: Optional[_AsyncIterator] = None
420 def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None:
421 if not isinstance(value, AsyncIterable):
422 raise TypeError(
423 "value argument must support "
424 "collections.abc.AsyncIterable interface, "
425 "got {!r}".format(type(value))
426 )
428 if "content_type" not in kwargs:
429 kwargs["content_type"] = "application/octet-stream"
431 super().__init__(value, *args, **kwargs)
433 self._iter = value.__aiter__()
435 async def write(self, writer: AbstractStreamWriter) -> None:
436 if self._iter:
437 try:
438 # iter is not None check prevents rare cases
439 # when the case iterable is used twice
440 while True:
441 chunk = await self._iter.__anext__()
442 await writer.write(chunk)
443 except StopAsyncIteration:
444 self._iter = None
447class StreamReaderPayload(AsyncIterablePayload):
448 def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None:
449 super().__init__(value.iter_any(), *args, **kwargs)
452PAYLOAD_REGISTRY = PayloadRegistry()
453PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview))
454PAYLOAD_REGISTRY.register(StringPayload, str)
455PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO)
456PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase)
457PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO)
458PAYLOAD_REGISTRY.register(BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom))
459PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase)
460PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader)
461# try_last for giving a chance to more specialized async interables like
462# multidict.BodyPartReaderPayload override the default
463PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, order=Order.try_last)