Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/testclient.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

355 statements  

1from __future__ import annotations 

2 

3import contextlib 

4import inspect 

5import io 

6import json 

7import math 

8import queue 

9import sys 

10import typing 

11import warnings 

12from concurrent.futures import Future 

13from functools import cached_property 

14from types import GeneratorType 

15from urllib.parse import unquote, urljoin 

16 

17import anyio 

18import anyio.abc 

19import anyio.from_thread 

20from anyio.abc import ObjectReceiveStream, ObjectSendStream 

21from anyio.streams.stapled import StapledObjectStream 

22 

23from starlette._utils import is_async_callable 

24from starlette.types import ASGIApp, Message, Receive, Scope, Send 

25from starlette.websockets import WebSocketDisconnect 

26 

27if sys.version_info >= (3, 10): # pragma: no cover 

28 from typing import TypeGuard 

29else: # pragma: no cover 

30 from typing_extensions import TypeGuard 

31 

32try: 

33 import httpx 

34except ModuleNotFoundError: # pragma: no cover 

35 raise RuntimeError( 

36 "The starlette.testclient module requires the httpx package to be installed.\n" 

37 "You can install this with:\n" 

38 " $ pip install httpx\n" 

39 ) 

40_PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]] 

41 

42ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]] 

43ASGI2App = typing.Callable[[Scope], ASGIInstance] 

44ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] 

45 

46 

47_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]] 

48 

49 

50def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]: 

51 if inspect.isclass(app): 

52 return hasattr(app, "__await__") 

53 return is_async_callable(app) 

54 

55 

56class _WrapASGI2: 

57 """ 

58 Provide an ASGI3 interface onto an ASGI2 app. 

59 """ 

60 

61 def __init__(self, app: ASGI2App) -> None: 

62 self.app = app 

63 

64 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

65 instance = self.app(scope) 

66 await instance(receive, send) 

67 

68 

69class _AsyncBackend(typing.TypedDict): 

70 backend: str 

71 backend_options: dict[str, typing.Any] 

72 

73 

74class _Upgrade(Exception): 

75 def __init__(self, session: WebSocketTestSession) -> None: 

76 self.session = session 

77 

78 

79class WebSocketDenialResponse( # type: ignore[misc] 

80 httpx.Response, 

81 WebSocketDisconnect, 

82): 

83 """ 

84 A special case of `WebSocketDisconnect`, raised in the `TestClient` if the 

85 `WebSocket` is closed before being accepted with a `send_denial_response()`. 

86 """ 

87 

88 

89class WebSocketTestSession: 

90 def __init__( 

91 self, 

92 app: ASGI3App, 

93 scope: Scope, 

94 portal_factory: _PortalFactoryType, 

95 ) -> None: 

96 self.app = app 

97 self.scope = scope 

98 self.accepted_subprotocol = None 

99 self.portal_factory = portal_factory 

100 self._receive_queue: queue.Queue[Message] = queue.Queue() 

101 self._send_queue: queue.Queue[Message | BaseException] = queue.Queue() 

102 self.extra_headers = None 

103 

104 def __enter__(self) -> WebSocketTestSession: 

105 self.exit_stack = contextlib.ExitStack() 

106 self.portal = self.exit_stack.enter_context(self.portal_factory()) 

107 

108 try: 

109 _: Future[None] = self.portal.start_task_soon(self._run) 

110 self.send({"type": "websocket.connect"}) 

111 message = self.receive() 

112 self._raise_on_close(message) 

113 except Exception: 

114 self.exit_stack.close() 

115 raise 

116 self.accepted_subprotocol = message.get("subprotocol", None) 

117 self.extra_headers = message.get("headers", None) 

118 return self 

119 

120 @cached_property 

121 def should_close(self) -> anyio.Event: 

122 return anyio.Event() 

123 

124 async def _notify_close(self) -> None: 

125 self.should_close.set() 

126 

127 def __exit__(self, *args: typing.Any) -> None: 

128 try: 

129 self.close(1000) 

130 finally: 

131 self.portal.start_task_soon(self._notify_close) 

132 self.exit_stack.close() 

133 while not self._send_queue.empty(): 

134 message = self._send_queue.get() 

135 if isinstance(message, BaseException): 

136 raise message 

137 

138 async def _run(self) -> None: 

139 """ 

140 The sub-thread in which the websocket session runs. 

141 """ 

142 

143 async def run_app(tg: anyio.abc.TaskGroup) -> None: 

144 try: 

145 await self.app(self.scope, self._asgi_receive, self._asgi_send) 

146 except anyio.get_cancelled_exc_class(): 

147 ... 

148 except BaseException as exc: 

149 self._send_queue.put(exc) 

150 raise 

151 finally: 

152 tg.cancel_scope.cancel() 

153 

154 async with anyio.create_task_group() as tg: 

155 tg.start_soon(run_app, tg) 

156 await self.should_close.wait() 

157 tg.cancel_scope.cancel() 

158 

159 async def _asgi_receive(self) -> Message: 

160 while self._receive_queue.empty(): 

161 self._queue_event = anyio.Event() 

162 await self._queue_event.wait() 

163 return self._receive_queue.get() 

164 

165 async def _asgi_send(self, message: Message) -> None: 

166 self._send_queue.put(message) 

167 

168 def _raise_on_close(self, message: Message) -> None: 

169 if message["type"] == "websocket.close": 

170 raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", "")) 

171 elif message["type"] == "websocket.http.response.start": 

172 status_code: int = message["status"] 

173 headers: list[tuple[bytes, bytes]] = message["headers"] 

174 body: list[bytes] = [] 

175 while True: 

176 message = self.receive() 

177 assert message["type"] == "websocket.http.response.body" 

178 body.append(message["body"]) 

179 if not message.get("more_body", False): 

180 break 

181 raise WebSocketDenialResponse( 

182 status_code=status_code, 

183 headers=headers, 

184 content=b"".join(body), 

185 ) 

186 

187 def send(self, message: Message) -> None: 

188 self._receive_queue.put(message) 

189 if hasattr(self, "_queue_event"): 

190 self.portal.start_task_soon(self._queue_event.set) 

191 

192 def send_text(self, data: str) -> None: 

193 self.send({"type": "websocket.receive", "text": data}) 

194 

195 def send_bytes(self, data: bytes) -> None: 

196 self.send({"type": "websocket.receive", "bytes": data}) 

197 

198 def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None: 

199 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) 

200 if mode == "text": 

201 self.send({"type": "websocket.receive", "text": text}) 

202 else: 

203 self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) 

204 

205 def close(self, code: int = 1000, reason: str | None = None) -> None: 

206 self.send({"type": "websocket.disconnect", "code": code, "reason": reason}) 

207 

208 def receive(self) -> Message: 

209 message = self._send_queue.get() 

210 if isinstance(message, BaseException): 

211 raise message 

212 return message 

213 

214 def receive_text(self) -> str: 

215 message = self.receive() 

216 self._raise_on_close(message) 

217 return typing.cast(str, message["text"]) 

218 

219 def receive_bytes(self) -> bytes: 

220 message = self.receive() 

221 self._raise_on_close(message) 

222 return typing.cast(bytes, message["bytes"]) 

223 

224 def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any: 

225 message = self.receive() 

226 self._raise_on_close(message) 

227 if mode == "text": 

228 text = message["text"] 

229 else: 

230 text = message["bytes"].decode("utf-8") 

231 return json.loads(text) 

232 

233 

234class _TestClientTransport(httpx.BaseTransport): 

235 def __init__( 

236 self, 

237 app: ASGI3App, 

238 portal_factory: _PortalFactoryType, 

239 raise_server_exceptions: bool = True, 

240 root_path: str = "", 

241 *, 

242 app_state: dict[str, typing.Any], 

243 ) -> None: 

244 self.app = app 

245 self.raise_server_exceptions = raise_server_exceptions 

246 self.root_path = root_path 

247 self.portal_factory = portal_factory 

248 self.app_state = app_state 

249 

250 def handle_request(self, request: httpx.Request) -> httpx.Response: 

251 scheme = request.url.scheme 

252 netloc = request.url.netloc.decode(encoding="ascii") 

253 path = request.url.path 

254 raw_path = request.url.raw_path 

255 query = request.url.query.decode(encoding="ascii") 

256 

257 default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] 

258 

259 if ":" in netloc: 

260 host, port_string = netloc.split(":", 1) 

261 port = int(port_string) 

262 else: 

263 host = netloc 

264 port = default_port 

265 

266 # Include the 'host' header. 

267 if "host" in request.headers: 

268 headers: list[tuple[bytes, bytes]] = [] 

269 elif port == default_port: # pragma: no cover 

270 headers = [(b"host", host.encode())] 

271 else: # pragma: no cover 

272 headers = [(b"host", (f"{host}:{port}").encode())] 

273 

274 # Include other request headers. 

275 headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()] 

276 

277 scope: dict[str, typing.Any] 

278 

279 if scheme in {"ws", "wss"}: 

280 subprotocol = request.headers.get("sec-websocket-protocol", None) 

281 if subprotocol is None: 

282 subprotocols: typing.Sequence[str] = [] 

283 else: 

284 subprotocols = [value.strip() for value in subprotocol.split(",")] 

285 scope = { 

286 "type": "websocket", 

287 "path": unquote(path), 

288 "raw_path": raw_path, 

289 "root_path": self.root_path, 

290 "scheme": scheme, 

291 "query_string": query.encode(), 

292 "headers": headers, 

293 "client": ["testclient", 50000], 

294 "server": [host, port], 

295 "subprotocols": subprotocols, 

296 "state": self.app_state.copy(), 

297 "extensions": {"websocket.http.response": {}}, 

298 } 

299 session = WebSocketTestSession(self.app, scope, self.portal_factory) 

300 raise _Upgrade(session) 

301 

302 scope = { 

303 "type": "http", 

304 "http_version": "1.1", 

305 "method": request.method, 

306 "path": unquote(path), 

307 "raw_path": raw_path, 

308 "root_path": self.root_path, 

309 "scheme": scheme, 

310 "query_string": query.encode(), 

311 "headers": headers, 

312 "client": ["testclient", 50000], 

313 "server": [host, port], 

314 "extensions": {"http.response.debug": {}}, 

315 "state": self.app_state.copy(), 

316 } 

317 

318 request_complete = False 

319 response_started = False 

320 response_complete: anyio.Event 

321 raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()} 

322 template = None 

323 context = None 

324 

325 async def receive() -> Message: 

326 nonlocal request_complete 

327 

328 if request_complete: 

329 if not response_complete.is_set(): 

330 await response_complete.wait() 

331 return {"type": "http.disconnect"} 

332 

333 body = request.read() 

334 if isinstance(body, str): 

335 body_bytes: bytes = body.encode("utf-8") # pragma: no cover 

336 elif body is None: 

337 body_bytes = b"" # pragma: no cover 

338 elif isinstance(body, GeneratorType): 

339 try: # pragma: no cover 

340 chunk = body.send(None) 

341 if isinstance(chunk, str): 

342 chunk = chunk.encode("utf-8") 

343 return {"type": "http.request", "body": chunk, "more_body": True} 

344 except StopIteration: # pragma: no cover 

345 request_complete = True 

346 return {"type": "http.request", "body": b""} 

347 else: 

348 body_bytes = body 

349 

350 request_complete = True 

351 return {"type": "http.request", "body": body_bytes} 

352 

353 async def send(message: Message) -> None: 

354 nonlocal raw_kwargs, response_started, template, context 

355 

356 if message["type"] == "http.response.start": 

357 assert not response_started, 'Received multiple "http.response.start" messages.' 

358 raw_kwargs["status_code"] = message["status"] 

359 raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])] 

360 response_started = True 

361 elif message["type"] == "http.response.body": 

362 assert response_started, 'Received "http.response.body" without "http.response.start".' 

363 assert not response_complete.is_set(), 'Received "http.response.body" after response completed.' 

364 body = message.get("body", b"") 

365 more_body = message.get("more_body", False) 

366 if request.method != "HEAD": 

367 raw_kwargs["stream"].write(body) 

368 if not more_body: 

369 raw_kwargs["stream"].seek(0) 

370 response_complete.set() 

371 elif message["type"] == "http.response.debug": 

372 template = message["info"]["template"] 

373 context = message["info"]["context"] 

374 

375 try: 

376 with self.portal_factory() as portal: 

377 response_complete = portal.call(anyio.Event) 

378 portal.call(self.app, scope, receive, send) 

379 except BaseException as exc: 

380 if self.raise_server_exceptions: 

381 raise exc 

382 

383 if self.raise_server_exceptions: 

384 assert response_started, "TestClient did not receive any response." 

385 elif not response_started: 

386 raw_kwargs = { 

387 "status_code": 500, 

388 "headers": [], 

389 "stream": io.BytesIO(), 

390 } 

391 

392 raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read()) 

393 

394 response = httpx.Response(**raw_kwargs, request=request) 

395 if template is not None: 

396 response.template = template # type: ignore[attr-defined] 

397 response.context = context # type: ignore[attr-defined] 

398 return response 

399 

400 

401class TestClient(httpx.Client): 

402 __test__ = False 

403 task: Future[None] 

404 portal: anyio.abc.BlockingPortal | None = None 

405 

406 def __init__( 

407 self, 

408 app: ASGIApp, 

409 base_url: str = "http://testserver", 

410 raise_server_exceptions: bool = True, 

411 root_path: str = "", 

412 backend: typing.Literal["asyncio", "trio"] = "asyncio", 

413 backend_options: dict[str, typing.Any] | None = None, 

414 cookies: httpx._types.CookieTypes | None = None, 

415 headers: dict[str, str] | None = None, 

416 follow_redirects: bool = True, 

417 ) -> None: 

418 self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {}) 

419 if _is_asgi3(app): 

420 asgi_app = app 

421 else: 

422 app = typing.cast(ASGI2App, app) # type: ignore[assignment] 

423 asgi_app = _WrapASGI2(app) # type: ignore[arg-type] 

424 self.app = asgi_app 

425 self.app_state: dict[str, typing.Any] = {} 

426 transport = _TestClientTransport( 

427 self.app, 

428 portal_factory=self._portal_factory, 

429 raise_server_exceptions=raise_server_exceptions, 

430 root_path=root_path, 

431 app_state=self.app_state, 

432 ) 

433 if headers is None: 

434 headers = {} 

435 headers.setdefault("user-agent", "testclient") 

436 super().__init__( 

437 base_url=base_url, 

438 headers=headers, 

439 transport=transport, 

440 follow_redirects=follow_redirects, 

441 cookies=cookies, 

442 ) 

443 

444 @contextlib.contextmanager 

445 def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]: 

446 if self.portal is not None: 

447 yield self.portal 

448 else: 

449 with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal: 

450 yield portal 

451 

452 def _choose_redirect_arg( 

453 self, follow_redirects: bool | None, allow_redirects: bool | None 

454 ) -> bool | httpx._client.UseClientDefault: 

455 redirect: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT 

456 if allow_redirects is not None: 

457 message = "The `allow_redirects` argument is deprecated. Use `follow_redirects` instead." 

458 warnings.warn(message, DeprecationWarning) 

459 redirect = allow_redirects 

460 if follow_redirects is not None: 

461 redirect = follow_redirects 

462 elif allow_redirects is not None and follow_redirects is not None: 

463 raise RuntimeError( # pragma: no cover 

464 "Cannot use both `allow_redirects` and `follow_redirects`." 

465 ) 

466 return redirect 

467 

468 def request( # type: ignore[override] 

469 self, 

470 method: str, 

471 url: httpx._types.URLTypes, 

472 *, 

473 content: httpx._types.RequestContent | None = None, 

474 data: _RequestData | None = None, 

475 files: httpx._types.RequestFiles | None = None, 

476 json: typing.Any = None, 

477 params: httpx._types.QueryParamTypes | None = None, 

478 headers: httpx._types.HeaderTypes | None = None, 

479 cookies: httpx._types.CookieTypes | None = None, 

480 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

481 follow_redirects: bool | None = None, 

482 allow_redirects: bool | None = None, 

483 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

484 extensions: dict[str, typing.Any] | None = None, 

485 ) -> httpx.Response: 

486 url = self._merge_url(url) 

487 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

488 return super().request( 

489 method, 

490 url, 

491 content=content, 

492 data=data, 

493 files=files, 

494 json=json, 

495 params=params, 

496 headers=headers, 

497 cookies=cookies, 

498 auth=auth, 

499 follow_redirects=redirect, 

500 timeout=timeout, 

501 extensions=extensions, 

502 ) 

503 

504 def get( # type: ignore[override] 

505 self, 

506 url: httpx._types.URLTypes, 

507 *, 

508 params: httpx._types.QueryParamTypes | None = None, 

509 headers: httpx._types.HeaderTypes | None = None, 

510 cookies: httpx._types.CookieTypes | None = None, 

511 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

512 follow_redirects: bool | None = None, 

513 allow_redirects: bool | None = None, 

514 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

515 extensions: dict[str, typing.Any] | None = None, 

516 ) -> httpx.Response: 

517 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

518 return super().get( 

519 url, 

520 params=params, 

521 headers=headers, 

522 cookies=cookies, 

523 auth=auth, 

524 follow_redirects=redirect, 

525 timeout=timeout, 

526 extensions=extensions, 

527 ) 

528 

529 def options( # type: ignore[override] 

530 self, 

531 url: httpx._types.URLTypes, 

532 *, 

533 params: httpx._types.QueryParamTypes | None = None, 

534 headers: httpx._types.HeaderTypes | None = None, 

535 cookies: httpx._types.CookieTypes | None = None, 

536 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

537 follow_redirects: bool | None = None, 

538 allow_redirects: bool | None = None, 

539 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

540 extensions: dict[str, typing.Any] | None = None, 

541 ) -> httpx.Response: 

542 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

543 return super().options( 

544 url, 

545 params=params, 

546 headers=headers, 

547 cookies=cookies, 

548 auth=auth, 

549 follow_redirects=redirect, 

550 timeout=timeout, 

551 extensions=extensions, 

552 ) 

553 

554 def head( # type: ignore[override] 

555 self, 

556 url: httpx._types.URLTypes, 

557 *, 

558 params: httpx._types.QueryParamTypes | None = None, 

559 headers: httpx._types.HeaderTypes | None = None, 

560 cookies: httpx._types.CookieTypes | None = None, 

561 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

562 follow_redirects: bool | None = None, 

563 allow_redirects: bool | None = None, 

564 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

565 extensions: dict[str, typing.Any] | None = None, 

566 ) -> httpx.Response: 

567 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

568 return super().head( 

569 url, 

570 params=params, 

571 headers=headers, 

572 cookies=cookies, 

573 auth=auth, 

574 follow_redirects=redirect, 

575 timeout=timeout, 

576 extensions=extensions, 

577 ) 

578 

579 def post( # type: ignore[override] 

580 self, 

581 url: httpx._types.URLTypes, 

582 *, 

583 content: httpx._types.RequestContent | None = None, 

584 data: _RequestData | None = None, 

585 files: httpx._types.RequestFiles | None = None, 

586 json: typing.Any = None, 

587 params: httpx._types.QueryParamTypes | None = None, 

588 headers: httpx._types.HeaderTypes | None = None, 

589 cookies: httpx._types.CookieTypes | None = None, 

590 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

591 follow_redirects: bool | None = None, 

592 allow_redirects: bool | None = None, 

593 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

594 extensions: dict[str, typing.Any] | None = None, 

595 ) -> httpx.Response: 

596 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

597 return super().post( 

598 url, 

599 content=content, 

600 data=data, 

601 files=files, 

602 json=json, 

603 params=params, 

604 headers=headers, 

605 cookies=cookies, 

606 auth=auth, 

607 follow_redirects=redirect, 

608 timeout=timeout, 

609 extensions=extensions, 

610 ) 

611 

612 def put( # type: ignore[override] 

613 self, 

614 url: httpx._types.URLTypes, 

615 *, 

616 content: httpx._types.RequestContent | None = None, 

617 data: _RequestData | None = None, 

618 files: httpx._types.RequestFiles | None = None, 

619 json: typing.Any = None, 

620 params: httpx._types.QueryParamTypes | None = None, 

621 headers: httpx._types.HeaderTypes | None = None, 

622 cookies: httpx._types.CookieTypes | None = None, 

623 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

624 follow_redirects: bool | None = None, 

625 allow_redirects: bool | None = None, 

626 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

627 extensions: dict[str, typing.Any] | None = None, 

628 ) -> httpx.Response: 

629 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

630 return super().put( 

631 url, 

632 content=content, 

633 data=data, 

634 files=files, 

635 json=json, 

636 params=params, 

637 headers=headers, 

638 cookies=cookies, 

639 auth=auth, 

640 follow_redirects=redirect, 

641 timeout=timeout, 

642 extensions=extensions, 

643 ) 

644 

645 def patch( # type: ignore[override] 

646 self, 

647 url: httpx._types.URLTypes, 

648 *, 

649 content: httpx._types.RequestContent | None = None, 

650 data: _RequestData | None = None, 

651 files: httpx._types.RequestFiles | None = None, 

652 json: typing.Any = None, 

653 params: httpx._types.QueryParamTypes | None = None, 

654 headers: httpx._types.HeaderTypes | None = None, 

655 cookies: httpx._types.CookieTypes | None = None, 

656 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

657 follow_redirects: bool | None = None, 

658 allow_redirects: bool | None = None, 

659 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

660 extensions: dict[str, typing.Any] | None = None, 

661 ) -> httpx.Response: 

662 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

663 return super().patch( 

664 url, 

665 content=content, 

666 data=data, 

667 files=files, 

668 json=json, 

669 params=params, 

670 headers=headers, 

671 cookies=cookies, 

672 auth=auth, 

673 follow_redirects=redirect, 

674 timeout=timeout, 

675 extensions=extensions, 

676 ) 

677 

678 def delete( # type: ignore[override] 

679 self, 

680 url: httpx._types.URLTypes, 

681 *, 

682 params: httpx._types.QueryParamTypes | None = None, 

683 headers: httpx._types.HeaderTypes | None = None, 

684 cookies: httpx._types.CookieTypes | None = None, 

685 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

686 follow_redirects: bool | None = None, 

687 allow_redirects: bool | None = None, 

688 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

689 extensions: dict[str, typing.Any] | None = None, 

690 ) -> httpx.Response: 

691 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

692 return super().delete( 

693 url, 

694 params=params, 

695 headers=headers, 

696 cookies=cookies, 

697 auth=auth, 

698 follow_redirects=redirect, 

699 timeout=timeout, 

700 extensions=extensions, 

701 ) 

702 

703 def websocket_connect( 

704 self, 

705 url: str, 

706 subprotocols: typing.Sequence[str] | None = None, 

707 **kwargs: typing.Any, 

708 ) -> WebSocketTestSession: 

709 url = urljoin("ws://testserver", url) 

710 headers = kwargs.get("headers", {}) 

711 headers.setdefault("connection", "upgrade") 

712 headers.setdefault("sec-websocket-key", "testserver==") 

713 headers.setdefault("sec-websocket-version", "13") 

714 if subprotocols is not None: 

715 headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) 

716 kwargs["headers"] = headers 

717 try: 

718 super().request("GET", url, **kwargs) 

719 except _Upgrade as exc: 

720 session = exc.session 

721 else: 

722 raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover 

723 

724 return session 

725 

726 def __enter__(self) -> TestClient: 

727 with contextlib.ExitStack() as stack: 

728 self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend)) 

729 

730 @stack.callback 

731 def reset_portal() -> None: 

732 self.portal = None 

733 

734 send1: ObjectSendStream[typing.MutableMapping[str, typing.Any] | None] 

735 receive1: ObjectReceiveStream[typing.MutableMapping[str, typing.Any] | None] 

736 send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]] 

737 receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] 

738 send1, receive1 = anyio.create_memory_object_stream(math.inf) 

739 send2, receive2 = anyio.create_memory_object_stream(math.inf) 

740 self.stream_send = StapledObjectStream(send1, receive1) 

741 self.stream_receive = StapledObjectStream(send2, receive2) 

742 self.task = portal.start_task_soon(self.lifespan) 

743 portal.call(self.wait_startup) 

744 

745 @stack.callback 

746 def wait_shutdown() -> None: 

747 portal.call(self.wait_shutdown) 

748 

749 self.exit_stack = stack.pop_all() 

750 

751 return self 

752 

753 def __exit__(self, *args: typing.Any) -> None: 

754 self.exit_stack.close() 

755 

756 async def lifespan(self) -> None: 

757 scope = {"type": "lifespan", "state": self.app_state} 

758 try: 

759 await self.app(scope, self.stream_receive.receive, self.stream_send.send) 

760 finally: 

761 await self.stream_send.send(None) 

762 

763 async def wait_startup(self) -> None: 

764 await self.stream_receive.send({"type": "lifespan.startup"}) 

765 

766 async def receive() -> typing.Any: 

767 message = await self.stream_send.receive() 

768 if message is None: 

769 self.task.result() 

770 return message 

771 

772 message = await receive() 

773 assert message["type"] in ( 

774 "lifespan.startup.complete", 

775 "lifespan.startup.failed", 

776 ) 

777 if message["type"] == "lifespan.startup.failed": 

778 await receive() 

779 

780 async def wait_shutdown(self) -> None: 

781 async def receive() -> typing.Any: 

782 message = await self.stream_send.receive() 

783 if message is None: 

784 self.task.result() 

785 return message 

786 

787 async with self.stream_send: 

788 await self.stream_receive.send({"type": "lifespan.shutdown"}) 

789 message = await receive() 

790 assert message["type"] in ( 

791 "lifespan.shutdown.complete", 

792 "lifespan.shutdown.failed", 

793 ) 

794 if message["type"] == "lifespan.shutdown.failed": 

795 await receive()