Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/aiohttp/multipart.py: 19%
560 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:56 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:56 +0000
1import base64
2import binascii
3import json
4import re
5import uuid
6import warnings
7import zlib
8from collections import deque
9from types import TracebackType
10from typing import (
11 TYPE_CHECKING,
12 Any,
13 AsyncIterator,
14 Deque,
15 Dict,
16 Iterator,
17 List,
18 Mapping,
19 Optional,
20 Sequence,
21 Tuple,
22 Type,
23 Union,
24 cast,
25)
26from urllib.parse import parse_qsl, unquote, urlencode
28from multidict import CIMultiDict, CIMultiDictProxy, MultiMapping
30from .hdrs import (
31 CONTENT_DISPOSITION,
32 CONTENT_ENCODING,
33 CONTENT_LENGTH,
34 CONTENT_TRANSFER_ENCODING,
35 CONTENT_TYPE,
36)
37from .helpers import CHAR, TOKEN, parse_mimetype, reify
38from .http import HeadersParser
39from .payload import (
40 JsonPayload,
41 LookupError,
42 Order,
43 Payload,
44 StringPayload,
45 get_payload,
46 payload_type,
47)
48from .streams import StreamReader
50__all__ = (
51 "MultipartReader",
52 "MultipartWriter",
53 "BodyPartReader",
54 "BadContentDispositionHeader",
55 "BadContentDispositionParam",
56 "parse_content_disposition",
57 "content_disposition_filename",
58)
61if TYPE_CHECKING: # pragma: no cover
62 from .client_reqrep import ClientResponse
65class BadContentDispositionHeader(RuntimeWarning):
66 pass
69class BadContentDispositionParam(RuntimeWarning):
70 pass
73def parse_content_disposition(
74 header: Optional[str],
75) -> Tuple[Optional[str], Dict[str, str]]:
76 def is_token(string: str) -> bool:
77 return bool(string) and TOKEN >= set(string)
79 def is_quoted(string: str) -> bool:
80 return string[0] == string[-1] == '"'
82 def is_rfc5987(string: str) -> bool:
83 return is_token(string) and string.count("'") == 2
85 def is_extended_param(string: str) -> bool:
86 return string.endswith("*")
88 def is_continuous_param(string: str) -> bool:
89 pos = string.find("*") + 1
90 if not pos:
91 return False
92 substring = string[pos:-1] if string.endswith("*") else string[pos:]
93 return substring.isdigit()
95 def unescape(text: str, *, chars: str = "".join(map(re.escape, CHAR))) -> str:
96 return re.sub(f"\\\\([{chars}])", "\\1", text)
98 if not header:
99 return None, {}
101 disptype, *parts = header.split(";")
102 if not is_token(disptype):
103 warnings.warn(BadContentDispositionHeader(header))
104 return None, {}
106 params: Dict[str, str] = {}
107 while parts:
108 item = parts.pop(0)
110 if "=" not in item:
111 warnings.warn(BadContentDispositionHeader(header))
112 return None, {}
114 key, value = item.split("=", 1)
115 key = key.lower().strip()
116 value = value.lstrip()
118 if key in params:
119 warnings.warn(BadContentDispositionHeader(header))
120 return None, {}
122 if not is_token(key):
123 warnings.warn(BadContentDispositionParam(item))
124 continue
126 elif is_continuous_param(key):
127 if is_quoted(value):
128 value = unescape(value[1:-1])
129 elif not is_token(value):
130 warnings.warn(BadContentDispositionParam(item))
131 continue
133 elif is_extended_param(key):
134 if is_rfc5987(value):
135 encoding, _, value = value.split("'", 2)
136 encoding = encoding or "utf-8"
137 else:
138 warnings.warn(BadContentDispositionParam(item))
139 continue
141 try:
142 value = unquote(value, encoding, "strict")
143 except UnicodeDecodeError: # pragma: nocover
144 warnings.warn(BadContentDispositionParam(item))
145 continue
147 else:
148 failed = True
149 if is_quoted(value):
150 failed = False
151 value = unescape(value[1:-1].lstrip("\\/"))
152 elif is_token(value):
153 failed = False
154 elif parts:
155 # maybe just ; in filename, in any case this is just
156 # one case fix, for proper fix we need to redesign parser
157 _value = f"{value};{parts[0]}"
158 if is_quoted(_value):
159 parts.pop(0)
160 value = unescape(_value[1:-1].lstrip("\\/"))
161 failed = False
163 if failed:
164 warnings.warn(BadContentDispositionHeader(header))
165 return None, {}
167 params[key] = value
169 return disptype.lower(), params
172def content_disposition_filename(
173 params: Mapping[str, str], name: str = "filename"
174) -> Optional[str]:
175 name_suf = "%s*" % name
176 if not params:
177 return None
178 elif name_suf in params:
179 return params[name_suf]
180 elif name in params:
181 return params[name]
182 else:
183 parts = []
184 fnparams = sorted(
185 (key, value) for key, value in params.items() if key.startswith(name_suf)
186 )
187 for num, (key, value) in enumerate(fnparams):
188 _, tail = key.split("*", 1)
189 if tail.endswith("*"):
190 tail = tail[:-1]
191 if tail == str(num):
192 parts.append(value)
193 else:
194 break
195 if not parts:
196 return None
197 value = "".join(parts)
198 if "'" in value:
199 encoding, _, value = value.split("'", 2)
200 encoding = encoding or "utf-8"
201 return unquote(value, encoding, "strict")
202 return value
205class MultipartResponseWrapper:
206 """Wrapper around the MultipartReader.
208 It takes care about
209 underlying connection and close it when it needs in.
210 """
212 def __init__(
213 self,
214 resp: "ClientResponse",
215 stream: "MultipartReader",
216 ) -> None:
217 self.resp = resp
218 self.stream = stream
220 def __aiter__(self) -> "MultipartResponseWrapper":
221 return self
223 async def __anext__(
224 self,
225 ) -> Union["MultipartReader", "BodyPartReader"]:
226 part = await self.next()
227 if part is None:
228 raise StopAsyncIteration
229 return part
231 def at_eof(self) -> bool:
232 """Returns True when all response data had been read."""
233 return self.resp.content.at_eof()
235 async def next(
236 self,
237 ) -> Optional[Union["MultipartReader", "BodyPartReader"]]:
238 """Emits next multipart reader object."""
239 item = await self.stream.next()
240 if self.stream.at_eof():
241 await self.release()
242 return item
244 async def release(self) -> None:
245 """Release the connection gracefully.
247 All remaining content is read to the void.
248 """
249 await self.resp.release()
252class BodyPartReader:
253 """Multipart reader for single body part."""
255 chunk_size = 8192
257 def __init__(
258 self, boundary: bytes, headers: "CIMultiDictProxy[str]", content: StreamReader
259 ) -> None:
260 self.headers = headers
261 self._boundary = boundary
262 self._content = content
263 self._at_eof = False
264 length = self.headers.get(CONTENT_LENGTH, None)
265 self._length = int(length) if length is not None else None
266 self._read_bytes = 0
267 # TODO: typeing.Deque is not supported by Python 3.5
268 self._unread: Deque[bytes] = deque()
269 self._prev_chunk: Optional[bytes] = None
270 self._content_eof = 0
271 self._cache: Dict[str, Any] = {}
273 def __aiter__(self) -> AsyncIterator["BodyPartReader"]:
274 return self # type: ignore[return-value]
276 async def __anext__(self) -> bytes:
277 part = await self.next()
278 if part is None:
279 raise StopAsyncIteration
280 return part
282 async def next(self) -> Optional[bytes]:
283 item = await self.read()
284 if not item:
285 return None
286 return item
288 async def read(self, *, decode: bool = False) -> bytes:
289 """Reads body part data.
291 decode: Decodes data following by encoding
292 method from Content-Encoding header. If it missed
293 data remains untouched
294 """
295 if self._at_eof:
296 return b""
297 data = bytearray()
298 while not self._at_eof:
299 data.extend(await self.read_chunk(self.chunk_size))
300 if decode:
301 return self.decode(data)
302 return data
304 async def read_chunk(self, size: int = chunk_size) -> bytes:
305 """Reads body part content chunk of the specified size.
307 size: chunk size
308 """
309 if self._at_eof:
310 return b""
311 if self._length:
312 chunk = await self._read_chunk_from_length(size)
313 else:
314 chunk = await self._read_chunk_from_stream(size)
316 self._read_bytes += len(chunk)
317 if self._read_bytes == self._length:
318 self._at_eof = True
319 if self._at_eof:
320 clrf = await self._content.readline()
321 assert (
322 b"\r\n" == clrf
323 ), "reader did not read all the data or it is malformed"
324 return chunk
326 async def _read_chunk_from_length(self, size: int) -> bytes:
327 # Reads body part content chunk of the specified size.
328 # The body part must has Content-Length header with proper value.
329 assert self._length is not None, "Content-Length required for chunked read"
330 chunk_size = min(size, self._length - self._read_bytes)
331 chunk = await self._content.read(chunk_size)
332 return chunk
334 async def _read_chunk_from_stream(self, size: int) -> bytes:
335 # Reads content chunk of body part with unknown length.
336 # The Content-Length header for body part is not necessary.
337 assert (
338 size >= len(self._boundary) + 2
339 ), "Chunk size must be greater or equal than boundary length + 2"
340 first_chunk = self._prev_chunk is None
341 if first_chunk:
342 self._prev_chunk = await self._content.read(size)
344 chunk = await self._content.read(size)
345 self._content_eof += int(self._content.at_eof())
346 assert self._content_eof < 3, "Reading after EOF"
347 assert self._prev_chunk is not None
348 window = self._prev_chunk + chunk
349 sub = b"\r\n" + self._boundary
350 if first_chunk:
351 idx = window.find(sub)
352 else:
353 idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub)))
354 if idx >= 0:
355 # pushing boundary back to content
356 with warnings.catch_warnings():
357 warnings.filterwarnings("ignore", category=DeprecationWarning)
358 self._content.unread_data(window[idx:])
359 if size > idx:
360 self._prev_chunk = self._prev_chunk[:idx]
361 chunk = window[len(self._prev_chunk) : idx]
362 if not chunk:
363 self._at_eof = True
364 result = self._prev_chunk
365 self._prev_chunk = chunk
366 return result
368 async def readline(self) -> bytes:
369 """Reads body part by line by line."""
370 if self._at_eof:
371 return b""
373 if self._unread:
374 line = self._unread.popleft()
375 else:
376 line = await self._content.readline()
378 if line.startswith(self._boundary):
379 # the very last boundary may not come with \r\n,
380 # so set single rules for everyone
381 sline = line.rstrip(b"\r\n")
382 boundary = self._boundary
383 last_boundary = self._boundary + b"--"
384 # ensure that we read exactly the boundary, not something alike
385 if sline == boundary or sline == last_boundary:
386 self._at_eof = True
387 self._unread.append(line)
388 return b""
389 else:
390 next_line = await self._content.readline()
391 if next_line.startswith(self._boundary):
392 line = line[:-2] # strip CRLF but only once
393 self._unread.append(next_line)
395 return line
397 async def release(self) -> None:
398 """Like read(), but reads all the data to the void."""
399 if self._at_eof:
400 return
401 while not self._at_eof:
402 await self.read_chunk(self.chunk_size)
404 async def text(self, *, encoding: Optional[str] = None) -> str:
405 """Like read(), but assumes that body part contains text data."""
406 data = await self.read(decode=True)
407 # see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm # NOQA
408 # and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send # NOQA
409 encoding = encoding or self.get_charset(default="utf-8")
410 return data.decode(encoding)
412 async def json(self, *, encoding: Optional[str] = None) -> Optional[Dict[str, Any]]:
413 """Like read(), but assumes that body parts contains JSON data."""
414 data = await self.read(decode=True)
415 if not data:
416 return None
417 encoding = encoding or self.get_charset(default="utf-8")
418 return cast(Dict[str, Any], json.loads(data.decode(encoding)))
420 async def form(self, *, encoding: Optional[str] = None) -> List[Tuple[str, str]]:
421 """Like read(), but assumes that body parts contain form urlencoded data."""
422 data = await self.read(decode=True)
423 if not data:
424 return []
425 if encoding is not None:
426 real_encoding = encoding
427 else:
428 real_encoding = self.get_charset(default="utf-8")
429 return parse_qsl(
430 data.rstrip().decode(real_encoding),
431 keep_blank_values=True,
432 encoding=real_encoding,
433 )
435 def at_eof(self) -> bool:
436 """Returns True if the boundary was reached or False otherwise."""
437 return self._at_eof
439 def decode(self, data: bytes) -> bytes:
440 """Decodes data.
442 Decoding is done according the specified Content-Encoding
443 or Content-Transfer-Encoding headers value.
444 """
445 if CONTENT_TRANSFER_ENCODING in self.headers:
446 data = self._decode_content_transfer(data)
447 if CONTENT_ENCODING in self.headers:
448 return self._decode_content(data)
449 return data
451 def _decode_content(self, data: bytes) -> bytes:
452 encoding = self.headers.get(CONTENT_ENCODING, "").lower()
454 if encoding == "deflate":
455 return zlib.decompress(data, -zlib.MAX_WBITS)
456 elif encoding == "gzip":
457 return zlib.decompress(data, 16 + zlib.MAX_WBITS)
458 elif encoding == "identity":
459 return data
460 else:
461 raise RuntimeError(f"unknown content encoding: {encoding}")
463 def _decode_content_transfer(self, data: bytes) -> bytes:
464 encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
466 if encoding == "base64":
467 return base64.b64decode(data)
468 elif encoding == "quoted-printable":
469 return binascii.a2b_qp(data)
470 elif encoding in ("binary", "8bit", "7bit"):
471 return data
472 else:
473 raise RuntimeError(
474 "unknown content transfer encoding: {}" "".format(encoding)
475 )
477 def get_charset(self, default: str) -> str:
478 """Returns charset parameter from Content-Type header or default."""
479 ctype = self.headers.get(CONTENT_TYPE, "")
480 mimetype = parse_mimetype(ctype)
481 return mimetype.parameters.get("charset", default)
483 @reify
484 def name(self) -> Optional[str]:
485 """Returns name specified in Content-Disposition header.
487 If the header is missing or malformed, returns None.
488 """
489 _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
490 return content_disposition_filename(params, "name")
492 @reify
493 def filename(self) -> Optional[str]:
494 """Returns filename specified in Content-Disposition header.
496 Returns None if the header is missing or malformed.
497 """
498 _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
499 return content_disposition_filename(params, "filename")
502@payload_type(BodyPartReader, order=Order.try_first)
503class BodyPartReaderPayload(Payload):
504 def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None:
505 super().__init__(value, *args, **kwargs)
507 params: Dict[str, str] = {}
508 if value.name is not None:
509 params["name"] = value.name
510 if value.filename is not None:
511 params["filename"] = value.filename
513 if params:
514 self.set_content_disposition("attachment", True, **params)
516 async def write(self, writer: Any) -> None:
517 field = self._value
518 chunk = await field.read_chunk(size=2**16)
519 while chunk:
520 await writer.write(field.decode(chunk))
521 chunk = await field.read_chunk(size=2**16)
524class MultipartReader:
525 """Multipart body reader."""
527 #: Response wrapper, used when multipart readers constructs from response.
528 response_wrapper_cls = MultipartResponseWrapper
529 #: Multipart reader class, used to handle multipart/* body parts.
530 #: None points to type(self)
531 multipart_reader_cls = None
532 #: Body part reader class for non multipart/* content types.
533 part_reader_cls = BodyPartReader
535 def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None:
536 self.headers = headers
537 self._boundary = ("--" + self._get_boundary()).encode()
538 self._content = content
539 self._last_part: Optional[Union["MultipartReader", BodyPartReader]] = None
540 self._at_eof = False
541 self._at_bof = True
542 self._unread: List[bytes] = []
544 def __aiter__(
545 self,
546 ) -> AsyncIterator["BodyPartReader"]:
547 return self # type: ignore[return-value]
549 async def __anext__(
550 self,
551 ) -> Optional[Union["MultipartReader", BodyPartReader]]:
552 part = await self.next()
553 if part is None:
554 raise StopAsyncIteration
555 return part
557 @classmethod
558 def from_response(
559 cls,
560 response: "ClientResponse",
561 ) -> MultipartResponseWrapper:
562 """Constructs reader instance from HTTP response.
564 :param response: :class:`~aiohttp.client.ClientResponse` instance
565 """
566 obj = cls.response_wrapper_cls(
567 response, cls(response.headers, response.content)
568 )
569 return obj
571 def at_eof(self) -> bool:
572 """Returns True if the final boundary was reached, false otherwise."""
573 return self._at_eof
575 async def next(
576 self,
577 ) -> Optional[Union["MultipartReader", BodyPartReader]]:
578 """Emits the next multipart body part."""
579 # So, if we're at BOF, we need to skip till the boundary.
580 if self._at_eof:
581 return None
582 await self._maybe_release_last_part()
583 if self._at_bof:
584 await self._read_until_first_boundary()
585 self._at_bof = False
586 else:
587 await self._read_boundary()
588 if self._at_eof: # we just read the last boundary, nothing to do there
589 return None
590 self._last_part = await self.fetch_next_part()
591 return self._last_part
593 async def release(self) -> None:
594 """Reads all the body parts to the void till the final boundary."""
595 while not self._at_eof:
596 item = await self.next()
597 if item is None:
598 break
599 await item.release()
601 async def fetch_next_part(
602 self,
603 ) -> Union["MultipartReader", BodyPartReader]:
604 """Returns the next body part reader."""
605 headers = await self._read_headers()
606 return self._get_part_reader(headers)
608 def _get_part_reader(
609 self,
610 headers: "CIMultiDictProxy[str]",
611 ) -> Union["MultipartReader", BodyPartReader]:
612 """Dispatches the response by the `Content-Type` header.
614 Returns a suitable reader instance.
616 :param dict headers: Response headers
617 """
618 ctype = headers.get(CONTENT_TYPE, "")
619 mimetype = parse_mimetype(ctype)
621 if mimetype.type == "multipart":
622 if self.multipart_reader_cls is None:
623 return type(self)(headers, self._content)
624 return self.multipart_reader_cls(headers, self._content)
625 else:
626 return self.part_reader_cls(self._boundary, headers, self._content)
628 def _get_boundary(self) -> str:
629 mimetype = parse_mimetype(self.headers[CONTENT_TYPE])
631 assert mimetype.type == "multipart", "multipart/* content type expected"
633 if "boundary" not in mimetype.parameters:
634 raise ValueError(
635 "boundary missed for Content-Type: %s" % self.headers[CONTENT_TYPE]
636 )
638 boundary = mimetype.parameters["boundary"]
639 if len(boundary) > 70:
640 raise ValueError("boundary %r is too long (70 chars max)" % boundary)
642 return boundary
644 async def _readline(self) -> bytes:
645 if self._unread:
646 return self._unread.pop()
647 return await self._content.readline()
649 async def _read_until_first_boundary(self) -> None:
650 while True:
651 chunk = await self._readline()
652 if chunk == b"":
653 raise ValueError(
654 "Could not find starting boundary %r" % (self._boundary)
655 )
656 chunk = chunk.rstrip()
657 if chunk == self._boundary:
658 return
659 elif chunk == self._boundary + b"--":
660 self._at_eof = True
661 return
663 async def _read_boundary(self) -> None:
664 chunk = (await self._readline()).rstrip()
665 if chunk == self._boundary:
666 pass
667 elif chunk == self._boundary + b"--":
668 self._at_eof = True
669 epilogue = await self._readline()
670 next_line = await self._readline()
672 # the epilogue is expected and then either the end of input or the
673 # parent multipart boundary, if the parent boundary is found then
674 # it should be marked as unread and handed to the parent for
675 # processing
676 if next_line[:2] == b"--":
677 self._unread.append(next_line)
678 # otherwise the request is likely missing an epilogue and both
679 # lines should be passed to the parent for processing
680 # (this handles the old behavior gracefully)
681 else:
682 self._unread.extend([next_line, epilogue])
683 else:
684 raise ValueError(f"Invalid boundary {chunk!r}, expected {self._boundary!r}")
686 async def _read_headers(self) -> "CIMultiDictProxy[str]":
687 lines = [b""]
688 while True:
689 chunk = await self._content.readline()
690 chunk = chunk.strip()
691 lines.append(chunk)
692 if not chunk:
693 break
694 parser = HeadersParser()
695 headers, raw_headers = parser.parse_headers(lines)
696 return headers
698 async def _maybe_release_last_part(self) -> None:
699 """Ensures that the last read body part is read completely."""
700 if self._last_part is not None:
701 if not self._last_part.at_eof():
702 await self._last_part.release()
703 self._unread.extend(self._last_part._unread)
704 self._last_part = None
707_Part = Tuple[Payload, str, str]
710class MultipartWriter(Payload):
711 """Multipart body writer."""
713 def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> None:
714 boundary = boundary if boundary is not None else uuid.uuid4().hex
715 # The underlying Payload API demands a str (utf-8), not bytes,
716 # so we need to ensure we don't lose anything during conversion.
717 # As a result, require the boundary to be ASCII only.
718 # In both situations.
720 try:
721 self._boundary = boundary.encode("ascii")
722 except UnicodeEncodeError:
723 raise ValueError("boundary should contain ASCII only chars") from None
724 ctype = f"multipart/{subtype}; boundary={self._boundary_value}"
726 super().__init__(None, content_type=ctype)
728 self._parts: List[_Part] = []
730 def __enter__(self) -> "MultipartWriter":
731 return self
733 def __exit__(
734 self,
735 exc_type: Optional[Type[BaseException]],
736 exc_val: Optional[BaseException],
737 exc_tb: Optional[TracebackType],
738 ) -> None:
739 pass
741 def __iter__(self) -> Iterator[_Part]:
742 return iter(self._parts)
744 def __len__(self) -> int:
745 return len(self._parts)
747 def __bool__(self) -> bool:
748 return True
750 _valid_tchar_regex = re.compile(rb"\A[!#$%&'*+\-.^_`|~\w]+\Z")
751 _invalid_qdtext_char_regex = re.compile(rb"[\x00-\x08\x0A-\x1F\x7F]")
753 @property
754 def _boundary_value(self) -> str:
755 """Wrap boundary parameter value in quotes, if necessary.
757 Reads self.boundary and returns a unicode sting.
758 """
759 # Refer to RFCs 7231, 7230, 5234.
760 #
761 # parameter = token "=" ( token / quoted-string )
762 # token = 1*tchar
763 # quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE
764 # qdtext = HTAB / SP / %x21 / %x23-5B / %x5D-7E / obs-text
765 # obs-text = %x80-FF
766 # quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text )
767 # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*"
768 # / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~"
769 # / DIGIT / ALPHA
770 # ; any VCHAR, except delimiters
771 # VCHAR = %x21-7E
772 value = self._boundary
773 if re.match(self._valid_tchar_regex, value):
774 return value.decode("ascii") # cannot fail
776 if re.search(self._invalid_qdtext_char_regex, value):
777 raise ValueError("boundary value contains invalid characters")
779 # escape %x5C and %x22
780 quoted_value_content = value.replace(b"\\", b"\\\\")
781 quoted_value_content = quoted_value_content.replace(b'"', b'\\"')
783 return '"' + quoted_value_content.decode("ascii") + '"'
785 @property
786 def boundary(self) -> str:
787 return self._boundary.decode("ascii")
789 def append(self, obj: Any, headers: Optional[MultiMapping[str]] = None) -> Payload:
790 if headers is None:
791 headers = CIMultiDict()
793 if isinstance(obj, Payload):
794 obj.headers.update(headers)
795 return self.append_payload(obj)
796 else:
797 try:
798 payload = get_payload(obj, headers=headers)
799 except LookupError:
800 raise TypeError("Cannot create payload from %r" % obj)
801 else:
802 return self.append_payload(payload)
804 def append_payload(self, payload: Payload) -> Payload:
805 """Adds a new body part to multipart writer."""
806 # compression
807 encoding: Optional[str] = payload.headers.get(
808 CONTENT_ENCODING,
809 "",
810 ).lower()
811 if encoding and encoding not in ("deflate", "gzip", "identity"):
812 raise RuntimeError(f"unknown content encoding: {encoding}")
813 if encoding == "identity":
814 encoding = None
816 # te encoding
817 te_encoding: Optional[str] = payload.headers.get(
818 CONTENT_TRANSFER_ENCODING,
819 "",
820 ).lower()
821 if te_encoding not in ("", "base64", "quoted-printable", "binary"):
822 raise RuntimeError(
823 "unknown content transfer encoding: {}" "".format(te_encoding)
824 )
825 if te_encoding == "binary":
826 te_encoding = None
828 # size
829 size = payload.size
830 if size is not None and not (encoding or te_encoding):
831 payload.headers[CONTENT_LENGTH] = str(size)
833 self._parts.append((payload, encoding, te_encoding)) # type: ignore[arg-type]
834 return payload
836 def append_json(
837 self, obj: Any, headers: Optional[MultiMapping[str]] = None
838 ) -> Payload:
839 """Helper to append JSON part."""
840 if headers is None:
841 headers = CIMultiDict()
843 return self.append_payload(JsonPayload(obj, headers=headers))
845 def append_form(
846 self,
847 obj: Union[Sequence[Tuple[str, str]], Mapping[str, str]],
848 headers: Optional[MultiMapping[str]] = None,
849 ) -> Payload:
850 """Helper to append form urlencoded part."""
851 assert isinstance(obj, (Sequence, Mapping))
853 if headers is None:
854 headers = CIMultiDict()
856 if isinstance(obj, Mapping):
857 obj = list(obj.items())
858 data = urlencode(obj, doseq=True)
860 return self.append_payload(
861 StringPayload(
862 data, headers=headers, content_type="application/x-www-form-urlencoded"
863 )
864 )
866 @property
867 def size(self) -> Optional[int]:
868 """Size of the payload."""
869 total = 0
870 for part, encoding, te_encoding in self._parts:
871 if encoding or te_encoding or part.size is None:
872 return None
874 total += int(
875 2
876 + len(self._boundary)
877 + 2
878 + part.size # b'--'+self._boundary+b'\r\n'
879 + len(part._binary_headers)
880 + 2 # b'\r\n'
881 )
883 total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n'
884 return total
886 async def write(self, writer: Any, close_boundary: bool = True) -> None:
887 """Write body."""
888 for part, encoding, te_encoding in self._parts:
889 await writer.write(b"--" + self._boundary + b"\r\n")
890 await writer.write(part._binary_headers)
892 if encoding or te_encoding:
893 w = MultipartPayloadWriter(writer)
894 if encoding:
895 w.enable_compression(encoding)
896 if te_encoding:
897 w.enable_encoding(te_encoding)
898 await part.write(w) # type: ignore[arg-type]
899 await w.write_eof()
900 else:
901 await part.write(writer)
903 await writer.write(b"\r\n")
905 if close_boundary:
906 await writer.write(b"--" + self._boundary + b"--\r\n")
909class MultipartPayloadWriter:
910 def __init__(self, writer: Any) -> None:
911 self._writer = writer
912 self._encoding: Optional[str] = None
913 self._compress: Any = None
914 self._encoding_buffer: Optional[bytearray] = None
916 def enable_encoding(self, encoding: str) -> None:
917 if encoding == "base64":
918 self._encoding = encoding
919 self._encoding_buffer = bytearray()
920 elif encoding == "quoted-printable":
921 self._encoding = "quoted-printable"
923 def enable_compression(
924 self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
925 ) -> None:
926 zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else -zlib.MAX_WBITS
927 self._compress = zlib.compressobj(wbits=zlib_mode, strategy=strategy)
929 async def write_eof(self) -> None:
930 if self._compress is not None:
931 chunk = self._compress.flush()
932 if chunk:
933 self._compress = None
934 await self.write(chunk)
936 if self._encoding == "base64":
937 if self._encoding_buffer:
938 await self._writer.write(base64.b64encode(self._encoding_buffer))
940 async def write(self, chunk: bytes) -> None:
941 if self._compress is not None:
942 if chunk:
943 chunk = self._compress.compress(chunk)
944 if not chunk:
945 return
947 if self._encoding == "base64":
948 buf = self._encoding_buffer
949 assert buf is not None
950 buf.extend(chunk)
952 if buf:
953 div, mod = divmod(len(buf), 3)
954 enc_chunk, self._encoding_buffer = (buf[: div * 3], buf[div * 3 :])
955 if enc_chunk:
956 b64chunk = base64.b64encode(enc_chunk)
957 await self._writer.write(b64chunk)
958 elif self._encoding == "quoted-printable":
959 await self._writer.write(binascii.b2a_qp(chunk))
960 else:
961 await self._writer.write(chunk)