Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/multipart.py: 18%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import base64
2import binascii
3import json
4import re
5import sys
6import uuid
7import warnings
8from collections import deque
9from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
10from types import TracebackType
11from typing import TYPE_CHECKING, Any, Union, cast
12from urllib.parse import parse_qsl, unquote, urlencode
14from multidict import CIMultiDict
16from .abc import AbstractStreamWriter
17from .compression_utils import ZLibCompressor, ZLibDecompressor
18from .hdrs import (
19 CONTENT_DISPOSITION,
20 CONTENT_ENCODING,
21 CONTENT_LENGTH,
22 CONTENT_TRANSFER_ENCODING,
23 CONTENT_TYPE,
24)
25from .helpers import (
26 CHAR,
27 DEFAULT_CHUNK_SIZE,
28 TOKEN,
29 HeadersDictProxy,
30 parse_mimetype,
31 reify,
32)
33from .http import HeadersParser
34from .http_exceptions import BadHttpMessage
35from .log import internal_logger
36from .payload import (
37 JsonPayload,
38 LookupError,
39 Order,
40 Payload,
41 StringPayload,
42 get_payload,
43 payload_type,
44)
45from .streams import StreamReader
47if sys.version_info >= (3, 11):
48 from typing import Self
49else:
50 from typing import TypeVar
52 Self = TypeVar("Self", bound="BodyPartReader")
54__all__ = (
55 "MultipartReader",
56 "MultipartWriter",
57 "BodyPartReader",
58 "BadContentDispositionHeader",
59 "BadContentDispositionParam",
60 "parse_content_disposition",
61 "content_disposition_filename",
62)
65if TYPE_CHECKING:
66 from .client_reqrep import ClientResponse
69class BadContentDispositionHeader(RuntimeWarning):
70 pass
73class BadContentDispositionParam(RuntimeWarning):
74 pass
77def parse_content_disposition(
78 header: str | None,
79) -> tuple[str | None, dict[str, str]]:
80 def is_token(string: str) -> bool:
81 return bool(string) and TOKEN >= set(string)
83 def is_quoted(string: str) -> bool:
84 return string[0] == string[-1] == '"'
86 def is_rfc5987(string: str) -> bool:
87 return is_token(string) and string.count("'") == 2
89 def is_extended_param(string: str) -> bool:
90 return string.endswith("*")
92 def is_continuous_param(string: str) -> bool:
93 pos = string.find("*") + 1
94 if not pos:
95 return False
96 substring = string[pos:-1] if string.endswith("*") else string[pos:]
97 return substring.isdigit()
99 def unescape(text: str, *, chars: str = "".join(map(re.escape, CHAR))) -> str:
100 return re.sub(f"\\\\([{chars}])", "\\1", text)
102 if not header:
103 return None, {}
105 disptype, *parts = header.split(";")
106 if not is_token(disptype):
107 warnings.warn(BadContentDispositionHeader(header))
108 return None, {}
110 params: dict[str, str] = {}
111 while parts:
112 item = parts.pop(0)
114 if not item: # To handle trailing semicolons
115 warnings.warn(BadContentDispositionHeader(header))
116 continue
118 if "=" not in item:
119 warnings.warn(BadContentDispositionHeader(header))
120 return None, {}
122 key, value = item.split("=", 1)
123 key = key.lower().strip()
124 value = value.lstrip()
126 if key in params:
127 warnings.warn(BadContentDispositionHeader(header))
128 return None, {}
130 if not is_token(key):
131 warnings.warn(BadContentDispositionParam(item))
132 continue
134 elif is_continuous_param(key):
135 if is_quoted(value):
136 value = unescape(value[1:-1])
137 elif not is_token(value):
138 warnings.warn(BadContentDispositionParam(item))
139 continue
141 elif is_extended_param(key):
142 if is_rfc5987(value):
143 encoding, _, value = value.split("'", 2)
144 encoding = encoding or "utf-8"
145 else:
146 warnings.warn(BadContentDispositionParam(item))
147 continue
149 try:
150 value = unquote(value, encoding, "strict")
151 except UnicodeDecodeError: # pragma: nocover
152 warnings.warn(BadContentDispositionParam(item))
153 continue
155 else:
156 failed = True
157 if is_quoted(value):
158 failed = False
159 value = unescape(value[1:-1].lstrip("\\/"))
160 elif is_token(value):
161 failed = False
162 elif parts:
163 # maybe just ; in filename, in any case this is just
164 # one case fix, for proper fix we need to redesign parser
165 _value = f"{value};{parts[0]}"
166 if is_quoted(_value):
167 parts.pop(0)
168 value = unescape(_value[1:-1].lstrip("\\/"))
169 failed = False
171 if failed:
172 warnings.warn(BadContentDispositionHeader(header))
173 return None, {}
175 params[key] = value
177 return disptype.lower(), params
180def content_disposition_filename(
181 params: Mapping[str, str], name: str = "filename"
182) -> str | None:
183 name_suf = "%s*" % name
184 if not params:
185 return None
186 elif name_suf in params:
187 return params[name_suf]
188 elif name in params:
189 return params[name]
190 else:
191 parts = []
192 fnparams = sorted(
193 (key, value) for key, value in params.items() if key.startswith(name_suf)
194 )
195 for num, (key, value) in enumerate(fnparams):
196 _, tail = key.split("*", 1)
197 if tail.endswith("*"):
198 tail = tail[:-1]
199 if tail == str(num):
200 parts.append(value)
201 else:
202 break
203 if not parts:
204 return None
205 value = "".join(parts)
206 if "'" in value:
207 encoding, _, value = value.split("'", 2)
208 encoding = encoding or "utf-8"
209 return unquote(value, encoding, "strict")
210 return value
213class MultipartResponseWrapper:
214 """Wrapper around the MultipartReader.
216 It takes care about
217 underlying connection and close it when it needs in.
218 """
220 def __init__(
221 self,
222 resp: "ClientResponse",
223 stream: "MultipartReader",
224 ) -> None:
225 self.resp = resp
226 self.stream = stream
228 def __aiter__(self) -> "MultipartResponseWrapper":
229 return self
231 async def __anext__(
232 self,
233 ) -> Union["MultipartReader", "BodyPartReader"]:
234 part = await self.next()
235 if part is None:
236 raise StopAsyncIteration
237 return part
239 def at_eof(self) -> bool:
240 """Returns True when all response data had been read."""
241 return self.resp.content.at_eof()
243 async def next(
244 self,
245 ) -> Union["MultipartReader", "BodyPartReader"] | None:
246 """Emits next multipart reader object."""
247 item = await self.stream.next()
248 if self.stream.at_eof():
249 await self.release()
250 return item
252 async def release(self) -> None:
253 """Release the connection gracefully.
255 All remaining content is read to the void.
256 """
257 self.resp.release()
260class BodyPartReader:
261 """Multipart reader for single body part."""
263 chunk_size = 8192
265 def __init__(
266 self,
267 boundary: bytes,
268 headers: HeadersDictProxy,
269 content: StreamReader,
270 *,
271 subtype: str = "mixed",
272 default_charset: str | None = None,
273 max_decompress_size: int = DEFAULT_CHUNK_SIZE,
274 client_max_size: int = sys.maxsize,
275 max_size_error_cls: type[Exception] = ValueError,
276 ) -> None:
277 self.headers = headers
278 self._boundary = boundary
279 self._boundary_len = len(boundary) + 2 # Boundary + \r\n
280 self._content = content
281 self._default_charset = default_charset
282 self._at_eof = False
283 self._is_form_data = subtype == "form-data"
284 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
285 length = None if self._is_form_data else self.headers.get(CONTENT_LENGTH, None)
286 self._length = int(length) if length is not None else None
287 self._read_bytes = 0
288 self._unread: deque[bytes] = deque()
289 self._prev_chunk: bytes | None = None
290 self._content_eof = 0
291 self._cache: dict[str, Any] = {}
292 self._max_decompress_size = max_decompress_size
293 self._client_max_size = client_max_size
294 self._max_size_error_cls = max_size_error_cls
296 def __aiter__(self) -> Self:
297 return self
299 async def __anext__(self) -> bytes:
300 part = await self.next()
301 if part is None:
302 raise StopAsyncIteration
303 return part
305 async def next(self) -> bytes | None:
306 item = await self.read()
307 if not item:
308 return None
309 return item
311 async def read(self, *, decode: bool = False) -> bytes:
312 """Reads body part data.
314 decode: Decodes data following by encoding
315 method from Content-Encoding header. If it missed
316 data remains untouched
317 """
318 if self._at_eof:
319 return b""
320 data = bytearray()
321 while not self._at_eof:
322 data.extend(await self.read_chunk(self.chunk_size))
323 if len(data) > self._client_max_size:
324 raise self._max_size_error_cls(self._client_max_size)
325 # https://github.com/python/mypy/issues/17537
326 if decode: # type: ignore[unreachable]
327 decoded_data = bytearray()
328 async for d in self.decode_iter(data):
329 decoded_data.extend(d)
330 if len(decoded_data) > self._client_max_size:
331 raise self._max_size_error_cls(self._client_max_size)
332 return decoded_data
333 return data
335 async def read_chunk(self, size: int = chunk_size) -> bytes:
336 """Reads body part content chunk of the specified size.
338 size: chunk size
339 """
340 if self._at_eof:
341 return b""
342 if self._length:
343 chunk = await self._read_chunk_from_length(size)
344 else:
345 chunk = await self._read_chunk_from_stream(size)
347 # For the case of base64 data, we must read a fragment of size with a
348 # remainder of 0 by dividing by 4 for string without symbols \n or \r
349 encoding = self.headers.get(CONTENT_TRANSFER_ENCODING)
350 if encoding and encoding.lower() == "base64":
351 stripped_chunk = b"".join(chunk.split())
352 remainder = len(stripped_chunk) % 4
354 while remainder != 0 and not self.at_eof():
355 over_chunk_size = 4 - remainder
356 over_chunk = b""
358 if self._prev_chunk:
359 over_chunk = self._prev_chunk[:over_chunk_size]
360 self._prev_chunk = self._prev_chunk[len(over_chunk) :]
362 if len(over_chunk) != over_chunk_size:
363 over_chunk += await self._content.read(4 - len(over_chunk))
365 if not over_chunk:
366 self._at_eof = True
368 stripped_chunk += b"".join(over_chunk.split())
369 chunk += over_chunk
370 remainder = len(stripped_chunk) % 4
372 self._read_bytes += len(chunk)
373 if self._read_bytes == self._length:
374 self._at_eof = True
375 if self._at_eof and await self._content.readline() != b"\r\n":
376 raise ValueError("Reader did not read all the data or it is malformed")
377 return chunk
379 async def _read_chunk_from_length(self, size: int) -> bytes:
380 # Reads body part content chunk of the specified size.
381 # The body part must has Content-Length header with proper value.
382 assert self._length is not None, "Content-Length required for chunked read"
383 chunk_size = min(size, self._length - self._read_bytes)
384 chunk = await self._content.read(chunk_size)
385 if self._content.at_eof():
386 self._at_eof = True
387 return chunk
389 async def _read_chunk_from_stream(self, size: int) -> bytes:
390 # Reads content chunk of body part with unknown length.
391 # The Content-Length header for body part is not necessary.
392 assert (
393 size >= self._boundary_len
394 ), "Chunk size must be greater or equal than boundary length + 2"
395 first_chunk = self._prev_chunk is None
396 if first_chunk:
397 # We need to re-add the CRLF that got removed from headers parsing.
398 self._prev_chunk = b"\r\n" + await self._content.read(size)
400 chunk = b""
401 # content.read() may return less than size, so we need to loop to ensure
402 # we have enough data to detect the boundary.
403 while len(chunk) < self._boundary_len:
404 chunk += await self._content.read(size)
405 self._content_eof += int(self._content.at_eof())
406 if self._content_eof > 2:
407 raise ValueError("Reading after EOF")
408 if self._content_eof:
409 break
410 if len(chunk) > size:
411 self._content.unread_data(chunk[size:])
412 chunk = chunk[:size]
414 assert self._prev_chunk is not None
415 window = self._prev_chunk + chunk
416 sub = b"\r\n" + self._boundary
417 if first_chunk:
418 idx = window.find(sub)
419 else:
420 idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub)))
421 if idx >= 0:
422 # pushing boundary back to content
423 with warnings.catch_warnings():
424 warnings.filterwarnings("ignore", category=DeprecationWarning)
425 self._content.unread_data(window[idx:])
426 self._prev_chunk = self._prev_chunk[:idx]
427 chunk = window[len(self._prev_chunk) : idx]
428 if not chunk:
429 self._at_eof = True
430 result = self._prev_chunk[2 if first_chunk else 0 :] # Strip initial CRLF
431 self._prev_chunk = chunk
432 return result
434 async def readline(self) -> bytes:
435 """Reads body part by line by line."""
436 if self._at_eof:
437 return b""
439 if self._unread:
440 line = self._unread.popleft()
441 else:
442 line = await self._content.readline()
444 if line.startswith(self._boundary):
445 # the very last boundary may not come with \r\n,
446 # so set single rules for everyone
447 sline = line.rstrip(b"\r\n")
448 boundary = self._boundary
449 last_boundary = self._boundary + b"--"
450 # ensure that we read exactly the boundary, not something alike
451 if sline == boundary or sline == last_boundary:
452 self._at_eof = True
453 self._unread.append(line)
454 return b""
455 else:
456 next_line = await self._content.readline()
457 if next_line.startswith(self._boundary):
458 line = line[:-2] # strip CRLF but only once
459 self._unread.append(next_line)
461 return line
463 async def release(self) -> None:
464 """Like read(), but reads all the data to the void."""
465 if self._at_eof:
466 return
467 while not self._at_eof:
468 await self.read_chunk(self.chunk_size)
470 async def text(self, *, encoding: str | None = None) -> str:
471 """Like read(), but assumes that body part contains text data."""
472 data = await self.read(decode=True)
473 # see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm
474 # and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send
475 encoding = encoding or self.get_charset(default="utf-8")
476 return data.decode(encoding)
478 async def json(self, *, encoding: str | None = None) -> dict[str, Any] | None:
479 """Like read(), but assumes that body parts contains JSON data."""
480 data = await self.read(decode=True)
481 if not data:
482 return None
483 encoding = encoding or self.get_charset(default="utf-8")
484 return cast(dict[str, Any], json.loads(data.decode(encoding)))
486 async def form(self, *, encoding: str | None = None) -> list[tuple[str, str]]:
487 """Like read(), but assumes that body parts contain form urlencoded data."""
488 data = await self.read(decode=True)
489 if not data:
490 return []
491 if encoding is not None:
492 real_encoding = encoding
493 else:
494 real_encoding = self.get_charset(default="utf-8")
495 try:
496 decoded_data = data.rstrip().decode(real_encoding)
497 except UnicodeDecodeError:
498 raise ValueError("data cannot be decoded with %s encoding" % real_encoding)
500 return parse_qsl(
501 decoded_data,
502 keep_blank_values=True,
503 encoding=real_encoding,
504 )
506 def at_eof(self) -> bool:
507 """Returns True if the boundary was reached or False otherwise."""
508 return self._at_eof
510 def _apply_content_transfer_decoding(self, data: bytes) -> bytes:
511 """Apply Content-Transfer-Encoding decoding if header is present."""
512 if CONTENT_TRANSFER_ENCODING in self.headers:
513 return self._decode_content_transfer(data)
514 return data
516 def _needs_content_decoding(self) -> bool:
517 """Check if Content-Encoding decoding should be applied."""
518 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
519 return not self._is_form_data and CONTENT_ENCODING in self.headers
521 def decode(self, data: bytes) -> bytes:
522 """Decodes data synchronously.
524 Decodes data according the specified Content-Encoding
525 or Content-Transfer-Encoding headers value.
527 Note: For large payloads, consider using decode_iter() instead
528 to avoid blocking the event loop during decompression.
529 """
530 data = self._apply_content_transfer_decoding(data)
531 if self._needs_content_decoding():
532 return self._decode_content(data)
533 return data
535 async def decode_iter(self, data: bytes) -> AsyncIterator[bytes]:
536 """Async generator that yields decoded data chunks.
538 Decodes data according the specified Content-Encoding
539 or Content-Transfer-Encoding headers value.
541 This method offloads decompression to an executor for large payloads
542 to avoid blocking the event loop.
543 """
544 data = self._apply_content_transfer_decoding(data)
545 if self._needs_content_decoding():
546 async for d in self._decode_content_async(data):
547 yield d
548 else:
549 yield data
551 def _decode_content(self, data: bytes) -> bytes:
552 encoding = self.headers.get(CONTENT_ENCODING, "").lower()
553 if encoding == "identity":
554 return data
555 if encoding in {"deflate", "gzip"}:
556 return ZLibDecompressor(
557 encoding=encoding,
558 suppress_deflate_header=True,
559 ).decompress_sync(data, max_length=self._max_decompress_size)
561 raise RuntimeError(f"unknown content encoding: {encoding}")
563 async def _decode_content_async(self, data: bytes) -> AsyncIterator[bytes]:
564 encoding = self.headers.get(CONTENT_ENCODING, "").lower()
565 if encoding == "identity":
566 yield data
567 elif encoding in {"deflate", "gzip"}:
568 d = ZLibDecompressor(
569 encoding=encoding,
570 suppress_deflate_header=True,
571 )
572 yield await d.decompress(data, max_length=self._max_decompress_size)
573 while d.data_available:
574 yield await d.decompress(b"", max_length=self._max_decompress_size)
575 else:
576 raise RuntimeError(f"unknown content encoding: {encoding}")
578 def _decode_content_transfer(self, data: bytes) -> bytes:
579 encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
581 if encoding == "base64":
582 return base64.b64decode(data)
583 elif encoding == "quoted-printable":
584 return binascii.a2b_qp(data)
585 elif encoding in ("binary", "8bit", "7bit"):
586 return data
587 else:
588 raise RuntimeError(f"unknown content transfer encoding: {encoding}")
590 def get_charset(self, default: str) -> str:
591 """Returns charset parameter from Content-Type header or default."""
592 ctype = self.headers.get(CONTENT_TYPE, "")
593 mimetype = parse_mimetype(ctype)
594 return mimetype.parameters.get("charset", self._default_charset or default)
596 @reify
597 def name(self) -> str | None:
598 """Returns name specified in Content-Disposition header.
600 If the header is missing or malformed, returns None.
601 """
602 _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
603 return content_disposition_filename(params, "name")
605 @reify
606 def filename(self) -> str | None:
607 """Returns filename specified in Content-Disposition header.
609 Returns None if the header is missing or malformed.
610 """
611 _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
612 return content_disposition_filename(params, "filename")
615@payload_type(BodyPartReader, order=Order.try_first)
616class BodyPartReaderPayload(Payload):
617 _value: BodyPartReader
618 # _autoclose = False (inherited) - Streaming reader that may have resources
620 def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None:
621 super().__init__(value, *args, **kwargs)
623 params: dict[str, str] = {}
624 if value.name is not None:
625 params["name"] = value.name
626 if value.filename is not None:
627 params["filename"] = value.filename
629 if params:
630 self.set_content_disposition("attachment", True, **params)
632 def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
633 raise TypeError("Unable to decode.")
635 async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes:
636 """Raises TypeError as body parts should be consumed via write().
638 This is intentional: BodyPartReader payloads are designed for streaming
639 large data (potentially gigabytes) and must be consumed only once via
640 the write() method to avoid memory exhaustion. They cannot be buffered
641 in memory for reuse.
642 """
643 raise TypeError("Unable to read body part as bytes. Use write() to consume.")
645 async def write(self, writer: AbstractStreamWriter) -> None:
646 field = self._value
647 while chunk := await field.read_chunk(size=DEFAULT_CHUNK_SIZE):
648 async for d in field.decode_iter(chunk):
649 await writer.write(d)
652class MultipartReader:
653 """Multipart body reader."""
655 #: Response wrapper, used when multipart readers constructs from response.
656 response_wrapper_cls = MultipartResponseWrapper
657 #: Multipart reader class, used to handle multipart/* body parts.
658 #: None points to type(self)
659 multipart_reader_cls: type["MultipartReader"] | None = None
660 #: Body part reader class for non multipart/* content types.
661 part_reader_cls = BodyPartReader
663 def __init__(
664 self,
665 headers: Mapping[str, str],
666 content: StreamReader,
667 *,
668 client_max_size: int = sys.maxsize,
669 max_field_size: int = 8190,
670 max_headers: int = 128,
671 max_size_error_cls: type[Exception] = ValueError,
672 ) -> None:
673 self._mimetype = parse_mimetype(headers[CONTENT_TYPE])
674 assert self._mimetype.type == "multipart", "multipart/* content type expected"
675 if "boundary" not in self._mimetype.parameters:
676 raise ValueError(
677 "boundary missed for Content-Type: %s" % headers[CONTENT_TYPE]
678 )
680 self.headers = headers
681 self._boundary = ("--" + self._get_boundary()).encode()
682 self._client_max_size = client_max_size
683 self._content = content
684 self._default_charset: str | None = None
685 self._last_part: MultipartReader | BodyPartReader | None = None
686 self._max_field_size = max_field_size
687 self._max_headers = max_headers
688 self._max_size_error_cls = max_size_error_cls
689 self._at_eof = False
690 self._at_bof = True
691 self._unread: list[bytes] = []
693 def __aiter__(self) -> Self:
694 return self
696 async def __anext__(
697 self,
698 ) -> Union["MultipartReader", BodyPartReader] | None:
699 part = await self.next()
700 if part is None:
701 raise StopAsyncIteration
702 return part
704 @classmethod
705 def from_response(
706 cls,
707 response: "ClientResponse",
708 ) -> MultipartResponseWrapper:
709 """Constructs reader instance from HTTP response.
711 :param response: :class:`~aiohttp.client.ClientResponse` instance
712 """
713 obj = cls.response_wrapper_cls(
714 response, cls(response.headers, response.content)
715 )
716 return obj
718 def at_eof(self) -> bool:
719 """Returns True if the final boundary was reached, false otherwise."""
720 return self._at_eof
722 async def next(
723 self,
724 ) -> Union["MultipartReader", BodyPartReader] | None:
725 """Emits the next multipart body part."""
726 # So, if we're at BOF, we need to skip till the boundary.
727 if self._at_eof:
728 return None
729 await self._maybe_release_last_part()
730 if self._at_bof:
731 await self._read_until_first_boundary()
732 self._at_bof = False
733 else:
734 await self._read_boundary()
735 if self._at_eof: # we just read the last boundary, nothing to do there
736 # https://github.com/python/mypy/issues/17537
737 return None # type: ignore[unreachable]
739 part = await self.fetch_next_part()
740 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.6
741 if (
742 self._last_part is None
743 and self._mimetype.subtype == "form-data"
744 and isinstance(part, BodyPartReader)
745 ):
746 _, params = parse_content_disposition(part.headers.get(CONTENT_DISPOSITION))
747 if params.get("name") == "_charset_":
748 # Longest encoding in https://encoding.spec.whatwg.org/encodings.json
749 # is 19 characters, so 32 should be more than enough for any valid encoding.
750 charset = await part.read_chunk(32)
751 if len(charset) > 31:
752 raise RuntimeError("Invalid default charset")
753 self._default_charset = charset.strip().decode()
754 part = await self.fetch_next_part()
755 self._last_part = part
756 return self._last_part
758 async def release(self) -> None:
759 """Reads all the body parts to the void till the final boundary."""
760 while not self._at_eof:
761 item = await self.next()
762 if item is None:
763 break
764 await item.release()
766 async def fetch_next_part(
767 self,
768 ) -> Union["MultipartReader", BodyPartReader]:
769 """Returns the next body part reader."""
770 headers = await self._read_headers()
771 return self._get_part_reader(headers)
773 def _get_part_reader(
774 self,
775 headers: HeadersDictProxy,
776 ) -> Union["MultipartReader", BodyPartReader]:
777 """Dispatches the response by the `Content-Type` header.
779 Returns a suitable reader instance.
781 :param dict headers: Response headers
782 """
783 ctype = headers.get(CONTENT_TYPE, "")
784 mimetype = parse_mimetype(ctype)
786 if mimetype.type == "multipart":
787 if self.multipart_reader_cls is None:
788 return type(self)(
789 headers,
790 self._content,
791 client_max_size=self._client_max_size,
792 max_field_size=self._max_field_size,
793 max_headers=self._max_headers,
794 max_size_error_cls=self._max_size_error_cls,
795 )
796 return self.multipart_reader_cls(
797 headers,
798 self._content,
799 client_max_size=self._client_max_size,
800 max_field_size=self._max_field_size,
801 max_headers=self._max_headers,
802 max_size_error_cls=self._max_size_error_cls,
803 )
804 else:
805 return self.part_reader_cls(
806 self._boundary,
807 headers,
808 self._content,
809 subtype=self._mimetype.subtype,
810 default_charset=self._default_charset,
811 client_max_size=self._client_max_size,
812 max_size_error_cls=self._max_size_error_cls,
813 )
815 def _get_boundary(self) -> str:
816 boundary = self._mimetype.parameters["boundary"]
817 if len(boundary) > 70:
818 raise ValueError("boundary %r is too long (70 chars max)" % boundary)
820 return boundary
822 async def _readline(self) -> bytes:
823 if self._unread:
824 return self._unread.pop()
825 return await self._content.readline()
827 async def _read_until_first_boundary(self) -> None:
828 while True:
829 chunk = await self._readline()
830 if chunk == b"":
831 raise ValueError(f"Could not find starting boundary {self._boundary!r}")
832 chunk = chunk.rstrip()
833 if chunk == self._boundary:
834 return
835 elif chunk == self._boundary + b"--":
836 self._at_eof = True
837 return
839 async def _read_boundary(self) -> None:
840 chunk = (await self._readline()).rstrip()
841 if chunk == self._boundary:
842 pass
843 elif chunk == self._boundary + b"--":
844 self._at_eof = True
845 epilogue = await self._readline()
846 next_line = await self._readline()
848 # the epilogue is expected and then either the end of input or the
849 # parent multipart boundary, if the parent boundary is found then
850 # it should be marked as unread and handed to the parent for
851 # processing
852 if next_line[:2] == b"--":
853 self._unread.append(next_line)
854 # otherwise the request is likely missing an epilogue and both
855 # lines should be passed to the parent for processing
856 # (this handles the old behavior gracefully)
857 else:
858 self._unread.extend([next_line, epilogue])
859 else:
860 raise ValueError(f"Invalid boundary {chunk!r}, expected {self._boundary!r}")
862 async def _read_headers(self) -> HeadersDictProxy:
863 lines = []
864 while True:
865 chunk = await self._content.readline(max_line_length=self._max_field_size)
866 chunk = chunk.rstrip(b"\r\n")
867 lines.append(chunk)
868 if not chunk:
869 break
870 if len(lines) > self._max_headers:
871 raise BadHttpMessage("Too many headers received")
872 parser = HeadersParser(max_field_size=self._max_field_size)
873 headers, _ = parser.parse_headers(lines)
874 return headers
876 async def _maybe_release_last_part(self) -> None:
877 """Ensures that the last read body part is read completely."""
878 if self._last_part is not None:
879 if not self._last_part.at_eof():
880 await self._last_part.release()
881 self._unread.extend(self._last_part._unread)
882 self._last_part = None
885_Part = tuple[Payload, str, str]
888class MultipartWriter(Payload):
889 """Multipart body writer."""
891 _value: None
892 # _consumed = False (inherited) - Can be encoded multiple times
893 _autoclose = True # No file handles, just collects parts in memory
895 def __init__(self, subtype: str = "mixed", boundary: str | None = None) -> None:
896 boundary = boundary if boundary is not None else uuid.uuid4().hex
897 # The underlying Payload API demands a str (utf-8), not bytes,
898 # so we need to ensure we don't lose anything during conversion.
899 # As a result, require the boundary to be ASCII only.
900 # In both situations.
902 try:
903 self._boundary = boundary.encode("ascii")
904 except UnicodeEncodeError:
905 raise ValueError("boundary should contain ASCII only chars") from None
907 if len(boundary) > 70:
908 raise ValueError("boundary %r is too long (70 chars max)" % boundary)
910 ctype = f"multipart/{subtype}; boundary={self._boundary_value}"
912 super().__init__(None, content_type=ctype)
914 self._parts: list[_Part] = []
915 self._is_form_data = subtype == "form-data"
917 def __enter__(self) -> "MultipartWriter":
918 return self
920 def __exit__(
921 self,
922 exc_type: type[BaseException] | None,
923 exc_val: BaseException | None,
924 exc_tb: TracebackType | None,
925 ) -> None:
926 pass
928 def __iter__(self) -> Iterator[_Part]:
929 return iter(self._parts)
931 def __len__(self) -> int:
932 return len(self._parts)
934 def __bool__(self) -> bool:
935 return True
937 _valid_tchar_regex = re.compile(rb"\A[!#$%&'*+\-.^_`|~\w]+\Z")
938 _invalid_qdtext_char_regex = re.compile(rb"[\x00-\x08\x0A-\x1F\x7F]")
940 @property
941 def _boundary_value(self) -> str:
942 """Wrap boundary parameter value in quotes, if necessary.
944 Reads self.boundary and returns a unicode string.
945 """
946 # Refer to RFCs 7231, 7230, 5234.
947 #
948 # parameter = token "=" ( token / quoted-string )
949 # token = 1*tchar
950 # quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE
951 # qdtext = HTAB / SP / %x21 / %x23-5B / %x5D-7E / obs-text
952 # obs-text = %x80-FF
953 # quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text )
954 # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*"
955 # / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~"
956 # / DIGIT / ALPHA
957 # ; any VCHAR, except delimiters
958 # VCHAR = %x21-7E
959 value = self._boundary
960 if re.match(self._valid_tchar_regex, value):
961 return value.decode("ascii") # cannot fail
963 if re.search(self._invalid_qdtext_char_regex, value):
964 raise ValueError("boundary value contains invalid characters")
966 # escape %x5C and %x22
967 quoted_value_content = value.replace(b"\\", b"\\\\")
968 quoted_value_content = quoted_value_content.replace(b'"', b'\\"')
970 return '"' + quoted_value_content.decode("ascii") + '"'
972 @property
973 def boundary(self) -> str:
974 return self._boundary.decode("ascii")
976 def append(self, obj: Any, headers: Mapping[str, str] | None = None) -> Payload:
977 if headers is None:
978 headers = CIMultiDict()
980 if isinstance(obj, Payload):
981 obj.headers.update(headers)
982 return self.append_payload(obj)
983 else:
984 try:
985 payload = get_payload(obj, headers=headers)
986 except LookupError:
987 raise TypeError("Cannot create payload from %r" % obj)
988 else:
989 return self.append_payload(payload)
991 def append_payload(self, payload: Payload) -> Payload:
992 """Adds a new body part to multipart writer."""
993 encoding: str | None = None
994 te_encoding: str | None = None
995 if self._is_form_data:
996 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.7
997 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
998 assert (
999 not {CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING}
1000 & payload.headers.keys()
1001 )
1002 # Set default Content-Disposition in case user doesn't create one
1003 if CONTENT_DISPOSITION not in payload.headers:
1004 name = f"section-{len(self._parts)}"
1005 payload.set_content_disposition("form-data", name=name)
1006 else:
1007 # compression
1008 encoding = payload.headers.get(CONTENT_ENCODING, "").lower()
1009 if encoding and encoding not in ("deflate", "gzip", "identity"):
1010 raise RuntimeError(f"unknown content encoding: {encoding}")
1011 if encoding == "identity":
1012 encoding = None
1014 # te encoding
1015 te_encoding = payload.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
1016 if te_encoding not in ("", "base64", "quoted-printable", "binary"):
1017 raise RuntimeError(f"unknown content transfer encoding: {te_encoding}")
1018 if te_encoding == "binary":
1019 te_encoding = None
1021 # size
1022 size = payload.size
1023 if size is not None and not (encoding or te_encoding):
1024 payload.headers[CONTENT_LENGTH] = str(size)
1026 self._parts.append((payload, encoding, te_encoding)) # type: ignore[arg-type]
1027 return payload
1029 def append_json(
1030 self, obj: Any, headers: Mapping[str, str] | None = None
1031 ) -> Payload:
1032 """Helper to append JSON part."""
1033 if headers is None:
1034 headers = CIMultiDict()
1036 return self.append_payload(JsonPayload(obj, headers=headers))
1038 def append_form(
1039 self,
1040 obj: Sequence[tuple[str, str]] | Mapping[str, str],
1041 headers: Mapping[str, str] | None = None,
1042 ) -> Payload:
1043 """Helper to append form urlencoded part."""
1044 assert isinstance(obj, (Sequence, Mapping))
1046 if headers is None:
1047 headers = CIMultiDict()
1049 if isinstance(obj, Mapping):
1050 obj = list(obj.items())
1051 data = urlencode(obj, doseq=True)
1053 return self.append_payload(
1054 StringPayload(
1055 data, headers=headers, content_type="application/x-www-form-urlencoded"
1056 )
1057 )
1059 @property
1060 def size(self) -> int | None:
1061 """Size of the payload."""
1062 total = 0
1063 for part, encoding, te_encoding in self._parts:
1064 part_size = part.size
1065 if encoding or te_encoding or part_size is None:
1066 return None
1068 total += int(
1069 2
1070 + len(self._boundary)
1071 + 2
1072 + part_size # b'--'+self._boundary+b'\r\n'
1073 + len(part._binary_headers)
1074 + 2 # b'\r\n'
1075 )
1077 total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n'
1078 return total
1080 def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
1081 """Return string representation of the multipart data.
1083 WARNING: This method may do blocking I/O if parts contain file payloads.
1084 It should not be called in the event loop. Use as_bytes().decode() instead.
1085 """
1086 return "".join(
1087 "--"
1088 + self.boundary
1089 + "\r\n"
1090 + part._binary_headers.decode(encoding, errors)
1091 + part.decode()
1092 for part, _e, _te in self._parts
1093 )
1095 async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes:
1096 """Return bytes representation of the multipart data.
1098 This method is async-safe and calls as_bytes on underlying payloads.
1099 """
1100 parts: list[bytes] = []
1102 # Process each part
1103 for part, _e, _te in self._parts:
1104 # Add boundary
1105 parts.append(b"--" + self._boundary + b"\r\n")
1107 # Add headers
1108 parts.append(part._binary_headers)
1110 # Add payload content using as_bytes for async safety
1111 part_bytes = await part.as_bytes(encoding, errors)
1112 parts.append(part_bytes)
1114 # Add trailing CRLF
1115 parts.append(b"\r\n")
1117 # Add closing boundary
1118 parts.append(b"--" + self._boundary + b"--\r\n")
1120 return b"".join(parts)
1122 async def write(
1123 self, writer: AbstractStreamWriter, close_boundary: bool = True
1124 ) -> None:
1125 """Write body."""
1126 for part, encoding, te_encoding in self._parts:
1127 if self._is_form_data:
1128 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.2
1129 assert CONTENT_DISPOSITION in part.headers
1130 assert "name=" in part.headers[CONTENT_DISPOSITION]
1132 await writer.write(b"--" + self._boundary + b"\r\n")
1133 await writer.write(part._binary_headers)
1135 if encoding or te_encoding:
1136 w = MultipartPayloadWriter(writer)
1137 if encoding:
1138 w.enable_compression(encoding)
1139 if te_encoding:
1140 w.enable_encoding(te_encoding)
1141 await part.write(w) # type: ignore[arg-type]
1142 await w.write_eof()
1143 else:
1144 await part.write(writer)
1146 await writer.write(b"\r\n")
1148 if close_boundary:
1149 await writer.write(b"--" + self._boundary + b"--\r\n")
1151 async def close(self) -> None:
1152 """
1153 Close all part payloads that need explicit closing.
1155 IMPORTANT: This method must not await anything that might not finish
1156 immediately, as it may be called during cleanup/cancellation. Schedule
1157 any long-running operations without awaiting them.
1158 """
1159 if self._consumed:
1160 return
1161 self._consumed = True
1163 # Close all parts that need explicit closing
1164 # We catch and log exceptions to ensure all parts get a chance to close
1165 # we do not use asyncio.gather() here because we are not allowed
1166 # to suspend given we may be called during cleanup
1167 for idx, (part, _, _) in enumerate(self._parts):
1168 if not part.autoclose and not part.consumed:
1169 try:
1170 await part.close()
1171 except Exception as exc:
1172 internal_logger.error(
1173 "Failed to close multipart part %d: %s", idx, exc, exc_info=True
1174 )
1177class MultipartPayloadWriter:
1178 def __init__(self, writer: AbstractStreamWriter) -> None:
1179 self._writer = writer
1180 self._encoding: str | None = None
1181 self._compress: ZLibCompressor | None = None
1182 self._encoding_buffer: bytearray | None = None
1184 def enable_encoding(self, encoding: str) -> None:
1185 if encoding == "base64":
1186 self._encoding = encoding
1187 self._encoding_buffer = bytearray()
1188 elif encoding == "quoted-printable":
1189 self._encoding = "quoted-printable"
1191 def enable_compression(
1192 self, encoding: str = "deflate", strategy: int | None = None
1193 ) -> None:
1194 self._compress = ZLibCompressor(
1195 encoding=encoding,
1196 suppress_deflate_header=True,
1197 strategy=strategy,
1198 )
1200 async def write_eof(self) -> None:
1201 if self._compress is not None:
1202 chunk = self._compress.flush()
1203 if chunk:
1204 self._compress = None
1205 await self.write(chunk)
1207 if self._encoding == "base64":
1208 if self._encoding_buffer:
1209 await self._writer.write(base64.b64encode(self._encoding_buffer))
1211 async def write(self, chunk: bytes) -> None:
1212 if self._compress is not None:
1213 if chunk:
1214 chunk = await self._compress.compress(chunk)
1215 if not chunk:
1216 return
1218 if self._encoding == "base64":
1219 buf = self._encoding_buffer
1220 assert buf is not None
1221 buf.extend(chunk)
1223 if buf:
1224 div, mod = divmod(len(buf), 3)
1225 enc_chunk, self._encoding_buffer = (buf[: div * 3], buf[div * 3 :])
1226 if enc_chunk:
1227 b64chunk = base64.b64encode(enc_chunk)
1228 await self._writer.write(b64chunk)
1229 elif self._encoding == "quoted-printable":
1230 await self._writer.write(binascii.b2a_qp(chunk))
1231 else:
1232 await self._writer.write(chunk)