1import base64
2import binascii
3import json
4import re
5import uuid
6import warnings
7import zlib
8from collections import deque
9from types import TracebackType
10from typing import (
11 TYPE_CHECKING,
12 Any,
13 AsyncIterator,
14 Deque,
15 Dict,
16 Iterator,
17 List,
18 Mapping,
19 Optional,
20 Sequence,
21 Tuple,
22 Type,
23 Union,
24 cast,
25)
26from urllib.parse import parse_qsl, unquote, urlencode
27
28from multidict import CIMultiDict, CIMultiDictProxy
29
30from .compression_utils import ZLibCompressor, ZLibDecompressor
31from .hdrs import (
32 CONTENT_DISPOSITION,
33 CONTENT_ENCODING,
34 CONTENT_LENGTH,
35 CONTENT_TRANSFER_ENCODING,
36 CONTENT_TYPE,
37)
38from .helpers import CHAR, TOKEN, parse_mimetype, reify
39from .http import HeadersParser
40from .payload import (
41 JsonPayload,
42 LookupError,
43 Order,
44 Payload,
45 StringPayload,
46 get_payload,
47 payload_type,
48)
49from .streams import StreamReader
50
51__all__ = (
52 "MultipartReader",
53 "MultipartWriter",
54 "BodyPartReader",
55 "BadContentDispositionHeader",
56 "BadContentDispositionParam",
57 "parse_content_disposition",
58 "content_disposition_filename",
59)
60
61
62if TYPE_CHECKING:
63 from .client_reqrep import ClientResponse
64
65
66class BadContentDispositionHeader(RuntimeWarning):
67 pass
68
69
70class BadContentDispositionParam(RuntimeWarning):
71 pass
72
73
74def parse_content_disposition(
75 header: Optional[str],
76) -> Tuple[Optional[str], Dict[str, str]]:
77 def is_token(string: str) -> bool:
78 return bool(string) and TOKEN >= set(string)
79
80 def is_quoted(string: str) -> bool:
81 return string[0] == string[-1] == '"'
82
83 def is_rfc5987(string: str) -> bool:
84 return is_token(string) and string.count("'") == 2
85
86 def is_extended_param(string: str) -> bool:
87 return string.endswith("*")
88
89 def is_continuous_param(string: str) -> bool:
90 pos = string.find("*") + 1
91 if not pos:
92 return False
93 substring = string[pos:-1] if string.endswith("*") else string[pos:]
94 return substring.isdigit()
95
96 def unescape(text: str, *, chars: str = "".join(map(re.escape, CHAR))) -> str:
97 return re.sub(f"\\\\([{chars}])", "\\1", text)
98
99 if not header:
100 return None, {}
101
102 disptype, *parts = header.split(";")
103 if not is_token(disptype):
104 warnings.warn(BadContentDispositionHeader(header))
105 return None, {}
106
107 params: Dict[str, str] = {}
108 while parts:
109 item = parts.pop(0)
110
111 if "=" not in item:
112 warnings.warn(BadContentDispositionHeader(header))
113 return None, {}
114
115 key, value = item.split("=", 1)
116 key = key.lower().strip()
117 value = value.lstrip()
118
119 if key in params:
120 warnings.warn(BadContentDispositionHeader(header))
121 return None, {}
122
123 if not is_token(key):
124 warnings.warn(BadContentDispositionParam(item))
125 continue
126
127 elif is_continuous_param(key):
128 if is_quoted(value):
129 value = unescape(value[1:-1])
130 elif not is_token(value):
131 warnings.warn(BadContentDispositionParam(item))
132 continue
133
134 elif is_extended_param(key):
135 if is_rfc5987(value):
136 encoding, _, value = value.split("'", 2)
137 encoding = encoding or "utf-8"
138 else:
139 warnings.warn(BadContentDispositionParam(item))
140 continue
141
142 try:
143 value = unquote(value, encoding, "strict")
144 except UnicodeDecodeError: # pragma: nocover
145 warnings.warn(BadContentDispositionParam(item))
146 continue
147
148 else:
149 failed = True
150 if is_quoted(value):
151 failed = False
152 value = unescape(value[1:-1].lstrip("\\/"))
153 elif is_token(value):
154 failed = False
155 elif parts:
156 # maybe just ; in filename, in any case this is just
157 # one case fix, for proper fix we need to redesign parser
158 _value = f"{value};{parts[0]}"
159 if is_quoted(_value):
160 parts.pop(0)
161 value = unescape(_value[1:-1].lstrip("\\/"))
162 failed = False
163
164 if failed:
165 warnings.warn(BadContentDispositionHeader(header))
166 return None, {}
167
168 params[key] = value
169
170 return disptype.lower(), params
171
172
173def content_disposition_filename(
174 params: Mapping[str, str], name: str = "filename"
175) -> Optional[str]:
176 name_suf = "%s*" % name
177 if not params:
178 return None
179 elif name_suf in params:
180 return params[name_suf]
181 elif name in params:
182 return params[name]
183 else:
184 parts = []
185 fnparams = sorted(
186 (key, value) for key, value in params.items() if key.startswith(name_suf)
187 )
188 for num, (key, value) in enumerate(fnparams):
189 _, tail = key.split("*", 1)
190 if tail.endswith("*"):
191 tail = tail[:-1]
192 if tail == str(num):
193 parts.append(value)
194 else:
195 break
196 if not parts:
197 return None
198 value = "".join(parts)
199 if "'" in value:
200 encoding, _, value = value.split("'", 2)
201 encoding = encoding or "utf-8"
202 return unquote(value, encoding, "strict")
203 return value
204
205
206class MultipartResponseWrapper:
207 """Wrapper around the MultipartReader.
208
209 It takes care about
210 underlying connection and close it when it needs in.
211 """
212
213 def __init__(
214 self,
215 resp: "ClientResponse",
216 stream: "MultipartReader",
217 ) -> None:
218 self.resp = resp
219 self.stream = stream
220
221 def __aiter__(self) -> "MultipartResponseWrapper":
222 return self
223
224 async def __anext__(
225 self,
226 ) -> Union["MultipartReader", "BodyPartReader"]:
227 part = await self.next()
228 if part is None:
229 raise StopAsyncIteration
230 return part
231
232 def at_eof(self) -> bool:
233 """Returns True when all response data had been read."""
234 return self.resp.content.at_eof()
235
236 async def next(
237 self,
238 ) -> Optional[Union["MultipartReader", "BodyPartReader"]]:
239 """Emits next multipart reader object."""
240 item = await self.stream.next()
241 if self.stream.at_eof():
242 await self.release()
243 return item
244
245 async def release(self) -> None:
246 """Release the connection gracefully.
247
248 All remaining content is read to the void.
249 """
250 await self.resp.release()
251
252
253class BodyPartReader:
254 """Multipart reader for single body part."""
255
256 chunk_size = 8192
257
258 def __init__(
259 self,
260 boundary: bytes,
261 headers: "CIMultiDictProxy[str]",
262 content: StreamReader,
263 *,
264 subtype: str = "mixed",
265 default_charset: Optional[str] = None,
266 ) -> None:
267 self.headers = headers
268 self._boundary = boundary
269 self._content = content
270 self._default_charset = default_charset
271 self._at_eof = False
272 self._is_form_data = subtype == "form-data"
273 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
274 length = None if self._is_form_data else self.headers.get(CONTENT_LENGTH, None)
275 self._length = int(length) if length is not None else None
276 self._read_bytes = 0
277 self._unread: Deque[bytes] = deque()
278 self._prev_chunk: Optional[bytes] = None
279 self._content_eof = 0
280 self._cache: Dict[str, Any] = {}
281
282 def __aiter__(self) -> AsyncIterator["BodyPartReader"]:
283 return self # type: ignore[return-value]
284
285 async def __anext__(self) -> bytes:
286 part = await self.next()
287 if part is None:
288 raise StopAsyncIteration
289 return part
290
291 async def next(self) -> Optional[bytes]:
292 item = await self.read()
293 if not item:
294 return None
295 return item
296
297 async def read(self, *, decode: bool = False) -> bytes:
298 """Reads body part data.
299
300 decode: Decodes data following by encoding
301 method from Content-Encoding header. If it missed
302 data remains untouched
303 """
304 if self._at_eof:
305 return b""
306 data = bytearray()
307 while not self._at_eof:
308 data.extend(await self.read_chunk(self.chunk_size))
309 if decode:
310 return self.decode(data)
311 return data
312
313 async def read_chunk(self, size: int = chunk_size) -> bytes:
314 """Reads body part content chunk of the specified size.
315
316 size: chunk size
317 """
318 if self._at_eof:
319 return b""
320 if self._length:
321 chunk = await self._read_chunk_from_length(size)
322 else:
323 chunk = await self._read_chunk_from_stream(size)
324
325 self._read_bytes += len(chunk)
326 if self._read_bytes == self._length:
327 self._at_eof = True
328 if self._at_eof:
329 clrf = await self._content.readline()
330 assert (
331 b"\r\n" == clrf
332 ), "reader did not read all the data or it is malformed"
333 return chunk
334
335 async def _read_chunk_from_length(self, size: int) -> bytes:
336 # Reads body part content chunk of the specified size.
337 # The body part must has Content-Length header with proper value.
338 assert self._length is not None, "Content-Length required for chunked read"
339 chunk_size = min(size, self._length - self._read_bytes)
340 chunk = await self._content.read(chunk_size)
341 if self._content.at_eof():
342 self._at_eof = True
343 return chunk
344
345 async def _read_chunk_from_stream(self, size: int) -> bytes:
346 # Reads content chunk of body part with unknown length.
347 # The Content-Length header for body part is not necessary.
348 assert (
349 size >= len(self._boundary) + 2
350 ), "Chunk size must be greater or equal than boundary length + 2"
351 first_chunk = self._prev_chunk is None
352 if first_chunk:
353 self._prev_chunk = await self._content.read(size)
354
355 chunk = await self._content.read(size)
356 self._content_eof += int(self._content.at_eof())
357 assert self._content_eof < 3, "Reading after EOF"
358 assert self._prev_chunk is not None
359 window = self._prev_chunk + chunk
360 sub = b"\r\n" + self._boundary
361 if first_chunk:
362 idx = window.find(sub)
363 else:
364 idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub)))
365 if idx >= 0:
366 # pushing boundary back to content
367 with warnings.catch_warnings():
368 warnings.filterwarnings("ignore", category=DeprecationWarning)
369 self._content.unread_data(window[idx:])
370 if size > idx:
371 self._prev_chunk = self._prev_chunk[:idx]
372 chunk = window[len(self._prev_chunk) : idx]
373 if not chunk:
374 self._at_eof = True
375 result = self._prev_chunk
376 self._prev_chunk = chunk
377 return result
378
379 async def readline(self) -> bytes:
380 """Reads body part by line by line."""
381 if self._at_eof:
382 return b""
383
384 if self._unread:
385 line = self._unread.popleft()
386 else:
387 line = await self._content.readline()
388
389 if line.startswith(self._boundary):
390 # the very last boundary may not come with \r\n,
391 # so set single rules for everyone
392 sline = line.rstrip(b"\r\n")
393 boundary = self._boundary
394 last_boundary = self._boundary + b"--"
395 # ensure that we read exactly the boundary, not something alike
396 if sline == boundary or sline == last_boundary:
397 self._at_eof = True
398 self._unread.append(line)
399 return b""
400 else:
401 next_line = await self._content.readline()
402 if next_line.startswith(self._boundary):
403 line = line[:-2] # strip CRLF but only once
404 self._unread.append(next_line)
405
406 return line
407
408 async def release(self) -> None:
409 """Like read(), but reads all the data to the void."""
410 if self._at_eof:
411 return
412 while not self._at_eof:
413 await self.read_chunk(self.chunk_size)
414
415 async def text(self, *, encoding: Optional[str] = None) -> str:
416 """Like read(), but assumes that body part contains text data."""
417 data = await self.read(decode=True)
418 # see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm
419 # and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send
420 encoding = encoding or self.get_charset(default="utf-8")
421 return data.decode(encoding)
422
423 async def json(self, *, encoding: Optional[str] = None) -> Optional[Dict[str, Any]]:
424 """Like read(), but assumes that body parts contains JSON data."""
425 data = await self.read(decode=True)
426 if not data:
427 return None
428 encoding = encoding or self.get_charset(default="utf-8")
429 return cast(Dict[str, Any], json.loads(data.decode(encoding)))
430
431 async def form(self, *, encoding: Optional[str] = None) -> List[Tuple[str, str]]:
432 """Like read(), but assumes that body parts contain form urlencoded data."""
433 data = await self.read(decode=True)
434 if not data:
435 return []
436 if encoding is not None:
437 real_encoding = encoding
438 else:
439 real_encoding = self.get_charset(default="utf-8")
440 try:
441 decoded_data = data.rstrip().decode(real_encoding)
442 except UnicodeDecodeError:
443 raise ValueError("data cannot be decoded with %s encoding" % real_encoding)
444
445 return parse_qsl(
446 decoded_data,
447 keep_blank_values=True,
448 encoding=real_encoding,
449 )
450
451 def at_eof(self) -> bool:
452 """Returns True if the boundary was reached or False otherwise."""
453 return self._at_eof
454
455 def decode(self, data: bytes) -> bytes:
456 """Decodes data.
457
458 Decoding is done according the specified Content-Encoding
459 or Content-Transfer-Encoding headers value.
460 """
461 if CONTENT_TRANSFER_ENCODING in self.headers:
462 data = self._decode_content_transfer(data)
463 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
464 if not self._is_form_data and CONTENT_ENCODING in self.headers:
465 return self._decode_content(data)
466 return data
467
468 def _decode_content(self, data: bytes) -> bytes:
469 encoding = self.headers.get(CONTENT_ENCODING, "").lower()
470 if encoding == "identity":
471 return data
472 if encoding in {"deflate", "gzip"}:
473 return ZLibDecompressor(
474 encoding=encoding,
475 suppress_deflate_header=True,
476 ).decompress_sync(data)
477
478 raise RuntimeError(f"unknown content encoding: {encoding}")
479
480 def _decode_content_transfer(self, data: bytes) -> bytes:
481 encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
482
483 if encoding == "base64":
484 return base64.b64decode(data)
485 elif encoding == "quoted-printable":
486 return binascii.a2b_qp(data)
487 elif encoding in ("binary", "8bit", "7bit"):
488 return data
489 else:
490 raise RuntimeError(
491 "unknown content transfer encoding: {}" "".format(encoding)
492 )
493
494 def get_charset(self, default: str) -> str:
495 """Returns charset parameter from Content-Type header or default."""
496 ctype = self.headers.get(CONTENT_TYPE, "")
497 mimetype = parse_mimetype(ctype)
498 return mimetype.parameters.get("charset", self._default_charset or default)
499
500 @reify
501 def name(self) -> Optional[str]:
502 """Returns name specified in Content-Disposition header.
503
504 If the header is missing or malformed, returns None.
505 """
506 _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
507 return content_disposition_filename(params, "name")
508
509 @reify
510 def filename(self) -> Optional[str]:
511 """Returns filename specified in Content-Disposition header.
512
513 Returns None if the header is missing or malformed.
514 """
515 _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
516 return content_disposition_filename(params, "filename")
517
518
519@payload_type(BodyPartReader, order=Order.try_first)
520class BodyPartReaderPayload(Payload):
521 def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None:
522 super().__init__(value, *args, **kwargs)
523
524 params: Dict[str, str] = {}
525 if value.name is not None:
526 params["name"] = value.name
527 if value.filename is not None:
528 params["filename"] = value.filename
529
530 if params:
531 self.set_content_disposition("attachment", True, **params)
532
533 async def write(self, writer: Any) -> None:
534 field = self._value
535 chunk = await field.read_chunk(size=2**16)
536 while chunk:
537 await writer.write(field.decode(chunk))
538 chunk = await field.read_chunk(size=2**16)
539
540
541class MultipartReader:
542 """Multipart body reader."""
543
544 #: Response wrapper, used when multipart readers constructs from response.
545 response_wrapper_cls = MultipartResponseWrapper
546 #: Multipart reader class, used to handle multipart/* body parts.
547 #: None points to type(self)
548 multipart_reader_cls = None
549 #: Body part reader class for non multipart/* content types.
550 part_reader_cls = BodyPartReader
551
552 def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None:
553 self._mimetype = parse_mimetype(headers[CONTENT_TYPE])
554 assert self._mimetype.type == "multipart", "multipart/* content type expected"
555 if "boundary" not in self._mimetype.parameters:
556 raise ValueError(
557 "boundary missed for Content-Type: %s" % headers[CONTENT_TYPE]
558 )
559
560 self.headers = headers
561 self._boundary = ("--" + self._get_boundary()).encode()
562 self._content = content
563 self._default_charset: Optional[str] = None
564 self._last_part: Optional[Union["MultipartReader", BodyPartReader]] = None
565 self._at_eof = False
566 self._at_bof = True
567 self._unread: List[bytes] = []
568
569 def __aiter__(
570 self,
571 ) -> AsyncIterator["BodyPartReader"]:
572 return self # type: ignore[return-value]
573
574 async def __anext__(
575 self,
576 ) -> Optional[Union["MultipartReader", BodyPartReader]]:
577 part = await self.next()
578 if part is None:
579 raise StopAsyncIteration
580 return part
581
582 @classmethod
583 def from_response(
584 cls,
585 response: "ClientResponse",
586 ) -> MultipartResponseWrapper:
587 """Constructs reader instance from HTTP response.
588
589 :param response: :class:`~aiohttp.client.ClientResponse` instance
590 """
591 obj = cls.response_wrapper_cls(
592 response, cls(response.headers, response.content)
593 )
594 return obj
595
596 def at_eof(self) -> bool:
597 """Returns True if the final boundary was reached, false otherwise."""
598 return self._at_eof
599
600 async def next(
601 self,
602 ) -> Optional[Union["MultipartReader", BodyPartReader]]:
603 """Emits the next multipart body part."""
604 # So, if we're at BOF, we need to skip till the boundary.
605 if self._at_eof:
606 return None
607 await self._maybe_release_last_part()
608 if self._at_bof:
609 await self._read_until_first_boundary()
610 self._at_bof = False
611 else:
612 await self._read_boundary()
613 if self._at_eof: # we just read the last boundary, nothing to do there
614 return None
615
616 part = await self.fetch_next_part()
617 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.6
618 if (
619 self._last_part is None
620 and self._mimetype.subtype == "form-data"
621 and isinstance(part, BodyPartReader)
622 ):
623 _, params = parse_content_disposition(part.headers.get(CONTENT_DISPOSITION))
624 if params.get("name") == "_charset_":
625 # Longest encoding in https://encoding.spec.whatwg.org/encodings.json
626 # is 19 characters, so 32 should be more than enough for any valid encoding.
627 charset = await part.read_chunk(32)
628 if len(charset) > 31:
629 raise RuntimeError("Invalid default charset")
630 self._default_charset = charset.strip().decode()
631 part = await self.fetch_next_part()
632 self._last_part = part
633 return self._last_part
634
635 async def release(self) -> None:
636 """Reads all the body parts to the void till the final boundary."""
637 while not self._at_eof:
638 item = await self.next()
639 if item is None:
640 break
641 await item.release()
642
643 async def fetch_next_part(
644 self,
645 ) -> Union["MultipartReader", BodyPartReader]:
646 """Returns the next body part reader."""
647 headers = await self._read_headers()
648 return self._get_part_reader(headers)
649
650 def _get_part_reader(
651 self,
652 headers: "CIMultiDictProxy[str]",
653 ) -> Union["MultipartReader", BodyPartReader]:
654 """Dispatches the response by the `Content-Type` header.
655
656 Returns a suitable reader instance.
657
658 :param dict headers: Response headers
659 """
660 ctype = headers.get(CONTENT_TYPE, "")
661 mimetype = parse_mimetype(ctype)
662
663 if mimetype.type == "multipart":
664 if self.multipart_reader_cls is None:
665 return type(self)(headers, self._content)
666 return self.multipart_reader_cls(headers, self._content)
667 else:
668 return self.part_reader_cls(
669 self._boundary,
670 headers,
671 self._content,
672 subtype=self._mimetype.subtype,
673 default_charset=self._default_charset,
674 )
675
676 def _get_boundary(self) -> str:
677 boundary = self._mimetype.parameters["boundary"]
678 if len(boundary) > 70:
679 raise ValueError("boundary %r is too long (70 chars max)" % boundary)
680
681 return boundary
682
683 async def _readline(self) -> bytes:
684 if self._unread:
685 return self._unread.pop()
686 return await self._content.readline()
687
688 async def _read_until_first_boundary(self) -> None:
689 while True:
690 chunk = await self._readline()
691 if chunk == b"":
692 raise ValueError(
693 "Could not find starting boundary %r" % (self._boundary)
694 )
695 chunk = chunk.rstrip()
696 if chunk == self._boundary:
697 return
698 elif chunk == self._boundary + b"--":
699 self._at_eof = True
700 return
701
702 async def _read_boundary(self) -> None:
703 chunk = (await self._readline()).rstrip()
704 if chunk == self._boundary:
705 pass
706 elif chunk == self._boundary + b"--":
707 self._at_eof = True
708 epilogue = await self._readline()
709 next_line = await self._readline()
710
711 # the epilogue is expected and then either the end of input or the
712 # parent multipart boundary, if the parent boundary is found then
713 # it should be marked as unread and handed to the parent for
714 # processing
715 if next_line[:2] == b"--":
716 self._unread.append(next_line)
717 # otherwise the request is likely missing an epilogue and both
718 # lines should be passed to the parent for processing
719 # (this handles the old behavior gracefully)
720 else:
721 self._unread.extend([next_line, epilogue])
722 else:
723 raise ValueError(f"Invalid boundary {chunk!r}, expected {self._boundary!r}")
724
725 async def _read_headers(self) -> "CIMultiDictProxy[str]":
726 lines = [b""]
727 while True:
728 chunk = await self._content.readline()
729 chunk = chunk.strip()
730 lines.append(chunk)
731 if not chunk:
732 break
733 parser = HeadersParser()
734 headers, raw_headers = parser.parse_headers(lines)
735 return headers
736
737 async def _maybe_release_last_part(self) -> None:
738 """Ensures that the last read body part is read completely."""
739 if self._last_part is not None:
740 if not self._last_part.at_eof():
741 await self._last_part.release()
742 self._unread.extend(self._last_part._unread)
743 self._last_part = None
744
745
746_Part = Tuple[Payload, str, str]
747
748
749class MultipartWriter(Payload):
750 """Multipart body writer."""
751
752 def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> None:
753 boundary = boundary if boundary is not None else uuid.uuid4().hex
754 # The underlying Payload API demands a str (utf-8), not bytes,
755 # so we need to ensure we don't lose anything during conversion.
756 # As a result, require the boundary to be ASCII only.
757 # In both situations.
758
759 try:
760 self._boundary = boundary.encode("ascii")
761 except UnicodeEncodeError:
762 raise ValueError("boundary should contain ASCII only chars") from None
763 ctype = f"multipart/{subtype}; boundary={self._boundary_value}"
764
765 super().__init__(None, content_type=ctype)
766
767 self._parts: List[_Part] = []
768 self._is_form_data = subtype == "form-data"
769
770 def __enter__(self) -> "MultipartWriter":
771 return self
772
773 def __exit__(
774 self,
775 exc_type: Optional[Type[BaseException]],
776 exc_val: Optional[BaseException],
777 exc_tb: Optional[TracebackType],
778 ) -> None:
779 pass
780
781 def __iter__(self) -> Iterator[_Part]:
782 return iter(self._parts)
783
784 def __len__(self) -> int:
785 return len(self._parts)
786
787 def __bool__(self) -> bool:
788 return True
789
790 _valid_tchar_regex = re.compile(rb"\A[!#$%&'*+\-.^_`|~\w]+\Z")
791 _invalid_qdtext_char_regex = re.compile(rb"[\x00-\x08\x0A-\x1F\x7F]")
792
793 @property
794 def _boundary_value(self) -> str:
795 """Wrap boundary parameter value in quotes, if necessary.
796
797 Reads self.boundary and returns a unicode string.
798 """
799 # Refer to RFCs 7231, 7230, 5234.
800 #
801 # parameter = token "=" ( token / quoted-string )
802 # token = 1*tchar
803 # quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE
804 # qdtext = HTAB / SP / %x21 / %x23-5B / %x5D-7E / obs-text
805 # obs-text = %x80-FF
806 # quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text )
807 # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*"
808 # / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~"
809 # / DIGIT / ALPHA
810 # ; any VCHAR, except delimiters
811 # VCHAR = %x21-7E
812 value = self._boundary
813 if re.match(self._valid_tchar_regex, value):
814 return value.decode("ascii") # cannot fail
815
816 if re.search(self._invalid_qdtext_char_regex, value):
817 raise ValueError("boundary value contains invalid characters")
818
819 # escape %x5C and %x22
820 quoted_value_content = value.replace(b"\\", b"\\\\")
821 quoted_value_content = quoted_value_content.replace(b'"', b'\\"')
822
823 return '"' + quoted_value_content.decode("ascii") + '"'
824
825 @property
826 def boundary(self) -> str:
827 return self._boundary.decode("ascii")
828
829 def append(self, obj: Any, headers: Optional[Mapping[str, str]] = None) -> Payload:
830 if headers is None:
831 headers = CIMultiDict()
832
833 if isinstance(obj, Payload):
834 obj.headers.update(headers)
835 return self.append_payload(obj)
836 else:
837 try:
838 payload = get_payload(obj, headers=headers)
839 except LookupError:
840 raise TypeError("Cannot create payload from %r" % obj)
841 else:
842 return self.append_payload(payload)
843
844 def append_payload(self, payload: Payload) -> Payload:
845 """Adds a new body part to multipart writer."""
846 encoding: Optional[str] = None
847 te_encoding: Optional[str] = None
848 if self._is_form_data:
849 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.7
850 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
851 assert (
852 not {CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING}
853 & payload.headers.keys()
854 )
855 # Set default Content-Disposition in case user doesn't create one
856 if CONTENT_DISPOSITION not in payload.headers:
857 name = f"section-{len(self._parts)}"
858 payload.set_content_disposition("form-data", name=name)
859 else:
860 # compression
861 encoding = payload.headers.get(CONTENT_ENCODING, "").lower()
862 if encoding and encoding not in ("deflate", "gzip", "identity"):
863 raise RuntimeError(f"unknown content encoding: {encoding}")
864 if encoding == "identity":
865 encoding = None
866
867 # te encoding
868 te_encoding = payload.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
869 if te_encoding not in ("", "base64", "quoted-printable", "binary"):
870 raise RuntimeError(f"unknown content transfer encoding: {te_encoding}")
871 if te_encoding == "binary":
872 te_encoding = None
873
874 # size
875 size = payload.size
876 if size is not None and not (encoding or te_encoding):
877 payload.headers[CONTENT_LENGTH] = str(size)
878
879 self._parts.append((payload, encoding, te_encoding)) # type: ignore[arg-type]
880 return payload
881
882 def append_json(
883 self, obj: Any, headers: Optional[Mapping[str, str]] = None
884 ) -> Payload:
885 """Helper to append JSON part."""
886 if headers is None:
887 headers = CIMultiDict()
888
889 return self.append_payload(JsonPayload(obj, headers=headers))
890
891 def append_form(
892 self,
893 obj: Union[Sequence[Tuple[str, str]], Mapping[str, str]],
894 headers: Optional[Mapping[str, str]] = None,
895 ) -> Payload:
896 """Helper to append form urlencoded part."""
897 assert isinstance(obj, (Sequence, Mapping))
898
899 if headers is None:
900 headers = CIMultiDict()
901
902 if isinstance(obj, Mapping):
903 obj = list(obj.items())
904 data = urlencode(obj, doseq=True)
905
906 return self.append_payload(
907 StringPayload(
908 data, headers=headers, content_type="application/x-www-form-urlencoded"
909 )
910 )
911
912 @property
913 def size(self) -> Optional[int]:
914 """Size of the payload."""
915 total = 0
916 for part, encoding, te_encoding in self._parts:
917 if encoding or te_encoding or part.size is None:
918 return None
919
920 total += int(
921 2
922 + len(self._boundary)
923 + 2
924 + part.size # b'--'+self._boundary+b'\r\n'
925 + len(part._binary_headers)
926 + 2 # b'\r\n'
927 )
928
929 total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n'
930 return total
931
932 async def write(self, writer: Any, close_boundary: bool = True) -> None:
933 """Write body."""
934 for part, encoding, te_encoding in self._parts:
935 if self._is_form_data:
936 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.2
937 assert CONTENT_DISPOSITION in part.headers
938 assert "name=" in part.headers[CONTENT_DISPOSITION]
939
940 await writer.write(b"--" + self._boundary + b"\r\n")
941 await writer.write(part._binary_headers)
942
943 if encoding or te_encoding:
944 w = MultipartPayloadWriter(writer)
945 if encoding:
946 w.enable_compression(encoding)
947 if te_encoding:
948 w.enable_encoding(te_encoding)
949 await part.write(w) # type: ignore[arg-type]
950 await w.write_eof()
951 else:
952 await part.write(writer)
953
954 await writer.write(b"\r\n")
955
956 if close_boundary:
957 await writer.write(b"--" + self._boundary + b"--\r\n")
958
959
960class MultipartPayloadWriter:
961 def __init__(self, writer: Any) -> None:
962 self._writer = writer
963 self._encoding: Optional[str] = None
964 self._compress: Optional[ZLibCompressor] = None
965 self._encoding_buffer: Optional[bytearray] = None
966
967 def enable_encoding(self, encoding: str) -> None:
968 if encoding == "base64":
969 self._encoding = encoding
970 self._encoding_buffer = bytearray()
971 elif encoding == "quoted-printable":
972 self._encoding = "quoted-printable"
973
974 def enable_compression(
975 self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
976 ) -> None:
977 self._compress = ZLibCompressor(
978 encoding=encoding,
979 suppress_deflate_header=True,
980 strategy=strategy,
981 )
982
983 async def write_eof(self) -> None:
984 if self._compress is not None:
985 chunk = self._compress.flush()
986 if chunk:
987 self._compress = None
988 await self.write(chunk)
989
990 if self._encoding == "base64":
991 if self._encoding_buffer:
992 await self._writer.write(base64.b64encode(self._encoding_buffer))
993
994 async def write(self, chunk: bytes) -> None:
995 if self._compress is not None:
996 if chunk:
997 chunk = await self._compress.compress(chunk)
998 if not chunk:
999 return
1000
1001 if self._encoding == "base64":
1002 buf = self._encoding_buffer
1003 assert buf is not None
1004 buf.extend(chunk)
1005
1006 if buf:
1007 div, mod = divmod(len(buf), 3)
1008 enc_chunk, self._encoding_buffer = (buf[: div * 3], buf[div * 3 :])
1009 if enc_chunk:
1010 b64chunk = base64.b64encode(enc_chunk)
1011 await self._writer.write(b64chunk)
1012 elif self._encoding == "quoted-printable":
1013 await self._writer.write(binascii.b2a_qp(chunk))
1014 else:
1015 await self._writer.write(chunk)