Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/testclient.py: 23%

324 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-03-26 06:12 +0000

1import contextlib 

2import inspect 

3import io 

4import json 

5import math 

6import queue 

7import sys 

8import typing 

9import warnings 

10from concurrent.futures import Future 

11from types import GeneratorType 

12from urllib.parse import unquote, urljoin 

13 

14import anyio 

15import anyio.from_thread 

16import httpx 

17from anyio.streams.stapled import StapledObjectStream 

18 

19from starlette._utils import is_async_callable 

20from starlette.types import ASGIApp, Message, Receive, Scope, Send 

21from starlette.websockets import WebSocketDisconnect 

22 

23if sys.version_info >= (3, 8): # pragma: no cover 

24 from typing import TypedDict 

25else: # pragma: no cover 

26 from typing_extensions import TypedDict 

27 

28_PortalFactoryType = typing.Callable[ 

29 [], typing.ContextManager[anyio.abc.BlockingPortal] 

30] 

31 

32ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]] 

33ASGI2App = typing.Callable[[Scope], ASGIInstance] 

34ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] 

35 

36 

37_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str]]] 

38 

39 

40def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool: 

41 if inspect.isclass(app): 

42 return hasattr(app, "__await__") 

43 return is_async_callable(app) 

44 

45 

46class _WrapASGI2: 

47 """ 

48 Provide an ASGI3 interface onto an ASGI2 app. 

49 """ 

50 

51 def __init__(self, app: ASGI2App) -> None: 

52 self.app = app 

53 

54 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

55 instance = self.app(scope) 

56 await instance(receive, send) 

57 

58 

59class _AsyncBackend(TypedDict): 

60 backend: str 

61 backend_options: typing.Dict[str, typing.Any] 

62 

63 

64class _Upgrade(Exception): 

65 def __init__(self, session: "WebSocketTestSession") -> None: 

66 self.session = session 

67 

68 

69class WebSocketTestSession: 

70 def __init__( 

71 self, 

72 app: ASGI3App, 

73 scope: Scope, 

74 portal_factory: _PortalFactoryType, 

75 ) -> None: 

76 self.app = app 

77 self.scope = scope 

78 self.accepted_subprotocol = None 

79 self.portal_factory = portal_factory 

80 self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue() 

81 self._send_queue: "queue.Queue[typing.Any]" = queue.Queue() 

82 self.extra_headers = None 

83 

84 def __enter__(self) -> "WebSocketTestSession": 

85 self.exit_stack = contextlib.ExitStack() 

86 self.portal = self.exit_stack.enter_context(self.portal_factory()) 

87 

88 try: 

89 _: "Future[None]" = self.portal.start_task_soon(self._run) 

90 self.send({"type": "websocket.connect"}) 

91 message = self.receive() 

92 self._raise_on_close(message) 

93 except Exception: 

94 self.exit_stack.close() 

95 raise 

96 self.accepted_subprotocol = message.get("subprotocol", None) 

97 self.extra_headers = message.get("headers", None) 

98 return self 

99 

100 def __exit__(self, *args: typing.Any) -> None: 

101 try: 

102 self.close(1000) 

103 finally: 

104 self.exit_stack.close() 

105 while not self._send_queue.empty(): 

106 message = self._send_queue.get() 

107 if isinstance(message, BaseException): 

108 raise message 

109 

110 async def _run(self) -> None: 

111 """ 

112 The sub-thread in which the websocket session runs. 

113 """ 

114 scope = self.scope 

115 receive = self._asgi_receive 

116 send = self._asgi_send 

117 try: 

118 await self.app(scope, receive, send) 

119 except BaseException as exc: 

120 self._send_queue.put(exc) 

121 raise 

122 

123 async def _asgi_receive(self) -> Message: 

124 while self._receive_queue.empty(): 

125 await anyio.sleep(0) 

126 return self._receive_queue.get() 

127 

128 async def _asgi_send(self, message: Message) -> None: 

129 self._send_queue.put(message) 

130 

131 def _raise_on_close(self, message: Message) -> None: 

132 if message["type"] == "websocket.close": 

133 raise WebSocketDisconnect( 

134 message.get("code", 1000), message.get("reason", "") 

135 ) 

136 

137 def send(self, message: Message) -> None: 

138 self._receive_queue.put(message) 

139 

140 def send_text(self, data: str) -> None: 

141 self.send({"type": "websocket.receive", "text": data}) 

142 

143 def send_bytes(self, data: bytes) -> None: 

144 self.send({"type": "websocket.receive", "bytes": data}) 

145 

146 def send_json(self, data: typing.Any, mode: str = "text") -> None: 

147 assert mode in ["text", "binary"] 

148 text = json.dumps(data) 

149 if mode == "text": 

150 self.send({"type": "websocket.receive", "text": text}) 

151 else: 

152 self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) 

153 

154 def close(self, code: int = 1000) -> None: 

155 self.send({"type": "websocket.disconnect", "code": code}) 

156 

157 def receive(self) -> Message: 

158 message = self._send_queue.get() 

159 if isinstance(message, BaseException): 

160 raise message 

161 return message 

162 

163 def receive_text(self) -> str: 

164 message = self.receive() 

165 self._raise_on_close(message) 

166 return message["text"] 

167 

168 def receive_bytes(self) -> bytes: 

169 message = self.receive() 

170 self._raise_on_close(message) 

171 return message["bytes"] 

172 

173 def receive_json(self, mode: str = "text") -> typing.Any: 

174 assert mode in ["text", "binary"] 

175 message = self.receive() 

176 self._raise_on_close(message) 

177 if mode == "text": 

178 text = message["text"] 

179 else: 

180 text = message["bytes"].decode("utf-8") 

181 return json.loads(text) 

182 

183 

184class _TestClientTransport(httpx.BaseTransport): 

185 def __init__( 

186 self, 

187 app: ASGI3App, 

188 portal_factory: _PortalFactoryType, 

189 raise_server_exceptions: bool = True, 

190 root_path: str = "", 

191 ) -> None: 

192 self.app = app 

193 self.raise_server_exceptions = raise_server_exceptions 

194 self.root_path = root_path 

195 self.portal_factory = portal_factory 

196 

197 def handle_request(self, request: httpx.Request) -> httpx.Response: 

198 scheme = request.url.scheme 

199 netloc = request.url.netloc.decode(encoding="ascii") 

200 path = request.url.path 

201 raw_path = request.url.raw_path 

202 query = request.url.query.decode(encoding="ascii") 

203 

204 default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] 

205 

206 if ":" in netloc: 

207 host, port_string = netloc.split(":", 1) 

208 port = int(port_string) 

209 else: 

210 host = netloc 

211 port = default_port 

212 

213 # Include the 'host' header. 

214 if "host" in request.headers: 

215 headers: typing.List[typing.Tuple[bytes, bytes]] = [] 

216 elif port == default_port: # pragma: no cover 

217 headers = [(b"host", host.encode())] 

218 else: # pragma: no cover 

219 headers = [(b"host", (f"{host}:{port}").encode())] 

220 

221 # Include other request headers. 

222 headers += [ 

223 (key.lower().encode(), value.encode()) 

224 for key, value in request.headers.items() 

225 ] 

226 

227 scope: typing.Dict[str, typing.Any] 

228 

229 if scheme in {"ws", "wss"}: 

230 subprotocol = request.headers.get("sec-websocket-protocol", None) 

231 if subprotocol is None: 

232 subprotocols: typing.Sequence[str] = [] 

233 else: 

234 subprotocols = [value.strip() for value in subprotocol.split(",")] 

235 scope = { 

236 "type": "websocket", 

237 "path": unquote(path), 

238 "raw_path": raw_path, 

239 "root_path": self.root_path, 

240 "scheme": scheme, 

241 "query_string": query.encode(), 

242 "headers": headers, 

243 "client": ["testclient", 50000], 

244 "server": [host, port], 

245 "subprotocols": subprotocols, 

246 } 

247 session = WebSocketTestSession(self.app, scope, self.portal_factory) 

248 raise _Upgrade(session) 

249 

250 scope = { 

251 "type": "http", 

252 "http_version": "1.1", 

253 "method": request.method, 

254 "path": unquote(path), 

255 "raw_path": raw_path, 

256 "root_path": self.root_path, 

257 "scheme": scheme, 

258 "query_string": query.encode(), 

259 "headers": headers, 

260 "client": ["testclient", 50000], 

261 "server": [host, port], 

262 "extensions": {"http.response.debug": {}}, 

263 } 

264 

265 request_complete = False 

266 response_started = False 

267 response_complete: anyio.Event 

268 raw_kwargs: typing.Dict[str, typing.Any] = {"stream": io.BytesIO()} 

269 template = None 

270 context = None 

271 

272 async def receive() -> Message: 

273 nonlocal request_complete 

274 

275 if request_complete: 

276 if not response_complete.is_set(): 

277 await response_complete.wait() 

278 return {"type": "http.disconnect"} 

279 

280 body = request.read() 

281 if isinstance(body, str): 

282 body_bytes: bytes = body.encode("utf-8") # pragma: no cover 

283 elif body is None: 

284 body_bytes = b"" # pragma: no cover 

285 elif isinstance(body, GeneratorType): 

286 try: # pragma: no cover 

287 chunk = body.send(None) 

288 if isinstance(chunk, str): 

289 chunk = chunk.encode("utf-8") 

290 return {"type": "http.request", "body": chunk, "more_body": True} 

291 except StopIteration: # pragma: no cover 

292 request_complete = True 

293 return {"type": "http.request", "body": b""} 

294 else: 

295 body_bytes = body 

296 

297 request_complete = True 

298 return {"type": "http.request", "body": body_bytes} 

299 

300 async def send(message: Message) -> None: 

301 nonlocal raw_kwargs, response_started, template, context 

302 

303 if message["type"] == "http.response.start": 

304 assert ( 

305 not response_started 

306 ), 'Received multiple "http.response.start" messages.' 

307 raw_kwargs["status_code"] = message["status"] 

308 raw_kwargs["headers"] = [ 

309 (key.decode(), value.decode()) 

310 for key, value in message.get("headers", []) 

311 ] 

312 response_started = True 

313 elif message["type"] == "http.response.body": 

314 assert ( 

315 response_started 

316 ), 'Received "http.response.body" without "http.response.start".' 

317 assert ( 

318 not response_complete.is_set() 

319 ), 'Received "http.response.body" after response completed.' 

320 body = message.get("body", b"") 

321 more_body = message.get("more_body", False) 

322 if request.method != "HEAD": 

323 raw_kwargs["stream"].write(body) 

324 if not more_body: 

325 raw_kwargs["stream"].seek(0) 

326 response_complete.set() 

327 elif message["type"] == "http.response.debug": 

328 template = message["info"]["template"] 

329 context = message["info"]["context"] 

330 

331 try: 

332 with self.portal_factory() as portal: 

333 response_complete = portal.call(anyio.Event) 

334 portal.call(self.app, scope, receive, send) 

335 except BaseException as exc: 

336 if self.raise_server_exceptions: 

337 raise exc 

338 

339 if self.raise_server_exceptions: 

340 assert response_started, "TestClient did not receive any response." 

341 elif not response_started: 

342 raw_kwargs = { 

343 "status_code": 500, 

344 "headers": [], 

345 "stream": io.BytesIO(), 

346 } 

347 

348 raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read()) 

349 

350 response = httpx.Response(**raw_kwargs, request=request) 

351 if template is not None: 

352 response.template = template # type: ignore[attr-defined] 

353 response.context = context # type: ignore[attr-defined] 

354 return response 

355 

356 

357class TestClient(httpx.Client): 

358 __test__ = False 

359 task: "Future[None]" 

360 portal: typing.Optional[anyio.abc.BlockingPortal] = None 

361 

362 def __init__( 

363 self, 

364 app: ASGIApp, 

365 base_url: str = "http://testserver", 

366 raise_server_exceptions: bool = True, 

367 root_path: str = "", 

368 backend: str = "asyncio", 

369 backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None, 

370 cookies: httpx._client.CookieTypes = None, 

371 headers: typing.Dict[str, str] = None, 

372 ) -> None: 

373 self.async_backend = _AsyncBackend( 

374 backend=backend, backend_options=backend_options or {} 

375 ) 

376 if _is_asgi3(app): 

377 app = typing.cast(ASGI3App, app) 

378 asgi_app = app 

379 else: 

380 app = typing.cast(ASGI2App, app) # type: ignore[assignment] 

381 asgi_app = _WrapASGI2(app) # type: ignore[arg-type] 

382 self.app = asgi_app 

383 transport = _TestClientTransport( 

384 self.app, 

385 portal_factory=self._portal_factory, 

386 raise_server_exceptions=raise_server_exceptions, 

387 root_path=root_path, 

388 ) 

389 if headers is None: 

390 headers = {} 

391 headers.setdefault("user-agent", "testclient") 

392 super().__init__( 

393 app=self.app, 

394 base_url=base_url, 

395 headers=headers, 

396 transport=transport, 

397 follow_redirects=True, 

398 cookies=cookies, 

399 ) 

400 

401 @contextlib.contextmanager 

402 def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]: 

403 if self.portal is not None: 

404 yield self.portal 

405 else: 

406 with anyio.from_thread.start_blocking_portal( 

407 **self.async_backend 

408 ) as portal: 

409 yield portal 

410 

411 def _choose_redirect_arg( 

412 self, 

413 follow_redirects: typing.Optional[bool], 

414 allow_redirects: typing.Optional[bool], 

415 ) -> typing.Union[bool, httpx._client.UseClientDefault]: 

416 redirect: typing.Union[ 

417 bool, httpx._client.UseClientDefault 

418 ] = httpx._client.USE_CLIENT_DEFAULT 

419 if allow_redirects is not None: 

420 message = ( 

421 "The `allow_redirects` argument is deprecated. " 

422 "Use `follow_redirects` instead." 

423 ) 

424 warnings.warn(message, DeprecationWarning) 

425 redirect = allow_redirects 

426 if follow_redirects is not None: 

427 redirect = follow_redirects 

428 elif allow_redirects is not None and follow_redirects is not None: 

429 raise RuntimeError( # pragma: no cover 

430 "Cannot use both `allow_redirects` and `follow_redirects`." 

431 ) 

432 return redirect 

433 

434 def request( # type: ignore[override] 

435 self, 

436 method: str, 

437 url: httpx._types.URLTypes, 

438 *, 

439 content: typing.Optional[httpx._types.RequestContent] = None, 

440 data: typing.Optional[_RequestData] = None, 

441 files: typing.Optional[httpx._types.RequestFiles] = None, 

442 json: typing.Any = None, 

443 params: typing.Optional[httpx._types.QueryParamTypes] = None, 

444 headers: typing.Optional[httpx._types.HeaderTypes] = None, 

445 cookies: typing.Optional[httpx._types.CookieTypes] = None, 

446 auth: typing.Union[ 

447 httpx._types.AuthTypes, httpx._client.UseClientDefault 

448 ] = httpx._client.USE_CLIENT_DEFAULT, 

449 follow_redirects: typing.Optional[bool] = None, 

450 allow_redirects: typing.Optional[bool] = None, 

451 timeout: typing.Union[ 

452 httpx._client.TimeoutTypes, httpx._client.UseClientDefault 

453 ] = httpx._client.USE_CLIENT_DEFAULT, 

454 extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, 

455 ) -> httpx.Response: 

456 url = self.base_url.join(url) 

457 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

458 return super().request( 

459 method, 

460 url, 

461 content=content, 

462 data=data, # type: ignore[arg-type] 

463 files=files, 

464 json=json, 

465 params=params, 

466 headers=headers, 

467 cookies=cookies, 

468 auth=auth, 

469 follow_redirects=redirect, 

470 timeout=timeout, 

471 extensions=extensions, 

472 ) 

473 

474 def get( # type: ignore[override] 

475 self, 

476 url: httpx._types.URLTypes, 

477 *, 

478 params: typing.Optional[httpx._types.QueryParamTypes] = None, 

479 headers: typing.Optional[httpx._types.HeaderTypes] = None, 

480 cookies: typing.Optional[httpx._types.CookieTypes] = None, 

481 auth: typing.Union[ 

482 httpx._types.AuthTypes, httpx._client.UseClientDefault 

483 ] = httpx._client.USE_CLIENT_DEFAULT, 

484 follow_redirects: typing.Optional[bool] = None, 

485 allow_redirects: typing.Optional[bool] = None, 

486 timeout: typing.Union[ 

487 httpx._client.TimeoutTypes, httpx._client.UseClientDefault 

488 ] = httpx._client.USE_CLIENT_DEFAULT, 

489 extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, 

490 ) -> httpx.Response: 

491 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

492 return super().get( 

493 url, 

494 params=params, 

495 headers=headers, 

496 cookies=cookies, 

497 auth=auth, 

498 follow_redirects=redirect, 

499 timeout=timeout, 

500 extensions=extensions, 

501 ) 

502 

503 def options( # type: ignore[override] 

504 self, 

505 url: httpx._types.URLTypes, 

506 *, 

507 params: typing.Optional[httpx._types.QueryParamTypes] = None, 

508 headers: typing.Optional[httpx._types.HeaderTypes] = None, 

509 cookies: typing.Optional[httpx._types.CookieTypes] = None, 

510 auth: typing.Union[ 

511 httpx._types.AuthTypes, httpx._client.UseClientDefault 

512 ] = httpx._client.USE_CLIENT_DEFAULT, 

513 follow_redirects: typing.Optional[bool] = None, 

514 allow_redirects: typing.Optional[bool] = None, 

515 timeout: typing.Union[ 

516 httpx._client.TimeoutTypes, httpx._client.UseClientDefault 

517 ] = httpx._client.USE_CLIENT_DEFAULT, 

518 extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, 

519 ) -> httpx.Response: 

520 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

521 return super().options( 

522 url, 

523 params=params, 

524 headers=headers, 

525 cookies=cookies, 

526 auth=auth, 

527 follow_redirects=redirect, 

528 timeout=timeout, 

529 extensions=extensions, 

530 ) 

531 

532 def head( # type: ignore[override] 

533 self, 

534 url: httpx._types.URLTypes, 

535 *, 

536 params: typing.Optional[httpx._types.QueryParamTypes] = None, 

537 headers: typing.Optional[httpx._types.HeaderTypes] = None, 

538 cookies: typing.Optional[httpx._types.CookieTypes] = None, 

539 auth: typing.Union[ 

540 httpx._types.AuthTypes, httpx._client.UseClientDefault 

541 ] = httpx._client.USE_CLIENT_DEFAULT, 

542 follow_redirects: typing.Optional[bool] = None, 

543 allow_redirects: typing.Optional[bool] = None, 

544 timeout: typing.Union[ 

545 httpx._client.TimeoutTypes, httpx._client.UseClientDefault 

546 ] = httpx._client.USE_CLIENT_DEFAULT, 

547 extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, 

548 ) -> httpx.Response: 

549 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

550 return super().head( 

551 url, 

552 params=params, 

553 headers=headers, 

554 cookies=cookies, 

555 auth=auth, 

556 follow_redirects=redirect, 

557 timeout=timeout, 

558 extensions=extensions, 

559 ) 

560 

561 def post( # type: ignore[override] 

562 self, 

563 url: httpx._types.URLTypes, 

564 *, 

565 content: typing.Optional[httpx._types.RequestContent] = None, 

566 data: typing.Optional[_RequestData] = None, 

567 files: typing.Optional[httpx._types.RequestFiles] = None, 

568 json: typing.Any = None, 

569 params: typing.Optional[httpx._types.QueryParamTypes] = None, 

570 headers: typing.Optional[httpx._types.HeaderTypes] = None, 

571 cookies: typing.Optional[httpx._types.CookieTypes] = None, 

572 auth: typing.Union[ 

573 httpx._types.AuthTypes, httpx._client.UseClientDefault 

574 ] = httpx._client.USE_CLIENT_DEFAULT, 

575 follow_redirects: typing.Optional[bool] = None, 

576 allow_redirects: typing.Optional[bool] = None, 

577 timeout: typing.Union[ 

578 httpx._client.TimeoutTypes, httpx._client.UseClientDefault 

579 ] = httpx._client.USE_CLIENT_DEFAULT, 

580 extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, 

581 ) -> httpx.Response: 

582 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

583 return super().post( 

584 url, 

585 content=content, 

586 data=data, # type: ignore[arg-type] 

587 files=files, 

588 json=json, 

589 params=params, 

590 headers=headers, 

591 cookies=cookies, 

592 auth=auth, 

593 follow_redirects=redirect, 

594 timeout=timeout, 

595 extensions=extensions, 

596 ) 

597 

598 def put( # type: ignore[override] 

599 self, 

600 url: httpx._types.URLTypes, 

601 *, 

602 content: typing.Optional[httpx._types.RequestContent] = None, 

603 data: typing.Optional[_RequestData] = None, 

604 files: typing.Optional[httpx._types.RequestFiles] = None, 

605 json: typing.Any = None, 

606 params: typing.Optional[httpx._types.QueryParamTypes] = None, 

607 headers: typing.Optional[httpx._types.HeaderTypes] = None, 

608 cookies: typing.Optional[httpx._types.CookieTypes] = None, 

609 auth: typing.Union[ 

610 httpx._types.AuthTypes, httpx._client.UseClientDefault 

611 ] = httpx._client.USE_CLIENT_DEFAULT, 

612 follow_redirects: typing.Optional[bool] = None, 

613 allow_redirects: typing.Optional[bool] = None, 

614 timeout: typing.Union[ 

615 httpx._client.TimeoutTypes, httpx._client.UseClientDefault 

616 ] = httpx._client.USE_CLIENT_DEFAULT, 

617 extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, 

618 ) -> httpx.Response: 

619 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

620 return super().put( 

621 url, 

622 content=content, 

623 data=data, # type: ignore[arg-type] 

624 files=files, 

625 json=json, 

626 params=params, 

627 headers=headers, 

628 cookies=cookies, 

629 auth=auth, 

630 follow_redirects=redirect, 

631 timeout=timeout, 

632 extensions=extensions, 

633 ) 

634 

635 def patch( # type: ignore[override] 

636 self, 

637 url: httpx._types.URLTypes, 

638 *, 

639 content: typing.Optional[httpx._types.RequestContent] = None, 

640 data: typing.Optional[_RequestData] = None, 

641 files: typing.Optional[httpx._types.RequestFiles] = None, 

642 json: typing.Any = None, 

643 params: typing.Optional[httpx._types.QueryParamTypes] = None, 

644 headers: typing.Optional[httpx._types.HeaderTypes] = None, 

645 cookies: typing.Optional[httpx._types.CookieTypes] = None, 

646 auth: typing.Union[ 

647 httpx._types.AuthTypes, httpx._client.UseClientDefault 

648 ] = httpx._client.USE_CLIENT_DEFAULT, 

649 follow_redirects: typing.Optional[bool] = None, 

650 allow_redirects: typing.Optional[bool] = None, 

651 timeout: typing.Union[ 

652 httpx._client.TimeoutTypes, httpx._client.UseClientDefault 

653 ] = httpx._client.USE_CLIENT_DEFAULT, 

654 extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, 

655 ) -> httpx.Response: 

656 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

657 return super().patch( 

658 url, 

659 content=content, 

660 data=data, # type: ignore[arg-type] 

661 files=files, 

662 json=json, 

663 params=params, 

664 headers=headers, 

665 cookies=cookies, 

666 auth=auth, 

667 follow_redirects=redirect, 

668 timeout=timeout, 

669 extensions=extensions, 

670 ) 

671 

672 def delete( # type: ignore[override] 

673 self, 

674 url: httpx._types.URLTypes, 

675 *, 

676 params: typing.Optional[httpx._types.QueryParamTypes] = None, 

677 headers: typing.Optional[httpx._types.HeaderTypes] = None, 

678 cookies: typing.Optional[httpx._types.CookieTypes] = None, 

679 auth: typing.Union[ 

680 httpx._types.AuthTypes, httpx._client.UseClientDefault 

681 ] = httpx._client.USE_CLIENT_DEFAULT, 

682 follow_redirects: typing.Optional[bool] = None, 

683 allow_redirects: typing.Optional[bool] = None, 

684 timeout: typing.Union[ 

685 httpx._client.TimeoutTypes, httpx._client.UseClientDefault 

686 ] = httpx._client.USE_CLIENT_DEFAULT, 

687 extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, 

688 ) -> httpx.Response: 

689 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) 

690 return super().delete( 

691 url, 

692 params=params, 

693 headers=headers, 

694 cookies=cookies, 

695 auth=auth, 

696 follow_redirects=redirect, 

697 timeout=timeout, 

698 extensions=extensions, 

699 ) 

700 

701 def websocket_connect( 

702 self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any 

703 ) -> typing.Any: 

704 url = urljoin("ws://testserver", url) 

705 headers = kwargs.get("headers", {}) 

706 headers.setdefault("connection", "upgrade") 

707 headers.setdefault("sec-websocket-key", "testserver==") 

708 headers.setdefault("sec-websocket-version", "13") 

709 if subprotocols is not None: 

710 headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) 

711 kwargs["headers"] = headers 

712 try: 

713 super().request("GET", url, **kwargs) 

714 except _Upgrade as exc: 

715 session = exc.session 

716 else: 

717 raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover 

718 

719 return session 

720 

721 def __enter__(self) -> "TestClient": 

722 with contextlib.ExitStack() as stack: 

723 self.portal = portal = stack.enter_context( 

724 anyio.from_thread.start_blocking_portal(**self.async_backend) 

725 ) 

726 

727 @stack.callback 

728 def reset_portal() -> None: 

729 self.portal = None 

730 

731 self.stream_send = StapledObjectStream( 

732 *anyio.create_memory_object_stream(math.inf) 

733 ) 

734 self.stream_receive = StapledObjectStream( 

735 *anyio.create_memory_object_stream(math.inf) 

736 ) 

737 self.task = portal.start_task_soon(self.lifespan) 

738 portal.call(self.wait_startup) 

739 

740 @stack.callback 

741 def wait_shutdown() -> None: 

742 portal.call(self.wait_shutdown) 

743 

744 self.exit_stack = stack.pop_all() 

745 

746 return self 

747 

748 def __exit__(self, *args: typing.Any) -> None: 

749 self.exit_stack.close() 

750 

751 async def lifespan(self) -> None: 

752 scope = {"type": "lifespan"} 

753 try: 

754 await self.app(scope, self.stream_receive.receive, self.stream_send.send) 

755 finally: 

756 await self.stream_send.send(None) 

757 

758 async def wait_startup(self) -> None: 

759 await self.stream_receive.send({"type": "lifespan.startup"}) 

760 

761 async def receive() -> typing.Any: 

762 message = await self.stream_send.receive() 

763 if message is None: 

764 self.task.result() 

765 return message 

766 

767 message = await receive() 

768 assert message["type"] in ( 

769 "lifespan.startup.complete", 

770 "lifespan.startup.failed", 

771 ) 

772 if message["type"] == "lifespan.startup.failed": 

773 await receive() 

774 

775 async def wait_shutdown(self) -> None: 

776 async def receive() -> typing.Any: 

777 message = await self.stream_send.receive() 

778 if message is None: 

779 self.task.result() 

780 return message 

781 

782 async with self.stream_send: 

783 await self.stream_receive.send({"type": "lifespan.shutdown"}) 

784 message = await receive() 

785 assert message["type"] in ( 

786 "lifespan.shutdown.complete", 

787 "lifespan.shutdown.failed", 

788 ) 

789 if message["type"] == "lifespan.shutdown.failed": 

790 await receive()