Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/testclient.py: 25%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

312 statements  

1from __future__ import annotations 

2 

3import contextlib 

4import inspect 

5import io 

6import json 

7import math 

8import sys 

9import warnings 

10from collections.abc import Awaitable, Generator, Iterable, Mapping, MutableMapping, Sequence 

11from concurrent.futures import Future 

12from contextlib import AbstractContextManager 

13from types import GeneratorType 

14from typing import ( 

15 Any, 

16 Callable, 

17 Literal, 

18 TypedDict, 

19 Union, 

20 cast, 

21) 

22from urllib.parse import unquote, urljoin 

23 

24import anyio 

25import anyio.abc 

26import anyio.from_thread 

27from anyio.streams.stapled import StapledObjectStream 

28 

29from starlette._utils import is_async_callable 

30from starlette.types import ASGIApp, Message, Receive, Scope, Send 

31from starlette.websockets import WebSocketDisconnect 

32 

33if sys.version_info >= (3, 10): # pragma: no cover 

34 from typing import TypeGuard 

35else: # pragma: no cover 

36 from typing_extensions import TypeGuard 

37 

38try: 

39 import httpx 

40except ModuleNotFoundError: # pragma: no cover 

41 raise RuntimeError( 

42 "The starlette.testclient module requires the httpx package to be installed.\n" 

43 "You can install this with:\n" 

44 " $ pip install httpx\n" 

45 ) 

46_PortalFactoryType = Callable[[], AbstractContextManager[anyio.abc.BlockingPortal]] 

47 

48ASGIInstance = Callable[[Receive, Send], Awaitable[None]] 

49ASGI2App = Callable[[Scope], ASGIInstance] 

50ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]] 

51 

52 

53_RequestData = Mapping[str, Union[str, Iterable[str], bytes]] 

54 

55 

56def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]: 

57 if inspect.isclass(app): 

58 return hasattr(app, "__await__") 

59 return is_async_callable(app) 

60 

61 

62class _WrapASGI2: 

63 """ 

64 Provide an ASGI3 interface onto an ASGI2 app. 

65 """ 

66 

67 def __init__(self, app: ASGI2App) -> None: 

68 self.app = app 

69 

70 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

71 instance = self.app(scope) 

72 await instance(receive, send) 

73 

74 

75class _AsyncBackend(TypedDict): 

76 backend: str 

77 backend_options: dict[str, Any] 

78 

79 

80class _Upgrade(Exception): 

81 def __init__(self, session: WebSocketTestSession) -> None: 

82 self.session = session 

83 

84 

85class WebSocketDenialResponse( # type: ignore[misc] 

86 httpx.Response, 

87 WebSocketDisconnect, 

88): 

89 """ 

90 A special case of `WebSocketDisconnect`, raised in the `TestClient` if the 

91 `WebSocket` is closed before being accepted with a `send_denial_response()`. 

92 """ 

93 

94 

95class WebSocketTestSession: 

96 def __init__( 

97 self, 

98 app: ASGI3App, 

99 scope: Scope, 

100 portal_factory: _PortalFactoryType, 

101 ) -> None: 

102 self.app = app 

103 self.scope = scope 

104 self.accepted_subprotocol = None 

105 self.portal_factory = portal_factory 

106 self.extra_headers = None 

107 

108 def __enter__(self) -> WebSocketTestSession: 

109 with contextlib.ExitStack() as stack: 

110 self.portal = portal = stack.enter_context(self.portal_factory()) 

111 fut, cs = portal.start_task(self._run) 

112 stack.callback(fut.result) 

113 stack.callback(portal.call, cs.cancel) 

114 self.send({"type": "websocket.connect"}) 

115 message = self.receive() 

116 self._raise_on_close(message) 

117 self.accepted_subprotocol = message.get("subprotocol", None) 

118 self.extra_headers = message.get("headers", None) 

119 stack.callback(self.close, 1000) 

120 self.exit_stack = stack.pop_all() 

121 return self 

122 

123 def __exit__(self, *args: Any) -> bool | None: 

124 return self.exit_stack.__exit__(*args) 

125 

126 async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: 

127 """ 

128 The sub-thread in which the websocket session runs. 

129 """ 

130 send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) 

131 send_tx, send_rx = send 

132 receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) 

133 receive_tx, receive_rx = receive 

134 with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs: 

135 self._receive_tx = receive_tx 

136 self._send_rx = send_rx 

137 task_status.started(cs) 

138 await self.app(self.scope, receive_rx.receive, send_tx.send) 

139 

140 # wait for cs.cancel to be called before closing streams 

141 await anyio.sleep_forever() 

142 

143 def _raise_on_close(self, message: Message) -> None: 

144 if message["type"] == "websocket.close": 

145 raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", "")) 

146 elif message["type"] == "websocket.http.response.start": 

147 status_code: int = message["status"] 

148 headers: list[tuple[bytes, bytes]] = message["headers"] 

149 body: list[bytes] = [] 

150 while True: 

151 message = self.receive() 

152 assert message["type"] == "websocket.http.response.body" 

153 body.append(message["body"]) 

154 if not message.get("more_body", False): 

155 break 

156 raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body)) 

157 

158 def send(self, message: Message) -> None: 

159 self.portal.call(self._receive_tx.send, message) 

160 

161 def send_text(self, data: str) -> None: 

162 self.send({"type": "websocket.receive", "text": data}) 

163 

164 def send_bytes(self, data: bytes) -> None: 

165 self.send({"type": "websocket.receive", "bytes": data}) 

166 

167 def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None: 

168 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) 

169 if mode == "text": 

170 self.send({"type": "websocket.receive", "text": text}) 

171 else: 

172 self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) 

173 

174 def close(self, code: int = 1000, reason: str | None = None) -> None: 

175 self.send({"type": "websocket.disconnect", "code": code, "reason": reason}) 

176 

177 def receive(self) -> Message: 

178 return self.portal.call(self._send_rx.receive) 

179 

180 def receive_text(self) -> str: 

181 message = self.receive() 

182 self._raise_on_close(message) 

183 return cast(str, message["text"]) 

184 

185 def receive_bytes(self) -> bytes: 

186 message = self.receive() 

187 self._raise_on_close(message) 

188 return cast(bytes, message["bytes"]) 

189 

190 def receive_json(self, mode: Literal["text", "binary"] = "text") -> Any: 

191 message = self.receive() 

192 self._raise_on_close(message) 

193 if mode == "text": 

194 text = message["text"] 

195 else: 

196 text = message["bytes"].decode("utf-8") 

197 return json.loads(text) 

198 

199 

200class _TestClientTransport(httpx.BaseTransport): 

201 def __init__( 

202 self, 

203 app: ASGI3App, 

204 portal_factory: _PortalFactoryType, 

205 raise_server_exceptions: bool = True, 

206 root_path: str = "", 

207 *, 

208 client: tuple[str, int], 

209 app_state: dict[str, Any], 

210 ) -> None: 

211 self.app = app 

212 self.raise_server_exceptions = raise_server_exceptions 

213 self.root_path = root_path 

214 self.portal_factory = portal_factory 

215 self.app_state = app_state 

216 self.client = client 

217 

218 def handle_request(self, request: httpx.Request) -> httpx.Response: 

219 scheme = request.url.scheme 

220 netloc = request.url.netloc.decode(encoding="ascii") 

221 path = request.url.path 

222 raw_path = request.url.raw_path 

223 query = request.url.query.decode(encoding="ascii") 

224 

225 default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] 

226 

227 if ":" in netloc: 

228 host, port_string = netloc.split(":", 1) 

229 port = int(port_string) 

230 else: 

231 host = netloc 

232 port = default_port 

233 

234 # Include the 'host' header. 

235 if "host" in request.headers: 

236 headers: list[tuple[bytes, bytes]] = [] 

237 elif port == default_port: # pragma: no cover 

238 headers = [(b"host", host.encode())] 

239 else: # pragma: no cover 

240 headers = [(b"host", (f"{host}:{port}").encode())] 

241 

242 # Include other request headers. 

243 headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()] 

244 

245 scope: dict[str, Any] 

246 

247 if scheme in {"ws", "wss"}: 

248 subprotocol = request.headers.get("sec-websocket-protocol", None) 

249 if subprotocol is None: 

250 subprotocols: Sequence[str] = [] 

251 else: 

252 subprotocols = [value.strip() for value in subprotocol.split(",")] 

253 scope = { 

254 "type": "websocket", 

255 "path": unquote(path), 

256 "raw_path": raw_path.split(b"?", 1)[0], 

257 "root_path": self.root_path, 

258 "scheme": scheme, 

259 "query_string": query.encode(), 

260 "headers": headers, 

261 "client": self.client, 

262 "server": [host, port], 

263 "subprotocols": subprotocols, 

264 "state": self.app_state.copy(), 

265 "extensions": {"websocket.http.response": {}}, 

266 } 

267 session = WebSocketTestSession(self.app, scope, self.portal_factory) 

268 raise _Upgrade(session) 

269 

270 scope = { 

271 "type": "http", 

272 "http_version": "1.1", 

273 "method": request.method, 

274 "path": unquote(path), 

275 "raw_path": raw_path.split(b"?", 1)[0], 

276 "root_path": self.root_path, 

277 "scheme": scheme, 

278 "query_string": query.encode(), 

279 "headers": headers, 

280 "client": self.client, 

281 "server": [host, port], 

282 "extensions": {"http.response.debug": {}}, 

283 "state": self.app_state.copy(), 

284 } 

285 

286 request_complete = False 

287 response_started = False 

288 response_complete: anyio.Event 

289 raw_kwargs: dict[str, Any] = {"stream": io.BytesIO()} 

290 template = None 

291 context = None 

292 

293 async def receive() -> Message: 

294 nonlocal request_complete 

295 

296 if request_complete: 

297 if not response_complete.is_set(): 

298 await response_complete.wait() 

299 return {"type": "http.disconnect"} 

300 

301 body = request.read() 

302 if isinstance(body, str): 

303 body_bytes: bytes = body.encode("utf-8") # pragma: no cover 

304 elif body is None: 

305 body_bytes = b"" # pragma: no cover 

306 elif isinstance(body, GeneratorType): 

307 try: # pragma: no cover 

308 chunk = body.send(None) 

309 if isinstance(chunk, str): 

310 chunk = chunk.encode("utf-8") 

311 return {"type": "http.request", "body": chunk, "more_body": True} 

312 except StopIteration: # pragma: no cover 

313 request_complete = True 

314 return {"type": "http.request", "body": b""} 

315 else: 

316 body_bytes = body 

317 

318 request_complete = True 

319 return {"type": "http.request", "body": body_bytes} 

320 

321 async def send(message: Message) -> None: 

322 nonlocal raw_kwargs, response_started, template, context 

323 

324 if message["type"] == "http.response.start": 

325 assert not response_started, 'Received multiple "http.response.start" messages.' 

326 raw_kwargs["status_code"] = message["status"] 

327 raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])] 

328 response_started = True 

329 elif message["type"] == "http.response.body": 

330 assert response_started, 'Received "http.response.body" without "http.response.start".' 

331 assert not response_complete.is_set(), 'Received "http.response.body" after response completed.' 

332 body = message.get("body", b"") 

333 more_body = message.get("more_body", False) 

334 if request.method != "HEAD": 

335 raw_kwargs["stream"].write(body) 

336 if not more_body: 

337 raw_kwargs["stream"].seek(0) 

338 response_complete.set() 

339 elif message["type"] == "http.response.debug": 

340 template = message["info"]["template"] 

341 context = message["info"]["context"] 

342 

343 try: 

344 with self.portal_factory() as portal: 

345 response_complete = portal.call(anyio.Event) 

346 portal.call(self.app, scope, receive, send) 

347 except BaseException as exc: 

348 if self.raise_server_exceptions: 

349 raise exc 

350 

351 if self.raise_server_exceptions: 

352 assert response_started, "TestClient did not receive any response." 

353 elif not response_started: 

354 raw_kwargs = { 

355 "status_code": 500, 

356 "headers": [], 

357 "stream": io.BytesIO(), 

358 } 

359 

360 raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read()) 

361 

362 response = httpx.Response(**raw_kwargs, request=request) 

363 if template is not None: 

364 response.template = template # type: ignore[attr-defined] 

365 response.context = context # type: ignore[attr-defined] 

366 return response 

367 

368 

369class TestClient(httpx.Client): 

370 __test__ = False 

371 task: Future[None] 

372 portal: anyio.abc.BlockingPortal | None = None 

373 

374 def __init__( 

375 self, 

376 app: ASGIApp, 

377 base_url: str = "http://testserver", 

378 raise_server_exceptions: bool = True, 

379 root_path: str = "", 

380 backend: Literal["asyncio", "trio"] = "asyncio", 

381 backend_options: dict[str, Any] | None = None, 

382 cookies: httpx._types.CookieTypes | None = None, 

383 headers: dict[str, str] | None = None, 

384 follow_redirects: bool = True, 

385 client: tuple[str, int] = ("testclient", 50000), 

386 ) -> None: 

387 self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {}) 

388 if _is_asgi3(app): 

389 asgi_app = app 

390 else: 

391 app = cast(ASGI2App, app) # type: ignore[assignment] 

392 asgi_app = _WrapASGI2(app) # type: ignore[arg-type] 

393 self.app = asgi_app 

394 self.app_state: dict[str, Any] = {} 

395 transport = _TestClientTransport( 

396 self.app, 

397 portal_factory=self._portal_factory, 

398 raise_server_exceptions=raise_server_exceptions, 

399 root_path=root_path, 

400 app_state=self.app_state, 

401 client=client, 

402 ) 

403 if headers is None: 

404 headers = {} 

405 headers.setdefault("user-agent", "testclient") 

406 super().__init__( 

407 base_url=base_url, 

408 headers=headers, 

409 transport=transport, 

410 follow_redirects=follow_redirects, 

411 cookies=cookies, 

412 ) 

413 

414 @contextlib.contextmanager 

415 def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]: 

416 if self.portal is not None: 

417 yield self.portal 

418 else: 

419 with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal: 

420 yield portal 

421 

422 def request( # type: ignore[override] 

423 self, 

424 method: str, 

425 url: httpx._types.URLTypes, 

426 *, 

427 content: httpx._types.RequestContent | None = None, 

428 data: _RequestData | None = None, 

429 files: httpx._types.RequestFiles | None = None, 

430 json: Any = None, 

431 params: httpx._types.QueryParamTypes | None = None, 

432 headers: httpx._types.HeaderTypes | None = None, 

433 cookies: httpx._types.CookieTypes | None = None, 

434 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

435 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

436 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

437 extensions: dict[str, Any] | None = None, 

438 ) -> httpx.Response: 

439 if timeout is not httpx.USE_CLIENT_DEFAULT: 

440 warnings.warn( 

441 "You should not use the 'timeout' argument with the TestClient. " 

442 "See https://github.com/encode/starlette/issues/1108 for more information.", 

443 DeprecationWarning, 

444 ) 

445 url = self._merge_url(url) 

446 return super().request( 

447 method, 

448 url, 

449 content=content, 

450 data=data, 

451 files=files, 

452 json=json, 

453 params=params, 

454 headers=headers, 

455 cookies=cookies, 

456 auth=auth, 

457 follow_redirects=follow_redirects, 

458 timeout=timeout, 

459 extensions=extensions, 

460 ) 

461 

462 def get( # type: ignore[override] 

463 self, 

464 url: httpx._types.URLTypes, 

465 *, 

466 params: httpx._types.QueryParamTypes | None = None, 

467 headers: httpx._types.HeaderTypes | None = None, 

468 cookies: httpx._types.CookieTypes | None = None, 

469 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

470 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

471 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

472 extensions: dict[str, Any] | None = None, 

473 ) -> httpx.Response: 

474 return super().get( 

475 url, 

476 params=params, 

477 headers=headers, 

478 cookies=cookies, 

479 auth=auth, 

480 follow_redirects=follow_redirects, 

481 timeout=timeout, 

482 extensions=extensions, 

483 ) 

484 

485 def options( # type: ignore[override] 

486 self, 

487 url: httpx._types.URLTypes, 

488 *, 

489 params: httpx._types.QueryParamTypes | None = None, 

490 headers: httpx._types.HeaderTypes | None = None, 

491 cookies: httpx._types.CookieTypes | None = None, 

492 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

493 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

494 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

495 extensions: dict[str, Any] | None = None, 

496 ) -> httpx.Response: 

497 return super().options( 

498 url, 

499 params=params, 

500 headers=headers, 

501 cookies=cookies, 

502 auth=auth, 

503 follow_redirects=follow_redirects, 

504 timeout=timeout, 

505 extensions=extensions, 

506 ) 

507 

508 def head( # type: ignore[override] 

509 self, 

510 url: httpx._types.URLTypes, 

511 *, 

512 params: httpx._types.QueryParamTypes | None = None, 

513 headers: httpx._types.HeaderTypes | None = None, 

514 cookies: httpx._types.CookieTypes | None = None, 

515 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

516 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

517 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

518 extensions: dict[str, Any] | None = None, 

519 ) -> httpx.Response: 

520 return super().head( 

521 url, 

522 params=params, 

523 headers=headers, 

524 cookies=cookies, 

525 auth=auth, 

526 follow_redirects=follow_redirects, 

527 timeout=timeout, 

528 extensions=extensions, 

529 ) 

530 

531 def post( # type: ignore[override] 

532 self, 

533 url: httpx._types.URLTypes, 

534 *, 

535 content: httpx._types.RequestContent | None = None, 

536 data: _RequestData | None = None, 

537 files: httpx._types.RequestFiles | None = None, 

538 json: Any = None, 

539 params: httpx._types.QueryParamTypes | None = None, 

540 headers: httpx._types.HeaderTypes | None = None, 

541 cookies: httpx._types.CookieTypes | None = None, 

542 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

543 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

544 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

545 extensions: dict[str, Any] | None = None, 

546 ) -> httpx.Response: 

547 return super().post( 

548 url, 

549 content=content, 

550 data=data, 

551 files=files, 

552 json=json, 

553 params=params, 

554 headers=headers, 

555 cookies=cookies, 

556 auth=auth, 

557 follow_redirects=follow_redirects, 

558 timeout=timeout, 

559 extensions=extensions, 

560 ) 

561 

562 def put( # type: ignore[override] 

563 self, 

564 url: httpx._types.URLTypes, 

565 *, 

566 content: httpx._types.RequestContent | None = None, 

567 data: _RequestData | None = None, 

568 files: httpx._types.RequestFiles | None = None, 

569 json: Any = None, 

570 params: httpx._types.QueryParamTypes | None = None, 

571 headers: httpx._types.HeaderTypes | None = None, 

572 cookies: httpx._types.CookieTypes | None = None, 

573 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

574 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

575 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

576 extensions: dict[str, Any] | None = None, 

577 ) -> httpx.Response: 

578 return super().put( 

579 url, 

580 content=content, 

581 data=data, 

582 files=files, 

583 json=json, 

584 params=params, 

585 headers=headers, 

586 cookies=cookies, 

587 auth=auth, 

588 follow_redirects=follow_redirects, 

589 timeout=timeout, 

590 extensions=extensions, 

591 ) 

592 

593 def patch( # type: ignore[override] 

594 self, 

595 url: httpx._types.URLTypes, 

596 *, 

597 content: httpx._types.RequestContent | None = None, 

598 data: _RequestData | None = None, 

599 files: httpx._types.RequestFiles | None = None, 

600 json: Any = None, 

601 params: httpx._types.QueryParamTypes | None = None, 

602 headers: httpx._types.HeaderTypes | None = None, 

603 cookies: httpx._types.CookieTypes | None = None, 

604 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

605 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

606 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

607 extensions: dict[str, Any] | None = None, 

608 ) -> httpx.Response: 

609 return super().patch( 

610 url, 

611 content=content, 

612 data=data, 

613 files=files, 

614 json=json, 

615 params=params, 

616 headers=headers, 

617 cookies=cookies, 

618 auth=auth, 

619 follow_redirects=follow_redirects, 

620 timeout=timeout, 

621 extensions=extensions, 

622 ) 

623 

624 def delete( # type: ignore[override] 

625 self, 

626 url: httpx._types.URLTypes, 

627 *, 

628 params: httpx._types.QueryParamTypes | None = None, 

629 headers: httpx._types.HeaderTypes | None = None, 

630 cookies: httpx._types.CookieTypes | None = None, 

631 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

632 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

633 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

634 extensions: dict[str, Any] | None = None, 

635 ) -> httpx.Response: 

636 return super().delete( 

637 url, 

638 params=params, 

639 headers=headers, 

640 cookies=cookies, 

641 auth=auth, 

642 follow_redirects=follow_redirects, 

643 timeout=timeout, 

644 extensions=extensions, 

645 ) 

646 

647 def websocket_connect( 

648 self, 

649 url: str, 

650 subprotocols: Sequence[str] | None = None, 

651 **kwargs: Any, 

652 ) -> WebSocketTestSession: 

653 url = urljoin("ws://testserver", url) 

654 headers = kwargs.get("headers", {}) 

655 headers.setdefault("connection", "upgrade") 

656 headers.setdefault("sec-websocket-key", "testserver==") 

657 headers.setdefault("sec-websocket-version", "13") 

658 if subprotocols is not None: 

659 headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) 

660 kwargs["headers"] = headers 

661 try: 

662 super().request("GET", url, **kwargs) 

663 except _Upgrade as exc: 

664 session = exc.session 

665 else: 

666 raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover 

667 

668 return session 

669 

670 def __enter__(self) -> TestClient: 

671 with contextlib.ExitStack() as stack: 

672 self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend)) 

673 

674 @stack.callback 

675 def reset_portal() -> None: 

676 self.portal = None 

677 

678 send: anyio.create_memory_object_stream[MutableMapping[str, Any] | None] = ( 

679 anyio.create_memory_object_stream(math.inf) 

680 ) 

681 receive: anyio.create_memory_object_stream[MutableMapping[str, Any]] = anyio.create_memory_object_stream( 

682 math.inf 

683 ) 

684 for channel in (*send, *receive): 

685 stack.callback(channel.close) 

686 self.stream_send = StapledObjectStream(*send) 

687 self.stream_receive = StapledObjectStream(*receive) 

688 self.task = portal.start_task_soon(self.lifespan) 

689 portal.call(self.wait_startup) 

690 

691 @stack.callback 

692 def wait_shutdown() -> None: 

693 portal.call(self.wait_shutdown) 

694 

695 self.exit_stack = stack.pop_all() 

696 

697 return self 

698 

699 def __exit__(self, *args: Any) -> None: 

700 self.exit_stack.close() 

701 

702 async def lifespan(self) -> None: 

703 scope = {"type": "lifespan", "state": self.app_state} 

704 try: 

705 await self.app(scope, self.stream_receive.receive, self.stream_send.send) 

706 finally: 

707 await self.stream_send.send(None) 

708 

709 async def wait_startup(self) -> None: 

710 await self.stream_receive.send({"type": "lifespan.startup"}) 

711 

712 async def receive() -> Any: 

713 message = await self.stream_send.receive() 

714 if message is None: 

715 self.task.result() 

716 return message 

717 

718 message = await receive() 

719 assert message["type"] in ( 

720 "lifespan.startup.complete", 

721 "lifespan.startup.failed", 

722 ) 

723 if message["type"] == "lifespan.startup.failed": 

724 await receive() 

725 

726 async def wait_shutdown(self) -> None: 

727 async def receive() -> Any: 

728 message = await self.stream_send.receive() 

729 if message is None: 

730 self.task.result() 

731 return message 

732 

733 await self.stream_receive.send({"type": "lifespan.shutdown"}) 

734 message = await receive() 

735 assert message["type"] in ( 

736 "lifespan.shutdown.complete", 

737 "lifespan.shutdown.failed", 

738 ) 

739 if message["type"] == "lifespan.shutdown.failed": 

740 await receive()