Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/test_utils.py: 43%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

311 statements  

1"""Utilities shared by tests.""" 

2 

3import asyncio 

4import contextlib 

5import gc 

6import ipaddress 

7import os 

8import socket 

9import sys 

10from abc import ABC, abstractmethod 

11from types import TracebackType 

12from typing import ( 

13 TYPE_CHECKING, 

14 Any, 

15 Callable, 

16 Dict, 

17 Generic, 

18 Iterator, 

19 List, 

20 Optional, 

21 Type, 

22 TypeVar, 

23 Union, 

24 cast, 

25 overload, 

26) 

27from unittest import IsolatedAsyncioTestCase, mock 

28 

29from aiosignal import Signal 

30from multidict import CIMultiDict, CIMultiDictProxy 

31from yarl import URL 

32 

33import aiohttp 

34from aiohttp.client import ( 

35 _RequestContextManager, 

36 _RequestOptions, 

37 _WSRequestContextManager, 

38) 

39 

40from . import ClientSession, hdrs 

41from .abc import AbstractCookieJar, AbstractStreamWriter 

42from .client_reqrep import ClientResponse 

43from .client_ws import ClientWebSocketResponse 

44from .http import HttpVersion, RawRequestMessage 

45from .streams import EMPTY_PAYLOAD, StreamReader 

46from .typedefs import LooseHeaders, StrOrURL 

47from .web import ( 

48 Application, 

49 AppRunner, 

50 BaseRequest, 

51 BaseRunner, 

52 Request, 

53 RequestHandler, 

54 Server, 

55 ServerRunner, 

56 SockSite, 

57 UrlMappingMatchInfo, 

58) 

59from .web_protocol import _RequestHandler 

60 

61if TYPE_CHECKING: 

62 from ssl import SSLContext 

63else: 

64 SSLContext = None 

65 

66if sys.version_info >= (3, 11) and TYPE_CHECKING: 

67 from typing import Unpack 

68 

69if sys.version_info >= (3, 11): 

70 from typing import Self 

71else: 

72 Self = Any 

73 

74_ApplicationNone = TypeVar("_ApplicationNone", Application, None) 

75_Request = TypeVar("_Request", bound=BaseRequest) 

76 

77REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin" 

78 

79 

80def get_unused_port_socket( 

81 host: str, family: socket.AddressFamily = socket.AF_INET 

82) -> socket.socket: 

83 return get_port_socket(host, 0, family) 

84 

85 

86def get_port_socket( 

87 host: str, port: int, family: socket.AddressFamily = socket.AF_INET 

88) -> socket.socket: 

89 s = socket.socket(family, socket.SOCK_STREAM) 

90 if REUSE_ADDRESS: 

91 # Windows has different semantics for SO_REUSEADDR, 

92 # so don't set it. Ref: 

93 # https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse 

94 s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 

95 s.bind((host, port)) 

96 return s 

97 

98 

99def unused_port() -> int: 

100 """Return a port that is unused on the current host.""" 

101 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 

102 s.bind(("127.0.0.1", 0)) 

103 return cast(int, s.getsockname()[1]) 

104 

105 

106class BaseTestServer(ABC, Generic[_Request]): 

107 __test__ = False 

108 

109 def __init__( 

110 self, 

111 *, 

112 scheme: str = "", 

113 host: str = "127.0.0.1", 

114 port: Optional[int] = None, 

115 skip_url_asserts: bool = False, 

116 socket_factory: Callable[ 

117 [str, int, socket.AddressFamily], socket.socket 

118 ] = get_port_socket, 

119 **kwargs: Any, 

120 ) -> None: 

121 self.runner: Optional[BaseRunner[_Request]] = None 

122 self._root: Optional[URL] = None 

123 self.host = host 

124 self.port = port or 0 

125 self._closed = False 

126 self.scheme = scheme 

127 self.skip_url_asserts = skip_url_asserts 

128 self.socket_factory = socket_factory 

129 

130 async def start_server(self, **kwargs: Any) -> None: 

131 if self.runner: 

132 return 

133 self._ssl = kwargs.pop("ssl", None) 

134 self.runner = await self._make_runner(handler_cancellation=True, **kwargs) 

135 await self.runner.setup() 

136 absolute_host = self.host 

137 try: 

138 version = ipaddress.ip_address(self.host).version 

139 except ValueError: 

140 version = 4 

141 if version == 6: 

142 absolute_host = f"[{self.host}]" 

143 family = socket.AF_INET6 if version == 6 else socket.AF_INET 

144 _sock = self.socket_factory(self.host, self.port, family) 

145 self.host, self.port = _sock.getsockname()[:2] 

146 site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl) 

147 await site.start() 

148 server = site._server 

149 assert server is not None 

150 sockets = server.sockets 

151 assert sockets is not None 

152 self.port = sockets[0].getsockname()[1] 

153 if not self.scheme: 

154 self.scheme = "https" if self._ssl else "http" 

155 self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}") 

156 

157 @abstractmethod 

158 async def _make_runner(self, **kwargs: Any) -> BaseRunner[_Request]: # type: ignore[misc] 

159 """Return a new runner for the server.""" 

160 # TODO(PY311): Use Unpack to specify Server kwargs. 

161 

162 def make_url(self, path: StrOrURL) -> URL: 

163 assert self._root is not None 

164 url = URL(path) 

165 if not self.skip_url_asserts: 

166 assert not url.absolute 

167 return self._root.join(url) 

168 else: 

169 return URL(str(self._root) + str(path)) 

170 

171 @property 

172 def started(self) -> bool: 

173 return self.runner is not None 

174 

175 @property 

176 def closed(self) -> bool: 

177 return self._closed 

178 

179 @property 

180 def handler(self) -> Server[_Request]: 

181 # for backward compatibility 

182 # web.Server instance 

183 runner = self.runner 

184 assert runner is not None 

185 assert runner.server is not None 

186 return runner.server 

187 

188 async def close(self) -> None: 

189 """Close all fixtures created by the test client. 

190 

191 After that point, the TestClient is no longer usable. 

192 

193 This is an idempotent function: running close multiple times 

194 will not have any additional effects. 

195 

196 close is also run when the object is garbage collected, and on 

197 exit when used as a context manager. 

198 

199 """ 

200 if self.started and not self.closed: 

201 assert self.runner is not None 

202 await self.runner.cleanup() 

203 self._root = None 

204 self.port = 0 

205 self._closed = True 

206 

207 async def __aenter__(self) -> Self: 

208 await self.start_server() 

209 return self 

210 

211 async def __aexit__( 

212 self, 

213 exc_type: Optional[Type[BaseException]], 

214 exc_value: Optional[BaseException], 

215 traceback: Optional[TracebackType], 

216 ) -> None: 

217 await self.close() 

218 

219 

220class TestServer(BaseTestServer[Request]): 

221 def __init__( 

222 self, 

223 app: Application, 

224 *, 

225 scheme: str = "", 

226 host: str = "127.0.0.1", 

227 port: Optional[int] = None, 

228 **kwargs: Any, 

229 ): 

230 self.app = app 

231 super().__init__(scheme=scheme, host=host, port=port, **kwargs) 

232 

233 async def _make_runner(self, **kwargs: Any) -> AppRunner: 

234 # TODO(PY311): Use Unpack to specify Server kwargs. 

235 return AppRunner(self.app, **kwargs) 

236 

237 

238class RawTestServer(BaseTestServer[BaseRequest]): 

239 def __init__( 

240 self, 

241 handler: _RequestHandler[BaseRequest], 

242 *, 

243 scheme: str = "", 

244 host: str = "127.0.0.1", 

245 port: Optional[int] = None, 

246 **kwargs: Any, 

247 ) -> None: 

248 self._handler = handler 

249 super().__init__(scheme=scheme, host=host, port=port, **kwargs) 

250 

251 async def _make_runner(self, **kwargs: Any) -> ServerRunner: 

252 # TODO(PY311): Use Unpack to specify Server kwargs. 

253 srv = Server(self._handler, **kwargs) 

254 return ServerRunner(srv, **kwargs) 

255 

256 

257class TestClient(Generic[_Request, _ApplicationNone]): 

258 """ 

259 A test client implementation. 

260 

261 To write functional tests for aiohttp based servers. 

262 

263 """ 

264 

265 __test__ = False 

266 

267 @overload 

268 def __init__( # type: ignore[misc] 

269 self: "TestClient[Request, Application]", 

270 server: TestServer, 

271 *, 

272 cookie_jar: Optional[AbstractCookieJar] = None, 

273 **kwargs: Any, 

274 ) -> None: ... 

275 @overload 

276 def __init__( # type: ignore[misc] 

277 self: "TestClient[_Request, None]", 

278 server: BaseTestServer[_Request], 

279 *, 

280 cookie_jar: Optional[AbstractCookieJar] = None, 

281 **kwargs: Any, 

282 ) -> None: ... 

283 def __init__( # type: ignore[misc] 

284 self, 

285 server: BaseTestServer[_Request], 

286 *, 

287 cookie_jar: Optional[AbstractCookieJar] = None, 

288 **kwargs: Any, 

289 ) -> None: 

290 # TODO(PY311): Use Unpack to specify ClientSession kwargs. 

291 if not isinstance(server, BaseTestServer): 

292 raise TypeError( 

293 "server must be TestServer instance, found type: %r" % type(server) 

294 ) 

295 self._server = server 

296 if cookie_jar is None: 

297 cookie_jar = aiohttp.CookieJar(unsafe=True) 

298 self._session = ClientSession(cookie_jar=cookie_jar, **kwargs) 

299 self._session._retry_connection = False 

300 self._closed = False 

301 self._responses: List[ClientResponse] = [] 

302 self._websockets: List[ClientWebSocketResponse] = [] 

303 

304 async def start_server(self) -> None: 

305 await self._server.start_server() 

306 

307 @property 

308 def scheme(self) -> Union[str, object]: 

309 return self._server.scheme 

310 

311 @property 

312 def host(self) -> str: 

313 return self._server.host 

314 

315 @property 

316 def port(self) -> int: 

317 return self._server.port 

318 

319 @property 

320 def server(self) -> BaseTestServer[_Request]: 

321 return self._server 

322 

323 @property 

324 def app(self) -> _ApplicationNone: 

325 return getattr(self._server, "app", None) # type: ignore[return-value] 

326 

327 @property 

328 def session(self) -> ClientSession: 

329 """An internal aiohttp.ClientSession. 

330 

331 Unlike the methods on the TestClient, client session requests 

332 do not automatically include the host in the url queried, and 

333 will require an absolute path to the resource. 

334 

335 """ 

336 return self._session 

337 

338 def make_url(self, path: StrOrURL) -> URL: 

339 return self._server.make_url(path) 

340 

341 async def _request( 

342 self, method: str, path: StrOrURL, **kwargs: Any 

343 ) -> ClientResponse: 

344 resp = await self._session.request(method, self.make_url(path), **kwargs) 

345 # save it to close later 

346 self._responses.append(resp) 

347 return resp 

348 

349 if sys.version_info >= (3, 11) and TYPE_CHECKING: 

350 

351 def request( 

352 self, method: str, path: StrOrURL, **kwargs: Unpack[_RequestOptions] 

353 ) -> _RequestContextManager: ... 

354 

355 def get( 

356 self, 

357 path: StrOrURL, 

358 **kwargs: Unpack[_RequestOptions], 

359 ) -> _RequestContextManager: ... 

360 

361 def options( 

362 self, 

363 path: StrOrURL, 

364 **kwargs: Unpack[_RequestOptions], 

365 ) -> _RequestContextManager: ... 

366 

367 def head( 

368 self, 

369 path: StrOrURL, 

370 **kwargs: Unpack[_RequestOptions], 

371 ) -> _RequestContextManager: ... 

372 

373 def post( 

374 self, 

375 path: StrOrURL, 

376 **kwargs: Unpack[_RequestOptions], 

377 ) -> _RequestContextManager: ... 

378 

379 def put( 

380 self, 

381 path: StrOrURL, 

382 **kwargs: Unpack[_RequestOptions], 

383 ) -> _RequestContextManager: ... 

384 

385 def patch( 

386 self, 

387 path: StrOrURL, 

388 **kwargs: Unpack[_RequestOptions], 

389 ) -> _RequestContextManager: ... 

390 

391 def delete( 

392 self, 

393 path: StrOrURL, 

394 **kwargs: Unpack[_RequestOptions], 

395 ) -> _RequestContextManager: ... 

396 

397 else: 

398 

399 def request( 

400 self, method: str, path: StrOrURL, **kwargs: Any 

401 ) -> _RequestContextManager: 

402 """Routes a request to tested http server. 

403 

404 The interface is identical to aiohttp.ClientSession.request, 

405 except the loop kwarg is overridden by the instance used by the 

406 test server. 

407 

408 """ 

409 return _RequestContextManager(self._request(method, path, **kwargs)) 

410 

411 def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

412 """Perform an HTTP GET request.""" 

413 return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs)) 

414 

415 def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

416 """Perform an HTTP POST request.""" 

417 return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs)) 

418 

419 def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

420 """Perform an HTTP OPTIONS request.""" 

421 return _RequestContextManager( 

422 self._request(hdrs.METH_OPTIONS, path, **kwargs) 

423 ) 

424 

425 def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

426 """Perform an HTTP HEAD request.""" 

427 return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs)) 

428 

429 def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

430 """Perform an HTTP PUT request.""" 

431 return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs)) 

432 

433 def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

434 """Perform an HTTP PATCH request.""" 

435 return _RequestContextManager( 

436 self._request(hdrs.METH_PATCH, path, **kwargs) 

437 ) 

438 

439 def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: 

440 """Perform an HTTP PATCH request.""" 

441 return _RequestContextManager( 

442 self._request(hdrs.METH_DELETE, path, **kwargs) 

443 ) 

444 

445 def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager: 

446 """Initiate websocket connection. 

447 

448 The api corresponds to aiohttp.ClientSession.ws_connect. 

449 

450 """ 

451 return _WSRequestContextManager(self._ws_connect(path, **kwargs)) 

452 

453 async def _ws_connect( 

454 self, path: StrOrURL, **kwargs: Any 

455 ) -> ClientWebSocketResponse: 

456 ws = await self._session.ws_connect(self.make_url(path), **kwargs) 

457 self._websockets.append(ws) 

458 return ws 

459 

460 async def close(self) -> None: 

461 """Close all fixtures created by the test client. 

462 

463 After that point, the TestClient is no longer usable. 

464 

465 This is an idempotent function: running close multiple times 

466 will not have any additional effects. 

467 

468 close is also run on exit when used as a(n) (asynchronous) 

469 context manager. 

470 

471 """ 

472 if not self._closed: 

473 for resp in self._responses: 

474 resp.close() 

475 for ws in self._websockets: 

476 await ws.close() 

477 await self._session.close() 

478 await self._server.close() 

479 self._closed = True 

480 

481 async def __aenter__(self) -> Self: 

482 await self.start_server() 

483 return self 

484 

485 async def __aexit__( 

486 self, 

487 exc_type: Optional[Type[BaseException]], 

488 exc: Optional[BaseException], 

489 tb: Optional[TracebackType], 

490 ) -> None: 

491 await self.close() 

492 

493 

494class AioHTTPTestCase(IsolatedAsyncioTestCase, ABC): 

495 """A base class to allow for unittest web applications using aiohttp. 

496 

497 Provides the following: 

498 

499 * self.client (aiohttp.test_utils.TestClient): an aiohttp test client. 

500 * self.app (aiohttp.web.Application): the application returned by 

501 self.get_application() 

502 

503 Note that the TestClient's methods are asynchronous: you have to 

504 execute function on the test client using asynchronous methods. 

505 """ 

506 

507 @abstractmethod 

508 async def get_application(self) -> Application: 

509 """Get application. 

510 

511 This method should be overridden to return the aiohttp.web.Application 

512 object to test. 

513 """ 

514 

515 async def asyncSetUp(self) -> None: 

516 self.app = await self.get_application() 

517 self.server = await self.get_server(self.app) 

518 self.client = await self.get_client(self.server) 

519 

520 await self.client.start_server() 

521 

522 async def asyncTearDown(self) -> None: 

523 await self.client.close() 

524 

525 async def get_server(self, app: Application) -> TestServer: 

526 """Return a TestServer instance.""" 

527 return TestServer(app) 

528 

529 async def get_client(self, server: TestServer) -> TestClient[Request, Application]: 

530 """Return a TestClient instance.""" 

531 return TestClient(server) 

532 

533 

534_LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop] 

535 

536 

537@contextlib.contextmanager 

538def loop_context( 

539 loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False 

540) -> Iterator[asyncio.AbstractEventLoop]: 

541 """A contextmanager that creates an event_loop, for test purposes. 

542 

543 Handles the creation and cleanup of a test loop. 

544 """ 

545 loop = setup_test_loop(loop_factory) 

546 yield loop 

547 teardown_test_loop(loop, fast=fast) 

548 

549 

550def setup_test_loop( 

551 loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, 

552) -> asyncio.AbstractEventLoop: 

553 """Create and return an asyncio.BaseEventLoop instance. 

554 

555 The caller should also call teardown_test_loop, 

556 once they are done with the loop. 

557 """ 

558 loop = loop_factory() 

559 asyncio.set_event_loop(loop) 

560 return loop 

561 

562 

563def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None: 

564 """Teardown and cleanup an event_loop created by setup_test_loop.""" 

565 closed = loop.is_closed() 

566 if not closed: 

567 loop.call_soon(loop.stop) 

568 loop.run_forever() 

569 loop.close() 

570 

571 if not fast: 

572 gc.collect() 

573 

574 asyncio.set_event_loop(None) 

575 

576 

577def _create_app_mock() -> mock.MagicMock: 

578 def get_dict(app: Any, key: str) -> Any: 

579 return app.__app_dict[key] 

580 

581 def set_dict(app: Any, key: str, value: Any) -> None: 

582 app.__app_dict[key] = value 

583 

584 app = mock.MagicMock(spec=Application) 

585 app.__app_dict = {} 

586 app.__getitem__ = get_dict 

587 app.__setitem__ = set_dict 

588 

589 app.on_response_prepare = Signal(app) 

590 app.on_response_prepare.freeze() 

591 return app 

592 

593 

594def _create_transport(sslcontext: Optional[SSLContext] = None) -> mock.Mock: 

595 transport = mock.Mock() 

596 

597 def get_extra_info(key: str) -> Optional[SSLContext]: 

598 if key == "sslcontext": 

599 return sslcontext 

600 else: 

601 return None 

602 

603 transport.get_extra_info.side_effect = get_extra_info 

604 return transport 

605 

606 

607def make_mocked_request( 

608 method: str, 

609 path: str, 

610 headers: Optional[LooseHeaders] = None, 

611 *, 

612 match_info: Optional[Dict[str, str]] = None, 

613 version: HttpVersion = HttpVersion(1, 1), 

614 closing: bool = False, 

615 app: Optional[Application] = None, 

616 writer: Optional[AbstractStreamWriter] = None, 

617 protocol: Optional[RequestHandler[Request]] = None, 

618 transport: Optional[asyncio.Transport] = None, 

619 payload: StreamReader = EMPTY_PAYLOAD, 

620 sslcontext: Optional[SSLContext] = None, 

621 client_max_size: int = 1024**2, 

622 loop: Any = ..., 

623) -> Request: 

624 """Creates mocked web.Request testing purposes. 

625 

626 Useful in unit tests, when spinning full web server is overkill or 

627 specific conditions and errors are hard to trigger. 

628 """ 

629 task = mock.Mock() 

630 if loop is ...: 

631 # no loop passed, try to get the current one if 

632 # its is running as we need a real loop to create 

633 # executor jobs to be able to do testing 

634 # with a real executor 

635 try: 

636 loop = asyncio.get_running_loop() 

637 except RuntimeError: 

638 loop = mock.Mock() 

639 loop.create_future.return_value = () 

640 

641 if version < HttpVersion(1, 1): 

642 closing = True 

643 

644 if headers: 

645 headers = CIMultiDictProxy(CIMultiDict(headers)) 

646 raw_hdrs = tuple( 

647 (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items() 

648 ) 

649 else: 

650 headers = CIMultiDictProxy(CIMultiDict()) 

651 raw_hdrs = () 

652 

653 chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower() 

654 

655 message = RawRequestMessage( 

656 method, 

657 path, 

658 version, 

659 headers, 

660 raw_hdrs, 

661 closing, 

662 None, 

663 False, 

664 chunked, 

665 URL(path), 

666 ) 

667 if app is None: 

668 app = _create_app_mock() 

669 

670 if transport is None: 

671 transport = _create_transport(sslcontext) 

672 

673 if protocol is None: 

674 protocol = mock.Mock() 

675 protocol.transport = transport 

676 type(protocol).peername = mock.PropertyMock( 

677 return_value=transport.get_extra_info("peername") 

678 ) 

679 type(protocol).ssl_context = mock.PropertyMock(return_value=sslcontext) 

680 

681 if writer is None: 

682 writer = mock.Mock() 

683 writer.write_headers = mock.AsyncMock(return_value=None) 

684 writer.write = mock.AsyncMock(return_value=None) 

685 writer.write_eof = mock.AsyncMock(return_value=None) 

686 writer.drain = mock.AsyncMock(return_value=None) 

687 writer.transport = transport 

688 

689 protocol.transport = transport 

690 

691 req = Request( 

692 message, payload, protocol, writer, task, loop, client_max_size=client_max_size 

693 ) 

694 

695 match_info = UrlMappingMatchInfo( 

696 {} if match_info is None else match_info, mock.Mock() 

697 ) 

698 match_info.add_app(app) 

699 req._match_info = match_info 

700 

701 return req