1import base64
2import binascii
3import json
4import re
5import sys
6import uuid
7import warnings
8from collections import deque
9from collections.abc import Mapping, Sequence
10from types import TracebackType
11from typing import (
12 TYPE_CHECKING,
13 Any,
14 Deque,
15 Dict,
16 Iterator,
17 List,
18 Optional,
19 Tuple,
20 Type,
21 Union,
22 cast,
23)
24from urllib.parse import parse_qsl, unquote, urlencode
25
26from multidict import CIMultiDict, CIMultiDictProxy
27
28from .compression_utils import ZLibCompressor, ZLibDecompressor
29from .hdrs import (
30 CONTENT_DISPOSITION,
31 CONTENT_ENCODING,
32 CONTENT_LENGTH,
33 CONTENT_TRANSFER_ENCODING,
34 CONTENT_TYPE,
35)
36from .helpers import CHAR, TOKEN, parse_mimetype, reify
37from .http import HeadersParser
38from .log import internal_logger
39from .payload import (
40 JsonPayload,
41 LookupError,
42 Order,
43 Payload,
44 StringPayload,
45 get_payload,
46 payload_type,
47)
48from .streams import StreamReader
49
50if sys.version_info >= (3, 11):
51 from typing import Self
52else:
53 from typing import TypeVar
54
55 Self = TypeVar("Self", bound="BodyPartReader")
56
57__all__ = (
58 "MultipartReader",
59 "MultipartWriter",
60 "BodyPartReader",
61 "BadContentDispositionHeader",
62 "BadContentDispositionParam",
63 "parse_content_disposition",
64 "content_disposition_filename",
65)
66
67
68if TYPE_CHECKING:
69 from .client_reqrep import ClientResponse
70
71
72class BadContentDispositionHeader(RuntimeWarning):
73 pass
74
75
76class BadContentDispositionParam(RuntimeWarning):
77 pass
78
79
80def parse_content_disposition(
81 header: Optional[str],
82) -> Tuple[Optional[str], Dict[str, str]]:
83 def is_token(string: str) -> bool:
84 return bool(string) and TOKEN >= set(string)
85
86 def is_quoted(string: str) -> bool:
87 return string[0] == string[-1] == '"'
88
89 def is_rfc5987(string: str) -> bool:
90 return is_token(string) and string.count("'") == 2
91
92 def is_extended_param(string: str) -> bool:
93 return string.endswith("*")
94
95 def is_continuous_param(string: str) -> bool:
96 pos = string.find("*") + 1
97 if not pos:
98 return False
99 substring = string[pos:-1] if string.endswith("*") else string[pos:]
100 return substring.isdigit()
101
102 def unescape(text: str, *, chars: str = "".join(map(re.escape, CHAR))) -> str:
103 return re.sub(f"\\\\([{chars}])", "\\1", text)
104
105 if not header:
106 return None, {}
107
108 disptype, *parts = header.split(";")
109 if not is_token(disptype):
110 warnings.warn(BadContentDispositionHeader(header))
111 return None, {}
112
113 params: Dict[str, str] = {}
114 while parts:
115 item = parts.pop(0)
116
117 if "=" not in item:
118 warnings.warn(BadContentDispositionHeader(header))
119 return None, {}
120
121 key, value = item.split("=", 1)
122 key = key.lower().strip()
123 value = value.lstrip()
124
125 if key in params:
126 warnings.warn(BadContentDispositionHeader(header))
127 return None, {}
128
129 if not is_token(key):
130 warnings.warn(BadContentDispositionParam(item))
131 continue
132
133 elif is_continuous_param(key):
134 if is_quoted(value):
135 value = unescape(value[1:-1])
136 elif not is_token(value):
137 warnings.warn(BadContentDispositionParam(item))
138 continue
139
140 elif is_extended_param(key):
141 if is_rfc5987(value):
142 encoding, _, value = value.split("'", 2)
143 encoding = encoding or "utf-8"
144 else:
145 warnings.warn(BadContentDispositionParam(item))
146 continue
147
148 try:
149 value = unquote(value, encoding, "strict")
150 except UnicodeDecodeError: # pragma: nocover
151 warnings.warn(BadContentDispositionParam(item))
152 continue
153
154 else:
155 failed = True
156 if is_quoted(value):
157 failed = False
158 value = unescape(value[1:-1].lstrip("\\/"))
159 elif is_token(value):
160 failed = False
161 elif parts:
162 # maybe just ; in filename, in any case this is just
163 # one case fix, for proper fix we need to redesign parser
164 _value = f"{value};{parts[0]}"
165 if is_quoted(_value):
166 parts.pop(0)
167 value = unescape(_value[1:-1].lstrip("\\/"))
168 failed = False
169
170 if failed:
171 warnings.warn(BadContentDispositionHeader(header))
172 return None, {}
173
174 params[key] = value
175
176 return disptype.lower(), params
177
178
179def content_disposition_filename(
180 params: Mapping[str, str], name: str = "filename"
181) -> Optional[str]:
182 name_suf = "%s*" % name
183 if not params:
184 return None
185 elif name_suf in params:
186 return params[name_suf]
187 elif name in params:
188 return params[name]
189 else:
190 parts = []
191 fnparams = sorted(
192 (key, value) for key, value in params.items() if key.startswith(name_suf)
193 )
194 for num, (key, value) in enumerate(fnparams):
195 _, tail = key.split("*", 1)
196 if tail.endswith("*"):
197 tail = tail[:-1]
198 if tail == str(num):
199 parts.append(value)
200 else:
201 break
202 if not parts:
203 return None
204 value = "".join(parts)
205 if "'" in value:
206 encoding, _, value = value.split("'", 2)
207 encoding = encoding or "utf-8"
208 return unquote(value, encoding, "strict")
209 return value
210
211
212class MultipartResponseWrapper:
213 """Wrapper around the MultipartReader.
214
215 It takes care about
216 underlying connection and close it when it needs in.
217 """
218
219 def __init__(
220 self,
221 resp: "ClientResponse",
222 stream: "MultipartReader",
223 ) -> None:
224 self.resp = resp
225 self.stream = stream
226
227 def __aiter__(self) -> "MultipartResponseWrapper":
228 return self
229
230 async def __anext__(
231 self,
232 ) -> Union["MultipartReader", "BodyPartReader"]:
233 part = await self.next()
234 if part is None:
235 raise StopAsyncIteration
236 return part
237
238 def at_eof(self) -> bool:
239 """Returns True when all response data had been read."""
240 return self.resp.content.at_eof()
241
242 async def next(
243 self,
244 ) -> Optional[Union["MultipartReader", "BodyPartReader"]]:
245 """Emits next multipart reader object."""
246 item = await self.stream.next()
247 if self.stream.at_eof():
248 await self.release()
249 return item
250
251 async def release(self) -> None:
252 """Release the connection gracefully.
253
254 All remaining content is read to the void.
255 """
256 await self.resp.release()
257
258
259class BodyPartReader:
260 """Multipart reader for single body part."""
261
262 chunk_size = 8192
263
264 def __init__(
265 self,
266 boundary: bytes,
267 headers: "CIMultiDictProxy[str]",
268 content: StreamReader,
269 *,
270 subtype: str = "mixed",
271 default_charset: Optional[str] = None,
272 ) -> None:
273 self.headers = headers
274 self._boundary = boundary
275 self._boundary_len = len(boundary) + 2 # Boundary + \r\n
276 self._content = content
277 self._default_charset = default_charset
278 self._at_eof = False
279 self._is_form_data = subtype == "form-data"
280 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
281 length = None if self._is_form_data else self.headers.get(CONTENT_LENGTH, None)
282 self._length = int(length) if length is not None else None
283 self._read_bytes = 0
284 self._unread: Deque[bytes] = deque()
285 self._prev_chunk: Optional[bytes] = None
286 self._content_eof = 0
287 self._cache: Dict[str, Any] = {}
288
289 def __aiter__(self: Self) -> Self:
290 return self
291
292 async def __anext__(self) -> bytes:
293 part = await self.next()
294 if part is None:
295 raise StopAsyncIteration
296 return part
297
298 async def next(self) -> Optional[bytes]:
299 item = await self.read()
300 if not item:
301 return None
302 return item
303
304 async def read(self, *, decode: bool = False) -> bytes:
305 """Reads body part data.
306
307 decode: Decodes data following by encoding
308 method from Content-Encoding header. If it missed
309 data remains untouched
310 """
311 if self._at_eof:
312 return b""
313 data = bytearray()
314 while not self._at_eof:
315 data.extend(await self.read_chunk(self.chunk_size))
316 if decode:
317 return self.decode(data)
318 return data
319
320 async def read_chunk(self, size: int = chunk_size) -> bytes:
321 """Reads body part content chunk of the specified size.
322
323 size: chunk size
324 """
325 if self._at_eof:
326 return b""
327 if self._length:
328 chunk = await self._read_chunk_from_length(size)
329 else:
330 chunk = await self._read_chunk_from_stream(size)
331
332 # For the case of base64 data, we must read a fragment of size with a
333 # remainder of 0 by dividing by 4 for string without symbols \n or \r
334 encoding = self.headers.get(CONTENT_TRANSFER_ENCODING)
335 if encoding and encoding.lower() == "base64":
336 stripped_chunk = b"".join(chunk.split())
337 remainder = len(stripped_chunk) % 4
338
339 while remainder != 0 and not self.at_eof():
340 over_chunk_size = 4 - remainder
341 over_chunk = b""
342
343 if self._prev_chunk:
344 over_chunk = self._prev_chunk[:over_chunk_size]
345 self._prev_chunk = self._prev_chunk[len(over_chunk) :]
346
347 if len(over_chunk) != over_chunk_size:
348 over_chunk += await self._content.read(4 - len(over_chunk))
349
350 if not over_chunk:
351 self._at_eof = True
352
353 stripped_chunk += b"".join(over_chunk.split())
354 chunk += over_chunk
355 remainder = len(stripped_chunk) % 4
356
357 self._read_bytes += len(chunk)
358 if self._read_bytes == self._length:
359 self._at_eof = True
360 if self._at_eof:
361 clrf = await self._content.readline()
362 assert (
363 b"\r\n" == clrf
364 ), "reader did not read all the data or it is malformed"
365 return chunk
366
367 async def _read_chunk_from_length(self, size: int) -> bytes:
368 # Reads body part content chunk of the specified size.
369 # The body part must has Content-Length header with proper value.
370 assert self._length is not None, "Content-Length required for chunked read"
371 chunk_size = min(size, self._length - self._read_bytes)
372 chunk = await self._content.read(chunk_size)
373 if self._content.at_eof():
374 self._at_eof = True
375 return chunk
376
377 async def _read_chunk_from_stream(self, size: int) -> bytes:
378 # Reads content chunk of body part with unknown length.
379 # The Content-Length header for body part is not necessary.
380 assert (
381 size >= self._boundary_len
382 ), "Chunk size must be greater or equal than boundary length + 2"
383 first_chunk = self._prev_chunk is None
384 if first_chunk:
385 self._prev_chunk = await self._content.read(size)
386
387 chunk = b""
388 # content.read() may return less than size, so we need to loop to ensure
389 # we have enough data to detect the boundary.
390 while len(chunk) < self._boundary_len:
391 chunk += await self._content.read(size)
392 self._content_eof += int(self._content.at_eof())
393 assert self._content_eof < 3, "Reading after EOF"
394 if self._content_eof:
395 break
396 if len(chunk) > size:
397 self._content.unread_data(chunk[size:])
398 chunk = chunk[:size]
399
400 assert self._prev_chunk is not None
401 window = self._prev_chunk + chunk
402 sub = b"\r\n" + self._boundary
403 if first_chunk:
404 idx = window.find(sub)
405 else:
406 idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub)))
407 if idx >= 0:
408 # pushing boundary back to content
409 with warnings.catch_warnings():
410 warnings.filterwarnings("ignore", category=DeprecationWarning)
411 self._content.unread_data(window[idx:])
412 if size > idx:
413 self._prev_chunk = self._prev_chunk[:idx]
414 chunk = window[len(self._prev_chunk) : idx]
415 if not chunk:
416 self._at_eof = True
417 result = self._prev_chunk
418 self._prev_chunk = chunk
419 return result
420
421 async def readline(self) -> bytes:
422 """Reads body part by line by line."""
423 if self._at_eof:
424 return b""
425
426 if self._unread:
427 line = self._unread.popleft()
428 else:
429 line = await self._content.readline()
430
431 if line.startswith(self._boundary):
432 # the very last boundary may not come with \r\n,
433 # so set single rules for everyone
434 sline = line.rstrip(b"\r\n")
435 boundary = self._boundary
436 last_boundary = self._boundary + b"--"
437 # ensure that we read exactly the boundary, not something alike
438 if sline == boundary or sline == last_boundary:
439 self._at_eof = True
440 self._unread.append(line)
441 return b""
442 else:
443 next_line = await self._content.readline()
444 if next_line.startswith(self._boundary):
445 line = line[:-2] # strip CRLF but only once
446 self._unread.append(next_line)
447
448 return line
449
450 async def release(self) -> None:
451 """Like read(), but reads all the data to the void."""
452 if self._at_eof:
453 return
454 while not self._at_eof:
455 await self.read_chunk(self.chunk_size)
456
457 async def text(self, *, encoding: Optional[str] = None) -> str:
458 """Like read(), but assumes that body part contains text data."""
459 data = await self.read(decode=True)
460 # see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm
461 # and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send
462 encoding = encoding or self.get_charset(default="utf-8")
463 return data.decode(encoding)
464
465 async def json(self, *, encoding: Optional[str] = None) -> Optional[Dict[str, Any]]:
466 """Like read(), but assumes that body parts contains JSON data."""
467 data = await self.read(decode=True)
468 if not data:
469 return None
470 encoding = encoding or self.get_charset(default="utf-8")
471 return cast(Dict[str, Any], json.loads(data.decode(encoding)))
472
473 async def form(self, *, encoding: Optional[str] = None) -> List[Tuple[str, str]]:
474 """Like read(), but assumes that body parts contain form urlencoded data."""
475 data = await self.read(decode=True)
476 if not data:
477 return []
478 if encoding is not None:
479 real_encoding = encoding
480 else:
481 real_encoding = self.get_charset(default="utf-8")
482 try:
483 decoded_data = data.rstrip().decode(real_encoding)
484 except UnicodeDecodeError:
485 raise ValueError("data cannot be decoded with %s encoding" % real_encoding)
486
487 return parse_qsl(
488 decoded_data,
489 keep_blank_values=True,
490 encoding=real_encoding,
491 )
492
493 def at_eof(self) -> bool:
494 """Returns True if the boundary was reached or False otherwise."""
495 return self._at_eof
496
497 def decode(self, data: bytes) -> bytes:
498 """Decodes data.
499
500 Decoding is done according the specified Content-Encoding
501 or Content-Transfer-Encoding headers value.
502 """
503 if CONTENT_TRANSFER_ENCODING in self.headers:
504 data = self._decode_content_transfer(data)
505 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
506 if not self._is_form_data and CONTENT_ENCODING in self.headers:
507 return self._decode_content(data)
508 return data
509
510 def _decode_content(self, data: bytes) -> bytes:
511 encoding = self.headers.get(CONTENT_ENCODING, "").lower()
512 if encoding == "identity":
513 return data
514 if encoding in {"deflate", "gzip"}:
515 return ZLibDecompressor(
516 encoding=encoding,
517 suppress_deflate_header=True,
518 ).decompress_sync(data)
519
520 raise RuntimeError(f"unknown content encoding: {encoding}")
521
522 def _decode_content_transfer(self, data: bytes) -> bytes:
523 encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
524
525 if encoding == "base64":
526 return base64.b64decode(data)
527 elif encoding == "quoted-printable":
528 return binascii.a2b_qp(data)
529 elif encoding in ("binary", "8bit", "7bit"):
530 return data
531 else:
532 raise RuntimeError(f"unknown content transfer encoding: {encoding}")
533
534 def get_charset(self, default: str) -> str:
535 """Returns charset parameter from Content-Type header or default."""
536 ctype = self.headers.get(CONTENT_TYPE, "")
537 mimetype = parse_mimetype(ctype)
538 return mimetype.parameters.get("charset", self._default_charset or default)
539
540 @reify
541 def name(self) -> Optional[str]:
542 """Returns name specified in Content-Disposition header.
543
544 If the header is missing or malformed, returns None.
545 """
546 _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
547 return content_disposition_filename(params, "name")
548
549 @reify
550 def filename(self) -> Optional[str]:
551 """Returns filename specified in Content-Disposition header.
552
553 Returns None if the header is missing or malformed.
554 """
555 _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
556 return content_disposition_filename(params, "filename")
557
558
559@payload_type(BodyPartReader, order=Order.try_first)
560class BodyPartReaderPayload(Payload):
561 _value: BodyPartReader
562 # _autoclose = False (inherited) - Streaming reader that may have resources
563
564 def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None:
565 super().__init__(value, *args, **kwargs)
566
567 params: Dict[str, str] = {}
568 if value.name is not None:
569 params["name"] = value.name
570 if value.filename is not None:
571 params["filename"] = value.filename
572
573 if params:
574 self.set_content_disposition("attachment", True, **params)
575
576 def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
577 raise TypeError("Unable to decode.")
578
579 async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes:
580 """Raises TypeError as body parts should be consumed via write().
581
582 This is intentional: BodyPartReader payloads are designed for streaming
583 large data (potentially gigabytes) and must be consumed only once via
584 the write() method to avoid memory exhaustion. They cannot be buffered
585 in memory for reuse.
586 """
587 raise TypeError("Unable to read body part as bytes. Use write() to consume.")
588
589 async def write(self, writer: Any) -> None:
590 field = self._value
591 chunk = await field.read_chunk(size=2**16)
592 while chunk:
593 await writer.write(field.decode(chunk))
594 chunk = await field.read_chunk(size=2**16)
595
596
597class MultipartReader:
598 """Multipart body reader."""
599
600 #: Response wrapper, used when multipart readers constructs from response.
601 response_wrapper_cls = MultipartResponseWrapper
602 #: Multipart reader class, used to handle multipart/* body parts.
603 #: None points to type(self)
604 multipart_reader_cls: Optional[Type["MultipartReader"]] = None
605 #: Body part reader class for non multipart/* content types.
606 part_reader_cls = BodyPartReader
607
608 def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None:
609 self._mimetype = parse_mimetype(headers[CONTENT_TYPE])
610 assert self._mimetype.type == "multipart", "multipart/* content type expected"
611 if "boundary" not in self._mimetype.parameters:
612 raise ValueError(
613 "boundary missed for Content-Type: %s" % headers[CONTENT_TYPE]
614 )
615
616 self.headers = headers
617 self._boundary = ("--" + self._get_boundary()).encode()
618 self._content = content
619 self._default_charset: Optional[str] = None
620 self._last_part: Optional[Union["MultipartReader", BodyPartReader]] = None
621 self._at_eof = False
622 self._at_bof = True
623 self._unread: List[bytes] = []
624
625 def __aiter__(self: Self) -> Self:
626 return self
627
628 async def __anext__(
629 self,
630 ) -> Optional[Union["MultipartReader", BodyPartReader]]:
631 part = await self.next()
632 if part is None:
633 raise StopAsyncIteration
634 return part
635
636 @classmethod
637 def from_response(
638 cls,
639 response: "ClientResponse",
640 ) -> MultipartResponseWrapper:
641 """Constructs reader instance from HTTP response.
642
643 :param response: :class:`~aiohttp.client.ClientResponse` instance
644 """
645 obj = cls.response_wrapper_cls(
646 response, cls(response.headers, response.content)
647 )
648 return obj
649
650 def at_eof(self) -> bool:
651 """Returns True if the final boundary was reached, false otherwise."""
652 return self._at_eof
653
654 async def next(
655 self,
656 ) -> Optional[Union["MultipartReader", BodyPartReader]]:
657 """Emits the next multipart body part."""
658 # So, if we're at BOF, we need to skip till the boundary.
659 if self._at_eof:
660 return None
661 await self._maybe_release_last_part()
662 if self._at_bof:
663 await self._read_until_first_boundary()
664 self._at_bof = False
665 else:
666 await self._read_boundary()
667 if self._at_eof: # we just read the last boundary, nothing to do there
668 return None
669
670 part = await self.fetch_next_part()
671 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.6
672 if (
673 self._last_part is None
674 and self._mimetype.subtype == "form-data"
675 and isinstance(part, BodyPartReader)
676 ):
677 _, params = parse_content_disposition(part.headers.get(CONTENT_DISPOSITION))
678 if params.get("name") == "_charset_":
679 # Longest encoding in https://encoding.spec.whatwg.org/encodings.json
680 # is 19 characters, so 32 should be more than enough for any valid encoding.
681 charset = await part.read_chunk(32)
682 if len(charset) > 31:
683 raise RuntimeError("Invalid default charset")
684 self._default_charset = charset.strip().decode()
685 part = await self.fetch_next_part()
686 self._last_part = part
687 return self._last_part
688
689 async def release(self) -> None:
690 """Reads all the body parts to the void till the final boundary."""
691 while not self._at_eof:
692 item = await self.next()
693 if item is None:
694 break
695 await item.release()
696
697 async def fetch_next_part(
698 self,
699 ) -> Union["MultipartReader", BodyPartReader]:
700 """Returns the next body part reader."""
701 headers = await self._read_headers()
702 return self._get_part_reader(headers)
703
704 def _get_part_reader(
705 self,
706 headers: "CIMultiDictProxy[str]",
707 ) -> Union["MultipartReader", BodyPartReader]:
708 """Dispatches the response by the `Content-Type` header.
709
710 Returns a suitable reader instance.
711
712 :param dict headers: Response headers
713 """
714 ctype = headers.get(CONTENT_TYPE, "")
715 mimetype = parse_mimetype(ctype)
716
717 if mimetype.type == "multipart":
718 if self.multipart_reader_cls is None:
719 return type(self)(headers, self._content)
720 return self.multipart_reader_cls(headers, self._content)
721 else:
722 return self.part_reader_cls(
723 self._boundary,
724 headers,
725 self._content,
726 subtype=self._mimetype.subtype,
727 default_charset=self._default_charset,
728 )
729
730 def _get_boundary(self) -> str:
731 boundary = self._mimetype.parameters["boundary"]
732 if len(boundary) > 70:
733 raise ValueError("boundary %r is too long (70 chars max)" % boundary)
734
735 return boundary
736
737 async def _readline(self) -> bytes:
738 if self._unread:
739 return self._unread.pop()
740 return await self._content.readline()
741
742 async def _read_until_first_boundary(self) -> None:
743 while True:
744 chunk = await self._readline()
745 if chunk == b"":
746 raise ValueError(
747 "Could not find starting boundary %r" % (self._boundary)
748 )
749 chunk = chunk.rstrip()
750 if chunk == self._boundary:
751 return
752 elif chunk == self._boundary + b"--":
753 self._at_eof = True
754 return
755
756 async def _read_boundary(self) -> None:
757 chunk = (await self._readline()).rstrip()
758 if chunk == self._boundary:
759 pass
760 elif chunk == self._boundary + b"--":
761 self._at_eof = True
762 epilogue = await self._readline()
763 next_line = await self._readline()
764
765 # the epilogue is expected and then either the end of input or the
766 # parent multipart boundary, if the parent boundary is found then
767 # it should be marked as unread and handed to the parent for
768 # processing
769 if next_line[:2] == b"--":
770 self._unread.append(next_line)
771 # otherwise the request is likely missing an epilogue and both
772 # lines should be passed to the parent for processing
773 # (this handles the old behavior gracefully)
774 else:
775 self._unread.extend([next_line, epilogue])
776 else:
777 raise ValueError(f"Invalid boundary {chunk!r}, expected {self._boundary!r}")
778
779 async def _read_headers(self) -> "CIMultiDictProxy[str]":
780 lines = [b""]
781 while True:
782 chunk = await self._content.readline()
783 chunk = chunk.strip()
784 lines.append(chunk)
785 if not chunk:
786 break
787 parser = HeadersParser()
788 headers, raw_headers = parser.parse_headers(lines)
789 return headers
790
791 async def _maybe_release_last_part(self) -> None:
792 """Ensures that the last read body part is read completely."""
793 if self._last_part is not None:
794 if not self._last_part.at_eof():
795 await self._last_part.release()
796 self._unread.extend(self._last_part._unread)
797 self._last_part = None
798
799
800_Part = Tuple[Payload, str, str]
801
802
803class MultipartWriter(Payload):
804 """Multipart body writer."""
805
806 _value: None
807 # _consumed = False (inherited) - Can be encoded multiple times
808 _autoclose = True # No file handles, just collects parts in memory
809
810 def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> None:
811 boundary = boundary if boundary is not None else uuid.uuid4().hex
812 # The underlying Payload API demands a str (utf-8), not bytes,
813 # so we need to ensure we don't lose anything during conversion.
814 # As a result, require the boundary to be ASCII only.
815 # In both situations.
816
817 try:
818 self._boundary = boundary.encode("ascii")
819 except UnicodeEncodeError:
820 raise ValueError("boundary should contain ASCII only chars") from None
821 ctype = f"multipart/{subtype}; boundary={self._boundary_value}"
822
823 super().__init__(None, content_type=ctype)
824
825 self._parts: List[_Part] = []
826 self._is_form_data = subtype == "form-data"
827
828 def __enter__(self) -> "MultipartWriter":
829 return self
830
831 def __exit__(
832 self,
833 exc_type: Optional[Type[BaseException]],
834 exc_val: Optional[BaseException],
835 exc_tb: Optional[TracebackType],
836 ) -> None:
837 pass
838
839 def __iter__(self) -> Iterator[_Part]:
840 return iter(self._parts)
841
842 def __len__(self) -> int:
843 return len(self._parts)
844
845 def __bool__(self) -> bool:
846 return True
847
848 _valid_tchar_regex = re.compile(rb"\A[!#$%&'*+\-.^_`|~\w]+\Z")
849 _invalid_qdtext_char_regex = re.compile(rb"[\x00-\x08\x0A-\x1F\x7F]")
850
851 @property
852 def _boundary_value(self) -> str:
853 """Wrap boundary parameter value in quotes, if necessary.
854
855 Reads self.boundary and returns a unicode string.
856 """
857 # Refer to RFCs 7231, 7230, 5234.
858 #
859 # parameter = token "=" ( token / quoted-string )
860 # token = 1*tchar
861 # quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE
862 # qdtext = HTAB / SP / %x21 / %x23-5B / %x5D-7E / obs-text
863 # obs-text = %x80-FF
864 # quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text )
865 # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*"
866 # / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~"
867 # / DIGIT / ALPHA
868 # ; any VCHAR, except delimiters
869 # VCHAR = %x21-7E
870 value = self._boundary
871 if re.match(self._valid_tchar_regex, value):
872 return value.decode("ascii") # cannot fail
873
874 if re.search(self._invalid_qdtext_char_regex, value):
875 raise ValueError("boundary value contains invalid characters")
876
877 # escape %x5C and %x22
878 quoted_value_content = value.replace(b"\\", b"\\\\")
879 quoted_value_content = quoted_value_content.replace(b'"', b'\\"')
880
881 return '"' + quoted_value_content.decode("ascii") + '"'
882
883 @property
884 def boundary(self) -> str:
885 return self._boundary.decode("ascii")
886
887 def append(self, obj: Any, headers: Optional[Mapping[str, str]] = None) -> Payload:
888 if headers is None:
889 headers = CIMultiDict()
890
891 if isinstance(obj, Payload):
892 obj.headers.update(headers)
893 return self.append_payload(obj)
894 else:
895 try:
896 payload = get_payload(obj, headers=headers)
897 except LookupError:
898 raise TypeError("Cannot create payload from %r" % obj)
899 else:
900 return self.append_payload(payload)
901
902 def append_payload(self, payload: Payload) -> Payload:
903 """Adds a new body part to multipart writer."""
904 encoding: Optional[str] = None
905 te_encoding: Optional[str] = None
906 if self._is_form_data:
907 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.7
908 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
909 assert (
910 not {CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING}
911 & payload.headers.keys()
912 )
913 # Set default Content-Disposition in case user doesn't create one
914 if CONTENT_DISPOSITION not in payload.headers:
915 name = f"section-{len(self._parts)}"
916 payload.set_content_disposition("form-data", name=name)
917 else:
918 # compression
919 encoding = payload.headers.get(CONTENT_ENCODING, "").lower()
920 if encoding and encoding not in ("deflate", "gzip", "identity"):
921 raise RuntimeError(f"unknown content encoding: {encoding}")
922 if encoding == "identity":
923 encoding = None
924
925 # te encoding
926 te_encoding = payload.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
927 if te_encoding not in ("", "base64", "quoted-printable", "binary"):
928 raise RuntimeError(f"unknown content transfer encoding: {te_encoding}")
929 if te_encoding == "binary":
930 te_encoding = None
931
932 # size
933 size = payload.size
934 if size is not None and not (encoding or te_encoding):
935 payload.headers[CONTENT_LENGTH] = str(size)
936
937 self._parts.append((payload, encoding, te_encoding)) # type: ignore[arg-type]
938 return payload
939
940 def append_json(
941 self, obj: Any, headers: Optional[Mapping[str, str]] = None
942 ) -> Payload:
943 """Helper to append JSON part."""
944 if headers is None:
945 headers = CIMultiDict()
946
947 return self.append_payload(JsonPayload(obj, headers=headers))
948
949 def append_form(
950 self,
951 obj: Union[Sequence[Tuple[str, str]], Mapping[str, str]],
952 headers: Optional[Mapping[str, str]] = None,
953 ) -> Payload:
954 """Helper to append form urlencoded part."""
955 assert isinstance(obj, (Sequence, Mapping))
956
957 if headers is None:
958 headers = CIMultiDict()
959
960 if isinstance(obj, Mapping):
961 obj = list(obj.items())
962 data = urlencode(obj, doseq=True)
963
964 return self.append_payload(
965 StringPayload(
966 data, headers=headers, content_type="application/x-www-form-urlencoded"
967 )
968 )
969
970 @property
971 def size(self) -> Optional[int]:
972 """Size of the payload."""
973 total = 0
974 for part, encoding, te_encoding in self._parts:
975 if encoding or te_encoding or part.size is None:
976 return None
977
978 total += int(
979 2
980 + len(self._boundary)
981 + 2
982 + part.size # b'--'+self._boundary+b'\r\n'
983 + len(part._binary_headers)
984 + 2 # b'\r\n'
985 )
986
987 total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n'
988 return total
989
990 def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
991 """Return string representation of the multipart data.
992
993 WARNING: This method may do blocking I/O if parts contain file payloads.
994 It should not be called in the event loop. Use as_bytes().decode() instead.
995 """
996 return "".join(
997 "--"
998 + self.boundary
999 + "\r\n"
1000 + part._binary_headers.decode(encoding, errors)
1001 + part.decode()
1002 for part, _e, _te in self._parts
1003 )
1004
1005 async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes:
1006 """Return bytes representation of the multipart data.
1007
1008 This method is async-safe and calls as_bytes on underlying payloads.
1009 """
1010 parts: List[bytes] = []
1011
1012 # Process each part
1013 for part, _e, _te in self._parts:
1014 # Add boundary
1015 parts.append(b"--" + self._boundary + b"\r\n")
1016
1017 # Add headers
1018 parts.append(part._binary_headers)
1019
1020 # Add payload content using as_bytes for async safety
1021 part_bytes = await part.as_bytes(encoding, errors)
1022 parts.append(part_bytes)
1023
1024 # Add trailing CRLF
1025 parts.append(b"\r\n")
1026
1027 # Add closing boundary
1028 parts.append(b"--" + self._boundary + b"--\r\n")
1029
1030 return b"".join(parts)
1031
1032 async def write(self, writer: Any, close_boundary: bool = True) -> None:
1033 """Write body."""
1034 for part, encoding, te_encoding in self._parts:
1035 if self._is_form_data:
1036 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.2
1037 assert CONTENT_DISPOSITION in part.headers
1038 assert "name=" in part.headers[CONTENT_DISPOSITION]
1039
1040 await writer.write(b"--" + self._boundary + b"\r\n")
1041 await writer.write(part._binary_headers)
1042
1043 if encoding or te_encoding:
1044 w = MultipartPayloadWriter(writer)
1045 if encoding:
1046 w.enable_compression(encoding)
1047 if te_encoding:
1048 w.enable_encoding(te_encoding)
1049 await part.write(w) # type: ignore[arg-type]
1050 await w.write_eof()
1051 else:
1052 await part.write(writer)
1053
1054 await writer.write(b"\r\n")
1055
1056 if close_boundary:
1057 await writer.write(b"--" + self._boundary + b"--\r\n")
1058
1059 async def close(self) -> None:
1060 """
1061 Close all part payloads that need explicit closing.
1062
1063 IMPORTANT: This method must not await anything that might not finish
1064 immediately, as it may be called during cleanup/cancellation. Schedule
1065 any long-running operations without awaiting them.
1066 """
1067 if self._consumed:
1068 return
1069 self._consumed = True
1070
1071 # Close all parts that need explicit closing
1072 # We catch and log exceptions to ensure all parts get a chance to close
1073 # we do not use asyncio.gather() here because we are not allowed
1074 # to suspend given we may be called during cleanup
1075 for idx, (part, _, _) in enumerate(self._parts):
1076 if not part.autoclose and not part.consumed:
1077 try:
1078 await part.close()
1079 except Exception as exc:
1080 internal_logger.error(
1081 "Failed to close multipart part %d: %s", idx, exc, exc_info=True
1082 )
1083
1084
1085class MultipartPayloadWriter:
1086 def __init__(self, writer: Any) -> None:
1087 self._writer = writer
1088 self._encoding: Optional[str] = None
1089 self._compress: Optional[ZLibCompressor] = None
1090 self._encoding_buffer: Optional[bytearray] = None
1091
1092 def enable_encoding(self, encoding: str) -> None:
1093 if encoding == "base64":
1094 self._encoding = encoding
1095 self._encoding_buffer = bytearray()
1096 elif encoding == "quoted-printable":
1097 self._encoding = "quoted-printable"
1098
1099 def enable_compression(
1100 self, encoding: str = "deflate", strategy: Optional[int] = None
1101 ) -> None:
1102 self._compress = ZLibCompressor(
1103 encoding=encoding,
1104 suppress_deflate_header=True,
1105 strategy=strategy,
1106 )
1107
1108 async def write_eof(self) -> None:
1109 if self._compress is not None:
1110 chunk = self._compress.flush()
1111 if chunk:
1112 self._compress = None
1113 await self.write(chunk)
1114
1115 if self._encoding == "base64":
1116 if self._encoding_buffer:
1117 await self._writer.write(base64.b64encode(self._encoding_buffer))
1118
1119 async def write(self, chunk: bytes) -> None:
1120 if self._compress is not None:
1121 if chunk:
1122 chunk = await self._compress.compress(chunk)
1123 if not chunk:
1124 return
1125
1126 if self._encoding == "base64":
1127 buf = self._encoding_buffer
1128 assert buf is not None
1129 buf.extend(chunk)
1130
1131 if buf:
1132 div, mod = divmod(len(buf), 3)
1133 enc_chunk, self._encoding_buffer = (buf[: div * 3], buf[div * 3 :])
1134 if enc_chunk:
1135 b64chunk = base64.b64encode(enc_chunk)
1136 await self._writer.write(b64chunk)
1137 elif self._encoding == "quoted-printable":
1138 await self._writer.write(binascii.b2a_qp(chunk))
1139 else:
1140 await self._writer.write(chunk)