Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/payload.py: 60%
220 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:52 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:52 +0000
1import asyncio
2import enum
3import io
4import json
5import mimetypes
6import os
7import warnings
8from abc import ABC, abstractmethod
9from itertools import chain
10from typing import (
11 IO,
12 TYPE_CHECKING,
13 Any,
14 ByteString,
15 Dict,
16 Iterable,
17 Optional,
18 TextIO,
19 Tuple,
20 Type,
21 Union,
22)
24from multidict import CIMultiDict
25from typing_extensions import Final
27from . import hdrs
28from .abc import AbstractStreamWriter
29from .helpers import (
30 _SENTINEL,
31 content_disposition_header,
32 guess_filename,
33 parse_mimetype,
34 sentinel,
35)
36from .streams import StreamReader
37from .typedefs import JSONEncoder, _CIMultiDict
39__all__ = (
40 "PAYLOAD_REGISTRY",
41 "get_payload",
42 "payload_type",
43 "Payload",
44 "BytesPayload",
45 "StringPayload",
46 "IOBasePayload",
47 "BytesIOPayload",
48 "BufferedReaderPayload",
49 "TextIOPayload",
50 "StringIOPayload",
51 "JsonPayload",
52 "AsyncIterablePayload",
53)
55TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB
57if TYPE_CHECKING: # pragma: no cover
58 from typing import List
61class LookupError(Exception):
62 pass
65class Order(str, enum.Enum):
66 normal = "normal"
67 try_first = "try_first"
68 try_last = "try_last"
71def get_payload(data: Any, *args: Any, **kwargs: Any) -> "Payload":
72 return PAYLOAD_REGISTRY.get(data, *args, **kwargs)
75def register_payload(
76 factory: Type["Payload"], type: Any, *, order: Order = Order.normal
77) -> None:
78 PAYLOAD_REGISTRY.register(factory, type, order=order)
81class payload_type:
82 def __init__(self, type: Any, *, order: Order = Order.normal) -> None:
83 self.type = type
84 self.order = order
86 def __call__(self, factory: Type["Payload"]) -> Type["Payload"]:
87 register_payload(factory, self.type, order=self.order)
88 return factory
91PayloadType = Type["Payload"]
92_PayloadRegistryItem = Tuple[PayloadType, Any]
95class PayloadRegistry:
96 """Payload registry.
98 note: we need zope.interface for more efficient adapter search
99 """
101 def __init__(self) -> None:
102 self._first: List[_PayloadRegistryItem] = []
103 self._normal: List[_PayloadRegistryItem] = []
104 self._last: List[_PayloadRegistryItem] = []
106 def get(
107 self,
108 data: Any,
109 *args: Any,
110 _CHAIN: "Type[chain[_PayloadRegistryItem]]" = chain,
111 **kwargs: Any,
112 ) -> "Payload":
113 if isinstance(data, Payload):
114 return data
115 for factory, type in _CHAIN(self._first, self._normal, self._last):
116 if isinstance(data, type):
117 return factory(data, *args, **kwargs)
119 raise LookupError()
121 def register(
122 self, factory: PayloadType, type: Any, *, order: Order = Order.normal
123 ) -> None:
124 if order is Order.try_first:
125 self._first.append((factory, type))
126 elif order is Order.normal:
127 self._normal.append((factory, type))
128 elif order is Order.try_last:
129 self._last.append((factory, type))
130 else:
131 raise ValueError(f"Unsupported order {order!r}")
134class Payload(ABC):
135 _default_content_type: str = "application/octet-stream"
136 _size: Optional[int] = None
138 def __init__(
139 self,
140 value: Any,
141 headers: Optional[
142 Union[_CIMultiDict, Dict[str, str], Iterable[Tuple[str, str]]]
143 ] = None,
144 content_type: Union[None, str, _SENTINEL] = sentinel,
145 filename: Optional[str] = None,
146 encoding: Optional[str] = None,
147 **kwargs: Any,
148 ) -> None:
149 self._encoding = encoding
150 self._filename = filename
151 self._headers: _CIMultiDict = CIMultiDict()
152 self._value = value
153 if content_type is not sentinel and content_type is not None:
154 assert isinstance(content_type, str)
155 self._headers[hdrs.CONTENT_TYPE] = content_type
156 elif self._filename is not None:
157 content_type = mimetypes.guess_type(self._filename)[0]
158 if content_type is None:
159 content_type = self._default_content_type
160 self._headers[hdrs.CONTENT_TYPE] = content_type
161 else:
162 self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
163 self._headers.update(headers or {})
165 @property
166 def size(self) -> Optional[int]:
167 """Size of the payload."""
168 return self._size
170 @property
171 def filename(self) -> Optional[str]:
172 """Filename of the payload."""
173 return self._filename
175 @property
176 def headers(self) -> _CIMultiDict:
177 """Custom item headers"""
178 return self._headers
180 @property
181 def _binary_headers(self) -> bytes:
182 return (
183 "".join([k + ": " + v + "\r\n" for k, v in self.headers.items()]).encode(
184 "utf-8"
185 )
186 + b"\r\n"
187 )
189 @property
190 def encoding(self) -> Optional[str]:
191 """Payload encoding"""
192 return self._encoding
194 @property
195 def content_type(self) -> str:
196 """Content type"""
197 return self._headers[hdrs.CONTENT_TYPE]
199 def set_content_disposition(
200 self,
201 disptype: str,
202 quote_fields: bool = True,
203 _charset: str = "utf-8",
204 **params: Any,
205 ) -> None:
206 """Sets ``Content-Disposition`` header."""
207 self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header(
208 disptype, quote_fields=quote_fields, _charset=_charset, **params
209 )
211 @abstractmethod
212 async def write(self, writer: AbstractStreamWriter) -> None:
213 """Write payload.
215 writer is an AbstractStreamWriter instance:
216 """
219class BytesPayload(Payload):
220 def __init__(self, value: ByteString, *args: Any, **kwargs: Any) -> None:
221 if not isinstance(value, (bytes, bytearray, memoryview)):
222 raise TypeError(f"value argument must be byte-ish, not {type(value)!r}")
224 if "content_type" not in kwargs:
225 kwargs["content_type"] = "application/octet-stream"
227 super().__init__(value, *args, **kwargs)
229 if isinstance(value, memoryview):
230 self._size = value.nbytes
231 else:
232 self._size = len(value)
234 if self._size > TOO_LARGE_BYTES_BODY:
235 warnings.warn(
236 "Sending a large body directly with raw bytes might"
237 " lock the event loop. You should probably pass an "
238 "io.BytesIO object instead",
239 ResourceWarning,
240 source=self,
241 )
243 async def write(self, writer: AbstractStreamWriter) -> None:
244 await writer.write(self._value)
247class StringPayload(BytesPayload):
248 def __init__(
249 self,
250 value: str,
251 *args: Any,
252 encoding: Optional[str] = None,
253 content_type: Optional[str] = None,
254 **kwargs: Any,
255 ) -> None:
256 if encoding is None:
257 if content_type is None:
258 real_encoding = "utf-8"
259 content_type = "text/plain; charset=utf-8"
260 else:
261 mimetype = parse_mimetype(content_type)
262 real_encoding = mimetype.parameters.get("charset", "utf-8")
263 else:
264 if content_type is None:
265 content_type = "text/plain; charset=%s" % encoding
266 real_encoding = encoding
268 super().__init__(
269 value.encode(real_encoding),
270 encoding=real_encoding,
271 content_type=content_type,
272 *args,
273 **kwargs,
274 )
277class StringIOPayload(StringPayload):
278 def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None:
279 super().__init__(value.read(), *args, **kwargs)
282class IOBasePayload(Payload):
283 _value: IO[Any]
285 def __init__(
286 self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any
287 ) -> None:
288 if "filename" not in kwargs:
289 kwargs["filename"] = guess_filename(value)
291 super().__init__(value, *args, **kwargs)
293 if self._filename is not None and disposition is not None:
294 if hdrs.CONTENT_DISPOSITION not in self.headers:
295 self.set_content_disposition(disposition, filename=self._filename)
297 async def write(self, writer: AbstractStreamWriter) -> None:
298 loop = asyncio.get_event_loop()
299 try:
300 chunk = await loop.run_in_executor(None, self._value.read, 2**16)
301 while chunk:
302 await writer.write(chunk)
303 chunk = await loop.run_in_executor(None, self._value.read, 2**16)
304 finally:
305 await loop.run_in_executor(None, self._value.close)
308class TextIOPayload(IOBasePayload):
309 _value: TextIO
311 def __init__(
312 self,
313 value: TextIO,
314 *args: Any,
315 encoding: Optional[str] = None,
316 content_type: Optional[str] = None,
317 **kwargs: Any,
318 ) -> None:
319 if encoding is None:
320 if content_type is None:
321 encoding = "utf-8"
322 content_type = "text/plain; charset=utf-8"
323 else:
324 mimetype = parse_mimetype(content_type)
325 encoding = mimetype.parameters.get("charset", "utf-8")
326 else:
327 if content_type is None:
328 content_type = "text/plain; charset=%s" % encoding
330 super().__init__(
331 value,
332 content_type=content_type,
333 encoding=encoding,
334 *args,
335 **kwargs,
336 )
338 @property
339 def size(self) -> Optional[int]:
340 try:
341 return os.fstat(self._value.fileno()).st_size - self._value.tell()
342 except OSError:
343 return None
345 async def write(self, writer: AbstractStreamWriter) -> None:
346 loop = asyncio.get_event_loop()
347 try:
348 chunk = await loop.run_in_executor(None, self._value.read, 2**16)
349 while chunk:
350 data = (
351 chunk.encode(encoding=self._encoding)
352 if self._encoding
353 else chunk.encode()
354 )
355 await writer.write(data)
356 chunk = await loop.run_in_executor(None, self._value.read, 2**16)
357 finally:
358 await loop.run_in_executor(None, self._value.close)
361class BytesIOPayload(IOBasePayload):
362 @property
363 def size(self) -> int:
364 position = self._value.tell()
365 end = self._value.seek(0, os.SEEK_END)
366 self._value.seek(position)
367 return end - position
370class BufferedReaderPayload(IOBasePayload):
371 @property
372 def size(self) -> Optional[int]:
373 try:
374 return os.fstat(self._value.fileno()).st_size - self._value.tell()
375 except OSError:
376 # data.fileno() is not supported, e.g.
377 # io.BufferedReader(io.BytesIO(b'data'))
378 return None
381class JsonPayload(BytesPayload):
382 def __init__(
383 self,
384 value: Any,
385 encoding: str = "utf-8",
386 content_type: str = "application/json",
387 dumps: JSONEncoder = json.dumps,
388 *args: Any,
389 **kwargs: Any,
390 ) -> None:
391 super().__init__(
392 dumps(value).encode(encoding),
393 content_type=content_type,
394 encoding=encoding,
395 *args,
396 **kwargs,
397 )
400if TYPE_CHECKING: # pragma: no cover
401 from typing import AsyncIterable, AsyncIterator
403 _AsyncIterator = AsyncIterator[bytes]
404 _AsyncIterable = AsyncIterable[bytes]
405else:
406 from collections.abc import AsyncIterable, AsyncIterator
408 _AsyncIterator = AsyncIterator
409 _AsyncIterable = AsyncIterable
412class AsyncIterablePayload(Payload):
413 _iter: Optional[_AsyncIterator] = None
415 def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None:
416 if not isinstance(value, AsyncIterable):
417 raise TypeError(
418 "value argument must support "
419 "collections.abc.AsyncIterablebe interface, "
420 "got {!r}".format(type(value))
421 )
423 if "content_type" not in kwargs:
424 kwargs["content_type"] = "application/octet-stream"
426 super().__init__(value, *args, **kwargs)
428 self._iter = value.__aiter__()
430 async def write(self, writer: AbstractStreamWriter) -> None:
431 if self._iter:
432 try:
433 # iter is not None check prevents rare cases
434 # when the case iterable is used twice
435 while True:
436 chunk = await self._iter.__anext__()
437 await writer.write(chunk)
438 except StopAsyncIteration:
439 self._iter = None
442class StreamReaderPayload(AsyncIterablePayload):
443 def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None:
444 super().__init__(value.iter_any(), *args, **kwargs)
447PAYLOAD_REGISTRY = PayloadRegistry()
448PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview))
449PAYLOAD_REGISTRY.register(StringPayload, str)
450PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO)
451PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase)
452PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO)
453PAYLOAD_REGISTRY.register(BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom))
454PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase)
455PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader)
456# try_last for giving a chance to more specialized async interables like
457# multidict.BodyPartReaderPayload override the default
458PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, order=Order.try_last)