Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/testclient.py: 25%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

315 statements  

1from __future__ import annotations 

2 

3import contextlib 

4import inspect 

5import io 

6import json 

7import math 

8import sys 

9import warnings 

10from collections.abc import Awaitable, Callable, Generator, Iterable, Mapping, MutableMapping, Sequence 

11from concurrent.futures import Future 

12from contextlib import AbstractContextManager 

13from types import GeneratorType 

14from typing import ( 

15 TYPE_CHECKING, 

16 Any, 

17 Literal, 

18 TypedDict, 

19 TypeGuard, 

20 cast, 

21) 

22from urllib.parse import unquote, urljoin 

23 

24import anyio 

25import anyio.abc 

26import anyio.from_thread 

27from anyio.streams.stapled import StapledObjectStream 

28 

29from starlette._utils import is_async_callable 

30from starlette.exceptions import StarletteDeprecationWarning 

31from starlette.types import ASGIApp, Message, Receive, Scope, Send 

32from starlette.websockets import WebSocketDisconnect 

33 

34if sys.version_info >= (3, 11): # pragma: no cover 

35 from typing import Self 

36else: # pragma: no cover 

37 from typing_extensions import Self 

38 

39if TYPE_CHECKING: 

40 import httpx2 as httpx 

41else: 

42 try: 

43 import httpx2 as httpx 

44 except ModuleNotFoundError: # pragma: no cover 

45 try: 

46 import httpx 

47 except ModuleNotFoundError: 

48 raise RuntimeError( 

49 "The starlette.testclient module requires the httpx2 package to be installed.\n" 

50 "You can install this with:\n" 

51 " $ pip install httpx2\n" 

52 ) 

53 else: 

54 warnings.warn( 

55 "Using `httpx` with `starlette.testclient` is deprecated; install `httpx2` instead.", 

56 StarletteDeprecationWarning, 

57 stacklevel=2, 

58 ) 

59_PortalFactoryType = Callable[[], AbstractContextManager[anyio.abc.BlockingPortal]] 

60 

61ASGIInstance = Callable[[Receive, Send], Awaitable[None]] 

62ASGI2App = Callable[[Scope], ASGIInstance] 

63ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]] 

64 

65 

66_RequestData = Mapping[str, str | Iterable[str] | bytes] 

67 

68 

69def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]: 

70 if inspect.isclass(app): 

71 return hasattr(app, "__await__") 

72 return is_async_callable(app) 

73 

74 

75class _WrapASGI2: 

76 """ 

77 Provide an ASGI3 interface onto an ASGI2 app. 

78 """ 

79 

80 def __init__(self, app: ASGI2App) -> None: 

81 self.app = app 

82 

83 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

84 instance = self.app(scope) 

85 await instance(receive, send) 

86 

87 

88class _AsyncBackend(TypedDict): 

89 backend: str 

90 backend_options: dict[str, Any] 

91 

92 

93class _Upgrade(Exception): 

94 def __init__(self, session: WebSocketTestSession) -> None: 

95 self.session = session 

96 

97 

98class WebSocketDenialResponse( # type: ignore[misc] 

99 httpx.Response, 

100 WebSocketDisconnect, 

101): 

102 """ 

103 A special case of `WebSocketDisconnect`, raised in the `TestClient` if the 

104 `WebSocket` is closed before being accepted with a `send_denial_response()`. 

105 """ 

106 

107 

108class WebSocketTestSession: 

109 def __init__( 

110 self, 

111 app: ASGI3App, 

112 scope: Scope, 

113 portal_factory: _PortalFactoryType, 

114 ) -> None: 

115 self.app = app 

116 self.scope = scope 

117 self.accepted_subprotocol = None 

118 self.portal_factory = portal_factory 

119 self.extra_headers = None 

120 

121 def __enter__(self) -> WebSocketTestSession: 

122 with contextlib.ExitStack() as stack: 

123 self.portal = portal = stack.enter_context(self.portal_factory()) 

124 fut, cs = portal.start_task(self._run) 

125 stack.callback(fut.result) 

126 stack.callback(portal.call, cs.cancel) 

127 self.send({"type": "websocket.connect"}) 

128 message = self.receive() 

129 self._raise_on_close(message) 

130 self.accepted_subprotocol = message.get("subprotocol", None) 

131 self.extra_headers = message.get("headers", None) 

132 stack.callback(self.close, 1000) 

133 self.exit_stack = stack.pop_all() 

134 return self 

135 

136 def __exit__(self, *args: Any) -> bool | None: 

137 return self.exit_stack.__exit__(*args) 

138 

139 async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: 

140 """ 

141 The sub-thread in which the websocket session runs. 

142 """ 

143 send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) 

144 send_tx, send_rx = send 

145 receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) 

146 receive_tx, receive_rx = receive 

147 with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs: 

148 self._receive_tx = receive_tx 

149 self._send_rx = send_rx 

150 task_status.started(cs) 

151 await self.app(self.scope, receive_rx.receive, send_tx.send) 

152 

153 # wait for cs.cancel to be called before closing streams 

154 await anyio.sleep_forever() 

155 

156 def _raise_on_close(self, message: Message) -> None: 

157 if message["type"] == "websocket.close": 

158 raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", "")) 

159 elif message["type"] == "websocket.http.response.start": 

160 status_code: int = message["status"] 

161 headers: list[tuple[bytes, bytes]] = message["headers"] 

162 body: list[bytes] = [] 

163 while True: 

164 message = self.receive() 

165 assert message["type"] == "websocket.http.response.body" 

166 body.append(message["body"]) 

167 if not message.get("more_body", False): 

168 break 

169 raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body)) 

170 

171 def send(self, message: Message) -> None: 

172 self.portal.call(self._receive_tx.send, message) 

173 

174 def send_text(self, data: str) -> None: 

175 self.send({"type": "websocket.receive", "text": data}) 

176 

177 def send_bytes(self, data: bytes) -> None: 

178 self.send({"type": "websocket.receive", "bytes": data}) 

179 

180 def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None: 

181 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) 

182 if mode == "text": 

183 self.send({"type": "websocket.receive", "text": text}) 

184 else: 

185 self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) 

186 

187 def close(self, code: int = 1000, reason: str | None = None) -> None: 

188 self.send({"type": "websocket.disconnect", "code": code, "reason": reason}) 

189 

190 def receive(self) -> Message: 

191 return self.portal.call(self._send_rx.receive) 

192 

193 def receive_text(self) -> str: 

194 message = self.receive() 

195 self._raise_on_close(message) 

196 return cast(str, message["text"]) 

197 

198 def receive_bytes(self) -> bytes: 

199 message = self.receive() 

200 self._raise_on_close(message) 

201 return cast(bytes, message["bytes"]) 

202 

203 def receive_json(self, mode: Literal["text", "binary"] = "text") -> Any: 

204 message = self.receive() 

205 self._raise_on_close(message) 

206 if mode == "text": 

207 text = message["text"] 

208 else: 

209 text = message["bytes"].decode("utf-8") 

210 return json.loads(text) 

211 

212 

213class _TestClientTransport(httpx.BaseTransport): 

214 def __init__( 

215 self, 

216 app: ASGI3App, 

217 portal_factory: _PortalFactoryType, 

218 raise_server_exceptions: bool = True, 

219 root_path: str = "", 

220 *, 

221 client: tuple[str, int], 

222 app_state: dict[str, Any], 

223 ) -> None: 

224 self.app = app 

225 self.raise_server_exceptions = raise_server_exceptions 

226 self.root_path = root_path 

227 self.portal_factory = portal_factory 

228 self.app_state = app_state 

229 self.client = client 

230 

231 def handle_request(self, request: httpx.Request) -> httpx.Response: 

232 scheme = request.url.scheme 

233 netloc = request.url.netloc.decode(encoding="ascii") 

234 path = request.url.path 

235 raw_path = request.url.raw_path 

236 query = request.url.query.decode(encoding="ascii") 

237 

238 default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] 

239 

240 if ":" in netloc: 

241 host, port_string = netloc.split(":", 1) 

242 port = int(port_string) 

243 else: 

244 host = netloc 

245 port = default_port 

246 

247 # Include the 'host' header. 

248 if "host" in request.headers: 

249 headers: list[tuple[bytes, bytes]] = [] 

250 elif port == default_port: # pragma: no cover 

251 headers = [(b"host", host.encode())] 

252 else: # pragma: no cover 

253 headers = [(b"host", (f"{host}:{port}").encode())] 

254 

255 # Include other request headers. 

256 headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()] 

257 

258 scope: dict[str, Any] 

259 

260 if scheme in {"ws", "wss"}: 

261 subprotocol = request.headers.get("sec-websocket-protocol", None) 

262 if subprotocol is None: 

263 subprotocols: Sequence[str] = [] 

264 else: 

265 subprotocols = [value.strip() for value in subprotocol.split(",")] 

266 scope = { 

267 "type": "websocket", 

268 "path": unquote(path), 

269 "raw_path": raw_path.split(b"?", 1)[0], 

270 "root_path": self.root_path, 

271 "scheme": scheme, 

272 "query_string": query.encode(), 

273 "headers": headers, 

274 "client": self.client, 

275 "server": [host, port], 

276 "subprotocols": subprotocols, 

277 "state": self.app_state.copy(), 

278 "extensions": {"websocket.http.response": {}}, 

279 } 

280 session = WebSocketTestSession(self.app, scope, self.portal_factory) 

281 raise _Upgrade(session) 

282 

283 scope = { 

284 "type": "http", 

285 "http_version": "1.1", 

286 "method": request.method, 

287 "path": unquote(path), 

288 "raw_path": raw_path.split(b"?", 1)[0], 

289 "root_path": self.root_path, 

290 "scheme": scheme, 

291 "query_string": query.encode(), 

292 "headers": headers, 

293 "client": self.client, 

294 "server": [host, port], 

295 "extensions": {"http.response.debug": {}}, 

296 "state": self.app_state.copy(), 

297 } 

298 

299 request_complete = False 

300 response_started = False 

301 response_complete: anyio.Event 

302 raw_kwargs: dict[str, Any] = {"stream": io.BytesIO()} 

303 template = None 

304 context = None 

305 

306 async def receive() -> Message: 

307 nonlocal request_complete 

308 

309 if request_complete: 

310 if not response_complete.is_set(): 

311 await response_complete.wait() 

312 return {"type": "http.disconnect"} 

313 

314 body = request.read() 

315 if isinstance(body, str): 

316 body_bytes: bytes = body.encode("utf-8") # pragma: no cover 

317 elif body is None: 

318 body_bytes = b"" # pragma: no cover 

319 elif isinstance(body, GeneratorType): 

320 try: # pragma: no cover 

321 chunk = body.send(None) 

322 if isinstance(chunk, str): 

323 chunk = chunk.encode("utf-8") 

324 return {"type": "http.request", "body": chunk, "more_body": True} 

325 except StopIteration: # pragma: no cover 

326 request_complete = True 

327 return {"type": "http.request", "body": b""} 

328 else: 

329 body_bytes = body 

330 

331 request_complete = True 

332 return {"type": "http.request", "body": body_bytes} 

333 

334 async def send(message: Message) -> None: 

335 nonlocal raw_kwargs, response_started, template, context 

336 

337 if message["type"] == "http.response.start": 

338 assert not response_started, 'Received multiple "http.response.start" messages.' 

339 raw_kwargs["status_code"] = message["status"] 

340 raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])] 

341 response_started = True 

342 elif message["type"] == "http.response.body": 

343 assert response_started, 'Received "http.response.body" without "http.response.start".' 

344 assert not response_complete.is_set(), 'Received "http.response.body" after response completed.' 

345 body = message.get("body", b"") 

346 more_body = message.get("more_body", False) 

347 if request.method != "HEAD": 

348 raw_kwargs["stream"].write(body) 

349 if not more_body: 

350 raw_kwargs["stream"].seek(0) 

351 response_complete.set() 

352 elif message["type"] == "http.response.debug": 

353 template = message["info"]["template"] 

354 context = message["info"]["context"] 

355 

356 try: 

357 with self.portal_factory() as portal: 

358 response_complete = portal.call(anyio.Event) 

359 portal.call(self.app, scope, receive, send) 

360 except BaseException as exc: 

361 if self.raise_server_exceptions: 

362 raise exc 

363 

364 if self.raise_server_exceptions: 

365 assert response_started, "TestClient did not receive any response." 

366 elif not response_started: 

367 raw_kwargs = { 

368 "status_code": 500, 

369 "headers": [], 

370 "stream": io.BytesIO(), 

371 } 

372 

373 raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read()) 

374 

375 response = httpx.Response(**raw_kwargs, request=request) 

376 if template is not None: 

377 response.template = template # type: ignore[attr-defined] 

378 response.context = context # type: ignore[attr-defined] 

379 return response 

380 

381 

382class TestClient(httpx.Client): 

383 __test__ = False 

384 task: Future[None] 

385 portal: anyio.abc.BlockingPortal | None = None 

386 

387 def __init__( 

388 self, 

389 app: ASGIApp, 

390 base_url: str = "http://testserver", 

391 raise_server_exceptions: bool = True, 

392 root_path: str = "", 

393 backend: Literal["asyncio", "trio"] = "asyncio", 

394 backend_options: dict[str, Any] | None = None, 

395 cookies: httpx._types.CookieTypes | None = None, 

396 headers: dict[str, str] | None = None, 

397 follow_redirects: bool = True, 

398 client: tuple[str, int] = ("testclient", 50000), 

399 ) -> None: 

400 self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {}) 

401 if _is_asgi3(app): 

402 asgi_app = app 

403 else: 

404 app = cast(ASGI2App, app) # type: ignore[assignment] 

405 asgi_app = _WrapASGI2(app) # type: ignore[arg-type] 

406 self.app = asgi_app 

407 self.app_state: dict[str, Any] = {} 

408 transport = _TestClientTransport( 

409 self.app, 

410 portal_factory=self._portal_factory, 

411 raise_server_exceptions=raise_server_exceptions, 

412 root_path=root_path, 

413 app_state=self.app_state, 

414 client=client, 

415 ) 

416 if headers is None: 

417 headers = {} 

418 headers.setdefault("user-agent", "testclient") 

419 super().__init__( 

420 base_url=base_url, 

421 headers=headers, 

422 transport=transport, 

423 follow_redirects=follow_redirects, 

424 cookies=cookies, 

425 ) 

426 

427 @contextlib.contextmanager 

428 def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]: 

429 if self.portal is not None: 

430 yield self.portal 

431 else: 

432 with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal: 

433 yield portal 

434 

435 def request( # type: ignore[override] 

436 self, 

437 method: str, 

438 url: httpx._types.URLTypes, 

439 *, 

440 content: httpx._types.RequestContent | None = None, 

441 data: _RequestData | None = None, 

442 files: httpx._types.RequestFiles | None = None, 

443 json: Any = None, 

444 params: httpx._types.QueryParamTypes | None = None, 

445 headers: httpx._types.HeaderTypes | None = None, 

446 cookies: httpx._types.CookieTypes | None = None, 

447 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

448 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

449 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

450 extensions: dict[str, Any] | None = None, 

451 ) -> httpx.Response: 

452 if timeout is not httpx.USE_CLIENT_DEFAULT: 

453 warnings.warn( 

454 "You should not use the 'timeout' argument with the TestClient. " 

455 "See https://github.com/Kludex/starlette/issues/1108 for more information.", 

456 DeprecationWarning, 

457 ) 

458 url = self._merge_url(url) 

459 return super().request( 

460 method, 

461 url, 

462 content=content, 

463 data=data, 

464 files=files, 

465 json=json, 

466 params=params, 

467 headers=headers, 

468 cookies=cookies, 

469 auth=auth, 

470 follow_redirects=follow_redirects, 

471 timeout=timeout, 

472 extensions=extensions, 

473 ) 

474 

475 def get( # type: ignore[override] 

476 self, 

477 url: httpx._types.URLTypes, 

478 *, 

479 params: httpx._types.QueryParamTypes | None = None, 

480 headers: httpx._types.HeaderTypes | None = None, 

481 cookies: httpx._types.CookieTypes | None = None, 

482 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

483 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

484 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

485 extensions: dict[str, Any] | None = None, 

486 ) -> httpx.Response: 

487 return super().get( 

488 url, 

489 params=params, 

490 headers=headers, 

491 cookies=cookies, 

492 auth=auth, 

493 follow_redirects=follow_redirects, 

494 timeout=timeout, 

495 extensions=extensions, 

496 ) 

497 

498 def options( # type: ignore[override] 

499 self, 

500 url: httpx._types.URLTypes, 

501 *, 

502 params: httpx._types.QueryParamTypes | None = None, 

503 headers: httpx._types.HeaderTypes | None = None, 

504 cookies: httpx._types.CookieTypes | None = None, 

505 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

506 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

507 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

508 extensions: dict[str, Any] | None = None, 

509 ) -> httpx.Response: 

510 return super().options( 

511 url, 

512 params=params, 

513 headers=headers, 

514 cookies=cookies, 

515 auth=auth, 

516 follow_redirects=follow_redirects, 

517 timeout=timeout, 

518 extensions=extensions, 

519 ) 

520 

521 def head( # type: ignore[override] 

522 self, 

523 url: httpx._types.URLTypes, 

524 *, 

525 params: httpx._types.QueryParamTypes | None = None, 

526 headers: httpx._types.HeaderTypes | None = None, 

527 cookies: httpx._types.CookieTypes | None = None, 

528 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

529 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

530 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

531 extensions: dict[str, Any] | None = None, 

532 ) -> httpx.Response: 

533 return super().head( 

534 url, 

535 params=params, 

536 headers=headers, 

537 cookies=cookies, 

538 auth=auth, 

539 follow_redirects=follow_redirects, 

540 timeout=timeout, 

541 extensions=extensions, 

542 ) 

543 

544 def post( # type: ignore[override] 

545 self, 

546 url: httpx._types.URLTypes, 

547 *, 

548 content: httpx._types.RequestContent | None = None, 

549 data: _RequestData | None = None, 

550 files: httpx._types.RequestFiles | None = None, 

551 json: Any = None, 

552 params: httpx._types.QueryParamTypes | None = None, 

553 headers: httpx._types.HeaderTypes | None = None, 

554 cookies: httpx._types.CookieTypes | None = None, 

555 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

556 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

557 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

558 extensions: dict[str, Any] | None = None, 

559 ) -> httpx.Response: 

560 return super().post( 

561 url, 

562 content=content, 

563 data=data, 

564 files=files, 

565 json=json, 

566 params=params, 

567 headers=headers, 

568 cookies=cookies, 

569 auth=auth, 

570 follow_redirects=follow_redirects, 

571 timeout=timeout, 

572 extensions=extensions, 

573 ) 

574 

575 def put( # type: ignore[override] 

576 self, 

577 url: httpx._types.URLTypes, 

578 *, 

579 content: httpx._types.RequestContent | None = None, 

580 data: _RequestData | None = None, 

581 files: httpx._types.RequestFiles | None = None, 

582 json: Any = None, 

583 params: httpx._types.QueryParamTypes | None = None, 

584 headers: httpx._types.HeaderTypes | None = None, 

585 cookies: httpx._types.CookieTypes | None = None, 

586 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

587 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

588 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

589 extensions: dict[str, Any] | None = None, 

590 ) -> httpx.Response: 

591 return super().put( 

592 url, 

593 content=content, 

594 data=data, 

595 files=files, 

596 json=json, 

597 params=params, 

598 headers=headers, 

599 cookies=cookies, 

600 auth=auth, 

601 follow_redirects=follow_redirects, 

602 timeout=timeout, 

603 extensions=extensions, 

604 ) 

605 

606 def patch( # type: ignore[override] 

607 self, 

608 url: httpx._types.URLTypes, 

609 *, 

610 content: httpx._types.RequestContent | None = None, 

611 data: _RequestData | None = None, 

612 files: httpx._types.RequestFiles | None = None, 

613 json: Any = None, 

614 params: httpx._types.QueryParamTypes | None = None, 

615 headers: httpx._types.HeaderTypes | None = None, 

616 cookies: httpx._types.CookieTypes | None = None, 

617 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

618 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

619 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

620 extensions: dict[str, Any] | None = None, 

621 ) -> httpx.Response: 

622 return super().patch( 

623 url, 

624 content=content, 

625 data=data, 

626 files=files, 

627 json=json, 

628 params=params, 

629 headers=headers, 

630 cookies=cookies, 

631 auth=auth, 

632 follow_redirects=follow_redirects, 

633 timeout=timeout, 

634 extensions=extensions, 

635 ) 

636 

637 def delete( # type: ignore[override] 

638 self, 

639 url: httpx._types.URLTypes, 

640 *, 

641 params: httpx._types.QueryParamTypes | None = None, 

642 headers: httpx._types.HeaderTypes | None = None, 

643 cookies: httpx._types.CookieTypes | None = None, 

644 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

645 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

646 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

647 extensions: dict[str, Any] | None = None, 

648 ) -> httpx.Response: 

649 return super().delete( 

650 url, 

651 params=params, 

652 headers=headers, 

653 cookies=cookies, 

654 auth=auth, 

655 follow_redirects=follow_redirects, 

656 timeout=timeout, 

657 extensions=extensions, 

658 ) 

659 

660 def websocket_connect( 

661 self, 

662 url: str, 

663 subprotocols: Sequence[str] | None = None, 

664 **kwargs: Any, 

665 ) -> WebSocketTestSession: 

666 url = urljoin("ws://testserver", url) 

667 headers = kwargs.get("headers", {}) 

668 headers.setdefault("connection", "upgrade") 

669 headers.setdefault("sec-websocket-key", "testserver==") 

670 headers.setdefault("sec-websocket-version", "13") 

671 if subprotocols is not None: 

672 headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) 

673 kwargs["headers"] = headers 

674 try: 

675 super().request("GET", url, **kwargs) 

676 except _Upgrade as exc: 

677 session = exc.session 

678 else: 

679 raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover 

680 

681 return session 

682 

683 def __enter__(self) -> Self: 

684 with contextlib.ExitStack() as stack: 

685 self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend)) 

686 

687 @stack.callback 

688 def reset_portal() -> None: 

689 self.portal = None 

690 

691 send: anyio.create_memory_object_stream[MutableMapping[str, Any] | None] = ( 

692 anyio.create_memory_object_stream(math.inf) 

693 ) 

694 receive: anyio.create_memory_object_stream[MutableMapping[str, Any]] = anyio.create_memory_object_stream( 

695 math.inf 

696 ) 

697 for channel in (*send, *receive): 

698 stack.callback(channel.close) 

699 self.stream_send = StapledObjectStream(*send) 

700 self.stream_receive = StapledObjectStream(*receive) 

701 self.task = portal.start_task_soon(self.lifespan) 

702 portal.call(self.wait_startup) 

703 

704 @stack.callback 

705 def wait_shutdown() -> None: 

706 portal.call(self.wait_shutdown) 

707 

708 self.exit_stack = stack.pop_all() 

709 

710 return self 

711 

712 def __exit__(self, *args: Any) -> None: 

713 self.exit_stack.close() 

714 

715 async def lifespan(self) -> None: 

716 scope = {"type": "lifespan", "state": self.app_state} 

717 try: 

718 await self.app(scope, self.stream_receive.receive, self.stream_send.send) 

719 finally: 

720 await self.stream_send.send(None) 

721 

722 async def wait_startup(self) -> None: 

723 await self.stream_receive.send({"type": "lifespan.startup"}) 

724 

725 async def receive() -> Any: 

726 message = await self.stream_send.receive() 

727 if message is None: 

728 self.task.result() 

729 return message 

730 

731 message = await receive() 

732 assert message["type"] in ( 

733 "lifespan.startup.complete", 

734 "lifespan.startup.failed", 

735 ) 

736 if message["type"] == "lifespan.startup.failed": 

737 await receive() 

738 

739 async def wait_shutdown(self) -> None: 

740 async def receive() -> Any: 

741 message = await self.stream_send.receive() 

742 if message is None: 

743 self.task.result() 

744 return message 

745 

746 await self.stream_receive.send({"type": "lifespan.shutdown"}) 

747 message = await receive() 

748 assert message["type"] in ( 

749 "lifespan.shutdown.complete", 

750 "lifespan.shutdown.failed", 

751 ) 

752 if message["type"] == "lifespan.shutdown.failed": 

753 await receive()