Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/client_reqrep.py: 32%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import asyncio
2import codecs
3import contextlib
4import functools
5import io
6import re
7import sys
8import traceback
9import warnings
10from collections.abc import Callable, Iterable, Sequence
11from hashlib import md5, sha1, sha256
12from http.cookies import BaseCookie, SimpleCookie
13from types import MappingProxyType, TracebackType
14from typing import TYPE_CHECKING, Any, NamedTuple, TypedDict
16from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
17from yarl import URL, Query
19from . import hdrs, multipart, payload
20from ._cookie_helpers import (
21 parse_cookie_header,
22 parse_set_cookie_headers,
23 preserve_morsel_with_coded_value,
24)
25from .abc import AbstractStreamWriter
26from .base_protocol import BaseProtocol
27from .client_exceptions import (
28 ClientConnectionError,
29 ClientOSError,
30 ClientResponseError,
31 ContentTypeError,
32 InvalidURL,
33 ServerFingerprintMismatch,
34)
35from .compression_utils import HAS_BROTLI, HAS_ZSTD
36from .formdata import FormData
37from .helpers import (
38 _SENTINEL,
39 BaseTimerContext,
40 BasicAuth,
41 HeadersMixin,
42 TimerNoop,
43 frozen_dataclass_decorator,
44 is_expected_content_type,
45 parse_mimetype,
46 reify,
47 sentinel,
48 set_exception,
49 set_result,
50)
51from .http import (
52 SERVER_SOFTWARE,
53 HttpProcessingError,
54 HttpVersion,
55 HttpVersion10,
56 HttpVersion11,
57 StreamWriter,
58)
59from .streams import StreamReader
60from .typedefs import DEFAULT_JSON_DECODER, JSONDecoder, RawHeaders
62try:
63 import ssl
64 from ssl import SSLContext
65except ImportError: # pragma: no cover
66 ssl = None # type: ignore[assignment]
67 SSLContext = object # type: ignore[misc,assignment]
70__all__ = ("ClientRequest", "ClientResponse", "RequestInfo", "Fingerprint")
73if TYPE_CHECKING:
74 from .client import ClientSession
75 from .connector import Connection
76 from .tracing import Trace
79_CONNECTION_CLOSED_EXCEPTION = ClientConnectionError("Connection closed")
80_CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]")
83def _gen_default_accept_encoding() -> str:
84 encodings = [
85 "gzip",
86 "deflate",
87 ]
88 if HAS_BROTLI:
89 encodings.append("br")
90 if HAS_ZSTD:
91 encodings.append("zstd")
92 return ", ".join(encodings)
95@frozen_dataclass_decorator
96class ContentDisposition:
97 type: str | None
98 parameters: "MappingProxyType[str, str]"
99 filename: str | None
102class _RequestInfo(NamedTuple):
103 url: URL
104 method: str
105 headers: "CIMultiDictProxy[str]"
106 real_url: URL
109class RequestInfo(_RequestInfo):
111 def __new__(
112 cls,
113 url: URL,
114 method: str,
115 headers: "CIMultiDictProxy[str]",
116 real_url: URL | _SENTINEL = sentinel,
117 ) -> "RequestInfo":
118 """Create a new RequestInfo instance.
120 For backwards compatibility, the real_url parameter is optional.
121 """
122 return tuple.__new__(
123 cls, (url, method, headers, url if real_url is sentinel else real_url)
124 )
127class Fingerprint:
128 HASHFUNC_BY_DIGESTLEN = {
129 16: md5,
130 20: sha1,
131 32: sha256,
132 }
134 def __init__(self, fingerprint: bytes) -> None:
135 digestlen = len(fingerprint)
136 hashfunc = self.HASHFUNC_BY_DIGESTLEN.get(digestlen)
137 if not hashfunc:
138 raise ValueError("fingerprint has invalid length")
139 elif hashfunc is md5 or hashfunc is sha1:
140 raise ValueError("md5 and sha1 are insecure and not supported. Use sha256.")
141 self._hashfunc = hashfunc
142 self._fingerprint = fingerprint
144 @property
145 def fingerprint(self) -> bytes:
146 return self._fingerprint
148 def check(self, transport: asyncio.Transport) -> None:
149 if not transport.get_extra_info("sslcontext"):
150 return
151 sslobj = transport.get_extra_info("ssl_object")
152 cert = sslobj.getpeercert(binary_form=True)
153 got = self._hashfunc(cert).digest()
154 if got != self._fingerprint:
155 host, port, *_ = transport.get_extra_info("peername")
156 raise ServerFingerprintMismatch(self._fingerprint, got, host, port)
159if ssl is not None:
160 SSL_ALLOWED_TYPES = (ssl.SSLContext, bool, Fingerprint)
161else: # pragma: no cover
162 SSL_ALLOWED_TYPES = (bool,) # type: ignore[unreachable]
165_CONNECTION_CLOSED_EXCEPTION = ClientConnectionError("Connection closed")
166_SSL_SCHEMES = frozenset(("https", "wss"))
169# ConnectionKey is a NamedTuple because it is used as a key in a dict
170# and a set in the connector. Since a NamedTuple is a tuple it uses
171# the fast native tuple __hash__ and __eq__ implementation in CPython.
172class ConnectionKey(NamedTuple):
173 # the key should contain an information about used proxy / TLS
174 # to prevent reusing wrong connections from a pool
175 host: str
176 port: int | None
177 is_ssl: bool
178 ssl: SSLContext | bool | Fingerprint
179 proxy: URL | None
180 proxy_auth: BasicAuth | None
181 proxy_headers_hash: int | None # hash(CIMultiDict)
184class ClientResponse(HeadersMixin):
185 # Some of these attributes are None when created,
186 # but will be set by the start() method.
187 # As the end user will likely never see the None values, we cheat the types below.
188 # from the Status-Line of the response
189 version: HttpVersion | None = None # HTTP-Version
190 status: int = None # type: ignore[assignment] # Status-Code
191 reason: str | None = None # Reason-Phrase
193 content: StreamReader = None # type: ignore[assignment] # Payload stream
194 _body: bytes | None = None
195 _headers: CIMultiDictProxy[str] = None # type: ignore[assignment]
196 _history: tuple["ClientResponse", ...] = ()
197 _raw_headers: RawHeaders = None # type: ignore[assignment]
199 _connection: "Connection | None" = None # current connection
200 _cookies: SimpleCookie | None = None
201 _raw_cookie_headers: tuple[str, ...] | None = None
202 _continue: asyncio.Future[bool] | None = None
203 _source_traceback: traceback.StackSummary | None = None
204 _session: "ClientSession | None" = None
205 # set up by ClientRequest after ClientResponse object creation
206 # post-init stage allows to not change ctor signature
207 _closed = True # to allow __del__ for non-initialized properly response
208 _released = False
209 _in_context = False
211 _resolve_charset: Callable[["ClientResponse", bytes], str] = lambda *_: "utf-8"
213 __writer: asyncio.Task[None] | None = None
215 def __init__(
216 self,
217 method: str,
218 url: URL,
219 *,
220 writer: asyncio.Task[None] | None,
221 continue100: asyncio.Future[bool] | None,
222 timer: BaseTimerContext | None,
223 traces: Sequence["Trace"],
224 loop: asyncio.AbstractEventLoop,
225 session: "ClientSession | None",
226 request_headers: CIMultiDict[str],
227 original_url: URL,
228 **kwargs: object,
229 ) -> None:
230 # kwargs exists so authors of subclasses should expect to pass through unknown
231 # arguments. This allows us to safely add new arguments in future releases.
232 # But, we should never receive unknown arguments here in the parent class, this
233 # would indicate an argument has been named wrong or similar in the subclass.
234 assert not kwargs, "Unexpected arguments to ClientResponse"
235 # URL forbids subclasses, so a simple type check is enough.
236 assert type(url) is URL
238 self.method = method
240 self._real_url = url
241 self._url = url.with_fragment(None) if url.raw_fragment else url
242 if writer is not None:
243 self._writer = writer
244 if continue100 is not None:
245 self._continue = continue100
246 self._request_headers = request_headers
247 self._original_url = original_url
248 self._timer = timer if timer is not None else TimerNoop()
249 self._cache: dict[str, Any] = {}
250 self._traces = traces
251 self._loop = loop
252 # Save reference to _resolve_charset, so that get_encoding() will still
253 # work after the response has finished reading the body.
254 if session is not None:
255 # store a reference to session #1985
256 self._session = session
257 self._resolve_charset = session._resolve_charset
258 if loop.get_debug():
259 self._source_traceback = traceback.extract_stack(sys._getframe(1))
261 def __reset_writer(self, _: object = None) -> None:
262 self.__writer = None
264 @property
265 def _writer(self) -> asyncio.Task[None] | None:
266 """The writer task for streaming data.
268 _writer is only provided for backwards compatibility
269 for subclasses that may need to access it.
270 """
271 return self.__writer
273 @_writer.setter
274 def _writer(self, writer: asyncio.Task[None] | None) -> None:
275 """Set the writer task for streaming data."""
276 if self.__writer is not None:
277 self.__writer.remove_done_callback(self.__reset_writer)
278 self.__writer = writer
279 if writer is None:
280 return
281 if writer.done():
282 # The writer is already done, so we can clear it immediately.
283 self.__writer = None
284 else:
285 writer.add_done_callback(self.__reset_writer)
287 @property
288 def cookies(self) -> SimpleCookie:
289 if self._cookies is None:
290 if self._raw_cookie_headers is not None:
291 # Parse cookies for response.cookies (SimpleCookie for backward compatibility)
292 cookies = SimpleCookie()
293 # Use parse_set_cookie_headers for more lenient parsing that handles
294 # malformed cookies better than SimpleCookie.load
295 cookies.update(parse_set_cookie_headers(self._raw_cookie_headers))
296 self._cookies = cookies
297 else:
298 self._cookies = SimpleCookie()
299 return self._cookies
301 @cookies.setter
302 def cookies(self, cookies: SimpleCookie) -> None:
303 self._cookies = cookies
304 # Generate raw cookie headers from the SimpleCookie
305 if cookies:
306 self._raw_cookie_headers = tuple(
307 morsel.OutputString() for morsel in cookies.values()
308 )
309 else:
310 self._raw_cookie_headers = None
312 @reify
313 def url(self) -> URL:
314 return self._url
316 @reify
317 def real_url(self) -> URL:
318 return self._real_url
320 @reify
321 def host(self) -> str:
322 assert self._url.host is not None
323 return self._url.host
325 @reify
326 def headers(self) -> "CIMultiDictProxy[str]":
327 return self._headers
329 @reify
330 def raw_headers(self) -> RawHeaders:
331 return self._raw_headers
333 @reify
334 def request_info(self) -> RequestInfo:
335 # Build RequestInfo lazily from components
336 headers = CIMultiDictProxy(self._request_headers)
337 return tuple.__new__(
338 RequestInfo, (self._url, self.method, headers, self._original_url)
339 )
341 @reify
342 def content_disposition(self) -> ContentDisposition | None:
343 raw = self._headers.get(hdrs.CONTENT_DISPOSITION)
344 if raw is None:
345 return None
346 disposition_type, params_dct = multipart.parse_content_disposition(raw)
347 params = MappingProxyType(params_dct)
348 filename = multipart.content_disposition_filename(params)
349 return ContentDisposition(disposition_type, params, filename)
351 def __del__(self, _warnings: Any = warnings) -> None:
352 if self._closed:
353 return
355 if self._connection is not None:
356 self._connection.release()
357 self._cleanup_writer()
359 if self._loop.get_debug():
360 _warnings.warn(
361 f"Unclosed response {self!r}", ResourceWarning, source=self
362 )
363 context = {"client_response": self, "message": "Unclosed response"}
364 if self._source_traceback:
365 context["source_traceback"] = self._source_traceback
366 self._loop.call_exception_handler(context)
368 def __repr__(self) -> str:
369 out = io.StringIO()
370 ascii_encodable_url = str(self.url)
371 if self.reason:
372 ascii_encodable_reason = self.reason.encode(
373 "ascii", "backslashreplace"
374 ).decode("ascii")
375 else:
376 ascii_encodable_reason = "None"
377 print(
378 f"<ClientResponse({ascii_encodable_url}) [{self.status} {ascii_encodable_reason}]>",
379 file=out,
380 )
381 print(self.headers, file=out)
382 return out.getvalue()
384 @property
385 def connection(self) -> "Connection | None":
386 return self._connection
388 @reify
389 def history(self) -> tuple["ClientResponse", ...]:
390 """A sequence of responses, if redirects occurred."""
391 return self._history
393 @reify
394 def links(self) -> "MultiDictProxy[MultiDictProxy[str | URL]]":
395 links_str = ", ".join(self.headers.getall("link", []))
397 if not links_str:
398 return MultiDictProxy(MultiDict())
400 links: MultiDict[MultiDictProxy[str | URL]] = MultiDict()
402 for val in re.split(r",(?=\s*<)", links_str):
403 match = re.match(r"\s*<(.*)>(.*)", val)
404 if match is None: # Malformed link
405 continue
406 url, params_str = match.groups()
407 params = params_str.split(";")[1:]
409 link: MultiDict[str | URL] = MultiDict()
411 for param in params:
412 match = re.match(r"^\s*(\S*)\s*=\s*(['\"]?)(.*?)(\2)\s*$", param, re.M)
413 if match is None: # Malformed param
414 continue
415 key, _, value, _ = match.groups()
417 link.add(key, value)
419 key = link.get("rel", url)
421 link.add("url", self.url.join(URL(url)))
423 links.add(str(key), MultiDictProxy(link))
425 return MultiDictProxy(links)
427 async def start(self, connection: "Connection") -> "ClientResponse":
428 """Start response processing."""
429 self._closed = False
430 self._protocol = connection.protocol
431 self._connection = connection
433 with self._timer:
434 while True:
435 # read response
436 try:
437 protocol = self._protocol
438 message, payload = await protocol.read() # type: ignore[union-attr]
439 except HttpProcessingError as exc:
440 raise ClientResponseError(
441 self.request_info,
442 self.history,
443 status=exc.code,
444 message=exc.message,
445 headers=exc.headers,
446 ) from exc
448 if message.code < 100 or message.code > 199 or message.code == 101:
449 break
451 if self._continue is not None:
452 set_result(self._continue, True)
453 self._continue = None
455 # payload eof handler
456 payload.on_eof(self._response_eof)
458 # response status
459 self.version = message.version
460 self.status = message.code
461 self.reason = message.reason
463 # headers
464 self._headers = message.headers # type is CIMultiDictProxy
465 self._raw_headers = message.raw_headers # type is Tuple[bytes, bytes]
467 # payload
468 self.content = payload
470 # cookies
471 if cookie_hdrs := self.headers.getall(hdrs.SET_COOKIE, ()):
472 # Store raw cookie headers for CookieJar
473 self._raw_cookie_headers = tuple(cookie_hdrs)
474 return self
476 def _response_eof(self) -> None:
477 if self._closed:
478 return
480 # protocol could be None because connection could be detached
481 protocol = self._connection and self._connection.protocol
482 if protocol is not None and protocol.upgraded:
483 return
485 self._closed = True
486 self._cleanup_writer()
487 self._release_connection()
489 @property
490 def closed(self) -> bool:
491 return self._closed
493 def close(self) -> None:
494 if not self._released:
495 self._notify_content()
497 self._closed = True
498 if self._loop.is_closed():
499 return
501 self._cleanup_writer()
502 if self._connection is not None:
503 self._connection.close()
504 self._connection = None
506 def release(self) -> None:
507 if not self._released:
508 self._notify_content()
510 self._closed = True
512 self._cleanup_writer()
513 self._release_connection()
515 @property
516 def ok(self) -> bool:
517 """Returns ``True`` if ``status`` is less than ``400``, ``False`` if not.
519 This is **not** a check for ``200 OK`` but a check that the response
520 status is under 400.
521 """
522 return 400 > self.status
524 def raise_for_status(self) -> None:
525 if not self.ok:
526 # reason should always be not None for a started response
527 assert self.reason is not None
529 # If we're in a context we can rely on __aexit__() to release as the
530 # exception propagates.
531 if not self._in_context:
532 self.release()
534 raise ClientResponseError(
535 self.request_info,
536 self.history,
537 status=self.status,
538 message=self.reason,
539 headers=self.headers,
540 )
542 def _release_connection(self) -> None:
543 if self._connection is not None:
544 if self.__writer is None:
545 self._connection.release()
546 self._connection = None
547 else:
548 self.__writer.add_done_callback(lambda f: self._release_connection())
550 async def _wait_released(self) -> None:
551 if self.__writer is not None:
552 try:
553 await self.__writer
554 except asyncio.CancelledError:
555 if (
556 sys.version_info >= (3, 11)
557 and (task := asyncio.current_task())
558 and task.cancelling()
559 ):
560 raise
561 self._release_connection()
563 def _cleanup_writer(self) -> None:
564 if self.__writer is not None:
565 self.__writer.cancel()
566 self._session = None
568 def _notify_content(self) -> None:
569 content = self.content
570 # content can be None here, but the types are cheated elsewhere.
571 if content and content.exception() is None: # type: ignore[truthy-bool]
572 set_exception(content, _CONNECTION_CLOSED_EXCEPTION)
573 self._released = True
575 async def wait_for_close(self) -> None:
576 if self.__writer is not None:
577 try:
578 await self.__writer
579 except asyncio.CancelledError:
580 if (
581 sys.version_info >= (3, 11)
582 and (task := asyncio.current_task())
583 and task.cancelling()
584 ):
585 raise
586 self.release()
588 async def read(self) -> bytes:
589 """Read response payload."""
590 if self._body is None:
591 try:
592 self._body = await self.content.read()
593 for trace in self._traces:
594 await trace.send_response_chunk_received(
595 self.method, self.url, self._body
596 )
597 except BaseException:
598 self.close()
599 raise
600 elif self._released: # Response explicitly released
601 raise ClientConnectionError("Connection closed")
603 protocol = self._connection and self._connection.protocol
604 if protocol is None or not protocol.upgraded:
605 await self._wait_released() # Underlying connection released
606 return self._body
608 def get_encoding(self) -> str:
609 ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower()
610 mimetype = parse_mimetype(ctype)
612 encoding = mimetype.parameters.get("charset")
613 if encoding:
614 with contextlib.suppress(LookupError, ValueError):
615 return codecs.lookup(encoding).name
617 if mimetype.type == "application" and (
618 mimetype.subtype == "json" or mimetype.subtype == "rdap"
619 ):
620 # RFC 7159 states that the default encoding is UTF-8.
621 # RFC 7483 defines application/rdap+json
622 return "utf-8"
624 if self._body is None:
625 raise RuntimeError(
626 "Cannot compute fallback encoding of a not yet read body"
627 )
629 return self._resolve_charset(self, self._body)
631 async def text(self, encoding: str | None = None, errors: str = "strict") -> str:
632 """Read response payload and decode."""
633 await self.read()
635 if encoding is None:
636 encoding = self.get_encoding()
638 return self._body.decode(encoding, errors=errors) # type: ignore[union-attr]
640 async def json(
641 self,
642 *,
643 encoding: str | None = None,
644 loads: JSONDecoder = DEFAULT_JSON_DECODER,
645 content_type: str | None = "application/json",
646 ) -> Any:
647 """Read and decodes JSON response."""
648 await self.read()
650 if content_type:
651 if not is_expected_content_type(self.content_type, content_type):
652 raise ContentTypeError(
653 self.request_info,
654 self.history,
655 status=self.status,
656 message=(
657 "Attempt to decode JSON with "
658 "unexpected mimetype: %s" % self.content_type
659 ),
660 headers=self.headers,
661 )
663 if encoding is None:
664 encoding = self.get_encoding()
666 return loads(self._body.decode(encoding)) # type: ignore[union-attr]
668 async def __aenter__(self) -> "ClientResponse":
669 self._in_context = True
670 return self
672 async def __aexit__(
673 self,
674 exc_type: type[BaseException] | None,
675 exc_val: BaseException | None,
676 exc_tb: TracebackType | None,
677 ) -> None:
678 self._in_context = False
679 # similar to _RequestContextManager, we do not need to check
680 # for exceptions, response object can close connection
681 # if state is broken
682 self.release()
683 await self.wait_for_close()
686class ClientRequestBase:
687 """An internal class for proxy requests."""
689 POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
691 auth = None
692 proxy: URL | None = None
693 response_class = ClientResponse
694 server_hostname: str | None = None # Needed in connector.py
695 version = HttpVersion11
696 _response = None
698 # These class defaults help create_autospec() work correctly.
699 # If autospec is improved in future, maybe these can be removed.
700 url = URL()
701 method = "GET"
703 _writer_task: asyncio.Task[None] | None = None # async task for streaming data
705 _skip_auto_headers: "CIMultiDict[None] | None" = None
707 # N.B.
708 # Adding __del__ method with self._writer closing doesn't make sense
709 # because _writer is instance method, thus it keeps a reference to self.
710 # Until writer has finished finalizer will not be called.
712 def __init__(
713 self,
714 method: str,
715 url: URL,
716 *,
717 headers: CIMultiDict[str],
718 auth: BasicAuth | None,
719 loop: asyncio.AbstractEventLoop,
720 ssl: SSLContext | bool | Fingerprint,
721 trust_env: bool = False,
722 ):
723 if match := _CONTAINS_CONTROL_CHAR_RE.search(method):
724 raise ValueError(
725 f"Method cannot contain non-token characters {method!r} "
726 f"(found at least {match.group()!r})"
727 )
728 # URL forbids subclasses, so a simple type check is enough.
729 assert type(url) is URL, url
730 self.original_url = url
731 self.url = url.with_fragment(None) if url.raw_fragment else url
732 self.method = method.upper()
733 self.loop = loop
734 self._ssl = ssl
736 if loop.get_debug():
737 self._source_traceback = traceback.extract_stack(sys._getframe(1))
739 self._update_host(url)
740 self._update_headers(headers)
741 self._update_auth(auth, trust_env)
743 def _reset_writer(self, _: object = None) -> None:
744 self._writer_task = None
746 def _get_content_length(self) -> int | None:
747 """Extract and validate Content-Length header value.
749 Returns parsed Content-Length value or None if not set.
750 Raises ValueError if header exists but cannot be parsed as an integer.
751 """
752 if hdrs.CONTENT_LENGTH not in self.headers:
753 return None
755 content_length_hdr = self.headers[hdrs.CONTENT_LENGTH]
756 try:
757 return int(content_length_hdr)
758 except ValueError:
759 raise ValueError(
760 f"Invalid Content-Length header: {content_length_hdr}"
761 ) from None
763 @property
764 def _writer(self) -> asyncio.Task[None] | None:
765 return self._writer_task
767 @_writer.setter
768 def _writer(self, writer: asyncio.Task[None]) -> None:
769 if self._writer_task is not None:
770 self._writer_task.remove_done_callback(self._reset_writer)
771 self._writer_task = writer
772 writer.add_done_callback(self._reset_writer)
774 def is_ssl(self) -> bool:
775 return self.url.scheme in _SSL_SCHEMES
777 @property
778 def ssl(self) -> "SSLContext | bool | Fingerprint":
779 return self._ssl
781 @property
782 def connection_key(self) -> ConnectionKey:
783 url = self.url
784 return tuple.__new__(
785 ConnectionKey,
786 (
787 url.raw_host or "",
788 url.port,
789 url.scheme in _SSL_SCHEMES,
790 self._ssl,
791 None,
792 None,
793 None,
794 ),
795 )
797 def _update_auth(self, auth: BasicAuth | None, trust_env: bool = False) -> None:
798 """Set basic auth."""
799 if auth is None:
800 auth = self.auth
801 if auth is None:
802 return
804 if not isinstance(auth, BasicAuth):
805 raise TypeError("BasicAuth() tuple is required instead")
807 self.headers[hdrs.AUTHORIZATION] = auth.encode()
809 def _update_host(self, url: URL) -> None:
810 """Update destination host, port and connection type (ssl)."""
811 # get host/port
812 if not url.raw_host:
813 raise InvalidURL(url)
815 # basic auth info
816 if url.raw_user or url.raw_password:
817 self.auth = BasicAuth(url.user or "", url.password or "")
819 def _update_headers(self, headers: CIMultiDict[str]) -> None:
820 """Update request headers."""
821 self.headers: CIMultiDict[str] = CIMultiDict()
823 # Build the host header
824 host = self.url.host_port_subcomponent
826 # host_port_subcomponent is None when the URL is a relative URL.
827 # but we know we do not have a relative URL here.
828 assert host is not None
829 self.headers[hdrs.HOST] = headers.pop(hdrs.HOST, host)
830 self.headers.extend(headers)
832 def _create_response(self, task: asyncio.Task[None] | None) -> ClientResponse:
833 return self.response_class(
834 self.method,
835 self.original_url,
836 writer=task,
837 continue100=None,
838 timer=TimerNoop(),
839 traces=(),
840 loop=self.loop,
841 session=None,
842 request_headers=self.headers,
843 original_url=self.original_url,
844 )
846 def _create_writer(self, protocol: BaseProtocol) -> StreamWriter:
847 return StreamWriter(protocol, self.loop)
849 def _should_write(self, protocol: BaseProtocol) -> bool:
850 return protocol.writing_paused
852 async def _send(self, conn: "Connection") -> ClientResponse:
853 # Specify request target:
854 # - CONNECT request must send authority form URI
855 # - not CONNECT proxy must send absolute form URI
856 # - most common is origin form URI
857 if self.method == hdrs.METH_CONNECT:
858 connect_host = self.url.host_subcomponent
859 assert connect_host is not None
860 path = f"{connect_host}:{self.url.port}"
861 elif self.proxy and not self.is_ssl():
862 path = str(self.url)
863 else:
864 path = self.url.raw_path_qs
866 protocol = conn.protocol
867 assert protocol is not None
868 writer = self._create_writer(protocol)
870 # set default content-type
871 if (
872 self.method in self.POST_METHODS
873 and (
874 self._skip_auto_headers is None
875 or hdrs.CONTENT_TYPE not in self._skip_auto_headers
876 )
877 and hdrs.CONTENT_TYPE not in self.headers
878 ):
879 self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream"
881 v = self.version
882 if hdrs.CONNECTION not in self.headers:
883 if conn._connector.force_close:
884 if v == HttpVersion11:
885 self.headers[hdrs.CONNECTION] = "close"
886 elif v == HttpVersion10:
887 self.headers[hdrs.CONNECTION] = "keep-alive"
889 # status + headers
890 status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}"
892 # Buffer headers for potential coalescing with body
893 await writer.write_headers(status_line, self.headers)
895 task: asyncio.Task[None] | None
896 if self._should_write(protocol):
897 coro = self._write_bytes(writer, conn, self._get_content_length())
898 if sys.version_info >= (3, 12):
899 # Optimization for Python 3.12, try to write
900 # bytes immediately to avoid having to schedule
901 # the task on the event loop.
902 task = asyncio.Task(coro, loop=self.loop, eager_start=True)
903 else:
904 task = self.loop.create_task(coro)
905 if task.done():
906 task = None
907 else:
908 self._writer = task
909 else:
910 # We have nothing to write because
911 # - there is no body
912 # - the protocol does not have writing paused
913 # - we are not waiting for a 100-continue response
914 protocol.start_timeout()
915 writer.set_eof()
916 task = None
917 self._response = self._create_response(task)
918 return self._response
920 async def _write_bytes(
921 self,
922 writer: AbstractStreamWriter,
923 conn: "Connection",
924 content_length: int | None,
925 ) -> None:
926 # Base class never has a body, this will never be run.
927 assert False
930class ClientRequestArgs(TypedDict, total=False):
931 params: Query
932 headers: CIMultiDict[str]
933 skip_auto_headers: Iterable[str] | None
934 data: Any
935 cookies: BaseCookie[str]
936 auth: BasicAuth | None
937 version: HttpVersion
938 compress: str | bool
939 chunked: bool | None
940 expect100: bool
941 loop: asyncio.AbstractEventLoop
942 response_class: type[ClientResponse]
943 proxy: URL | None
944 proxy_auth: BasicAuth | None
945 timer: BaseTimerContext
946 session: "ClientSession"
947 ssl: SSLContext | bool | Fingerprint
948 proxy_headers: CIMultiDict[str] | None
949 traces: list["Trace"]
950 trust_env: bool
951 server_hostname: str | None
954class ClientRequest(ClientRequestBase):
955 _EMPTY_BODY = payload.PAYLOAD_REGISTRY.get(b"", disposition=None)
956 _body = _EMPTY_BODY
957 _continue = None # waiter future for '100 Continue' response
959 GET_METHODS = {
960 hdrs.METH_GET,
961 hdrs.METH_HEAD,
962 hdrs.METH_OPTIONS,
963 hdrs.METH_TRACE,
964 }
965 DEFAULT_HEADERS = {
966 hdrs.ACCEPT: "*/*",
967 hdrs.ACCEPT_ENCODING: _gen_default_accept_encoding(),
968 }
970 def __init__(
971 self,
972 method: str,
973 url: URL,
974 *,
975 params: Query,
976 headers: CIMultiDict[str],
977 skip_auto_headers: Iterable[str] | None,
978 data: Any,
979 cookies: BaseCookie[str],
980 auth: BasicAuth | None,
981 version: HttpVersion,
982 compress: str | bool,
983 chunked: bool | None,
984 expect100: bool,
985 loop: asyncio.AbstractEventLoop,
986 response_class: type[ClientResponse],
987 proxy: URL | None,
988 proxy_auth: BasicAuth | None,
989 timer: BaseTimerContext,
990 session: "ClientSession",
991 ssl: SSLContext | bool | Fingerprint,
992 proxy_headers: CIMultiDict[str] | None,
993 traces: list["Trace"],
994 trust_env: bool,
995 server_hostname: str | None,
996 **kwargs: object,
997 ):
998 # kwargs exists so authors of subclasses should expect to pass through unknown
999 # arguments. This allows us to safely add new arguments in future releases.
1000 # But, we should never receive unknown arguments here in the parent class, this
1001 # would indicate an argument has been named wrong or similar in the subclass.
1002 assert not kwargs, "Unexpected arguments to ClientRequest"
1004 if params:
1005 url = url.extend_query(params)
1006 super().__init__(method, url, headers=headers, auth=auth, loop=loop, ssl=ssl)
1008 if proxy is not None:
1009 assert type(proxy) is URL, proxy
1010 self._session = session
1011 self.chunked = chunked
1012 self.response_class = response_class
1013 self._timer = timer
1014 self.server_hostname = server_hostname
1015 self.version = version
1017 self._update_auto_headers(skip_auto_headers)
1018 self._update_cookies(cookies)
1019 self._update_content_encoding(data, compress)
1020 self._update_proxy(proxy, proxy_auth, proxy_headers)
1022 self._update_body_from_data(data)
1023 if data is not None or self.method not in self.GET_METHODS:
1024 self._update_transfer_encoding()
1025 self._update_expect_continue(expect100)
1026 self._traces = traces
1028 @property
1029 def body(self) -> payload.Payload:
1030 return self._body
1032 @property
1033 def skip_auto_headers(self) -> CIMultiDict[None]:
1034 return self._skip_auto_headers or CIMultiDict()
1036 @property
1037 def connection_key(self) -> ConnectionKey:
1038 if proxy_headers := self.proxy_headers:
1039 h: int | None = hash(tuple(proxy_headers.items()))
1040 else:
1041 h = None
1042 url = self.url
1043 return tuple.__new__(
1044 ConnectionKey,
1045 (
1046 url.raw_host or "",
1047 url.port,
1048 url.scheme in _SSL_SCHEMES,
1049 self._ssl,
1050 self.proxy,
1051 self.proxy_auth,
1052 h,
1053 ),
1054 )
1056 @property
1057 def session(self) -> "ClientSession":
1058 """Return the ClientSession instance.
1060 This property provides access to the ClientSession that initiated
1061 this request, allowing middleware to make additional requests
1062 using the same session.
1063 """
1064 return self._session
1066 def _update_auto_headers(self, skip_auto_headers: Iterable[str] | None) -> None:
1067 if skip_auto_headers is not None:
1068 self._skip_auto_headers = CIMultiDict(
1069 (hdr, None) for hdr in sorted(skip_auto_headers)
1070 )
1071 used_headers = self.headers.copy()
1072 used_headers.extend(self._skip_auto_headers) # type: ignore[arg-type]
1073 else:
1074 # Fast path when there are no headers to skip
1075 # which is the most common case.
1076 used_headers = self.headers
1078 for hdr, val in self.DEFAULT_HEADERS.items():
1079 if hdr not in used_headers:
1080 self.headers[hdr] = val
1082 if hdrs.USER_AGENT not in used_headers:
1083 self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE
1085 def _update_cookies(self, cookies: BaseCookie[str]) -> None:
1086 """Update request cookies header."""
1087 if not cookies:
1088 return
1090 c = SimpleCookie()
1091 if hdrs.COOKIE in self.headers:
1092 # parse_cookie_header for RFC 6265 compliant Cookie header parsing
1093 c.update(parse_cookie_header(self.headers.get(hdrs.COOKIE, "")))
1094 del self.headers[hdrs.COOKIE]
1096 for name, value in cookies.items():
1097 # Use helper to preserve coded_value exactly as sent by server
1098 c[name] = preserve_morsel_with_coded_value(value)
1100 self.headers[hdrs.COOKIE] = c.output(header="", sep=";").strip()
1102 def _update_content_encoding(self, data: Any, compress: bool | str) -> None:
1103 """Set request content encoding."""
1104 self.compress = None
1105 if not data:
1106 return
1108 if self.headers.get(hdrs.CONTENT_ENCODING):
1109 if compress:
1110 raise ValueError(
1111 "compress can not be set if Content-Encoding header is set"
1112 )
1113 elif compress:
1114 self.compress = compress if isinstance(compress, str) else "deflate"
1115 self.headers[hdrs.CONTENT_ENCODING] = self.compress
1116 self.chunked = True # enable chunked, no need to deal with length
1118 def _update_transfer_encoding(self) -> None:
1119 """Analyze transfer-encoding header."""
1120 te = self.headers.get(hdrs.TRANSFER_ENCODING, "").lower()
1122 if "chunked" in te:
1123 if self.chunked:
1124 raise ValueError(
1125 "chunked can not be set "
1126 'if "Transfer-Encoding: chunked" header is set'
1127 )
1129 elif self.chunked:
1130 if hdrs.CONTENT_LENGTH in self.headers:
1131 raise ValueError(
1132 "chunked can not be set if Content-Length header is set"
1133 )
1135 self.headers[hdrs.TRANSFER_ENCODING] = "chunked"
1137 def _update_body_from_data(self, body: Any) -> None:
1138 """Update request body from data."""
1139 if body is None:
1140 self._body = self._EMPTY_BODY
1141 # Set Content-Length to 0 when body is None for methods that expect a body
1142 if (
1143 self.method not in self.GET_METHODS
1144 and not self.chunked
1145 and hdrs.CONTENT_LENGTH not in self.headers
1146 ):
1147 self.headers[hdrs.CONTENT_LENGTH] = "0"
1148 return
1150 # FormData
1151 if isinstance(body, FormData):
1152 body = body()
1153 else:
1154 try:
1155 body = payload.PAYLOAD_REGISTRY.get(body, disposition=None)
1156 except payload.LookupError:
1157 boundary = None
1158 if hdrs.CONTENT_TYPE in self.headers:
1159 boundary = parse_mimetype(
1160 self.headers[hdrs.CONTENT_TYPE]
1161 ).parameters.get("boundary")
1162 body = FormData(body, boundary=boundary)()
1164 self._body = body
1166 # enable chunked encoding if needed
1167 if not self.chunked and hdrs.CONTENT_LENGTH not in self.headers:
1168 if (size := body.size) is not None:
1169 self.headers[hdrs.CONTENT_LENGTH] = str(size)
1170 else:
1171 self.chunked = True
1173 # copy payload headers
1174 assert body.headers
1175 headers = self.headers
1176 skip_headers = self._skip_auto_headers
1177 for key, value in body.headers.items():
1178 if key in headers or (skip_headers is not None and key in skip_headers):
1179 continue
1180 headers[key] = value
1182 def _update_body(self, body: Any) -> None:
1183 """Update request body after its already been set."""
1184 # Remove existing Content-Length header since body is changing
1185 if hdrs.CONTENT_LENGTH in self.headers:
1186 del self.headers[hdrs.CONTENT_LENGTH]
1188 # Remove existing Transfer-Encoding header to avoid conflicts
1189 if self.chunked and hdrs.TRANSFER_ENCODING in self.headers:
1190 del self.headers[hdrs.TRANSFER_ENCODING]
1192 # Now update the body using the existing method
1193 self._update_body_from_data(body)
1195 # Update transfer encoding headers if needed (same logic as __init__)
1196 if body is not None or self.method not in self.GET_METHODS:
1197 self._update_transfer_encoding()
1199 async def update_body(self, body: Any) -> None:
1200 """
1201 Update request body and close previous payload if needed.
1203 This method safely updates the request body by first closing any existing
1204 payload to prevent resource leaks, then setting the new body.
1206 IMPORTANT: Always use this method instead of setting request.body directly.
1207 Direct assignment to request.body will leak resources if the previous body
1208 contains file handles, streams, or other resources that need cleanup.
1210 Args:
1211 body: The new body content. Can be:
1212 - bytes/bytearray: Raw binary data
1213 - str: Text data (will be encoded using charset from Content-Type)
1214 - FormData: Form data that will be encoded as multipart/form-data
1215 - Payload: A pre-configured payload object
1216 - AsyncIterable: An async iterable of bytes chunks
1217 - File-like object: Will be read and sent as binary data
1218 - None: Clears the body
1220 Usage:
1221 # CORRECT: Use update_body
1222 await request.update_body(b"new request data")
1224 # WRONG: Don't set body directly
1225 # request.body = b"new request data" # This will leak resources!
1227 # Update with form data
1228 form_data = FormData()
1229 form_data.add_field('field', 'value')
1230 await request.update_body(form_data)
1232 # Clear body
1233 await request.update_body(None)
1235 Note:
1236 This method is async because it may need to close file handles or
1237 other resources associated with the previous payload. Always await
1238 this method to ensure proper cleanup.
1240 Warning:
1241 Setting request.body directly is highly discouraged and can lead to:
1242 - Resource leaks (unclosed file handles, streams)
1243 - Memory leaks (unreleased buffers)
1244 - Unexpected behavior with streaming payloads
1246 It is not recommended to change the payload type in middleware. If the
1247 body was already set (e.g., as bytes), it's best to keep the same type
1248 rather than converting it (e.g., to str) as this may result in unexpected
1249 behavior.
1251 See Also:
1252 - update_body_from_data: Synchronous body update without cleanup
1253 - body property: Direct body access (STRONGLY DISCOURAGED)
1255 """
1256 # Close existing payload if it exists and needs closing
1257 if self._body is not None:
1258 await self._body.close()
1259 self._update_body(body)
1261 def _update_expect_continue(self, expect: bool = False) -> None:
1262 if expect:
1263 self.headers[hdrs.EXPECT] = "100-continue"
1264 elif (
1265 hdrs.EXPECT in self.headers
1266 and self.headers[hdrs.EXPECT].lower() == "100-continue"
1267 ):
1268 expect = True
1270 if expect:
1271 self._continue = self.loop.create_future()
1273 def _update_proxy(
1274 self,
1275 proxy: URL | None,
1276 proxy_auth: BasicAuth | None,
1277 proxy_headers: CIMultiDict[str] | None,
1278 ) -> None:
1279 self.proxy = proxy
1280 if proxy is None:
1281 self.proxy_auth = None
1282 self.proxy_headers = None
1283 return
1285 if proxy_auth and not isinstance(proxy_auth, BasicAuth):
1286 raise ValueError("proxy_auth must be None or BasicAuth() tuple")
1287 self.proxy_auth = proxy_auth
1288 self.proxy_headers = proxy_headers
1290 def _create_response(self, task: asyncio.Task[None] | None) -> ClientResponse:
1291 return self.response_class(
1292 self.method,
1293 self.original_url,
1294 writer=task,
1295 continue100=self._continue,
1296 timer=self._timer,
1297 traces=self._traces,
1298 loop=self.loop,
1299 session=self._session,
1300 request_headers=self.headers,
1301 original_url=self.original_url,
1302 )
1304 def _create_writer(self, protocol: BaseProtocol) -> StreamWriter:
1305 writer = StreamWriter(
1306 protocol,
1307 self.loop,
1308 on_chunk_sent=(
1309 functools.partial(self._on_chunk_request_sent, self.method, self.url)
1310 if self._traces
1311 else None
1312 ),
1313 on_headers_sent=(
1314 functools.partial(self._on_headers_request_sent, self.method, self.url)
1315 if self._traces
1316 else None
1317 ),
1318 )
1320 if self.compress:
1321 writer.enable_compression(self.compress)
1323 if self.chunked is not None:
1324 writer.enable_chunking()
1325 return writer
1327 def _should_write(self, protocol: BaseProtocol) -> bool:
1328 return (
1329 self.body.size != 0 or self._continue is not None or protocol.writing_paused
1330 )
1332 async def _write_bytes(
1333 self,
1334 writer: AbstractStreamWriter,
1335 conn: "Connection",
1336 content_length: int | None,
1337 ) -> None:
1338 """
1339 Write the request body to the connection stream.
1341 This method handles writing different types of request bodies:
1342 1. Payload objects (using their specialized write_with_length method)
1343 2. Bytes/bytearray objects
1344 3. Iterable body content
1346 Args:
1347 writer: The stream writer to write the body to
1348 conn: The connection being used for this request
1349 content_length: Optional maximum number of bytes to write from the body
1350 (None means write the entire body)
1352 The method properly handles:
1353 - Waiting for 100-Continue responses if required
1354 - Content length constraints for chunked encoding
1355 - Error handling for network issues, cancellation, and other exceptions
1356 - Signaling EOF and timeout management
1358 Raises:
1359 ClientOSError: When there's an OS-level error writing the body
1360 ClientConnectionError: When there's a general connection error
1361 asyncio.CancelledError: When the operation is cancelled
1363 """
1364 # 100 response
1365 if self._continue is not None:
1366 # Force headers to be sent before waiting for 100-continue
1367 writer.send_headers()
1368 await writer.drain()
1369 await self._continue
1371 protocol = conn.protocol
1372 assert protocol is not None
1373 try:
1374 await self._body.write_with_length(writer, content_length)
1375 except OSError as underlying_exc:
1376 reraised_exc = underlying_exc
1378 # Distinguish between timeout and other OS errors for better error reporting
1379 exc_is_not_timeout = underlying_exc.errno is not None or not isinstance(
1380 underlying_exc, asyncio.TimeoutError
1381 )
1382 if exc_is_not_timeout:
1383 reraised_exc = ClientOSError(
1384 underlying_exc.errno,
1385 f"Can not write request body for {self.url !s}",
1386 )
1388 set_exception(protocol, reraised_exc, underlying_exc)
1389 except asyncio.CancelledError:
1390 # Body hasn't been fully sent, so connection can't be reused
1391 conn.close()
1392 raise
1393 except Exception as underlying_exc:
1394 set_exception(
1395 protocol,
1396 ClientConnectionError(
1397 "Failed to send bytes into the underlying connection "
1398 f"{conn !s}: {underlying_exc!r}",
1399 ),
1400 underlying_exc,
1401 )
1402 else:
1403 # Successfully wrote the body, signal EOF and start response timeout
1404 await writer.write_eof()
1405 protocol.start_timeout()
1407 async def _close(self) -> None:
1408 if self._writer_task is not None:
1409 try:
1410 await self._writer_task
1411 except asyncio.CancelledError:
1412 if (
1413 sys.version_info >= (3, 11)
1414 and (task := asyncio.current_task())
1415 and task.cancelling()
1416 ):
1417 raise
1419 def _terminate(self) -> None:
1420 if self._writer_task is not None:
1421 if not self.loop.is_closed():
1422 self._writer_task.cancel()
1423 self._writer_task.remove_done_callback(self._reset_writer)
1424 self._writer_task = None
1426 async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None:
1427 for trace in self._traces:
1428 await trace.send_request_chunk_sent(method, url, chunk)
1430 async def _on_headers_request_sent(
1431 self, method: str, url: URL, headers: "CIMultiDict[str]"
1432 ) -> None:
1433 for trace in self._traces:
1434 await trace.send_request_headers(method, url, headers)