Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/multipart.py: 18%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import base64
2import binascii
3import json
4import re
5import sys
6import uuid
7import warnings
8from collections import deque
9from types import TracebackType
10from typing import (
11 TYPE_CHECKING,
12 Any,
13 Deque,
14 Dict,
15 Iterator,
16 List,
17 Mapping,
18 Optional,
19 Sequence,
20 Tuple,
21 Type,
22 Union,
23 cast,
24)
25from urllib.parse import parse_qsl, unquote, urlencode
27from multidict import CIMultiDict, CIMultiDictProxy
29from .compression_utils import ZLibCompressor, ZLibDecompressor
30from .hdrs import (
31 CONTENT_DISPOSITION,
32 CONTENT_ENCODING,
33 CONTENT_LENGTH,
34 CONTENT_TRANSFER_ENCODING,
35 CONTENT_TYPE,
36)
37from .helpers import CHAR, TOKEN, parse_mimetype, reify
38from .http import HeadersParser
39from .payload import (
40 JsonPayload,
41 LookupError,
42 Order,
43 Payload,
44 StringPayload,
45 get_payload,
46 payload_type,
47)
48from .streams import StreamReader
50if sys.version_info >= (3, 11):
51 from typing import Self
52else:
53 from typing import TypeVar
55 Self = TypeVar("Self", bound="BodyPartReader")
57__all__ = (
58 "MultipartReader",
59 "MultipartWriter",
60 "BodyPartReader",
61 "BadContentDispositionHeader",
62 "BadContentDispositionParam",
63 "parse_content_disposition",
64 "content_disposition_filename",
65)
68if TYPE_CHECKING:
69 from .client_reqrep import ClientResponse
72class BadContentDispositionHeader(RuntimeWarning):
73 pass
76class BadContentDispositionParam(RuntimeWarning):
77 pass
80def parse_content_disposition(
81 header: Optional[str],
82) -> Tuple[Optional[str], Dict[str, str]]:
83 def is_token(string: str) -> bool:
84 return bool(string) and TOKEN >= set(string)
86 def is_quoted(string: str) -> bool:
87 return string[0] == string[-1] == '"'
89 def is_rfc5987(string: str) -> bool:
90 return is_token(string) and string.count("'") == 2
92 def is_extended_param(string: str) -> bool:
93 return string.endswith("*")
95 def is_continuous_param(string: str) -> bool:
96 pos = string.find("*") + 1
97 if not pos:
98 return False
99 substring = string[pos:-1] if string.endswith("*") else string[pos:]
100 return substring.isdigit()
102 def unescape(text: str, *, chars: str = "".join(map(re.escape, CHAR))) -> str:
103 return re.sub(f"\\\\([{chars}])", "\\1", text)
105 if not header:
106 return None, {}
108 disptype, *parts = header.split(";")
109 if not is_token(disptype):
110 warnings.warn(BadContentDispositionHeader(header))
111 return None, {}
113 params: Dict[str, str] = {}
114 while parts:
115 item = parts.pop(0)
117 if "=" not in item:
118 warnings.warn(BadContentDispositionHeader(header))
119 return None, {}
121 key, value = item.split("=", 1)
122 key = key.lower().strip()
123 value = value.lstrip()
125 if key in params:
126 warnings.warn(BadContentDispositionHeader(header))
127 return None, {}
129 if not is_token(key):
130 warnings.warn(BadContentDispositionParam(item))
131 continue
133 elif is_continuous_param(key):
134 if is_quoted(value):
135 value = unescape(value[1:-1])
136 elif not is_token(value):
137 warnings.warn(BadContentDispositionParam(item))
138 continue
140 elif is_extended_param(key):
141 if is_rfc5987(value):
142 encoding, _, value = value.split("'", 2)
143 encoding = encoding or "utf-8"
144 else:
145 warnings.warn(BadContentDispositionParam(item))
146 continue
148 try:
149 value = unquote(value, encoding, "strict")
150 except UnicodeDecodeError: # pragma: nocover
151 warnings.warn(BadContentDispositionParam(item))
152 continue
154 else:
155 failed = True
156 if is_quoted(value):
157 failed = False
158 value = unescape(value[1:-1].lstrip("\\/"))
159 elif is_token(value):
160 failed = False
161 elif parts:
162 # maybe just ; in filename, in any case this is just
163 # one case fix, for proper fix we need to redesign parser
164 _value = f"{value};{parts[0]}"
165 if is_quoted(_value):
166 parts.pop(0)
167 value = unescape(_value[1:-1].lstrip("\\/"))
168 failed = False
170 if failed:
171 warnings.warn(BadContentDispositionHeader(header))
172 return None, {}
174 params[key] = value
176 return disptype.lower(), params
179def content_disposition_filename(
180 params: Mapping[str, str], name: str = "filename"
181) -> Optional[str]:
182 name_suf = "%s*" % name
183 if not params:
184 return None
185 elif name_suf in params:
186 return params[name_suf]
187 elif name in params:
188 return params[name]
189 else:
190 parts = []
191 fnparams = sorted(
192 (key, value) for key, value in params.items() if key.startswith(name_suf)
193 )
194 for num, (key, value) in enumerate(fnparams):
195 _, tail = key.split("*", 1)
196 if tail.endswith("*"):
197 tail = tail[:-1]
198 if tail == str(num):
199 parts.append(value)
200 else:
201 break
202 if not parts:
203 return None
204 value = "".join(parts)
205 if "'" in value:
206 encoding, _, value = value.split("'", 2)
207 encoding = encoding or "utf-8"
208 return unquote(value, encoding, "strict")
209 return value
212class MultipartResponseWrapper:
213 """Wrapper around the MultipartReader.
215 It takes care about
216 underlying connection and close it when it needs in.
217 """
219 def __init__(
220 self,
221 resp: "ClientResponse",
222 stream: "MultipartReader",
223 ) -> None:
224 self.resp = resp
225 self.stream = stream
227 def __aiter__(self) -> "MultipartResponseWrapper":
228 return self
230 async def __anext__(
231 self,
232 ) -> Union["MultipartReader", "BodyPartReader"]:
233 part = await self.next()
234 if part is None:
235 raise StopAsyncIteration
236 return part
238 def at_eof(self) -> bool:
239 """Returns True when all response data had been read."""
240 return self.resp.content.at_eof()
242 async def next(
243 self,
244 ) -> Optional[Union["MultipartReader", "BodyPartReader"]]:
245 """Emits next multipart reader object."""
246 item = await self.stream.next()
247 if self.stream.at_eof():
248 await self.release()
249 return item
251 async def release(self) -> None:
252 """Release the connection gracefully.
254 All remaining content is read to the void.
255 """
256 self.resp.release()
259class BodyPartReader:
260 """Multipart reader for single body part."""
262 chunk_size = 8192
264 def __init__(
265 self,
266 boundary: bytes,
267 headers: "CIMultiDictProxy[str]",
268 content: StreamReader,
269 *,
270 subtype: str = "mixed",
271 default_charset: Optional[str] = None,
272 ) -> None:
273 self.headers = headers
274 self._boundary = boundary
275 self._boundary_len = len(boundary) + 2 # Boundary + \r\n
276 self._content = content
277 self._default_charset = default_charset
278 self._at_eof = False
279 self._is_form_data = subtype == "form-data"
280 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
281 length = None if self._is_form_data else self.headers.get(CONTENT_LENGTH, None)
282 self._length = int(length) if length is not None else None
283 self._read_bytes = 0
284 self._unread: Deque[bytes] = deque()
285 self._prev_chunk: Optional[bytes] = None
286 self._content_eof = 0
287 self._cache: Dict[str, Any] = {}
289 def __aiter__(self) -> Self:
290 return self
292 async def __anext__(self) -> bytes:
293 part = await self.next()
294 if part is None:
295 raise StopAsyncIteration
296 return part
298 async def next(self) -> Optional[bytes]:
299 item = await self.read()
300 if not item:
301 return None
302 return item
304 async def read(self, *, decode: bool = False) -> bytes:
305 """Reads body part data.
307 decode: Decodes data following by encoding
308 method from Content-Encoding header. If it missed
309 data remains untouched
310 """
311 if self._at_eof:
312 return b""
313 data = bytearray()
314 while not self._at_eof:
315 data.extend(await self.read_chunk(self.chunk_size))
316 # https://github.com/python/mypy/issues/17537
317 if decode: # type: ignore[unreachable]
318 return self.decode(data)
319 return data
321 async def read_chunk(self, size: int = chunk_size) -> bytes:
322 """Reads body part content chunk of the specified size.
324 size: chunk size
325 """
326 if self._at_eof:
327 return b""
328 if self._length:
329 chunk = await self._read_chunk_from_length(size)
330 else:
331 chunk = await self._read_chunk_from_stream(size)
333 # For the case of base64 data, we must read a fragment of size with a
334 # remainder of 0 by dividing by 4 for string without symbols \n or \r
335 encoding = self.headers.get(CONTENT_TRANSFER_ENCODING)
336 if encoding and encoding.lower() == "base64":
337 stripped_chunk = b"".join(chunk.split())
338 remainder = len(stripped_chunk) % 4
340 while remainder != 0 and not self.at_eof():
341 over_chunk_size = 4 - remainder
342 over_chunk = b""
344 if self._prev_chunk:
345 over_chunk = self._prev_chunk[:over_chunk_size]
346 self._prev_chunk = self._prev_chunk[len(over_chunk) :]
348 if len(over_chunk) != over_chunk_size:
349 over_chunk += await self._content.read(4 - len(over_chunk))
351 if not over_chunk:
352 self._at_eof = True
354 stripped_chunk += b"".join(over_chunk.split())
355 chunk += over_chunk
356 remainder = len(stripped_chunk) % 4
358 self._read_bytes += len(chunk)
359 if self._read_bytes == self._length:
360 self._at_eof = True
361 if self._at_eof:
362 clrf = await self._content.readline()
363 assert (
364 b"\r\n" == clrf
365 ), "reader did not read all the data or it is malformed"
366 return chunk
368 async def _read_chunk_from_length(self, size: int) -> bytes:
369 # Reads body part content chunk of the specified size.
370 # The body part must has Content-Length header with proper value.
371 assert self._length is not None, "Content-Length required for chunked read"
372 chunk_size = min(size, self._length - self._read_bytes)
373 chunk = await self._content.read(chunk_size)
374 if self._content.at_eof():
375 self._at_eof = True
376 return chunk
378 async def _read_chunk_from_stream(self, size: int) -> bytes:
379 # Reads content chunk of body part with unknown length.
380 # The Content-Length header for body part is not necessary.
381 assert (
382 size >= self._boundary_len
383 ), "Chunk size must be greater or equal than boundary length + 2"
384 first_chunk = self._prev_chunk is None
385 if first_chunk:
386 self._prev_chunk = await self._content.read(size)
388 chunk = b""
389 # content.read() may return less than size, so we need to loop to ensure
390 # we have enough data to detect the boundary.
391 while len(chunk) < self._boundary_len:
392 chunk += await self._content.read(size)
393 self._content_eof += int(self._content.at_eof())
394 assert self._content_eof < 3, "Reading after EOF"
395 if self._content_eof:
396 break
397 if len(chunk) > size:
398 self._content.unread_data(chunk[size:])
399 chunk = chunk[:size]
401 assert self._prev_chunk is not None
402 window = self._prev_chunk + chunk
403 sub = b"\r\n" + self._boundary
404 if first_chunk:
405 idx = window.find(sub)
406 else:
407 idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub)))
408 if idx >= 0:
409 # pushing boundary back to content
410 with warnings.catch_warnings():
411 warnings.filterwarnings("ignore", category=DeprecationWarning)
412 self._content.unread_data(window[idx:])
413 if size > idx:
414 self._prev_chunk = self._prev_chunk[:idx]
415 chunk = window[len(self._prev_chunk) : idx]
416 if not chunk:
417 self._at_eof = True
418 result = self._prev_chunk
419 self._prev_chunk = chunk
420 return result
422 async def readline(self) -> bytes:
423 """Reads body part by line by line."""
424 if self._at_eof:
425 return b""
427 if self._unread:
428 line = self._unread.popleft()
429 else:
430 line = await self._content.readline()
432 if line.startswith(self._boundary):
433 # the very last boundary may not come with \r\n,
434 # so set single rules for everyone
435 sline = line.rstrip(b"\r\n")
436 boundary = self._boundary
437 last_boundary = self._boundary + b"--"
438 # ensure that we read exactly the boundary, not something alike
439 if sline == boundary or sline == last_boundary:
440 self._at_eof = True
441 self._unread.append(line)
442 return b""
443 else:
444 next_line = await self._content.readline()
445 if next_line.startswith(self._boundary):
446 line = line[:-2] # strip CRLF but only once
447 self._unread.append(next_line)
449 return line
451 async def release(self) -> None:
452 """Like read(), but reads all the data to the void."""
453 if self._at_eof:
454 return
455 while not self._at_eof:
456 await self.read_chunk(self.chunk_size)
458 async def text(self, *, encoding: Optional[str] = None) -> str:
459 """Like read(), but assumes that body part contains text data."""
460 data = await self.read(decode=True)
461 # see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm
462 # and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send
463 encoding = encoding or self.get_charset(default="utf-8")
464 return data.decode(encoding)
466 async def json(self, *, encoding: Optional[str] = None) -> Optional[Dict[str, Any]]:
467 """Like read(), but assumes that body parts contains JSON data."""
468 data = await self.read(decode=True)
469 if not data:
470 return None
471 encoding = encoding or self.get_charset(default="utf-8")
472 return cast(Dict[str, Any], json.loads(data.decode(encoding)))
474 async def form(self, *, encoding: Optional[str] = None) -> List[Tuple[str, str]]:
475 """Like read(), but assumes that body parts contain form urlencoded data."""
476 data = await self.read(decode=True)
477 if not data:
478 return []
479 if encoding is not None:
480 real_encoding = encoding
481 else:
482 real_encoding = self.get_charset(default="utf-8")
483 try:
484 decoded_data = data.rstrip().decode(real_encoding)
485 except UnicodeDecodeError:
486 raise ValueError("data cannot be decoded with %s encoding" % real_encoding)
488 return parse_qsl(
489 decoded_data,
490 keep_blank_values=True,
491 encoding=real_encoding,
492 )
494 def at_eof(self) -> bool:
495 """Returns True if the boundary was reached or False otherwise."""
496 return self._at_eof
498 def decode(self, data: bytes) -> bytes:
499 """Decodes data.
501 Decoding is done according the specified Content-Encoding
502 or Content-Transfer-Encoding headers value.
503 """
504 if CONTENT_TRANSFER_ENCODING in self.headers:
505 data = self._decode_content_transfer(data)
506 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
507 if not self._is_form_data and CONTENT_ENCODING in self.headers:
508 return self._decode_content(data)
509 return data
511 def _decode_content(self, data: bytes) -> bytes:
512 encoding = self.headers.get(CONTENT_ENCODING, "").lower()
513 if encoding == "identity":
514 return data
515 if encoding in {"deflate", "gzip"}:
516 return ZLibDecompressor(
517 encoding=encoding,
518 suppress_deflate_header=True,
519 ).decompress_sync(data)
521 raise RuntimeError(f"unknown content encoding: {encoding}")
523 def _decode_content_transfer(self, data: bytes) -> bytes:
524 encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
526 if encoding == "base64":
527 return base64.b64decode(data)
528 elif encoding == "quoted-printable":
529 return binascii.a2b_qp(data)
530 elif encoding in ("binary", "8bit", "7bit"):
531 return data
532 else:
533 raise RuntimeError(f"unknown content transfer encoding: {encoding}")
535 def get_charset(self, default: str) -> str:
536 """Returns charset parameter from Content-Type header or default."""
537 ctype = self.headers.get(CONTENT_TYPE, "")
538 mimetype = parse_mimetype(ctype)
539 return mimetype.parameters.get("charset", self._default_charset or default)
541 @reify
542 def name(self) -> Optional[str]:
543 """Returns name specified in Content-Disposition header.
545 If the header is missing or malformed, returns None.
546 """
547 _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
548 return content_disposition_filename(params, "name")
550 @reify
551 def filename(self) -> Optional[str]:
552 """Returns filename specified in Content-Disposition header.
554 Returns None if the header is missing or malformed.
555 """
556 _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
557 return content_disposition_filename(params, "filename")
560@payload_type(BodyPartReader, order=Order.try_first)
561class BodyPartReaderPayload(Payload):
562 _value: BodyPartReader
564 def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None:
565 super().__init__(value, *args, **kwargs)
567 params: Dict[str, str] = {}
568 if value.name is not None:
569 params["name"] = value.name
570 if value.filename is not None:
571 params["filename"] = value.filename
573 if params:
574 self.set_content_disposition("attachment", True, **params)
576 def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
577 raise TypeError("Unable to decode.")
579 async def write(self, writer: Any) -> None:
580 field = self._value
581 chunk = await field.read_chunk(size=2**16)
582 while chunk:
583 await writer.write(field.decode(chunk))
584 chunk = await field.read_chunk(size=2**16)
587class MultipartReader:
588 """Multipart body reader."""
590 #: Response wrapper, used when multipart readers constructs from response.
591 response_wrapper_cls = MultipartResponseWrapper
592 #: Multipart reader class, used to handle multipart/* body parts.
593 #: None points to type(self)
594 multipart_reader_cls: Optional[Type["MultipartReader"]] = None
595 #: Body part reader class for non multipart/* content types.
596 part_reader_cls = BodyPartReader
598 def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None:
599 self._mimetype = parse_mimetype(headers[CONTENT_TYPE])
600 assert self._mimetype.type == "multipart", "multipart/* content type expected"
601 if "boundary" not in self._mimetype.parameters:
602 raise ValueError(
603 "boundary missed for Content-Type: %s" % headers[CONTENT_TYPE]
604 )
606 self.headers = headers
607 self._boundary = ("--" + self._get_boundary()).encode()
608 self._content = content
609 self._default_charset: Optional[str] = None
610 self._last_part: Optional[Union["MultipartReader", BodyPartReader]] = None
611 self._at_eof = False
612 self._at_bof = True
613 self._unread: List[bytes] = []
615 def __aiter__(self) -> Self:
616 return self
618 async def __anext__(
619 self,
620 ) -> Optional[Union["MultipartReader", BodyPartReader]]:
621 part = await self.next()
622 if part is None:
623 raise StopAsyncIteration
624 return part
626 @classmethod
627 def from_response(
628 cls,
629 response: "ClientResponse",
630 ) -> MultipartResponseWrapper:
631 """Constructs reader instance from HTTP response.
633 :param response: :class:`~aiohttp.client.ClientResponse` instance
634 """
635 obj = cls.response_wrapper_cls(
636 response, cls(response.headers, response.content)
637 )
638 return obj
640 def at_eof(self) -> bool:
641 """Returns True if the final boundary was reached, false otherwise."""
642 return self._at_eof
644 async def next(
645 self,
646 ) -> Optional[Union["MultipartReader", BodyPartReader]]:
647 """Emits the next multipart body part."""
648 # So, if we're at BOF, we need to skip till the boundary.
649 if self._at_eof:
650 return None
651 await self._maybe_release_last_part()
652 if self._at_bof:
653 await self._read_until_first_boundary()
654 self._at_bof = False
655 else:
656 await self._read_boundary()
657 if self._at_eof: # we just read the last boundary, nothing to do there
658 # https://github.com/python/mypy/issues/17537
659 return None # type: ignore[unreachable]
661 part = await self.fetch_next_part()
662 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.6
663 if (
664 self._last_part is None
665 and self._mimetype.subtype == "form-data"
666 and isinstance(part, BodyPartReader)
667 ):
668 _, params = parse_content_disposition(part.headers.get(CONTENT_DISPOSITION))
669 if params.get("name") == "_charset_":
670 # Longest encoding in https://encoding.spec.whatwg.org/encodings.json
671 # is 19 characters, so 32 should be more than enough for any valid encoding.
672 charset = await part.read_chunk(32)
673 if len(charset) > 31:
674 raise RuntimeError("Invalid default charset")
675 self._default_charset = charset.strip().decode()
676 part = await self.fetch_next_part()
677 self._last_part = part
678 return self._last_part
680 async def release(self) -> None:
681 """Reads all the body parts to the void till the final boundary."""
682 while not self._at_eof:
683 item = await self.next()
684 if item is None:
685 break
686 await item.release()
688 async def fetch_next_part(
689 self,
690 ) -> Union["MultipartReader", BodyPartReader]:
691 """Returns the next body part reader."""
692 headers = await self._read_headers()
693 return self._get_part_reader(headers)
695 def _get_part_reader(
696 self,
697 headers: "CIMultiDictProxy[str]",
698 ) -> Union["MultipartReader", BodyPartReader]:
699 """Dispatches the response by the `Content-Type` header.
701 Returns a suitable reader instance.
703 :param dict headers: Response headers
704 """
705 ctype = headers.get(CONTENT_TYPE, "")
706 mimetype = parse_mimetype(ctype)
708 if mimetype.type == "multipart":
709 if self.multipart_reader_cls is None:
710 return type(self)(headers, self._content)
711 return self.multipart_reader_cls(headers, self._content)
712 else:
713 return self.part_reader_cls(
714 self._boundary,
715 headers,
716 self._content,
717 subtype=self._mimetype.subtype,
718 default_charset=self._default_charset,
719 )
721 def _get_boundary(self) -> str:
722 boundary = self._mimetype.parameters["boundary"]
723 if len(boundary) > 70:
724 raise ValueError("boundary %r is too long (70 chars max)" % boundary)
726 return boundary
728 async def _readline(self) -> bytes:
729 if self._unread:
730 return self._unread.pop()
731 return await self._content.readline()
733 async def _read_until_first_boundary(self) -> None:
734 while True:
735 chunk = await self._readline()
736 if chunk == b"":
737 raise ValueError(f"Could not find starting boundary {self._boundary!r}")
738 chunk = chunk.rstrip()
739 if chunk == self._boundary:
740 return
741 elif chunk == self._boundary + b"--":
742 self._at_eof = True
743 return
745 async def _read_boundary(self) -> None:
746 chunk = (await self._readline()).rstrip()
747 if chunk == self._boundary:
748 pass
749 elif chunk == self._boundary + b"--":
750 self._at_eof = True
751 epilogue = await self._readline()
752 next_line = await self._readline()
754 # the epilogue is expected and then either the end of input or the
755 # parent multipart boundary, if the parent boundary is found then
756 # it should be marked as unread and handed to the parent for
757 # processing
758 if next_line[:2] == b"--":
759 self._unread.append(next_line)
760 # otherwise the request is likely missing an epilogue and both
761 # lines should be passed to the parent for processing
762 # (this handles the old behavior gracefully)
763 else:
764 self._unread.extend([next_line, epilogue])
765 else:
766 raise ValueError(f"Invalid boundary {chunk!r}, expected {self._boundary!r}")
768 async def _read_headers(self) -> "CIMultiDictProxy[str]":
769 lines = [b""]
770 while True:
771 chunk = await self._content.readline()
772 chunk = chunk.strip()
773 lines.append(chunk)
774 if not chunk:
775 break
776 parser = HeadersParser()
777 headers, raw_headers = parser.parse_headers(lines)
778 return headers
780 async def _maybe_release_last_part(self) -> None:
781 """Ensures that the last read body part is read completely."""
782 if self._last_part is not None:
783 if not self._last_part.at_eof():
784 await self._last_part.release()
785 self._unread.extend(self._last_part._unread)
786 self._last_part = None
789_Part = Tuple[Payload, str, str]
792class MultipartWriter(Payload):
793 """Multipart body writer."""
795 _value: None
797 def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> None:
798 boundary = boundary if boundary is not None else uuid.uuid4().hex
799 # The underlying Payload API demands a str (utf-8), not bytes,
800 # so we need to ensure we don't lose anything during conversion.
801 # As a result, require the boundary to be ASCII only.
802 # In both situations.
804 try:
805 self._boundary = boundary.encode("ascii")
806 except UnicodeEncodeError:
807 raise ValueError("boundary should contain ASCII only chars") from None
809 if len(boundary) > 70:
810 raise ValueError("boundary %r is too long (70 chars max)" % boundary)
812 ctype = f"multipart/{subtype}; boundary={self._boundary_value}"
814 super().__init__(None, content_type=ctype)
816 self._parts: List[_Part] = []
817 self._is_form_data = subtype == "form-data"
819 def __enter__(self) -> "MultipartWriter":
820 return self
822 def __exit__(
823 self,
824 exc_type: Optional[Type[BaseException]],
825 exc_val: Optional[BaseException],
826 exc_tb: Optional[TracebackType],
827 ) -> None:
828 pass
830 def __iter__(self) -> Iterator[_Part]:
831 return iter(self._parts)
833 def __len__(self) -> int:
834 return len(self._parts)
836 def __bool__(self) -> bool:
837 return True
839 _valid_tchar_regex = re.compile(rb"\A[!#$%&'*+\-.^_`|~\w]+\Z")
840 _invalid_qdtext_char_regex = re.compile(rb"[\x00-\x08\x0A-\x1F\x7F]")
842 @property
843 def _boundary_value(self) -> str:
844 """Wrap boundary parameter value in quotes, if necessary.
846 Reads self.boundary and returns a unicode string.
847 """
848 # Refer to RFCs 7231, 7230, 5234.
849 #
850 # parameter = token "=" ( token / quoted-string )
851 # token = 1*tchar
852 # quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE
853 # qdtext = HTAB / SP / %x21 / %x23-5B / %x5D-7E / obs-text
854 # obs-text = %x80-FF
855 # quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text )
856 # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*"
857 # / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~"
858 # / DIGIT / ALPHA
859 # ; any VCHAR, except delimiters
860 # VCHAR = %x21-7E
861 value = self._boundary
862 if re.match(self._valid_tchar_regex, value):
863 return value.decode("ascii") # cannot fail
865 if re.search(self._invalid_qdtext_char_regex, value):
866 raise ValueError("boundary value contains invalid characters")
868 # escape %x5C and %x22
869 quoted_value_content = value.replace(b"\\", b"\\\\")
870 quoted_value_content = quoted_value_content.replace(b'"', b'\\"')
872 return '"' + quoted_value_content.decode("ascii") + '"'
874 @property
875 def boundary(self) -> str:
876 return self._boundary.decode("ascii")
878 def append(self, obj: Any, headers: Optional[Mapping[str, str]] = None) -> Payload:
879 if headers is None:
880 headers = CIMultiDict()
882 if isinstance(obj, Payload):
883 obj.headers.update(headers)
884 return self.append_payload(obj)
885 else:
886 try:
887 payload = get_payload(obj, headers=headers)
888 except LookupError:
889 raise TypeError("Cannot create payload from %r" % obj)
890 else:
891 return self.append_payload(payload)
893 def append_payload(self, payload: Payload) -> Payload:
894 """Adds a new body part to multipart writer."""
895 encoding: Optional[str] = None
896 te_encoding: Optional[str] = None
897 if self._is_form_data:
898 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.7
899 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
900 assert (
901 not {CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING}
902 & payload.headers.keys()
903 )
904 # Set default Content-Disposition in case user doesn't create one
905 if CONTENT_DISPOSITION not in payload.headers:
906 name = f"section-{len(self._parts)}"
907 payload.set_content_disposition("form-data", name=name)
908 else:
909 # compression
910 encoding = payload.headers.get(CONTENT_ENCODING, "").lower()
911 if encoding and encoding not in ("deflate", "gzip", "identity"):
912 raise RuntimeError(f"unknown content encoding: {encoding}")
913 if encoding == "identity":
914 encoding = None
916 # te encoding
917 te_encoding = payload.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
918 if te_encoding not in ("", "base64", "quoted-printable", "binary"):
919 raise RuntimeError(f"unknown content transfer encoding: {te_encoding}")
920 if te_encoding == "binary":
921 te_encoding = None
923 # size
924 size = payload.size
925 if size is not None and not (encoding or te_encoding):
926 payload.headers[CONTENT_LENGTH] = str(size)
928 self._parts.append((payload, encoding, te_encoding)) # type: ignore[arg-type]
929 return payload
931 def append_json(
932 self, obj: Any, headers: Optional[Mapping[str, str]] = None
933 ) -> Payload:
934 """Helper to append JSON part."""
935 if headers is None:
936 headers = CIMultiDict()
938 return self.append_payload(JsonPayload(obj, headers=headers))
940 def append_form(
941 self,
942 obj: Union[Sequence[Tuple[str, str]], Mapping[str, str]],
943 headers: Optional[Mapping[str, str]] = None,
944 ) -> Payload:
945 """Helper to append form urlencoded part."""
946 assert isinstance(obj, (Sequence, Mapping))
948 if headers is None:
949 headers = CIMultiDict()
951 if isinstance(obj, Mapping):
952 obj = list(obj.items())
953 data = urlencode(obj, doseq=True)
955 return self.append_payload(
956 StringPayload(
957 data, headers=headers, content_type="application/x-www-form-urlencoded"
958 )
959 )
961 @property
962 def size(self) -> Optional[int]:
963 """Size of the payload."""
964 total = 0
965 for part, encoding, te_encoding in self._parts:
966 if encoding or te_encoding or part.size is None:
967 return None
969 total += int(
970 2
971 + len(self._boundary)
972 + 2
973 + part.size # b'--'+self._boundary+b'\r\n'
974 + len(part._binary_headers)
975 + 2 # b'\r\n'
976 )
978 total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n'
979 return total
981 def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
982 return "".join(
983 "--"
984 + self.boundary
985 + "\r\n"
986 + part._binary_headers.decode(encoding, errors)
987 + part.decode()
988 for part, _e, _te in self._parts
989 )
991 async def write(self, writer: Any, close_boundary: bool = True) -> None:
992 """Write body."""
993 for part, encoding, te_encoding in self._parts:
994 if self._is_form_data:
995 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.2
996 assert CONTENT_DISPOSITION in part.headers
997 assert "name=" in part.headers[CONTENT_DISPOSITION]
999 await writer.write(b"--" + self._boundary + b"\r\n")
1000 await writer.write(part._binary_headers)
1002 if encoding or te_encoding:
1003 w = MultipartPayloadWriter(writer)
1004 if encoding:
1005 w.enable_compression(encoding)
1006 if te_encoding:
1007 w.enable_encoding(te_encoding)
1008 await part.write(w) # type: ignore[arg-type]
1009 await w.write_eof()
1010 else:
1011 await part.write(writer)
1013 await writer.write(b"\r\n")
1015 if close_boundary:
1016 await writer.write(b"--" + self._boundary + b"--\r\n")
1019class MultipartPayloadWriter:
1020 def __init__(self, writer: Any) -> None:
1021 self._writer = writer
1022 self._encoding: Optional[str] = None
1023 self._compress: Optional[ZLibCompressor] = None
1024 self._encoding_buffer: Optional[bytearray] = None
1026 def enable_encoding(self, encoding: str) -> None:
1027 if encoding == "base64":
1028 self._encoding = encoding
1029 self._encoding_buffer = bytearray()
1030 elif encoding == "quoted-printable":
1031 self._encoding = "quoted-printable"
1033 def enable_compression(
1034 self, encoding: str = "deflate", strategy: Optional[int] = None
1035 ) -> None:
1036 self._compress = ZLibCompressor(
1037 encoding=encoding,
1038 suppress_deflate_header=True,
1039 strategy=strategy,
1040 )
1042 async def write_eof(self) -> None:
1043 if self._compress is not None:
1044 chunk = self._compress.flush()
1045 if chunk:
1046 self._compress = None
1047 await self.write(chunk)
1049 if self._encoding == "base64":
1050 if self._encoding_buffer:
1051 await self._writer.write(base64.b64encode(self._encoding_buffer))
1053 async def write(self, chunk: bytes) -> None:
1054 if self._compress is not None:
1055 if chunk:
1056 chunk = await self._compress.compress(chunk)
1057 if not chunk:
1058 return
1060 if self._encoding == "base64":
1061 buf = self._encoding_buffer
1062 assert buf is not None
1063 buf.extend(chunk)
1065 if buf:
1066 div, mod = divmod(len(buf), 3)
1067 enc_chunk, self._encoding_buffer = (buf[: div * 3], buf[div * 3 :])
1068 if enc_chunk:
1069 b64chunk = base64.b64encode(enc_chunk)
1070 await self._writer.write(b64chunk)
1071 elif self._encoding == "quoted-printable":
1072 await self._writer.write(binascii.b2a_qp(chunk))
1073 else:
1074 await self._writer.write(chunk)