Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/multipart.py: 18%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import base64
2import binascii
3import json
4import re
5import sys
6import uuid
7import warnings
8from collections import deque
9from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
10from types import TracebackType
11from typing import TYPE_CHECKING, Any, Union, cast
12from urllib.parse import parse_qsl, unquote, urlencode
14from multidict import CIMultiDict, CIMultiDictProxy
16from .abc import AbstractStreamWriter
17from .compression_utils import (
18 DEFAULT_MAX_DECOMPRESS_SIZE,
19 ZLibCompressor,
20 ZLibDecompressor,
21)
22from .hdrs import (
23 CONTENT_DISPOSITION,
24 CONTENT_ENCODING,
25 CONTENT_LENGTH,
26 CONTENT_TRANSFER_ENCODING,
27 CONTENT_TYPE,
28)
29from .helpers import CHAR, TOKEN, parse_mimetype, reify
30from .http import HeadersParser
31from .http_exceptions import BadHttpMessage
32from .log import internal_logger
33from .payload import (
34 JsonPayload,
35 LookupError,
36 Order,
37 Payload,
38 StringPayload,
39 get_payload,
40 payload_type,
41)
42from .streams import StreamReader
44if sys.version_info >= (3, 11):
45 from typing import Self
46else:
47 from typing import TypeVar
49 Self = TypeVar("Self", bound="BodyPartReader")
51__all__ = (
52 "MultipartReader",
53 "MultipartWriter",
54 "BodyPartReader",
55 "BadContentDispositionHeader",
56 "BadContentDispositionParam",
57 "parse_content_disposition",
58 "content_disposition_filename",
59)
62if TYPE_CHECKING:
63 from .client_reqrep import ClientResponse
66class BadContentDispositionHeader(RuntimeWarning):
67 pass
70class BadContentDispositionParam(RuntimeWarning):
71 pass
74def parse_content_disposition(
75 header: str | None,
76) -> tuple[str | None, dict[str, str]]:
77 def is_token(string: str) -> bool:
78 return bool(string) and TOKEN >= set(string)
80 def is_quoted(string: str) -> bool:
81 return string[0] == string[-1] == '"'
83 def is_rfc5987(string: str) -> bool:
84 return is_token(string) and string.count("'") == 2
86 def is_extended_param(string: str) -> bool:
87 return string.endswith("*")
89 def is_continuous_param(string: str) -> bool:
90 pos = string.find("*") + 1
91 if not pos:
92 return False
93 substring = string[pos:-1] if string.endswith("*") else string[pos:]
94 return substring.isdigit()
96 def unescape(text: str, *, chars: str = "".join(map(re.escape, CHAR))) -> str:
97 return re.sub(f"\\\\([{chars}])", "\\1", text)
99 if not header:
100 return None, {}
102 disptype, *parts = header.split(";")
103 if not is_token(disptype):
104 warnings.warn(BadContentDispositionHeader(header))
105 return None, {}
107 params: dict[str, str] = {}
108 while parts:
109 item = parts.pop(0)
111 if not item: # To handle trailing semicolons
112 warnings.warn(BadContentDispositionHeader(header))
113 continue
115 if "=" not in item:
116 warnings.warn(BadContentDispositionHeader(header))
117 return None, {}
119 key, value = item.split("=", 1)
120 key = key.lower().strip()
121 value = value.lstrip()
123 if key in params:
124 warnings.warn(BadContentDispositionHeader(header))
125 return None, {}
127 if not is_token(key):
128 warnings.warn(BadContentDispositionParam(item))
129 continue
131 elif is_continuous_param(key):
132 if is_quoted(value):
133 value = unescape(value[1:-1])
134 elif not is_token(value):
135 warnings.warn(BadContentDispositionParam(item))
136 continue
138 elif is_extended_param(key):
139 if is_rfc5987(value):
140 encoding, _, value = value.split("'", 2)
141 encoding = encoding or "utf-8"
142 else:
143 warnings.warn(BadContentDispositionParam(item))
144 continue
146 try:
147 value = unquote(value, encoding, "strict")
148 except UnicodeDecodeError: # pragma: nocover
149 warnings.warn(BadContentDispositionParam(item))
150 continue
152 else:
153 failed = True
154 if is_quoted(value):
155 failed = False
156 value = unescape(value[1:-1].lstrip("\\/"))
157 elif is_token(value):
158 failed = False
159 elif parts:
160 # maybe just ; in filename, in any case this is just
161 # one case fix, for proper fix we need to redesign parser
162 _value = f"{value};{parts[0]}"
163 if is_quoted(_value):
164 parts.pop(0)
165 value = unescape(_value[1:-1].lstrip("\\/"))
166 failed = False
168 if failed:
169 warnings.warn(BadContentDispositionHeader(header))
170 return None, {}
172 params[key] = value
174 return disptype.lower(), params
177def content_disposition_filename(
178 params: Mapping[str, str], name: str = "filename"
179) -> str | None:
180 name_suf = "%s*" % name
181 if not params:
182 return None
183 elif name_suf in params:
184 return params[name_suf]
185 elif name in params:
186 return params[name]
187 else:
188 parts = []
189 fnparams = sorted(
190 (key, value) for key, value in params.items() if key.startswith(name_suf)
191 )
192 for num, (key, value) in enumerate(fnparams):
193 _, tail = key.split("*", 1)
194 if tail.endswith("*"):
195 tail = tail[:-1]
196 if tail == str(num):
197 parts.append(value)
198 else:
199 break
200 if not parts:
201 return None
202 value = "".join(parts)
203 if "'" in value:
204 encoding, _, value = value.split("'", 2)
205 encoding = encoding or "utf-8"
206 return unquote(value, encoding, "strict")
207 return value
210class MultipartResponseWrapper:
211 """Wrapper around the MultipartReader.
213 It takes care about
214 underlying connection and close it when it needs in.
215 """
217 def __init__(
218 self,
219 resp: "ClientResponse",
220 stream: "MultipartReader",
221 ) -> None:
222 self.resp = resp
223 self.stream = stream
225 def __aiter__(self) -> "MultipartResponseWrapper":
226 return self
228 async def __anext__(
229 self,
230 ) -> Union["MultipartReader", "BodyPartReader"]:
231 part = await self.next()
232 if part is None:
233 raise StopAsyncIteration
234 return part
236 def at_eof(self) -> bool:
237 """Returns True when all response data had been read."""
238 return self.resp.content.at_eof()
240 async def next(
241 self,
242 ) -> Union["MultipartReader", "BodyPartReader"] | None:
243 """Emits next multipart reader object."""
244 item = await self.stream.next()
245 if self.stream.at_eof():
246 await self.release()
247 return item
249 async def release(self) -> None:
250 """Release the connection gracefully.
252 All remaining content is read to the void.
253 """
254 self.resp.release()
257class BodyPartReader:
258 """Multipart reader for single body part."""
260 chunk_size = 8192
262 def __init__(
263 self,
264 boundary: bytes,
265 headers: "CIMultiDictProxy[str]",
266 content: StreamReader,
267 *,
268 subtype: str = "mixed",
269 default_charset: str | None = None,
270 max_decompress_size: int = DEFAULT_MAX_DECOMPRESS_SIZE,
271 ) -> None:
272 self.headers = headers
273 self._boundary = boundary
274 self._boundary_len = len(boundary) + 2 # Boundary + \r\n
275 self._content = content
276 self._default_charset = default_charset
277 self._at_eof = False
278 self._is_form_data = subtype == "form-data"
279 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
280 length = None if self._is_form_data else self.headers.get(CONTENT_LENGTH, None)
281 self._length = int(length) if length is not None else None
282 self._read_bytes = 0
283 self._unread: deque[bytes] = deque()
284 self._prev_chunk: bytes | None = None
285 self._content_eof = 0
286 self._cache: dict[str, Any] = {}
287 self._max_decompress_size = max_decompress_size
289 def __aiter__(self) -> Self:
290 return self
292 async def __anext__(self) -> bytes:
293 part = await self.next()
294 if part is None:
295 raise StopAsyncIteration
296 return part
298 async def next(self) -> bytes | None:
299 item = await self.read()
300 if not item:
301 return None
302 return item
304 async def read(self, *, decode: bool = False) -> bytes:
305 """Reads body part data.
307 decode: Decodes data following by encoding
308 method from Content-Encoding header. If it missed
309 data remains untouched
310 """
311 if self._at_eof:
312 return b""
313 data = bytearray()
314 while not self._at_eof:
315 data.extend(await self.read_chunk(self.chunk_size))
316 # https://github.com/python/mypy/issues/17537
317 if decode: # type: ignore[unreachable]
318 decoded_data = bytearray()
319 async for d in self.decode_iter(data):
320 decoded_data.extend(d)
321 return decoded_data
322 return data
324 async def read_chunk(self, size: int = chunk_size) -> bytes:
325 """Reads body part content chunk of the specified size.
327 size: chunk size
328 """
329 if self._at_eof:
330 return b""
331 if self._length:
332 chunk = await self._read_chunk_from_length(size)
333 else:
334 chunk = await self._read_chunk_from_stream(size)
336 # For the case of base64 data, we must read a fragment of size with a
337 # remainder of 0 by dividing by 4 for string without symbols \n or \r
338 encoding = self.headers.get(CONTENT_TRANSFER_ENCODING)
339 if encoding and encoding.lower() == "base64":
340 stripped_chunk = b"".join(chunk.split())
341 remainder = len(stripped_chunk) % 4
343 while remainder != 0 and not self.at_eof():
344 over_chunk_size = 4 - remainder
345 over_chunk = b""
347 if self._prev_chunk:
348 over_chunk = self._prev_chunk[:over_chunk_size]
349 self._prev_chunk = self._prev_chunk[len(over_chunk) :]
351 if len(over_chunk) != over_chunk_size:
352 over_chunk += await self._content.read(4 - len(over_chunk))
354 if not over_chunk:
355 self._at_eof = True
357 stripped_chunk += b"".join(over_chunk.split())
358 chunk += over_chunk
359 remainder = len(stripped_chunk) % 4
361 self._read_bytes += len(chunk)
362 if self._read_bytes == self._length:
363 self._at_eof = True
364 if self._at_eof and await self._content.readline() != b"\r\n":
365 raise ValueError("Reader did not read all the data or it is malformed")
366 return chunk
368 async def _read_chunk_from_length(self, size: int) -> bytes:
369 # Reads body part content chunk of the specified size.
370 # The body part must has Content-Length header with proper value.
371 assert self._length is not None, "Content-Length required for chunked read"
372 chunk_size = min(size, self._length - self._read_bytes)
373 chunk = await self._content.read(chunk_size)
374 if self._content.at_eof():
375 self._at_eof = True
376 return chunk
378 async def _read_chunk_from_stream(self, size: int) -> bytes:
379 # Reads content chunk of body part with unknown length.
380 # The Content-Length header for body part is not necessary.
381 assert (
382 size >= self._boundary_len
383 ), "Chunk size must be greater or equal than boundary length + 2"
384 first_chunk = self._prev_chunk is None
385 if first_chunk:
386 # We need to re-add the CRLF that got removed from headers parsing.
387 self._prev_chunk = b"\r\n" + await self._content.read(size)
389 chunk = b""
390 # content.read() may return less than size, so we need to loop to ensure
391 # we have enough data to detect the boundary.
392 while len(chunk) < self._boundary_len:
393 chunk += await self._content.read(size)
394 self._content_eof += int(self._content.at_eof())
395 if self._content_eof > 2:
396 raise ValueError("Reading after EOF")
397 if self._content_eof:
398 break
399 if len(chunk) > size:
400 self._content.unread_data(chunk[size:])
401 chunk = chunk[:size]
403 assert self._prev_chunk is not None
404 window = self._prev_chunk + chunk
405 sub = b"\r\n" + self._boundary
406 if first_chunk:
407 idx = window.find(sub)
408 else:
409 idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub)))
410 if idx >= 0:
411 # pushing boundary back to content
412 with warnings.catch_warnings():
413 warnings.filterwarnings("ignore", category=DeprecationWarning)
414 self._content.unread_data(window[idx:])
415 self._prev_chunk = self._prev_chunk[:idx]
416 chunk = window[len(self._prev_chunk) : idx]
417 if not chunk:
418 self._at_eof = True
419 result = self._prev_chunk[2 if first_chunk else 0 :] # Strip initial CRLF
420 self._prev_chunk = chunk
421 return result
423 async def readline(self) -> bytes:
424 """Reads body part by line by line."""
425 if self._at_eof:
426 return b""
428 if self._unread:
429 line = self._unread.popleft()
430 else:
431 line = await self._content.readline()
433 if line.startswith(self._boundary):
434 # the very last boundary may not come with \r\n,
435 # so set single rules for everyone
436 sline = line.rstrip(b"\r\n")
437 boundary = self._boundary
438 last_boundary = self._boundary + b"--"
439 # ensure that we read exactly the boundary, not something alike
440 if sline == boundary or sline == last_boundary:
441 self._at_eof = True
442 self._unread.append(line)
443 return b""
444 else:
445 next_line = await self._content.readline()
446 if next_line.startswith(self._boundary):
447 line = line[:-2] # strip CRLF but only once
448 self._unread.append(next_line)
450 return line
452 async def release(self) -> None:
453 """Like read(), but reads all the data to the void."""
454 if self._at_eof:
455 return
456 while not self._at_eof:
457 await self.read_chunk(self.chunk_size)
459 async def text(self, *, encoding: str | None = None) -> str:
460 """Like read(), but assumes that body part contains text data."""
461 data = await self.read(decode=True)
462 # see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm
463 # and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send
464 encoding = encoding or self.get_charset(default="utf-8")
465 return data.decode(encoding)
467 async def json(self, *, encoding: str | None = None) -> dict[str, Any] | None:
468 """Like read(), but assumes that body parts contains JSON data."""
469 data = await self.read(decode=True)
470 if not data:
471 return None
472 encoding = encoding or self.get_charset(default="utf-8")
473 return cast(dict[str, Any], json.loads(data.decode(encoding)))
475 async def form(self, *, encoding: str | None = None) -> list[tuple[str, str]]:
476 """Like read(), but assumes that body parts contain form urlencoded data."""
477 data = await self.read(decode=True)
478 if not data:
479 return []
480 if encoding is not None:
481 real_encoding = encoding
482 else:
483 real_encoding = self.get_charset(default="utf-8")
484 try:
485 decoded_data = data.rstrip().decode(real_encoding)
486 except UnicodeDecodeError:
487 raise ValueError("data cannot be decoded with %s encoding" % real_encoding)
489 return parse_qsl(
490 decoded_data,
491 keep_blank_values=True,
492 encoding=real_encoding,
493 )
495 def at_eof(self) -> bool:
496 """Returns True if the boundary was reached or False otherwise."""
497 return self._at_eof
499 def _apply_content_transfer_decoding(self, data: bytes) -> bytes:
500 """Apply Content-Transfer-Encoding decoding if header is present."""
501 if CONTENT_TRANSFER_ENCODING in self.headers:
502 return self._decode_content_transfer(data)
503 return data
505 def _needs_content_decoding(self) -> bool:
506 """Check if Content-Encoding decoding should be applied."""
507 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
508 return not self._is_form_data and CONTENT_ENCODING in self.headers
510 def decode(self, data: bytes) -> bytes:
511 """Decodes data synchronously.
513 Decodes data according the specified Content-Encoding
514 or Content-Transfer-Encoding headers value.
516 Note: For large payloads, consider using decode_iter() instead
517 to avoid blocking the event loop during decompression.
518 """
519 data = self._apply_content_transfer_decoding(data)
520 if self._needs_content_decoding():
521 return self._decode_content(data)
522 return data
524 async def decode_iter(self, data: bytes) -> AsyncIterator[bytes]:
525 """Async generator that yields decoded data chunks.
527 Decodes data according the specified Content-Encoding
528 or Content-Transfer-Encoding headers value.
530 This method offloads decompression to an executor for large payloads
531 to avoid blocking the event loop.
532 """
533 data = self._apply_content_transfer_decoding(data)
534 if self._needs_content_decoding():
535 async for d in self._decode_content_async(data):
536 yield d
537 else:
538 yield data
540 def _decode_content(self, data: bytes) -> bytes:
541 encoding = self.headers.get(CONTENT_ENCODING, "").lower()
542 if encoding == "identity":
543 return data
544 if encoding in {"deflate", "gzip"}:
545 return ZLibDecompressor(
546 encoding=encoding,
547 suppress_deflate_header=True,
548 ).decompress_sync(data, max_length=self._max_decompress_size)
550 raise RuntimeError(f"unknown content encoding: {encoding}")
552 async def _decode_content_async(self, data: bytes) -> AsyncIterator[bytes]:
553 encoding = self.headers.get(CONTENT_ENCODING, "").lower()
554 if encoding == "identity":
555 yield data
556 elif encoding in {"deflate", "gzip"}:
557 d = ZLibDecompressor(
558 encoding=encoding,
559 suppress_deflate_header=True,
560 )
561 yield await d.decompress(data, max_length=self._max_decompress_size)
562 else:
563 raise RuntimeError(f"unknown content encoding: {encoding}")
565 def _decode_content_transfer(self, data: bytes) -> bytes:
566 encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
568 if encoding == "base64":
569 return base64.b64decode(data)
570 elif encoding == "quoted-printable":
571 return binascii.a2b_qp(data)
572 elif encoding in ("binary", "8bit", "7bit"):
573 return data
574 else:
575 raise RuntimeError(f"unknown content transfer encoding: {encoding}")
577 def get_charset(self, default: str) -> str:
578 """Returns charset parameter from Content-Type header or default."""
579 ctype = self.headers.get(CONTENT_TYPE, "")
580 mimetype = parse_mimetype(ctype)
581 return mimetype.parameters.get("charset", self._default_charset or default)
583 @reify
584 def name(self) -> str | None:
585 """Returns name specified in Content-Disposition header.
587 If the header is missing or malformed, returns None.
588 """
589 _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
590 return content_disposition_filename(params, "name")
592 @reify
593 def filename(self) -> str | None:
594 """Returns filename specified in Content-Disposition header.
596 Returns None if the header is missing or malformed.
597 """
598 _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
599 return content_disposition_filename(params, "filename")
602@payload_type(BodyPartReader, order=Order.try_first)
603class BodyPartReaderPayload(Payload):
604 _value: BodyPartReader
605 # _autoclose = False (inherited) - Streaming reader that may have resources
607 def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None:
608 super().__init__(value, *args, **kwargs)
610 params: dict[str, str] = {}
611 if value.name is not None:
612 params["name"] = value.name
613 if value.filename is not None:
614 params["filename"] = value.filename
616 if params:
617 self.set_content_disposition("attachment", True, **params)
619 def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
620 raise TypeError("Unable to decode.")
622 async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes:
623 """Raises TypeError as body parts should be consumed via write().
625 This is intentional: BodyPartReader payloads are designed for streaming
626 large data (potentially gigabytes) and must be consumed only once via
627 the write() method to avoid memory exhaustion. They cannot be buffered
628 in memory for reuse.
629 """
630 raise TypeError("Unable to read body part as bytes. Use write() to consume.")
632 async def write(self, writer: AbstractStreamWriter) -> None:
633 field = self._value
634 while chunk := await field.read_chunk(size=2**18):
635 async for d in field.decode_iter(chunk):
636 await writer.write(d)
639class MultipartReader:
640 """Multipart body reader."""
642 #: Response wrapper, used when multipart readers constructs from response.
643 response_wrapper_cls = MultipartResponseWrapper
644 #: Multipart reader class, used to handle multipart/* body parts.
645 #: None points to type(self)
646 multipart_reader_cls: type["MultipartReader"] | None = None
647 #: Body part reader class for non multipart/* content types.
648 part_reader_cls = BodyPartReader
650 def __init__(
651 self,
652 headers: Mapping[str, str],
653 content: StreamReader,
654 *,
655 max_field_size: int = 8190,
656 max_headers: int = 128,
657 ) -> None:
658 self._mimetype = parse_mimetype(headers[CONTENT_TYPE])
659 assert self._mimetype.type == "multipart", "multipart/* content type expected"
660 if "boundary" not in self._mimetype.parameters:
661 raise ValueError(
662 "boundary missed for Content-Type: %s" % headers[CONTENT_TYPE]
663 )
665 self.headers = headers
666 self._boundary = ("--" + self._get_boundary()).encode()
667 self._content = content
668 self._default_charset: str | None = None
669 self._last_part: MultipartReader | BodyPartReader | None = None
670 self._max_field_size = max_field_size
671 self._max_headers = max_headers
672 self._at_eof = False
673 self._at_bof = True
674 self._unread: list[bytes] = []
676 def __aiter__(self) -> Self:
677 return self
679 async def __anext__(
680 self,
681 ) -> Union["MultipartReader", BodyPartReader] | None:
682 part = await self.next()
683 if part is None:
684 raise StopAsyncIteration
685 return part
687 @classmethod
688 def from_response(
689 cls,
690 response: "ClientResponse",
691 ) -> MultipartResponseWrapper:
692 """Constructs reader instance from HTTP response.
694 :param response: :class:`~aiohttp.client.ClientResponse` instance
695 """
696 obj = cls.response_wrapper_cls(
697 response, cls(response.headers, response.content)
698 )
699 return obj
701 def at_eof(self) -> bool:
702 """Returns True if the final boundary was reached, false otherwise."""
703 return self._at_eof
705 async def next(
706 self,
707 ) -> Union["MultipartReader", BodyPartReader] | None:
708 """Emits the next multipart body part."""
709 # So, if we're at BOF, we need to skip till the boundary.
710 if self._at_eof:
711 return None
712 await self._maybe_release_last_part()
713 if self._at_bof:
714 await self._read_until_first_boundary()
715 self._at_bof = False
716 else:
717 await self._read_boundary()
718 if self._at_eof: # we just read the last boundary, nothing to do there
719 # https://github.com/python/mypy/issues/17537
720 return None # type: ignore[unreachable]
722 part = await self.fetch_next_part()
723 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.6
724 if (
725 self._last_part is None
726 and self._mimetype.subtype == "form-data"
727 and isinstance(part, BodyPartReader)
728 ):
729 _, params = parse_content_disposition(part.headers.get(CONTENT_DISPOSITION))
730 if params.get("name") == "_charset_":
731 # Longest encoding in https://encoding.spec.whatwg.org/encodings.json
732 # is 19 characters, so 32 should be more than enough for any valid encoding.
733 charset = await part.read_chunk(32)
734 if len(charset) > 31:
735 raise RuntimeError("Invalid default charset")
736 self._default_charset = charset.strip().decode()
737 part = await self.fetch_next_part()
738 self._last_part = part
739 return self._last_part
741 async def release(self) -> None:
742 """Reads all the body parts to the void till the final boundary."""
743 while not self._at_eof:
744 item = await self.next()
745 if item is None:
746 break
747 await item.release()
749 async def fetch_next_part(
750 self,
751 ) -> Union["MultipartReader", BodyPartReader]:
752 """Returns the next body part reader."""
753 headers = await self._read_headers()
754 return self._get_part_reader(headers)
756 def _get_part_reader(
757 self,
758 headers: "CIMultiDictProxy[str]",
759 ) -> Union["MultipartReader", BodyPartReader]:
760 """Dispatches the response by the `Content-Type` header.
762 Returns a suitable reader instance.
764 :param dict headers: Response headers
765 """
766 ctype = headers.get(CONTENT_TYPE, "")
767 mimetype = parse_mimetype(ctype)
769 if mimetype.type == "multipart":
770 if self.multipart_reader_cls is None:
771 return type(self)(headers, self._content)
772 return self.multipart_reader_cls(
773 headers,
774 self._content,
775 max_field_size=self._max_field_size,
776 max_headers=self._max_headers,
777 )
778 else:
779 return self.part_reader_cls(
780 self._boundary,
781 headers,
782 self._content,
783 subtype=self._mimetype.subtype,
784 default_charset=self._default_charset,
785 )
787 def _get_boundary(self) -> str:
788 boundary = self._mimetype.parameters["boundary"]
789 if len(boundary) > 70:
790 raise ValueError("boundary %r is too long (70 chars max)" % boundary)
792 return boundary
794 async def _readline(self) -> bytes:
795 if self._unread:
796 return self._unread.pop()
797 return await self._content.readline()
799 async def _read_until_first_boundary(self) -> None:
800 while True:
801 chunk = await self._readline()
802 if chunk == b"":
803 raise ValueError(f"Could not find starting boundary {self._boundary!r}")
804 chunk = chunk.rstrip()
805 if chunk == self._boundary:
806 return
807 elif chunk == self._boundary + b"--":
808 self._at_eof = True
809 return
811 async def _read_boundary(self) -> None:
812 chunk = (await self._readline()).rstrip()
813 if chunk == self._boundary:
814 pass
815 elif chunk == self._boundary + b"--":
816 self._at_eof = True
817 epilogue = await self._readline()
818 next_line = await self._readline()
820 # the epilogue is expected and then either the end of input or the
821 # parent multipart boundary, if the parent boundary is found then
822 # it should be marked as unread and handed to the parent for
823 # processing
824 if next_line[:2] == b"--":
825 self._unread.append(next_line)
826 # otherwise the request is likely missing an epilogue and both
827 # lines should be passed to the parent for processing
828 # (this handles the old behavior gracefully)
829 else:
830 self._unread.extend([next_line, epilogue])
831 else:
832 raise ValueError(f"Invalid boundary {chunk!r}, expected {self._boundary!r}")
834 async def _read_headers(self) -> "CIMultiDictProxy[str]":
835 lines = []
836 while True:
837 chunk = await self._content.readline(max_line_length=self._max_field_size)
838 chunk = chunk.rstrip(b"\r\n")
839 lines.append(chunk)
840 if not chunk:
841 break
842 if len(lines) > self._max_headers:
843 raise BadHttpMessage("Too many headers received")
844 parser = HeadersParser(max_field_size=self._max_field_size)
845 headers, raw_headers = parser.parse_headers(lines)
846 return headers
848 async def _maybe_release_last_part(self) -> None:
849 """Ensures that the last read body part is read completely."""
850 if self._last_part is not None:
851 if not self._last_part.at_eof():
852 await self._last_part.release()
853 self._unread.extend(self._last_part._unread)
854 self._last_part = None
857_Part = tuple[Payload, str, str]
860class MultipartWriter(Payload):
861 """Multipart body writer."""
863 _value: None
864 # _consumed = False (inherited) - Can be encoded multiple times
865 _autoclose = True # No file handles, just collects parts in memory
867 def __init__(self, subtype: str = "mixed", boundary: str | None = None) -> None:
868 boundary = boundary if boundary is not None else uuid.uuid4().hex
869 # The underlying Payload API demands a str (utf-8), not bytes,
870 # so we need to ensure we don't lose anything during conversion.
871 # As a result, require the boundary to be ASCII only.
872 # In both situations.
874 try:
875 self._boundary = boundary.encode("ascii")
876 except UnicodeEncodeError:
877 raise ValueError("boundary should contain ASCII only chars") from None
879 if len(boundary) > 70:
880 raise ValueError("boundary %r is too long (70 chars max)" % boundary)
882 ctype = f"multipart/{subtype}; boundary={self._boundary_value}"
884 super().__init__(None, content_type=ctype)
886 self._parts: list[_Part] = []
887 self._is_form_data = subtype == "form-data"
889 def __enter__(self) -> "MultipartWriter":
890 return self
892 def __exit__(
893 self,
894 exc_type: type[BaseException] | None,
895 exc_val: BaseException | None,
896 exc_tb: TracebackType | None,
897 ) -> None:
898 pass
900 def __iter__(self) -> Iterator[_Part]:
901 return iter(self._parts)
903 def __len__(self) -> int:
904 return len(self._parts)
906 def __bool__(self) -> bool:
907 return True
909 _valid_tchar_regex = re.compile(rb"\A[!#$%&'*+\-.^_`|~\w]+\Z")
910 _invalid_qdtext_char_regex = re.compile(rb"[\x00-\x08\x0A-\x1F\x7F]")
912 @property
913 def _boundary_value(self) -> str:
914 """Wrap boundary parameter value in quotes, if necessary.
916 Reads self.boundary and returns a unicode string.
917 """
918 # Refer to RFCs 7231, 7230, 5234.
919 #
920 # parameter = token "=" ( token / quoted-string )
921 # token = 1*tchar
922 # quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE
923 # qdtext = HTAB / SP / %x21 / %x23-5B / %x5D-7E / obs-text
924 # obs-text = %x80-FF
925 # quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text )
926 # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*"
927 # / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~"
928 # / DIGIT / ALPHA
929 # ; any VCHAR, except delimiters
930 # VCHAR = %x21-7E
931 value = self._boundary
932 if re.match(self._valid_tchar_regex, value):
933 return value.decode("ascii") # cannot fail
935 if re.search(self._invalid_qdtext_char_regex, value):
936 raise ValueError("boundary value contains invalid characters")
938 # escape %x5C and %x22
939 quoted_value_content = value.replace(b"\\", b"\\\\")
940 quoted_value_content = quoted_value_content.replace(b'"', b'\\"')
942 return '"' + quoted_value_content.decode("ascii") + '"'
944 @property
945 def boundary(self) -> str:
946 return self._boundary.decode("ascii")
948 def append(self, obj: Any, headers: Mapping[str, str] | None = None) -> Payload:
949 if headers is None:
950 headers = CIMultiDict()
952 if isinstance(obj, Payload):
953 obj.headers.update(headers)
954 return self.append_payload(obj)
955 else:
956 try:
957 payload = get_payload(obj, headers=headers)
958 except LookupError:
959 raise TypeError("Cannot create payload from %r" % obj)
960 else:
961 return self.append_payload(payload)
963 def append_payload(self, payload: Payload) -> Payload:
964 """Adds a new body part to multipart writer."""
965 encoding: str | None = None
966 te_encoding: str | None = None
967 if self._is_form_data:
968 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.7
969 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
970 assert (
971 not {CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING}
972 & payload.headers.keys()
973 )
974 # Set default Content-Disposition in case user doesn't create one
975 if CONTENT_DISPOSITION not in payload.headers:
976 name = f"section-{len(self._parts)}"
977 payload.set_content_disposition("form-data", name=name)
978 else:
979 # compression
980 encoding = payload.headers.get(CONTENT_ENCODING, "").lower()
981 if encoding and encoding not in ("deflate", "gzip", "identity"):
982 raise RuntimeError(f"unknown content encoding: {encoding}")
983 if encoding == "identity":
984 encoding = None
986 # te encoding
987 te_encoding = payload.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
988 if te_encoding not in ("", "base64", "quoted-printable", "binary"):
989 raise RuntimeError(f"unknown content transfer encoding: {te_encoding}")
990 if te_encoding == "binary":
991 te_encoding = None
993 # size
994 size = payload.size
995 if size is not None and not (encoding or te_encoding):
996 payload.headers[CONTENT_LENGTH] = str(size)
998 self._parts.append((payload, encoding, te_encoding)) # type: ignore[arg-type]
999 return payload
1001 def append_json(
1002 self, obj: Any, headers: Mapping[str, str] | None = None
1003 ) -> Payload:
1004 """Helper to append JSON part."""
1005 if headers is None:
1006 headers = CIMultiDict()
1008 return self.append_payload(JsonPayload(obj, headers=headers))
1010 def append_form(
1011 self,
1012 obj: Sequence[tuple[str, str]] | Mapping[str, str],
1013 headers: Mapping[str, str] | None = None,
1014 ) -> Payload:
1015 """Helper to append form urlencoded part."""
1016 assert isinstance(obj, (Sequence, Mapping))
1018 if headers is None:
1019 headers = CIMultiDict()
1021 if isinstance(obj, Mapping):
1022 obj = list(obj.items())
1023 data = urlencode(obj, doseq=True)
1025 return self.append_payload(
1026 StringPayload(
1027 data, headers=headers, content_type="application/x-www-form-urlencoded"
1028 )
1029 )
1031 @property
1032 def size(self) -> int | None:
1033 """Size of the payload."""
1034 total = 0
1035 for part, encoding, te_encoding in self._parts:
1036 part_size = part.size
1037 if encoding or te_encoding or part_size is None:
1038 return None
1040 total += int(
1041 2
1042 + len(self._boundary)
1043 + 2
1044 + part_size # b'--'+self._boundary+b'\r\n'
1045 + len(part._binary_headers)
1046 + 2 # b'\r\n'
1047 )
1049 total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n'
1050 return total
1052 def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
1053 """Return string representation of the multipart data.
1055 WARNING: This method may do blocking I/O if parts contain file payloads.
1056 It should not be called in the event loop. Use as_bytes().decode() instead.
1057 """
1058 return "".join(
1059 "--"
1060 + self.boundary
1061 + "\r\n"
1062 + part._binary_headers.decode(encoding, errors)
1063 + part.decode()
1064 for part, _e, _te in self._parts
1065 )
1067 async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes:
1068 """Return bytes representation of the multipart data.
1070 This method is async-safe and calls as_bytes on underlying payloads.
1071 """
1072 parts: list[bytes] = []
1074 # Process each part
1075 for part, _e, _te in self._parts:
1076 # Add boundary
1077 parts.append(b"--" + self._boundary + b"\r\n")
1079 # Add headers
1080 parts.append(part._binary_headers)
1082 # Add payload content using as_bytes for async safety
1083 part_bytes = await part.as_bytes(encoding, errors)
1084 parts.append(part_bytes)
1086 # Add trailing CRLF
1087 parts.append(b"\r\n")
1089 # Add closing boundary
1090 parts.append(b"--" + self._boundary + b"--\r\n")
1092 return b"".join(parts)
1094 async def write(
1095 self, writer: AbstractStreamWriter, close_boundary: bool = True
1096 ) -> None:
1097 """Write body."""
1098 for part, encoding, te_encoding in self._parts:
1099 if self._is_form_data:
1100 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.2
1101 assert CONTENT_DISPOSITION in part.headers
1102 assert "name=" in part.headers[CONTENT_DISPOSITION]
1104 await writer.write(b"--" + self._boundary + b"\r\n")
1105 await writer.write(part._binary_headers)
1107 if encoding or te_encoding:
1108 w = MultipartPayloadWriter(writer)
1109 if encoding:
1110 w.enable_compression(encoding)
1111 if te_encoding:
1112 w.enable_encoding(te_encoding)
1113 await part.write(w) # type: ignore[arg-type]
1114 await w.write_eof()
1115 else:
1116 await part.write(writer)
1118 await writer.write(b"\r\n")
1120 if close_boundary:
1121 await writer.write(b"--" + self._boundary + b"--\r\n")
1123 async def close(self) -> None:
1124 """
1125 Close all part payloads that need explicit closing.
1127 IMPORTANT: This method must not await anything that might not finish
1128 immediately, as it may be called during cleanup/cancellation. Schedule
1129 any long-running operations without awaiting them.
1130 """
1131 if self._consumed:
1132 return
1133 self._consumed = True
1135 # Close all parts that need explicit closing
1136 # We catch and log exceptions to ensure all parts get a chance to close
1137 # we do not use asyncio.gather() here because we are not allowed
1138 # to suspend given we may be called during cleanup
1139 for idx, (part, _, _) in enumerate(self._parts):
1140 if not part.autoclose and not part.consumed:
1141 try:
1142 await part.close()
1143 except Exception as exc:
1144 internal_logger.error(
1145 "Failed to close multipart part %d: %s", idx, exc, exc_info=True
1146 )
1149class MultipartPayloadWriter:
1150 def __init__(self, writer: AbstractStreamWriter) -> None:
1151 self._writer = writer
1152 self._encoding: str | None = None
1153 self._compress: ZLibCompressor | None = None
1154 self._encoding_buffer: bytearray | None = None
1156 def enable_encoding(self, encoding: str) -> None:
1157 if encoding == "base64":
1158 self._encoding = encoding
1159 self._encoding_buffer = bytearray()
1160 elif encoding == "quoted-printable":
1161 self._encoding = "quoted-printable"
1163 def enable_compression(
1164 self, encoding: str = "deflate", strategy: int | None = None
1165 ) -> None:
1166 self._compress = ZLibCompressor(
1167 encoding=encoding,
1168 suppress_deflate_header=True,
1169 strategy=strategy,
1170 )
1172 async def write_eof(self) -> None:
1173 if self._compress is not None:
1174 chunk = self._compress.flush()
1175 if chunk:
1176 self._compress = None
1177 await self.write(chunk)
1179 if self._encoding == "base64":
1180 if self._encoding_buffer:
1181 await self._writer.write(base64.b64encode(self._encoding_buffer))
1183 async def write(self, chunk: bytes) -> None:
1184 if self._compress is not None:
1185 if chunk:
1186 chunk = await self._compress.compress(chunk)
1187 if not chunk:
1188 return
1190 if self._encoding == "base64":
1191 buf = self._encoding_buffer
1192 assert buf is not None
1193 buf.extend(chunk)
1195 if buf:
1196 div, mod = divmod(len(buf), 3)
1197 enc_chunk, self._encoding_buffer = (buf[: div * 3], buf[div * 3 :])
1198 if enc_chunk:
1199 b64chunk = base64.b64encode(enc_chunk)
1200 await self._writer.write(b64chunk)
1201 elif self._encoding == "quoted-printable":
1202 await self._writer.write(binascii.b2a_qp(chunk))
1203 else:
1204 await self._writer.write(chunk)