Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/aiohttp/multipart.py: 19%

562 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 06:40 +0000

1import base64 

2import binascii 

3import json 

4import re 

5import uuid 

6import warnings 

7import zlib 

8from collections import deque 

9from types import TracebackType 

10from typing import ( 

11 TYPE_CHECKING, 

12 Any, 

13 AsyncIterator, 

14 Deque, 

15 Dict, 

16 Iterator, 

17 List, 

18 Mapping, 

19 Optional, 

20 Sequence, 

21 Tuple, 

22 Type, 

23 Union, 

24 cast, 

25) 

26from urllib.parse import parse_qsl, unquote, urlencode 

27 

28from multidict import CIMultiDict, CIMultiDictProxy, MultiMapping 

29 

30from .compression_utils import ZLibCompressor, ZLibDecompressor 

31from .hdrs import ( 

32 CONTENT_DISPOSITION, 

33 CONTENT_ENCODING, 

34 CONTENT_LENGTH, 

35 CONTENT_TRANSFER_ENCODING, 

36 CONTENT_TYPE, 

37) 

38from .helpers import CHAR, TOKEN, parse_mimetype, reify 

39from .http import HeadersParser 

40from .payload import ( 

41 JsonPayload, 

42 LookupError, 

43 Order, 

44 Payload, 

45 StringPayload, 

46 get_payload, 

47 payload_type, 

48) 

49from .streams import StreamReader 

50 

51__all__ = ( 

52 "MultipartReader", 

53 "MultipartWriter", 

54 "BodyPartReader", 

55 "BadContentDispositionHeader", 

56 "BadContentDispositionParam", 

57 "parse_content_disposition", 

58 "content_disposition_filename", 

59) 

60 

61 

62if TYPE_CHECKING: # pragma: no cover 

63 from .client_reqrep import ClientResponse 

64 

65 

66class BadContentDispositionHeader(RuntimeWarning): 

67 pass 

68 

69 

70class BadContentDispositionParam(RuntimeWarning): 

71 pass 

72 

73 

74def parse_content_disposition( 

75 header: Optional[str], 

76) -> Tuple[Optional[str], Dict[str, str]]: 

77 def is_token(string: str) -> bool: 

78 return bool(string) and TOKEN >= set(string) 

79 

80 def is_quoted(string: str) -> bool: 

81 return string[0] == string[-1] == '"' 

82 

83 def is_rfc5987(string: str) -> bool: 

84 return is_token(string) and string.count("'") == 2 

85 

86 def is_extended_param(string: str) -> bool: 

87 return string.endswith("*") 

88 

89 def is_continuous_param(string: str) -> bool: 

90 pos = string.find("*") + 1 

91 if not pos: 

92 return False 

93 substring = string[pos:-1] if string.endswith("*") else string[pos:] 

94 return substring.isdigit() 

95 

96 def unescape(text: str, *, chars: str = "".join(map(re.escape, CHAR))) -> str: 

97 return re.sub(f"\\\\([{chars}])", "\\1", text) 

98 

99 if not header: 

100 return None, {} 

101 

102 disptype, *parts = header.split(";") 

103 if not is_token(disptype): 

104 warnings.warn(BadContentDispositionHeader(header)) 

105 return None, {} 

106 

107 params: Dict[str, str] = {} 

108 while parts: 

109 item = parts.pop(0) 

110 

111 if "=" not in item: 

112 warnings.warn(BadContentDispositionHeader(header)) 

113 return None, {} 

114 

115 key, value = item.split("=", 1) 

116 key = key.lower().strip() 

117 value = value.lstrip() 

118 

119 if key in params: 

120 warnings.warn(BadContentDispositionHeader(header)) 

121 return None, {} 

122 

123 if not is_token(key): 

124 warnings.warn(BadContentDispositionParam(item)) 

125 continue 

126 

127 elif is_continuous_param(key): 

128 if is_quoted(value): 

129 value = unescape(value[1:-1]) 

130 elif not is_token(value): 

131 warnings.warn(BadContentDispositionParam(item)) 

132 continue 

133 

134 elif is_extended_param(key): 

135 if is_rfc5987(value): 

136 encoding, _, value = value.split("'", 2) 

137 encoding = encoding or "utf-8" 

138 else: 

139 warnings.warn(BadContentDispositionParam(item)) 

140 continue 

141 

142 try: 

143 value = unquote(value, encoding, "strict") 

144 except UnicodeDecodeError: # pragma: nocover 

145 warnings.warn(BadContentDispositionParam(item)) 

146 continue 

147 

148 else: 

149 failed = True 

150 if is_quoted(value): 

151 failed = False 

152 value = unescape(value[1:-1].lstrip("\\/")) 

153 elif is_token(value): 

154 failed = False 

155 elif parts: 

156 # maybe just ; in filename, in any case this is just 

157 # one case fix, for proper fix we need to redesign parser 

158 _value = f"{value};{parts[0]}" 

159 if is_quoted(_value): 

160 parts.pop(0) 

161 value = unescape(_value[1:-1].lstrip("\\/")) 

162 failed = False 

163 

164 if failed: 

165 warnings.warn(BadContentDispositionHeader(header)) 

166 return None, {} 

167 

168 params[key] = value 

169 

170 return disptype.lower(), params 

171 

172 

173def content_disposition_filename( 

174 params: Mapping[str, str], name: str = "filename" 

175) -> Optional[str]: 

176 name_suf = "%s*" % name 

177 if not params: 

178 return None 

179 elif name_suf in params: 

180 return params[name_suf] 

181 elif name in params: 

182 return params[name] 

183 else: 

184 parts = [] 

185 fnparams = sorted( 

186 (key, value) for key, value in params.items() if key.startswith(name_suf) 

187 ) 

188 for num, (key, value) in enumerate(fnparams): 

189 _, tail = key.split("*", 1) 

190 if tail.endswith("*"): 

191 tail = tail[:-1] 

192 if tail == str(num): 

193 parts.append(value) 

194 else: 

195 break 

196 if not parts: 

197 return None 

198 value = "".join(parts) 

199 if "'" in value: 

200 encoding, _, value = value.split("'", 2) 

201 encoding = encoding or "utf-8" 

202 return unquote(value, encoding, "strict") 

203 return value 

204 

205 

206class MultipartResponseWrapper: 

207 """Wrapper around the MultipartReader. 

208 

209 It takes care about 

210 underlying connection and close it when it needs in. 

211 """ 

212 

213 def __init__( 

214 self, 

215 resp: "ClientResponse", 

216 stream: "MultipartReader", 

217 ) -> None: 

218 self.resp = resp 

219 self.stream = stream 

220 

221 def __aiter__(self) -> "MultipartResponseWrapper": 

222 return self 

223 

224 async def __anext__( 

225 self, 

226 ) -> Union["MultipartReader", "BodyPartReader"]: 

227 part = await self.next() 

228 if part is None: 

229 raise StopAsyncIteration 

230 return part 

231 

232 def at_eof(self) -> bool: 

233 """Returns True when all response data had been read.""" 

234 return self.resp.content.at_eof() 

235 

236 async def next( 

237 self, 

238 ) -> Optional[Union["MultipartReader", "BodyPartReader"]]: 

239 """Emits next multipart reader object.""" 

240 item = await self.stream.next() 

241 if self.stream.at_eof(): 

242 await self.release() 

243 return item 

244 

245 async def release(self) -> None: 

246 """Release the connection gracefully. 

247 

248 All remaining content is read to the void. 

249 """ 

250 await self.resp.release() 

251 

252 

253class BodyPartReader: 

254 """Multipart reader for single body part.""" 

255 

256 chunk_size = 8192 

257 

258 def __init__( 

259 self, boundary: bytes, headers: "CIMultiDictProxy[str]", content: StreamReader 

260 ) -> None: 

261 self.headers = headers 

262 self._boundary = boundary 

263 self._content = content 

264 self._at_eof = False 

265 length = self.headers.get(CONTENT_LENGTH, None) 

266 self._length = int(length) if length is not None else None 

267 self._read_bytes = 0 

268 self._unread: Deque[bytes] = deque() 

269 self._prev_chunk: Optional[bytes] = None 

270 self._content_eof = 0 

271 self._cache: Dict[str, Any] = {} 

272 

273 def __aiter__(self) -> AsyncIterator["BodyPartReader"]: 

274 return self # type: ignore[return-value] 

275 

276 async def __anext__(self) -> bytes: 

277 part = await self.next() 

278 if part is None: 

279 raise StopAsyncIteration 

280 return part 

281 

282 async def next(self) -> Optional[bytes]: 

283 item = await self.read() 

284 if not item: 

285 return None 

286 return item 

287 

288 async def read(self, *, decode: bool = False) -> bytes: 

289 """Reads body part data. 

290 

291 decode: Decodes data following by encoding 

292 method from Content-Encoding header. If it missed 

293 data remains untouched 

294 """ 

295 if self._at_eof: 

296 return b"" 

297 data = bytearray() 

298 while not self._at_eof: 

299 data.extend(await self.read_chunk(self.chunk_size)) 

300 if decode: 

301 return self.decode(data) 

302 return data 

303 

304 async def read_chunk(self, size: int = chunk_size) -> bytes: 

305 """Reads body part content chunk of the specified size. 

306 

307 size: chunk size 

308 """ 

309 if self._at_eof: 

310 return b"" 

311 if self._length: 

312 chunk = await self._read_chunk_from_length(size) 

313 else: 

314 chunk = await self._read_chunk_from_stream(size) 

315 

316 self._read_bytes += len(chunk) 

317 if self._read_bytes == self._length: 

318 self._at_eof = True 

319 if self._at_eof: 

320 clrf = await self._content.readline() 

321 assert ( 

322 b"\r\n" == clrf 

323 ), "reader did not read all the data or it is malformed" 

324 return chunk 

325 

326 async def _read_chunk_from_length(self, size: int) -> bytes: 

327 # Reads body part content chunk of the specified size. 

328 # The body part must has Content-Length header with proper value. 

329 assert self._length is not None, "Content-Length required for chunked read" 

330 chunk_size = min(size, self._length - self._read_bytes) 

331 chunk = await self._content.read(chunk_size) 

332 return chunk 

333 

334 async def _read_chunk_from_stream(self, size: int) -> bytes: 

335 # Reads content chunk of body part with unknown length. 

336 # The Content-Length header for body part is not necessary. 

337 assert ( 

338 size >= len(self._boundary) + 2 

339 ), "Chunk size must be greater or equal than boundary length + 2" 

340 first_chunk = self._prev_chunk is None 

341 if first_chunk: 

342 self._prev_chunk = await self._content.read(size) 

343 

344 chunk = await self._content.read(size) 

345 self._content_eof += int(self._content.at_eof()) 

346 assert self._content_eof < 3, "Reading after EOF" 

347 assert self._prev_chunk is not None 

348 window = self._prev_chunk + chunk 

349 sub = b"\r\n" + self._boundary 

350 if first_chunk: 

351 idx = window.find(sub) 

352 else: 

353 idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub))) 

354 if idx >= 0: 

355 # pushing boundary back to content 

356 with warnings.catch_warnings(): 

357 warnings.filterwarnings("ignore", category=DeprecationWarning) 

358 self._content.unread_data(window[idx:]) 

359 if size > idx: 

360 self._prev_chunk = self._prev_chunk[:idx] 

361 chunk = window[len(self._prev_chunk) : idx] 

362 if not chunk: 

363 self._at_eof = True 

364 result = self._prev_chunk 

365 self._prev_chunk = chunk 

366 return result 

367 

368 async def readline(self) -> bytes: 

369 """Reads body part by line by line.""" 

370 if self._at_eof: 

371 return b"" 

372 

373 if self._unread: 

374 line = self._unread.popleft() 

375 else: 

376 line = await self._content.readline() 

377 

378 if line.startswith(self._boundary): 

379 # the very last boundary may not come with \r\n, 

380 # so set single rules for everyone 

381 sline = line.rstrip(b"\r\n") 

382 boundary = self._boundary 

383 last_boundary = self._boundary + b"--" 

384 # ensure that we read exactly the boundary, not something alike 

385 if sline == boundary or sline == last_boundary: 

386 self._at_eof = True 

387 self._unread.append(line) 

388 return b"" 

389 else: 

390 next_line = await self._content.readline() 

391 if next_line.startswith(self._boundary): 

392 line = line[:-2] # strip CRLF but only once 

393 self._unread.append(next_line) 

394 

395 return line 

396 

397 async def release(self) -> None: 

398 """Like read(), but reads all the data to the void.""" 

399 if self._at_eof: 

400 return 

401 while not self._at_eof: 

402 await self.read_chunk(self.chunk_size) 

403 

404 async def text(self, *, encoding: Optional[str] = None) -> str: 

405 """Like read(), but assumes that body part contains text data.""" 

406 data = await self.read(decode=True) 

407 # see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm 

408 # and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send 

409 encoding = encoding or self.get_charset(default="utf-8") 

410 return data.decode(encoding) 

411 

412 async def json(self, *, encoding: Optional[str] = None) -> Optional[Dict[str, Any]]: 

413 """Like read(), but assumes that body parts contains JSON data.""" 

414 data = await self.read(decode=True) 

415 if not data: 

416 return None 

417 encoding = encoding or self.get_charset(default="utf-8") 

418 return cast(Dict[str, Any], json.loads(data.decode(encoding))) 

419 

420 async def form(self, *, encoding: Optional[str] = None) -> List[Tuple[str, str]]: 

421 """Like read(), but assumes that body parts contain form urlencoded data.""" 

422 data = await self.read(decode=True) 

423 if not data: 

424 return [] 

425 if encoding is not None: 

426 real_encoding = encoding 

427 else: 

428 real_encoding = self.get_charset(default="utf-8") 

429 try: 

430 decoded_data = data.rstrip().decode(real_encoding) 

431 except UnicodeDecodeError: 

432 raise ValueError("data cannot be decoded with %s encoding" % real_encoding) 

433 

434 return parse_qsl( 

435 decoded_data, 

436 keep_blank_values=True, 

437 encoding=real_encoding, 

438 ) 

439 

440 def at_eof(self) -> bool: 

441 """Returns True if the boundary was reached or False otherwise.""" 

442 return self._at_eof 

443 

444 def decode(self, data: bytes) -> bytes: 

445 """Decodes data. 

446 

447 Decoding is done according the specified Content-Encoding 

448 or Content-Transfer-Encoding headers value. 

449 """ 

450 if CONTENT_TRANSFER_ENCODING in self.headers: 

451 data = self._decode_content_transfer(data) 

452 if CONTENT_ENCODING in self.headers: 

453 return self._decode_content(data) 

454 return data 

455 

456 def _decode_content(self, data: bytes) -> bytes: 

457 encoding = self.headers.get(CONTENT_ENCODING, "").lower() 

458 if encoding == "identity": 

459 return data 

460 if encoding in {"deflate", "gzip"}: 

461 return ZLibDecompressor( 

462 encoding=encoding, 

463 suppress_deflate_header=True, 

464 ).decompress_sync(data) 

465 

466 raise RuntimeError(f"unknown content encoding: {encoding}") 

467 

468 def _decode_content_transfer(self, data: bytes) -> bytes: 

469 encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower() 

470 

471 if encoding == "base64": 

472 return base64.b64decode(data) 

473 elif encoding == "quoted-printable": 

474 return binascii.a2b_qp(data) 

475 elif encoding in ("binary", "8bit", "7bit"): 

476 return data 

477 else: 

478 raise RuntimeError( 

479 "unknown content transfer encoding: {}" "".format(encoding) 

480 ) 

481 

482 def get_charset(self, default: str) -> str: 

483 """Returns charset parameter from Content-Type header or default.""" 

484 ctype = self.headers.get(CONTENT_TYPE, "") 

485 mimetype = parse_mimetype(ctype) 

486 return mimetype.parameters.get("charset", default) 

487 

488 @reify 

489 def name(self) -> Optional[str]: 

490 """Returns name specified in Content-Disposition header. 

491 

492 If the header is missing or malformed, returns None. 

493 """ 

494 _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION)) 

495 return content_disposition_filename(params, "name") 

496 

497 @reify 

498 def filename(self) -> Optional[str]: 

499 """Returns filename specified in Content-Disposition header. 

500 

501 Returns None if the header is missing or malformed. 

502 """ 

503 _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION)) 

504 return content_disposition_filename(params, "filename") 

505 

506 

507@payload_type(BodyPartReader, order=Order.try_first) 

508class BodyPartReaderPayload(Payload): 

509 def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None: 

510 super().__init__(value, *args, **kwargs) 

511 

512 params: Dict[str, str] = {} 

513 if value.name is not None: 

514 params["name"] = value.name 

515 if value.filename is not None: 

516 params["filename"] = value.filename 

517 

518 if params: 

519 self.set_content_disposition("attachment", True, **params) 

520 

521 async def write(self, writer: Any) -> None: 

522 field = self._value 

523 chunk = await field.read_chunk(size=2**16) 

524 while chunk: 

525 await writer.write(field.decode(chunk)) 

526 chunk = await field.read_chunk(size=2**16) 

527 

528 

529class MultipartReader: 

530 """Multipart body reader.""" 

531 

532 #: Response wrapper, used when multipart readers constructs from response. 

533 response_wrapper_cls = MultipartResponseWrapper 

534 #: Multipart reader class, used to handle multipart/* body parts. 

535 #: None points to type(self) 

536 multipart_reader_cls = None 

537 #: Body part reader class for non multipart/* content types. 

538 part_reader_cls = BodyPartReader 

539 

540 def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None: 

541 self.headers = headers 

542 self._boundary = ("--" + self._get_boundary()).encode() 

543 self._content = content 

544 self._last_part: Optional[Union["MultipartReader", BodyPartReader]] = None 

545 self._at_eof = False 

546 self._at_bof = True 

547 self._unread: List[bytes] = [] 

548 

549 def __aiter__( 

550 self, 

551 ) -> AsyncIterator["BodyPartReader"]: 

552 return self # type: ignore[return-value] 

553 

554 async def __anext__( 

555 self, 

556 ) -> Optional[Union["MultipartReader", BodyPartReader]]: 

557 part = await self.next() 

558 if part is None: 

559 raise StopAsyncIteration 

560 return part 

561 

562 @classmethod 

563 def from_response( 

564 cls, 

565 response: "ClientResponse", 

566 ) -> MultipartResponseWrapper: 

567 """Constructs reader instance from HTTP response. 

568 

569 :param response: :class:`~aiohttp.client.ClientResponse` instance 

570 """ 

571 obj = cls.response_wrapper_cls( 

572 response, cls(response.headers, response.content) 

573 ) 

574 return obj 

575 

576 def at_eof(self) -> bool: 

577 """Returns True if the final boundary was reached, false otherwise.""" 

578 return self._at_eof 

579 

580 async def next( 

581 self, 

582 ) -> Optional[Union["MultipartReader", BodyPartReader]]: 

583 """Emits the next multipart body part.""" 

584 # So, if we're at BOF, we need to skip till the boundary. 

585 if self._at_eof: 

586 return None 

587 await self._maybe_release_last_part() 

588 if self._at_bof: 

589 await self._read_until_first_boundary() 

590 self._at_bof = False 

591 else: 

592 await self._read_boundary() 

593 if self._at_eof: # we just read the last boundary, nothing to do there 

594 return None 

595 self._last_part = await self.fetch_next_part() 

596 return self._last_part 

597 

598 async def release(self) -> None: 

599 """Reads all the body parts to the void till the final boundary.""" 

600 while not self._at_eof: 

601 item = await self.next() 

602 if item is None: 

603 break 

604 await item.release() 

605 

606 async def fetch_next_part( 

607 self, 

608 ) -> Union["MultipartReader", BodyPartReader]: 

609 """Returns the next body part reader.""" 

610 headers = await self._read_headers() 

611 return self._get_part_reader(headers) 

612 

613 def _get_part_reader( 

614 self, 

615 headers: "CIMultiDictProxy[str]", 

616 ) -> Union["MultipartReader", BodyPartReader]: 

617 """Dispatches the response by the `Content-Type` header. 

618 

619 Returns a suitable reader instance. 

620 

621 :param dict headers: Response headers 

622 """ 

623 ctype = headers.get(CONTENT_TYPE, "") 

624 mimetype = parse_mimetype(ctype) 

625 

626 if mimetype.type == "multipart": 

627 if self.multipart_reader_cls is None: 

628 return type(self)(headers, self._content) 

629 return self.multipart_reader_cls(headers, self._content) 

630 else: 

631 return self.part_reader_cls(self._boundary, headers, self._content) 

632 

633 def _get_boundary(self) -> str: 

634 mimetype = parse_mimetype(self.headers[CONTENT_TYPE]) 

635 

636 assert mimetype.type == "multipart", "multipart/* content type expected" 

637 

638 if "boundary" not in mimetype.parameters: 

639 raise ValueError( 

640 "boundary missed for Content-Type: %s" % self.headers[CONTENT_TYPE] 

641 ) 

642 

643 boundary = mimetype.parameters["boundary"] 

644 if len(boundary) > 70: 

645 raise ValueError("boundary %r is too long (70 chars max)" % boundary) 

646 

647 return boundary 

648 

649 async def _readline(self) -> bytes: 

650 if self._unread: 

651 return self._unread.pop() 

652 return await self._content.readline() 

653 

654 async def _read_until_first_boundary(self) -> None: 

655 while True: 

656 chunk = await self._readline() 

657 if chunk == b"": 

658 raise ValueError( 

659 "Could not find starting boundary %r" % (self._boundary) 

660 ) 

661 chunk = chunk.rstrip() 

662 if chunk == self._boundary: 

663 return 

664 elif chunk == self._boundary + b"--": 

665 self._at_eof = True 

666 return 

667 

668 async def _read_boundary(self) -> None: 

669 chunk = (await self._readline()).rstrip() 

670 if chunk == self._boundary: 

671 pass 

672 elif chunk == self._boundary + b"--": 

673 self._at_eof = True 

674 epilogue = await self._readline() 

675 next_line = await self._readline() 

676 

677 # the epilogue is expected and then either the end of input or the 

678 # parent multipart boundary, if the parent boundary is found then 

679 # it should be marked as unread and handed to the parent for 

680 # processing 

681 if next_line[:2] == b"--": 

682 self._unread.append(next_line) 

683 # otherwise the request is likely missing an epilogue and both 

684 # lines should be passed to the parent for processing 

685 # (this handles the old behavior gracefully) 

686 else: 

687 self._unread.extend([next_line, epilogue]) 

688 else: 

689 raise ValueError(f"Invalid boundary {chunk!r}, expected {self._boundary!r}") 

690 

691 async def _read_headers(self) -> "CIMultiDictProxy[str]": 

692 lines = [b""] 

693 while True: 

694 chunk = await self._content.readline() 

695 chunk = chunk.strip() 

696 lines.append(chunk) 

697 if not chunk: 

698 break 

699 parser = HeadersParser() 

700 headers, raw_headers = parser.parse_headers(lines) 

701 return headers 

702 

703 async def _maybe_release_last_part(self) -> None: 

704 """Ensures that the last read body part is read completely.""" 

705 if self._last_part is not None: 

706 if not self._last_part.at_eof(): 

707 await self._last_part.release() 

708 self._unread.extend(self._last_part._unread) 

709 self._last_part = None 

710 

711 

712_Part = Tuple[Payload, str, str] 

713 

714 

715class MultipartWriter(Payload): 

716 """Multipart body writer.""" 

717 

718 def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> None: 

719 boundary = boundary if boundary is not None else uuid.uuid4().hex 

720 # The underlying Payload API demands a str (utf-8), not bytes, 

721 # so we need to ensure we don't lose anything during conversion. 

722 # As a result, require the boundary to be ASCII only. 

723 # In both situations. 

724 

725 try: 

726 self._boundary = boundary.encode("ascii") 

727 except UnicodeEncodeError: 

728 raise ValueError("boundary should contain ASCII only chars") from None 

729 ctype = f"multipart/{subtype}; boundary={self._boundary_value}" 

730 

731 super().__init__(None, content_type=ctype) 

732 

733 self._parts: List[_Part] = [] 

734 

735 def __enter__(self) -> "MultipartWriter": 

736 return self 

737 

738 def __exit__( 

739 self, 

740 exc_type: Optional[Type[BaseException]], 

741 exc_val: Optional[BaseException], 

742 exc_tb: Optional[TracebackType], 

743 ) -> None: 

744 pass 

745 

746 def __iter__(self) -> Iterator[_Part]: 

747 return iter(self._parts) 

748 

749 def __len__(self) -> int: 

750 return len(self._parts) 

751 

752 def __bool__(self) -> bool: 

753 return True 

754 

755 _valid_tchar_regex = re.compile(rb"\A[!#$%&'*+\-.^_`|~\w]+\Z") 

756 _invalid_qdtext_char_regex = re.compile(rb"[\x00-\x08\x0A-\x1F\x7F]") 

757 

758 @property 

759 def _boundary_value(self) -> str: 

760 """Wrap boundary parameter value in quotes, if necessary. 

761 

762 Reads self.boundary and returns a unicode string. 

763 """ 

764 # Refer to RFCs 7231, 7230, 5234. 

765 # 

766 # parameter = token "=" ( token / quoted-string ) 

767 # token = 1*tchar 

768 # quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE 

769 # qdtext = HTAB / SP / %x21 / %x23-5B / %x5D-7E / obs-text 

770 # obs-text = %x80-FF 

771 # quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text ) 

772 # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" 

773 # / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" 

774 # / DIGIT / ALPHA 

775 # ; any VCHAR, except delimiters 

776 # VCHAR = %x21-7E 

777 value = self._boundary 

778 if re.match(self._valid_tchar_regex, value): 

779 return value.decode("ascii") # cannot fail 

780 

781 if re.search(self._invalid_qdtext_char_regex, value): 

782 raise ValueError("boundary value contains invalid characters") 

783 

784 # escape %x5C and %x22 

785 quoted_value_content = value.replace(b"\\", b"\\\\") 

786 quoted_value_content = quoted_value_content.replace(b'"', b'\\"') 

787 

788 return '"' + quoted_value_content.decode("ascii") + '"' 

789 

790 @property 

791 def boundary(self) -> str: 

792 return self._boundary.decode("ascii") 

793 

794 def append(self, obj: Any, headers: Optional[MultiMapping[str]] = None) -> Payload: 

795 if headers is None: 

796 headers = CIMultiDict() 

797 

798 if isinstance(obj, Payload): 

799 obj.headers.update(headers) 

800 return self.append_payload(obj) 

801 else: 

802 try: 

803 payload = get_payload(obj, headers=headers) 

804 except LookupError: 

805 raise TypeError("Cannot create payload from %r" % obj) 

806 else: 

807 return self.append_payload(payload) 

808 

809 def append_payload(self, payload: Payload) -> Payload: 

810 """Adds a new body part to multipart writer.""" 

811 # compression 

812 encoding: Optional[str] = payload.headers.get( 

813 CONTENT_ENCODING, 

814 "", 

815 ).lower() 

816 if encoding and encoding not in ("deflate", "gzip", "identity"): 

817 raise RuntimeError(f"unknown content encoding: {encoding}") 

818 if encoding == "identity": 

819 encoding = None 

820 

821 # te encoding 

822 te_encoding: Optional[str] = payload.headers.get( 

823 CONTENT_TRANSFER_ENCODING, 

824 "", 

825 ).lower() 

826 if te_encoding not in ("", "base64", "quoted-printable", "binary"): 

827 raise RuntimeError( 

828 "unknown content transfer encoding: {}" "".format(te_encoding) 

829 ) 

830 if te_encoding == "binary": 

831 te_encoding = None 

832 

833 # size 

834 size = payload.size 

835 if size is not None and not (encoding or te_encoding): 

836 payload.headers[CONTENT_LENGTH] = str(size) 

837 

838 self._parts.append((payload, encoding, te_encoding)) # type: ignore[arg-type] 

839 return payload 

840 

841 def append_json( 

842 self, obj: Any, headers: Optional[MultiMapping[str]] = None 

843 ) -> Payload: 

844 """Helper to append JSON part.""" 

845 if headers is None: 

846 headers = CIMultiDict() 

847 

848 return self.append_payload(JsonPayload(obj, headers=headers)) 

849 

850 def append_form( 

851 self, 

852 obj: Union[Sequence[Tuple[str, str]], Mapping[str, str]], 

853 headers: Optional[MultiMapping[str]] = None, 

854 ) -> Payload: 

855 """Helper to append form urlencoded part.""" 

856 assert isinstance(obj, (Sequence, Mapping)) 

857 

858 if headers is None: 

859 headers = CIMultiDict() 

860 

861 if isinstance(obj, Mapping): 

862 obj = list(obj.items()) 

863 data = urlencode(obj, doseq=True) 

864 

865 return self.append_payload( 

866 StringPayload( 

867 data, headers=headers, content_type="application/x-www-form-urlencoded" 

868 ) 

869 ) 

870 

871 @property 

872 def size(self) -> Optional[int]: 

873 """Size of the payload.""" 

874 total = 0 

875 for part, encoding, te_encoding in self._parts: 

876 if encoding or te_encoding or part.size is None: 

877 return None 

878 

879 total += int( 

880 2 

881 + len(self._boundary) 

882 + 2 

883 + part.size # b'--'+self._boundary+b'\r\n' 

884 + len(part._binary_headers) 

885 + 2 # b'\r\n' 

886 ) 

887 

888 total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n' 

889 return total 

890 

891 async def write(self, writer: Any, close_boundary: bool = True) -> None: 

892 """Write body.""" 

893 for part, encoding, te_encoding in self._parts: 

894 await writer.write(b"--" + self._boundary + b"\r\n") 

895 await writer.write(part._binary_headers) 

896 

897 if encoding or te_encoding: 

898 w = MultipartPayloadWriter(writer) 

899 if encoding: 

900 w.enable_compression(encoding) 

901 if te_encoding: 

902 w.enable_encoding(te_encoding) 

903 await part.write(w) # type: ignore[arg-type] 

904 await w.write_eof() 

905 else: 

906 await part.write(writer) 

907 

908 await writer.write(b"\r\n") 

909 

910 if close_boundary: 

911 await writer.write(b"--" + self._boundary + b"--\r\n") 

912 

913 

914class MultipartPayloadWriter: 

915 def __init__(self, writer: Any) -> None: 

916 self._writer = writer 

917 self._encoding: Optional[str] = None 

918 self._compress: Optional[ZLibCompressor] = None 

919 self._encoding_buffer: Optional[bytearray] = None 

920 

921 def enable_encoding(self, encoding: str) -> None: 

922 if encoding == "base64": 

923 self._encoding = encoding 

924 self._encoding_buffer = bytearray() 

925 elif encoding == "quoted-printable": 

926 self._encoding = "quoted-printable" 

927 

928 def enable_compression( 

929 self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY 

930 ) -> None: 

931 self._compress = ZLibCompressor( 

932 encoding=encoding, 

933 suppress_deflate_header=True, 

934 strategy=strategy, 

935 ) 

936 

937 async def write_eof(self) -> None: 

938 if self._compress is not None: 

939 chunk = self._compress.flush() 

940 if chunk: 

941 self._compress = None 

942 await self.write(chunk) 

943 

944 if self._encoding == "base64": 

945 if self._encoding_buffer: 

946 await self._writer.write(base64.b64encode(self._encoding_buffer)) 

947 

948 async def write(self, chunk: bytes) -> None: 

949 if self._compress is not None: 

950 if chunk: 

951 chunk = await self._compress.compress(chunk) 

952 if not chunk: 

953 return 

954 

955 if self._encoding == "base64": 

956 buf = self._encoding_buffer 

957 assert buf is not None 

958 buf.extend(chunk) 

959 

960 if buf: 

961 div, mod = divmod(len(buf), 3) 

962 enc_chunk, self._encoding_buffer = (buf[: div * 3], buf[div * 3 :]) 

963 if enc_chunk: 

964 b64chunk = base64.b64encode(enc_chunk) 

965 await self._writer.write(b64chunk) 

966 elif self._encoding == "quoted-printable": 

967 await self._writer.write(binascii.b2a_qp(chunk)) 

968 else: 

969 await self._writer.write(chunk)