Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/test_utils.py: 53%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Utilities shared by tests."""
3import asyncio
4import contextlib
5import gc
6import ipaddress
7import os
8import socket
9import sys
10from abc import ABC, abstractmethod
11from collections.abc import Callable, Iterator
12from types import TracebackType
13from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast, overload
14from unittest import IsolatedAsyncioTestCase, mock
16from aiosignal import Signal
17from multidict import CIMultiDict, CIMultiDictProxy
18from yarl import URL
20import aiohttp
21from aiohttp.client import (
22 _BaseRequestContextManager,
23 _RequestContextManager,
24 _RequestOptions,
25 _WSRequestContextManager,
26)
28from . import ClientSession, hdrs
29from .abc import AbstractCookieJar, AbstractStreamWriter
30from .client_reqrep import ClientResponse
31from .client_ws import ClientWebSocketResponse
32from .http import HttpVersion, RawRequestMessage
33from .streams import EMPTY_PAYLOAD, StreamReader
34from .typedefs import LooseHeaders, StrOrURL
35from .web import (
36 Application,
37 AppRunner,
38 BaseRequest,
39 BaseRunner,
40 Request,
41 RequestHandler,
42 Server,
43 ServerRunner,
44 SockSite,
45 UrlMappingMatchInfo,
46)
47from .web_protocol import _RequestHandler
49if TYPE_CHECKING:
50 from ssl import SSLContext
51else:
52 SSLContext = Any
54if sys.version_info >= (3, 11) and TYPE_CHECKING:
55 from typing import Unpack
57if sys.version_info >= (3, 11):
58 from typing import Self
59else:
60 Self = Any
62_ApplicationNone = TypeVar("_ApplicationNone", Application, None)
63_Request = TypeVar("_Request", bound=BaseRequest)
65REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin"
68def get_unused_port_socket(
69 host: str, family: socket.AddressFamily = socket.AF_INET
70) -> socket.socket:
71 return get_port_socket(host, 0, family)
74def get_port_socket(
75 host: str, port: int, family: socket.AddressFamily = socket.AF_INET
76) -> socket.socket:
77 s = socket.socket(family, socket.SOCK_STREAM)
78 if REUSE_ADDRESS:
79 # Windows has different semantics for SO_REUSEADDR,
80 # so don't set it. Ref:
81 # https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
82 s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
83 s.bind((host, port))
84 return s
87def unused_port() -> int:
88 """Return a port that is unused on the current host."""
89 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
90 s.bind(("127.0.0.1", 0))
91 return cast(int, s.getsockname()[1])
94class BaseTestServer(ABC, Generic[_Request]):
95 __test__ = False
97 def __init__(
98 self,
99 *,
100 scheme: str = "",
101 host: str = "127.0.0.1",
102 port: int | None = None,
103 skip_url_asserts: bool = False,
104 socket_factory: Callable[
105 [str, int, socket.AddressFamily], socket.socket
106 ] = get_port_socket,
107 **kwargs: Any,
108 ) -> None:
109 self.runner: BaseRunner[_Request] | None = None
110 self._root: URL | None = None
111 self.host = host
112 self.port = port or 0
113 self._closed = False
114 self.scheme = scheme
115 self.skip_url_asserts = skip_url_asserts
116 self.socket_factory = socket_factory
118 async def start_server(self, **kwargs: Any) -> None:
119 if self.runner:
120 return
121 self._ssl = kwargs.pop("ssl", None)
122 self.runner = await self._make_runner(handler_cancellation=True, **kwargs)
123 await self.runner.setup()
124 absolute_host = self.host
125 try:
126 version = ipaddress.ip_address(self.host).version
127 except ValueError:
128 version = 4
129 if version == 6:
130 absolute_host = f"[{self.host}]"
131 family = socket.AF_INET6 if version == 6 else socket.AF_INET
132 _sock = self.socket_factory(self.host, self.port, family)
133 self.host, self.port = _sock.getsockname()[:2]
134 site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl)
135 await site.start()
136 server = site._server
137 assert server is not None
138 sockets = server.sockets
139 assert sockets is not None
140 self.port = sockets[0].getsockname()[1]
141 if not self.scheme:
142 self.scheme = "https" if self._ssl else "http"
143 self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}")
145 @abstractmethod
146 async def _make_runner(self, **kwargs: Any) -> BaseRunner[_Request]:
147 """Return a new runner for the server."""
148 # TODO(PY311): Use Unpack to specify Server kwargs.
150 def make_url(self, path: StrOrURL) -> URL:
151 assert self._root is not None
152 url = URL(path)
153 if not self.skip_url_asserts:
154 assert not url.absolute
155 return self._root.join(url)
156 else:
157 return URL(str(self._root) + str(path))
159 @property
160 def started(self) -> bool:
161 return self.runner is not None
163 @property
164 def closed(self) -> bool:
165 return self._closed
167 @property
168 def handler(self) -> Server[_Request]:
169 # for backward compatibility
170 # web.Server instance
171 runner = self.runner
172 assert runner is not None
173 assert runner.server is not None
174 return runner.server
176 async def close(self) -> None:
177 """Close all fixtures created by the test client.
179 After that point, the TestClient is no longer usable.
181 This is an idempotent function: running close multiple times
182 will not have any additional effects.
184 close is also run when the object is garbage collected, and on
185 exit when used as a context manager.
187 """
188 if self.started and not self.closed:
189 assert self.runner is not None
190 await self.runner.cleanup()
191 self._root = None
192 self.port = 0
193 self._closed = True
195 async def __aenter__(self) -> Self:
196 await self.start_server()
197 return self
199 async def __aexit__(
200 self,
201 exc_type: type[BaseException] | None,
202 exc_value: BaseException | None,
203 traceback: TracebackType | None,
204 ) -> None:
205 await self.close()
208class TestServer(BaseTestServer[Request]):
209 def __init__(
210 self,
211 app: Application,
212 *,
213 scheme: str = "",
214 host: str = "127.0.0.1",
215 port: int | None = None,
216 **kwargs: Any,
217 ):
218 self.app = app
219 super().__init__(scheme=scheme, host=host, port=port, **kwargs)
221 async def _make_runner(self, **kwargs: Any) -> AppRunner:
222 # TODO(PY311): Use Unpack to specify Server kwargs.
223 return AppRunner(self.app, **kwargs)
226class RawTestServer(BaseTestServer[BaseRequest]):
227 def __init__(
228 self,
229 handler: _RequestHandler[BaseRequest],
230 *,
231 scheme: str = "",
232 host: str = "127.0.0.1",
233 port: int | None = None,
234 **kwargs: Any,
235 ) -> None:
236 self._handler = handler
237 super().__init__(scheme=scheme, host=host, port=port, **kwargs)
239 async def _make_runner(self, **kwargs: Any) -> ServerRunner:
240 # TODO(PY311): Use Unpack to specify Server kwargs.
241 srv = Server(self._handler, **kwargs)
242 return ServerRunner(srv, **kwargs)
245class TestClient(Generic[_Request, _ApplicationNone]):
246 """
247 A test client implementation.
249 To write functional tests for aiohttp based servers.
251 """
253 __test__ = False
255 @overload
256 def __init__(
257 self: "TestClient[Request, Application]",
258 server: TestServer,
259 *,
260 cookie_jar: AbstractCookieJar | None = None,
261 **kwargs: Any,
262 ) -> None: ...
263 @overload
264 def __init__(
265 self: "TestClient[_Request, None]",
266 server: BaseTestServer[_Request],
267 *,
268 cookie_jar: AbstractCookieJar | None = None,
269 **kwargs: Any,
270 ) -> None: ...
271 def __init__( # type: ignore[misc]
272 self,
273 server: BaseTestServer[_Request],
274 *,
275 cookie_jar: AbstractCookieJar | None = None,
276 **kwargs: Any,
277 ) -> None:
278 # TODO(PY311): Use Unpack to specify ClientSession kwargs.
279 if not isinstance(server, BaseTestServer):
280 raise TypeError(
281 "server must be TestServer instance, found type: %r" % type(server)
282 )
283 self._server = server
284 if cookie_jar is None:
285 cookie_jar = aiohttp.CookieJar(unsafe=True)
286 self._session = ClientSession(cookie_jar=cookie_jar, **kwargs)
287 self._session._retry_connection = False
288 self._closed = False
289 self._responses: list[ClientResponse] = []
290 self._websockets: list[ClientWebSocketResponse[bool]] = []
292 async def start_server(self) -> None:
293 await self._server.start_server()
295 @property
296 def scheme(self) -> str | object:
297 return self._server.scheme
299 @property
300 def host(self) -> str:
301 return self._server.host
303 @property
304 def port(self) -> int:
305 return self._server.port
307 @property
308 def server(self) -> BaseTestServer[_Request]:
309 return self._server
311 @property
312 def app(self) -> _ApplicationNone:
313 return getattr(self._server, "app", None) # type: ignore[return-value]
315 @property
316 def session(self) -> ClientSession:
317 """An internal aiohttp.ClientSession.
319 Unlike the methods on the TestClient, client session requests
320 do not automatically include the host in the url queried, and
321 will require an absolute path to the resource.
323 """
324 return self._session
326 def make_url(self, path: StrOrURL) -> URL:
327 return self._server.make_url(path)
329 async def _request(
330 self, method: str, path: StrOrURL, **kwargs: Any
331 ) -> ClientResponse:
332 resp = await self._session.request(method, self.make_url(path), **kwargs)
333 # save it to close later
334 self._responses.append(resp)
335 return resp
337 if sys.version_info >= (3, 11) and TYPE_CHECKING:
339 def request(
340 self, method: str, path: StrOrURL, **kwargs: Unpack[_RequestOptions]
341 ) -> _RequestContextManager: ...
343 def get(
344 self,
345 path: StrOrURL,
346 **kwargs: Unpack[_RequestOptions],
347 ) -> _RequestContextManager: ...
349 def options(
350 self,
351 path: StrOrURL,
352 **kwargs: Unpack[_RequestOptions],
353 ) -> _RequestContextManager: ...
355 def head(
356 self,
357 path: StrOrURL,
358 **kwargs: Unpack[_RequestOptions],
359 ) -> _RequestContextManager: ...
361 def post(
362 self,
363 path: StrOrURL,
364 **kwargs: Unpack[_RequestOptions],
365 ) -> _RequestContextManager: ...
367 def put(
368 self,
369 path: StrOrURL,
370 **kwargs: Unpack[_RequestOptions],
371 ) -> _RequestContextManager: ...
373 def patch(
374 self,
375 path: StrOrURL,
376 **kwargs: Unpack[_RequestOptions],
377 ) -> _RequestContextManager: ...
379 def delete(
380 self,
381 path: StrOrURL,
382 **kwargs: Unpack[_RequestOptions],
383 ) -> _RequestContextManager: ...
385 else:
387 def request(
388 self, method: str, path: StrOrURL, **kwargs: Any
389 ) -> _RequestContextManager:
390 """Routes a request to tested http server.
392 The interface is identical to aiohttp.ClientSession.request,
393 except the loop kwarg is overridden by the instance used by the
394 test server.
396 """
397 return _RequestContextManager(self._request(method, path, **kwargs))
399 def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
400 """Perform an HTTP GET request."""
401 return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs))
403 def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
404 """Perform an HTTP POST request."""
405 return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs))
407 def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
408 """Perform an HTTP OPTIONS request."""
409 return _RequestContextManager(
410 self._request(hdrs.METH_OPTIONS, path, **kwargs)
411 )
413 def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
414 """Perform an HTTP HEAD request."""
415 return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs))
417 def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
418 """Perform an HTTP PUT request."""
419 return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs))
421 def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
422 """Perform an HTTP PATCH request."""
423 return _RequestContextManager(
424 self._request(hdrs.METH_PATCH, path, **kwargs)
425 )
427 def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
428 """Perform an HTTP PATCH request."""
429 return _RequestContextManager(
430 self._request(hdrs.METH_DELETE, path, **kwargs)
431 )
433 @overload
434 def ws_connect(
435 self, path: StrOrURL, *, decode_text: Literal[True] = ..., **kwargs: Any
436 ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[True]]]": ...
438 @overload
439 def ws_connect(
440 self, path: StrOrURL, *, decode_text: Literal[False], **kwargs: Any
441 ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]": ...
443 @overload
444 def ws_connect(
445 self, path: StrOrURL, *, decode_text: bool = ..., **kwargs: Any
446 ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": ...
448 def ws_connect(
449 self, path: StrOrURL, *, decode_text: bool = True, **kwargs: Any
450 ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]":
451 """Initiate websocket connection.
453 The api corresponds to aiohttp.ClientSession.ws_connect.
455 """
456 return _WSRequestContextManager(
457 self._ws_connect(path, decode_text=decode_text, **kwargs)
458 )
460 @overload
461 async def _ws_connect(
462 self, path: StrOrURL, *, decode_text: Literal[True] = ..., **kwargs: Any
463 ) -> "ClientWebSocketResponse[Literal[True]]": ...
465 @overload
466 async def _ws_connect(
467 self, path: StrOrURL, *, decode_text: Literal[False], **kwargs: Any
468 ) -> "ClientWebSocketResponse[Literal[False]]": ...
470 @overload
471 async def _ws_connect(
472 self, path: StrOrURL, *, decode_text: bool = ..., **kwargs: Any
473 ) -> "ClientWebSocketResponse[bool]": ...
475 async def _ws_connect(
476 self, path: StrOrURL, *, decode_text: bool = True, **kwargs: Any
477 ) -> "ClientWebSocketResponse[bool]":
478 ws = await self._session.ws_connect(
479 self.make_url(path), decode_text=decode_text, **kwargs
480 )
481 self._websockets.append(ws)
482 return ws
484 async def close(self) -> None:
485 """Close all fixtures created by the test client.
487 After that point, the TestClient is no longer usable.
489 This is an idempotent function: running close multiple times
490 will not have any additional effects.
492 close is also run on exit when used as a(n) (asynchronous)
493 context manager.
495 """
496 if not self._closed:
497 for resp in self._responses:
498 resp.close()
499 for ws in self._websockets:
500 await ws.close()
501 await self._session.close()
502 await self._server.close()
503 self._closed = True
505 async def __aenter__(self) -> Self:
506 await self.start_server()
507 return self
509 async def __aexit__(
510 self,
511 exc_type: type[BaseException] | None,
512 exc: BaseException | None,
513 tb: TracebackType | None,
514 ) -> None:
515 await self.close()
518class AioHTTPTestCase(IsolatedAsyncioTestCase, ABC):
519 """A base class to allow for unittest web applications using aiohttp.
521 Provides the following:
523 * self.client (aiohttp.test_utils.TestClient): an aiohttp test client.
524 * self.app (aiohttp.web.Application): the application returned by
525 self.get_application()
527 Note that the TestClient's methods are asynchronous: you have to
528 execute function on the test client using asynchronous methods.
529 """
531 @abstractmethod
532 async def get_application(self) -> Application:
533 """Get application.
535 This method should be overridden to return the aiohttp.web.Application
536 object to test.
537 """
539 async def asyncSetUp(self) -> None:
540 self.app = await self.get_application()
541 self.server = await self.get_server(self.app)
542 self.client = await self.get_client(self.server)
544 await self.client.start_server()
546 async def asyncTearDown(self) -> None:
547 await self.client.close()
549 async def get_server(self, app: Application) -> TestServer:
550 """Return a TestServer instance."""
551 return TestServer(app)
553 async def get_client(self, server: TestServer) -> TestClient[Request, Application]:
554 """Return a TestClient instance."""
555 return TestClient(server)
558_LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop]
561@contextlib.contextmanager
562def loop_context(
563 loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False
564) -> Iterator[asyncio.AbstractEventLoop]:
565 """A contextmanager that creates an event_loop, for test purposes.
567 Handles the creation and cleanup of a test loop.
568 """
569 loop = setup_test_loop(loop_factory)
570 yield loop
571 teardown_test_loop(loop, fast=fast)
574def setup_test_loop(
575 loop_factory: _LOOP_FACTORY = asyncio.new_event_loop,
576) -> asyncio.AbstractEventLoop:
577 """Create and return an asyncio.BaseEventLoop instance.
579 The caller should also call teardown_test_loop,
580 once they are done with the loop.
581 """
582 loop = loop_factory()
583 asyncio.set_event_loop(loop)
584 return loop
587def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None:
588 """Teardown and cleanup an event_loop created by setup_test_loop."""
589 closed = loop.is_closed()
590 if not closed:
591 loop.call_soon(loop.stop)
592 loop.run_forever()
593 loop.close()
595 if not fast:
596 gc.collect()
598 asyncio.set_event_loop(None)
601def _create_app_mock() -> mock.MagicMock:
602 def get_dict(app: Any, key: str) -> Any:
603 return app.__app_dict[key]
605 def set_dict(app: Any, key: str, value: Any) -> None:
606 app.__app_dict[key] = value
608 app = mock.MagicMock(spec=Application)
609 app.__app_dict = {}
610 app.__getitem__ = get_dict
611 app.__setitem__ = set_dict
613 app.on_response_prepare = Signal(app)
614 app.on_response_prepare.freeze()
615 return app
618def _create_transport(sslcontext: SSLContext | None = None) -> mock.Mock:
619 transport = mock.Mock()
621 def get_extra_info(key: str) -> SSLContext | None:
622 if key == "sslcontext":
623 return sslcontext
624 else:
625 return None
627 transport.get_extra_info.side_effect = get_extra_info
628 return transport
631def make_mocked_request(
632 method: str,
633 path: str,
634 headers: LooseHeaders | None = None,
635 *,
636 match_info: dict[str, str] | None = None,
637 version: HttpVersion = HttpVersion(1, 1),
638 closing: bool = False,
639 app: Application | None = None,
640 writer: AbstractStreamWriter | None = None,
641 protocol: RequestHandler[Request] | None = None,
642 transport: asyncio.Transport | None = None,
643 payload: StreamReader = EMPTY_PAYLOAD,
644 sslcontext: SSLContext | None = None,
645 client_max_size: int = 1024**2,
646 loop: Any = ...,
647) -> Request:
648 """Creates mocked web.Request testing purposes.
650 Useful in unit tests, when spinning full web server is overkill or
651 specific conditions and errors are hard to trigger.
652 """
653 task = mock.Mock()
654 if loop is ...:
655 # no loop passed, try to get the current one if
656 # its is running as we need a real loop to create
657 # executor jobs to be able to do testing
658 # with a real executor
659 try:
660 loop = asyncio.get_running_loop()
661 except RuntimeError:
662 loop = mock.Mock()
663 loop.create_future.return_value = ()
665 if version < HttpVersion(1, 1):
666 closing = True
668 if headers:
669 headers = CIMultiDictProxy(CIMultiDict(headers))
670 raw_hdrs = tuple(
671 (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items()
672 )
673 else:
674 headers = CIMultiDictProxy(CIMultiDict())
675 raw_hdrs = ()
677 chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower()
679 message = RawRequestMessage(
680 method,
681 path,
682 version,
683 headers,
684 raw_hdrs,
685 closing,
686 None,
687 False,
688 chunked,
689 URL(path),
690 )
691 if app is None:
692 app = _create_app_mock()
694 if transport is None:
695 transport = _create_transport(sslcontext)
697 if protocol is None:
698 protocol = mock.Mock()
699 protocol.max_field_size = 8190
700 protocol.max_line_length = 8190
701 protocol.max_headers = 128
702 protocol.transport = transport
703 type(protocol).peername = mock.PropertyMock(
704 return_value=transport.get_extra_info("peername")
705 )
706 type(protocol).ssl_context = mock.PropertyMock(return_value=sslcontext)
708 if writer is None:
709 writer = mock.Mock()
710 writer.write_headers = mock.AsyncMock(return_value=None)
711 writer.write = mock.AsyncMock(return_value=None)
712 writer.write_eof = mock.AsyncMock(return_value=None)
713 writer.drain = mock.AsyncMock(return_value=None)
714 writer.transport = transport
716 protocol.transport = transport
718 req = Request(
719 message, payload, protocol, writer, task, loop, client_max_size=client_max_size
720 )
722 match_info = UrlMappingMatchInfo(
723 {} if match_info is None else match_info, mock.Mock()
724 )
725 match_info.add_app(app)
726 req._match_info = match_info
728 return req