Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/test_utils.py: 56%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Utilities shared by tests."""
3import asyncio
4import ipaddress
5import os
6import socket
7import sys
8from abc import ABC, abstractmethod
9from collections.abc import Callable
10from types import TracebackType
11from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload
12from unittest import IsolatedAsyncioTestCase, mock
14from aiosignal import Signal
15from multidict import CIMultiDict
16from yarl import URL
18import aiohttp
19from aiohttp.client import (
20 _BaseRequestContextManager,
21 _RequestContextManager,
22 _RequestOptions,
23 _WSRequestContextManager,
24)
26from . import ClientSession, hdrs
27from .abc import AbstractCookieJar, AbstractStreamWriter
28from .client_reqrep import ClientResponse
29from .client_ws import ClientWebSocketResponse
30from .helpers import HeadersDictProxy
31from .http import HttpVersion, RawRequestMessage
32from .streams import EMPTY_PAYLOAD, StreamReader
33from .typedefs import LooseHeaders, StrOrURL
34from .web import (
35 Application,
36 AppRunner,
37 BaseRequest,
38 BaseRunner,
39 Request,
40 RequestHandler,
41 Server,
42 ServerRunner,
43 SockSite,
44 UrlMappingMatchInfo,
45)
46from .web_protocol import _RequestHandler
48if TYPE_CHECKING:
49 from ssl import SSLContext
50else:
51 SSLContext = Any
53if sys.version_info >= (3, 11) and TYPE_CHECKING:
54 from typing import Unpack
56if sys.version_info >= (3, 11):
57 from typing import Self
58else:
59 Self = Any
61_ApplicationNone = TypeVar("_ApplicationNone", Application, None)
62_Request = TypeVar("_Request", bound=BaseRequest)
64REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin"
67class BaseTestServer(ABC, Generic[_Request]):
68 __test__ = False
70 def __init__(
71 self,
72 *,
73 scheme: str = "",
74 host: str = "127.0.0.1",
75 port: int | None = None,
76 skip_url_asserts: bool = False,
77 socket_factory: Callable[
78 [str, int, socket.AddressFamily], socket.socket
79 ] = lambda h, p, f: socket.create_server(
80 (h, p), family=f, reuse_port=REUSE_ADDRESS
81 ),
82 **kwargs: Any,
83 ) -> None:
84 self.runner: BaseRunner[_Request] | None = None
85 self._root: URL | None = None
86 self.host = host
87 self.port = port or 0
88 self._closed = False
89 self.scheme = scheme
90 self.skip_url_asserts = skip_url_asserts
91 self.socket_factory = socket_factory
93 async def start_server(self, **kwargs: Any) -> None:
94 if self.runner:
95 return
96 self._ssl = kwargs.pop("ssl", None)
97 self.runner = await self._make_runner(handler_cancellation=True, **kwargs)
98 await self.runner.setup()
99 absolute_host = self.host
100 try:
101 version = ipaddress.ip_address(self.host).version
102 except ValueError:
103 version = 4
104 if version == 6:
105 absolute_host = f"[{self.host}]"
106 family = socket.AF_INET6 if version == 6 else socket.AF_INET
107 _sock = self.socket_factory(self.host, self.port, family)
108 self.host, self.port = _sock.getsockname()[:2]
109 site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl)
110 await site.start()
111 server = site._server
112 assert server is not None
113 sockets = server.sockets
114 assert sockets is not None
115 self.port = sockets[0].getsockname()[1]
116 if not self.scheme:
117 self.scheme = "https" if self._ssl else "http"
118 self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}")
120 @abstractmethod
121 async def _make_runner(self, **kwargs: Any) -> BaseRunner[_Request]:
122 """Return a new runner for the server."""
123 # TODO(PY311): Use Unpack to specify Server kwargs.
125 def make_url(self, path: StrOrURL) -> URL:
126 assert self._root is not None
127 url = URL(path)
128 if not self.skip_url_asserts:
129 assert not url.absolute
130 return self._root.join(url)
131 else:
132 return URL(str(self._root) + str(path))
134 @property
135 def started(self) -> bool:
136 return self.runner is not None
138 @property
139 def closed(self) -> bool:
140 return self._closed
142 @property
143 def handler(self) -> Server[_Request]:
144 # for backward compatibility
145 # web.Server instance
146 runner = self.runner
147 assert runner is not None
148 assert runner.server is not None
149 return runner.server
151 async def close(self) -> None:
152 """Close all fixtures created by the test client.
154 After that point, the TestClient is no longer usable.
156 This is an idempotent function: running close multiple times
157 will not have any additional effects.
159 close is also run when the object is garbage collected, and on
160 exit when used as a context manager.
162 """
163 if self.started and not self.closed:
164 assert self.runner is not None
165 await self.runner.cleanup()
166 self._root = None
167 self.port = 0
168 self._closed = True
170 async def __aenter__(self) -> Self:
171 await self.start_server()
172 return self
174 async def __aexit__(
175 self,
176 exc_type: type[BaseException] | None,
177 exc_value: BaseException | None,
178 traceback: TracebackType | None,
179 ) -> None:
180 await self.close()
183class TestServer(BaseTestServer[Request]):
184 def __init__(
185 self,
186 app: Application,
187 *,
188 scheme: str = "",
189 host: str = "127.0.0.1",
190 port: int | None = None,
191 **kwargs: Any,
192 ):
193 self.app = app
194 super().__init__(scheme=scheme, host=host, port=port, **kwargs)
196 async def _make_runner(self, **kwargs: Any) -> AppRunner:
197 # TODO(PY311): Use Unpack to specify Server kwargs.
198 return AppRunner(self.app, **kwargs)
201class RawTestServer(BaseTestServer[BaseRequest]):
202 def __init__(
203 self,
204 handler: _RequestHandler[BaseRequest],
205 *,
206 scheme: str = "",
207 host: str = "127.0.0.1",
208 port: int | None = None,
209 **kwargs: Any,
210 ) -> None:
211 self._handler = handler
212 super().__init__(scheme=scheme, host=host, port=port, **kwargs)
214 async def _make_runner(self, **kwargs: Any) -> ServerRunner:
215 # TODO(PY311): Use Unpack to specify Server kwargs.
216 srv = Server(self._handler, **kwargs)
217 return ServerRunner(srv, **kwargs)
220class TestClient(Generic[_Request, _ApplicationNone]):
221 """
222 A test client implementation.
224 To write functional tests for aiohttp based servers.
226 """
228 __test__ = False
230 @overload
231 def __init__(
232 self: "TestClient[Request, Application]",
233 server: TestServer,
234 *,
235 cookie_jar: AbstractCookieJar | None = None,
236 **kwargs: Any,
237 ) -> None: ...
238 @overload
239 def __init__(
240 self: "TestClient[_Request, None]",
241 server: BaseTestServer[_Request],
242 *,
243 cookie_jar: AbstractCookieJar | None = None,
244 **kwargs: Any,
245 ) -> None: ...
246 def __init__( # type: ignore[misc]
247 self,
248 server: BaseTestServer[_Request],
249 *,
250 cookie_jar: AbstractCookieJar | None = None,
251 **kwargs: Any,
252 ) -> None:
253 # TODO(PY311): Use Unpack to specify ClientSession kwargs.
254 if not isinstance(server, BaseTestServer):
255 raise TypeError(
256 "server must be TestServer instance, found type: %r" % type(server)
257 )
258 self._server = server
259 if cookie_jar is None:
260 cookie_jar = aiohttp.CookieJar(unsafe=True)
261 self._session = ClientSession(cookie_jar=cookie_jar, **kwargs)
262 self._session._retry_connection = False
263 self._closed = False
264 self._responses: list[ClientResponse] = []
265 self._websockets: list[ClientWebSocketResponse[bool]] = []
267 async def start_server(self) -> None:
268 await self._server.start_server()
270 @property
271 def scheme(self) -> str | object:
272 return self._server.scheme
274 @property
275 def host(self) -> str:
276 return self._server.host
278 @property
279 def port(self) -> int:
280 return self._server.port
282 @property
283 def server(self) -> BaseTestServer[_Request]:
284 return self._server
286 @property
287 def app(self) -> _ApplicationNone:
288 return getattr(self._server, "app", None) # type: ignore[return-value]
290 @property
291 def session(self) -> ClientSession:
292 """An internal aiohttp.ClientSession.
294 Unlike the methods on the TestClient, client session requests
295 do not automatically include the host in the url queried, and
296 will require an absolute path to the resource.
298 """
299 return self._session
301 def make_url(self, path: StrOrURL) -> URL:
302 return self._server.make_url(path)
304 async def _request(
305 self, method: str, path: StrOrURL, **kwargs: Any
306 ) -> ClientResponse:
307 resp = await self._session.request(method, self.make_url(path), **kwargs)
308 # save it to close later
309 self._responses.append(resp)
310 return resp
312 if sys.version_info >= (3, 11) and TYPE_CHECKING:
314 def request(
315 self, method: str, path: StrOrURL, **kwargs: Unpack[_RequestOptions]
316 ) -> _RequestContextManager: ...
318 def get(
319 self,
320 path: StrOrURL,
321 **kwargs: Unpack[_RequestOptions],
322 ) -> _RequestContextManager: ...
324 def options(
325 self,
326 path: StrOrURL,
327 **kwargs: Unpack[_RequestOptions],
328 ) -> _RequestContextManager: ...
330 def head(
331 self,
332 path: StrOrURL,
333 **kwargs: Unpack[_RequestOptions],
334 ) -> _RequestContextManager: ...
336 def post(
337 self,
338 path: StrOrURL,
339 **kwargs: Unpack[_RequestOptions],
340 ) -> _RequestContextManager: ...
342 def put(
343 self,
344 path: StrOrURL,
345 **kwargs: Unpack[_RequestOptions],
346 ) -> _RequestContextManager: ...
348 def patch(
349 self,
350 path: StrOrURL,
351 **kwargs: Unpack[_RequestOptions],
352 ) -> _RequestContextManager: ...
354 def delete(
355 self,
356 path: StrOrURL,
357 **kwargs: Unpack[_RequestOptions],
358 ) -> _RequestContextManager: ...
360 else:
362 def request(
363 self, method: str, path: StrOrURL, **kwargs: Any
364 ) -> _RequestContextManager:
365 """Routes a request to tested http server.
367 The interface is identical to aiohttp.ClientSession.request,
368 except the loop kwarg is overridden by the instance used by the
369 test server.
371 """
372 return _RequestContextManager(self._request(method, path, **kwargs))
374 def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
375 """Perform an HTTP GET request."""
376 return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs))
378 def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
379 """Perform an HTTP POST request."""
380 return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs))
382 def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
383 """Perform an HTTP OPTIONS request."""
384 return _RequestContextManager(
385 self._request(hdrs.METH_OPTIONS, path, **kwargs)
386 )
388 def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
389 """Perform an HTTP HEAD request."""
390 return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs))
392 def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
393 """Perform an HTTP PUT request."""
394 return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs))
396 def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
397 """Perform an HTTP PATCH request."""
398 return _RequestContextManager(
399 self._request(hdrs.METH_PATCH, path, **kwargs)
400 )
402 def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
403 """Perform an HTTP PATCH request."""
404 return _RequestContextManager(
405 self._request(hdrs.METH_DELETE, path, **kwargs)
406 )
408 @overload
409 def ws_connect(
410 self, path: StrOrURL, *, decode_text: Literal[True] = ..., **kwargs: Any
411 ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[True]]]": ...
413 @overload
414 def ws_connect(
415 self, path: StrOrURL, *, decode_text: Literal[False], **kwargs: Any
416 ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]": ...
418 @overload
419 def ws_connect(
420 self, path: StrOrURL, *, decode_text: bool = ..., **kwargs: Any
421 ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": ...
423 def ws_connect(
424 self, path: StrOrURL, *, decode_text: bool = True, **kwargs: Any
425 ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]":
426 """Initiate websocket connection.
428 The api corresponds to aiohttp.ClientSession.ws_connect.
430 """
431 return _WSRequestContextManager(
432 self._ws_connect(path, decode_text=decode_text, **kwargs)
433 )
435 @overload
436 async def _ws_connect(
437 self, path: StrOrURL, *, decode_text: Literal[True] = ..., **kwargs: Any
438 ) -> "ClientWebSocketResponse[Literal[True]]": ...
440 @overload
441 async def _ws_connect(
442 self, path: StrOrURL, *, decode_text: Literal[False], **kwargs: Any
443 ) -> "ClientWebSocketResponse[Literal[False]]": ...
445 @overload
446 async def _ws_connect(
447 self, path: StrOrURL, *, decode_text: bool = ..., **kwargs: Any
448 ) -> "ClientWebSocketResponse[bool]": ...
450 async def _ws_connect(
451 self, path: StrOrURL, *, decode_text: bool = True, **kwargs: Any
452 ) -> "ClientWebSocketResponse[bool]":
453 ws = await self._session.ws_connect(
454 self.make_url(path), decode_text=decode_text, **kwargs
455 )
456 self._websockets.append(ws)
457 return ws
459 async def close(self) -> None:
460 """Close all fixtures created by the test client.
462 After that point, the TestClient is no longer usable.
464 This is an idempotent function: running close multiple times
465 will not have any additional effects.
467 close is also run on exit when used as a(n) (asynchronous)
468 context manager.
470 """
471 if not self._closed:
472 for resp in self._responses:
473 resp.close()
474 for ws in self._websockets:
475 await ws.close()
476 await self._session.close()
477 await self._server.close()
478 self._closed = True
480 async def __aenter__(self) -> Self:
481 await self.start_server()
482 return self
484 async def __aexit__(
485 self,
486 exc_type: type[BaseException] | None,
487 exc: BaseException | None,
488 tb: TracebackType | None,
489 ) -> None:
490 await self.close()
493class AioHTTPTestCase(IsolatedAsyncioTestCase, ABC):
494 """A base class to allow for unittest web applications using aiohttp.
496 Provides the following:
498 * self.client (aiohttp.test_utils.TestClient): an aiohttp test client.
499 * self.app (aiohttp.web.Application): the application returned by
500 self.get_application()
502 Note that the TestClient's methods are asynchronous: you have to
503 execute function on the test client using asynchronous methods.
504 """
506 @abstractmethod
507 async def get_application(self) -> Application:
508 """Get application.
510 This method should be overridden to return the aiohttp.web.Application
511 object to test.
512 """
514 async def asyncSetUp(self) -> None:
515 self.app = await self.get_application()
516 self.server = await self.get_server(self.app)
517 self.client = await self.get_client(self.server)
519 await self.client.start_server()
521 async def asyncTearDown(self) -> None:
522 await self.client.close()
524 async def get_server(self, app: Application) -> TestServer:
525 """Return a TestServer instance."""
526 return TestServer(app)
528 async def get_client(self, server: TestServer) -> TestClient[Request, Application]:
529 """Return a TestClient instance."""
530 return TestClient(server)
533def _create_app_mock() -> mock.MagicMock:
534 def get_dict(app: Any, key: str) -> Any:
535 return app.__app_dict[key]
537 def set_dict(app: Any, key: str, value: Any) -> None:
538 app.__app_dict[key] = value
540 app = mock.MagicMock(spec=Application)
541 app.__app_dict = {}
542 app.__getitem__ = get_dict
543 app.__setitem__ = set_dict
545 app.on_response_prepare = Signal(app)
546 app.on_response_prepare.freeze()
547 return app
550def _create_transport(sslcontext: SSLContext | None = None) -> mock.Mock:
551 transport = mock.Mock()
553 def get_extra_info(key: str) -> SSLContext | tuple[str, int] | None:
554 if key == "sslcontext":
555 return sslcontext
556 return ("127.0.0.1", 80) if key == "sockname" else None
558 transport.get_extra_info.side_effect = get_extra_info
559 return transport
562def make_mocked_request(
563 method: str,
564 path: str,
565 headers: LooseHeaders | None = None,
566 *,
567 match_info: dict[str, str] | None = None,
568 version: HttpVersion = HttpVersion(1, 1),
569 closing: bool = False,
570 app: Application | None = None,
571 writer: AbstractStreamWriter | None = None,
572 protocol: RequestHandler[Request] | None = None,
573 transport: asyncio.Transport | None = None,
574 payload: StreamReader = EMPTY_PAYLOAD,
575 sslcontext: SSLContext | None = None,
576 client_max_size: int = 1024**2,
577 loop: Any = ...,
578) -> Request:
579 """Creates mocked web.Request testing purposes.
581 Useful in unit tests, when spinning full web server is overkill or
582 specific conditions and errors are hard to trigger.
583 """
584 task = mock.Mock()
585 if loop is ...:
586 # no loop passed, try to get the current one if
587 # its is running as we need a real loop to create
588 # executor jobs to be able to do testing
589 # with a real executor
590 try:
591 loop = asyncio.get_running_loop()
592 except RuntimeError:
593 loop = mock.Mock()
594 loop.create_future.return_value = ()
596 if version < HttpVersion(1, 1):
597 closing = True
599 if headers:
600 headers = HeadersDictProxy(CIMultiDict(headers))
601 raw_hdrs = tuple(
602 (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items()
603 )
604 else:
605 headers = HeadersDictProxy(CIMultiDict())
606 raw_hdrs = ()
608 chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower()
610 message = RawRequestMessage(
611 method,
612 path,
613 version,
614 headers,
615 raw_hdrs,
616 closing,
617 None,
618 False,
619 chunked,
620 URL(path),
621 )
622 if app is None:
623 app = _create_app_mock()
625 if transport is None:
626 transport = _create_transport(sslcontext)
628 if protocol is None:
629 protocol = mock.Mock()
630 protocol.max_field_size = 8190
631 protocol.max_line_length = 8190
632 protocol.max_headers = 128
633 protocol.transport = transport
634 type(protocol).peername = mock.PropertyMock(
635 return_value=transport.get_extra_info("peername")
636 )
637 type(protocol).sockname = mock.PropertyMock(
638 return_value=transport.get_extra_info("sockname")
639 )
640 type(protocol).ssl_context = mock.PropertyMock(return_value=sslcontext)
642 if writer is None:
643 writer = mock.Mock()
644 writer.write_headers = mock.AsyncMock(return_value=None)
645 writer.write = mock.AsyncMock(return_value=None)
646 writer.write_eof = mock.AsyncMock(return_value=None)
647 writer.drain = mock.AsyncMock(return_value=None)
648 writer.transport = transport
650 protocol.transport = transport
652 req = Request(
653 message, payload, protocol, writer, task, loop, client_max_size=client_max_size
654 )
656 match_info = UrlMappingMatchInfo(
657 {} if match_info is None else match_info, mock.Mock()
658 )
659 match_info.add_app(app)
660 req._match_info = match_info
662 return req