Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/testclient.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

352 statements  

1from __future__ import annotations 

2 

3import contextlib 

4import inspect 

5import io 

6import json 

7import math 

8import queue 

9import sys 

10import typing 

11import warnings 

12from concurrent.futures import Future 

13from functools import cached_property 

14from types import GeneratorType 

15from urllib.parse import unquote, urljoin 

16 

17import anyio 

18import anyio.abc 

19import anyio.from_thread 

20from anyio.abc import ObjectReceiveStream, ObjectSendStream 

21from anyio.streams.stapled import StapledObjectStream 

22 

23from starlette._utils import is_async_callable 

24from starlette.types import ASGIApp, Message, Receive, Scope, Send 

25from starlette.websockets import WebSocketDisconnect 

26 

27if sys.version_info >= (3, 10): # pragma: no cover 

28 from typing import TypeGuard 

29else: # pragma: no cover 

30 from typing_extensions import TypeGuard 

31 

32try: 

33 import httpx 

34except ModuleNotFoundError: # pragma: no cover 

35 raise RuntimeError( 

36 "The starlette.testclient module requires the httpx package to be installed.\n" 

37 "You can install this with:\n" 

38 " $ pip install httpx\n" 

39 ) 

40_PortalFactoryType = typing.Callable[ 

41 [], typing.ContextManager[anyio.abc.BlockingPortal] 

42] 

43 

44ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]] 

45ASGI2App = typing.Callable[[Scope], ASGIInstance] 

46ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] 

47 

48 

49_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]] 

50 

51 

52def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]: 

53 if inspect.isclass(app): 

54 return hasattr(app, "__await__") 

55 return is_async_callable(app) 

56 

57 

58class _WrapASGI2: 

59 """ 

60 Provide an ASGI3 interface onto an ASGI2 app. 

61 """ 

62 

63 def __init__(self, app: ASGI2App) -> None: 

64 self.app = app 

65 

66 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

67 instance = self.app(scope) 

68 await instance(receive, send) 

69 

70 

71class _AsyncBackend(typing.TypedDict): 

72 backend: str 

73 backend_options: dict[str, typing.Any] 

74 

75 

76class _Upgrade(Exception): 

77 def __init__(self, session: WebSocketTestSession) -> None: 

78 self.session = session 

79 

80 

81class WebSocketDenialResponse( # type: ignore[misc] 

82 httpx.Response, 

83 WebSocketDisconnect, 

84): 

85 """ 

86 A special case of `WebSocketDisconnect`, raised in the `TestClient` if the 

87 `WebSocket` is closed before being accepted with a `send_denial_response()`. 

88 """ 

89 

90 

91class WebSocketTestSession: 

92 def __init__( 

93 self, 

94 app: ASGI3App, 

95 scope: Scope, 

96 portal_factory: _PortalFactoryType, 

97 ) -> None: 

98 self.app = app 

99 self.scope = scope 

100 self.accepted_subprotocol = None 

101 self.portal_factory = portal_factory 

102 self._receive_queue: queue.Queue[Message] = queue.Queue() 

103 self._send_queue: queue.Queue[Message | BaseException] = queue.Queue() 

104 self.extra_headers = None 

105 

106 def __enter__(self) -> WebSocketTestSession: 

107 self.exit_stack = contextlib.ExitStack() 

108 self.portal = self.exit_stack.enter_context(self.portal_factory()) 

109 

110 try: 

111 _: Future[None] = self.portal.start_task_soon(self._run) 

112 self.send({"type": "websocket.connect"}) 

113 message = self.receive() 

114 self._raise_on_close(message) 

115 except Exception: 

116 self.exit_stack.close() 

117 raise 

118 self.accepted_subprotocol = message.get("subprotocol", None) 

119 self.extra_headers = message.get("headers", None) 

120 return self 

121 

122 @cached_property 

123 def should_close(self) -> anyio.Event: 

124 return anyio.Event() 

125 

126 async def _notify_close(self) -> None: 

127 self.should_close.set() 

128 

129 def __exit__(self, *args: typing.Any) -> None: 

130 try: 

131 self.close(1000) 

132 finally: 

133 self.portal.start_task_soon(self._notify_close) 

134 self.exit_stack.close() 

135 while not self._send_queue.empty(): 

136 message = self._send_queue.get() 

137 if isinstance(message, BaseException): 

138 raise message 

139 

140 async def _run(self) -> None: 

141 """ 

142 The sub-thread in which the websocket session runs. 

143 """ 

144 

145 async def run_app(tg: anyio.abc.TaskGroup) -> None: 

146 try: 

147 await self.app(self.scope, self._asgi_receive, self._asgi_send) 

148 except anyio.get_cancelled_exc_class(): 

149 ... 

150 except BaseException as exc: 

151 self._send_queue.put(exc) 

152 raise 

153 finally: 

154 tg.cancel_scope.cancel() 

155 

156 async with anyio.create_task_group() as tg: 

157 tg.start_soon(run_app, tg) 

158 await self.should_close.wait() 

159 tg.cancel_scope.cancel() 

160 

161 async def _asgi_receive(self) -> Message: 

162 while self._receive_queue.empty(): 

163 await anyio.sleep(0) 

164 return self._receive_queue.get() 

165 

166 async def _asgi_send(self, message: Message) -> None: 

167 self._send_queue.put(message) 

168 

169 def _raise_on_close(self, message: Message) -> None: 

170 if message["type"] == "websocket.close": 

171 raise WebSocketDisconnect( 

172 code=message.get("code", 1000), reason=message.get("reason", "") 

173 ) 

174 elif message["type"] == "websocket.http.response.start": 

175 status_code: int = message["status"] 

176 headers: list[tuple[bytes, bytes]] = message["headers"] 

177 body: list[bytes] = [] 

178 while True: 

179 message = self.receive() 

180 assert message["type"] == "websocket.http.response.body" 

181 body.append(message["body"]) 

182 if not message.get("more_body", False): 

183 break 

184 raise WebSocketDenialResponse( 

185 status_code=status_code, 

186 headers=headers, 

187 content=b"".join(body), 

188 ) 

189 

190 def send(self, message: Message) -> None: 

191 self._receive_queue.put(message) 

192 

193 def send_text(self, data: str) -> None: 

194 self.send({"type": "websocket.receive", "text": data}) 

195 

196 def send_bytes(self, data: bytes) -> None: 

197 self.send({"type": "websocket.receive", "bytes": data}) 

198 

199 def send_json( 

200 self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text" 

201 ) -> None: 

202 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) 

203 if mode == "text": 

204 self.send({"type": "websocket.receive", "text": text}) 

205 else: 

206 self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) 

207 

208 def close(self, code: int = 1000, reason: str | None = None) -> None: 

209 self.send({"type": "websocket.disconnect", "code": code, "reason": reason}) 

210 

211 def receive(self) -> Message: 

212 message = self._send_queue.get() 

213 if isinstance(message, BaseException): 

214 raise message 

215 return message 

216 

217 def receive_text(self) -> str: 

218 message = self.receive() 

219 self._raise_on_close(message) 

220 return typing.cast(str, message["text"]) 

221 

222 def receive_bytes(self) -> bytes: 

223 message = self.receive() 

224 self._raise_on_close(message) 

225 return typing.cast(bytes, message["bytes"]) 

226 

227 def receive_json( 

228 self, mode: typing.Literal["text", "binary"] = "text" 

229 ) -> typing.Any: 

230 message = self.receive() 

231 self._raise_on_close(message) 

232 if mode == "text": 

233 text = message["text"] 

234 else: 

235 text = message["bytes"].decode("utf-8") 

236 return json.loads(text) 

237 

238 

239class _TestClientTransport(httpx.BaseTransport): 

240 def __init__( 

241 self, 

242 app: ASGI3App, 

243 portal_factory: _PortalFactoryType, 

244 raise_server_exceptions: bool = True, 

245 root_path: str = "", 

246 *, 

247 app_state: dict[str, typing.Any], 

248 ) -> None: 

249 self.app = app 

250 self.raise_server_exceptions = raise_server_exceptions 

251 self.root_path = root_path 

252 self.portal_factory = portal_factory 

253 self.app_state = app_state 

254 

255 def handle_request(self, request: httpx.Request) -> httpx.Response: 

256 scheme = request.url.scheme 

257 netloc = request.url.netloc.decode(encoding="ascii") 

258 path = request.url.path 

259 raw_path = request.url.raw_path 

260 query = request.url.query.decode(encoding="ascii") 

261 

262 default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] 

263 

264 if ":" in netloc: 

265 host, port_string = netloc.split(":", 1) 

266 port = int(port_string) 

267 else: 

268 host = netloc 

269 port = default_port 

270 

271 # Include the 'host' header. 

272 if "host" in request.headers: 

273 headers: list[tuple[bytes, bytes]] = [] 

274 elif port == default_port: # pragma: no cover 

275 headers = [(b"host", host.encode())] 

276 else: # pragma: no cover 

277 headers = [(b"host", (f"{host}:{port}").encode())] 

278 

279 # Include other request headers. 

280 headers += [ 

281 (key.lower().encode(), value.encode()) 

282 for key, value in request.headers.multi_items() 

283 ] 

284 

285 scope: dict[str, typing.Any] 

286 

287 if scheme in {"ws", "wss"}: 

288 subprotocol = request.headers.get("sec-websocket-protocol", None) 

289 if subprotocol is None: 

290 subprotocols: typing.Sequence[str] = [] 

291 else: 

292 subprotocols = [value.strip() for value in subprotocol.split(",")] 

293 scope = { 

294 "type": "websocket", 

295 "path": unquote(path), 

296 "raw_path": raw_path, 

297 "root_path": self.root_path, 

298 "scheme": scheme, 

299 "query_string": query.encode(), 

300 "headers": headers, 

301 "client": ["testclient", 50000], 

302 "server": [host, port], 

303 "subprotocols": subprotocols, 

304 "state": self.app_state.copy(), 

305 "extensions": {"websocket.http.response": {}}, 

306 } 

307 session = WebSocketTestSession(self.app, scope, self.portal_factory) 

308 raise _Upgrade(session) 

309 

310 scope = { 

311 "type": "http", 

312 "http_version": "1.1", 

313 "method": request.method, 

314 "path": unquote(path), 

315 "raw_path": raw_path, 

316 "root_path": self.root_path, 

317 "scheme": scheme, 

318 "query_string": query.encode(), 

319 "headers": headers, 

320 "client": ["testclient", 50000], 

321 "server": [host, port], 

322 "extensions": {"http.response.debug": {}}, 

323 "state": self.app_state.copy(), 

324 } 

325 

326 request_complete = False 

327 response_started = False 

328 response_complete: anyio.Event 

329 raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()} 

330 template = None 

331 context = None 

332 

333 async def receive() -> Message: 

334 nonlocal request_complete 

335 

336 if request_complete: 

337 if not response_complete.is_set(): 

338 await response_complete.wait() 

339 return {"type": "http.disconnect"} 

340 

341 body = request.read() 

342 if isinstance(body, str): 

343 body_bytes: bytes = body.encode("utf-8") # pragma: no cover 

344 elif body is None: 

345 body_bytes = b"" # pragma: no cover 

346 elif isinstance(body, GeneratorType): 

347 try: # pragma: no cover 

348 chunk = body.send(None) 

349 if isinstance(chunk, str): 

350 chunk = chunk.encode("utf-8") 

351 return {"type": "http.request", "body": chunk, "more_body": True} 

352 except StopIteration: # pragma: no cover 

353 request_complete = True 

354 return {"type": "http.request", "body": b""} 

355 else: 

356 body_bytes = body 

357 

358 request_complete = True 

359 return {"type": "http.request", "body": body_bytes} 

360 

361 async def send(message: Message) -> None: 

362 nonlocal raw_kwargs, response_started, template, context 

363 

364 if message["type"] == "http.response.start": 

365 assert ( 

366 not response_started 

367 ), 'Received multiple "http.response.start" messages.' 

368 raw_kwargs["status_code"] = message["status"] 

369 raw_kwargs["headers"] = [ 

370 (key.decode(), value.decode()) 

371 for key, value in message.get("headers", []) 

372 ] 

373 response_started = True 

374 elif message["type"] == "http.response.body": 

375 assert ( 

376 response_started 

377 ), 'Received "http.response.body" without "http.response.start".' 

378 assert ( 

379 not response_complete.is_set() 

380 ), 'Received "http.response.body" after response completed.' 

381 body = message.get("body", b"") 

382 more_body = message.get("more_body", False) 

383 if request.method != "HEAD": 

384 raw_kwargs["stream"].write(body) 

385 if not more_body: 

386 raw_kwargs["stream"].seek(0) 

387 response_complete.set() 

388 elif message["type"] == "http.response.debug": 

389 template = message["info"]["template"] 

390 context = message["info"]["context"] 

391 

392 try: 

393 with self.portal_factory() as portal: 

394 response_complete = portal.call(anyio.Event) 

395 portal.call(self.app, scope, receive, send) 

396 except BaseException as exc: 

397 if self.raise_server_exceptions: 

398 raise exc 

399 

400 if self.raise_server_exceptions: 

401 assert response_started, "TestClient did not receive any response." 

402 elif not response_started: 

403 raw_kwargs = { 

404 "status_code": 500, 

405 "headers": [], 

406 "stream": io.BytesIO(), 

407 } 

408 

409 raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read()) 

410 

411 response = httpx.Response(**raw_kwargs, request=request) 

412 if template is not None: 

413 response.template = template # type: ignore[attr-defined] 

414 response.context = context # type: ignore[attr-defined] 

415 return response 

416 

417 

418class TestClient(httpx.Client): 

419 __test__ = False 

420 task: Future[None] 

421 portal: anyio.abc.BlockingPortal | None = None 

422 

423 def __init__( 

424 self, 

425 app: ASGIApp, 

426 base_url: str = "http://testserver", 

427 raise_server_exceptions: bool = True, 

428 root_path: str = "", 

429 backend: typing.Literal["asyncio", "trio"] = "asyncio", 

430 backend_options: dict[str, typing.Any] | None = None, 

431 cookies: httpx._types.CookieTypes | None = None, 

432 headers: dict[str, str] | None = None, 

433 follow_redirects: bool = True, 

434 ) -> None: 

435 self.async_backend = _AsyncBackend( 

436 backend=backend, backend_options=backend_options or {} 

437 ) 

438 if _is_asgi3(app): 

439 asgi_app = app 

440 else: 

441 app = typing.cast(ASGI2App, app) # type: ignore[assignment] 

442 asgi_app = _WrapASGI2(app) # type: ignore[arg-type] 

443 self.app = asgi_app 

444 self.app_state: dict[str, typing.Any] = {} 

445 transport = _TestClientTransport( 

446 self.app, 

447 portal_factory=self._portal_factory, 

448 raise_server_exceptions=raise_server_exceptions, 

449 root_path=root_path, 

450 app_state=self.app_state, 

451 ) 

452 if headers is None: 

453 headers = {} 

454 headers.setdefault("user-agent", "testclient") 

455 super().__init__( 

456 base_url=base_url, 

457 headers=headers, 

458 transport=transport, 

459 follow_redirects=follow_redirects, 

460 cookies=cookies, 

461 ) 

462 

463 @contextlib.contextmanager 

464 def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]: 

465 if self.portal is not None: 

466 yield self.portal 

467 else: 

468 with anyio.from_thread.start_blocking_portal( 

469 **self.async_backend 

470 ) as portal: 

471 yield portal 

472 

473 def _choose_redirect_arg( 

474 self, follow_redirects: bool | None, allow_redirects: bool | None 

475 ) -> bool | httpx._client.UseClientDefault: 

476 redirect: bool | httpx._client.UseClientDefault = ( 

477 httpx._client.USE_CLIENT_DEFAULT 

478 ) 

479 if allow_redirects is not None: 

480 message = ( 

481 "The `allow_redirects` argument is deprecated. " 

482 "Use `follow_redirects` instead." 

483 ) 

484 warnings.warn(message, DeprecationWarning) 

485 redirect = allow_redirects 

486 if follow_redirects is not None: 

487 redirect = follow_redirects 

488 elif allow_redirects is not None and follow_redirects is not None: 

489 raise RuntimeError( # pragma: no cover 

490 "Cannot use both `allow_redirects` and `follow_redirects`." 

491 ) 

492 return redirect 

493 

494 def request( # type: ignore[override] 

495 self, 

496 method: str, 

497 url: httpx._types.URLTypes, 

498 *, 

499 content: httpx._types.RequestContent | None = None, 

500 data: _RequestData | None = None, 

501 files: httpx._types.RequestFiles | None = None, 

502 json: typing.Any = None, 

503 params: httpx._types.QueryParamTypes | None = None, 

504 headers: httpx._types.HeaderTypes | None = None, 

505 cookies: httpx._types.CookieTypes | None = None, 

506 auth: httpx._types.AuthTypes 

507 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

508 follow_redirects: bool | None = None, 

509 allow_redirects: bool | None = None, 

510 timeout: httpx._types.TimeoutTypes 

511 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

512 extensions: dict[str, typing.Any] | None = None, 

513 ) -> httpx.Response: 

514 url = self._merge_url(url) 

515 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

516 return super().request( 

517 method, 

518 url, 

519 content=content, 

520 data=data, 

521 files=files, 

522 json=json, 

523 params=params, 

524 headers=headers, 

525 cookies=cookies, 

526 auth=auth, 

527 follow_redirects=redirect, 

528 timeout=timeout, 

529 extensions=extensions, 

530 ) 

531 

532 def get( # type: ignore[override] 

533 self, 

534 url: httpx._types.URLTypes, 

535 *, 

536 params: httpx._types.QueryParamTypes | None = None, 

537 headers: httpx._types.HeaderTypes | None = None, 

538 cookies: httpx._types.CookieTypes | None = None, 

539 auth: httpx._types.AuthTypes 

540 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

541 follow_redirects: bool | None = None, 

542 allow_redirects: bool | None = None, 

543 timeout: httpx._types.TimeoutTypes 

544 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

545 extensions: dict[str, typing.Any] | None = None, 

546 ) -> httpx.Response: 

547 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

548 return super().get( 

549 url, 

550 params=params, 

551 headers=headers, 

552 cookies=cookies, 

553 auth=auth, 

554 follow_redirects=redirect, 

555 timeout=timeout, 

556 extensions=extensions, 

557 ) 

558 

559 def options( # type: ignore[override] 

560 self, 

561 url: httpx._types.URLTypes, 

562 *, 

563 params: httpx._types.QueryParamTypes | None = None, 

564 headers: httpx._types.HeaderTypes | None = None, 

565 cookies: httpx._types.CookieTypes | None = None, 

566 auth: httpx._types.AuthTypes 

567 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

568 follow_redirects: bool | None = None, 

569 allow_redirects: bool | None = None, 

570 timeout: httpx._types.TimeoutTypes 

571 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

572 extensions: dict[str, typing.Any] | None = None, 

573 ) -> httpx.Response: 

574 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

575 return super().options( 

576 url, 

577 params=params, 

578 headers=headers, 

579 cookies=cookies, 

580 auth=auth, 

581 follow_redirects=redirect, 

582 timeout=timeout, 

583 extensions=extensions, 

584 ) 

585 

586 def head( # type: ignore[override] 

587 self, 

588 url: httpx._types.URLTypes, 

589 *, 

590 params: httpx._types.QueryParamTypes | None = None, 

591 headers: httpx._types.HeaderTypes | None = None, 

592 cookies: httpx._types.CookieTypes | None = None, 

593 auth: httpx._types.AuthTypes 

594 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

595 follow_redirects: bool | None = None, 

596 allow_redirects: bool | None = None, 

597 timeout: httpx._types.TimeoutTypes 

598 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

599 extensions: dict[str, typing.Any] | None = None, 

600 ) -> httpx.Response: 

601 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

602 return super().head( 

603 url, 

604 params=params, 

605 headers=headers, 

606 cookies=cookies, 

607 auth=auth, 

608 follow_redirects=redirect, 

609 timeout=timeout, 

610 extensions=extensions, 

611 ) 

612 

613 def post( # type: ignore[override] 

614 self, 

615 url: httpx._types.URLTypes, 

616 *, 

617 content: httpx._types.RequestContent | None = None, 

618 data: _RequestData | None = None, 

619 files: httpx._types.RequestFiles | None = None, 

620 json: typing.Any = None, 

621 params: httpx._types.QueryParamTypes | None = None, 

622 headers: httpx._types.HeaderTypes | None = None, 

623 cookies: httpx._types.CookieTypes | None = None, 

624 auth: httpx._types.AuthTypes 

625 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

626 follow_redirects: bool | None = None, 

627 allow_redirects: bool | None = None, 

628 timeout: httpx._types.TimeoutTypes 

629 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

630 extensions: dict[str, typing.Any] | None = None, 

631 ) -> httpx.Response: 

632 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

633 return super().post( 

634 url, 

635 content=content, 

636 data=data, 

637 files=files, 

638 json=json, 

639 params=params, 

640 headers=headers, 

641 cookies=cookies, 

642 auth=auth, 

643 follow_redirects=redirect, 

644 timeout=timeout, 

645 extensions=extensions, 

646 ) 

647 

648 def put( # type: ignore[override] 

649 self, 

650 url: httpx._types.URLTypes, 

651 *, 

652 content: httpx._types.RequestContent | None = None, 

653 data: _RequestData | None = None, 

654 files: httpx._types.RequestFiles | None = None, 

655 json: typing.Any = None, 

656 params: httpx._types.QueryParamTypes | None = None, 

657 headers: httpx._types.HeaderTypes | None = None, 

658 cookies: httpx._types.CookieTypes | None = None, 

659 auth: httpx._types.AuthTypes 

660 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

661 follow_redirects: bool | None = None, 

662 allow_redirects: bool | None = None, 

663 timeout: httpx._types.TimeoutTypes 

664 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

665 extensions: dict[str, typing.Any] | None = None, 

666 ) -> httpx.Response: 

667 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

668 return super().put( 

669 url, 

670 content=content, 

671 data=data, 

672 files=files, 

673 json=json, 

674 params=params, 

675 headers=headers, 

676 cookies=cookies, 

677 auth=auth, 

678 follow_redirects=redirect, 

679 timeout=timeout, 

680 extensions=extensions, 

681 ) 

682 

683 def patch( # type: ignore[override] 

684 self, 

685 url: httpx._types.URLTypes, 

686 *, 

687 content: httpx._types.RequestContent | None = None, 

688 data: _RequestData | None = None, 

689 files: httpx._types.RequestFiles | None = None, 

690 json: typing.Any = None, 

691 params: httpx._types.QueryParamTypes | None = None, 

692 headers: httpx._types.HeaderTypes | None = None, 

693 cookies: httpx._types.CookieTypes | None = None, 

694 auth: httpx._types.AuthTypes 

695 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

696 follow_redirects: bool | None = None, 

697 allow_redirects: bool | None = None, 

698 timeout: httpx._types.TimeoutTypes 

699 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

700 extensions: dict[str, typing.Any] | None = None, 

701 ) -> httpx.Response: 

702 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

703 return super().patch( 

704 url, 

705 content=content, 

706 data=data, 

707 files=files, 

708 json=json, 

709 params=params, 

710 headers=headers, 

711 cookies=cookies, 

712 auth=auth, 

713 follow_redirects=redirect, 

714 timeout=timeout, 

715 extensions=extensions, 

716 ) 

717 

718 def delete( # type: ignore[override] 

719 self, 

720 url: httpx._types.URLTypes, 

721 *, 

722 params: httpx._types.QueryParamTypes | None = None, 

723 headers: httpx._types.HeaderTypes | None = None, 

724 cookies: httpx._types.CookieTypes | None = None, 

725 auth: httpx._types.AuthTypes 

726 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

727 follow_redirects: bool | None = None, 

728 allow_redirects: bool | None = None, 

729 timeout: httpx._types.TimeoutTypes 

730 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, 

731 extensions: dict[str, typing.Any] | None = None, 

732 ) -> httpx.Response: 

733 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

734 return super().delete( 

735 url, 

736 params=params, 

737 headers=headers, 

738 cookies=cookies, 

739 auth=auth, 

740 follow_redirects=redirect, 

741 timeout=timeout, 

742 extensions=extensions, 

743 ) 

744 

745 def websocket_connect( 

746 self, 

747 url: str, 

748 subprotocols: typing.Sequence[str] | None = None, 

749 **kwargs: typing.Any, 

750 ) -> WebSocketTestSession: 

751 url = urljoin("ws://testserver", url) 

752 headers = kwargs.get("headers", {}) 

753 headers.setdefault("connection", "upgrade") 

754 headers.setdefault("sec-websocket-key", "testserver==") 

755 headers.setdefault("sec-websocket-version", "13") 

756 if subprotocols is not None: 

757 headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) 

758 kwargs["headers"] = headers 

759 try: 

760 super().request("GET", url, **kwargs) 

761 except _Upgrade as exc: 

762 session = exc.session 

763 else: 

764 raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover 

765 

766 return session 

767 

768 def __enter__(self) -> TestClient: 

769 with contextlib.ExitStack() as stack: 

770 self.portal = portal = stack.enter_context( 

771 anyio.from_thread.start_blocking_portal(**self.async_backend) 

772 ) 

773 

774 @stack.callback 

775 def reset_portal() -> None: 

776 self.portal = None 

777 

778 send1: ObjectSendStream[typing.MutableMapping[str, typing.Any] | None] 

779 receive1: ObjectReceiveStream[typing.MutableMapping[str, typing.Any] | None] 

780 send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]] 

781 receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] 

782 send1, receive1 = anyio.create_memory_object_stream(math.inf) 

783 send2, receive2 = anyio.create_memory_object_stream(math.inf) 

784 self.stream_send = StapledObjectStream(send1, receive1) 

785 self.stream_receive = StapledObjectStream(send2, receive2) 

786 self.task = portal.start_task_soon(self.lifespan) 

787 portal.call(self.wait_startup) 

788 

789 @stack.callback 

790 def wait_shutdown() -> None: 

791 portal.call(self.wait_shutdown) 

792 

793 self.exit_stack = stack.pop_all() 

794 

795 return self 

796 

797 def __exit__(self, *args: typing.Any) -> None: 

798 self.exit_stack.close() 

799 

800 async def lifespan(self) -> None: 

801 scope = {"type": "lifespan", "state": self.app_state} 

802 try: 

803 await self.app(scope, self.stream_receive.receive, self.stream_send.send) 

804 finally: 

805 await self.stream_send.send(None) 

806 

807 async def wait_startup(self) -> None: 

808 await self.stream_receive.send({"type": "lifespan.startup"}) 

809 

810 async def receive() -> typing.Any: 

811 message = await self.stream_send.receive() 

812 if message is None: 

813 self.task.result() 

814 return message 

815 

816 message = await receive() 

817 assert message["type"] in ( 

818 "lifespan.startup.complete", 

819 "lifespan.startup.failed", 

820 ) 

821 if message["type"] == "lifespan.startup.failed": 

822 await receive() 

823 

824 async def wait_shutdown(self) -> None: 

825 async def receive() -> typing.Any: 

826 message = await self.stream_send.receive() 

827 if message is None: 

828 self.task.result() 

829 return message 

830 

831 async with self.stream_send: 

832 await self.stream_receive.send({"type": "lifespan.shutdown"}) 

833 message = await receive() 

834 assert message["type"] in ( 

835 "lifespan.shutdown.complete", 

836 "lifespan.shutdown.failed", 

837 ) 

838 if message["type"] == "lifespan.shutdown.failed": 

839 await receive()