Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/testclient.py: 25%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

312 statements  

1from __future__ import annotations 

2 

3import contextlib 

4import inspect 

5import io 

6import json 

7import math 

8import sys 

9import warnings 

10from collections.abc import Awaitable, Callable, Generator, Iterable, Mapping, MutableMapping, Sequence 

11from concurrent.futures import Future 

12from contextlib import AbstractContextManager 

13from types import GeneratorType 

14from typing import ( 

15 Any, 

16 Literal, 

17 TypedDict, 

18 TypeGuard, 

19 cast, 

20) 

21from urllib.parse import unquote, urljoin 

22 

23import anyio 

24import anyio.abc 

25import anyio.from_thread 

26from anyio.streams.stapled import StapledObjectStream 

27 

28from starlette._utils import is_async_callable 

29from starlette.types import ASGIApp, Message, Receive, Scope, Send 

30from starlette.websockets import WebSocketDisconnect 

31 

32if sys.version_info >= (3, 11): # pragma: no cover 

33 from typing import Self 

34else: # pragma: no cover 

35 from typing_extensions import Self 

36 

37try: 

38 import httpx 

39except ModuleNotFoundError: # pragma: no cover 

40 raise RuntimeError( 

41 "The starlette.testclient module requires the httpx package to be installed.\n" 

42 "You can install this with:\n" 

43 " $ pip install httpx\n" 

44 ) 

45_PortalFactoryType = Callable[[], AbstractContextManager[anyio.abc.BlockingPortal]] 

46 

47ASGIInstance = Callable[[Receive, Send], Awaitable[None]] 

48ASGI2App = Callable[[Scope], ASGIInstance] 

49ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]] 

50 

51 

52_RequestData = Mapping[str, str | Iterable[str] | bytes] 

53 

54 

55def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]: 

56 if inspect.isclass(app): 

57 return hasattr(app, "__await__") 

58 return is_async_callable(app) 

59 

60 

61class _WrapASGI2: 

62 """ 

63 Provide an ASGI3 interface onto an ASGI2 app. 

64 """ 

65 

66 def __init__(self, app: ASGI2App) -> None: 

67 self.app = app 

68 

69 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

70 instance = self.app(scope) 

71 await instance(receive, send) 

72 

73 

74class _AsyncBackend(TypedDict): 

75 backend: str 

76 backend_options: dict[str, Any] 

77 

78 

79class _Upgrade(Exception): 

80 def __init__(self, session: WebSocketTestSession) -> None: 

81 self.session = session 

82 

83 

84class WebSocketDenialResponse( # type: ignore[misc] 

85 httpx.Response, 

86 WebSocketDisconnect, 

87): 

88 """ 

89 A special case of `WebSocketDisconnect`, raised in the `TestClient` if the 

90 `WebSocket` is closed before being accepted with a `send_denial_response()`. 

91 """ 

92 

93 

94class WebSocketTestSession: 

95 def __init__( 

96 self, 

97 app: ASGI3App, 

98 scope: Scope, 

99 portal_factory: _PortalFactoryType, 

100 ) -> None: 

101 self.app = app 

102 self.scope = scope 

103 self.accepted_subprotocol = None 

104 self.portal_factory = portal_factory 

105 self.extra_headers = None 

106 

107 def __enter__(self) -> WebSocketTestSession: 

108 with contextlib.ExitStack() as stack: 

109 self.portal = portal = stack.enter_context(self.portal_factory()) 

110 fut, cs = portal.start_task(self._run) 

111 stack.callback(fut.result) 

112 stack.callback(portal.call, cs.cancel) 

113 self.send({"type": "websocket.connect"}) 

114 message = self.receive() 

115 self._raise_on_close(message) 

116 self.accepted_subprotocol = message.get("subprotocol", None) 

117 self.extra_headers = message.get("headers", None) 

118 stack.callback(self.close, 1000) 

119 self.exit_stack = stack.pop_all() 

120 return self 

121 

122 def __exit__(self, *args: Any) -> bool | None: 

123 return self.exit_stack.__exit__(*args) 

124 

125 async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: 

126 """ 

127 The sub-thread in which the websocket session runs. 

128 """ 

129 send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) 

130 send_tx, send_rx = send 

131 receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) 

132 receive_tx, receive_rx = receive 

133 with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs: 

134 self._receive_tx = receive_tx 

135 self._send_rx = send_rx 

136 task_status.started(cs) 

137 await self.app(self.scope, receive_rx.receive, send_tx.send) 

138 

139 # wait for cs.cancel to be called before closing streams 

140 await anyio.sleep_forever() 

141 

142 def _raise_on_close(self, message: Message) -> None: 

143 if message["type"] == "websocket.close": 

144 raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", "")) 

145 elif message["type"] == "websocket.http.response.start": 

146 status_code: int = message["status"] 

147 headers: list[tuple[bytes, bytes]] = message["headers"] 

148 body: list[bytes] = [] 

149 while True: 

150 message = self.receive() 

151 assert message["type"] == "websocket.http.response.body" 

152 body.append(message["body"]) 

153 if not message.get("more_body", False): 

154 break 

155 raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body)) 

156 

157 def send(self, message: Message) -> None: 

158 self.portal.call(self._receive_tx.send, message) 

159 

160 def send_text(self, data: str) -> None: 

161 self.send({"type": "websocket.receive", "text": data}) 

162 

163 def send_bytes(self, data: bytes) -> None: 

164 self.send({"type": "websocket.receive", "bytes": data}) 

165 

166 def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None: 

167 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) 

168 if mode == "text": 

169 self.send({"type": "websocket.receive", "text": text}) 

170 else: 

171 self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) 

172 

173 def close(self, code: int = 1000, reason: str | None = None) -> None: 

174 self.send({"type": "websocket.disconnect", "code": code, "reason": reason}) 

175 

176 def receive(self) -> Message: 

177 return self.portal.call(self._send_rx.receive) 

178 

179 def receive_text(self) -> str: 

180 message = self.receive() 

181 self._raise_on_close(message) 

182 return cast(str, message["text"]) 

183 

184 def receive_bytes(self) -> bytes: 

185 message = self.receive() 

186 self._raise_on_close(message) 

187 return cast(bytes, message["bytes"]) 

188 

189 def receive_json(self, mode: Literal["text", "binary"] = "text") -> Any: 

190 message = self.receive() 

191 self._raise_on_close(message) 

192 if mode == "text": 

193 text = message["text"] 

194 else: 

195 text = message["bytes"].decode("utf-8") 

196 return json.loads(text) 

197 

198 

199class _TestClientTransport(httpx.BaseTransport): 

200 def __init__( 

201 self, 

202 app: ASGI3App, 

203 portal_factory: _PortalFactoryType, 

204 raise_server_exceptions: bool = True, 

205 root_path: str = "", 

206 *, 

207 client: tuple[str, int], 

208 app_state: dict[str, Any], 

209 ) -> None: 

210 self.app = app 

211 self.raise_server_exceptions = raise_server_exceptions 

212 self.root_path = root_path 

213 self.portal_factory = portal_factory 

214 self.app_state = app_state 

215 self.client = client 

216 

217 def handle_request(self, request: httpx.Request) -> httpx.Response: 

218 scheme = request.url.scheme 

219 netloc = request.url.netloc.decode(encoding="ascii") 

220 path = request.url.path 

221 raw_path = request.url.raw_path 

222 query = request.url.query.decode(encoding="ascii") 

223 

224 default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] 

225 

226 if ":" in netloc: 

227 host, port_string = netloc.split(":", 1) 

228 port = int(port_string) 

229 else: 

230 host = netloc 

231 port = default_port 

232 

233 # Include the 'host' header. 

234 if "host" in request.headers: 

235 headers: list[tuple[bytes, bytes]] = [] 

236 elif port == default_port: # pragma: no cover 

237 headers = [(b"host", host.encode())] 

238 else: # pragma: no cover 

239 headers = [(b"host", (f"{host}:{port}").encode())] 

240 

241 # Include other request headers. 

242 headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()] 

243 

244 scope: dict[str, Any] 

245 

246 if scheme in {"ws", "wss"}: 

247 subprotocol = request.headers.get("sec-websocket-protocol", None) 

248 if subprotocol is None: 

249 subprotocols: Sequence[str] = [] 

250 else: 

251 subprotocols = [value.strip() for value in subprotocol.split(",")] 

252 scope = { 

253 "type": "websocket", 

254 "path": unquote(path), 

255 "raw_path": raw_path.split(b"?", 1)[0], 

256 "root_path": self.root_path, 

257 "scheme": scheme, 

258 "query_string": query.encode(), 

259 "headers": headers, 

260 "client": self.client, 

261 "server": [host, port], 

262 "subprotocols": subprotocols, 

263 "state": self.app_state.copy(), 

264 "extensions": {"websocket.http.response": {}}, 

265 } 

266 session = WebSocketTestSession(self.app, scope, self.portal_factory) 

267 raise _Upgrade(session) 

268 

269 scope = { 

270 "type": "http", 

271 "http_version": "1.1", 

272 "method": request.method, 

273 "path": unquote(path), 

274 "raw_path": raw_path.split(b"?", 1)[0], 

275 "root_path": self.root_path, 

276 "scheme": scheme, 

277 "query_string": query.encode(), 

278 "headers": headers, 

279 "client": self.client, 

280 "server": [host, port], 

281 "extensions": {"http.response.debug": {}}, 

282 "state": self.app_state.copy(), 

283 } 

284 

285 request_complete = False 

286 response_started = False 

287 response_complete: anyio.Event 

288 raw_kwargs: dict[str, Any] = {"stream": io.BytesIO()} 

289 template = None 

290 context = None 

291 

292 async def receive() -> Message: 

293 nonlocal request_complete 

294 

295 if request_complete: 

296 if not response_complete.is_set(): 

297 await response_complete.wait() 

298 return {"type": "http.disconnect"} 

299 

300 body = request.read() 

301 if isinstance(body, str): 

302 body_bytes: bytes = body.encode("utf-8") # pragma: no cover 

303 elif body is None: 

304 body_bytes = b"" # pragma: no cover 

305 elif isinstance(body, GeneratorType): 

306 try: # pragma: no cover 

307 chunk = body.send(None) 

308 if isinstance(chunk, str): 

309 chunk = chunk.encode("utf-8") 

310 return {"type": "http.request", "body": chunk, "more_body": True} 

311 except StopIteration: # pragma: no cover 

312 request_complete = True 

313 return {"type": "http.request", "body": b""} 

314 else: 

315 body_bytes = body 

316 

317 request_complete = True 

318 return {"type": "http.request", "body": body_bytes} 

319 

320 async def send(message: Message) -> None: 

321 nonlocal raw_kwargs, response_started, template, context 

322 

323 if message["type"] == "http.response.start": 

324 assert not response_started, 'Received multiple "http.response.start" messages.' 

325 raw_kwargs["status_code"] = message["status"] 

326 raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])] 

327 response_started = True 

328 elif message["type"] == "http.response.body": 

329 assert response_started, 'Received "http.response.body" without "http.response.start".' 

330 assert not response_complete.is_set(), 'Received "http.response.body" after response completed.' 

331 body = message.get("body", b"") 

332 more_body = message.get("more_body", False) 

333 if request.method != "HEAD": 

334 raw_kwargs["stream"].write(body) 

335 if not more_body: 

336 raw_kwargs["stream"].seek(0) 

337 response_complete.set() 

338 elif message["type"] == "http.response.debug": 

339 template = message["info"]["template"] 

340 context = message["info"]["context"] 

341 

342 try: 

343 with self.portal_factory() as portal: 

344 response_complete = portal.call(anyio.Event) 

345 portal.call(self.app, scope, receive, send) 

346 except BaseException as exc: 

347 if self.raise_server_exceptions: 

348 raise exc 

349 

350 if self.raise_server_exceptions: 

351 assert response_started, "TestClient did not receive any response." 

352 elif not response_started: 

353 raw_kwargs = { 

354 "status_code": 500, 

355 "headers": [], 

356 "stream": io.BytesIO(), 

357 } 

358 

359 raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read()) 

360 

361 response = httpx.Response(**raw_kwargs, request=request) 

362 if template is not None: 

363 response.template = template # type: ignore[attr-defined] 

364 response.context = context # type: ignore[attr-defined] 

365 return response 

366 

367 

368class TestClient(httpx.Client): 

369 __test__ = False 

370 task: Future[None] 

371 portal: anyio.abc.BlockingPortal | None = None 

372 

373 def __init__( 

374 self, 

375 app: ASGIApp, 

376 base_url: str = "http://testserver", 

377 raise_server_exceptions: bool = True, 

378 root_path: str = "", 

379 backend: Literal["asyncio", "trio"] = "asyncio", 

380 backend_options: dict[str, Any] | None = None, 

381 cookies: httpx._types.CookieTypes | None = None, 

382 headers: dict[str, str] | None = None, 

383 follow_redirects: bool = True, 

384 client: tuple[str, int] = ("testclient", 50000), 

385 ) -> None: 

386 self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {}) 

387 if _is_asgi3(app): 

388 asgi_app = app 

389 else: 

390 app = cast(ASGI2App, app) # type: ignore[assignment] 

391 asgi_app = _WrapASGI2(app) # type: ignore[arg-type] 

392 self.app = asgi_app 

393 self.app_state: dict[str, Any] = {} 

394 transport = _TestClientTransport( 

395 self.app, 

396 portal_factory=self._portal_factory, 

397 raise_server_exceptions=raise_server_exceptions, 

398 root_path=root_path, 

399 app_state=self.app_state, 

400 client=client, 

401 ) 

402 if headers is None: 

403 headers = {} 

404 headers.setdefault("user-agent", "testclient") 

405 super().__init__( 

406 base_url=base_url, 

407 headers=headers, 

408 transport=transport, 

409 follow_redirects=follow_redirects, 

410 cookies=cookies, 

411 ) 

412 

413 @contextlib.contextmanager 

414 def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]: 

415 if self.portal is not None: 

416 yield self.portal 

417 else: 

418 with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal: 

419 yield portal 

420 

421 def request( # type: ignore[override] 

422 self, 

423 method: str, 

424 url: httpx._types.URLTypes, 

425 *, 

426 content: httpx._types.RequestContent | None = None, 

427 data: _RequestData | None = None, 

428 files: httpx._types.RequestFiles | None = None, 

429 json: Any = None, 

430 params: httpx._types.QueryParamTypes | None = None, 

431 headers: httpx._types.HeaderTypes | None = None, 

432 cookies: httpx._types.CookieTypes | None = None, 

433 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

434 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

435 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

436 extensions: dict[str, Any] | None = None, 

437 ) -> httpx.Response: 

438 if timeout is not httpx.USE_CLIENT_DEFAULT: 

439 warnings.warn( 

440 "You should not use the 'timeout' argument with the TestClient. " 

441 "See https://github.com/Kludex/starlette/issues/1108 for more information.", 

442 DeprecationWarning, 

443 ) 

444 url = self._merge_url(url) 

445 return super().request( 

446 method, 

447 url, 

448 content=content, 

449 data=data, 

450 files=files, 

451 json=json, 

452 params=params, 

453 headers=headers, 

454 cookies=cookies, 

455 auth=auth, 

456 follow_redirects=follow_redirects, 

457 timeout=timeout, 

458 extensions=extensions, 

459 ) 

460 

461 def get( # type: ignore[override] 

462 self, 

463 url: httpx._types.URLTypes, 

464 *, 

465 params: httpx._types.QueryParamTypes | None = None, 

466 headers: httpx._types.HeaderTypes | None = None, 

467 cookies: httpx._types.CookieTypes | None = None, 

468 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

469 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

470 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

471 extensions: dict[str, Any] | None = None, 

472 ) -> httpx.Response: 

473 return super().get( 

474 url, 

475 params=params, 

476 headers=headers, 

477 cookies=cookies, 

478 auth=auth, 

479 follow_redirects=follow_redirects, 

480 timeout=timeout, 

481 extensions=extensions, 

482 ) 

483 

484 def options( # type: ignore[override] 

485 self, 

486 url: httpx._types.URLTypes, 

487 *, 

488 params: httpx._types.QueryParamTypes | None = None, 

489 headers: httpx._types.HeaderTypes | None = None, 

490 cookies: httpx._types.CookieTypes | None = None, 

491 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

492 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

493 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

494 extensions: dict[str, Any] | None = None, 

495 ) -> httpx.Response: 

496 return super().options( 

497 url, 

498 params=params, 

499 headers=headers, 

500 cookies=cookies, 

501 auth=auth, 

502 follow_redirects=follow_redirects, 

503 timeout=timeout, 

504 extensions=extensions, 

505 ) 

506 

507 def head( # type: ignore[override] 

508 self, 

509 url: httpx._types.URLTypes, 

510 *, 

511 params: httpx._types.QueryParamTypes | None = None, 

512 headers: httpx._types.HeaderTypes | None = None, 

513 cookies: httpx._types.CookieTypes | None = None, 

514 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

515 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

516 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

517 extensions: dict[str, Any] | None = None, 

518 ) -> httpx.Response: 

519 return super().head( 

520 url, 

521 params=params, 

522 headers=headers, 

523 cookies=cookies, 

524 auth=auth, 

525 follow_redirects=follow_redirects, 

526 timeout=timeout, 

527 extensions=extensions, 

528 ) 

529 

530 def post( # type: ignore[override] 

531 self, 

532 url: httpx._types.URLTypes, 

533 *, 

534 content: httpx._types.RequestContent | None = None, 

535 data: _RequestData | None = None, 

536 files: httpx._types.RequestFiles | None = None, 

537 json: Any = None, 

538 params: httpx._types.QueryParamTypes | None = None, 

539 headers: httpx._types.HeaderTypes | None = None, 

540 cookies: httpx._types.CookieTypes | None = None, 

541 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

542 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

543 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

544 extensions: dict[str, Any] | None = None, 

545 ) -> httpx.Response: 

546 return super().post( 

547 url, 

548 content=content, 

549 data=data, 

550 files=files, 

551 json=json, 

552 params=params, 

553 headers=headers, 

554 cookies=cookies, 

555 auth=auth, 

556 follow_redirects=follow_redirects, 

557 timeout=timeout, 

558 extensions=extensions, 

559 ) 

560 

561 def put( # type: ignore[override] 

562 self, 

563 url: httpx._types.URLTypes, 

564 *, 

565 content: httpx._types.RequestContent | None = None, 

566 data: _RequestData | None = None, 

567 files: httpx._types.RequestFiles | None = None, 

568 json: Any = None, 

569 params: httpx._types.QueryParamTypes | None = None, 

570 headers: httpx._types.HeaderTypes | None = None, 

571 cookies: httpx._types.CookieTypes | None = None, 

572 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

573 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

574 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

575 extensions: dict[str, Any] | None = None, 

576 ) -> httpx.Response: 

577 return super().put( 

578 url, 

579 content=content, 

580 data=data, 

581 files=files, 

582 json=json, 

583 params=params, 

584 headers=headers, 

585 cookies=cookies, 

586 auth=auth, 

587 follow_redirects=follow_redirects, 

588 timeout=timeout, 

589 extensions=extensions, 

590 ) 

591 

592 def patch( # type: ignore[override] 

593 self, 

594 url: httpx._types.URLTypes, 

595 *, 

596 content: httpx._types.RequestContent | None = None, 

597 data: _RequestData | None = None, 

598 files: httpx._types.RequestFiles | None = None, 

599 json: Any = None, 

600 params: httpx._types.QueryParamTypes | None = None, 

601 headers: httpx._types.HeaderTypes | None = None, 

602 cookies: httpx._types.CookieTypes | None = None, 

603 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

604 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

605 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

606 extensions: dict[str, Any] | None = None, 

607 ) -> httpx.Response: 

608 return super().patch( 

609 url, 

610 content=content, 

611 data=data, 

612 files=files, 

613 json=json, 

614 params=params, 

615 headers=headers, 

616 cookies=cookies, 

617 auth=auth, 

618 follow_redirects=follow_redirects, 

619 timeout=timeout, 

620 extensions=extensions, 

621 ) 

622 

623 def delete( # type: ignore[override] 

624 self, 

625 url: httpx._types.URLTypes, 

626 *, 

627 params: httpx._types.QueryParamTypes | None = None, 

628 headers: httpx._types.HeaderTypes | None = None, 

629 cookies: httpx._types.CookieTypes | None = None, 

630 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

631 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

632 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

633 extensions: dict[str, Any] | None = None, 

634 ) -> httpx.Response: 

635 return super().delete( 

636 url, 

637 params=params, 

638 headers=headers, 

639 cookies=cookies, 

640 auth=auth, 

641 follow_redirects=follow_redirects, 

642 timeout=timeout, 

643 extensions=extensions, 

644 ) 

645 

646 def websocket_connect( 

647 self, 

648 url: str, 

649 subprotocols: Sequence[str] | None = None, 

650 **kwargs: Any, 

651 ) -> WebSocketTestSession: 

652 url = urljoin("ws://testserver", url) 

653 headers = kwargs.get("headers", {}) 

654 headers.setdefault("connection", "upgrade") 

655 headers.setdefault("sec-websocket-key", "testserver==") 

656 headers.setdefault("sec-websocket-version", "13") 

657 if subprotocols is not None: 

658 headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) 

659 kwargs["headers"] = headers 

660 try: 

661 super().request("GET", url, **kwargs) 

662 except _Upgrade as exc: 

663 session = exc.session 

664 else: 

665 raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover 

666 

667 return session 

668 

669 def __enter__(self) -> Self: 

670 with contextlib.ExitStack() as stack: 

671 self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend)) 

672 

673 @stack.callback 

674 def reset_portal() -> None: 

675 self.portal = None 

676 

677 send: anyio.create_memory_object_stream[MutableMapping[str, Any] | None] = ( 

678 anyio.create_memory_object_stream(math.inf) 

679 ) 

680 receive: anyio.create_memory_object_stream[MutableMapping[str, Any]] = anyio.create_memory_object_stream( 

681 math.inf 

682 ) 

683 for channel in (*send, *receive): 

684 stack.callback(channel.close) 

685 self.stream_send = StapledObjectStream(*send) 

686 self.stream_receive = StapledObjectStream(*receive) 

687 self.task = portal.start_task_soon(self.lifespan) 

688 portal.call(self.wait_startup) 

689 

690 @stack.callback 

691 def wait_shutdown() -> None: 

692 portal.call(self.wait_shutdown) 

693 

694 self.exit_stack = stack.pop_all() 

695 

696 return self 

697 

698 def __exit__(self, *args: Any) -> None: 

699 self.exit_stack.close() 

700 

701 async def lifespan(self) -> None: 

702 scope = {"type": "lifespan", "state": self.app_state} 

703 try: 

704 await self.app(scope, self.stream_receive.receive, self.stream_send.send) 

705 finally: 

706 await self.stream_send.send(None) 

707 

708 async def wait_startup(self) -> None: 

709 await self.stream_receive.send({"type": "lifespan.startup"}) 

710 

711 async def receive() -> Any: 

712 message = await self.stream_send.receive() 

713 if message is None: 

714 self.task.result() 

715 return message 

716 

717 message = await receive() 

718 assert message["type"] in ( 

719 "lifespan.startup.complete", 

720 "lifespan.startup.failed", 

721 ) 

722 if message["type"] == "lifespan.startup.failed": 

723 await receive() 

724 

725 async def wait_shutdown(self) -> None: 

726 async def receive() -> Any: 

727 message = await self.stream_send.receive() 

728 if message is None: 

729 self.task.result() 

730 return message 

731 

732 await self.stream_receive.send({"type": "lifespan.shutdown"}) 

733 message = await receive() 

734 assert message["type"] in ( 

735 "lifespan.shutdown.complete", 

736 "lifespan.shutdown.failed", 

737 ) 

738 if message["type"] == "lifespan.shutdown.failed": 

739 await receive()