Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/test_utils.py: 50%

300 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-26 06:16 +0000

1"""Utilities shared by tests.""" 

2 

3import asyncio 

4import contextlib 

5import gc 

6import inspect 

7import ipaddress 

8import os 

9import socket 

10import sys 

11from abc import ABC, abstractmethod 

12from types import TracebackType 

13from typing import ( 

14 TYPE_CHECKING, 

15 Any, 

16 Callable, 

17 Iterator, 

18 List, 

19 Optional, 

20 Type, 

21 Union, 

22 cast, 

23) 

24from unittest import IsolatedAsyncioTestCase, mock 

25 

26from aiosignal import Signal 

27from multidict import CIMultiDict, CIMultiDictProxy 

28from yarl import URL 

29 

30import aiohttp 

31from aiohttp.client import _RequestContextManager, _WSRequestContextManager 

32 

33from . import ClientSession, hdrs 

34from .abc import AbstractCookieJar 

35from .client_reqrep import ClientResponse 

36from .client_ws import ClientWebSocketResponse 

37from .helpers import _SENTINEL, sentinel 

38from .http import HttpVersion, RawRequestMessage 

39from .typedefs import StrOrURL 

40from .web import ( 

41 Application, 

42 AppRunner, 

43 BaseRunner, 

44 Request, 

45 Server, 

46 ServerRunner, 

47 SockSite, 

48 UrlMappingMatchInfo, 

49) 

50from .web_protocol import _RequestHandler 

51 

52if TYPE_CHECKING: 

53 from ssl import SSLContext 

54else: 

55 SSLContext = None 

56 

57REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin" 

58 

59 

60def get_unused_port_socket( 

61 host: str, family: socket.AddressFamily = socket.AF_INET 

62) -> socket.socket: 

63 return get_port_socket(host, 0, family) 

64 

65 

66def get_port_socket( 

67 host: str, port: int, family: socket.AddressFamily = socket.AF_INET 

68) -> socket.socket: 

69 s = socket.socket(family, socket.SOCK_STREAM) 

70 if REUSE_ADDRESS: 

71 # Windows has different semantics for SO_REUSEADDR, 

72 # so don't set it. Ref: 

73 # https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse 

74 s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 

75 s.bind((host, port)) 

76 return s 

77 

78 

79def unused_port() -> int: 

80 """Return a port that is unused on the current host.""" 

81 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 

82 s.bind(("127.0.0.1", 0)) 

83 return cast(int, s.getsockname()[1]) 

84 

85 

86class BaseTestServer(ABC): 

87 __test__ = False 

88 

89 def __init__( 

90 self, 

91 *, 

92 scheme: Union[str, _SENTINEL] = sentinel, 

93 host: str = "127.0.0.1", 

94 port: Optional[int] = None, 

95 skip_url_asserts: bool = False, 

96 socket_factory: Callable[ 

97 [str, int, socket.AddressFamily], socket.socket 

98 ] = get_port_socket, 

99 **kwargs: Any, 

100 ) -> None: 

101 self.runner: Optional[BaseRunner] = None 

102 self._root: Optional[URL] = None 

103 self.host = host 

104 self.port = port 

105 self._closed = False 

106 self.scheme = scheme 

107 self.skip_url_asserts = skip_url_asserts 

108 self.socket_factory = socket_factory 

109 

110 async def start_server(self, **kwargs: Any) -> None: 

111 if self.runner: 

112 return 

113 self._ssl = kwargs.pop("ssl", None) 

114 self.runner = await self._make_runner(handler_cancellation=True, **kwargs) 

115 await self.runner.setup() 

116 if not self.port: 

117 self.port = 0 

118 absolute_host = self.host 

119 try: 

120 version = ipaddress.ip_address(self.host).version 

121 except ValueError: 

122 version = 4 

123 if version == 6: 

124 absolute_host = f"[{self.host}]" 

125 family = socket.AF_INET6 if version == 6 else socket.AF_INET 

126 _sock = self.socket_factory(self.host, self.port, family) 

127 self.host, self.port = _sock.getsockname()[:2] 

128 site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl) 

129 await site.start() 

130 server = site._server 

131 assert server is not None 

132 sockets = server.sockets # type: ignore[attr-defined] 

133 assert sockets is not None 

134 self.port = sockets[0].getsockname()[1] 

135 if self.scheme is sentinel: 

136 if self._ssl: 

137 scheme = "https" 

138 else: 

139 scheme = "http" 

140 self.scheme = scheme 

141 self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}") 

142 

143 @abstractmethod # pragma: no cover 

144 async def _make_runner(self, **kwargs: Any) -> BaseRunner: 

145 pass 

146 

147 def make_url(self, path: StrOrURL) -> URL: 

148 assert self._root is not None 

149 url = URL(path) 

150 if not self.skip_url_asserts: 

151 assert not url.is_absolute() 

152 return self._root.join(url) 

153 else: 

154 return URL(str(self._root) + str(path)) 

155 

156 @property 

157 def started(self) -> bool: 

158 return self.runner is not None 

159 

160 @property 

161 def closed(self) -> bool: 

162 return self._closed 

163 

164 @property 

165 def handler(self) -> Server: 

166 # for backward compatibility 

167 # web.Server instance 

168 runner = self.runner 

169 assert runner is not None 

170 assert runner.server is not None 

171 return runner.server 

172 

173 async def close(self) -> None: 

174 """Close all fixtures created by the test client. 

175 

176 After that point, the TestClient is no longer usable. 

177 

178 This is an idempotent function: running close multiple times 

179 will not have any additional effects. 

180 

181 close is also run when the object is garbage collected, and on 

182 exit when used as a context manager. 

183 

184 """ 

185 if self.started and not self.closed: 

186 assert self.runner is not None 

187 await self.runner.cleanup() 

188 self._root = None 

189 self.port = None 

190 self._closed = True 

191 

192 async def __aenter__(self) -> "BaseTestServer": 

193 await self.start_server() 

194 return self 

195 

196 async def __aexit__( 

197 self, 

198 exc_type: Optional[Type[BaseException]], 

199 exc_value: Optional[BaseException], 

200 traceback: Optional[TracebackType], 

201 ) -> None: 

202 await self.close() 

203 

204 

205class TestServer(BaseTestServer): 

206 def __init__( 

207 self, 

208 app: Application, 

209 *, 

210 scheme: Union[str, _SENTINEL] = sentinel, 

211 host: str = "127.0.0.1", 

212 port: Optional[int] = None, 

213 **kwargs: Any, 

214 ): 

215 self.app = app 

216 super().__init__(scheme=scheme, host=host, port=port, **kwargs) 

217 

218 async def _make_runner(self, **kwargs: Any) -> BaseRunner: 

219 return AppRunner(self.app, **kwargs) 

220 

221 

222class RawTestServer(BaseTestServer): 

223 def __init__( 

224 self, 

225 handler: _RequestHandler, 

226 *, 

227 scheme: Union[str, _SENTINEL] = sentinel, 

228 host: str = "127.0.0.1", 

229 port: Optional[int] = None, 

230 **kwargs: Any, 

231 ) -> None: 

232 self._handler = handler 

233 super().__init__(scheme=scheme, host=host, port=port, **kwargs) 

234 

235 async def _make_runner(self, **kwargs: Any) -> ServerRunner: 

236 srv = Server(self._handler, **kwargs) 

237 return ServerRunner(srv, **kwargs) 

238 

239 

240class TestClient: 

241 """ 

242 A test client implementation. 

243 

244 To write functional tests for aiohttp based servers. 

245 

246 """ 

247 

248 __test__ = False 

249 

250 def __init__( 

251 self, 

252 server: BaseTestServer, 

253 *, 

254 cookie_jar: Optional[AbstractCookieJar] = None, 

255 **kwargs: Any, 

256 ) -> None: 

257 if not isinstance(server, BaseTestServer): 

258 raise TypeError( 

259 "server must be TestServer " "instance, found type: %r" % type(server) 

260 ) 

261 self._server = server 

262 if cookie_jar is None: 

263 cookie_jar = aiohttp.CookieJar(unsafe=True) 

264 self._session = ClientSession(cookie_jar=cookie_jar, **kwargs) 

265 self._closed = False 

266 self._responses: List[ClientResponse] = [] 

267 self._websockets: List[ClientWebSocketResponse] = [] 

268 

269 async def start_server(self) -> None: 

270 await self._server.start_server() 

271 

272 @property 

273 def scheme(self) -> Union[str, object]: 

274 return self._server.scheme 

275 

276 @property 

277 def host(self) -> str: 

278 return self._server.host 

279 

280 @property 

281 def port(self) -> Optional[int]: 

282 return self._server.port 

283 

284 @property 

285 def server(self) -> BaseTestServer: 

286 return self._server 

287 

288 @property 

289 def app(self) -> Optional[Application]: 

290 return cast(Optional[Application], getattr(self._server, "app", None)) 

291 

292 @property 

293 def session(self) -> ClientSession: 

294 """An internal aiohttp.ClientSession. 

295 

296 Unlike the methods on the TestClient, client session requests 

297 do not automatically include the host in the url queried, and 

298 will require an absolute path to the resource. 

299 

300 """ 

301 return self._session 

302 

303 def make_url(self, path: StrOrURL) -> URL: 

304 return self._server.make_url(path) 

305 

306 async def _request( 

307 self, method: str, path: StrOrURL, **kwargs: Any 

308 ) -> ClientResponse: 

309 resp = await self._session.request(method, self.make_url(path), **kwargs) 

310 # save it to close later 

311 self._responses.append(resp) 

312 return resp 

313 

314 def request( 

315 self, method: str, path: StrOrURL, **kwargs: Any 

316 ) -> _RequestContextManager: 

317 """Routes a request to tested http server. 

318 

319 The interface is identical to aiohttp.ClientSession.request, 

320 except the loop kwarg is overridden by the instance used by the 

321 test server. 

322 

323 """ 

324 return _RequestContextManager(self._request(method, path, **kwargs)) 

325 

326 def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

327 """Perform an HTTP GET request.""" 

328 return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs)) 

329 

330 def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

331 """Perform an HTTP POST request.""" 

332 return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs)) 

333 

334 def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

335 """Perform an HTTP OPTIONS request.""" 

336 return _RequestContextManager(self._request(hdrs.METH_OPTIONS, path, **kwargs)) 

337 

338 def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

339 """Perform an HTTP HEAD request.""" 

340 return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs)) 

341 

342 def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

343 """Perform an HTTP PUT request.""" 

344 return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs)) 

345 

346 def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

347 """Perform an HTTP PATCH request.""" 

348 return _RequestContextManager(self._request(hdrs.METH_PATCH, path, **kwargs)) 

349 

350 def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

351 """Perform an HTTP PATCH request.""" 

352 return _RequestContextManager(self._request(hdrs.METH_DELETE, path, **kwargs)) 

353 

354 def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager: 

355 """Initiate websocket connection. 

356 

357 The api corresponds to aiohttp.ClientSession.ws_connect. 

358 

359 """ 

360 return _WSRequestContextManager(self._ws_connect(path, **kwargs)) 

361 

362 async def _ws_connect( 

363 self, path: StrOrURL, **kwargs: Any 

364 ) -> ClientWebSocketResponse: 

365 ws = await self._session.ws_connect(self.make_url(path), **kwargs) 

366 self._websockets.append(ws) 

367 return ws 

368 

369 async def close(self) -> None: 

370 """Close all fixtures created by the test client. 

371 

372 After that point, the TestClient is no longer usable. 

373 

374 This is an idempotent function: running close multiple times 

375 will not have any additional effects. 

376 

377 close is also run on exit when used as a(n) (asynchronous) 

378 context manager. 

379 

380 """ 

381 if not self._closed: 

382 for resp in self._responses: 

383 resp.close() 

384 for ws in self._websockets: 

385 await ws.close() 

386 await self._session.close() 

387 await self._server.close() 

388 self._closed = True 

389 

390 async def __aenter__(self) -> "TestClient": 

391 await self.start_server() 

392 return self 

393 

394 async def __aexit__( 

395 self, 

396 exc_type: Optional[Type[BaseException]], 

397 exc: Optional[BaseException], 

398 tb: Optional[TracebackType], 

399 ) -> None: 

400 await self.close() 

401 

402 

403class AioHTTPTestCase(IsolatedAsyncioTestCase, ABC): 

404 """A base class to allow for unittest web applications using aiohttp. 

405 

406 Provides the following: 

407 

408 * self.client (aiohttp.test_utils.TestClient): an aiohttp test client. 

409 * self.app (aiohttp.web.Application): the application returned by 

410 self.get_application() 

411 

412 Note that the TestClient's methods are asynchronous: you have to 

413 execute function on the test client using asynchronous methods. 

414 """ 

415 

416 @abstractmethod 

417 async def get_application(self) -> Application: 

418 """Get application. 

419 

420 This method should be overridden to return the aiohttp.web.Application 

421 object to test. 

422 """ 

423 

424 async def asyncSetUp(self) -> None: 

425 self.app = await self.get_application() 

426 self.server = await self.get_server(self.app) 

427 self.client = await self.get_client(self.server) 

428 

429 await self.client.start_server() 

430 

431 async def asyncTearDown(self) -> None: 

432 await self.client.close() 

433 

434 async def get_server(self, app: Application) -> TestServer: 

435 """Return a TestServer instance.""" 

436 return TestServer(app) 

437 

438 async def get_client(self, server: TestServer) -> TestClient: 

439 """Return a TestClient instance.""" 

440 return TestClient(server) 

441 

442 

443_LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop] 

444 

445 

446@contextlib.contextmanager 

447def loop_context( 

448 loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False 

449) -> Iterator[asyncio.AbstractEventLoop]: 

450 """A contextmanager that creates an event_loop, for test purposes. 

451 

452 Handles the creation and cleanup of a test loop. 

453 """ 

454 loop = setup_test_loop(loop_factory) 

455 yield loop 

456 teardown_test_loop(loop, fast=fast) 

457 

458 

459def setup_test_loop( 

460 loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, 

461) -> asyncio.AbstractEventLoop: 

462 """Create and return an asyncio.BaseEventLoop instance. 

463 

464 The caller should also call teardown_test_loop, 

465 once they are done with the loop. 

466 """ 

467 loop = loop_factory() 

468 asyncio.set_event_loop(loop) 

469 return loop 

470 

471 

472def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None: 

473 """Teardown and cleanup an event_loop created by setup_test_loop.""" 

474 closed = loop.is_closed() 

475 if not closed: 

476 loop.call_soon(loop.stop) 

477 loop.run_forever() 

478 loop.close() 

479 

480 if not fast: 

481 gc.collect() 

482 

483 asyncio.set_event_loop(None) 

484 

485 

486def _create_app_mock() -> mock.MagicMock: 

487 def get_dict(app: Any, key: str) -> Any: 

488 return app.__app_dict[key] 

489 

490 def set_dict(app: Any, key: str, value: Any) -> None: 

491 app.__app_dict[key] = value 

492 

493 app = mock.MagicMock(spec=Application) 

494 app.__app_dict = {} 

495 app.__getitem__ = get_dict 

496 app.__setitem__ = set_dict 

497 

498 app.on_response_prepare = Signal(app) 

499 app.on_response_prepare.freeze() 

500 return app 

501 

502 

503def _create_transport(sslcontext: Optional[SSLContext] = None) -> mock.Mock: 

504 transport = mock.Mock() 

505 

506 def get_extra_info(key: str) -> Optional[SSLContext]: 

507 if key == "sslcontext": 

508 return sslcontext 

509 else: 

510 return None 

511 

512 transport.get_extra_info.side_effect = get_extra_info 

513 return transport 

514 

515 

516def make_mocked_request( 

517 method: str, 

518 path: str, 

519 headers: Any = None, 

520 *, 

521 match_info: Any = sentinel, 

522 version: HttpVersion = HttpVersion(1, 1), 

523 closing: bool = False, 

524 app: Any = None, 

525 writer: Any = sentinel, 

526 protocol: Any = sentinel, 

527 transport: Any = sentinel, 

528 payload: Any = sentinel, 

529 sslcontext: Optional[SSLContext] = None, 

530 client_max_size: int = 1024**2, 

531 loop: Any = ..., 

532) -> Request: 

533 """Creates mocked web.Request testing purposes. 

534 

535 Useful in unit tests, when spinning full web server is overkill or 

536 specific conditions and errors are hard to trigger. 

537 """ 

538 task = mock.Mock() 

539 if loop is ...: 

540 loop = mock.Mock() 

541 loop.create_future.return_value = () 

542 

543 if version < HttpVersion(1, 1): 

544 closing = True 

545 

546 if headers: 

547 headers = CIMultiDictProxy(CIMultiDict(headers)) 

548 raw_hdrs = tuple( 

549 (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items() 

550 ) 

551 else: 

552 headers = CIMultiDictProxy(CIMultiDict()) 

553 raw_hdrs = () 

554 

555 chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower() 

556 

557 message = RawRequestMessage( 

558 method, 

559 path, 

560 version, 

561 headers, 

562 raw_hdrs, 

563 closing, 

564 None, 

565 False, 

566 chunked, 

567 URL(path), 

568 ) 

569 if app is None: 

570 app = _create_app_mock() 

571 

572 if transport is sentinel: 

573 transport = _create_transport(sslcontext) 

574 

575 if protocol is sentinel: 

576 protocol = mock.Mock() 

577 protocol.transport = transport 

578 

579 if writer is sentinel: 

580 writer = mock.Mock() 

581 writer.write_headers = make_mocked_coro(None) 

582 writer.write = make_mocked_coro(None) 

583 writer.write_eof = make_mocked_coro(None) 

584 writer.drain = make_mocked_coro(None) 

585 writer.transport = transport 

586 

587 protocol.transport = transport 

588 protocol.writer = writer 

589 

590 if payload is sentinel: 

591 payload = mock.Mock() 

592 

593 req = Request( 

594 message, payload, protocol, writer, task, loop, client_max_size=client_max_size 

595 ) 

596 

597 match_info = UrlMappingMatchInfo( 

598 {} if match_info is sentinel else match_info, mock.Mock() 

599 ) 

600 match_info.add_app(app) 

601 req._match_info = match_info 

602 

603 return req 

604 

605 

606def make_mocked_coro( 

607 return_value: Any = sentinel, raise_exception: Any = sentinel 

608) -> Any: 

609 """Creates a coroutine mock.""" 

610 

611 async def mock_coro(*args: Any, **kwargs: Any) -> Any: 

612 if raise_exception is not sentinel: 

613 raise raise_exception 

614 if not inspect.isawaitable(return_value): 

615 return return_value 

616 await return_value 

617 

618 return mock.Mock(wraps=mock_coro)