Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/test_utils.py: 48%
319 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:52 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:52 +0000
1"""Utilities shared by tests."""
3import asyncio
4import contextlib
5import gc
6import inspect
7import ipaddress
8import os
9import socket
10import sys
11from abc import ABC, abstractmethod
12from types import TracebackType
13from typing import (
14 TYPE_CHECKING,
15 Any,
16 Callable,
17 Iterator,
18 List,
19 Optional,
20 Type,
21 Union,
22 cast,
23)
24from unittest import mock
26from aiosignal import Signal
27from multidict import CIMultiDict, CIMultiDictProxy
28from yarl import URL
30import aiohttp
31from aiohttp.client import _RequestContextManager, _WSRequestContextManager
33from . import ClientSession, hdrs
34from .abc import AbstractCookieJar
35from .client_reqrep import ClientResponse
36from .client_ws import ClientWebSocketResponse
37from .helpers import _SENTINEL, PY_38, sentinel
38from .http import HttpVersion, RawRequestMessage
39from .typedefs import StrOrURL
40from .web import (
41 Application,
42 AppRunner,
43 BaseRunner,
44 Request,
45 Server,
46 ServerRunner,
47 SockSite,
48 UrlMappingMatchInfo,
49)
50from .web_protocol import _RequestHandler
52if TYPE_CHECKING: # pragma: no cover
53 from ssl import SSLContext
54else:
55 SSLContext = None
57if PY_38:
58 from unittest import IsolatedAsyncioTestCase as TestCase
59else:
60 from asynctest import TestCase # type: ignore[no-redef]
62REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin"
65def get_unused_port_socket(
66 host: str, family: socket.AddressFamily = socket.AF_INET
67) -> socket.socket:
68 return get_port_socket(host, 0, family)
71def get_port_socket(
72 host: str, port: int, family: socket.AddressFamily = socket.AF_INET
73) -> socket.socket:
74 s = socket.socket(family, socket.SOCK_STREAM)
75 if REUSE_ADDRESS:
76 # Windows has different semantics for SO_REUSEADDR,
77 # so don't set it. Ref:
78 # https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
79 s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
80 s.bind((host, port))
81 return s
84def unused_port() -> int:
85 """Return a port that is unused on the current host."""
86 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
87 s.bind(("127.0.0.1", 0))
88 return cast(int, s.getsockname()[1])
91class BaseTestServer(ABC):
92 __test__ = False
94 def __init__(
95 self,
96 *,
97 scheme: Union[str, _SENTINEL] = sentinel,
98 host: str = "127.0.0.1",
99 port: Optional[int] = None,
100 skip_url_asserts: bool = False,
101 socket_factory: Callable[
102 [str, int, socket.AddressFamily], socket.socket
103 ] = get_port_socket,
104 **kwargs: Any,
105 ) -> None:
106 self.runner: Optional[BaseRunner] = None
107 self._root: Optional[URL] = None
108 self.host = host
109 self.port = port
110 self._closed = False
111 self.scheme = scheme
112 self.skip_url_asserts = skip_url_asserts
113 self.socket_factory = socket_factory
115 async def start_server(self, **kwargs: Any) -> None:
116 if self.runner:
117 return
118 self._ssl = kwargs.pop("ssl", None)
119 self.runner = await self._make_runner(handler_cancellation=True, **kwargs)
120 await self.runner.setup()
121 if not self.port:
122 self.port = 0
123 absolute_host = self.host
124 try:
125 version = ipaddress.ip_address(self.host).version
126 except ValueError:
127 version = 4
128 if version == 6:
129 absolute_host = f"[{self.host}]"
130 family = socket.AF_INET6 if version == 6 else socket.AF_INET
131 _sock = self.socket_factory(self.host, self.port, family)
132 self.host, self.port = _sock.getsockname()[:2]
133 site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl)
134 await site.start()
135 server = site._server
136 assert server is not None
137 sockets = server.sockets # type: ignore[attr-defined]
138 assert sockets is not None
139 self.port = sockets[0].getsockname()[1]
140 if self.scheme is sentinel:
141 if self._ssl:
142 scheme = "https"
143 else:
144 scheme = "http"
145 self.scheme = scheme
146 self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}")
148 @abstractmethod # pragma: no cover
149 async def _make_runner(self, **kwargs: Any) -> BaseRunner:
150 pass
152 def make_url(self, path: StrOrURL) -> URL:
153 assert self._root is not None
154 url = URL(path)
155 if not self.skip_url_asserts:
156 assert not url.is_absolute()
157 return self._root.join(url)
158 else:
159 return URL(str(self._root) + str(path))
161 @property
162 def started(self) -> bool:
163 return self.runner is not None
165 @property
166 def closed(self) -> bool:
167 return self._closed
169 @property
170 def handler(self) -> Server:
171 # for backward compatibility
172 # web.Server instance
173 runner = self.runner
174 assert runner is not None
175 assert runner.server is not None
176 return runner.server
178 async def close(self) -> None:
179 """Close all fixtures created by the test client.
181 After that point, the TestClient is no longer usable.
183 This is an idempotent function: running close multiple times
184 will not have any additional effects.
186 close is also run when the object is garbage collected, and on
187 exit when used as a context manager.
189 """
190 if self.started and not self.closed:
191 assert self.runner is not None
192 await self.runner.cleanup()
193 self._root = None
194 self.port = None
195 self._closed = True
197 async def __aenter__(self) -> "BaseTestServer":
198 await self.start_server()
199 return self
201 async def __aexit__(
202 self,
203 exc_type: Optional[Type[BaseException]],
204 exc_value: Optional[BaseException],
205 traceback: Optional[TracebackType],
206 ) -> None:
207 await self.close()
210class TestServer(BaseTestServer):
211 def __init__(
212 self,
213 app: Application,
214 *,
215 scheme: Union[str, _SENTINEL] = sentinel,
216 host: str = "127.0.0.1",
217 port: Optional[int] = None,
218 **kwargs: Any,
219 ):
220 self.app = app
221 super().__init__(scheme=scheme, host=host, port=port, **kwargs)
223 async def _make_runner(self, **kwargs: Any) -> BaseRunner:
224 return AppRunner(self.app, **kwargs)
227class RawTestServer(BaseTestServer):
228 def __init__(
229 self,
230 handler: _RequestHandler,
231 *,
232 scheme: Union[str, _SENTINEL] = sentinel,
233 host: str = "127.0.0.1",
234 port: Optional[int] = None,
235 **kwargs: Any,
236 ) -> None:
237 self._handler = handler
238 super().__init__(scheme=scheme, host=host, port=port, **kwargs)
240 async def _make_runner(self, **kwargs: Any) -> ServerRunner:
241 srv = Server(self._handler, **kwargs)
242 return ServerRunner(srv, **kwargs)
245class TestClient:
246 """
247 A test client implementation.
249 To write functional tests for aiohttp based servers.
251 """
253 __test__ = False
255 def __init__(
256 self,
257 server: BaseTestServer,
258 *,
259 cookie_jar: Optional[AbstractCookieJar] = None,
260 **kwargs: Any,
261 ) -> None:
262 if not isinstance(server, BaseTestServer):
263 raise TypeError(
264 "server must be TestServer " "instance, found type: %r" % type(server)
265 )
266 self._server = server
267 if cookie_jar is None:
268 cookie_jar = aiohttp.CookieJar(unsafe=True)
269 self._session = ClientSession(cookie_jar=cookie_jar, **kwargs)
270 self._closed = False
271 self._responses: List[ClientResponse] = []
272 self._websockets: List[ClientWebSocketResponse] = []
274 async def start_server(self) -> None:
275 await self._server.start_server()
277 @property
278 def scheme(self) -> Union[str, object]:
279 return self._server.scheme
281 @property
282 def host(self) -> str:
283 return self._server.host
285 @property
286 def port(self) -> Optional[int]:
287 return self._server.port
289 @property
290 def server(self) -> BaseTestServer:
291 return self._server
293 @property
294 def app(self) -> Optional[Application]:
295 return cast(Optional[Application], getattr(self._server, "app", None))
297 @property
298 def session(self) -> ClientSession:
299 """An internal aiohttp.ClientSession.
301 Unlike the methods on the TestClient, client session requests
302 do not automatically include the host in the url queried, and
303 will require an absolute path to the resource.
305 """
306 return self._session
308 def make_url(self, path: StrOrURL) -> URL:
309 return self._server.make_url(path)
311 async def _request(
312 self, method: str, path: StrOrURL, **kwargs: Any
313 ) -> ClientResponse:
314 resp = await self._session.request(method, self.make_url(path), **kwargs)
315 # save it to close later
316 self._responses.append(resp)
317 return resp
319 def request(
320 self, method: str, path: StrOrURL, **kwargs: Any
321 ) -> _RequestContextManager:
322 """Routes a request to tested http server.
324 The interface is identical to aiohttp.ClientSession.request,
325 except the loop kwarg is overridden by the instance used by the
326 test server.
328 """
329 return _RequestContextManager(self._request(method, path, **kwargs))
331 def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
332 """Perform an HTTP GET request."""
333 return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs))
335 def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
336 """Perform an HTTP POST request."""
337 return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs))
339 def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
340 """Perform an HTTP OPTIONS request."""
341 return _RequestContextManager(self._request(hdrs.METH_OPTIONS, path, **kwargs))
343 def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
344 """Perform an HTTP HEAD request."""
345 return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs))
347 def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
348 """Perform an HTTP PUT request."""
349 return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs))
351 def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
352 """Perform an HTTP PATCH request."""
353 return _RequestContextManager(self._request(hdrs.METH_PATCH, path, **kwargs))
355 def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
356 """Perform an HTTP PATCH request."""
357 return _RequestContextManager(self._request(hdrs.METH_DELETE, path, **kwargs))
359 def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager:
360 """Initiate websocket connection.
362 The api corresponds to aiohttp.ClientSession.ws_connect.
364 """
365 return _WSRequestContextManager(self._ws_connect(path, **kwargs))
367 async def _ws_connect(
368 self, path: StrOrURL, **kwargs: Any
369 ) -> ClientWebSocketResponse:
370 ws = await self._session.ws_connect(self.make_url(path), **kwargs)
371 self._websockets.append(ws)
372 return ws
374 async def close(self) -> None:
375 """Close all fixtures created by the test client.
377 After that point, the TestClient is no longer usable.
379 This is an idempotent function: running close multiple times
380 will not have any additional effects.
382 close is also run on exit when used as a(n) (asynchronous)
383 context manager.
385 """
386 if not self._closed:
387 for resp in self._responses:
388 resp.close()
389 for ws in self._websockets:
390 await ws.close()
391 await self._session.close()
392 await self._server.close()
393 self._closed = True
395 async def __aenter__(self) -> "TestClient":
396 await self.start_server()
397 return self
399 async def __aexit__(
400 self,
401 exc_type: Optional[Type[BaseException]],
402 exc: Optional[BaseException],
403 tb: Optional[TracebackType],
404 ) -> None:
405 await self.close()
408class AioHTTPTestCase(TestCase, ABC):
409 """A base class to allow for unittest web applications using aiohttp.
411 Provides the following:
413 * self.client (aiohttp.test_utils.TestClient): an aiohttp test client.
414 * self.app (aiohttp.web.Application): the application returned by
415 self.get_application()
417 Note that the TestClient's methods are asynchronous: you have to
418 execute function on the test client using asynchronous methods.
419 """
421 @abstractmethod
422 async def get_application(self) -> Application:
423 """Get application.
425 This method should be overridden to return the aiohttp.web.Application
426 object to test.
427 """
429 def setUp(self) -> None:
430 if not PY_38:
431 asyncio.get_event_loop().run_until_complete(self.asyncSetUp())
433 async def asyncSetUp(self) -> None:
434 self.app = await self.get_application()
435 self.server = await self.get_server(self.app)
436 self.client = await self.get_client(self.server)
438 await self.client.start_server()
440 def tearDown(self) -> None:
441 if not PY_38:
442 asyncio.get_event_loop().run_until_complete(self.asyncTearDown())
444 async def asyncTearDown(self) -> None:
445 await self.client.close()
447 async def get_server(self, app: Application) -> TestServer:
448 """Return a TestServer instance."""
449 return TestServer(app)
451 async def get_client(self, server: TestServer) -> TestClient:
452 """Return a TestClient instance."""
453 return TestClient(server)
456_LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop]
459@contextlib.contextmanager
460def loop_context(
461 loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False
462) -> Iterator[asyncio.AbstractEventLoop]:
463 """A contextmanager that creates an event_loop, for test purposes.
465 Handles the creation and cleanup of a test loop.
466 """
467 loop = setup_test_loop(loop_factory)
468 yield loop
469 teardown_test_loop(loop, fast=fast)
472def setup_test_loop(
473 loop_factory: _LOOP_FACTORY = asyncio.new_event_loop,
474) -> asyncio.AbstractEventLoop:
475 """Create and return an asyncio.BaseEventLoop instance.
477 The caller should also call teardown_test_loop,
478 once they are done with the loop.
479 """
480 loop = loop_factory()
481 try:
482 module = loop.__class__.__module__
483 skip_watcher = "uvloop" in module
484 except AttributeError: # pragma: no cover
485 # Just in case
486 skip_watcher = True
487 asyncio.set_event_loop(loop)
488 if sys.platform != "win32" and not skip_watcher:
489 policy = asyncio.get_event_loop_policy()
490 watcher: asyncio.AbstractChildWatcher
491 try: # Python >= 3.8
492 # Refs:
493 # * https://github.com/pytest-dev/pytest-xdist/issues/620
494 # * https://stackoverflow.com/a/58614689/595220
495 # * https://bugs.python.org/issue35621
496 # * https://github.com/python/cpython/pull/14344
497 watcher = asyncio.ThreadedChildWatcher()
498 except AttributeError: # Python < 3.8
499 watcher = asyncio.SafeChildWatcher()
500 watcher.attach_loop(loop)
501 with contextlib.suppress(NotImplementedError):
502 policy.set_child_watcher(watcher)
503 return loop
506def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None:
507 """Teardown and cleanup an event_loop created by setup_test_loop."""
508 closed = loop.is_closed()
509 if not closed:
510 loop.call_soon(loop.stop)
511 loop.run_forever()
512 loop.close()
514 if not fast:
515 gc.collect()
517 asyncio.set_event_loop(None)
520def _create_app_mock() -> mock.MagicMock:
521 def get_dict(app: Any, key: str) -> Any:
522 return app.__app_dict[key]
524 def set_dict(app: Any, key: str, value: Any) -> None:
525 app.__app_dict[key] = value
527 app = mock.MagicMock(spec=Application)
528 app.__app_dict = {}
529 app.__getitem__ = get_dict
530 app.__setitem__ = set_dict
532 app.on_response_prepare = Signal(app)
533 app.on_response_prepare.freeze()
534 return app
537def _create_transport(sslcontext: Optional[SSLContext] = None) -> mock.Mock:
538 transport = mock.Mock()
540 def get_extra_info(key: str) -> Optional[SSLContext]:
541 if key == "sslcontext":
542 return sslcontext
543 else:
544 return None
546 transport.get_extra_info.side_effect = get_extra_info
547 return transport
550def make_mocked_request(
551 method: str,
552 path: str,
553 headers: Any = None,
554 *,
555 match_info: Any = sentinel,
556 version: HttpVersion = HttpVersion(1, 1),
557 closing: bool = False,
558 app: Any = None,
559 writer: Any = sentinel,
560 protocol: Any = sentinel,
561 transport: Any = sentinel,
562 payload: Any = sentinel,
563 sslcontext: Optional[SSLContext] = None,
564 client_max_size: int = 1024**2,
565 loop: Any = ...,
566) -> Request:
567 """Creates mocked web.Request testing purposes.
569 Useful in unit tests, when spinning full web server is overkill or
570 specific conditions and errors are hard to trigger.
571 """
572 task = mock.Mock()
573 if loop is ...:
574 loop = mock.Mock()
575 loop.create_future.return_value = ()
577 if version < HttpVersion(1, 1):
578 closing = True
580 if headers:
581 headers = CIMultiDictProxy(CIMultiDict(headers))
582 raw_hdrs = tuple(
583 (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items()
584 )
585 else:
586 headers = CIMultiDictProxy(CIMultiDict())
587 raw_hdrs = ()
589 chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower()
591 message = RawRequestMessage(
592 method,
593 path,
594 version,
595 headers,
596 raw_hdrs,
597 closing,
598 None,
599 False,
600 chunked,
601 URL(path),
602 )
603 if app is None:
604 app = _create_app_mock()
606 if transport is sentinel:
607 transport = _create_transport(sslcontext)
609 if protocol is sentinel:
610 protocol = mock.Mock()
611 protocol.transport = transport
613 if writer is sentinel:
614 writer = mock.Mock()
615 writer.write_headers = make_mocked_coro(None)
616 writer.write = make_mocked_coro(None)
617 writer.write_eof = make_mocked_coro(None)
618 writer.drain = make_mocked_coro(None)
619 writer.transport = transport
621 protocol.transport = transport
622 protocol.writer = writer
624 if payload is sentinel:
625 payload = mock.Mock()
627 req = Request(
628 message, payload, protocol, writer, task, loop, client_max_size=client_max_size
629 )
631 match_info = UrlMappingMatchInfo(
632 {} if match_info is sentinel else match_info, mock.Mock()
633 )
634 match_info.add_app(app)
635 req._match_info = match_info
637 return req
640def make_mocked_coro(
641 return_value: Any = sentinel, raise_exception: Any = sentinel
642) -> Any:
643 """Creates a coroutine mock."""
645 async def mock_coro(*args: Any, **kwargs: Any) -> Any:
646 if raise_exception is not sentinel:
647 raise raise_exception
648 if not inspect.isawaitable(return_value):
649 return return_value
650 await return_value
652 return mock.Mock(wraps=mock_coro)