Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/test_utils.py: 56%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

297 statements  

1"""Utilities shared by tests.""" 

2 

3import asyncio 

4import ipaddress 

5import os 

6import socket 

7import sys 

8from abc import ABC, abstractmethod 

9from collections.abc import Callable 

10from types import TracebackType 

11from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload 

12from unittest import IsolatedAsyncioTestCase, mock 

13 

14from aiosignal import Signal 

15from multidict import CIMultiDict 

16from yarl import URL 

17 

18import aiohttp 

19from aiohttp.client import ( 

20 _BaseRequestContextManager, 

21 _RequestContextManager, 

22 _RequestOptions, 

23 _WSRequestContextManager, 

24) 

25 

26from . import ClientSession, hdrs 

27from .abc import AbstractCookieJar, AbstractStreamWriter 

28from .client_reqrep import ClientResponse 

29from .client_ws import ClientWebSocketResponse 

30from .helpers import HeadersDictProxy 

31from .http import HttpVersion, RawRequestMessage 

32from .streams import EMPTY_PAYLOAD, StreamReader 

33from .typedefs import LooseHeaders, StrOrURL 

34from .web import ( 

35 Application, 

36 AppRunner, 

37 BaseRequest, 

38 BaseRunner, 

39 Request, 

40 RequestHandler, 

41 Server, 

42 ServerRunner, 

43 SockSite, 

44 UrlMappingMatchInfo, 

45) 

46from .web_protocol import _RequestHandler 

47 

48if TYPE_CHECKING: 

49 from ssl import SSLContext 

50else: 

51 SSLContext = Any 

52 

53if sys.version_info >= (3, 11) and TYPE_CHECKING: 

54 from typing import Unpack 

55 

56if sys.version_info >= (3, 11): 

57 from typing import Self 

58else: 

59 Self = Any 

60 

61_ApplicationNone = TypeVar("_ApplicationNone", Application, None) 

62_Request = TypeVar("_Request", bound=BaseRequest) 

63 

64REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin" 

65 

66 

67class BaseTestServer(ABC, Generic[_Request]): 

68 __test__ = False 

69 

70 def __init__( 

71 self, 

72 *, 

73 scheme: str = "", 

74 host: str = "127.0.0.1", 

75 port: int | None = None, 

76 skip_url_asserts: bool = False, 

77 socket_factory: Callable[ 

78 [str, int, socket.AddressFamily], socket.socket 

79 ] = lambda h, p, f: socket.create_server( 

80 (h, p), family=f, reuse_port=REUSE_ADDRESS 

81 ), 

82 **kwargs: Any, 

83 ) -> None: 

84 self.runner: BaseRunner[_Request] | None = None 

85 self._root: URL | None = None 

86 self.host = host 

87 self.port = port or 0 

88 self._closed = False 

89 self.scheme = scheme 

90 self.skip_url_asserts = skip_url_asserts 

91 self.socket_factory = socket_factory 

92 

93 async def start_server(self, **kwargs: Any) -> None: 

94 if self.runner: 

95 return 

96 self._ssl = kwargs.pop("ssl", None) 

97 self.runner = await self._make_runner(handler_cancellation=True, **kwargs) 

98 await self.runner.setup() 

99 absolute_host = self.host 

100 try: 

101 version = ipaddress.ip_address(self.host).version 

102 except ValueError: 

103 version = 4 

104 if version == 6: 

105 absolute_host = f"[{self.host}]" 

106 family = socket.AF_INET6 if version == 6 else socket.AF_INET 

107 _sock = self.socket_factory(self.host, self.port, family) 

108 self.host, self.port = _sock.getsockname()[:2] 

109 site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl) 

110 await site.start() 

111 server = site._server 

112 assert server is not None 

113 sockets = server.sockets 

114 assert sockets is not None 

115 self.port = sockets[0].getsockname()[1] 

116 if not self.scheme: 

117 self.scheme = "https" if self._ssl else "http" 

118 self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}") 

119 

120 @abstractmethod 

121 async def _make_runner(self, **kwargs: Any) -> BaseRunner[_Request]: 

122 """Return a new runner for the server.""" 

123 # TODO(PY311): Use Unpack to specify Server kwargs. 

124 

125 def make_url(self, path: StrOrURL) -> URL: 

126 assert self._root is not None 

127 url = URL(path) 

128 if not self.skip_url_asserts: 

129 assert not url.absolute 

130 return self._root.join(url) 

131 else: 

132 return URL(str(self._root) + str(path)) 

133 

134 @property 

135 def started(self) -> bool: 

136 return self.runner is not None 

137 

138 @property 

139 def closed(self) -> bool: 

140 return self._closed 

141 

142 @property 

143 def handler(self) -> Server[_Request]: 

144 # for backward compatibility 

145 # web.Server instance 

146 runner = self.runner 

147 assert runner is not None 

148 assert runner.server is not None 

149 return runner.server 

150 

151 async def close(self) -> None: 

152 """Close all fixtures created by the test client. 

153 

154 After that point, the TestClient is no longer usable. 

155 

156 This is an idempotent function: running close multiple times 

157 will not have any additional effects. 

158 

159 close is also run when the object is garbage collected, and on 

160 exit when used as a context manager. 

161 

162 """ 

163 if self.started and not self.closed: 

164 assert self.runner is not None 

165 await self.runner.cleanup() 

166 self._root = None 

167 self.port = 0 

168 self._closed = True 

169 

170 async def __aenter__(self) -> Self: 

171 await self.start_server() 

172 return self 

173 

174 async def __aexit__( 

175 self, 

176 exc_type: type[BaseException] | None, 

177 exc_value: BaseException | None, 

178 traceback: TracebackType | None, 

179 ) -> None: 

180 await self.close() 

181 

182 

183class TestServer(BaseTestServer[Request]): 

184 def __init__( 

185 self, 

186 app: Application, 

187 *, 

188 scheme: str = "", 

189 host: str = "127.0.0.1", 

190 port: int | None = None, 

191 **kwargs: Any, 

192 ): 

193 self.app = app 

194 super().__init__(scheme=scheme, host=host, port=port, **kwargs) 

195 

196 async def _make_runner(self, **kwargs: Any) -> AppRunner: 

197 # TODO(PY311): Use Unpack to specify Server kwargs. 

198 return AppRunner(self.app, **kwargs) 

199 

200 

201class RawTestServer(BaseTestServer[BaseRequest]): 

202 def __init__( 

203 self, 

204 handler: _RequestHandler[BaseRequest], 

205 *, 

206 scheme: str = "", 

207 host: str = "127.0.0.1", 

208 port: int | None = None, 

209 **kwargs: Any, 

210 ) -> None: 

211 self._handler = handler 

212 super().__init__(scheme=scheme, host=host, port=port, **kwargs) 

213 

214 async def _make_runner(self, **kwargs: Any) -> ServerRunner: 

215 # TODO(PY311): Use Unpack to specify Server kwargs. 

216 srv = Server(self._handler, **kwargs) 

217 return ServerRunner(srv, **kwargs) 

218 

219 

220class TestClient(Generic[_Request, _ApplicationNone]): 

221 """ 

222 A test client implementation. 

223 

224 To write functional tests for aiohttp based servers. 

225 

226 """ 

227 

228 __test__ = False 

229 

230 @overload 

231 def __init__( 

232 self: "TestClient[Request, Application]", 

233 server: TestServer, 

234 *, 

235 cookie_jar: AbstractCookieJar | None = None, 

236 **kwargs: Any, 

237 ) -> None: ... 

238 @overload 

239 def __init__( 

240 self: "TestClient[_Request, None]", 

241 server: BaseTestServer[_Request], 

242 *, 

243 cookie_jar: AbstractCookieJar | None = None, 

244 **kwargs: Any, 

245 ) -> None: ... 

246 def __init__( # type: ignore[misc] 

247 self, 

248 server: BaseTestServer[_Request], 

249 *, 

250 cookie_jar: AbstractCookieJar | None = None, 

251 **kwargs: Any, 

252 ) -> None: 

253 # TODO(PY311): Use Unpack to specify ClientSession kwargs. 

254 if not isinstance(server, BaseTestServer): 

255 raise TypeError( 

256 "server must be TestServer instance, found type: %r" % type(server) 

257 ) 

258 self._server = server 

259 if cookie_jar is None: 

260 cookie_jar = aiohttp.CookieJar(unsafe=True) 

261 self._session = ClientSession(cookie_jar=cookie_jar, **kwargs) 

262 self._session._retry_connection = False 

263 self._closed = False 

264 self._responses: list[ClientResponse] = [] 

265 self._websockets: list[ClientWebSocketResponse[bool]] = [] 

266 

267 async def start_server(self) -> None: 

268 await self._server.start_server() 

269 

270 @property 

271 def scheme(self) -> str | object: 

272 return self._server.scheme 

273 

274 @property 

275 def host(self) -> str: 

276 return self._server.host 

277 

278 @property 

279 def port(self) -> int: 

280 return self._server.port 

281 

282 @property 

283 def server(self) -> BaseTestServer[_Request]: 

284 return self._server 

285 

286 @property 

287 def app(self) -> _ApplicationNone: 

288 return getattr(self._server, "app", None) # type: ignore[return-value] 

289 

290 @property 

291 def session(self) -> ClientSession: 

292 """An internal aiohttp.ClientSession. 

293 

294 Unlike the methods on the TestClient, client session requests 

295 do not automatically include the host in the url queried, and 

296 will require an absolute path to the resource. 

297 

298 """ 

299 return self._session 

300 

301 def make_url(self, path: StrOrURL) -> URL: 

302 return self._server.make_url(path) 

303 

304 async def _request( 

305 self, method: str, path: StrOrURL, **kwargs: Any 

306 ) -> ClientResponse: 

307 resp = await self._session.request(method, self.make_url(path), **kwargs) 

308 # save it to close later 

309 self._responses.append(resp) 

310 return resp 

311 

312 if sys.version_info >= (3, 11) and TYPE_CHECKING: 

313 

314 def request( 

315 self, method: str, path: StrOrURL, **kwargs: Unpack[_RequestOptions] 

316 ) -> _RequestContextManager: ... 

317 

318 def get( 

319 self, 

320 path: StrOrURL, 

321 **kwargs: Unpack[_RequestOptions], 

322 ) -> _RequestContextManager: ... 

323 

324 def options( 

325 self, 

326 path: StrOrURL, 

327 **kwargs: Unpack[_RequestOptions], 

328 ) -> _RequestContextManager: ... 

329 

330 def head( 

331 self, 

332 path: StrOrURL, 

333 **kwargs: Unpack[_RequestOptions], 

334 ) -> _RequestContextManager: ... 

335 

336 def post( 

337 self, 

338 path: StrOrURL, 

339 **kwargs: Unpack[_RequestOptions], 

340 ) -> _RequestContextManager: ... 

341 

342 def put( 

343 self, 

344 path: StrOrURL, 

345 **kwargs: Unpack[_RequestOptions], 

346 ) -> _RequestContextManager: ... 

347 

348 def patch( 

349 self, 

350 path: StrOrURL, 

351 **kwargs: Unpack[_RequestOptions], 

352 ) -> _RequestContextManager: ... 

353 

354 def delete( 

355 self, 

356 path: StrOrURL, 

357 **kwargs: Unpack[_RequestOptions], 

358 ) -> _RequestContextManager: ... 

359 

360 else: 

361 

362 def request( 

363 self, method: str, path: StrOrURL, **kwargs: Any 

364 ) -> _RequestContextManager: 

365 """Routes a request to tested http server. 

366 

367 The interface is identical to aiohttp.ClientSession.request, 

368 except the loop kwarg is overridden by the instance used by the 

369 test server. 

370 

371 """ 

372 return _RequestContextManager(self._request(method, path, **kwargs)) 

373 

374 def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

375 """Perform an HTTP GET request.""" 

376 return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs)) 

377 

378 def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

379 """Perform an HTTP POST request.""" 

380 return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs)) 

381 

382 def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

383 """Perform an HTTP OPTIONS request.""" 

384 return _RequestContextManager( 

385 self._request(hdrs.METH_OPTIONS, path, **kwargs) 

386 ) 

387 

388 def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

389 """Perform an HTTP HEAD request.""" 

390 return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs)) 

391 

392 def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

393 """Perform an HTTP PUT request.""" 

394 return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs)) 

395 

396 def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

397 """Perform an HTTP PATCH request.""" 

398 return _RequestContextManager( 

399 self._request(hdrs.METH_PATCH, path, **kwargs) 

400 ) 

401 

402 def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

403 """Perform an HTTP PATCH request.""" 

404 return _RequestContextManager( 

405 self._request(hdrs.METH_DELETE, path, **kwargs) 

406 ) 

407 

408 @overload 

409 def ws_connect( 

410 self, path: StrOrURL, *, decode_text: Literal[True] = ..., **kwargs: Any 

411 ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[True]]]": ... 

412 

413 @overload 

414 def ws_connect( 

415 self, path: StrOrURL, *, decode_text: Literal[False], **kwargs: Any 

416 ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]": ... 

417 

418 @overload 

419 def ws_connect( 

420 self, path: StrOrURL, *, decode_text: bool = ..., **kwargs: Any 

421 ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": ... 

422 

423 def ws_connect( 

424 self, path: StrOrURL, *, decode_text: bool = True, **kwargs: Any 

425 ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": 

426 """Initiate websocket connection. 

427 

428 The api corresponds to aiohttp.ClientSession.ws_connect. 

429 

430 """ 

431 return _WSRequestContextManager( 

432 self._ws_connect(path, decode_text=decode_text, **kwargs) 

433 ) 

434 

435 @overload 

436 async def _ws_connect( 

437 self, path: StrOrURL, *, decode_text: Literal[True] = ..., **kwargs: Any 

438 ) -> "ClientWebSocketResponse[Literal[True]]": ... 

439 

440 @overload 

441 async def _ws_connect( 

442 self, path: StrOrURL, *, decode_text: Literal[False], **kwargs: Any 

443 ) -> "ClientWebSocketResponse[Literal[False]]": ... 

444 

445 @overload 

446 async def _ws_connect( 

447 self, path: StrOrURL, *, decode_text: bool = ..., **kwargs: Any 

448 ) -> "ClientWebSocketResponse[bool]": ... 

449 

450 async def _ws_connect( 

451 self, path: StrOrURL, *, decode_text: bool = True, **kwargs: Any 

452 ) -> "ClientWebSocketResponse[bool]": 

453 ws = await self._session.ws_connect( 

454 self.make_url(path), decode_text=decode_text, **kwargs 

455 ) 

456 self._websockets.append(ws) 

457 return ws 

458 

459 async def close(self) -> None: 

460 """Close all fixtures created by the test client. 

461 

462 After that point, the TestClient is no longer usable. 

463 

464 This is an idempotent function: running close multiple times 

465 will not have any additional effects. 

466 

467 close is also run on exit when used as a(n) (asynchronous) 

468 context manager. 

469 

470 """ 

471 if not self._closed: 

472 for resp in self._responses: 

473 resp.close() 

474 for ws in self._websockets: 

475 await ws.close() 

476 await self._session.close() 

477 await self._server.close() 

478 self._closed = True 

479 

480 async def __aenter__(self) -> Self: 

481 await self.start_server() 

482 return self 

483 

484 async def __aexit__( 

485 self, 

486 exc_type: type[BaseException] | None, 

487 exc: BaseException | None, 

488 tb: TracebackType | None, 

489 ) -> None: 

490 await self.close() 

491 

492 

493class AioHTTPTestCase(IsolatedAsyncioTestCase, ABC): 

494 """A base class to allow for unittest web applications using aiohttp. 

495 

496 Provides the following: 

497 

498 * self.client (aiohttp.test_utils.TestClient): an aiohttp test client. 

499 * self.app (aiohttp.web.Application): the application returned by 

500 self.get_application() 

501 

502 Note that the TestClient's methods are asynchronous: you have to 

503 execute function on the test client using asynchronous methods. 

504 """ 

505 

506 @abstractmethod 

507 async def get_application(self) -> Application: 

508 """Get application. 

509 

510 This method should be overridden to return the aiohttp.web.Application 

511 object to test. 

512 """ 

513 

514 async def asyncSetUp(self) -> None: 

515 self.app = await self.get_application() 

516 self.server = await self.get_server(self.app) 

517 self.client = await self.get_client(self.server) 

518 

519 await self.client.start_server() 

520 

521 async def asyncTearDown(self) -> None: 

522 await self.client.close() 

523 

524 async def get_server(self, app: Application) -> TestServer: 

525 """Return a TestServer instance.""" 

526 return TestServer(app) 

527 

528 async def get_client(self, server: TestServer) -> TestClient[Request, Application]: 

529 """Return a TestClient instance.""" 

530 return TestClient(server) 

531 

532 

533def _create_app_mock() -> mock.MagicMock: 

534 def get_dict(app: Any, key: str) -> Any: 

535 return app.__app_dict[key] 

536 

537 def set_dict(app: Any, key: str, value: Any) -> None: 

538 app.__app_dict[key] = value 

539 

540 app = mock.MagicMock(spec=Application) 

541 app.__app_dict = {} 

542 app.__getitem__ = get_dict 

543 app.__setitem__ = set_dict 

544 

545 app.on_response_prepare = Signal(app) 

546 app.on_response_prepare.freeze() 

547 return app 

548 

549 

550def _create_transport(sslcontext: SSLContext | None = None) -> mock.Mock: 

551 transport = mock.Mock() 

552 

553 def get_extra_info(key: str) -> SSLContext | tuple[str, int] | None: 

554 if key == "sslcontext": 

555 return sslcontext 

556 return ("127.0.0.1", 80) if key == "sockname" else None 

557 

558 transport.get_extra_info.side_effect = get_extra_info 

559 return transport 

560 

561 

562def make_mocked_request( 

563 method: str, 

564 path: str, 

565 headers: LooseHeaders | None = None, 

566 *, 

567 match_info: dict[str, str] | None = None, 

568 version: HttpVersion = HttpVersion(1, 1), 

569 closing: bool = False, 

570 app: Application | None = None, 

571 writer: AbstractStreamWriter | None = None, 

572 protocol: RequestHandler[Request] | None = None, 

573 transport: asyncio.Transport | None = None, 

574 payload: StreamReader = EMPTY_PAYLOAD, 

575 sslcontext: SSLContext | None = None, 

576 client_max_size: int = 1024**2, 

577 loop: Any = ..., 

578) -> Request: 

579 """Creates mocked web.Request testing purposes. 

580 

581 Useful in unit tests, when spinning full web server is overkill or 

582 specific conditions and errors are hard to trigger. 

583 """ 

584 task = mock.Mock() 

585 if loop is ...: 

586 # no loop passed, try to get the current one if 

587 # its is running as we need a real loop to create 

588 # executor jobs to be able to do testing 

589 # with a real executor 

590 try: 

591 loop = asyncio.get_running_loop() 

592 except RuntimeError: 

593 loop = mock.Mock() 

594 loop.create_future.return_value = () 

595 

596 if version < HttpVersion(1, 1): 

597 closing = True 

598 

599 if headers: 

600 headers = HeadersDictProxy(CIMultiDict(headers)) 

601 raw_hdrs = tuple( 

602 (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items() 

603 ) 

604 else: 

605 headers = HeadersDictProxy(CIMultiDict()) 

606 raw_hdrs = () 

607 

608 chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower() 

609 

610 message = RawRequestMessage( 

611 method, 

612 path, 

613 version, 

614 headers, 

615 raw_hdrs, 

616 closing, 

617 None, 

618 False, 

619 chunked, 

620 URL(path), 

621 ) 

622 if app is None: 

623 app = _create_app_mock() 

624 

625 if transport is None: 

626 transport = _create_transport(sslcontext) 

627 

628 if protocol is None: 

629 protocol = mock.Mock() 

630 protocol.max_field_size = 8190 

631 protocol.max_line_length = 8190 

632 protocol.max_headers = 128 

633 protocol.transport = transport 

634 type(protocol).peername = mock.PropertyMock( 

635 return_value=transport.get_extra_info("peername") 

636 ) 

637 type(protocol).sockname = mock.PropertyMock( 

638 return_value=transport.get_extra_info("sockname") 

639 ) 

640 type(protocol).ssl_context = mock.PropertyMock(return_value=sslcontext) 

641 

642 if writer is None: 

643 writer = mock.Mock() 

644 writer.write_headers = mock.AsyncMock(return_value=None) 

645 writer.write = mock.AsyncMock(return_value=None) 

646 writer.write_eof = mock.AsyncMock(return_value=None) 

647 writer.drain = mock.AsyncMock(return_value=None) 

648 writer.transport = transport 

649 

650 protocol.transport = transport 

651 

652 req = Request( 

653 message, payload, protocol, writer, task, loop, client_max_size=client_max_size 

654 ) 

655 

656 match_info = UrlMappingMatchInfo( 

657 {} if match_info is None else match_info, mock.Mock() 

658 ) 

659 match_info.add_app(app) 

660 req._match_info = match_info 

661 

662 return req