Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/test_utils.py: 43%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Utilities shared by tests."""
3import asyncio
4import contextlib
5import gc
6import ipaddress
7import os
8import socket
9import sys
10from abc import ABC, abstractmethod
11from types import TracebackType
12from typing import (
13 TYPE_CHECKING,
14 Any,
15 Callable,
16 Dict,
17 Generic,
18 Iterator,
19 List,
20 Optional,
21 Type,
22 TypeVar,
23 Union,
24 cast,
25 overload,
26)
27from unittest import IsolatedAsyncioTestCase, mock
29from aiosignal import Signal
30from multidict import CIMultiDict, CIMultiDictProxy
31from yarl import URL
33import aiohttp
34from aiohttp.client import (
35 _RequestContextManager,
36 _RequestOptions,
37 _WSRequestContextManager,
38)
40from . import ClientSession, hdrs
41from .abc import AbstractCookieJar, AbstractStreamWriter
42from .client_reqrep import ClientResponse
43from .client_ws import ClientWebSocketResponse
44from .http import HttpVersion, RawRequestMessage
45from .streams import EMPTY_PAYLOAD, StreamReader
46from .typedefs import LooseHeaders, StrOrURL
47from .web import (
48 Application,
49 AppRunner,
50 BaseRequest,
51 BaseRunner,
52 Request,
53 RequestHandler,
54 Server,
55 ServerRunner,
56 SockSite,
57 UrlMappingMatchInfo,
58)
59from .web_protocol import _RequestHandler
61if TYPE_CHECKING:
62 from ssl import SSLContext
63else:
64 SSLContext = None
66if sys.version_info >= (3, 11) and TYPE_CHECKING:
67 from typing import Unpack
69if sys.version_info >= (3, 11):
70 from typing import Self
71else:
72 Self = Any
74_ApplicationNone = TypeVar("_ApplicationNone", Application, None)
75_Request = TypeVar("_Request", bound=BaseRequest)
77REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin"
80def get_unused_port_socket(
81 host: str, family: socket.AddressFamily = socket.AF_INET
82) -> socket.socket:
83 return get_port_socket(host, 0, family)
86def get_port_socket(
87 host: str, port: int, family: socket.AddressFamily = socket.AF_INET
88) -> socket.socket:
89 s = socket.socket(family, socket.SOCK_STREAM)
90 if REUSE_ADDRESS:
91 # Windows has different semantics for SO_REUSEADDR,
92 # so don't set it. Ref:
93 # https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
94 s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
95 s.bind((host, port))
96 return s
99def unused_port() -> int:
100 """Return a port that is unused on the current host."""
101 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
102 s.bind(("127.0.0.1", 0))
103 return cast(int, s.getsockname()[1])
106class BaseTestServer(ABC, Generic[_Request]):
107 __test__ = False
109 def __init__(
110 self,
111 *,
112 scheme: str = "",
113 host: str = "127.0.0.1",
114 port: Optional[int] = None,
115 skip_url_asserts: bool = False,
116 socket_factory: Callable[
117 [str, int, socket.AddressFamily], socket.socket
118 ] = get_port_socket,
119 **kwargs: Any,
120 ) -> None:
121 self.runner: Optional[BaseRunner[_Request]] = None
122 self._root: Optional[URL] = None
123 self.host = host
124 self.port = port or 0
125 self._closed = False
126 self.scheme = scheme
127 self.skip_url_asserts = skip_url_asserts
128 self.socket_factory = socket_factory
130 async def start_server(self, **kwargs: Any) -> None:
131 if self.runner:
132 return
133 self._ssl = kwargs.pop("ssl", None)
134 self.runner = await self._make_runner(handler_cancellation=True, **kwargs)
135 await self.runner.setup()
136 absolute_host = self.host
137 try:
138 version = ipaddress.ip_address(self.host).version
139 except ValueError:
140 version = 4
141 if version == 6:
142 absolute_host = f"[{self.host}]"
143 family = socket.AF_INET6 if version == 6 else socket.AF_INET
144 _sock = self.socket_factory(self.host, self.port, family)
145 self.host, self.port = _sock.getsockname()[:2]
146 site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl)
147 await site.start()
148 server = site._server
149 assert server is not None
150 sockets = server.sockets
151 assert sockets is not None
152 self.port = sockets[0].getsockname()[1]
153 if not self.scheme:
154 self.scheme = "https" if self._ssl else "http"
155 self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}")
157 @abstractmethod
158 async def _make_runner(self, **kwargs: Any) -> BaseRunner[_Request]: # type: ignore[misc]
159 """Return a new runner for the server."""
160 # TODO(PY311): Use Unpack to specify Server kwargs.
162 def make_url(self, path: StrOrURL) -> URL:
163 assert self._root is not None
164 url = URL(path)
165 if not self.skip_url_asserts:
166 assert not url.absolute
167 return self._root.join(url)
168 else:
169 return URL(str(self._root) + str(path))
171 @property
172 def started(self) -> bool:
173 return self.runner is not None
175 @property
176 def closed(self) -> bool:
177 return self._closed
179 @property
180 def handler(self) -> Server[_Request]:
181 # for backward compatibility
182 # web.Server instance
183 runner = self.runner
184 assert runner is not None
185 assert runner.server is not None
186 return runner.server
188 async def close(self) -> None:
189 """Close all fixtures created by the test client.
191 After that point, the TestClient is no longer usable.
193 This is an idempotent function: running close multiple times
194 will not have any additional effects.
196 close is also run when the object is garbage collected, and on
197 exit when used as a context manager.
199 """
200 if self.started and not self.closed:
201 assert self.runner is not None
202 await self.runner.cleanup()
203 self._root = None
204 self.port = 0
205 self._closed = True
207 async def __aenter__(self) -> Self:
208 await self.start_server()
209 return self
211 async def __aexit__(
212 self,
213 exc_type: Optional[Type[BaseException]],
214 exc_value: Optional[BaseException],
215 traceback: Optional[TracebackType],
216 ) -> None:
217 await self.close()
220class TestServer(BaseTestServer[Request]):
221 def __init__(
222 self,
223 app: Application,
224 *,
225 scheme: str = "",
226 host: str = "127.0.0.1",
227 port: Optional[int] = None,
228 **kwargs: Any,
229 ):
230 self.app = app
231 super().__init__(scheme=scheme, host=host, port=port, **kwargs)
233 async def _make_runner(self, **kwargs: Any) -> AppRunner:
234 # TODO(PY311): Use Unpack to specify Server kwargs.
235 return AppRunner(self.app, **kwargs)
238class RawTestServer(BaseTestServer[BaseRequest]):
239 def __init__(
240 self,
241 handler: _RequestHandler[BaseRequest],
242 *,
243 scheme: str = "",
244 host: str = "127.0.0.1",
245 port: Optional[int] = None,
246 **kwargs: Any,
247 ) -> None:
248 self._handler = handler
249 super().__init__(scheme=scheme, host=host, port=port, **kwargs)
251 async def _make_runner(self, **kwargs: Any) -> ServerRunner:
252 # TODO(PY311): Use Unpack to specify Server kwargs.
253 srv = Server(self._handler, **kwargs)
254 return ServerRunner(srv, **kwargs)
257class TestClient(Generic[_Request, _ApplicationNone]):
258 """
259 A test client implementation.
261 To write functional tests for aiohttp based servers.
263 """
265 __test__ = False
267 @overload
268 def __init__( # type: ignore[misc]
269 self: "TestClient[Request, Application]",
270 server: TestServer,
271 *,
272 cookie_jar: Optional[AbstractCookieJar] = None,
273 **kwargs: Any,
274 ) -> None: ...
275 @overload
276 def __init__( # type: ignore[misc]
277 self: "TestClient[_Request, None]",
278 server: BaseTestServer[_Request],
279 *,
280 cookie_jar: Optional[AbstractCookieJar] = None,
281 **kwargs: Any,
282 ) -> None: ...
283 def __init__( # type: ignore[misc]
284 self,
285 server: BaseTestServer[_Request],
286 *,
287 cookie_jar: Optional[AbstractCookieJar] = None,
288 **kwargs: Any,
289 ) -> None:
290 # TODO(PY311): Use Unpack to specify ClientSession kwargs.
291 if not isinstance(server, BaseTestServer):
292 raise TypeError(
293 "server must be TestServer instance, found type: %r" % type(server)
294 )
295 self._server = server
296 if cookie_jar is None:
297 cookie_jar = aiohttp.CookieJar(unsafe=True)
298 self._session = ClientSession(cookie_jar=cookie_jar, **kwargs)
299 self._session._retry_connection = False
300 self._closed = False
301 self._responses: List[ClientResponse] = []
302 self._websockets: List[ClientWebSocketResponse] = []
304 async def start_server(self) -> None:
305 await self._server.start_server()
307 @property
308 def scheme(self) -> Union[str, object]:
309 return self._server.scheme
311 @property
312 def host(self) -> str:
313 return self._server.host
315 @property
316 def port(self) -> int:
317 return self._server.port
319 @property
320 def server(self) -> BaseTestServer[_Request]:
321 return self._server
323 @property
324 def app(self) -> _ApplicationNone:
325 return getattr(self._server, "app", None) # type: ignore[return-value]
327 @property
328 def session(self) -> ClientSession:
329 """An internal aiohttp.ClientSession.
331 Unlike the methods on the TestClient, client session requests
332 do not automatically include the host in the url queried, and
333 will require an absolute path to the resource.
335 """
336 return self._session
338 def make_url(self, path: StrOrURL) -> URL:
339 return self._server.make_url(path)
341 async def _request(
342 self, method: str, path: StrOrURL, **kwargs: Any
343 ) -> ClientResponse:
344 resp = await self._session.request(method, self.make_url(path), **kwargs)
345 # save it to close later
346 self._responses.append(resp)
347 return resp
349 if sys.version_info >= (3, 11) and TYPE_CHECKING:
351 def request(
352 self, method: str, path: StrOrURL, **kwargs: Unpack[_RequestOptions]
353 ) -> _RequestContextManager: ...
355 def get(
356 self,
357 path: StrOrURL,
358 **kwargs: Unpack[_RequestOptions],
359 ) -> _RequestContextManager: ...
361 def options(
362 self,
363 path: StrOrURL,
364 **kwargs: Unpack[_RequestOptions],
365 ) -> _RequestContextManager: ...
367 def head(
368 self,
369 path: StrOrURL,
370 **kwargs: Unpack[_RequestOptions],
371 ) -> _RequestContextManager: ...
373 def post(
374 self,
375 path: StrOrURL,
376 **kwargs: Unpack[_RequestOptions],
377 ) -> _RequestContextManager: ...
379 def put(
380 self,
381 path: StrOrURL,
382 **kwargs: Unpack[_RequestOptions],
383 ) -> _RequestContextManager: ...
385 def patch(
386 self,
387 path: StrOrURL,
388 **kwargs: Unpack[_RequestOptions],
389 ) -> _RequestContextManager: ...
391 def delete(
392 self,
393 path: StrOrURL,
394 **kwargs: Unpack[_RequestOptions],
395 ) -> _RequestContextManager: ...
397 else:
399 def request(
400 self, method: str, path: StrOrURL, **kwargs: Any
401 ) -> _RequestContextManager:
402 """Routes a request to tested http server.
404 The interface is identical to aiohttp.ClientSession.request,
405 except the loop kwarg is overridden by the instance used by the
406 test server.
408 """
409 return _RequestContextManager(self._request(method, path, **kwargs))
411 def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
412 """Perform an HTTP GET request."""
413 return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs))
415 def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
416 """Perform an HTTP POST request."""
417 return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs))
419 def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
420 """Perform an HTTP OPTIONS request."""
421 return _RequestContextManager(
422 self._request(hdrs.METH_OPTIONS, path, **kwargs)
423 )
425 def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
426 """Perform an HTTP HEAD request."""
427 return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs))
429 def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
430 """Perform an HTTP PUT request."""
431 return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs))
433 def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
434 """Perform an HTTP PATCH request."""
435 return _RequestContextManager(
436 self._request(hdrs.METH_PATCH, path, **kwargs)
437 )
439 def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
440 """Perform an HTTP PATCH request."""
441 return _RequestContextManager(
442 self._request(hdrs.METH_DELETE, path, **kwargs)
443 )
445 def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager:
446 """Initiate websocket connection.
448 The api corresponds to aiohttp.ClientSession.ws_connect.
450 """
451 return _WSRequestContextManager(self._ws_connect(path, **kwargs))
453 async def _ws_connect(
454 self, path: StrOrURL, **kwargs: Any
455 ) -> ClientWebSocketResponse:
456 ws = await self._session.ws_connect(self.make_url(path), **kwargs)
457 self._websockets.append(ws)
458 return ws
460 async def close(self) -> None:
461 """Close all fixtures created by the test client.
463 After that point, the TestClient is no longer usable.
465 This is an idempotent function: running close multiple times
466 will not have any additional effects.
468 close is also run on exit when used as a(n) (asynchronous)
469 context manager.
471 """
472 if not self._closed:
473 for resp in self._responses:
474 resp.close()
475 for ws in self._websockets:
476 await ws.close()
477 await self._session.close()
478 await self._server.close()
479 self._closed = True
481 async def __aenter__(self) -> Self:
482 await self.start_server()
483 return self
485 async def __aexit__(
486 self,
487 exc_type: Optional[Type[BaseException]],
488 exc: Optional[BaseException],
489 tb: Optional[TracebackType],
490 ) -> None:
491 await self.close()
494class AioHTTPTestCase(IsolatedAsyncioTestCase, ABC):
495 """A base class to allow for unittest web applications using aiohttp.
497 Provides the following:
499 * self.client (aiohttp.test_utils.TestClient): an aiohttp test client.
500 * self.app (aiohttp.web.Application): the application returned by
501 self.get_application()
503 Note that the TestClient's methods are asynchronous: you have to
504 execute function on the test client using asynchronous methods.
505 """
507 @abstractmethod
508 async def get_application(self) -> Application:
509 """Get application.
511 This method should be overridden to return the aiohttp.web.Application
512 object to test.
513 """
515 async def asyncSetUp(self) -> None:
516 self.app = await self.get_application()
517 self.server = await self.get_server(self.app)
518 self.client = await self.get_client(self.server)
520 await self.client.start_server()
522 async def asyncTearDown(self) -> None:
523 await self.client.close()
525 async def get_server(self, app: Application) -> TestServer:
526 """Return a TestServer instance."""
527 return TestServer(app)
529 async def get_client(self, server: TestServer) -> TestClient[Request, Application]:
530 """Return a TestClient instance."""
531 return TestClient(server)
534_LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop]
537@contextlib.contextmanager
538def loop_context(
539 loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False
540) -> Iterator[asyncio.AbstractEventLoop]:
541 """A contextmanager that creates an event_loop, for test purposes.
543 Handles the creation and cleanup of a test loop.
544 """
545 loop = setup_test_loop(loop_factory)
546 yield loop
547 teardown_test_loop(loop, fast=fast)
550def setup_test_loop(
551 loop_factory: _LOOP_FACTORY = asyncio.new_event_loop,
552) -> asyncio.AbstractEventLoop:
553 """Create and return an asyncio.BaseEventLoop instance.
555 The caller should also call teardown_test_loop,
556 once they are done with the loop.
557 """
558 loop = loop_factory()
559 asyncio.set_event_loop(loop)
560 return loop
563def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None:
564 """Teardown and cleanup an event_loop created by setup_test_loop."""
565 closed = loop.is_closed()
566 if not closed:
567 loop.call_soon(loop.stop)
568 loop.run_forever()
569 loop.close()
571 if not fast:
572 gc.collect()
574 asyncio.set_event_loop(None)
577def _create_app_mock() -> mock.MagicMock:
578 def get_dict(app: Any, key: str) -> Any:
579 return app.__app_dict[key]
581 def set_dict(app: Any, key: str, value: Any) -> None:
582 app.__app_dict[key] = value
584 app = mock.MagicMock(spec=Application)
585 app.__app_dict = {}
586 app.__getitem__ = get_dict
587 app.__setitem__ = set_dict
589 app.on_response_prepare = Signal(app)
590 app.on_response_prepare.freeze()
591 return app
594def _create_transport(sslcontext: Optional[SSLContext] = None) -> mock.Mock:
595 transport = mock.Mock()
597 def get_extra_info(key: str) -> Optional[SSLContext]:
598 if key == "sslcontext":
599 return sslcontext
600 else:
601 return None
603 transport.get_extra_info.side_effect = get_extra_info
604 return transport
607def make_mocked_request(
608 method: str,
609 path: str,
610 headers: Optional[LooseHeaders] = None,
611 *,
612 match_info: Optional[Dict[str, str]] = None,
613 version: HttpVersion = HttpVersion(1, 1),
614 closing: bool = False,
615 app: Optional[Application] = None,
616 writer: Optional[AbstractStreamWriter] = None,
617 protocol: Optional[RequestHandler[Request]] = None,
618 transport: Optional[asyncio.Transport] = None,
619 payload: StreamReader = EMPTY_PAYLOAD,
620 sslcontext: Optional[SSLContext] = None,
621 client_max_size: int = 1024**2,
622 loop: Any = ...,
623) -> Request:
624 """Creates mocked web.Request testing purposes.
626 Useful in unit tests, when spinning full web server is overkill or
627 specific conditions and errors are hard to trigger.
628 """
629 task = mock.Mock()
630 if loop is ...:
631 # no loop passed, try to get the current one if
632 # its is running as we need a real loop to create
633 # executor jobs to be able to do testing
634 # with a real executor
635 try:
636 loop = asyncio.get_running_loop()
637 except RuntimeError:
638 loop = mock.Mock()
639 loop.create_future.return_value = ()
641 if version < HttpVersion(1, 1):
642 closing = True
644 if headers:
645 headers = CIMultiDictProxy(CIMultiDict(headers))
646 raw_hdrs = tuple(
647 (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items()
648 )
649 else:
650 headers = CIMultiDictProxy(CIMultiDict())
651 raw_hdrs = ()
653 chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower()
655 message = RawRequestMessage(
656 method,
657 path,
658 version,
659 headers,
660 raw_hdrs,
661 closing,
662 None,
663 False,
664 chunked,
665 URL(path),
666 )
667 if app is None:
668 app = _create_app_mock()
670 if transport is None:
671 transport = _create_transport(sslcontext)
673 if protocol is None:
674 protocol = mock.Mock()
675 protocol.transport = transport
676 type(protocol).peername = mock.PropertyMock(
677 return_value=transport.get_extra_info("peername")
678 )
679 type(protocol).ssl_context = mock.PropertyMock(return_value=sslcontext)
681 if writer is None:
682 writer = mock.Mock()
683 writer.write_headers = mock.AsyncMock(return_value=None)
684 writer.write = mock.AsyncMock(return_value=None)
685 writer.write_eof = mock.AsyncMock(return_value=None)
686 writer.drain = mock.AsyncMock(return_value=None)
687 writer.transport = transport
689 protocol.transport = transport
691 req = Request(
692 message, payload, protocol, writer, task, loop, client_max_size=client_max_size
693 )
695 match_info = UrlMappingMatchInfo(
696 {} if match_info is None else match_info, mock.Mock()
697 )
698 match_info.add_app(app)
699 req._match_info = match_info
701 return req