1import asyncio
2import codecs
3import contextlib
4import functools
5import io
6import re
7import sys
8import traceback
9import warnings
10from hashlib import md5, sha1, sha256
11from http.cookies import CookieError, Morsel, SimpleCookie
12from types import MappingProxyType, TracebackType
13from typing import (
14 TYPE_CHECKING,
15 Any,
16 Callable,
17 Dict,
18 Iterable,
19 List,
20 Mapping,
21 Optional,
22 Tuple,
23 Type,
24 Union,
25 cast,
26)
27
28import attr
29from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
30from yarl import URL
31
32from . import hdrs, helpers, http, multipart, payload
33from .abc import AbstractStreamWriter
34from .client_exceptions import (
35 ClientConnectionError,
36 ClientOSError,
37 ClientResponseError,
38 ContentTypeError,
39 InvalidURL,
40 ServerFingerprintMismatch,
41)
42from .compression_utils import HAS_BROTLI
43from .formdata import FormData
44from .helpers import (
45 BaseTimerContext,
46 BasicAuth,
47 HeadersMixin,
48 TimerNoop,
49 basicauth_from_netrc,
50 netrc_from_env,
51 noop,
52 reify,
53 set_exception,
54 set_result,
55)
56from .http import (
57 SERVER_SOFTWARE,
58 HttpVersion,
59 HttpVersion10,
60 HttpVersion11,
61 StreamWriter,
62)
63from .log import client_logger
64from .streams import StreamReader
65from .typedefs import (
66 DEFAULT_JSON_DECODER,
67 JSONDecoder,
68 LooseCookies,
69 LooseHeaders,
70 RawHeaders,
71)
72
73try:
74 import ssl
75 from ssl import SSLContext
76except ImportError: # pragma: no cover
77 ssl = None # type: ignore[assignment]
78 SSLContext = object # type: ignore[misc,assignment]
79
80
81__all__ = ("ClientRequest", "ClientResponse", "RequestInfo", "Fingerprint")
82
83
84if TYPE_CHECKING:
85 from .client import ClientSession
86 from .connector import Connection
87 from .tracing import Trace
88
89
90_CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]")
91json_re = re.compile(r"^application/(?:[\w.+-]+?\+)?json")
92
93
94def _gen_default_accept_encoding() -> str:
95 return "gzip, deflate, br" if HAS_BROTLI else "gzip, deflate"
96
97
98@attr.s(auto_attribs=True, frozen=True, slots=True)
99class ContentDisposition:
100 type: Optional[str]
101 parameters: "MappingProxyType[str, str]"
102 filename: Optional[str]
103
104
105@attr.s(auto_attribs=True, frozen=True, slots=True)
106class RequestInfo:
107 url: URL
108 method: str
109 headers: "CIMultiDictProxy[str]"
110 real_url: URL = attr.ib()
111
112 @real_url.default
113 def real_url_default(self) -> URL:
114 return self.url
115
116
117class Fingerprint:
118 HASHFUNC_BY_DIGESTLEN = {
119 16: md5,
120 20: sha1,
121 32: sha256,
122 }
123
124 def __init__(self, fingerprint: bytes) -> None:
125 digestlen = len(fingerprint)
126 hashfunc = self.HASHFUNC_BY_DIGESTLEN.get(digestlen)
127 if not hashfunc:
128 raise ValueError("fingerprint has invalid length")
129 elif hashfunc is md5 or hashfunc is sha1:
130 raise ValueError(
131 "md5 and sha1 are insecure and " "not supported. Use sha256."
132 )
133 self._hashfunc = hashfunc
134 self._fingerprint = fingerprint
135
136 @property
137 def fingerprint(self) -> bytes:
138 return self._fingerprint
139
140 def check(self, transport: asyncio.Transport) -> None:
141 if not transport.get_extra_info("sslcontext"):
142 return
143 sslobj = transport.get_extra_info("ssl_object")
144 cert = sslobj.getpeercert(binary_form=True)
145 got = self._hashfunc(cert).digest()
146 if got != self._fingerprint:
147 host, port, *_ = transport.get_extra_info("peername")
148 raise ServerFingerprintMismatch(self._fingerprint, got, host, port)
149
150
151if ssl is not None:
152 SSL_ALLOWED_TYPES = (ssl.SSLContext, bool, Fingerprint, type(None))
153else: # pragma: no cover
154 SSL_ALLOWED_TYPES = (bool, type(None))
155
156
157def _merge_ssl_params(
158 ssl: Union["SSLContext", bool, Fingerprint],
159 verify_ssl: Optional[bool],
160 ssl_context: Optional["SSLContext"],
161 fingerprint: Optional[bytes],
162) -> Union["SSLContext", bool, Fingerprint]:
163 if ssl is None:
164 ssl = True # Double check for backwards compatibility
165 if verify_ssl is not None and not verify_ssl:
166 warnings.warn(
167 "verify_ssl is deprecated, use ssl=False instead",
168 DeprecationWarning,
169 stacklevel=3,
170 )
171 if ssl is not True:
172 raise ValueError(
173 "verify_ssl, ssl_context, fingerprint and ssl "
174 "parameters are mutually exclusive"
175 )
176 else:
177 ssl = False
178 if ssl_context is not None:
179 warnings.warn(
180 "ssl_context is deprecated, use ssl=context instead",
181 DeprecationWarning,
182 stacklevel=3,
183 )
184 if ssl is not True:
185 raise ValueError(
186 "verify_ssl, ssl_context, fingerprint and ssl "
187 "parameters are mutually exclusive"
188 )
189 else:
190 ssl = ssl_context
191 if fingerprint is not None:
192 warnings.warn(
193 "fingerprint is deprecated, " "use ssl=Fingerprint(fingerprint) instead",
194 DeprecationWarning,
195 stacklevel=3,
196 )
197 if ssl is not True:
198 raise ValueError(
199 "verify_ssl, ssl_context, fingerprint and ssl "
200 "parameters are mutually exclusive"
201 )
202 else:
203 ssl = Fingerprint(fingerprint)
204 if not isinstance(ssl, SSL_ALLOWED_TYPES):
205 raise TypeError(
206 "ssl should be SSLContext, bool, Fingerprint or None, "
207 "got {!r} instead.".format(ssl)
208 )
209 return ssl
210
211
212@attr.s(auto_attribs=True, slots=True, frozen=True)
213class ConnectionKey:
214 # the key should contain an information about used proxy / TLS
215 # to prevent reusing wrong connections from a pool
216 host: str
217 port: Optional[int]
218 is_ssl: bool
219 ssl: Union[SSLContext, bool, Fingerprint]
220 proxy: Optional[URL]
221 proxy_auth: Optional[BasicAuth]
222 proxy_headers_hash: Optional[int] # hash(CIMultiDict)
223
224
225def _is_expected_content_type(
226 response_content_type: str, expected_content_type: str
227) -> bool:
228 if expected_content_type == "application/json":
229 return json_re.match(response_content_type) is not None
230 return expected_content_type in response_content_type
231
232
233class ClientRequest:
234 GET_METHODS = {
235 hdrs.METH_GET,
236 hdrs.METH_HEAD,
237 hdrs.METH_OPTIONS,
238 hdrs.METH_TRACE,
239 }
240 POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
241 ALL_METHODS = GET_METHODS.union(POST_METHODS).union({hdrs.METH_DELETE})
242
243 DEFAULT_HEADERS = {
244 hdrs.ACCEPT: "*/*",
245 hdrs.ACCEPT_ENCODING: _gen_default_accept_encoding(),
246 }
247
248 body = b""
249 auth = None
250 response = None
251
252 __writer = None # async task for streaming data
253 _continue = None # waiter future for '100 Continue' response
254
255 # N.B.
256 # Adding __del__ method with self._writer closing doesn't make sense
257 # because _writer is instance method, thus it keeps a reference to self.
258 # Until writer has finished finalizer will not be called.
259
260 def __init__(
261 self,
262 method: str,
263 url: URL,
264 *,
265 params: Optional[Mapping[str, str]] = None,
266 headers: Optional[LooseHeaders] = None,
267 skip_auto_headers: Iterable[str] = frozenset(),
268 data: Any = None,
269 cookies: Optional[LooseCookies] = None,
270 auth: Optional[BasicAuth] = None,
271 version: http.HttpVersion = http.HttpVersion11,
272 compress: Optional[str] = None,
273 chunked: Optional[bool] = None,
274 expect100: bool = False,
275 loop: Optional[asyncio.AbstractEventLoop] = None,
276 response_class: Optional[Type["ClientResponse"]] = None,
277 proxy: Optional[URL] = None,
278 proxy_auth: Optional[BasicAuth] = None,
279 timer: Optional[BaseTimerContext] = None,
280 session: Optional["ClientSession"] = None,
281 ssl: Union[SSLContext, bool, Fingerprint] = True,
282 proxy_headers: Optional[LooseHeaders] = None,
283 traces: Optional[List["Trace"]] = None,
284 trust_env: bool = False,
285 server_hostname: Optional[str] = None,
286 ):
287 if loop is None:
288 loop = asyncio.get_event_loop()
289
290 match = _CONTAINS_CONTROL_CHAR_RE.search(method)
291 if match:
292 raise ValueError(
293 f"Method cannot contain non-token characters {method!r} "
294 "(found at least {match.group()!r})"
295 )
296
297 assert isinstance(url, URL), url
298 assert isinstance(proxy, (URL, type(None))), proxy
299 # FIXME: session is None in tests only, need to fix tests
300 # assert session is not None
301 self._session = cast("ClientSession", session)
302 if params:
303 q = MultiDict(url.query)
304 url2 = url.with_query(params)
305 q.extend(url2.query)
306 url = url.with_query(q)
307 self.original_url = url
308 self.url = url.with_fragment(None)
309 self.method = method.upper()
310 self.chunked = chunked
311 self.compress = compress
312 self.loop = loop
313 self.length = None
314 if response_class is None:
315 real_response_class = ClientResponse
316 else:
317 real_response_class = response_class
318 self.response_class: Type[ClientResponse] = real_response_class
319 self._timer = timer if timer is not None else TimerNoop()
320 self._ssl = ssl if ssl is not None else True
321 self.server_hostname = server_hostname
322
323 if loop.get_debug():
324 self._source_traceback = traceback.extract_stack(sys._getframe(1))
325
326 self.update_version(version)
327 self.update_host(url)
328 self.update_headers(headers)
329 self.update_auto_headers(skip_auto_headers)
330 self.update_cookies(cookies)
331 self.update_content_encoding(data)
332 self.update_auth(auth, trust_env)
333 self.update_proxy(proxy, proxy_auth, proxy_headers)
334
335 self.update_body_from_data(data)
336 if data is not None or self.method not in self.GET_METHODS:
337 self.update_transfer_encoding()
338 self.update_expect_continue(expect100)
339 if traces is None:
340 traces = []
341 self._traces = traces
342
343 def __reset_writer(self, _: object = None) -> None:
344 self.__writer = None
345
346 @property
347 def _writer(self) -> Optional["asyncio.Task[None]"]:
348 return self.__writer
349
350 @_writer.setter
351 def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
352 if self.__writer is not None:
353 self.__writer.remove_done_callback(self.__reset_writer)
354 self.__writer = writer
355 if writer is not None:
356 writer.add_done_callback(self.__reset_writer)
357
358 def is_ssl(self) -> bool:
359 return self.url.scheme in ("https", "wss")
360
361 @property
362 def ssl(self) -> Union["SSLContext", bool, Fingerprint]:
363 return self._ssl
364
365 @property
366 def connection_key(self) -> ConnectionKey:
367 proxy_headers = self.proxy_headers
368 if proxy_headers:
369 h: Optional[int] = hash(tuple((k, v) for k, v in proxy_headers.items()))
370 else:
371 h = None
372 return ConnectionKey(
373 self.host,
374 self.port,
375 self.is_ssl(),
376 self.ssl,
377 self.proxy,
378 self.proxy_auth,
379 h,
380 )
381
382 @property
383 def host(self) -> str:
384 ret = self.url.raw_host
385 assert ret is not None
386 return ret
387
388 @property
389 def port(self) -> Optional[int]:
390 return self.url.port
391
392 @property
393 def request_info(self) -> RequestInfo:
394 headers: CIMultiDictProxy[str] = CIMultiDictProxy(self.headers)
395 return RequestInfo(self.url, self.method, headers, self.original_url)
396
397 def update_host(self, url: URL) -> None:
398 """Update destination host, port and connection type (ssl)."""
399 # get host/port
400 if not url.raw_host:
401 raise InvalidURL(url)
402
403 # basic auth info
404 username, password = url.user, url.password
405 if username:
406 self.auth = helpers.BasicAuth(username, password or "")
407
408 def update_version(self, version: Union[http.HttpVersion, str]) -> None:
409 """Convert request version to two elements tuple.
410
411 parser HTTP version '1.1' => (1, 1)
412 """
413 if isinstance(version, str):
414 v = [part.strip() for part in version.split(".", 1)]
415 try:
416 version = http.HttpVersion(int(v[0]), int(v[1]))
417 except ValueError:
418 raise ValueError(
419 f"Can not parse http version number: {version}"
420 ) from None
421 self.version = version
422
423 def update_headers(self, headers: Optional[LooseHeaders]) -> None:
424 """Update request headers."""
425 self.headers: CIMultiDict[str] = CIMultiDict()
426
427 # add host
428 netloc = cast(str, self.url.raw_host)
429 if helpers.is_ipv6_address(netloc):
430 netloc = f"[{netloc}]"
431 # See https://github.com/aio-libs/aiohttp/issues/3636.
432 netloc = netloc.rstrip(".")
433 if self.url.port is not None and not self.url.is_default_port():
434 netloc += ":" + str(self.url.port)
435 self.headers[hdrs.HOST] = netloc
436
437 if headers:
438 if isinstance(headers, (dict, MultiDictProxy, MultiDict)):
439 headers = headers.items() # type: ignore[assignment]
440
441 for key, value in headers: # type: ignore[misc]
442 # A special case for Host header
443 if key.lower() == "host":
444 self.headers[key] = value
445 else:
446 self.headers.add(key, value)
447
448 def update_auto_headers(self, skip_auto_headers: Iterable[str]) -> None:
449 self.skip_auto_headers = CIMultiDict(
450 (hdr, None) for hdr in sorted(skip_auto_headers)
451 )
452 used_headers = self.headers.copy()
453 used_headers.extend(self.skip_auto_headers) # type: ignore[arg-type]
454
455 for hdr, val in self.DEFAULT_HEADERS.items():
456 if hdr not in used_headers:
457 self.headers.add(hdr, val)
458
459 if hdrs.USER_AGENT not in used_headers:
460 self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE
461
462 def update_cookies(self, cookies: Optional[LooseCookies]) -> None:
463 """Update request cookies header."""
464 if not cookies:
465 return
466
467 c = SimpleCookie()
468 if hdrs.COOKIE in self.headers:
469 c.load(self.headers.get(hdrs.COOKIE, ""))
470 del self.headers[hdrs.COOKIE]
471
472 if isinstance(cookies, Mapping):
473 iter_cookies = cookies.items()
474 else:
475 iter_cookies = cookies # type: ignore[assignment]
476 for name, value in iter_cookies:
477 if isinstance(value, Morsel):
478 # Preserve coded_value
479 mrsl_val = value.get(value.key, Morsel())
480 mrsl_val.set(value.key, value.value, value.coded_value)
481 c[name] = mrsl_val
482 else:
483 c[name] = value # type: ignore[assignment]
484
485 self.headers[hdrs.COOKIE] = c.output(header="", sep=";").strip()
486
487 def update_content_encoding(self, data: Any) -> None:
488 """Set request content encoding."""
489 if data is None:
490 return
491
492 enc = self.headers.get(hdrs.CONTENT_ENCODING, "").lower()
493 if enc:
494 if self.compress:
495 raise ValueError(
496 "compress can not be set " "if Content-Encoding header is set"
497 )
498 elif self.compress:
499 if not isinstance(self.compress, str):
500 self.compress = "deflate"
501 self.headers[hdrs.CONTENT_ENCODING] = self.compress
502 self.chunked = True # enable chunked, no need to deal with length
503
504 def update_transfer_encoding(self) -> None:
505 """Analyze transfer-encoding header."""
506 te = self.headers.get(hdrs.TRANSFER_ENCODING, "").lower()
507
508 if "chunked" in te:
509 if self.chunked:
510 raise ValueError(
511 "chunked can not be set "
512 'if "Transfer-Encoding: chunked" header is set'
513 )
514
515 elif self.chunked:
516 if hdrs.CONTENT_LENGTH in self.headers:
517 raise ValueError(
518 "chunked can not be set " "if Content-Length header is set"
519 )
520
521 self.headers[hdrs.TRANSFER_ENCODING] = "chunked"
522 else:
523 if hdrs.CONTENT_LENGTH not in self.headers:
524 self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))
525
526 def update_auth(self, auth: Optional[BasicAuth], trust_env: bool = False) -> None:
527 """Set basic auth."""
528 if auth is None:
529 auth = self.auth
530 if auth is None and trust_env and self.url.host is not None:
531 netrc_obj = netrc_from_env()
532 with contextlib.suppress(LookupError):
533 auth = basicauth_from_netrc(netrc_obj, self.url.host)
534 if auth is None:
535 return
536
537 if not isinstance(auth, helpers.BasicAuth):
538 raise TypeError("BasicAuth() tuple is required instead")
539
540 self.headers[hdrs.AUTHORIZATION] = auth.encode()
541
542 def update_body_from_data(self, body: Any) -> None:
543 if body is None:
544 return
545
546 # FormData
547 if isinstance(body, FormData):
548 body = body()
549
550 try:
551 body = payload.PAYLOAD_REGISTRY.get(body, disposition=None)
552 except payload.LookupError:
553 body = FormData(body)()
554
555 self.body = body
556
557 # enable chunked encoding if needed
558 if not self.chunked:
559 if hdrs.CONTENT_LENGTH not in self.headers:
560 size = body.size
561 if size is None:
562 self.chunked = True
563 else:
564 if hdrs.CONTENT_LENGTH not in self.headers:
565 self.headers[hdrs.CONTENT_LENGTH] = str(size)
566
567 # copy payload headers
568 assert body.headers
569 for (key, value) in body.headers.items():
570 if key in self.headers:
571 continue
572 if key in self.skip_auto_headers:
573 continue
574 self.headers[key] = value
575
576 def update_expect_continue(self, expect: bool = False) -> None:
577 if expect:
578 self.headers[hdrs.EXPECT] = "100-continue"
579 elif self.headers.get(hdrs.EXPECT, "").lower() == "100-continue":
580 expect = True
581
582 if expect:
583 self._continue = self.loop.create_future()
584
585 def update_proxy(
586 self,
587 proxy: Optional[URL],
588 proxy_auth: Optional[BasicAuth],
589 proxy_headers: Optional[LooseHeaders],
590 ) -> None:
591 if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
592 raise ValueError("proxy_auth must be None or BasicAuth() tuple")
593 self.proxy = proxy
594 self.proxy_auth = proxy_auth
595 self.proxy_headers = proxy_headers
596
597 def keep_alive(self) -> bool:
598 if self.version < HttpVersion10:
599 # keep alive not supported at all
600 return False
601 if self.version == HttpVersion10:
602 if self.headers.get(hdrs.CONNECTION) == "keep-alive":
603 return True
604 else: # no headers means we close for Http 1.0
605 return False
606 elif self.headers.get(hdrs.CONNECTION) == "close":
607 return False
608
609 return True
610
611 async def write_bytes(
612 self, writer: AbstractStreamWriter, conn: "Connection"
613 ) -> None:
614 """Support coroutines that yields bytes objects."""
615 # 100 response
616 if self._continue is not None:
617 try:
618 await writer.drain()
619 await self._continue
620 except asyncio.CancelledError:
621 return
622
623 protocol = conn.protocol
624 assert protocol is not None
625 try:
626 if isinstance(self.body, payload.Payload):
627 await self.body.write(writer)
628 else:
629 if isinstance(self.body, (bytes, bytearray)):
630 self.body = (self.body,) # type: ignore[assignment]
631
632 for chunk in self.body:
633 await writer.write(chunk) # type: ignore[arg-type]
634 except OSError as underlying_exc:
635 reraised_exc = underlying_exc
636
637 exc_is_not_timeout = underlying_exc.errno is not None or not isinstance(
638 underlying_exc, asyncio.TimeoutError
639 )
640 if exc_is_not_timeout:
641 reraised_exc = ClientOSError(
642 underlying_exc.errno,
643 f"Can not write request body for {self.url !s}",
644 )
645
646 set_exception(protocol, reraised_exc, underlying_exc)
647 except asyncio.CancelledError:
648 await writer.write_eof()
649 except Exception as underlying_exc:
650 set_exception(
651 protocol,
652 ClientConnectionError(
653 f"Failed to send bytes into the underlying connection {conn !s}",
654 ),
655 underlying_exc,
656 )
657 else:
658 await writer.write_eof()
659 protocol.start_timeout()
660
661 async def send(self, conn: "Connection") -> "ClientResponse":
662 # Specify request target:
663 # - CONNECT request must send authority form URI
664 # - not CONNECT proxy must send absolute form URI
665 # - most common is origin form URI
666 if self.method == hdrs.METH_CONNECT:
667 connect_host = self.url.raw_host
668 assert connect_host is not None
669 if helpers.is_ipv6_address(connect_host):
670 connect_host = f"[{connect_host}]"
671 path = f"{connect_host}:{self.url.port}"
672 elif self.proxy and not self.is_ssl():
673 path = str(self.url)
674 else:
675 path = self.url.raw_path
676 if self.url.raw_query_string:
677 path += "?" + self.url.raw_query_string
678
679 protocol = conn.protocol
680 assert protocol is not None
681 writer = StreamWriter(
682 protocol,
683 self.loop,
684 on_chunk_sent=functools.partial(
685 self._on_chunk_request_sent, self.method, self.url
686 ),
687 on_headers_sent=functools.partial(
688 self._on_headers_request_sent, self.method, self.url
689 ),
690 )
691
692 if self.compress:
693 writer.enable_compression(self.compress)
694
695 if self.chunked is not None:
696 writer.enable_chunking()
697
698 # set default content-type
699 if (
700 self.method in self.POST_METHODS
701 and hdrs.CONTENT_TYPE not in self.skip_auto_headers
702 and hdrs.CONTENT_TYPE not in self.headers
703 ):
704 self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream"
705
706 # set the connection header
707 connection = self.headers.get(hdrs.CONNECTION)
708 if not connection:
709 if self.keep_alive():
710 if self.version == HttpVersion10:
711 connection = "keep-alive"
712 else:
713 if self.version == HttpVersion11:
714 connection = "close"
715
716 if connection is not None:
717 self.headers[hdrs.CONNECTION] = connection
718
719 # status + headers
720 status_line = "{0} {1} HTTP/{v.major}.{v.minor}".format(
721 self.method, path, v=self.version
722 )
723 await writer.write_headers(status_line, self.headers)
724
725 self._writer = self.loop.create_task(self.write_bytes(writer, conn))
726
727 response_class = self.response_class
728 assert response_class is not None
729 self.response = response_class(
730 self.method,
731 self.original_url,
732 writer=self._writer,
733 continue100=self._continue,
734 timer=self._timer,
735 request_info=self.request_info,
736 traces=self._traces,
737 loop=self.loop,
738 session=self._session,
739 )
740 return self.response
741
742 async def close(self) -> None:
743 if self._writer is not None:
744 with contextlib.suppress(asyncio.CancelledError):
745 await self._writer
746
747 def terminate(self) -> None:
748 if self._writer is not None:
749 if not self.loop.is_closed():
750 self._writer.cancel()
751 self._writer.remove_done_callback(self.__reset_writer)
752 self._writer = None
753
754 async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None:
755 for trace in self._traces:
756 await trace.send_request_chunk_sent(method, url, chunk)
757
758 async def _on_headers_request_sent(
759 self, method: str, url: URL, headers: "CIMultiDict[str]"
760 ) -> None:
761 for trace in self._traces:
762 await trace.send_request_headers(method, url, headers)
763
764
765class ClientResponse(HeadersMixin):
766
767 # Some of these attributes are None when created,
768 # but will be set by the start() method.
769 # As the end user will likely never see the None values, we cheat the types below.
770 # from the Status-Line of the response
771 version: Optional[HttpVersion] = None # HTTP-Version
772 status: int = None # type: ignore[assignment] # Status-Code
773 reason: Optional[str] = None # Reason-Phrase
774
775 content: StreamReader = None # type: ignore[assignment] # Payload stream
776 _headers: CIMultiDictProxy[str] = None # type: ignore[assignment]
777 _raw_headers: RawHeaders = None # type: ignore[assignment]
778
779 _connection = None # current connection
780 _source_traceback: Optional[traceback.StackSummary] = None
781 # set up by ClientRequest after ClientResponse object creation
782 # post-init stage allows to not change ctor signature
783 _closed = True # to allow __del__ for non-initialized properly response
784 _released = False
785 __writer = None
786
787 def __init__(
788 self,
789 method: str,
790 url: URL,
791 *,
792 writer: "asyncio.Task[None]",
793 continue100: Optional["asyncio.Future[bool]"],
794 timer: BaseTimerContext,
795 request_info: RequestInfo,
796 traces: List["Trace"],
797 loop: asyncio.AbstractEventLoop,
798 session: "ClientSession",
799 ) -> None:
800 assert isinstance(url, URL)
801
802 self.method = method
803 self.cookies = SimpleCookie()
804
805 self._real_url = url
806 self._url = url.with_fragment(None)
807 self._body: Any = None
808 self._writer: Optional[asyncio.Task[None]] = writer
809 self._continue = continue100 # None by default
810 self._closed = True
811 self._history: Tuple[ClientResponse, ...] = ()
812 self._request_info = request_info
813 self._timer = timer if timer is not None else TimerNoop()
814 self._cache: Dict[str, Any] = {}
815 self._traces = traces
816 self._loop = loop
817 # store a reference to session #1985
818 self._session: Optional[ClientSession] = session
819 # Save reference to _resolve_charset, so that get_encoding() will still
820 # work after the response has finished reading the body.
821 if session is None:
822 # TODO: Fix session=None in tests (see ClientRequest.__init__).
823 self._resolve_charset: Callable[
824 ["ClientResponse", bytes], str
825 ] = lambda *_: "utf-8"
826 else:
827 self._resolve_charset = session._resolve_charset
828 if loop.get_debug():
829 self._source_traceback = traceback.extract_stack(sys._getframe(1))
830
831 def __reset_writer(self, _: object = None) -> None:
832 self.__writer = None
833
834 @property
835 def _writer(self) -> Optional["asyncio.Task[None]"]:
836 return self.__writer
837
838 @_writer.setter
839 def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
840 if self.__writer is not None:
841 self.__writer.remove_done_callback(self.__reset_writer)
842 self.__writer = writer
843 if writer is not None:
844 writer.add_done_callback(self.__reset_writer)
845
846 @reify
847 def url(self) -> URL:
848 return self._url
849
850 @reify
851 def url_obj(self) -> URL:
852 warnings.warn("Deprecated, use .url #1654", DeprecationWarning, stacklevel=2)
853 return self._url
854
855 @reify
856 def real_url(self) -> URL:
857 return self._real_url
858
859 @reify
860 def host(self) -> str:
861 assert self._url.host is not None
862 return self._url.host
863
864 @reify
865 def headers(self) -> "CIMultiDictProxy[str]":
866 return self._headers
867
868 @reify
869 def raw_headers(self) -> RawHeaders:
870 return self._raw_headers
871
872 @reify
873 def request_info(self) -> RequestInfo:
874 return self._request_info
875
876 @reify
877 def content_disposition(self) -> Optional[ContentDisposition]:
878 raw = self._headers.get(hdrs.CONTENT_DISPOSITION)
879 if raw is None:
880 return None
881 disposition_type, params_dct = multipart.parse_content_disposition(raw)
882 params = MappingProxyType(params_dct)
883 filename = multipart.content_disposition_filename(params)
884 return ContentDisposition(disposition_type, params, filename)
885
886 def __del__(self, _warnings: Any = warnings) -> None:
887 if self._closed:
888 return
889
890 if self._connection is not None:
891 self._connection.release()
892 self._cleanup_writer()
893
894 if self._loop.get_debug():
895 kwargs = {"source": self}
896 _warnings.warn(f"Unclosed response {self!r}", ResourceWarning, **kwargs)
897 context = {"client_response": self, "message": "Unclosed response"}
898 if self._source_traceback:
899 context["source_traceback"] = self._source_traceback
900 self._loop.call_exception_handler(context)
901
902 def __repr__(self) -> str:
903 out = io.StringIO()
904 ascii_encodable_url = str(self.url)
905 if self.reason:
906 ascii_encodable_reason = self.reason.encode(
907 "ascii", "backslashreplace"
908 ).decode("ascii")
909 else:
910 ascii_encodable_reason = "None"
911 print(
912 "<ClientResponse({}) [{} {}]>".format(
913 ascii_encodable_url, self.status, ascii_encodable_reason
914 ),
915 file=out,
916 )
917 print(self.headers, file=out)
918 return out.getvalue()
919
920 @property
921 def connection(self) -> Optional["Connection"]:
922 return self._connection
923
924 @reify
925 def history(self) -> Tuple["ClientResponse", ...]:
926 """A sequence of of responses, if redirects occurred."""
927 return self._history
928
929 @reify
930 def links(self) -> "MultiDictProxy[MultiDictProxy[Union[str, URL]]]":
931 links_str = ", ".join(self.headers.getall("link", []))
932
933 if not links_str:
934 return MultiDictProxy(MultiDict())
935
936 links: MultiDict[MultiDictProxy[Union[str, URL]]] = MultiDict()
937
938 for val in re.split(r",(?=\s*<)", links_str):
939 match = re.match(r"\s*<(.*)>(.*)", val)
940 if match is None: # pragma: no cover
941 # the check exists to suppress mypy error
942 continue
943 url, params_str = match.groups()
944 params = params_str.split(";")[1:]
945
946 link: MultiDict[Union[str, URL]] = MultiDict()
947
948 for param in params:
949 match = re.match(r"^\s*(\S*)\s*=\s*(['\"]?)(.*?)(\2)\s*$", param, re.M)
950 if match is None: # pragma: no cover
951 # the check exists to suppress mypy error
952 continue
953 key, _, value, _ = match.groups()
954
955 link.add(key, value)
956
957 key = link.get("rel", url)
958
959 link.add("url", self.url.join(URL(url)))
960
961 links.add(str(key), MultiDictProxy(link))
962
963 return MultiDictProxy(links)
964
965 async def start(self, connection: "Connection") -> "ClientResponse":
966 """Start response processing."""
967 self._closed = False
968 self._protocol = connection.protocol
969 self._connection = connection
970
971 with self._timer:
972 while True:
973 # read response
974 try:
975 protocol = self._protocol
976 message, payload = await protocol.read() # type: ignore[union-attr]
977 except http.HttpProcessingError as exc:
978 raise ClientResponseError(
979 self.request_info,
980 self.history,
981 status=exc.code,
982 message=exc.message,
983 headers=exc.headers,
984 ) from exc
985
986 if message.code < 100 or message.code > 199 or message.code == 101:
987 break
988
989 if self._continue is not None:
990 set_result(self._continue, True)
991 self._continue = None
992
993 # payload eof handler
994 payload.on_eof(self._response_eof)
995
996 # response status
997 self.version = message.version
998 self.status = message.code
999 self.reason = message.reason
1000
1001 # headers
1002 self._headers = message.headers # type is CIMultiDictProxy
1003 self._raw_headers = message.raw_headers # type is Tuple[bytes, bytes]
1004
1005 # payload
1006 self.content = payload
1007
1008 # cookies
1009 for hdr in self.headers.getall(hdrs.SET_COOKIE, ()):
1010 try:
1011 self.cookies.load(hdr)
1012 except CookieError as exc:
1013 client_logger.warning("Can not load response cookies: %s", exc)
1014 return self
1015
1016 def _response_eof(self) -> None:
1017 if self._closed:
1018 return
1019
1020 # protocol could be None because connection could be detached
1021 protocol = self._connection and self._connection.protocol
1022 if protocol is not None and protocol.upgraded:
1023 return
1024
1025 self._closed = True
1026 self._cleanup_writer()
1027 self._release_connection()
1028
1029 @property
1030 def closed(self) -> bool:
1031 return self._closed
1032
1033 def close(self) -> None:
1034 if not self._released:
1035 self._notify_content()
1036
1037 self._closed = True
1038 if self._loop is None or self._loop.is_closed():
1039 return
1040
1041 self._cleanup_writer()
1042 if self._connection is not None:
1043 self._connection.close()
1044 self._connection = None
1045
1046 def release(self) -> Any:
1047 if not self._released:
1048 self._notify_content()
1049
1050 self._closed = True
1051
1052 self._cleanup_writer()
1053 self._release_connection()
1054 return noop()
1055
1056 @property
1057 def ok(self) -> bool:
1058 """Returns ``True`` if ``status`` is less than ``400``, ``False`` if not.
1059
1060 This is **not** a check for ``200 OK`` but a check that the response
1061 status is under 400.
1062 """
1063 return 400 > self.status
1064
1065 def raise_for_status(self) -> None:
1066 if not self.ok:
1067 # reason should always be not None for a started response
1068 assert self.reason is not None
1069 self.release()
1070 raise ClientResponseError(
1071 self.request_info,
1072 self.history,
1073 status=self.status,
1074 message=self.reason,
1075 headers=self.headers,
1076 )
1077
1078 def _release_connection(self) -> None:
1079 if self._connection is not None:
1080 if self._writer is None:
1081 self._connection.release()
1082 self._connection = None
1083 else:
1084 self._writer.add_done_callback(lambda f: self._release_connection())
1085
1086 async def _wait_released(self) -> None:
1087 if self._writer is not None:
1088 await self._writer
1089 self._release_connection()
1090
1091 def _cleanup_writer(self) -> None:
1092 if self._writer is not None:
1093 self._writer.cancel()
1094 self._session = None
1095
1096 def _notify_content(self) -> None:
1097 content = self.content
1098 if content and content.exception() is None:
1099 set_exception(content, ClientConnectionError("Connection closed"))
1100 self._released = True
1101
1102 async def wait_for_close(self) -> None:
1103 if self._writer is not None:
1104 await self._writer
1105 self.release()
1106
1107 async def read(self) -> bytes:
1108 """Read response payload."""
1109 if self._body is None:
1110 try:
1111 self._body = await self.content.read()
1112 for trace in self._traces:
1113 await trace.send_response_chunk_received(
1114 self.method, self.url, self._body
1115 )
1116 except BaseException:
1117 self.close()
1118 raise
1119 elif self._released: # Response explicitly released
1120 raise ClientConnectionError("Connection closed")
1121
1122 protocol = self._connection and self._connection.protocol
1123 if protocol is None or not protocol.upgraded:
1124 await self._wait_released() # Underlying connection released
1125 return self._body # type: ignore[no-any-return]
1126
1127 def get_encoding(self) -> str:
1128 ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower()
1129 mimetype = helpers.parse_mimetype(ctype)
1130
1131 encoding = mimetype.parameters.get("charset")
1132 if encoding:
1133 with contextlib.suppress(LookupError):
1134 return codecs.lookup(encoding).name
1135
1136 if mimetype.type == "application" and (
1137 mimetype.subtype == "json" or mimetype.subtype == "rdap"
1138 ):
1139 # RFC 7159 states that the default encoding is UTF-8.
1140 # RFC 7483 defines application/rdap+json
1141 return "utf-8"
1142
1143 if self._body is None:
1144 raise RuntimeError(
1145 "Cannot compute fallback encoding of a not yet read body"
1146 )
1147
1148 return self._resolve_charset(self, self._body)
1149
1150 async def text(self, encoding: Optional[str] = None, errors: str = "strict") -> str:
1151 """Read response payload and decode."""
1152 if self._body is None:
1153 await self.read()
1154
1155 if encoding is None:
1156 encoding = self.get_encoding()
1157
1158 return self._body.decode( # type: ignore[no-any-return,union-attr]
1159 encoding, errors=errors
1160 )
1161
1162 async def json(
1163 self,
1164 *,
1165 encoding: Optional[str] = None,
1166 loads: JSONDecoder = DEFAULT_JSON_DECODER,
1167 content_type: Optional[str] = "application/json",
1168 ) -> Any:
1169 """Read and decodes JSON response."""
1170 if self._body is None:
1171 await self.read()
1172
1173 if content_type:
1174 ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower()
1175 if not _is_expected_content_type(ctype, content_type):
1176 raise ContentTypeError(
1177 self.request_info,
1178 self.history,
1179 message=(
1180 "Attempt to decode JSON with " "unexpected mimetype: %s" % ctype
1181 ),
1182 headers=self.headers,
1183 )
1184
1185 stripped = self._body.strip() # type: ignore[union-attr]
1186 if not stripped:
1187 return None
1188
1189 if encoding is None:
1190 encoding = self.get_encoding()
1191
1192 return loads(stripped.decode(encoding))
1193
1194 async def __aenter__(self) -> "ClientResponse":
1195 return self
1196
1197 async def __aexit__(
1198 self,
1199 exc_type: Optional[Type[BaseException]],
1200 exc_val: Optional[BaseException],
1201 exc_tb: Optional[TracebackType],
1202 ) -> None:
1203 # similar to _RequestContextManager, we do not need to check
1204 # for exceptions, response object can close connection
1205 # if state is broken
1206 self.release()
1207 await self.wait_for_close()