Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/testclient.py: 25%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

312 statements  

1from __future__ import annotations 

2 

3import contextlib 

4import inspect 

5import io 

6import json 

7import math 

8import sys 

9import warnings 

10from collections.abc import Awaitable, Generator, Iterable, Mapping, MutableMapping, Sequence 

11from concurrent.futures import Future 

12from contextlib import AbstractContextManager 

13from types import GeneratorType 

14from typing import ( 

15 Any, 

16 Callable, 

17 Literal, 

18 TypedDict, 

19 Union, 

20 cast, 

21) 

22from urllib.parse import unquote, urljoin 

23 

24import anyio 

25import anyio.abc 

26import anyio.from_thread 

27from anyio.streams.stapled import StapledObjectStream 

28 

29from starlette._utils import is_async_callable 

30from starlette.types import ASGIApp, Message, Receive, Scope, Send 

31from starlette.websockets import WebSocketDisconnect 

32 

33if sys.version_info >= (3, 10): # pragma: no cover 

34 from typing import TypeGuard 

35else: # pragma: no cover 

36 from typing_extensions import TypeGuard 

37 

38if sys.version_info >= (3, 11): # pragma: no cover 

39 from typing import Self 

40else: # pragma: no cover 

41 from typing_extensions import Self 

42 

43try: 

44 import httpx 

45except ModuleNotFoundError: # pragma: no cover 

46 raise RuntimeError( 

47 "The starlette.testclient module requires the httpx package to be installed.\n" 

48 "You can install this with:\n" 

49 " $ pip install httpx\n" 

50 ) 

51_PortalFactoryType = Callable[[], AbstractContextManager[anyio.abc.BlockingPortal]] 

52 

53ASGIInstance = Callable[[Receive, Send], Awaitable[None]] 

54ASGI2App = Callable[[Scope], ASGIInstance] 

55ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]] 

56 

57 

58_RequestData = Mapping[str, Union[str, Iterable[str], bytes]] 

59 

60 

61def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]: 

62 if inspect.isclass(app): 

63 return hasattr(app, "__await__") 

64 return is_async_callable(app) 

65 

66 

67class _WrapASGI2: 

68 """ 

69 Provide an ASGI3 interface onto an ASGI2 app. 

70 """ 

71 

72 def __init__(self, app: ASGI2App) -> None: 

73 self.app = app 

74 

75 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

76 instance = self.app(scope) 

77 await instance(receive, send) 

78 

79 

80class _AsyncBackend(TypedDict): 

81 backend: str 

82 backend_options: dict[str, Any] 

83 

84 

85class _Upgrade(Exception): 

86 def __init__(self, session: WebSocketTestSession) -> None: 

87 self.session = session 

88 

89 

90class WebSocketDenialResponse( # type: ignore[misc] 

91 httpx.Response, 

92 WebSocketDisconnect, 

93): 

94 """ 

95 A special case of `WebSocketDisconnect`, raised in the `TestClient` if the 

96 `WebSocket` is closed before being accepted with a `send_denial_response()`. 

97 """ 

98 

99 

100class WebSocketTestSession: 

101 def __init__( 

102 self, 

103 app: ASGI3App, 

104 scope: Scope, 

105 portal_factory: _PortalFactoryType, 

106 ) -> None: 

107 self.app = app 

108 self.scope = scope 

109 self.accepted_subprotocol = None 

110 self.portal_factory = portal_factory 

111 self.extra_headers = None 

112 

113 def __enter__(self) -> WebSocketTestSession: 

114 with contextlib.ExitStack() as stack: 

115 self.portal = portal = stack.enter_context(self.portal_factory()) 

116 fut, cs = portal.start_task(self._run) 

117 stack.callback(fut.result) 

118 stack.callback(portal.call, cs.cancel) 

119 self.send({"type": "websocket.connect"}) 

120 message = self.receive() 

121 self._raise_on_close(message) 

122 self.accepted_subprotocol = message.get("subprotocol", None) 

123 self.extra_headers = message.get("headers", None) 

124 stack.callback(self.close, 1000) 

125 self.exit_stack = stack.pop_all() 

126 return self 

127 

128 def __exit__(self, *args: Any) -> bool | None: 

129 return self.exit_stack.__exit__(*args) 

130 

131 async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: 

132 """ 

133 The sub-thread in which the websocket session runs. 

134 """ 

135 send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) 

136 send_tx, send_rx = send 

137 receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) 

138 receive_tx, receive_rx = receive 

139 with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs: 

140 self._receive_tx = receive_tx 

141 self._send_rx = send_rx 

142 task_status.started(cs) 

143 await self.app(self.scope, receive_rx.receive, send_tx.send) 

144 

145 # wait for cs.cancel to be called before closing streams 

146 await anyio.sleep_forever() 

147 

148 def _raise_on_close(self, message: Message) -> None: 

149 if message["type"] == "websocket.close": 

150 raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", "")) 

151 elif message["type"] == "websocket.http.response.start": 

152 status_code: int = message["status"] 

153 headers: list[tuple[bytes, bytes]] = message["headers"] 

154 body: list[bytes] = [] 

155 while True: 

156 message = self.receive() 

157 assert message["type"] == "websocket.http.response.body" 

158 body.append(message["body"]) 

159 if not message.get("more_body", False): 

160 break 

161 raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body)) 

162 

163 def send(self, message: Message) -> None: 

164 self.portal.call(self._receive_tx.send, message) 

165 

166 def send_text(self, data: str) -> None: 

167 self.send({"type": "websocket.receive", "text": data}) 

168 

169 def send_bytes(self, data: bytes) -> None: 

170 self.send({"type": "websocket.receive", "bytes": data}) 

171 

172 def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None: 

173 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) 

174 if mode == "text": 

175 self.send({"type": "websocket.receive", "text": text}) 

176 else: 

177 self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) 

178 

179 def close(self, code: int = 1000, reason: str | None = None) -> None: 

180 self.send({"type": "websocket.disconnect", "code": code, "reason": reason}) 

181 

182 def receive(self) -> Message: 

183 return self.portal.call(self._send_rx.receive) 

184 

185 def receive_text(self) -> str: 

186 message = self.receive() 

187 self._raise_on_close(message) 

188 return cast(str, message["text"]) 

189 

190 def receive_bytes(self) -> bytes: 

191 message = self.receive() 

192 self._raise_on_close(message) 

193 return cast(bytes, message["bytes"]) 

194 

195 def receive_json(self, mode: Literal["text", "binary"] = "text") -> Any: 

196 message = self.receive() 

197 self._raise_on_close(message) 

198 if mode == "text": 

199 text = message["text"] 

200 else: 

201 text = message["bytes"].decode("utf-8") 

202 return json.loads(text) 

203 

204 

205class _TestClientTransport(httpx.BaseTransport): 

206 def __init__( 

207 self, 

208 app: ASGI3App, 

209 portal_factory: _PortalFactoryType, 

210 raise_server_exceptions: bool = True, 

211 root_path: str = "", 

212 *, 

213 client: tuple[str, int], 

214 app_state: dict[str, Any], 

215 ) -> None: 

216 self.app = app 

217 self.raise_server_exceptions = raise_server_exceptions 

218 self.root_path = root_path 

219 self.portal_factory = portal_factory 

220 self.app_state = app_state 

221 self.client = client 

222 

223 def handle_request(self, request: httpx.Request) -> httpx.Response: 

224 scheme = request.url.scheme 

225 netloc = request.url.netloc.decode(encoding="ascii") 

226 path = request.url.path 

227 raw_path = request.url.raw_path 

228 query = request.url.query.decode(encoding="ascii") 

229 

230 default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] 

231 

232 if ":" in netloc: 

233 host, port_string = netloc.split(":", 1) 

234 port = int(port_string) 

235 else: 

236 host = netloc 

237 port = default_port 

238 

239 # Include the 'host' header. 

240 if "host" in request.headers: 

241 headers: list[tuple[bytes, bytes]] = [] 

242 elif port == default_port: # pragma: no cover 

243 headers = [(b"host", host.encode())] 

244 else: # pragma: no cover 

245 headers = [(b"host", (f"{host}:{port}").encode())] 

246 

247 # Include other request headers. 

248 headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()] 

249 

250 scope: dict[str, Any] 

251 

252 if scheme in {"ws", "wss"}: 

253 subprotocol = request.headers.get("sec-websocket-protocol", None) 

254 if subprotocol is None: 

255 subprotocols: Sequence[str] = [] 

256 else: 

257 subprotocols = [value.strip() for value in subprotocol.split(",")] 

258 scope = { 

259 "type": "websocket", 

260 "path": unquote(path), 

261 "raw_path": raw_path.split(b"?", 1)[0], 

262 "root_path": self.root_path, 

263 "scheme": scheme, 

264 "query_string": query.encode(), 

265 "headers": headers, 

266 "client": self.client, 

267 "server": [host, port], 

268 "subprotocols": subprotocols, 

269 "state": self.app_state.copy(), 

270 "extensions": {"websocket.http.response": {}}, 

271 } 

272 session = WebSocketTestSession(self.app, scope, self.portal_factory) 

273 raise _Upgrade(session) 

274 

275 scope = { 

276 "type": "http", 

277 "http_version": "1.1", 

278 "method": request.method, 

279 "path": unquote(path), 

280 "raw_path": raw_path.split(b"?", 1)[0], 

281 "root_path": self.root_path, 

282 "scheme": scheme, 

283 "query_string": query.encode(), 

284 "headers": headers, 

285 "client": self.client, 

286 "server": [host, port], 

287 "extensions": {"http.response.debug": {}}, 

288 "state": self.app_state.copy(), 

289 } 

290 

291 request_complete = False 

292 response_started = False 

293 response_complete: anyio.Event 

294 raw_kwargs: dict[str, Any] = {"stream": io.BytesIO()} 

295 template = None 

296 context = None 

297 

298 async def receive() -> Message: 

299 nonlocal request_complete 

300 

301 if request_complete: 

302 if not response_complete.is_set(): 

303 await response_complete.wait() 

304 return {"type": "http.disconnect"} 

305 

306 body = request.read() 

307 if isinstance(body, str): 

308 body_bytes: bytes = body.encode("utf-8") # pragma: no cover 

309 elif body is None: 

310 body_bytes = b"" # pragma: no cover 

311 elif isinstance(body, GeneratorType): 

312 try: # pragma: no cover 

313 chunk = body.send(None) 

314 if isinstance(chunk, str): 

315 chunk = chunk.encode("utf-8") 

316 return {"type": "http.request", "body": chunk, "more_body": True} 

317 except StopIteration: # pragma: no cover 

318 request_complete = True 

319 return {"type": "http.request", "body": b""} 

320 else: 

321 body_bytes = body 

322 

323 request_complete = True 

324 return {"type": "http.request", "body": body_bytes} 

325 

326 async def send(message: Message) -> None: 

327 nonlocal raw_kwargs, response_started, template, context 

328 

329 if message["type"] == "http.response.start": 

330 assert not response_started, 'Received multiple "http.response.start" messages.' 

331 raw_kwargs["status_code"] = message["status"] 

332 raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])] 

333 response_started = True 

334 elif message["type"] == "http.response.body": 

335 assert response_started, 'Received "http.response.body" without "http.response.start".' 

336 assert not response_complete.is_set(), 'Received "http.response.body" after response completed.' 

337 body = message.get("body", b"") 

338 more_body = message.get("more_body", False) 

339 if request.method != "HEAD": 

340 raw_kwargs["stream"].write(body) 

341 if not more_body: 

342 raw_kwargs["stream"].seek(0) 

343 response_complete.set() 

344 elif message["type"] == "http.response.debug": 

345 template = message["info"]["template"] 

346 context = message["info"]["context"] 

347 

348 try: 

349 with self.portal_factory() as portal: 

350 response_complete = portal.call(anyio.Event) 

351 portal.call(self.app, scope, receive, send) 

352 except BaseException as exc: 

353 if self.raise_server_exceptions: 

354 raise exc 

355 

356 if self.raise_server_exceptions: 

357 assert response_started, "TestClient did not receive any response." 

358 elif not response_started: 

359 raw_kwargs = { 

360 "status_code": 500, 

361 "headers": [], 

362 "stream": io.BytesIO(), 

363 } 

364 

365 raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read()) 

366 

367 response = httpx.Response(**raw_kwargs, request=request) 

368 if template is not None: 

369 response.template = template # type: ignore[attr-defined] 

370 response.context = context # type: ignore[attr-defined] 

371 return response 

372 

373 

374class TestClient(httpx.Client): 

375 __test__ = False 

376 task: Future[None] 

377 portal: anyio.abc.BlockingPortal | None = None 

378 

379 def __init__( 

380 self, 

381 app: ASGIApp, 

382 base_url: str = "http://testserver", 

383 raise_server_exceptions: bool = True, 

384 root_path: str = "", 

385 backend: Literal["asyncio", "trio"] = "asyncio", 

386 backend_options: dict[str, Any] | None = None, 

387 cookies: httpx._types.CookieTypes | None = None, 

388 headers: dict[str, str] | None = None, 

389 follow_redirects: bool = True, 

390 client: tuple[str, int] = ("testclient", 50000), 

391 ) -> None: 

392 self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {}) 

393 if _is_asgi3(app): 

394 asgi_app = app 

395 else: 

396 app = cast(ASGI2App, app) # type: ignore[assignment] 

397 asgi_app = _WrapASGI2(app) # type: ignore[arg-type] 

398 self.app = asgi_app 

399 self.app_state: dict[str, Any] = {} 

400 transport = _TestClientTransport( 

401 self.app, 

402 portal_factory=self._portal_factory, 

403 raise_server_exceptions=raise_server_exceptions, 

404 root_path=root_path, 

405 app_state=self.app_state, 

406 client=client, 

407 ) 

408 if headers is None: 

409 headers = {} 

410 headers.setdefault("user-agent", "testclient") 

411 super().__init__( 

412 base_url=base_url, 

413 headers=headers, 

414 transport=transport, 

415 follow_redirects=follow_redirects, 

416 cookies=cookies, 

417 ) 

418 

419 @contextlib.contextmanager 

420 def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]: 

421 if self.portal is not None: 

422 yield self.portal 

423 else: 

424 with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal: 

425 yield portal 

426 

427 def request( # type: ignore[override] 

428 self, 

429 method: str, 

430 url: httpx._types.URLTypes, 

431 *, 

432 content: httpx._types.RequestContent | None = None, 

433 data: _RequestData | None = None, 

434 files: httpx._types.RequestFiles | None = None, 

435 json: Any = None, 

436 params: httpx._types.QueryParamTypes | None = None, 

437 headers: httpx._types.HeaderTypes | None = None, 

438 cookies: httpx._types.CookieTypes | None = None, 

439 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

440 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

441 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

442 extensions: dict[str, Any] | None = None, 

443 ) -> httpx.Response: 

444 if timeout is not httpx.USE_CLIENT_DEFAULT: 

445 warnings.warn( 

446 "You should not use the 'timeout' argument with the TestClient. " 

447 "See https://github.com/encode/starlette/issues/1108 for more information.", 

448 DeprecationWarning, 

449 ) 

450 url = self._merge_url(url) 

451 return super().request( 

452 method, 

453 url, 

454 content=content, 

455 data=data, 

456 files=files, 

457 json=json, 

458 params=params, 

459 headers=headers, 

460 cookies=cookies, 

461 auth=auth, 

462 follow_redirects=follow_redirects, 

463 timeout=timeout, 

464 extensions=extensions, 

465 ) 

466 

467 def get( # type: ignore[override] 

468 self, 

469 url: httpx._types.URLTypes, 

470 *, 

471 params: httpx._types.QueryParamTypes | None = None, 

472 headers: httpx._types.HeaderTypes | None = None, 

473 cookies: httpx._types.CookieTypes | None = None, 

474 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

475 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

476 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

477 extensions: dict[str, Any] | None = None, 

478 ) -> httpx.Response: 

479 return super().get( 

480 url, 

481 params=params, 

482 headers=headers, 

483 cookies=cookies, 

484 auth=auth, 

485 follow_redirects=follow_redirects, 

486 timeout=timeout, 

487 extensions=extensions, 

488 ) 

489 

490 def options( # type: ignore[override] 

491 self, 

492 url: httpx._types.URLTypes, 

493 *, 

494 params: httpx._types.QueryParamTypes | None = None, 

495 headers: httpx._types.HeaderTypes | None = None, 

496 cookies: httpx._types.CookieTypes | None = None, 

497 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

498 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

499 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

500 extensions: dict[str, Any] | None = None, 

501 ) -> httpx.Response: 

502 return super().options( 

503 url, 

504 params=params, 

505 headers=headers, 

506 cookies=cookies, 

507 auth=auth, 

508 follow_redirects=follow_redirects, 

509 timeout=timeout, 

510 extensions=extensions, 

511 ) 

512 

513 def head( # type: ignore[override] 

514 self, 

515 url: httpx._types.URLTypes, 

516 *, 

517 params: httpx._types.QueryParamTypes | None = None, 

518 headers: httpx._types.HeaderTypes | None = None, 

519 cookies: httpx._types.CookieTypes | None = None, 

520 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

521 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

522 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

523 extensions: dict[str, Any] | None = None, 

524 ) -> httpx.Response: 

525 return super().head( 

526 url, 

527 params=params, 

528 headers=headers, 

529 cookies=cookies, 

530 auth=auth, 

531 follow_redirects=follow_redirects, 

532 timeout=timeout, 

533 extensions=extensions, 

534 ) 

535 

536 def post( # type: ignore[override] 

537 self, 

538 url: httpx._types.URLTypes, 

539 *, 

540 content: httpx._types.RequestContent | None = None, 

541 data: _RequestData | None = None, 

542 files: httpx._types.RequestFiles | None = None, 

543 json: Any = None, 

544 params: httpx._types.QueryParamTypes | None = None, 

545 headers: httpx._types.HeaderTypes | None = None, 

546 cookies: httpx._types.CookieTypes | None = None, 

547 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

548 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

549 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

550 extensions: dict[str, Any] | None = None, 

551 ) -> httpx.Response: 

552 return super().post( 

553 url, 

554 content=content, 

555 data=data, 

556 files=files, 

557 json=json, 

558 params=params, 

559 headers=headers, 

560 cookies=cookies, 

561 auth=auth, 

562 follow_redirects=follow_redirects, 

563 timeout=timeout, 

564 extensions=extensions, 

565 ) 

566 

567 def put( # type: ignore[override] 

568 self, 

569 url: httpx._types.URLTypes, 

570 *, 

571 content: httpx._types.RequestContent | None = None, 

572 data: _RequestData | None = None, 

573 files: httpx._types.RequestFiles | None = None, 

574 json: Any = None, 

575 params: httpx._types.QueryParamTypes | None = None, 

576 headers: httpx._types.HeaderTypes | None = None, 

577 cookies: httpx._types.CookieTypes | None = None, 

578 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

579 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

580 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

581 extensions: dict[str, Any] | None = None, 

582 ) -> httpx.Response: 

583 return super().put( 

584 url, 

585 content=content, 

586 data=data, 

587 files=files, 

588 json=json, 

589 params=params, 

590 headers=headers, 

591 cookies=cookies, 

592 auth=auth, 

593 follow_redirects=follow_redirects, 

594 timeout=timeout, 

595 extensions=extensions, 

596 ) 

597 

598 def patch( # type: ignore[override] 

599 self, 

600 url: httpx._types.URLTypes, 

601 *, 

602 content: httpx._types.RequestContent | None = None, 

603 data: _RequestData | None = None, 

604 files: httpx._types.RequestFiles | None = None, 

605 json: Any = None, 

606 params: httpx._types.QueryParamTypes | None = None, 

607 headers: httpx._types.HeaderTypes | None = None, 

608 cookies: httpx._types.CookieTypes | None = None, 

609 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

610 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

611 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

612 extensions: dict[str, Any] | None = None, 

613 ) -> httpx.Response: 

614 return super().patch( 

615 url, 

616 content=content, 

617 data=data, 

618 files=files, 

619 json=json, 

620 params=params, 

621 headers=headers, 

622 cookies=cookies, 

623 auth=auth, 

624 follow_redirects=follow_redirects, 

625 timeout=timeout, 

626 extensions=extensions, 

627 ) 

628 

629 def delete( # type: ignore[override] 

630 self, 

631 url: httpx._types.URLTypes, 

632 *, 

633 params: httpx._types.QueryParamTypes | None = None, 

634 headers: httpx._types.HeaderTypes | None = None, 

635 cookies: httpx._types.CookieTypes | None = None, 

636 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

637 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

638 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

639 extensions: dict[str, Any] | None = None, 

640 ) -> httpx.Response: 

641 return super().delete( 

642 url, 

643 params=params, 

644 headers=headers, 

645 cookies=cookies, 

646 auth=auth, 

647 follow_redirects=follow_redirects, 

648 timeout=timeout, 

649 extensions=extensions, 

650 ) 

651 

652 def websocket_connect( 

653 self, 

654 url: str, 

655 subprotocols: Sequence[str] | None = None, 

656 **kwargs: Any, 

657 ) -> WebSocketTestSession: 

658 url = urljoin("ws://testserver", url) 

659 headers = kwargs.get("headers", {}) 

660 headers.setdefault("connection", "upgrade") 

661 headers.setdefault("sec-websocket-key", "testserver==") 

662 headers.setdefault("sec-websocket-version", "13") 

663 if subprotocols is not None: 

664 headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) 

665 kwargs["headers"] = headers 

666 try: 

667 super().request("GET", url, **kwargs) 

668 except _Upgrade as exc: 

669 session = exc.session 

670 else: 

671 raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover 

672 

673 return session 

674 

675 def __enter__(self) -> Self: 

676 with contextlib.ExitStack() as stack: 

677 self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend)) 

678 

679 @stack.callback 

680 def reset_portal() -> None: 

681 self.portal = None 

682 

683 send: anyio.create_memory_object_stream[MutableMapping[str, Any] | None] = ( 

684 anyio.create_memory_object_stream(math.inf) 

685 ) 

686 receive: anyio.create_memory_object_stream[MutableMapping[str, Any]] = anyio.create_memory_object_stream( 

687 math.inf 

688 ) 

689 for channel in (*send, *receive): 

690 stack.callback(channel.close) 

691 self.stream_send = StapledObjectStream(*send) 

692 self.stream_receive = StapledObjectStream(*receive) 

693 self.task = portal.start_task_soon(self.lifespan) 

694 portal.call(self.wait_startup) 

695 

696 @stack.callback 

697 def wait_shutdown() -> None: 

698 portal.call(self.wait_shutdown) 

699 

700 self.exit_stack = stack.pop_all() 

701 

702 return self 

703 

704 def __exit__(self, *args: Any) -> None: 

705 self.exit_stack.close() 

706 

707 async def lifespan(self) -> None: 

708 scope = {"type": "lifespan", "state": self.app_state} 

709 try: 

710 await self.app(scope, self.stream_receive.receive, self.stream_send.send) 

711 finally: 

712 await self.stream_send.send(None) 

713 

714 async def wait_startup(self) -> None: 

715 await self.stream_receive.send({"type": "lifespan.startup"}) 

716 

717 async def receive() -> Any: 

718 message = await self.stream_send.receive() 

719 if message is None: 

720 self.task.result() 

721 return message 

722 

723 message = await receive() 

724 assert message["type"] in ( 

725 "lifespan.startup.complete", 

726 "lifespan.startup.failed", 

727 ) 

728 if message["type"] == "lifespan.startup.failed": 

729 await receive() 

730 

731 async def wait_shutdown(self) -> None: 

732 async def receive() -> Any: 

733 message = await self.stream_send.receive() 

734 if message is None: 

735 self.task.result() 

736 return message 

737 

738 await self.stream_receive.send({"type": "lifespan.shutdown"}) 

739 message = await receive() 

740 assert message["type"] in ( 

741 "lifespan.shutdown.complete", 

742 "lifespan.shutdown.failed", 

743 ) 

744 if message["type"] == "lifespan.shutdown.failed": 

745 await receive()