Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/test_utils.py: 50%
300 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-26 06:16 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-26 06:16 +0000
1"""Utilities shared by tests."""
3import asyncio
4import contextlib
5import gc
6import inspect
7import ipaddress
8import os
9import socket
10import sys
11from abc import ABC, abstractmethod
12from types import TracebackType
13from typing import (
14 TYPE_CHECKING,
15 Any,
16 Callable,
17 Iterator,
18 List,
19 Optional,
20 Type,
21 Union,
22 cast,
23)
24from unittest import IsolatedAsyncioTestCase, mock
26from aiosignal import Signal
27from multidict import CIMultiDict, CIMultiDictProxy
28from yarl import URL
30import aiohttp
31from aiohttp.client import _RequestContextManager, _WSRequestContextManager
33from . import ClientSession, hdrs
34from .abc import AbstractCookieJar
35from .client_reqrep import ClientResponse
36from .client_ws import ClientWebSocketResponse
37from .helpers import _SENTINEL, sentinel
38from .http import HttpVersion, RawRequestMessage
39from .typedefs import StrOrURL
40from .web import (
41 Application,
42 AppRunner,
43 BaseRunner,
44 Request,
45 Server,
46 ServerRunner,
47 SockSite,
48 UrlMappingMatchInfo,
49)
50from .web_protocol import _RequestHandler
52if TYPE_CHECKING:
53 from ssl import SSLContext
54else:
55 SSLContext = None
57REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin"
60def get_unused_port_socket(
61 host: str, family: socket.AddressFamily = socket.AF_INET
62) -> socket.socket:
63 return get_port_socket(host, 0, family)
66def get_port_socket(
67 host: str, port: int, family: socket.AddressFamily = socket.AF_INET
68) -> socket.socket:
69 s = socket.socket(family, socket.SOCK_STREAM)
70 if REUSE_ADDRESS:
71 # Windows has different semantics for SO_REUSEADDR,
72 # so don't set it. Ref:
73 # https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
74 s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
75 s.bind((host, port))
76 return s
79def unused_port() -> int:
80 """Return a port that is unused on the current host."""
81 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
82 s.bind(("127.0.0.1", 0))
83 return cast(int, s.getsockname()[1])
86class BaseTestServer(ABC):
87 __test__ = False
89 def __init__(
90 self,
91 *,
92 scheme: Union[str, _SENTINEL] = sentinel,
93 host: str = "127.0.0.1",
94 port: Optional[int] = None,
95 skip_url_asserts: bool = False,
96 socket_factory: Callable[
97 [str, int, socket.AddressFamily], socket.socket
98 ] = get_port_socket,
99 **kwargs: Any,
100 ) -> None:
101 self.runner: Optional[BaseRunner] = None
102 self._root: Optional[URL] = None
103 self.host = host
104 self.port = port
105 self._closed = False
106 self.scheme = scheme
107 self.skip_url_asserts = skip_url_asserts
108 self.socket_factory = socket_factory
110 async def start_server(self, **kwargs: Any) -> None:
111 if self.runner:
112 return
113 self._ssl = kwargs.pop("ssl", None)
114 self.runner = await self._make_runner(handler_cancellation=True, **kwargs)
115 await self.runner.setup()
116 if not self.port:
117 self.port = 0
118 absolute_host = self.host
119 try:
120 version = ipaddress.ip_address(self.host).version
121 except ValueError:
122 version = 4
123 if version == 6:
124 absolute_host = f"[{self.host}]"
125 family = socket.AF_INET6 if version == 6 else socket.AF_INET
126 _sock = self.socket_factory(self.host, self.port, family)
127 self.host, self.port = _sock.getsockname()[:2]
128 site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl)
129 await site.start()
130 server = site._server
131 assert server is not None
132 sockets = server.sockets # type: ignore[attr-defined]
133 assert sockets is not None
134 self.port = sockets[0].getsockname()[1]
135 if self.scheme is sentinel:
136 if self._ssl:
137 scheme = "https"
138 else:
139 scheme = "http"
140 self.scheme = scheme
141 self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}")
143 @abstractmethod # pragma: no cover
144 async def _make_runner(self, **kwargs: Any) -> BaseRunner:
145 pass
147 def make_url(self, path: StrOrURL) -> URL:
148 assert self._root is not None
149 url = URL(path)
150 if not self.skip_url_asserts:
151 assert not url.is_absolute()
152 return self._root.join(url)
153 else:
154 return URL(str(self._root) + str(path))
156 @property
157 def started(self) -> bool:
158 return self.runner is not None
160 @property
161 def closed(self) -> bool:
162 return self._closed
164 @property
165 def handler(self) -> Server:
166 # for backward compatibility
167 # web.Server instance
168 runner = self.runner
169 assert runner is not None
170 assert runner.server is not None
171 return runner.server
173 async def close(self) -> None:
174 """Close all fixtures created by the test client.
176 After that point, the TestClient is no longer usable.
178 This is an idempotent function: running close multiple times
179 will not have any additional effects.
181 close is also run when the object is garbage collected, and on
182 exit when used as a context manager.
184 """
185 if self.started and not self.closed:
186 assert self.runner is not None
187 await self.runner.cleanup()
188 self._root = None
189 self.port = None
190 self._closed = True
192 async def __aenter__(self) -> "BaseTestServer":
193 await self.start_server()
194 return self
196 async def __aexit__(
197 self,
198 exc_type: Optional[Type[BaseException]],
199 exc_value: Optional[BaseException],
200 traceback: Optional[TracebackType],
201 ) -> None:
202 await self.close()
205class TestServer(BaseTestServer):
206 def __init__(
207 self,
208 app: Application,
209 *,
210 scheme: Union[str, _SENTINEL] = sentinel,
211 host: str = "127.0.0.1",
212 port: Optional[int] = None,
213 **kwargs: Any,
214 ):
215 self.app = app
216 super().__init__(scheme=scheme, host=host, port=port, **kwargs)
218 async def _make_runner(self, **kwargs: Any) -> BaseRunner:
219 return AppRunner(self.app, **kwargs)
222class RawTestServer(BaseTestServer):
223 def __init__(
224 self,
225 handler: _RequestHandler,
226 *,
227 scheme: Union[str, _SENTINEL] = sentinel,
228 host: str = "127.0.0.1",
229 port: Optional[int] = None,
230 **kwargs: Any,
231 ) -> None:
232 self._handler = handler
233 super().__init__(scheme=scheme, host=host, port=port, **kwargs)
235 async def _make_runner(self, **kwargs: Any) -> ServerRunner:
236 srv = Server(self._handler, **kwargs)
237 return ServerRunner(srv, **kwargs)
240class TestClient:
241 """
242 A test client implementation.
244 To write functional tests for aiohttp based servers.
246 """
248 __test__ = False
250 def __init__(
251 self,
252 server: BaseTestServer,
253 *,
254 cookie_jar: Optional[AbstractCookieJar] = None,
255 **kwargs: Any,
256 ) -> None:
257 if not isinstance(server, BaseTestServer):
258 raise TypeError(
259 "server must be TestServer " "instance, found type: %r" % type(server)
260 )
261 self._server = server
262 if cookie_jar is None:
263 cookie_jar = aiohttp.CookieJar(unsafe=True)
264 self._session = ClientSession(cookie_jar=cookie_jar, **kwargs)
265 self._closed = False
266 self._responses: List[ClientResponse] = []
267 self._websockets: List[ClientWebSocketResponse] = []
269 async def start_server(self) -> None:
270 await self._server.start_server()
272 @property
273 def scheme(self) -> Union[str, object]:
274 return self._server.scheme
276 @property
277 def host(self) -> str:
278 return self._server.host
280 @property
281 def port(self) -> Optional[int]:
282 return self._server.port
284 @property
285 def server(self) -> BaseTestServer:
286 return self._server
288 @property
289 def app(self) -> Optional[Application]:
290 return cast(Optional[Application], getattr(self._server, "app", None))
292 @property
293 def session(self) -> ClientSession:
294 """An internal aiohttp.ClientSession.
296 Unlike the methods on the TestClient, client session requests
297 do not automatically include the host in the url queried, and
298 will require an absolute path to the resource.
300 """
301 return self._session
303 def make_url(self, path: StrOrURL) -> URL:
304 return self._server.make_url(path)
306 async def _request(
307 self, method: str, path: StrOrURL, **kwargs: Any
308 ) -> ClientResponse:
309 resp = await self._session.request(method, self.make_url(path), **kwargs)
310 # save it to close later
311 self._responses.append(resp)
312 return resp
314 def request(
315 self, method: str, path: StrOrURL, **kwargs: Any
316 ) -> _RequestContextManager:
317 """Routes a request to tested http server.
319 The interface is identical to aiohttp.ClientSession.request,
320 except the loop kwarg is overridden by the instance used by the
321 test server.
323 """
324 return _RequestContextManager(self._request(method, path, **kwargs))
326 def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
327 """Perform an HTTP GET request."""
328 return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs))
330 def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
331 """Perform an HTTP POST request."""
332 return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs))
334 def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
335 """Perform an HTTP OPTIONS request."""
336 return _RequestContextManager(self._request(hdrs.METH_OPTIONS, path, **kwargs))
338 def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
339 """Perform an HTTP HEAD request."""
340 return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs))
342 def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
343 """Perform an HTTP PUT request."""
344 return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs))
346 def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
347 """Perform an HTTP PATCH request."""
348 return _RequestContextManager(self._request(hdrs.METH_PATCH, path, **kwargs))
350 def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
351 """Perform an HTTP PATCH request."""
352 return _RequestContextManager(self._request(hdrs.METH_DELETE, path, **kwargs))
354 def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager:
355 """Initiate websocket connection.
357 The api corresponds to aiohttp.ClientSession.ws_connect.
359 """
360 return _WSRequestContextManager(self._ws_connect(path, **kwargs))
362 async def _ws_connect(
363 self, path: StrOrURL, **kwargs: Any
364 ) -> ClientWebSocketResponse:
365 ws = await self._session.ws_connect(self.make_url(path), **kwargs)
366 self._websockets.append(ws)
367 return ws
369 async def close(self) -> None:
370 """Close all fixtures created by the test client.
372 After that point, the TestClient is no longer usable.
374 This is an idempotent function: running close multiple times
375 will not have any additional effects.
377 close is also run on exit when used as a(n) (asynchronous)
378 context manager.
380 """
381 if not self._closed:
382 for resp in self._responses:
383 resp.close()
384 for ws in self._websockets:
385 await ws.close()
386 await self._session.close()
387 await self._server.close()
388 self._closed = True
390 async def __aenter__(self) -> "TestClient":
391 await self.start_server()
392 return self
394 async def __aexit__(
395 self,
396 exc_type: Optional[Type[BaseException]],
397 exc: Optional[BaseException],
398 tb: Optional[TracebackType],
399 ) -> None:
400 await self.close()
403class AioHTTPTestCase(IsolatedAsyncioTestCase, ABC):
404 """A base class to allow for unittest web applications using aiohttp.
406 Provides the following:
408 * self.client (aiohttp.test_utils.TestClient): an aiohttp test client.
409 * self.app (aiohttp.web.Application): the application returned by
410 self.get_application()
412 Note that the TestClient's methods are asynchronous: you have to
413 execute function on the test client using asynchronous methods.
414 """
416 @abstractmethod
417 async def get_application(self) -> Application:
418 """Get application.
420 This method should be overridden to return the aiohttp.web.Application
421 object to test.
422 """
424 async def asyncSetUp(self) -> None:
425 self.app = await self.get_application()
426 self.server = await self.get_server(self.app)
427 self.client = await self.get_client(self.server)
429 await self.client.start_server()
431 async def asyncTearDown(self) -> None:
432 await self.client.close()
434 async def get_server(self, app: Application) -> TestServer:
435 """Return a TestServer instance."""
436 return TestServer(app)
438 async def get_client(self, server: TestServer) -> TestClient:
439 """Return a TestClient instance."""
440 return TestClient(server)
443_LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop]
446@contextlib.contextmanager
447def loop_context(
448 loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False
449) -> Iterator[asyncio.AbstractEventLoop]:
450 """A contextmanager that creates an event_loop, for test purposes.
452 Handles the creation and cleanup of a test loop.
453 """
454 loop = setup_test_loop(loop_factory)
455 yield loop
456 teardown_test_loop(loop, fast=fast)
459def setup_test_loop(
460 loop_factory: _LOOP_FACTORY = asyncio.new_event_loop,
461) -> asyncio.AbstractEventLoop:
462 """Create and return an asyncio.BaseEventLoop instance.
464 The caller should also call teardown_test_loop,
465 once they are done with the loop.
466 """
467 loop = loop_factory()
468 asyncio.set_event_loop(loop)
469 return loop
472def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None:
473 """Teardown and cleanup an event_loop created by setup_test_loop."""
474 closed = loop.is_closed()
475 if not closed:
476 loop.call_soon(loop.stop)
477 loop.run_forever()
478 loop.close()
480 if not fast:
481 gc.collect()
483 asyncio.set_event_loop(None)
486def _create_app_mock() -> mock.MagicMock:
487 def get_dict(app: Any, key: str) -> Any:
488 return app.__app_dict[key]
490 def set_dict(app: Any, key: str, value: Any) -> None:
491 app.__app_dict[key] = value
493 app = mock.MagicMock(spec=Application)
494 app.__app_dict = {}
495 app.__getitem__ = get_dict
496 app.__setitem__ = set_dict
498 app.on_response_prepare = Signal(app)
499 app.on_response_prepare.freeze()
500 return app
503def _create_transport(sslcontext: Optional[SSLContext] = None) -> mock.Mock:
504 transport = mock.Mock()
506 def get_extra_info(key: str) -> Optional[SSLContext]:
507 if key == "sslcontext":
508 return sslcontext
509 else:
510 return None
512 transport.get_extra_info.side_effect = get_extra_info
513 return transport
516def make_mocked_request(
517 method: str,
518 path: str,
519 headers: Any = None,
520 *,
521 match_info: Any = sentinel,
522 version: HttpVersion = HttpVersion(1, 1),
523 closing: bool = False,
524 app: Any = None,
525 writer: Any = sentinel,
526 protocol: Any = sentinel,
527 transport: Any = sentinel,
528 payload: Any = sentinel,
529 sslcontext: Optional[SSLContext] = None,
530 client_max_size: int = 1024**2,
531 loop: Any = ...,
532) -> Request:
533 """Creates mocked web.Request testing purposes.
535 Useful in unit tests, when spinning full web server is overkill or
536 specific conditions and errors are hard to trigger.
537 """
538 task = mock.Mock()
539 if loop is ...:
540 loop = mock.Mock()
541 loop.create_future.return_value = ()
543 if version < HttpVersion(1, 1):
544 closing = True
546 if headers:
547 headers = CIMultiDictProxy(CIMultiDict(headers))
548 raw_hdrs = tuple(
549 (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items()
550 )
551 else:
552 headers = CIMultiDictProxy(CIMultiDict())
553 raw_hdrs = ()
555 chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower()
557 message = RawRequestMessage(
558 method,
559 path,
560 version,
561 headers,
562 raw_hdrs,
563 closing,
564 None,
565 False,
566 chunked,
567 URL(path),
568 )
569 if app is None:
570 app = _create_app_mock()
572 if transport is sentinel:
573 transport = _create_transport(sslcontext)
575 if protocol is sentinel:
576 protocol = mock.Mock()
577 protocol.transport = transport
579 if writer is sentinel:
580 writer = mock.Mock()
581 writer.write_headers = make_mocked_coro(None)
582 writer.write = make_mocked_coro(None)
583 writer.write_eof = make_mocked_coro(None)
584 writer.drain = make_mocked_coro(None)
585 writer.transport = transport
587 protocol.transport = transport
588 protocol.writer = writer
590 if payload is sentinel:
591 payload = mock.Mock()
593 req = Request(
594 message, payload, protocol, writer, task, loop, client_max_size=client_max_size
595 )
597 match_info = UrlMappingMatchInfo(
598 {} if match_info is sentinel else match_info, mock.Mock()
599 )
600 match_info.add_app(app)
601 req._match_info = match_info
603 return req
606def make_mocked_coro(
607 return_value: Any = sentinel, raise_exception: Any = sentinel
608) -> Any:
609 """Creates a coroutine mock."""
611 async def mock_coro(*args: Any, **kwargs: Any) -> Any:
612 if raise_exception is not sentinel:
613 raise raise_exception
614 if not inspect.isawaitable(return_value):
615 return return_value
616 await return_value
618 return mock.Mock(wraps=mock_coro)