1import abc
2import asyncio
3import re
4import string
5from contextlib import suppress
6from enum import IntEnum
7from typing import (
8 Any,
9 ClassVar,
10 Final,
11 Generic,
12 List,
13 Literal,
14 NamedTuple,
15 Optional,
16 Pattern,
17 Set,
18 Tuple,
19 Type,
20 TypeVar,
21 Union,
22)
23
24from multidict import CIMultiDict, CIMultiDictProxy, istr
25from yarl import URL
26
27from . import hdrs
28from .base_protocol import BaseProtocol
29from .compression_utils import HAS_BROTLI, BrotliDecompressor, ZLibDecompressor
30from .helpers import (
31 _EXC_SENTINEL,
32 DEBUG,
33 EMPTY_BODY_METHODS,
34 EMPTY_BODY_STATUS_CODES,
35 NO_EXTENSIONS,
36 BaseTimerContext,
37 set_exception,
38)
39from .http_exceptions import (
40 BadHttpMessage,
41 BadHttpMethod,
42 BadStatusLine,
43 ContentEncodingError,
44 ContentLengthError,
45 InvalidHeader,
46 InvalidURLError,
47 LineTooLong,
48 TransferEncodingError,
49)
50from .http_writer import HttpVersion, HttpVersion10
51from .streams import EMPTY_PAYLOAD, StreamReader
52from .typedefs import RawHeaders
53
54__all__ = (
55 "HeadersParser",
56 "HttpParser",
57 "HttpRequestParser",
58 "HttpResponseParser",
59 "RawRequestMessage",
60 "RawResponseMessage",
61)
62
63_SEP = Literal[b"\r\n", b"\n"]
64
65ASCIISET: Final[Set[str]] = set(string.printable)
66
67# See https://www.rfc-editor.org/rfc/rfc9110.html#name-overview
68# and https://www.rfc-editor.org/rfc/rfc9110.html#name-tokens
69#
70# method = token
71# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
72# "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA
73# token = 1*tchar
74_TCHAR_SPECIALS: Final[str] = re.escape("!#$%&'*+-.^_`|~")
75TOKENRE: Final[Pattern[str]] = re.compile(f"[0-9A-Za-z{_TCHAR_SPECIALS}]+")
76VERSRE: Final[Pattern[str]] = re.compile(r"HTTP/(\d)\.(\d)", re.ASCII)
77DIGITS: Final[Pattern[str]] = re.compile(r"\d+", re.ASCII)
78HEXDIGITS: Final[Pattern[bytes]] = re.compile(rb"[0-9a-fA-F]+")
79
80
81class RawRequestMessage(NamedTuple):
82 method: str
83 path: str
84 version: HttpVersion
85 headers: "CIMultiDictProxy[str]"
86 raw_headers: RawHeaders
87 should_close: bool
88 compression: Optional[str]
89 upgrade: bool
90 chunked: bool
91 url: URL
92
93
94class RawResponseMessage(NamedTuple):
95 version: HttpVersion
96 code: int
97 reason: str
98 headers: CIMultiDictProxy[str]
99 raw_headers: RawHeaders
100 should_close: bool
101 compression: Optional[str]
102 upgrade: bool
103 chunked: bool
104
105
106_MsgT = TypeVar("_MsgT", RawRequestMessage, RawResponseMessage)
107
108
109class ParseState(IntEnum):
110
111 PARSE_NONE = 0
112 PARSE_LENGTH = 1
113 PARSE_CHUNKED = 2
114 PARSE_UNTIL_EOF = 3
115
116
117class ChunkState(IntEnum):
118 PARSE_CHUNKED_SIZE = 0
119 PARSE_CHUNKED_CHUNK = 1
120 PARSE_CHUNKED_CHUNK_EOF = 2
121 PARSE_MAYBE_TRAILERS = 3
122 PARSE_TRAILERS = 4
123
124
125class HeadersParser:
126 def __init__(
127 self,
128 max_line_size: int = 8190,
129 max_headers: int = 32768,
130 max_field_size: int = 8190,
131 lax: bool = False,
132 ) -> None:
133 self.max_line_size = max_line_size
134 self.max_headers = max_headers
135 self.max_field_size = max_field_size
136 self._lax = lax
137
138 def parse_headers(
139 self, lines: List[bytes]
140 ) -> Tuple["CIMultiDictProxy[str]", RawHeaders]:
141 headers: CIMultiDict[str] = CIMultiDict()
142 # note: "raw" does not mean inclusion of OWS before/after the field value
143 raw_headers = []
144
145 lines_idx = 0
146 line = lines[lines_idx]
147 line_count = len(lines)
148
149 while line:
150 # Parse initial header name : value pair.
151 try:
152 bname, bvalue = line.split(b":", 1)
153 except ValueError:
154 raise InvalidHeader(line) from None
155
156 if len(bname) == 0:
157 raise InvalidHeader(bname)
158
159 # https://www.rfc-editor.org/rfc/rfc9112.html#section-5.1-2
160 if {bname[0], bname[-1]} & {32, 9}: # {" ", "\t"}
161 raise InvalidHeader(line)
162
163 bvalue = bvalue.lstrip(b" \t")
164 if len(bname) > self.max_field_size:
165 raise LineTooLong(
166 "request header name {}".format(
167 bname.decode("utf8", "backslashreplace")
168 ),
169 str(self.max_field_size),
170 str(len(bname)),
171 )
172 name = bname.decode("utf-8", "surrogateescape")
173 if not TOKENRE.fullmatch(name):
174 raise InvalidHeader(bname)
175
176 header_length = len(bvalue)
177
178 # next line
179 lines_idx += 1
180 line = lines[lines_idx]
181
182 # consume continuation lines
183 continuation = self._lax and line and line[0] in (32, 9) # (' ', '\t')
184
185 # Deprecated: https://www.rfc-editor.org/rfc/rfc9112.html#name-obsolete-line-folding
186 if continuation:
187 bvalue_lst = [bvalue]
188 while continuation:
189 header_length += len(line)
190 if header_length > self.max_field_size:
191 raise LineTooLong(
192 "request header field {}".format(
193 bname.decode("utf8", "backslashreplace")
194 ),
195 str(self.max_field_size),
196 str(header_length),
197 )
198 bvalue_lst.append(line)
199
200 # next line
201 lines_idx += 1
202 if lines_idx < line_count:
203 line = lines[lines_idx]
204 if line:
205 continuation = line[0] in (32, 9) # (' ', '\t')
206 else:
207 line = b""
208 break
209 bvalue = b"".join(bvalue_lst)
210 else:
211 if header_length > self.max_field_size:
212 raise LineTooLong(
213 "request header field {}".format(
214 bname.decode("utf8", "backslashreplace")
215 ),
216 str(self.max_field_size),
217 str(header_length),
218 )
219
220 bvalue = bvalue.strip(b" \t")
221 value = bvalue.decode("utf-8", "surrogateescape")
222
223 # https://www.rfc-editor.org/rfc/rfc9110.html#section-5.5-5
224 if "\n" in value or "\r" in value or "\x00" in value:
225 raise InvalidHeader(bvalue)
226
227 headers.add(name, value)
228 raw_headers.append((bname, bvalue))
229
230 return (CIMultiDictProxy(headers), tuple(raw_headers))
231
232
233def _is_supported_upgrade(headers: CIMultiDictProxy[str]) -> bool:
234 """Check if the upgrade header is supported."""
235 return headers.get(hdrs.UPGRADE, "").lower() in {"tcp", "websocket"}
236
237
238class HttpParser(abc.ABC, Generic[_MsgT]):
239 lax: ClassVar[bool] = False
240
241 def __init__(
242 self,
243 protocol: Optional[BaseProtocol] = None,
244 loop: Optional[asyncio.AbstractEventLoop] = None,
245 limit: int = 2**16,
246 max_line_size: int = 8190,
247 max_headers: int = 32768,
248 max_field_size: int = 8190,
249 timer: Optional[BaseTimerContext] = None,
250 code: Optional[int] = None,
251 method: Optional[str] = None,
252 payload_exception: Optional[Type[BaseException]] = None,
253 response_with_body: bool = True,
254 read_until_eof: bool = False,
255 auto_decompress: bool = True,
256 ) -> None:
257 self.protocol = protocol
258 self.loop = loop
259 self.max_line_size = max_line_size
260 self.max_headers = max_headers
261 self.max_field_size = max_field_size
262 self.timer = timer
263 self.code = code
264 self.method = method
265 self.payload_exception = payload_exception
266 self.response_with_body = response_with_body
267 self.read_until_eof = read_until_eof
268
269 self._lines: List[bytes] = []
270 self._tail = b""
271 self._upgraded = False
272 self._payload = None
273 self._payload_parser: Optional[HttpPayloadParser] = None
274 self._auto_decompress = auto_decompress
275 self._limit = limit
276 self._headers_parser = HeadersParser(
277 max_line_size, max_headers, max_field_size, self.lax
278 )
279
280 @abc.abstractmethod
281 def parse_message(self, lines: List[bytes]) -> _MsgT: ...
282
283 @abc.abstractmethod
284 def _is_chunked_te(self, te: str) -> bool: ...
285
286 def feed_eof(self) -> Optional[_MsgT]:
287 if self._payload_parser is not None:
288 self._payload_parser.feed_eof()
289 self._payload_parser = None
290 else:
291 # try to extract partial message
292 if self._tail:
293 self._lines.append(self._tail)
294
295 if self._lines:
296 if self._lines[-1] != "\r\n":
297 self._lines.append(b"")
298 with suppress(Exception):
299 return self.parse_message(self._lines)
300 return None
301
302 def feed_data(
303 self,
304 data: bytes,
305 SEP: _SEP = b"\r\n",
306 EMPTY: bytes = b"",
307 CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH,
308 METH_CONNECT: str = hdrs.METH_CONNECT,
309 SEC_WEBSOCKET_KEY1: istr = hdrs.SEC_WEBSOCKET_KEY1,
310 ) -> Tuple[List[Tuple[_MsgT, StreamReader]], bool, bytes]:
311
312 messages = []
313
314 if self._tail:
315 data, self._tail = self._tail + data, b""
316
317 data_len = len(data)
318 start_pos = 0
319 loop = self.loop
320
321 should_close = False
322 while start_pos < data_len:
323
324 # read HTTP message (request/response line + headers), \r\n\r\n
325 # and split by lines
326 if self._payload_parser is None and not self._upgraded:
327 pos = data.find(SEP, start_pos)
328 # consume \r\n
329 if pos == start_pos and not self._lines:
330 start_pos = pos + len(SEP)
331 continue
332
333 if pos >= start_pos:
334 if should_close:
335 raise BadHttpMessage("Data after `Connection: close`")
336
337 # line found
338 line = data[start_pos:pos]
339 if SEP == b"\n": # For lax response parsing
340 line = line.rstrip(b"\r")
341 self._lines.append(line)
342 start_pos = pos + len(SEP)
343
344 # \r\n\r\n found
345 if self._lines[-1] == EMPTY:
346 try:
347 msg: _MsgT = self.parse_message(self._lines)
348 finally:
349 self._lines.clear()
350
351 def get_content_length() -> Optional[int]:
352 # payload length
353 length_hdr = msg.headers.get(CONTENT_LENGTH)
354 if length_hdr is None:
355 return None
356
357 # Shouldn't allow +/- or other number formats.
358 # https://www.rfc-editor.org/rfc/rfc9110#section-8.6-2
359 # msg.headers is already stripped of leading/trailing wsp
360 if not DIGITS.fullmatch(length_hdr):
361 raise InvalidHeader(CONTENT_LENGTH)
362
363 return int(length_hdr)
364
365 length = get_content_length()
366 # do not support old websocket spec
367 if SEC_WEBSOCKET_KEY1 in msg.headers:
368 raise InvalidHeader(SEC_WEBSOCKET_KEY1)
369
370 self._upgraded = msg.upgrade and _is_supported_upgrade(
371 msg.headers
372 )
373
374 method = getattr(msg, "method", self.method)
375 # code is only present on responses
376 code = getattr(msg, "code", 0)
377
378 assert self.protocol is not None
379 # calculate payload
380 empty_body = code in EMPTY_BODY_STATUS_CODES or bool(
381 method and method in EMPTY_BODY_METHODS
382 )
383 if not empty_body and (
384 ((length is not None and length > 0) or msg.chunked)
385 and not self._upgraded
386 ):
387 payload = StreamReader(
388 self.protocol,
389 timer=self.timer,
390 loop=loop,
391 limit=self._limit,
392 )
393 payload_parser = HttpPayloadParser(
394 payload,
395 length=length,
396 chunked=msg.chunked,
397 method=method,
398 compression=msg.compression,
399 code=self.code,
400 response_with_body=self.response_with_body,
401 auto_decompress=self._auto_decompress,
402 lax=self.lax,
403 headers_parser=self._headers_parser,
404 )
405 if not payload_parser.done:
406 self._payload_parser = payload_parser
407 elif method == METH_CONNECT:
408 assert isinstance(msg, RawRequestMessage)
409 payload = StreamReader(
410 self.protocol,
411 timer=self.timer,
412 loop=loop,
413 limit=self._limit,
414 )
415 self._upgraded = True
416 self._payload_parser = HttpPayloadParser(
417 payload,
418 method=msg.method,
419 compression=msg.compression,
420 auto_decompress=self._auto_decompress,
421 lax=self.lax,
422 headers_parser=self._headers_parser,
423 )
424 elif not empty_body and length is None and self.read_until_eof:
425 payload = StreamReader(
426 self.protocol,
427 timer=self.timer,
428 loop=loop,
429 limit=self._limit,
430 )
431 payload_parser = HttpPayloadParser(
432 payload,
433 length=length,
434 chunked=msg.chunked,
435 method=method,
436 compression=msg.compression,
437 code=self.code,
438 response_with_body=self.response_with_body,
439 auto_decompress=self._auto_decompress,
440 lax=self.lax,
441 headers_parser=self._headers_parser,
442 )
443 if not payload_parser.done:
444 self._payload_parser = payload_parser
445 else:
446 payload = EMPTY_PAYLOAD
447
448 messages.append((msg, payload))
449 should_close = msg.should_close
450 else:
451 self._tail = data[start_pos:]
452 data = EMPTY
453 break
454
455 # no parser, just store
456 elif self._payload_parser is None and self._upgraded:
457 assert not self._lines
458 break
459
460 # feed payload
461 elif data and start_pos < data_len:
462 assert not self._lines
463 assert self._payload_parser is not None
464 try:
465 eof, data = self._payload_parser.feed_data(data[start_pos:], SEP)
466 except BaseException as underlying_exc:
467 reraised_exc = underlying_exc
468 if self.payload_exception is not None:
469 reraised_exc = self.payload_exception(str(underlying_exc))
470
471 set_exception(
472 self._payload_parser.payload,
473 reraised_exc,
474 underlying_exc,
475 )
476
477 eof = True
478 data = b""
479 if isinstance(
480 underlying_exc, (InvalidHeader, TransferEncodingError)
481 ):
482 raise
483
484 if eof:
485 start_pos = 0
486 data_len = len(data)
487 self._payload_parser = None
488 continue
489 else:
490 break
491
492 if data and start_pos < data_len:
493 data = data[start_pos:]
494 else:
495 data = EMPTY
496
497 return messages, self._upgraded, data
498
499 def parse_headers(
500 self, lines: List[bytes]
501 ) -> Tuple[
502 "CIMultiDictProxy[str]", RawHeaders, Optional[bool], Optional[str], bool, bool
503 ]:
504 """Parses RFC 5322 headers from a stream.
505
506 Line continuations are supported. Returns list of header name
507 and value pairs. Header name is in upper case.
508 """
509 headers, raw_headers = self._headers_parser.parse_headers(lines)
510 close_conn = None
511 encoding = None
512 upgrade = False
513 chunked = False
514
515 # https://www.rfc-editor.org/rfc/rfc9110.html#section-5.5-6
516 # https://www.rfc-editor.org/rfc/rfc9110.html#name-collected-abnf
517 singletons = (
518 hdrs.CONTENT_LENGTH,
519 hdrs.CONTENT_LOCATION,
520 hdrs.CONTENT_RANGE,
521 hdrs.CONTENT_TYPE,
522 hdrs.ETAG,
523 hdrs.HOST,
524 hdrs.MAX_FORWARDS,
525 hdrs.SERVER,
526 hdrs.TRANSFER_ENCODING,
527 hdrs.USER_AGENT,
528 )
529 bad_hdr = next((h for h in singletons if len(headers.getall(h, ())) > 1), None)
530 if bad_hdr is not None:
531 raise BadHttpMessage(f"Duplicate '{bad_hdr}' header found.")
532
533 # keep-alive
534 conn = headers.get(hdrs.CONNECTION)
535 if conn:
536 v = conn.lower()
537 if v == "close":
538 close_conn = True
539 elif v == "keep-alive":
540 close_conn = False
541 # https://www.rfc-editor.org/rfc/rfc9110.html#name-101-switching-protocols
542 elif v == "upgrade" and headers.get(hdrs.UPGRADE):
543 upgrade = True
544
545 # encoding
546 enc = headers.get(hdrs.CONTENT_ENCODING)
547 if enc:
548 enc = enc.lower()
549 if enc in ("gzip", "deflate", "br"):
550 encoding = enc
551
552 # chunking
553 te = headers.get(hdrs.TRANSFER_ENCODING)
554 if te is not None:
555 if self._is_chunked_te(te):
556 chunked = True
557
558 if hdrs.CONTENT_LENGTH in headers:
559 raise BadHttpMessage(
560 "Transfer-Encoding can't be present with Content-Length",
561 )
562
563 return (headers, raw_headers, close_conn, encoding, upgrade, chunked)
564
565 def set_upgraded(self, val: bool) -> None:
566 """Set connection upgraded (to websocket) mode.
567
568 :param bool val: new state.
569 """
570 self._upgraded = val
571
572
573class HttpRequestParser(HttpParser[RawRequestMessage]):
574 """Read request status line.
575
576 Exception .http_exceptions.BadStatusLine
577 could be raised in case of any errors in status line.
578 Returns RawRequestMessage.
579 """
580
581 def parse_message(self, lines: List[bytes]) -> RawRequestMessage:
582 # request line
583 line = lines[0].decode("utf-8", "surrogateescape")
584 try:
585 method, path, version = line.split(" ", maxsplit=2)
586 except ValueError:
587 raise BadHttpMethod(line) from None
588
589 if len(path) > self.max_line_size:
590 raise LineTooLong(
591 "Status line is too long", str(self.max_line_size), str(len(path))
592 )
593
594 # method
595 if not TOKENRE.fullmatch(method):
596 raise BadHttpMethod(method)
597
598 # version
599 match = VERSRE.fullmatch(version)
600 if match is None:
601 raise BadStatusLine(line)
602 version_o = HttpVersion(int(match.group(1)), int(match.group(2)))
603
604 if method == "CONNECT":
605 # authority-form,
606 # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.3
607 url = URL.build(authority=path, encoded=True)
608 elif path.startswith("/"):
609 # origin-form,
610 # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.1
611 path_part, _hash_separator, url_fragment = path.partition("#")
612 path_part, _question_mark_separator, qs_part = path_part.partition("?")
613
614 # NOTE: `yarl.URL.build()` is used to mimic what the Cython-based
615 # NOTE: parser does, otherwise it results into the same
616 # NOTE: HTTP Request-Line input producing different
617 # NOTE: `yarl.URL()` objects
618 url = URL.build(
619 path=path_part,
620 query_string=qs_part,
621 fragment=url_fragment,
622 encoded=True,
623 )
624 elif path == "*" and method == "OPTIONS":
625 # asterisk-form,
626 url = URL(path, encoded=True)
627 else:
628 # absolute-form for proxy maybe,
629 # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.2
630 url = URL(path, encoded=True)
631 if url.scheme == "":
632 # not absolute-form
633 raise InvalidURLError(
634 path.encode(errors="surrogateescape").decode("latin1")
635 )
636
637 # read headers
638 (
639 headers,
640 raw_headers,
641 close,
642 compression,
643 upgrade,
644 chunked,
645 ) = self.parse_headers(lines[1:])
646
647 if close is None: # then the headers weren't set in the request
648 if version_o <= HttpVersion10: # HTTP 1.0 must asks to not close
649 close = True
650 else: # HTTP 1.1 must ask to close.
651 close = False
652
653 return RawRequestMessage(
654 method,
655 path,
656 version_o,
657 headers,
658 raw_headers,
659 close,
660 compression,
661 upgrade,
662 chunked,
663 url,
664 )
665
666 def _is_chunked_te(self, te: str) -> bool:
667 if te.rsplit(",", maxsplit=1)[-1].strip(" \t").lower() == "chunked":
668 return True
669 # https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.3
670 raise BadHttpMessage("Request has invalid `Transfer-Encoding`")
671
672
673class HttpResponseParser(HttpParser[RawResponseMessage]):
674 """Read response status line and headers.
675
676 BadStatusLine could be raised in case of any errors in status line.
677 Returns RawResponseMessage.
678 """
679
680 # Lax mode should only be enabled on response parser.
681 lax = not DEBUG
682
683 def feed_data(
684 self,
685 data: bytes,
686 SEP: Optional[_SEP] = None,
687 *args: Any,
688 **kwargs: Any,
689 ) -> Tuple[List[Tuple[RawResponseMessage, StreamReader]], bool, bytes]:
690 if SEP is None:
691 SEP = b"\r\n" if DEBUG else b"\n"
692 return super().feed_data(data, SEP, *args, **kwargs)
693
694 def parse_message(self, lines: List[bytes]) -> RawResponseMessage:
695 line = lines[0].decode("utf-8", "surrogateescape")
696 try:
697 version, status = line.split(maxsplit=1)
698 except ValueError:
699 raise BadStatusLine(line) from None
700
701 try:
702 status, reason = status.split(maxsplit=1)
703 except ValueError:
704 status = status.strip()
705 reason = ""
706
707 if len(reason) > self.max_line_size:
708 raise LineTooLong(
709 "Status line is too long", str(self.max_line_size), str(len(reason))
710 )
711
712 # version
713 match = VERSRE.fullmatch(version)
714 if match is None:
715 raise BadStatusLine(line)
716 version_o = HttpVersion(int(match.group(1)), int(match.group(2)))
717
718 # The status code is a three-digit ASCII number, no padding
719 if len(status) != 3 or not DIGITS.fullmatch(status):
720 raise BadStatusLine(line)
721 status_i = int(status)
722
723 # read headers
724 (
725 headers,
726 raw_headers,
727 close,
728 compression,
729 upgrade,
730 chunked,
731 ) = self.parse_headers(lines[1:])
732
733 if close is None:
734 if version_o <= HttpVersion10:
735 close = True
736 # https://www.rfc-editor.org/rfc/rfc9112.html#name-message-body-length
737 elif 100 <= status_i < 200 or status_i in {204, 304}:
738 close = False
739 elif hdrs.CONTENT_LENGTH in headers or hdrs.TRANSFER_ENCODING in headers:
740 close = False
741 else:
742 # https://www.rfc-editor.org/rfc/rfc9112.html#section-6.3-2.8
743 close = True
744
745 return RawResponseMessage(
746 version_o,
747 status_i,
748 reason.strip(),
749 headers,
750 raw_headers,
751 close,
752 compression,
753 upgrade,
754 chunked,
755 )
756
757 def _is_chunked_te(self, te: str) -> bool:
758 # https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.2
759 return te.rsplit(",", maxsplit=1)[-1].strip(" \t").lower() == "chunked"
760
761
762class HttpPayloadParser:
763 def __init__(
764 self,
765 payload: StreamReader,
766 length: Optional[int] = None,
767 chunked: bool = False,
768 compression: Optional[str] = None,
769 code: Optional[int] = None,
770 method: Optional[str] = None,
771 response_with_body: bool = True,
772 auto_decompress: bool = True,
773 lax: bool = False,
774 *,
775 headers_parser: HeadersParser,
776 ) -> None:
777 self._length = 0
778 self._type = ParseState.PARSE_UNTIL_EOF
779 self._chunk = ChunkState.PARSE_CHUNKED_SIZE
780 self._chunk_size = 0
781 self._chunk_tail = b""
782 self._auto_decompress = auto_decompress
783 self._lax = lax
784 self._headers_parser = headers_parser
785 self._trailer_lines: list[bytes] = []
786 self.done = False
787
788 # payload decompression wrapper
789 if response_with_body and compression and self._auto_decompress:
790 real_payload: Union[StreamReader, DeflateBuffer] = DeflateBuffer(
791 payload, compression
792 )
793 else:
794 real_payload = payload
795
796 # payload parser
797 if not response_with_body:
798 # don't parse payload if it's not expected to be received
799 self._type = ParseState.PARSE_NONE
800 real_payload.feed_eof()
801 self.done = True
802 elif chunked:
803 self._type = ParseState.PARSE_CHUNKED
804 elif length is not None:
805 self._type = ParseState.PARSE_LENGTH
806 self._length = length
807 if self._length == 0:
808 real_payload.feed_eof()
809 self.done = True
810
811 self.payload = real_payload
812
813 def feed_eof(self) -> None:
814 if self._type == ParseState.PARSE_UNTIL_EOF:
815 self.payload.feed_eof()
816 elif self._type == ParseState.PARSE_LENGTH:
817 raise ContentLengthError(
818 "Not enough data to satisfy content length header."
819 )
820 elif self._type == ParseState.PARSE_CHUNKED:
821 raise TransferEncodingError(
822 "Not enough data to satisfy transfer length header."
823 )
824
825 def feed_data(
826 self, chunk: bytes, SEP: _SEP = b"\r\n", CHUNK_EXT: bytes = b";"
827 ) -> Tuple[bool, bytes]:
828 # Read specified amount of bytes
829 if self._type == ParseState.PARSE_LENGTH:
830 required = self._length
831 chunk_len = len(chunk)
832
833 if required >= chunk_len:
834 self._length = required - chunk_len
835 self.payload.feed_data(chunk, chunk_len)
836 if self._length == 0:
837 self.payload.feed_eof()
838 return True, b""
839 else:
840 self._length = 0
841 self.payload.feed_data(chunk[:required], required)
842 self.payload.feed_eof()
843 return True, chunk[required:]
844
845 # Chunked transfer encoding parser
846 elif self._type == ParseState.PARSE_CHUNKED:
847 if self._chunk_tail:
848 chunk = self._chunk_tail + chunk
849 self._chunk_tail = b""
850
851 while chunk:
852
853 # read next chunk size
854 if self._chunk == ChunkState.PARSE_CHUNKED_SIZE:
855 pos = chunk.find(SEP)
856 if pos >= 0:
857 i = chunk.find(CHUNK_EXT, 0, pos)
858 if i >= 0:
859 size_b = chunk[:i] # strip chunk-extensions
860 # Verify no LF in the chunk-extension
861 if b"\n" in (ext := chunk[i:pos]):
862 exc = TransferEncodingError(
863 f"Unexpected LF in chunk-extension: {ext!r}"
864 )
865 set_exception(self.payload, exc)
866 raise exc
867 else:
868 size_b = chunk[:pos]
869
870 if self._lax: # Allow whitespace in lax mode.
871 size_b = size_b.strip()
872
873 if not re.fullmatch(HEXDIGITS, size_b):
874 exc = TransferEncodingError(
875 chunk[:pos].decode("ascii", "surrogateescape")
876 )
877 set_exception(self.payload, exc)
878 raise exc
879 size = int(bytes(size_b), 16)
880
881 chunk = chunk[pos + len(SEP) :]
882 if size == 0: # eof marker
883 self._chunk = ChunkState.PARSE_TRAILERS
884 if self._lax and chunk.startswith(b"\r"):
885 chunk = chunk[1:]
886 else:
887 self._chunk = ChunkState.PARSE_CHUNKED_CHUNK
888 self._chunk_size = size
889 self.payload.begin_http_chunk_receiving()
890 else:
891 self._chunk_tail = chunk
892 return False, b""
893
894 # read chunk and feed buffer
895 if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK:
896 required = self._chunk_size
897 chunk_len = len(chunk)
898
899 if required > chunk_len:
900 self._chunk_size = required - chunk_len
901 self.payload.feed_data(chunk, chunk_len)
902 return False, b""
903 else:
904 self._chunk_size = 0
905 self.payload.feed_data(chunk[:required], required)
906 chunk = chunk[required:]
907 self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF
908 self.payload.end_http_chunk_receiving()
909
910 # toss the CRLF at the end of the chunk
911 if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK_EOF:
912 if self._lax and chunk.startswith(b"\r"):
913 chunk = chunk[1:]
914 if chunk[: len(SEP)] == SEP:
915 chunk = chunk[len(SEP) :]
916 self._chunk = ChunkState.PARSE_CHUNKED_SIZE
917 else:
918 self._chunk_tail = chunk
919 return False, b""
920
921 if self._chunk == ChunkState.PARSE_TRAILERS:
922 pos = chunk.find(SEP)
923 if pos < 0: # No line found
924 self._chunk_tail = chunk
925 return False, b""
926
927 line = chunk[:pos]
928 chunk = chunk[pos + len(SEP) :]
929 if SEP == b"\n": # For lax response parsing
930 line = line.rstrip(b"\r")
931 self._trailer_lines.append(line)
932
933 # \r\n\r\n found, end of stream
934 if self._trailer_lines[-1] == b"":
935 # Headers and trailers are defined the same way,
936 # so we reuse the HeadersParser here.
937 try:
938 trailers, raw_trailers = self._headers_parser.parse_headers(
939 self._trailer_lines
940 )
941 finally:
942 self._trailer_lines.clear()
943 self.payload.feed_eof()
944 return True, chunk
945
946 # Read all bytes until eof
947 elif self._type == ParseState.PARSE_UNTIL_EOF:
948 self.payload.feed_data(chunk, len(chunk))
949
950 return False, b""
951
952
953class DeflateBuffer:
954 """DeflateStream decompress stream and feed data into specified stream."""
955
956 decompressor: Any
957
958 def __init__(self, out: StreamReader, encoding: Optional[str]) -> None:
959 self.out = out
960 self.size = 0
961 self.encoding = encoding
962 self._started_decoding = False
963
964 self.decompressor: Union[BrotliDecompressor, ZLibDecompressor]
965 if encoding == "br":
966 if not HAS_BROTLI: # pragma: no cover
967 raise ContentEncodingError(
968 "Can not decode content-encoding: brotli (br). "
969 "Please install `Brotli`"
970 )
971 self.decompressor = BrotliDecompressor()
972 else:
973 self.decompressor = ZLibDecompressor(encoding=encoding)
974
975 def set_exception(
976 self,
977 exc: BaseException,
978 exc_cause: BaseException = _EXC_SENTINEL,
979 ) -> None:
980 set_exception(self.out, exc, exc_cause)
981
982 def feed_data(self, chunk: bytes, size: int) -> None:
983 if not size:
984 return
985
986 self.size += size
987
988 # RFC1950
989 # bits 0..3 = CM = 0b1000 = 8 = "deflate"
990 # bits 4..7 = CINFO = 1..7 = windows size.
991 if (
992 not self._started_decoding
993 and self.encoding == "deflate"
994 and chunk[0] & 0xF != 8
995 ):
996 # Change the decoder to decompress incorrectly compressed data
997 # Actually we should issue a warning about non-RFC-compliant data.
998 self.decompressor = ZLibDecompressor(
999 encoding=self.encoding, suppress_deflate_header=True
1000 )
1001
1002 try:
1003 chunk = self.decompressor.decompress_sync(chunk)
1004 except Exception:
1005 raise ContentEncodingError(
1006 "Can not decode content-encoding: %s" % self.encoding
1007 )
1008
1009 self._started_decoding = True
1010
1011 if chunk:
1012 self.out.feed_data(chunk, len(chunk))
1013
1014 def feed_eof(self) -> None:
1015 chunk = self.decompressor.flush()
1016
1017 if chunk or self.size > 0:
1018 self.out.feed_data(chunk, len(chunk))
1019 if self.encoding == "deflate" and not self.decompressor.eof:
1020 raise ContentEncodingError("deflate")
1021
1022 self.out.feed_eof()
1023
1024 def begin_http_chunk_receiving(self) -> None:
1025 self.out.begin_http_chunk_receiving()
1026
1027 def end_http_chunk_receiving(self) -> None:
1028 self.out.end_http_chunk_receiving()
1029
1030
1031HttpRequestParserPy = HttpRequestParser
1032HttpResponseParserPy = HttpResponseParser
1033RawRequestMessagePy = RawRequestMessage
1034RawResponseMessagePy = RawResponseMessage
1035
1036try:
1037 if not NO_EXTENSIONS:
1038 from ._http_parser import ( # type: ignore[import-not-found,no-redef]
1039 HttpRequestParser,
1040 HttpResponseParser,
1041 RawRequestMessage,
1042 RawResponseMessage,
1043 )
1044
1045 HttpRequestParserC = HttpRequestParser
1046 HttpResponseParserC = HttpResponseParser
1047 RawRequestMessageC = RawRequestMessage
1048 RawResponseMessageC = RawResponseMessage
1049except ImportError: # pragma: no cover
1050 pass