1import base64
2import binascii
3import json
4import re
5import sys
6import uuid
7import warnings
8from collections import deque
9from collections.abc import Mapping, Sequence
10from types import TracebackType
11from typing import (
12 TYPE_CHECKING,
13 Any,
14 Deque,
15 Dict,
16 Iterator,
17 List,
18 Optional,
19 Tuple,
20 Type,
21 Union,
22 cast,
23)
24from urllib.parse import parse_qsl, unquote, urlencode
25
26from multidict import CIMultiDict, CIMultiDictProxy
27
28from .compression_utils import ZLibCompressor, ZLibDecompressor
29from .hdrs import (
30 CONTENT_DISPOSITION,
31 CONTENT_ENCODING,
32 CONTENT_LENGTH,
33 CONTENT_TRANSFER_ENCODING,
34 CONTENT_TYPE,
35)
36from .helpers import CHAR, TOKEN, parse_mimetype, reify
37from .http import HeadersParser
38from .log import internal_logger
39from .payload import (
40 JsonPayload,
41 LookupError,
42 Order,
43 Payload,
44 StringPayload,
45 get_payload,
46 payload_type,
47)
48from .streams import StreamReader
49
50if sys.version_info >= (3, 11):
51 from typing import Self
52else:
53 from typing import TypeVar
54
55 Self = TypeVar("Self", bound="BodyPartReader")
56
57__all__ = (
58 "MultipartReader",
59 "MultipartWriter",
60 "BodyPartReader",
61 "BadContentDispositionHeader",
62 "BadContentDispositionParam",
63 "parse_content_disposition",
64 "content_disposition_filename",
65)
66
67
68if TYPE_CHECKING:
69 from .client_reqrep import ClientResponse
70
71
72class BadContentDispositionHeader(RuntimeWarning):
73 pass
74
75
76class BadContentDispositionParam(RuntimeWarning):
77 pass
78
79
80def parse_content_disposition(
81 header: Optional[str],
82) -> Tuple[Optional[str], Dict[str, str]]:
83 def is_token(string: str) -> bool:
84 return bool(string) and TOKEN >= set(string)
85
86 def is_quoted(string: str) -> bool:
87 return string[0] == string[-1] == '"'
88
89 def is_rfc5987(string: str) -> bool:
90 return is_token(string) and string.count("'") == 2
91
92 def is_extended_param(string: str) -> bool:
93 return string.endswith("*")
94
95 def is_continuous_param(string: str) -> bool:
96 pos = string.find("*") + 1
97 if not pos:
98 return False
99 substring = string[pos:-1] if string.endswith("*") else string[pos:]
100 return substring.isdigit()
101
102 def unescape(text: str, *, chars: str = "".join(map(re.escape, CHAR))) -> str:
103 return re.sub(f"\\\\([{chars}])", "\\1", text)
104
105 if not header:
106 return None, {}
107
108 disptype, *parts = header.split(";")
109 if not is_token(disptype):
110 warnings.warn(BadContentDispositionHeader(header))
111 return None, {}
112
113 params: Dict[str, str] = {}
114 while parts:
115 item = parts.pop(0)
116
117 if not item: # To handle trailing semicolons
118 warnings.warn(BadContentDispositionHeader(header))
119 continue
120
121 if "=" not in item:
122 warnings.warn(BadContentDispositionHeader(header))
123 return None, {}
124
125 key, value = item.split("=", 1)
126 key = key.lower().strip()
127 value = value.lstrip()
128
129 if key in params:
130 warnings.warn(BadContentDispositionHeader(header))
131 return None, {}
132
133 if not is_token(key):
134 warnings.warn(BadContentDispositionParam(item))
135 continue
136
137 elif is_continuous_param(key):
138 if is_quoted(value):
139 value = unescape(value[1:-1])
140 elif not is_token(value):
141 warnings.warn(BadContentDispositionParam(item))
142 continue
143
144 elif is_extended_param(key):
145 if is_rfc5987(value):
146 encoding, _, value = value.split("'", 2)
147 encoding = encoding or "utf-8"
148 else:
149 warnings.warn(BadContentDispositionParam(item))
150 continue
151
152 try:
153 value = unquote(value, encoding, "strict")
154 except UnicodeDecodeError: # pragma: nocover
155 warnings.warn(BadContentDispositionParam(item))
156 continue
157
158 else:
159 failed = True
160 if is_quoted(value):
161 failed = False
162 value = unescape(value[1:-1].lstrip("\\/"))
163 elif is_token(value):
164 failed = False
165 elif parts:
166 # maybe just ; in filename, in any case this is just
167 # one case fix, for proper fix we need to redesign parser
168 _value = f"{value};{parts[0]}"
169 if is_quoted(_value):
170 parts.pop(0)
171 value = unescape(_value[1:-1].lstrip("\\/"))
172 failed = False
173
174 if failed:
175 warnings.warn(BadContentDispositionHeader(header))
176 return None, {}
177
178 params[key] = value
179
180 return disptype.lower(), params
181
182
183def content_disposition_filename(
184 params: Mapping[str, str], name: str = "filename"
185) -> Optional[str]:
186 name_suf = "%s*" % name
187 if not params:
188 return None
189 elif name_suf in params:
190 return params[name_suf]
191 elif name in params:
192 return params[name]
193 else:
194 parts = []
195 fnparams = sorted(
196 (key, value) for key, value in params.items() if key.startswith(name_suf)
197 )
198 for num, (key, value) in enumerate(fnparams):
199 _, tail = key.split("*", 1)
200 if tail.endswith("*"):
201 tail = tail[:-1]
202 if tail == str(num):
203 parts.append(value)
204 else:
205 break
206 if not parts:
207 return None
208 value = "".join(parts)
209 if "'" in value:
210 encoding, _, value = value.split("'", 2)
211 encoding = encoding or "utf-8"
212 return unquote(value, encoding, "strict")
213 return value
214
215
216class MultipartResponseWrapper:
217 """Wrapper around the MultipartReader.
218
219 It takes care about
220 underlying connection and close it when it needs in.
221 """
222
223 def __init__(
224 self,
225 resp: "ClientResponse",
226 stream: "MultipartReader",
227 ) -> None:
228 self.resp = resp
229 self.stream = stream
230
231 def __aiter__(self) -> "MultipartResponseWrapper":
232 return self
233
234 async def __anext__(
235 self,
236 ) -> Union["MultipartReader", "BodyPartReader"]:
237 part = await self.next()
238 if part is None:
239 raise StopAsyncIteration
240 return part
241
242 def at_eof(self) -> bool:
243 """Returns True when all response data had been read."""
244 return self.resp.content.at_eof()
245
246 async def next(
247 self,
248 ) -> Optional[Union["MultipartReader", "BodyPartReader"]]:
249 """Emits next multipart reader object."""
250 item = await self.stream.next()
251 if self.stream.at_eof():
252 await self.release()
253 return item
254
255 async def release(self) -> None:
256 """Release the connection gracefully.
257
258 All remaining content is read to the void.
259 """
260 await self.resp.release()
261
262
263class BodyPartReader:
264 """Multipart reader for single body part."""
265
266 chunk_size = 8192
267
268 def __init__(
269 self,
270 boundary: bytes,
271 headers: "CIMultiDictProxy[str]",
272 content: StreamReader,
273 *,
274 subtype: str = "mixed",
275 default_charset: Optional[str] = None,
276 ) -> None:
277 self.headers = headers
278 self._boundary = boundary
279 self._boundary_len = len(boundary) + 2 # Boundary + \r\n
280 self._content = content
281 self._default_charset = default_charset
282 self._at_eof = False
283 self._is_form_data = subtype == "form-data"
284 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
285 length = None if self._is_form_data else self.headers.get(CONTENT_LENGTH, None)
286 self._length = int(length) if length is not None else None
287 self._read_bytes = 0
288 self._unread: Deque[bytes] = deque()
289 self._prev_chunk: Optional[bytes] = None
290 self._content_eof = 0
291 self._cache: Dict[str, Any] = {}
292
293 def __aiter__(self: Self) -> Self:
294 return self
295
296 async def __anext__(self) -> bytes:
297 part = await self.next()
298 if part is None:
299 raise StopAsyncIteration
300 return part
301
302 async def next(self) -> Optional[bytes]:
303 item = await self.read()
304 if not item:
305 return None
306 return item
307
308 async def read(self, *, decode: bool = False) -> bytes:
309 """Reads body part data.
310
311 decode: Decodes data following by encoding
312 method from Content-Encoding header. If it missed
313 data remains untouched
314 """
315 if self._at_eof:
316 return b""
317 data = bytearray()
318 while not self._at_eof:
319 data.extend(await self.read_chunk(self.chunk_size))
320 if decode:
321 return self.decode(data)
322 return data
323
324 async def read_chunk(self, size: int = chunk_size) -> bytes:
325 """Reads body part content chunk of the specified size.
326
327 size: chunk size
328 """
329 if self._at_eof:
330 return b""
331 if self._length:
332 chunk = await self._read_chunk_from_length(size)
333 else:
334 chunk = await self._read_chunk_from_stream(size)
335
336 # For the case of base64 data, we must read a fragment of size with a
337 # remainder of 0 by dividing by 4 for string without symbols \n or \r
338 encoding = self.headers.get(CONTENT_TRANSFER_ENCODING)
339 if encoding and encoding.lower() == "base64":
340 stripped_chunk = b"".join(chunk.split())
341 remainder = len(stripped_chunk) % 4
342
343 while remainder != 0 and not self.at_eof():
344 over_chunk_size = 4 - remainder
345 over_chunk = b""
346
347 if self._prev_chunk:
348 over_chunk = self._prev_chunk[:over_chunk_size]
349 self._prev_chunk = self._prev_chunk[len(over_chunk) :]
350
351 if len(over_chunk) != over_chunk_size:
352 over_chunk += await self._content.read(4 - len(over_chunk))
353
354 if not over_chunk:
355 self._at_eof = True
356
357 stripped_chunk += b"".join(over_chunk.split())
358 chunk += over_chunk
359 remainder = len(stripped_chunk) % 4
360
361 self._read_bytes += len(chunk)
362 if self._read_bytes == self._length:
363 self._at_eof = True
364 if self._at_eof:
365 clrf = await self._content.readline()
366 assert (
367 b"\r\n" == clrf
368 ), "reader did not read all the data or it is malformed"
369 return chunk
370
371 async def _read_chunk_from_length(self, size: int) -> bytes:
372 # Reads body part content chunk of the specified size.
373 # The body part must has Content-Length header with proper value.
374 assert self._length is not None, "Content-Length required for chunked read"
375 chunk_size = min(size, self._length - self._read_bytes)
376 chunk = await self._content.read(chunk_size)
377 if self._content.at_eof():
378 self._at_eof = True
379 return chunk
380
381 async def _read_chunk_from_stream(self, size: int) -> bytes:
382 # Reads content chunk of body part with unknown length.
383 # The Content-Length header for body part is not necessary.
384 assert (
385 size >= self._boundary_len
386 ), "Chunk size must be greater or equal than boundary length + 2"
387 first_chunk = self._prev_chunk is None
388 if first_chunk:
389 self._prev_chunk = await self._content.read(size)
390
391 chunk = b""
392 # content.read() may return less than size, so we need to loop to ensure
393 # we have enough data to detect the boundary.
394 while len(chunk) < self._boundary_len:
395 chunk += await self._content.read(size)
396 self._content_eof += int(self._content.at_eof())
397 assert self._content_eof < 3, "Reading after EOF"
398 if self._content_eof:
399 break
400 if len(chunk) > size:
401 self._content.unread_data(chunk[size:])
402 chunk = chunk[:size]
403
404 assert self._prev_chunk is not None
405 window = self._prev_chunk + chunk
406 sub = b"\r\n" + self._boundary
407 if first_chunk:
408 idx = window.find(sub)
409 else:
410 idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub)))
411 if idx >= 0:
412 # pushing boundary back to content
413 with warnings.catch_warnings():
414 warnings.filterwarnings("ignore", category=DeprecationWarning)
415 self._content.unread_data(window[idx:])
416 if size > idx:
417 self._prev_chunk = self._prev_chunk[:idx]
418 chunk = window[len(self._prev_chunk) : idx]
419 if not chunk:
420 self._at_eof = True
421 result = self._prev_chunk
422 self._prev_chunk = chunk
423 return result
424
425 async def readline(self) -> bytes:
426 """Reads body part by line by line."""
427 if self._at_eof:
428 return b""
429
430 if self._unread:
431 line = self._unread.popleft()
432 else:
433 line = await self._content.readline()
434
435 if line.startswith(self._boundary):
436 # the very last boundary may not come with \r\n,
437 # so set single rules for everyone
438 sline = line.rstrip(b"\r\n")
439 boundary = self._boundary
440 last_boundary = self._boundary + b"--"
441 # ensure that we read exactly the boundary, not something alike
442 if sline == boundary or sline == last_boundary:
443 self._at_eof = True
444 self._unread.append(line)
445 return b""
446 else:
447 next_line = await self._content.readline()
448 if next_line.startswith(self._boundary):
449 line = line[:-2] # strip CRLF but only once
450 self._unread.append(next_line)
451
452 return line
453
454 async def release(self) -> None:
455 """Like read(), but reads all the data to the void."""
456 if self._at_eof:
457 return
458 while not self._at_eof:
459 await self.read_chunk(self.chunk_size)
460
461 async def text(self, *, encoding: Optional[str] = None) -> str:
462 """Like read(), but assumes that body part contains text data."""
463 data = await self.read(decode=True)
464 # see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm
465 # and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send
466 encoding = encoding or self.get_charset(default="utf-8")
467 return data.decode(encoding)
468
469 async def json(self, *, encoding: Optional[str] = None) -> Optional[Dict[str, Any]]:
470 """Like read(), but assumes that body parts contains JSON data."""
471 data = await self.read(decode=True)
472 if not data:
473 return None
474 encoding = encoding or self.get_charset(default="utf-8")
475 return cast(Dict[str, Any], json.loads(data.decode(encoding)))
476
477 async def form(self, *, encoding: Optional[str] = None) -> List[Tuple[str, str]]:
478 """Like read(), but assumes that body parts contain form urlencoded data."""
479 data = await self.read(decode=True)
480 if not data:
481 return []
482 if encoding is not None:
483 real_encoding = encoding
484 else:
485 real_encoding = self.get_charset(default="utf-8")
486 try:
487 decoded_data = data.rstrip().decode(real_encoding)
488 except UnicodeDecodeError:
489 raise ValueError("data cannot be decoded with %s encoding" % real_encoding)
490
491 return parse_qsl(
492 decoded_data,
493 keep_blank_values=True,
494 encoding=real_encoding,
495 )
496
497 def at_eof(self) -> bool:
498 """Returns True if the boundary was reached or False otherwise."""
499 return self._at_eof
500
501 def decode(self, data: bytes) -> bytes:
502 """Decodes data.
503
504 Decoding is done according the specified Content-Encoding
505 or Content-Transfer-Encoding headers value.
506 """
507 if CONTENT_TRANSFER_ENCODING in self.headers:
508 data = self._decode_content_transfer(data)
509 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
510 if not self._is_form_data and CONTENT_ENCODING in self.headers:
511 return self._decode_content(data)
512 return data
513
514 def _decode_content(self, data: bytes) -> bytes:
515 encoding = self.headers.get(CONTENT_ENCODING, "").lower()
516 if encoding == "identity":
517 return data
518 if encoding in {"deflate", "gzip"}:
519 return ZLibDecompressor(
520 encoding=encoding,
521 suppress_deflate_header=True,
522 ).decompress_sync(data)
523
524 raise RuntimeError(f"unknown content encoding: {encoding}")
525
526 def _decode_content_transfer(self, data: bytes) -> bytes:
527 encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
528
529 if encoding == "base64":
530 return base64.b64decode(data)
531 elif encoding == "quoted-printable":
532 return binascii.a2b_qp(data)
533 elif encoding in ("binary", "8bit", "7bit"):
534 return data
535 else:
536 raise RuntimeError(f"unknown content transfer encoding: {encoding}")
537
538 def get_charset(self, default: str) -> str:
539 """Returns charset parameter from Content-Type header or default."""
540 ctype = self.headers.get(CONTENT_TYPE, "")
541 mimetype = parse_mimetype(ctype)
542 return mimetype.parameters.get("charset", self._default_charset or default)
543
544 @reify
545 def name(self) -> Optional[str]:
546 """Returns name specified in Content-Disposition header.
547
548 If the header is missing or malformed, returns None.
549 """
550 _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
551 return content_disposition_filename(params, "name")
552
553 @reify
554 def filename(self) -> Optional[str]:
555 """Returns filename specified in Content-Disposition header.
556
557 Returns None if the header is missing or malformed.
558 """
559 _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
560 return content_disposition_filename(params, "filename")
561
562
563@payload_type(BodyPartReader, order=Order.try_first)
564class BodyPartReaderPayload(Payload):
565 _value: BodyPartReader
566 # _autoclose = False (inherited) - Streaming reader that may have resources
567
568 def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None:
569 super().__init__(value, *args, **kwargs)
570
571 params: Dict[str, str] = {}
572 if value.name is not None:
573 params["name"] = value.name
574 if value.filename is not None:
575 params["filename"] = value.filename
576
577 if params:
578 self.set_content_disposition("attachment", True, **params)
579
580 def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
581 raise TypeError("Unable to decode.")
582
583 async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes:
584 """Raises TypeError as body parts should be consumed via write().
585
586 This is intentional: BodyPartReader payloads are designed for streaming
587 large data (potentially gigabytes) and must be consumed only once via
588 the write() method to avoid memory exhaustion. They cannot be buffered
589 in memory for reuse.
590 """
591 raise TypeError("Unable to read body part as bytes. Use write() to consume.")
592
593 async def write(self, writer: Any) -> None:
594 field = self._value
595 chunk = await field.read_chunk(size=2**16)
596 while chunk:
597 await writer.write(field.decode(chunk))
598 chunk = await field.read_chunk(size=2**16)
599
600
601class MultipartReader:
602 """Multipart body reader."""
603
604 #: Response wrapper, used when multipart readers constructs from response.
605 response_wrapper_cls = MultipartResponseWrapper
606 #: Multipart reader class, used to handle multipart/* body parts.
607 #: None points to type(self)
608 multipart_reader_cls: Optional[Type["MultipartReader"]] = None
609 #: Body part reader class for non multipart/* content types.
610 part_reader_cls = BodyPartReader
611
612 def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None:
613 self._mimetype = parse_mimetype(headers[CONTENT_TYPE])
614 assert self._mimetype.type == "multipart", "multipart/* content type expected"
615 if "boundary" not in self._mimetype.parameters:
616 raise ValueError(
617 "boundary missed for Content-Type: %s" % headers[CONTENT_TYPE]
618 )
619
620 self.headers = headers
621 self._boundary = ("--" + self._get_boundary()).encode()
622 self._content = content
623 self._default_charset: Optional[str] = None
624 self._last_part: Optional[Union["MultipartReader", BodyPartReader]] = None
625 self._at_eof = False
626 self._at_bof = True
627 self._unread: List[bytes] = []
628
629 def __aiter__(self: Self) -> Self:
630 return self
631
632 async def __anext__(
633 self,
634 ) -> Optional[Union["MultipartReader", BodyPartReader]]:
635 part = await self.next()
636 if part is None:
637 raise StopAsyncIteration
638 return part
639
640 @classmethod
641 def from_response(
642 cls,
643 response: "ClientResponse",
644 ) -> MultipartResponseWrapper:
645 """Constructs reader instance from HTTP response.
646
647 :param response: :class:`~aiohttp.client.ClientResponse` instance
648 """
649 obj = cls.response_wrapper_cls(
650 response, cls(response.headers, response.content)
651 )
652 return obj
653
654 def at_eof(self) -> bool:
655 """Returns True if the final boundary was reached, false otherwise."""
656 return self._at_eof
657
658 async def next(
659 self,
660 ) -> Optional[Union["MultipartReader", BodyPartReader]]:
661 """Emits the next multipart body part."""
662 # So, if we're at BOF, we need to skip till the boundary.
663 if self._at_eof:
664 return None
665 await self._maybe_release_last_part()
666 if self._at_bof:
667 await self._read_until_first_boundary()
668 self._at_bof = False
669 else:
670 await self._read_boundary()
671 if self._at_eof: # we just read the last boundary, nothing to do there
672 return None
673
674 part = await self.fetch_next_part()
675 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.6
676 if (
677 self._last_part is None
678 and self._mimetype.subtype == "form-data"
679 and isinstance(part, BodyPartReader)
680 ):
681 _, params = parse_content_disposition(part.headers.get(CONTENT_DISPOSITION))
682 if params.get("name") == "_charset_":
683 # Longest encoding in https://encoding.spec.whatwg.org/encodings.json
684 # is 19 characters, so 32 should be more than enough for any valid encoding.
685 charset = await part.read_chunk(32)
686 if len(charset) > 31:
687 raise RuntimeError("Invalid default charset")
688 self._default_charset = charset.strip().decode()
689 part = await self.fetch_next_part()
690 self._last_part = part
691 return self._last_part
692
693 async def release(self) -> None:
694 """Reads all the body parts to the void till the final boundary."""
695 while not self._at_eof:
696 item = await self.next()
697 if item is None:
698 break
699 await item.release()
700
701 async def fetch_next_part(
702 self,
703 ) -> Union["MultipartReader", BodyPartReader]:
704 """Returns the next body part reader."""
705 headers = await self._read_headers()
706 return self._get_part_reader(headers)
707
708 def _get_part_reader(
709 self,
710 headers: "CIMultiDictProxy[str]",
711 ) -> Union["MultipartReader", BodyPartReader]:
712 """Dispatches the response by the `Content-Type` header.
713
714 Returns a suitable reader instance.
715
716 :param dict headers: Response headers
717 """
718 ctype = headers.get(CONTENT_TYPE, "")
719 mimetype = parse_mimetype(ctype)
720
721 if mimetype.type == "multipart":
722 if self.multipart_reader_cls is None:
723 return type(self)(headers, self._content)
724 return self.multipart_reader_cls(headers, self._content)
725 else:
726 return self.part_reader_cls(
727 self._boundary,
728 headers,
729 self._content,
730 subtype=self._mimetype.subtype,
731 default_charset=self._default_charset,
732 )
733
734 def _get_boundary(self) -> str:
735 boundary = self._mimetype.parameters["boundary"]
736 if len(boundary) > 70:
737 raise ValueError("boundary %r is too long (70 chars max)" % boundary)
738
739 return boundary
740
741 async def _readline(self) -> bytes:
742 if self._unread:
743 return self._unread.pop()
744 return await self._content.readline()
745
746 async def _read_until_first_boundary(self) -> None:
747 while True:
748 chunk = await self._readline()
749 if chunk == b"":
750 raise ValueError(
751 "Could not find starting boundary %r" % (self._boundary)
752 )
753 chunk = chunk.rstrip()
754 if chunk == self._boundary:
755 return
756 elif chunk == self._boundary + b"--":
757 self._at_eof = True
758 return
759
760 async def _read_boundary(self) -> None:
761 chunk = (await self._readline()).rstrip()
762 if chunk == self._boundary:
763 pass
764 elif chunk == self._boundary + b"--":
765 self._at_eof = True
766 epilogue = await self._readline()
767 next_line = await self._readline()
768
769 # the epilogue is expected and then either the end of input or the
770 # parent multipart boundary, if the parent boundary is found then
771 # it should be marked as unread and handed to the parent for
772 # processing
773 if next_line[:2] == b"--":
774 self._unread.append(next_line)
775 # otherwise the request is likely missing an epilogue and both
776 # lines should be passed to the parent for processing
777 # (this handles the old behavior gracefully)
778 else:
779 self._unread.extend([next_line, epilogue])
780 else:
781 raise ValueError(f"Invalid boundary {chunk!r}, expected {self._boundary!r}")
782
783 async def _read_headers(self) -> "CIMultiDictProxy[str]":
784 lines = []
785 while True:
786 chunk = await self._content.readline()
787 chunk = chunk.strip()
788 lines.append(chunk)
789 if not chunk:
790 break
791 parser = HeadersParser()
792 headers, raw_headers = parser.parse_headers(lines)
793 return headers
794
795 async def _maybe_release_last_part(self) -> None:
796 """Ensures that the last read body part is read completely."""
797 if self._last_part is not None:
798 if not self._last_part.at_eof():
799 await self._last_part.release()
800 self._unread.extend(self._last_part._unread)
801 self._last_part = None
802
803
804_Part = Tuple[Payload, str, str]
805
806
807class MultipartWriter(Payload):
808 """Multipart body writer."""
809
810 _value: None
811 # _consumed = False (inherited) - Can be encoded multiple times
812 _autoclose = True # No file handles, just collects parts in memory
813
814 def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> None:
815 boundary = boundary if boundary is not None else uuid.uuid4().hex
816 # The underlying Payload API demands a str (utf-8), not bytes,
817 # so we need to ensure we don't lose anything during conversion.
818 # As a result, require the boundary to be ASCII only.
819 # In both situations.
820
821 try:
822 self._boundary = boundary.encode("ascii")
823 except UnicodeEncodeError:
824 raise ValueError("boundary should contain ASCII only chars") from None
825 ctype = f"multipart/{subtype}; boundary={self._boundary_value}"
826
827 super().__init__(None, content_type=ctype)
828
829 self._parts: List[_Part] = []
830 self._is_form_data = subtype == "form-data"
831
832 def __enter__(self) -> "MultipartWriter":
833 return self
834
835 def __exit__(
836 self,
837 exc_type: Optional[Type[BaseException]],
838 exc_val: Optional[BaseException],
839 exc_tb: Optional[TracebackType],
840 ) -> None:
841 pass
842
843 def __iter__(self) -> Iterator[_Part]:
844 return iter(self._parts)
845
846 def __len__(self) -> int:
847 return len(self._parts)
848
849 def __bool__(self) -> bool:
850 return True
851
852 _valid_tchar_regex = re.compile(rb"\A[!#$%&'*+\-.^_`|~\w]+\Z")
853 _invalid_qdtext_char_regex = re.compile(rb"[\x00-\x08\x0A-\x1F\x7F]")
854
855 @property
856 def _boundary_value(self) -> str:
857 """Wrap boundary parameter value in quotes, if necessary.
858
859 Reads self.boundary and returns a unicode string.
860 """
861 # Refer to RFCs 7231, 7230, 5234.
862 #
863 # parameter = token "=" ( token / quoted-string )
864 # token = 1*tchar
865 # quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE
866 # qdtext = HTAB / SP / %x21 / %x23-5B / %x5D-7E / obs-text
867 # obs-text = %x80-FF
868 # quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text )
869 # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*"
870 # / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~"
871 # / DIGIT / ALPHA
872 # ; any VCHAR, except delimiters
873 # VCHAR = %x21-7E
874 value = self._boundary
875 if re.match(self._valid_tchar_regex, value):
876 return value.decode("ascii") # cannot fail
877
878 if re.search(self._invalid_qdtext_char_regex, value):
879 raise ValueError("boundary value contains invalid characters")
880
881 # escape %x5C and %x22
882 quoted_value_content = value.replace(b"\\", b"\\\\")
883 quoted_value_content = quoted_value_content.replace(b'"', b'\\"')
884
885 return '"' + quoted_value_content.decode("ascii") + '"'
886
887 @property
888 def boundary(self) -> str:
889 return self._boundary.decode("ascii")
890
891 def append(self, obj: Any, headers: Optional[Mapping[str, str]] = None) -> Payload:
892 if headers is None:
893 headers = CIMultiDict()
894
895 if isinstance(obj, Payload):
896 obj.headers.update(headers)
897 return self.append_payload(obj)
898 else:
899 try:
900 payload = get_payload(obj, headers=headers)
901 except LookupError:
902 raise TypeError("Cannot create payload from %r" % obj)
903 else:
904 return self.append_payload(payload)
905
906 def append_payload(self, payload: Payload) -> Payload:
907 """Adds a new body part to multipart writer."""
908 encoding: Optional[str] = None
909 te_encoding: Optional[str] = None
910 if self._is_form_data:
911 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.7
912 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
913 assert (
914 not {CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING}
915 & payload.headers.keys()
916 )
917 # Set default Content-Disposition in case user doesn't create one
918 if CONTENT_DISPOSITION not in payload.headers:
919 name = f"section-{len(self._parts)}"
920 payload.set_content_disposition("form-data", name=name)
921 else:
922 # compression
923 encoding = payload.headers.get(CONTENT_ENCODING, "").lower()
924 if encoding and encoding not in ("deflate", "gzip", "identity"):
925 raise RuntimeError(f"unknown content encoding: {encoding}")
926 if encoding == "identity":
927 encoding = None
928
929 # te encoding
930 te_encoding = payload.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
931 if te_encoding not in ("", "base64", "quoted-printable", "binary"):
932 raise RuntimeError(f"unknown content transfer encoding: {te_encoding}")
933 if te_encoding == "binary":
934 te_encoding = None
935
936 # size
937 size = payload.size
938 if size is not None and not (encoding or te_encoding):
939 payload.headers[CONTENT_LENGTH] = str(size)
940
941 self._parts.append((payload, encoding, te_encoding)) # type: ignore[arg-type]
942 return payload
943
944 def append_json(
945 self, obj: Any, headers: Optional[Mapping[str, str]] = None
946 ) -> Payload:
947 """Helper to append JSON part."""
948 if headers is None:
949 headers = CIMultiDict()
950
951 return self.append_payload(JsonPayload(obj, headers=headers))
952
953 def append_form(
954 self,
955 obj: Union[Sequence[Tuple[str, str]], Mapping[str, str]],
956 headers: Optional[Mapping[str, str]] = None,
957 ) -> Payload:
958 """Helper to append form urlencoded part."""
959 assert isinstance(obj, (Sequence, Mapping))
960
961 if headers is None:
962 headers = CIMultiDict()
963
964 if isinstance(obj, Mapping):
965 obj = list(obj.items())
966 data = urlencode(obj, doseq=True)
967
968 return self.append_payload(
969 StringPayload(
970 data, headers=headers, content_type="application/x-www-form-urlencoded"
971 )
972 )
973
974 @property
975 def size(self) -> Optional[int]:
976 """Size of the payload."""
977 total = 0
978 for part, encoding, te_encoding in self._parts:
979 part_size = part.size
980 if encoding or te_encoding or part_size is None:
981 return None
982
983 total += int(
984 2
985 + len(self._boundary)
986 + 2
987 + part_size # b'--'+self._boundary+b'\r\n'
988 + len(part._binary_headers)
989 + 2 # b'\r\n'
990 )
991
992 total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n'
993 return total
994
995 def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
996 """Return string representation of the multipart data.
997
998 WARNING: This method may do blocking I/O if parts contain file payloads.
999 It should not be called in the event loop. Use as_bytes().decode() instead.
1000 """
1001 return "".join(
1002 "--"
1003 + self.boundary
1004 + "\r\n"
1005 + part._binary_headers.decode(encoding, errors)
1006 + part.decode()
1007 for part, _e, _te in self._parts
1008 )
1009
1010 async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes:
1011 """Return bytes representation of the multipart data.
1012
1013 This method is async-safe and calls as_bytes on underlying payloads.
1014 """
1015 parts: List[bytes] = []
1016
1017 # Process each part
1018 for part, _e, _te in self._parts:
1019 # Add boundary
1020 parts.append(b"--" + self._boundary + b"\r\n")
1021
1022 # Add headers
1023 parts.append(part._binary_headers)
1024
1025 # Add payload content using as_bytes for async safety
1026 part_bytes = await part.as_bytes(encoding, errors)
1027 parts.append(part_bytes)
1028
1029 # Add trailing CRLF
1030 parts.append(b"\r\n")
1031
1032 # Add closing boundary
1033 parts.append(b"--" + self._boundary + b"--\r\n")
1034
1035 return b"".join(parts)
1036
1037 async def write(self, writer: Any, close_boundary: bool = True) -> None:
1038 """Write body."""
1039 for part, encoding, te_encoding in self._parts:
1040 if self._is_form_data:
1041 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.2
1042 assert CONTENT_DISPOSITION in part.headers
1043 assert "name=" in part.headers[CONTENT_DISPOSITION]
1044
1045 await writer.write(b"--" + self._boundary + b"\r\n")
1046 await writer.write(part._binary_headers)
1047
1048 if encoding or te_encoding:
1049 w = MultipartPayloadWriter(writer)
1050 if encoding:
1051 w.enable_compression(encoding)
1052 if te_encoding:
1053 w.enable_encoding(te_encoding)
1054 await part.write(w) # type: ignore[arg-type]
1055 await w.write_eof()
1056 else:
1057 await part.write(writer)
1058
1059 await writer.write(b"\r\n")
1060
1061 if close_boundary:
1062 await writer.write(b"--" + self._boundary + b"--\r\n")
1063
1064 async def close(self) -> None:
1065 """
1066 Close all part payloads that need explicit closing.
1067
1068 IMPORTANT: This method must not await anything that might not finish
1069 immediately, as it may be called during cleanup/cancellation. Schedule
1070 any long-running operations without awaiting them.
1071 """
1072 if self._consumed:
1073 return
1074 self._consumed = True
1075
1076 # Close all parts that need explicit closing
1077 # We catch and log exceptions to ensure all parts get a chance to close
1078 # we do not use asyncio.gather() here because we are not allowed
1079 # to suspend given we may be called during cleanup
1080 for idx, (part, _, _) in enumerate(self._parts):
1081 if not part.autoclose and not part.consumed:
1082 try:
1083 await part.close()
1084 except Exception as exc:
1085 internal_logger.error(
1086 "Failed to close multipart part %d: %s", idx, exc, exc_info=True
1087 )
1088
1089
1090class MultipartPayloadWriter:
1091 def __init__(self, writer: Any) -> None:
1092 self._writer = writer
1093 self._encoding: Optional[str] = None
1094 self._compress: Optional[ZLibCompressor] = None
1095 self._encoding_buffer: Optional[bytearray] = None
1096
1097 def enable_encoding(self, encoding: str) -> None:
1098 if encoding == "base64":
1099 self._encoding = encoding
1100 self._encoding_buffer = bytearray()
1101 elif encoding == "quoted-printable":
1102 self._encoding = "quoted-printable"
1103
1104 def enable_compression(
1105 self, encoding: str = "deflate", strategy: Optional[int] = None
1106 ) -> None:
1107 self._compress = ZLibCompressor(
1108 encoding=encoding,
1109 suppress_deflate_header=True,
1110 strategy=strategy,
1111 )
1112
1113 async def write_eof(self) -> None:
1114 if self._compress is not None:
1115 chunk = self._compress.flush()
1116 if chunk:
1117 self._compress = None
1118 await self.write(chunk)
1119
1120 if self._encoding == "base64":
1121 if self._encoding_buffer:
1122 await self._writer.write(base64.b64encode(self._encoding_buffer))
1123
1124 async def write(self, chunk: bytes) -> None:
1125 if self._compress is not None:
1126 if chunk:
1127 chunk = await self._compress.compress(chunk)
1128 if not chunk:
1129 return
1130
1131 if self._encoding == "base64":
1132 buf = self._encoding_buffer
1133 assert buf is not None
1134 buf.extend(chunk)
1135
1136 if buf:
1137 div, mod = divmod(len(buf), 3)
1138 enc_chunk, self._encoding_buffer = (buf[: div * 3], buf[div * 3 :])
1139 if enc_chunk:
1140 b64chunk = base64.b64encode(enc_chunk)
1141 await self._writer.write(b64chunk)
1142 elif self._encoding == "quoted-printable":
1143 await self._writer.write(binascii.b2a_qp(chunk))
1144 else:
1145 await self._writer.write(chunk)