Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/test_utils.py: 48%

319 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:52 +0000

1"""Utilities shared by tests.""" 

2 

3import asyncio 

4import contextlib 

5import gc 

6import inspect 

7import ipaddress 

8import os 

9import socket 

10import sys 

11from abc import ABC, abstractmethod 

12from types import TracebackType 

13from typing import ( 

14 TYPE_CHECKING, 

15 Any, 

16 Callable, 

17 Iterator, 

18 List, 

19 Optional, 

20 Type, 

21 Union, 

22 cast, 

23) 

24from unittest import mock 

25 

26from aiosignal import Signal 

27from multidict import CIMultiDict, CIMultiDictProxy 

28from yarl import URL 

29 

30import aiohttp 

31from aiohttp.client import _RequestContextManager, _WSRequestContextManager 

32 

33from . import ClientSession, hdrs 

34from .abc import AbstractCookieJar 

35from .client_reqrep import ClientResponse 

36from .client_ws import ClientWebSocketResponse 

37from .helpers import _SENTINEL, PY_38, sentinel 

38from .http import HttpVersion, RawRequestMessage 

39from .typedefs import StrOrURL 

40from .web import ( 

41 Application, 

42 AppRunner, 

43 BaseRunner, 

44 Request, 

45 Server, 

46 ServerRunner, 

47 SockSite, 

48 UrlMappingMatchInfo, 

49) 

50from .web_protocol import _RequestHandler 

51 

52if TYPE_CHECKING: # pragma: no cover 

53 from ssl import SSLContext 

54else: 

55 SSLContext = None 

56 

57if PY_38: 

58 from unittest import IsolatedAsyncioTestCase as TestCase 

59else: 

60 from asynctest import TestCase # type: ignore[no-redef] 

61 

62REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin" 

63 

64 

65def get_unused_port_socket( 

66 host: str, family: socket.AddressFamily = socket.AF_INET 

67) -> socket.socket: 

68 return get_port_socket(host, 0, family) 

69 

70 

71def get_port_socket( 

72 host: str, port: int, family: socket.AddressFamily = socket.AF_INET 

73) -> socket.socket: 

74 s = socket.socket(family, socket.SOCK_STREAM) 

75 if REUSE_ADDRESS: 

76 # Windows has different semantics for SO_REUSEADDR, 

77 # so don't set it. Ref: 

78 # https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse 

79 s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 

80 s.bind((host, port)) 

81 return s 

82 

83 

84def unused_port() -> int: 

85 """Return a port that is unused on the current host.""" 

86 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 

87 s.bind(("127.0.0.1", 0)) 

88 return cast(int, s.getsockname()[1]) 

89 

90 

91class BaseTestServer(ABC): 

92 __test__ = False 

93 

94 def __init__( 

95 self, 

96 *, 

97 scheme: Union[str, _SENTINEL] = sentinel, 

98 host: str = "127.0.0.1", 

99 port: Optional[int] = None, 

100 skip_url_asserts: bool = False, 

101 socket_factory: Callable[ 

102 [str, int, socket.AddressFamily], socket.socket 

103 ] = get_port_socket, 

104 **kwargs: Any, 

105 ) -> None: 

106 self.runner: Optional[BaseRunner] = None 

107 self._root: Optional[URL] = None 

108 self.host = host 

109 self.port = port 

110 self._closed = False 

111 self.scheme = scheme 

112 self.skip_url_asserts = skip_url_asserts 

113 self.socket_factory = socket_factory 

114 

115 async def start_server(self, **kwargs: Any) -> None: 

116 if self.runner: 

117 return 

118 self._ssl = kwargs.pop("ssl", None) 

119 self.runner = await self._make_runner(handler_cancellation=True, **kwargs) 

120 await self.runner.setup() 

121 if not self.port: 

122 self.port = 0 

123 absolute_host = self.host 

124 try: 

125 version = ipaddress.ip_address(self.host).version 

126 except ValueError: 

127 version = 4 

128 if version == 6: 

129 absolute_host = f"[{self.host}]" 

130 family = socket.AF_INET6 if version == 6 else socket.AF_INET 

131 _sock = self.socket_factory(self.host, self.port, family) 

132 self.host, self.port = _sock.getsockname()[:2] 

133 site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl) 

134 await site.start() 

135 server = site._server 

136 assert server is not None 

137 sockets = server.sockets # type: ignore[attr-defined] 

138 assert sockets is not None 

139 self.port = sockets[0].getsockname()[1] 

140 if self.scheme is sentinel: 

141 if self._ssl: 

142 scheme = "https" 

143 else: 

144 scheme = "http" 

145 self.scheme = scheme 

146 self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}") 

147 

148 @abstractmethod # pragma: no cover 

149 async def _make_runner(self, **kwargs: Any) -> BaseRunner: 

150 pass 

151 

152 def make_url(self, path: StrOrURL) -> URL: 

153 assert self._root is not None 

154 url = URL(path) 

155 if not self.skip_url_asserts: 

156 assert not url.is_absolute() 

157 return self._root.join(url) 

158 else: 

159 return URL(str(self._root) + str(path)) 

160 

161 @property 

162 def started(self) -> bool: 

163 return self.runner is not None 

164 

165 @property 

166 def closed(self) -> bool: 

167 return self._closed 

168 

169 @property 

170 def handler(self) -> Server: 

171 # for backward compatibility 

172 # web.Server instance 

173 runner = self.runner 

174 assert runner is not None 

175 assert runner.server is not None 

176 return runner.server 

177 

178 async def close(self) -> None: 

179 """Close all fixtures created by the test client. 

180 

181 After that point, the TestClient is no longer usable. 

182 

183 This is an idempotent function: running close multiple times 

184 will not have any additional effects. 

185 

186 close is also run when the object is garbage collected, and on 

187 exit when used as a context manager. 

188 

189 """ 

190 if self.started and not self.closed: 

191 assert self.runner is not None 

192 await self.runner.cleanup() 

193 self._root = None 

194 self.port = None 

195 self._closed = True 

196 

197 async def __aenter__(self) -> "BaseTestServer": 

198 await self.start_server() 

199 return self 

200 

201 async def __aexit__( 

202 self, 

203 exc_type: Optional[Type[BaseException]], 

204 exc_value: Optional[BaseException], 

205 traceback: Optional[TracebackType], 

206 ) -> None: 

207 await self.close() 

208 

209 

210class TestServer(BaseTestServer): 

211 def __init__( 

212 self, 

213 app: Application, 

214 *, 

215 scheme: Union[str, _SENTINEL] = sentinel, 

216 host: str = "127.0.0.1", 

217 port: Optional[int] = None, 

218 **kwargs: Any, 

219 ): 

220 self.app = app 

221 super().__init__(scheme=scheme, host=host, port=port, **kwargs) 

222 

223 async def _make_runner(self, **kwargs: Any) -> BaseRunner: 

224 return AppRunner(self.app, **kwargs) 

225 

226 

227class RawTestServer(BaseTestServer): 

228 def __init__( 

229 self, 

230 handler: _RequestHandler, 

231 *, 

232 scheme: Union[str, _SENTINEL] = sentinel, 

233 host: str = "127.0.0.1", 

234 port: Optional[int] = None, 

235 **kwargs: Any, 

236 ) -> None: 

237 self._handler = handler 

238 super().__init__(scheme=scheme, host=host, port=port, **kwargs) 

239 

240 async def _make_runner(self, **kwargs: Any) -> ServerRunner: 

241 srv = Server(self._handler, **kwargs) 

242 return ServerRunner(srv, **kwargs) 

243 

244 

245class TestClient: 

246 """ 

247 A test client implementation. 

248 

249 To write functional tests for aiohttp based servers. 

250 

251 """ 

252 

253 __test__ = False 

254 

255 def __init__( 

256 self, 

257 server: BaseTestServer, 

258 *, 

259 cookie_jar: Optional[AbstractCookieJar] = None, 

260 **kwargs: Any, 

261 ) -> None: 

262 if not isinstance(server, BaseTestServer): 

263 raise TypeError( 

264 "server must be TestServer " "instance, found type: %r" % type(server) 

265 ) 

266 self._server = server 

267 if cookie_jar is None: 

268 cookie_jar = aiohttp.CookieJar(unsafe=True) 

269 self._session = ClientSession(cookie_jar=cookie_jar, **kwargs) 

270 self._closed = False 

271 self._responses: List[ClientResponse] = [] 

272 self._websockets: List[ClientWebSocketResponse] = [] 

273 

274 async def start_server(self) -> None: 

275 await self._server.start_server() 

276 

277 @property 

278 def scheme(self) -> Union[str, object]: 

279 return self._server.scheme 

280 

281 @property 

282 def host(self) -> str: 

283 return self._server.host 

284 

285 @property 

286 def port(self) -> Optional[int]: 

287 return self._server.port 

288 

289 @property 

290 def server(self) -> BaseTestServer: 

291 return self._server 

292 

293 @property 

294 def app(self) -> Optional[Application]: 

295 return cast(Optional[Application], getattr(self._server, "app", None)) 

296 

297 @property 

298 def session(self) -> ClientSession: 

299 """An internal aiohttp.ClientSession. 

300 

301 Unlike the methods on the TestClient, client session requests 

302 do not automatically include the host in the url queried, and 

303 will require an absolute path to the resource. 

304 

305 """ 

306 return self._session 

307 

308 def make_url(self, path: StrOrURL) -> URL: 

309 return self._server.make_url(path) 

310 

311 async def _request( 

312 self, method: str, path: StrOrURL, **kwargs: Any 

313 ) -> ClientResponse: 

314 resp = await self._session.request(method, self.make_url(path), **kwargs) 

315 # save it to close later 

316 self._responses.append(resp) 

317 return resp 

318 

319 def request( 

320 self, method: str, path: StrOrURL, **kwargs: Any 

321 ) -> _RequestContextManager: 

322 """Routes a request to tested http server. 

323 

324 The interface is identical to aiohttp.ClientSession.request, 

325 except the loop kwarg is overridden by the instance used by the 

326 test server. 

327 

328 """ 

329 return _RequestContextManager(self._request(method, path, **kwargs)) 

330 

331 def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

332 """Perform an HTTP GET request.""" 

333 return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs)) 

334 

335 def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

336 """Perform an HTTP POST request.""" 

337 return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs)) 

338 

339 def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

340 """Perform an HTTP OPTIONS request.""" 

341 return _RequestContextManager(self._request(hdrs.METH_OPTIONS, path, **kwargs)) 

342 

343 def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

344 """Perform an HTTP HEAD request.""" 

345 return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs)) 

346 

347 def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

348 """Perform an HTTP PUT request.""" 

349 return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs)) 

350 

351 def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

352 """Perform an HTTP PATCH request.""" 

353 return _RequestContextManager(self._request(hdrs.METH_PATCH, path, **kwargs)) 

354 

355 def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

356 """Perform an HTTP PATCH request.""" 

357 return _RequestContextManager(self._request(hdrs.METH_DELETE, path, **kwargs)) 

358 

359 def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager: 

360 """Initiate websocket connection. 

361 

362 The api corresponds to aiohttp.ClientSession.ws_connect. 

363 

364 """ 

365 return _WSRequestContextManager(self._ws_connect(path, **kwargs)) 

366 

367 async def _ws_connect( 

368 self, path: StrOrURL, **kwargs: Any 

369 ) -> ClientWebSocketResponse: 

370 ws = await self._session.ws_connect(self.make_url(path), **kwargs) 

371 self._websockets.append(ws) 

372 return ws 

373 

374 async def close(self) -> None: 

375 """Close all fixtures created by the test client. 

376 

377 After that point, the TestClient is no longer usable. 

378 

379 This is an idempotent function: running close multiple times 

380 will not have any additional effects. 

381 

382 close is also run on exit when used as a(n) (asynchronous) 

383 context manager. 

384 

385 """ 

386 if not self._closed: 

387 for resp in self._responses: 

388 resp.close() 

389 for ws in self._websockets: 

390 await ws.close() 

391 await self._session.close() 

392 await self._server.close() 

393 self._closed = True 

394 

395 async def __aenter__(self) -> "TestClient": 

396 await self.start_server() 

397 return self 

398 

399 async def __aexit__( 

400 self, 

401 exc_type: Optional[Type[BaseException]], 

402 exc: Optional[BaseException], 

403 tb: Optional[TracebackType], 

404 ) -> None: 

405 await self.close() 

406 

407 

408class AioHTTPTestCase(TestCase, ABC): 

409 """A base class to allow for unittest web applications using aiohttp. 

410 

411 Provides the following: 

412 

413 * self.client (aiohttp.test_utils.TestClient): an aiohttp test client. 

414 * self.app (aiohttp.web.Application): the application returned by 

415 self.get_application() 

416 

417 Note that the TestClient's methods are asynchronous: you have to 

418 execute function on the test client using asynchronous methods. 

419 """ 

420 

421 @abstractmethod 

422 async def get_application(self) -> Application: 

423 """Get application. 

424 

425 This method should be overridden to return the aiohttp.web.Application 

426 object to test. 

427 """ 

428 

429 def setUp(self) -> None: 

430 if not PY_38: 

431 asyncio.get_event_loop().run_until_complete(self.asyncSetUp()) 

432 

433 async def asyncSetUp(self) -> None: 

434 self.app = await self.get_application() 

435 self.server = await self.get_server(self.app) 

436 self.client = await self.get_client(self.server) 

437 

438 await self.client.start_server() 

439 

440 def tearDown(self) -> None: 

441 if not PY_38: 

442 asyncio.get_event_loop().run_until_complete(self.asyncTearDown()) 

443 

444 async def asyncTearDown(self) -> None: 

445 await self.client.close() 

446 

447 async def get_server(self, app: Application) -> TestServer: 

448 """Return a TestServer instance.""" 

449 return TestServer(app) 

450 

451 async def get_client(self, server: TestServer) -> TestClient: 

452 """Return a TestClient instance.""" 

453 return TestClient(server) 

454 

455 

456_LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop] 

457 

458 

459@contextlib.contextmanager 

460def loop_context( 

461 loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False 

462) -> Iterator[asyncio.AbstractEventLoop]: 

463 """A contextmanager that creates an event_loop, for test purposes. 

464 

465 Handles the creation and cleanup of a test loop. 

466 """ 

467 loop = setup_test_loop(loop_factory) 

468 yield loop 

469 teardown_test_loop(loop, fast=fast) 

470 

471 

472def setup_test_loop( 

473 loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, 

474) -> asyncio.AbstractEventLoop: 

475 """Create and return an asyncio.BaseEventLoop instance. 

476 

477 The caller should also call teardown_test_loop, 

478 once they are done with the loop. 

479 """ 

480 loop = loop_factory() 

481 try: 

482 module = loop.__class__.__module__ 

483 skip_watcher = "uvloop" in module 

484 except AttributeError: # pragma: no cover 

485 # Just in case 

486 skip_watcher = True 

487 asyncio.set_event_loop(loop) 

488 if sys.platform != "win32" and not skip_watcher: 

489 policy = asyncio.get_event_loop_policy() 

490 watcher: asyncio.AbstractChildWatcher 

491 try: # Python >= 3.8 

492 # Refs: 

493 # * https://github.com/pytest-dev/pytest-xdist/issues/620 

494 # * https://stackoverflow.com/a/58614689/595220 

495 # * https://bugs.python.org/issue35621 

496 # * https://github.com/python/cpython/pull/14344 

497 watcher = asyncio.ThreadedChildWatcher() 

498 except AttributeError: # Python < 3.8 

499 watcher = asyncio.SafeChildWatcher() 

500 watcher.attach_loop(loop) 

501 with contextlib.suppress(NotImplementedError): 

502 policy.set_child_watcher(watcher) 

503 return loop 

504 

505 

506def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None: 

507 """Teardown and cleanup an event_loop created by setup_test_loop.""" 

508 closed = loop.is_closed() 

509 if not closed: 

510 loop.call_soon(loop.stop) 

511 loop.run_forever() 

512 loop.close() 

513 

514 if not fast: 

515 gc.collect() 

516 

517 asyncio.set_event_loop(None) 

518 

519 

520def _create_app_mock() -> mock.MagicMock: 

521 def get_dict(app: Any, key: str) -> Any: 

522 return app.__app_dict[key] 

523 

524 def set_dict(app: Any, key: str, value: Any) -> None: 

525 app.__app_dict[key] = value 

526 

527 app = mock.MagicMock(spec=Application) 

528 app.__app_dict = {} 

529 app.__getitem__ = get_dict 

530 app.__setitem__ = set_dict 

531 

532 app.on_response_prepare = Signal(app) 

533 app.on_response_prepare.freeze() 

534 return app 

535 

536 

537def _create_transport(sslcontext: Optional[SSLContext] = None) -> mock.Mock: 

538 transport = mock.Mock() 

539 

540 def get_extra_info(key: str) -> Optional[SSLContext]: 

541 if key == "sslcontext": 

542 return sslcontext 

543 else: 

544 return None 

545 

546 transport.get_extra_info.side_effect = get_extra_info 

547 return transport 

548 

549 

550def make_mocked_request( 

551 method: str, 

552 path: str, 

553 headers: Any = None, 

554 *, 

555 match_info: Any = sentinel, 

556 version: HttpVersion = HttpVersion(1, 1), 

557 closing: bool = False, 

558 app: Any = None, 

559 writer: Any = sentinel, 

560 protocol: Any = sentinel, 

561 transport: Any = sentinel, 

562 payload: Any = sentinel, 

563 sslcontext: Optional[SSLContext] = None, 

564 client_max_size: int = 1024**2, 

565 loop: Any = ..., 

566) -> Request: 

567 """Creates mocked web.Request testing purposes. 

568 

569 Useful in unit tests, when spinning full web server is overkill or 

570 specific conditions and errors are hard to trigger. 

571 """ 

572 task = mock.Mock() 

573 if loop is ...: 

574 loop = mock.Mock() 

575 loop.create_future.return_value = () 

576 

577 if version < HttpVersion(1, 1): 

578 closing = True 

579 

580 if headers: 

581 headers = CIMultiDictProxy(CIMultiDict(headers)) 

582 raw_hdrs = tuple( 

583 (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items() 

584 ) 

585 else: 

586 headers = CIMultiDictProxy(CIMultiDict()) 

587 raw_hdrs = () 

588 

589 chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower() 

590 

591 message = RawRequestMessage( 

592 method, 

593 path, 

594 version, 

595 headers, 

596 raw_hdrs, 

597 closing, 

598 None, 

599 False, 

600 chunked, 

601 URL(path), 

602 ) 

603 if app is None: 

604 app = _create_app_mock() 

605 

606 if transport is sentinel: 

607 transport = _create_transport(sslcontext) 

608 

609 if protocol is sentinel: 

610 protocol = mock.Mock() 

611 protocol.transport = transport 

612 

613 if writer is sentinel: 

614 writer = mock.Mock() 

615 writer.write_headers = make_mocked_coro(None) 

616 writer.write = make_mocked_coro(None) 

617 writer.write_eof = make_mocked_coro(None) 

618 writer.drain = make_mocked_coro(None) 

619 writer.transport = transport 

620 

621 protocol.transport = transport 

622 protocol.writer = writer 

623 

624 if payload is sentinel: 

625 payload = mock.Mock() 

626 

627 req = Request( 

628 message, payload, protocol, writer, task, loop, client_max_size=client_max_size 

629 ) 

630 

631 match_info = UrlMappingMatchInfo( 

632 {} if match_info is sentinel else match_info, mock.Mock() 

633 ) 

634 match_info.add_app(app) 

635 req._match_info = match_info 

636 

637 return req 

638 

639 

640def make_mocked_coro( 

641 return_value: Any = sentinel, raise_exception: Any = sentinel 

642) -> Any: 

643 """Creates a coroutine mock.""" 

644 

645 async def mock_coro(*args: Any, **kwargs: Any) -> Any: 

646 if raise_exception is not sentinel: 

647 raise raise_exception 

648 if not inspect.isawaitable(return_value): 

649 return return_value 

650 await return_value 

651 

652 return mock.Mock(wraps=mock_coro)