1from __future__ import annotations
2
3import contextlib
4import inspect
5import io
6import json
7import math
8import queue
9import sys
10import typing
11import warnings
12from concurrent.futures import Future
13from functools import cached_property
14from types import GeneratorType
15from urllib.parse import unquote, urljoin
16
17import anyio
18import anyio.abc
19import anyio.from_thread
20from anyio.abc import ObjectReceiveStream, ObjectSendStream
21from anyio.streams.stapled import StapledObjectStream
22
23from starlette._utils import is_async_callable
24from starlette.types import ASGIApp, Message, Receive, Scope, Send
25from starlette.websockets import WebSocketDisconnect
26
27if sys.version_info >= (3, 10): # pragma: no cover
28 from typing import TypeGuard
29else: # pragma: no cover
30 from typing_extensions import TypeGuard
31
32try:
33 import httpx
34except ModuleNotFoundError: # pragma: no cover
35 raise RuntimeError(
36 "The starlette.testclient module requires the httpx package to be installed.\n"
37 "You can install this with:\n"
38 " $ pip install httpx\n"
39 )
40_PortalFactoryType = typing.Callable[
41 [], typing.ContextManager[anyio.abc.BlockingPortal]
42]
43
44ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
45ASGI2App = typing.Callable[[Scope], ASGIInstance]
46ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
47
48
49_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]]
50
51
52def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
53 if inspect.isclass(app):
54 return hasattr(app, "__await__")
55 return is_async_callable(app)
56
57
58class _WrapASGI2:
59 """
60 Provide an ASGI3 interface onto an ASGI2 app.
61 """
62
63 def __init__(self, app: ASGI2App) -> None:
64 self.app = app
65
66 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
67 instance = self.app(scope)
68 await instance(receive, send)
69
70
71class _AsyncBackend(typing.TypedDict):
72 backend: str
73 backend_options: dict[str, typing.Any]
74
75
76class _Upgrade(Exception):
77 def __init__(self, session: WebSocketTestSession) -> None:
78 self.session = session
79
80
81class WebSocketDenialResponse( # type: ignore[misc]
82 httpx.Response,
83 WebSocketDisconnect,
84):
85 """
86 A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
87 `WebSocket` is closed before being accepted with a `send_denial_response()`.
88 """
89
90
91class WebSocketTestSession:
92 def __init__(
93 self,
94 app: ASGI3App,
95 scope: Scope,
96 portal_factory: _PortalFactoryType,
97 ) -> None:
98 self.app = app
99 self.scope = scope
100 self.accepted_subprotocol = None
101 self.portal_factory = portal_factory
102 self._receive_queue: queue.Queue[Message] = queue.Queue()
103 self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
104 self.extra_headers = None
105
106 def __enter__(self) -> WebSocketTestSession:
107 self.exit_stack = contextlib.ExitStack()
108 self.portal = self.exit_stack.enter_context(self.portal_factory())
109
110 try:
111 _: Future[None] = self.portal.start_task_soon(self._run)
112 self.send({"type": "websocket.connect"})
113 message = self.receive()
114 self._raise_on_close(message)
115 except Exception:
116 self.exit_stack.close()
117 raise
118 self.accepted_subprotocol = message.get("subprotocol", None)
119 self.extra_headers = message.get("headers", None)
120 return self
121
122 @cached_property
123 def should_close(self) -> anyio.Event:
124 return anyio.Event()
125
126 async def _notify_close(self) -> None:
127 self.should_close.set()
128
129 def __exit__(self, *args: typing.Any) -> None:
130 try:
131 self.close(1000)
132 finally:
133 self.portal.start_task_soon(self._notify_close)
134 self.exit_stack.close()
135 while not self._send_queue.empty():
136 message = self._send_queue.get()
137 if isinstance(message, BaseException):
138 raise message
139
140 async def _run(self) -> None:
141 """
142 The sub-thread in which the websocket session runs.
143 """
144
145 async def run_app(tg: anyio.abc.TaskGroup) -> None:
146 try:
147 await self.app(self.scope, self._asgi_receive, self._asgi_send)
148 except anyio.get_cancelled_exc_class():
149 ...
150 except BaseException as exc:
151 self._send_queue.put(exc)
152 raise
153 finally:
154 tg.cancel_scope.cancel()
155
156 async with anyio.create_task_group() as tg:
157 tg.start_soon(run_app, tg)
158 await self.should_close.wait()
159 tg.cancel_scope.cancel()
160
161 async def _asgi_receive(self) -> Message:
162 while self._receive_queue.empty():
163 await anyio.sleep(0)
164 return self._receive_queue.get()
165
166 async def _asgi_send(self, message: Message) -> None:
167 self._send_queue.put(message)
168
169 def _raise_on_close(self, message: Message) -> None:
170 if message["type"] == "websocket.close":
171 raise WebSocketDisconnect(
172 code=message.get("code", 1000), reason=message.get("reason", "")
173 )
174 elif message["type"] == "websocket.http.response.start":
175 status_code: int = message["status"]
176 headers: list[tuple[bytes, bytes]] = message["headers"]
177 body: list[bytes] = []
178 while True:
179 message = self.receive()
180 assert message["type"] == "websocket.http.response.body"
181 body.append(message["body"])
182 if not message.get("more_body", False):
183 break
184 raise WebSocketDenialResponse(
185 status_code=status_code,
186 headers=headers,
187 content=b"".join(body),
188 )
189
190 def send(self, message: Message) -> None:
191 self._receive_queue.put(message)
192
193 def send_text(self, data: str) -> None:
194 self.send({"type": "websocket.receive", "text": data})
195
196 def send_bytes(self, data: bytes) -> None:
197 self.send({"type": "websocket.receive", "bytes": data})
198
199 def send_json(
200 self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text"
201 ) -> None:
202 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
203 if mode == "text":
204 self.send({"type": "websocket.receive", "text": text})
205 else:
206 self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
207
208 def close(self, code: int = 1000, reason: str | None = None) -> None:
209 self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
210
211 def receive(self) -> Message:
212 message = self._send_queue.get()
213 if isinstance(message, BaseException):
214 raise message
215 return message
216
217 def receive_text(self) -> str:
218 message = self.receive()
219 self._raise_on_close(message)
220 return typing.cast(str, message["text"])
221
222 def receive_bytes(self) -> bytes:
223 message = self.receive()
224 self._raise_on_close(message)
225 return typing.cast(bytes, message["bytes"])
226
227 def receive_json(
228 self, mode: typing.Literal["text", "binary"] = "text"
229 ) -> typing.Any:
230 message = self.receive()
231 self._raise_on_close(message)
232 if mode == "text":
233 text = message["text"]
234 else:
235 text = message["bytes"].decode("utf-8")
236 return json.loads(text)
237
238
239class _TestClientTransport(httpx.BaseTransport):
240 def __init__(
241 self,
242 app: ASGI3App,
243 portal_factory: _PortalFactoryType,
244 raise_server_exceptions: bool = True,
245 root_path: str = "",
246 *,
247 app_state: dict[str, typing.Any],
248 ) -> None:
249 self.app = app
250 self.raise_server_exceptions = raise_server_exceptions
251 self.root_path = root_path
252 self.portal_factory = portal_factory
253 self.app_state = app_state
254
255 def handle_request(self, request: httpx.Request) -> httpx.Response:
256 scheme = request.url.scheme
257 netloc = request.url.netloc.decode(encoding="ascii")
258 path = request.url.path
259 raw_path = request.url.raw_path
260 query = request.url.query.decode(encoding="ascii")
261
262 default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
263
264 if ":" in netloc:
265 host, port_string = netloc.split(":", 1)
266 port = int(port_string)
267 else:
268 host = netloc
269 port = default_port
270
271 # Include the 'host' header.
272 if "host" in request.headers:
273 headers: list[tuple[bytes, bytes]] = []
274 elif port == default_port: # pragma: no cover
275 headers = [(b"host", host.encode())]
276 else: # pragma: no cover
277 headers = [(b"host", (f"{host}:{port}").encode())]
278
279 # Include other request headers.
280 headers += [
281 (key.lower().encode(), value.encode())
282 for key, value in request.headers.multi_items()
283 ]
284
285 scope: dict[str, typing.Any]
286
287 if scheme in {"ws", "wss"}:
288 subprotocol = request.headers.get("sec-websocket-protocol", None)
289 if subprotocol is None:
290 subprotocols: typing.Sequence[str] = []
291 else:
292 subprotocols = [value.strip() for value in subprotocol.split(",")]
293 scope = {
294 "type": "websocket",
295 "path": unquote(path),
296 "raw_path": raw_path,
297 "root_path": self.root_path,
298 "scheme": scheme,
299 "query_string": query.encode(),
300 "headers": headers,
301 "client": ["testclient", 50000],
302 "server": [host, port],
303 "subprotocols": subprotocols,
304 "state": self.app_state.copy(),
305 "extensions": {"websocket.http.response": {}},
306 }
307 session = WebSocketTestSession(self.app, scope, self.portal_factory)
308 raise _Upgrade(session)
309
310 scope = {
311 "type": "http",
312 "http_version": "1.1",
313 "method": request.method,
314 "path": unquote(path),
315 "raw_path": raw_path,
316 "root_path": self.root_path,
317 "scheme": scheme,
318 "query_string": query.encode(),
319 "headers": headers,
320 "client": ["testclient", 50000],
321 "server": [host, port],
322 "extensions": {"http.response.debug": {}},
323 "state": self.app_state.copy(),
324 }
325
326 request_complete = False
327 response_started = False
328 response_complete: anyio.Event
329 raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()}
330 template = None
331 context = None
332
333 async def receive() -> Message:
334 nonlocal request_complete
335
336 if request_complete:
337 if not response_complete.is_set():
338 await response_complete.wait()
339 return {"type": "http.disconnect"}
340
341 body = request.read()
342 if isinstance(body, str):
343 body_bytes: bytes = body.encode("utf-8") # pragma: no cover
344 elif body is None:
345 body_bytes = b"" # pragma: no cover
346 elif isinstance(body, GeneratorType):
347 try: # pragma: no cover
348 chunk = body.send(None)
349 if isinstance(chunk, str):
350 chunk = chunk.encode("utf-8")
351 return {"type": "http.request", "body": chunk, "more_body": True}
352 except StopIteration: # pragma: no cover
353 request_complete = True
354 return {"type": "http.request", "body": b""}
355 else:
356 body_bytes = body
357
358 request_complete = True
359 return {"type": "http.request", "body": body_bytes}
360
361 async def send(message: Message) -> None:
362 nonlocal raw_kwargs, response_started, template, context
363
364 if message["type"] == "http.response.start":
365 assert (
366 not response_started
367 ), 'Received multiple "http.response.start" messages.'
368 raw_kwargs["status_code"] = message["status"]
369 raw_kwargs["headers"] = [
370 (key.decode(), value.decode())
371 for key, value in message.get("headers", [])
372 ]
373 response_started = True
374 elif message["type"] == "http.response.body":
375 assert (
376 response_started
377 ), 'Received "http.response.body" without "http.response.start".'
378 assert (
379 not response_complete.is_set()
380 ), 'Received "http.response.body" after response completed.'
381 body = message.get("body", b"")
382 more_body = message.get("more_body", False)
383 if request.method != "HEAD":
384 raw_kwargs["stream"].write(body)
385 if not more_body:
386 raw_kwargs["stream"].seek(0)
387 response_complete.set()
388 elif message["type"] == "http.response.debug":
389 template = message["info"]["template"]
390 context = message["info"]["context"]
391
392 try:
393 with self.portal_factory() as portal:
394 response_complete = portal.call(anyio.Event)
395 portal.call(self.app, scope, receive, send)
396 except BaseException as exc:
397 if self.raise_server_exceptions:
398 raise exc
399
400 if self.raise_server_exceptions:
401 assert response_started, "TestClient did not receive any response."
402 elif not response_started:
403 raw_kwargs = {
404 "status_code": 500,
405 "headers": [],
406 "stream": io.BytesIO(),
407 }
408
409 raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
410
411 response = httpx.Response(**raw_kwargs, request=request)
412 if template is not None:
413 response.template = template # type: ignore[attr-defined]
414 response.context = context # type: ignore[attr-defined]
415 return response
416
417
418class TestClient(httpx.Client):
419 __test__ = False
420 task: Future[None]
421 portal: anyio.abc.BlockingPortal | None = None
422
423 def __init__(
424 self,
425 app: ASGIApp,
426 base_url: str = "http://testserver",
427 raise_server_exceptions: bool = True,
428 root_path: str = "",
429 backend: typing.Literal["asyncio", "trio"] = "asyncio",
430 backend_options: dict[str, typing.Any] | None = None,
431 cookies: httpx._types.CookieTypes | None = None,
432 headers: dict[str, str] | None = None,
433 follow_redirects: bool = True,
434 ) -> None:
435 self.async_backend = _AsyncBackend(
436 backend=backend, backend_options=backend_options or {}
437 )
438 if _is_asgi3(app):
439 asgi_app = app
440 else:
441 app = typing.cast(ASGI2App, app) # type: ignore[assignment]
442 asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
443 self.app = asgi_app
444 self.app_state: dict[str, typing.Any] = {}
445 transport = _TestClientTransport(
446 self.app,
447 portal_factory=self._portal_factory,
448 raise_server_exceptions=raise_server_exceptions,
449 root_path=root_path,
450 app_state=self.app_state,
451 )
452 if headers is None:
453 headers = {}
454 headers.setdefault("user-agent", "testclient")
455 super().__init__(
456 base_url=base_url,
457 headers=headers,
458 transport=transport,
459 follow_redirects=follow_redirects,
460 cookies=cookies,
461 )
462
463 @contextlib.contextmanager
464 def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
465 if self.portal is not None:
466 yield self.portal
467 else:
468 with anyio.from_thread.start_blocking_portal(
469 **self.async_backend
470 ) as portal:
471 yield portal
472
473 def _choose_redirect_arg(
474 self, follow_redirects: bool | None, allow_redirects: bool | None
475 ) -> bool | httpx._client.UseClientDefault:
476 redirect: bool | httpx._client.UseClientDefault = (
477 httpx._client.USE_CLIENT_DEFAULT
478 )
479 if allow_redirects is not None:
480 message = (
481 "The `allow_redirects` argument is deprecated. "
482 "Use `follow_redirects` instead."
483 )
484 warnings.warn(message, DeprecationWarning)
485 redirect = allow_redirects
486 if follow_redirects is not None:
487 redirect = follow_redirects
488 elif allow_redirects is not None and follow_redirects is not None:
489 raise RuntimeError( # pragma: no cover
490 "Cannot use both `allow_redirects` and `follow_redirects`."
491 )
492 return redirect
493
494 def request( # type: ignore[override]
495 self,
496 method: str,
497 url: httpx._types.URLTypes,
498 *,
499 content: httpx._types.RequestContent | None = None,
500 data: _RequestData | None = None,
501 files: httpx._types.RequestFiles | None = None,
502 json: typing.Any = None,
503 params: httpx._types.QueryParamTypes | None = None,
504 headers: httpx._types.HeaderTypes | None = None,
505 cookies: httpx._types.CookieTypes | None = None,
506 auth: httpx._types.AuthTypes
507 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
508 follow_redirects: bool | None = None,
509 allow_redirects: bool | None = None,
510 timeout: httpx._types.TimeoutTypes
511 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
512 extensions: dict[str, typing.Any] | None = None,
513 ) -> httpx.Response:
514 url = self._merge_url(url)
515 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
516 return super().request(
517 method,
518 url,
519 content=content,
520 data=data,
521 files=files,
522 json=json,
523 params=params,
524 headers=headers,
525 cookies=cookies,
526 auth=auth,
527 follow_redirects=redirect,
528 timeout=timeout,
529 extensions=extensions,
530 )
531
532 def get( # type: ignore[override]
533 self,
534 url: httpx._types.URLTypes,
535 *,
536 params: httpx._types.QueryParamTypes | None = None,
537 headers: httpx._types.HeaderTypes | None = None,
538 cookies: httpx._types.CookieTypes | None = None,
539 auth: httpx._types.AuthTypes
540 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
541 follow_redirects: bool | None = None,
542 allow_redirects: bool | None = None,
543 timeout: httpx._types.TimeoutTypes
544 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
545 extensions: dict[str, typing.Any] | None = None,
546 ) -> httpx.Response:
547 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
548 return super().get(
549 url,
550 params=params,
551 headers=headers,
552 cookies=cookies,
553 auth=auth,
554 follow_redirects=redirect,
555 timeout=timeout,
556 extensions=extensions,
557 )
558
559 def options( # type: ignore[override]
560 self,
561 url: httpx._types.URLTypes,
562 *,
563 params: httpx._types.QueryParamTypes | None = None,
564 headers: httpx._types.HeaderTypes | None = None,
565 cookies: httpx._types.CookieTypes | None = None,
566 auth: httpx._types.AuthTypes
567 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
568 follow_redirects: bool | None = None,
569 allow_redirects: bool | None = None,
570 timeout: httpx._types.TimeoutTypes
571 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
572 extensions: dict[str, typing.Any] | None = None,
573 ) -> httpx.Response:
574 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
575 return super().options(
576 url,
577 params=params,
578 headers=headers,
579 cookies=cookies,
580 auth=auth,
581 follow_redirects=redirect,
582 timeout=timeout,
583 extensions=extensions,
584 )
585
586 def head( # type: ignore[override]
587 self,
588 url: httpx._types.URLTypes,
589 *,
590 params: httpx._types.QueryParamTypes | None = None,
591 headers: httpx._types.HeaderTypes | None = None,
592 cookies: httpx._types.CookieTypes | None = None,
593 auth: httpx._types.AuthTypes
594 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
595 follow_redirects: bool | None = None,
596 allow_redirects: bool | None = None,
597 timeout: httpx._types.TimeoutTypes
598 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
599 extensions: dict[str, typing.Any] | None = None,
600 ) -> httpx.Response:
601 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
602 return super().head(
603 url,
604 params=params,
605 headers=headers,
606 cookies=cookies,
607 auth=auth,
608 follow_redirects=redirect,
609 timeout=timeout,
610 extensions=extensions,
611 )
612
613 def post( # type: ignore[override]
614 self,
615 url: httpx._types.URLTypes,
616 *,
617 content: httpx._types.RequestContent | None = None,
618 data: _RequestData | None = None,
619 files: httpx._types.RequestFiles | None = None,
620 json: typing.Any = None,
621 params: httpx._types.QueryParamTypes | None = None,
622 headers: httpx._types.HeaderTypes | None = None,
623 cookies: httpx._types.CookieTypes | None = None,
624 auth: httpx._types.AuthTypes
625 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
626 follow_redirects: bool | None = None,
627 allow_redirects: bool | None = None,
628 timeout: httpx._types.TimeoutTypes
629 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
630 extensions: dict[str, typing.Any] | None = None,
631 ) -> httpx.Response:
632 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
633 return super().post(
634 url,
635 content=content,
636 data=data,
637 files=files,
638 json=json,
639 params=params,
640 headers=headers,
641 cookies=cookies,
642 auth=auth,
643 follow_redirects=redirect,
644 timeout=timeout,
645 extensions=extensions,
646 )
647
648 def put( # type: ignore[override]
649 self,
650 url: httpx._types.URLTypes,
651 *,
652 content: httpx._types.RequestContent | None = None,
653 data: _RequestData | None = None,
654 files: httpx._types.RequestFiles | None = None,
655 json: typing.Any = None,
656 params: httpx._types.QueryParamTypes | None = None,
657 headers: httpx._types.HeaderTypes | None = None,
658 cookies: httpx._types.CookieTypes | None = None,
659 auth: httpx._types.AuthTypes
660 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
661 follow_redirects: bool | None = None,
662 allow_redirects: bool | None = None,
663 timeout: httpx._types.TimeoutTypes
664 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
665 extensions: dict[str, typing.Any] | None = None,
666 ) -> httpx.Response:
667 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
668 return super().put(
669 url,
670 content=content,
671 data=data,
672 files=files,
673 json=json,
674 params=params,
675 headers=headers,
676 cookies=cookies,
677 auth=auth,
678 follow_redirects=redirect,
679 timeout=timeout,
680 extensions=extensions,
681 )
682
683 def patch( # type: ignore[override]
684 self,
685 url: httpx._types.URLTypes,
686 *,
687 content: httpx._types.RequestContent | None = None,
688 data: _RequestData | None = None,
689 files: httpx._types.RequestFiles | None = None,
690 json: typing.Any = None,
691 params: httpx._types.QueryParamTypes | None = None,
692 headers: httpx._types.HeaderTypes | None = None,
693 cookies: httpx._types.CookieTypes | None = None,
694 auth: httpx._types.AuthTypes
695 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
696 follow_redirects: bool | None = None,
697 allow_redirects: bool | None = None,
698 timeout: httpx._types.TimeoutTypes
699 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
700 extensions: dict[str, typing.Any] | None = None,
701 ) -> httpx.Response:
702 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
703 return super().patch(
704 url,
705 content=content,
706 data=data,
707 files=files,
708 json=json,
709 params=params,
710 headers=headers,
711 cookies=cookies,
712 auth=auth,
713 follow_redirects=redirect,
714 timeout=timeout,
715 extensions=extensions,
716 )
717
718 def delete( # type: ignore[override]
719 self,
720 url: httpx._types.URLTypes,
721 *,
722 params: httpx._types.QueryParamTypes | None = None,
723 headers: httpx._types.HeaderTypes | None = None,
724 cookies: httpx._types.CookieTypes | None = None,
725 auth: httpx._types.AuthTypes
726 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
727 follow_redirects: bool | None = None,
728 allow_redirects: bool | None = None,
729 timeout: httpx._types.TimeoutTypes
730 | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
731 extensions: dict[str, typing.Any] | None = None,
732 ) -> httpx.Response:
733 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
734 return super().delete(
735 url,
736 params=params,
737 headers=headers,
738 cookies=cookies,
739 auth=auth,
740 follow_redirects=redirect,
741 timeout=timeout,
742 extensions=extensions,
743 )
744
745 def websocket_connect(
746 self,
747 url: str,
748 subprotocols: typing.Sequence[str] | None = None,
749 **kwargs: typing.Any,
750 ) -> WebSocketTestSession:
751 url = urljoin("ws://testserver", url)
752 headers = kwargs.get("headers", {})
753 headers.setdefault("connection", "upgrade")
754 headers.setdefault("sec-websocket-key", "testserver==")
755 headers.setdefault("sec-websocket-version", "13")
756 if subprotocols is not None:
757 headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
758 kwargs["headers"] = headers
759 try:
760 super().request("GET", url, **kwargs)
761 except _Upgrade as exc:
762 session = exc.session
763 else:
764 raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
765
766 return session
767
768 def __enter__(self) -> TestClient:
769 with contextlib.ExitStack() as stack:
770 self.portal = portal = stack.enter_context(
771 anyio.from_thread.start_blocking_portal(**self.async_backend)
772 )
773
774 @stack.callback
775 def reset_portal() -> None:
776 self.portal = None
777
778 send1: ObjectSendStream[typing.MutableMapping[str, typing.Any] | None]
779 receive1: ObjectReceiveStream[typing.MutableMapping[str, typing.Any] | None]
780 send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
781 receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
782 send1, receive1 = anyio.create_memory_object_stream(math.inf)
783 send2, receive2 = anyio.create_memory_object_stream(math.inf)
784 self.stream_send = StapledObjectStream(send1, receive1)
785 self.stream_receive = StapledObjectStream(send2, receive2)
786 self.task = portal.start_task_soon(self.lifespan)
787 portal.call(self.wait_startup)
788
789 @stack.callback
790 def wait_shutdown() -> None:
791 portal.call(self.wait_shutdown)
792
793 self.exit_stack = stack.pop_all()
794
795 return self
796
797 def __exit__(self, *args: typing.Any) -> None:
798 self.exit_stack.close()
799
800 async def lifespan(self) -> None:
801 scope = {"type": "lifespan", "state": self.app_state}
802 try:
803 await self.app(scope, self.stream_receive.receive, self.stream_send.send)
804 finally:
805 await self.stream_send.send(None)
806
807 async def wait_startup(self) -> None:
808 await self.stream_receive.send({"type": "lifespan.startup"})
809
810 async def receive() -> typing.Any:
811 message = await self.stream_send.receive()
812 if message is None:
813 self.task.result()
814 return message
815
816 message = await receive()
817 assert message["type"] in (
818 "lifespan.startup.complete",
819 "lifespan.startup.failed",
820 )
821 if message["type"] == "lifespan.startup.failed":
822 await receive()
823
824 async def wait_shutdown(self) -> None:
825 async def receive() -> typing.Any:
826 message = await self.stream_send.receive()
827 if message is None:
828 self.task.result()
829 return message
830
831 async with self.stream_send:
832 await self.stream_receive.send({"type": "lifespan.shutdown"})
833 message = await receive()
834 assert message["type"] in (
835 "lifespan.shutdown.complete",
836 "lifespan.shutdown.failed",
837 )
838 if message["type"] == "lifespan.shutdown.failed":
839 await receive()