1from __future__ import annotations
2
3import contextlib
4import inspect
5import io
6import json
7import math
8import queue
9import sys
10import typing
11import warnings
12from concurrent.futures import Future
13from functools import cached_property
14from types import GeneratorType
15from urllib.parse import unquote, urljoin
16
17import anyio
18import anyio.abc
19import anyio.from_thread
20from anyio.abc import ObjectReceiveStream, ObjectSendStream
21from anyio.streams.stapled import StapledObjectStream
22
23from starlette._utils import is_async_callable
24from starlette.types import ASGIApp, Message, Receive, Scope, Send
25from starlette.websockets import WebSocketDisconnect
26
27if sys.version_info >= (3, 10): # pragma: no cover
28 from typing import TypeGuard
29else: # pragma: no cover
30 from typing_extensions import TypeGuard
31
32try:
33 import httpx
34except ModuleNotFoundError: # pragma: no cover
35 raise RuntimeError(
36 "The starlette.testclient module requires the httpx package to be installed.\n"
37 "You can install this with:\n"
38 " $ pip install httpx\n"
39 )
40_PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]]
41
42ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
43ASGI2App = typing.Callable[[Scope], ASGIInstance]
44ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
45
46
47_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]]
48
49
50def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
51 if inspect.isclass(app):
52 return hasattr(app, "__await__")
53 return is_async_callable(app)
54
55
56class _WrapASGI2:
57 """
58 Provide an ASGI3 interface onto an ASGI2 app.
59 """
60
61 def __init__(self, app: ASGI2App) -> None:
62 self.app = app
63
64 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
65 instance = self.app(scope)
66 await instance(receive, send)
67
68
69class _AsyncBackend(typing.TypedDict):
70 backend: str
71 backend_options: dict[str, typing.Any]
72
73
74class _Upgrade(Exception):
75 def __init__(self, session: WebSocketTestSession) -> None:
76 self.session = session
77
78
79class WebSocketDenialResponse( # type: ignore[misc]
80 httpx.Response,
81 WebSocketDisconnect,
82):
83 """
84 A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
85 `WebSocket` is closed before being accepted with a `send_denial_response()`.
86 """
87
88
89class WebSocketTestSession:
90 def __init__(
91 self,
92 app: ASGI3App,
93 scope: Scope,
94 portal_factory: _PortalFactoryType,
95 ) -> None:
96 self.app = app
97 self.scope = scope
98 self.accepted_subprotocol = None
99 self.portal_factory = portal_factory
100 self._receive_queue: queue.Queue[Message] = queue.Queue()
101 self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
102 self.extra_headers = None
103
104 def __enter__(self) -> WebSocketTestSession:
105 self.exit_stack = contextlib.ExitStack()
106 self.portal = self.exit_stack.enter_context(self.portal_factory())
107
108 try:
109 _: Future[None] = self.portal.start_task_soon(self._run)
110 self.send({"type": "websocket.connect"})
111 message = self.receive()
112 self._raise_on_close(message)
113 except Exception:
114 self.exit_stack.close()
115 raise
116 self.accepted_subprotocol = message.get("subprotocol", None)
117 self.extra_headers = message.get("headers", None)
118 return self
119
120 @cached_property
121 def should_close(self) -> anyio.Event:
122 return anyio.Event()
123
124 async def _notify_close(self) -> None:
125 self.should_close.set()
126
127 def __exit__(self, *args: typing.Any) -> None:
128 try:
129 self.close(1000)
130 finally:
131 self.portal.start_task_soon(self._notify_close)
132 self.exit_stack.close()
133 while not self._send_queue.empty():
134 message = self._send_queue.get()
135 if isinstance(message, BaseException):
136 raise message
137
138 async def _run(self) -> None:
139 """
140 The sub-thread in which the websocket session runs.
141 """
142
143 async def run_app(tg: anyio.abc.TaskGroup) -> None:
144 try:
145 await self.app(self.scope, self._asgi_receive, self._asgi_send)
146 except anyio.get_cancelled_exc_class():
147 ...
148 except BaseException as exc:
149 self._send_queue.put(exc)
150 raise
151 finally:
152 tg.cancel_scope.cancel()
153
154 async with anyio.create_task_group() as tg:
155 tg.start_soon(run_app, tg)
156 await self.should_close.wait()
157 tg.cancel_scope.cancel()
158
159 async def _asgi_receive(self) -> Message:
160 while self._receive_queue.empty():
161 self._queue_event = anyio.Event()
162 await self._queue_event.wait()
163 return self._receive_queue.get()
164
165 async def _asgi_send(self, message: Message) -> None:
166 self._send_queue.put(message)
167
168 def _raise_on_close(self, message: Message) -> None:
169 if message["type"] == "websocket.close":
170 raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
171 elif message["type"] == "websocket.http.response.start":
172 status_code: int = message["status"]
173 headers: list[tuple[bytes, bytes]] = message["headers"]
174 body: list[bytes] = []
175 while True:
176 message = self.receive()
177 assert message["type"] == "websocket.http.response.body"
178 body.append(message["body"])
179 if not message.get("more_body", False):
180 break
181 raise WebSocketDenialResponse(
182 status_code=status_code,
183 headers=headers,
184 content=b"".join(body),
185 )
186
187 def send(self, message: Message) -> None:
188 self._receive_queue.put(message)
189 if hasattr(self, "_queue_event"):
190 self.portal.start_task_soon(self._queue_event.set)
191
192 def send_text(self, data: str) -> None:
193 self.send({"type": "websocket.receive", "text": data})
194
195 def send_bytes(self, data: bytes) -> None:
196 self.send({"type": "websocket.receive", "bytes": data})
197
198 def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None:
199 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
200 if mode == "text":
201 self.send({"type": "websocket.receive", "text": text})
202 else:
203 self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
204
205 def close(self, code: int = 1000, reason: str | None = None) -> None:
206 self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
207
208 def receive(self) -> Message:
209 message = self._send_queue.get()
210 if isinstance(message, BaseException):
211 raise message
212 return message
213
214 def receive_text(self) -> str:
215 message = self.receive()
216 self._raise_on_close(message)
217 return typing.cast(str, message["text"])
218
219 def receive_bytes(self) -> bytes:
220 message = self.receive()
221 self._raise_on_close(message)
222 return typing.cast(bytes, message["bytes"])
223
224 def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any:
225 message = self.receive()
226 self._raise_on_close(message)
227 if mode == "text":
228 text = message["text"]
229 else:
230 text = message["bytes"].decode("utf-8")
231 return json.loads(text)
232
233
234class _TestClientTransport(httpx.BaseTransport):
235 def __init__(
236 self,
237 app: ASGI3App,
238 portal_factory: _PortalFactoryType,
239 raise_server_exceptions: bool = True,
240 root_path: str = "",
241 *,
242 app_state: dict[str, typing.Any],
243 ) -> None:
244 self.app = app
245 self.raise_server_exceptions = raise_server_exceptions
246 self.root_path = root_path
247 self.portal_factory = portal_factory
248 self.app_state = app_state
249
250 def handle_request(self, request: httpx.Request) -> httpx.Response:
251 scheme = request.url.scheme
252 netloc = request.url.netloc.decode(encoding="ascii")
253 path = request.url.path
254 raw_path = request.url.raw_path
255 query = request.url.query.decode(encoding="ascii")
256
257 default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
258
259 if ":" in netloc:
260 host, port_string = netloc.split(":", 1)
261 port = int(port_string)
262 else:
263 host = netloc
264 port = default_port
265
266 # Include the 'host' header.
267 if "host" in request.headers:
268 headers: list[tuple[bytes, bytes]] = []
269 elif port == default_port: # pragma: no cover
270 headers = [(b"host", host.encode())]
271 else: # pragma: no cover
272 headers = [(b"host", (f"{host}:{port}").encode())]
273
274 # Include other request headers.
275 headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
276
277 scope: dict[str, typing.Any]
278
279 if scheme in {"ws", "wss"}:
280 subprotocol = request.headers.get("sec-websocket-protocol", None)
281 if subprotocol is None:
282 subprotocols: typing.Sequence[str] = []
283 else:
284 subprotocols = [value.strip() for value in subprotocol.split(",")]
285 scope = {
286 "type": "websocket",
287 "path": unquote(path),
288 "raw_path": raw_path,
289 "root_path": self.root_path,
290 "scheme": scheme,
291 "query_string": query.encode(),
292 "headers": headers,
293 "client": ["testclient", 50000],
294 "server": [host, port],
295 "subprotocols": subprotocols,
296 "state": self.app_state.copy(),
297 "extensions": {"websocket.http.response": {}},
298 }
299 session = WebSocketTestSession(self.app, scope, self.portal_factory)
300 raise _Upgrade(session)
301
302 scope = {
303 "type": "http",
304 "http_version": "1.1",
305 "method": request.method,
306 "path": unquote(path),
307 "raw_path": raw_path,
308 "root_path": self.root_path,
309 "scheme": scheme,
310 "query_string": query.encode(),
311 "headers": headers,
312 "client": ["testclient", 50000],
313 "server": [host, port],
314 "extensions": {"http.response.debug": {}},
315 "state": self.app_state.copy(),
316 }
317
318 request_complete = False
319 response_started = False
320 response_complete: anyio.Event
321 raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()}
322 template = None
323 context = None
324
325 async def receive() -> Message:
326 nonlocal request_complete
327
328 if request_complete:
329 if not response_complete.is_set():
330 await response_complete.wait()
331 return {"type": "http.disconnect"}
332
333 body = request.read()
334 if isinstance(body, str):
335 body_bytes: bytes = body.encode("utf-8") # pragma: no cover
336 elif body is None:
337 body_bytes = b"" # pragma: no cover
338 elif isinstance(body, GeneratorType):
339 try: # pragma: no cover
340 chunk = body.send(None)
341 if isinstance(chunk, str):
342 chunk = chunk.encode("utf-8")
343 return {"type": "http.request", "body": chunk, "more_body": True}
344 except StopIteration: # pragma: no cover
345 request_complete = True
346 return {"type": "http.request", "body": b""}
347 else:
348 body_bytes = body
349
350 request_complete = True
351 return {"type": "http.request", "body": body_bytes}
352
353 async def send(message: Message) -> None:
354 nonlocal raw_kwargs, response_started, template, context
355
356 if message["type"] == "http.response.start":
357 assert not response_started, 'Received multiple "http.response.start" messages.'
358 raw_kwargs["status_code"] = message["status"]
359 raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
360 response_started = True
361 elif message["type"] == "http.response.body":
362 assert response_started, 'Received "http.response.body" without "http.response.start".'
363 assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
364 body = message.get("body", b"")
365 more_body = message.get("more_body", False)
366 if request.method != "HEAD":
367 raw_kwargs["stream"].write(body)
368 if not more_body:
369 raw_kwargs["stream"].seek(0)
370 response_complete.set()
371 elif message["type"] == "http.response.debug":
372 template = message["info"]["template"]
373 context = message["info"]["context"]
374
375 try:
376 with self.portal_factory() as portal:
377 response_complete = portal.call(anyio.Event)
378 portal.call(self.app, scope, receive, send)
379 except BaseException as exc:
380 if self.raise_server_exceptions:
381 raise exc
382
383 if self.raise_server_exceptions:
384 assert response_started, "TestClient did not receive any response."
385 elif not response_started:
386 raw_kwargs = {
387 "status_code": 500,
388 "headers": [],
389 "stream": io.BytesIO(),
390 }
391
392 raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
393
394 response = httpx.Response(**raw_kwargs, request=request)
395 if template is not None:
396 response.template = template # type: ignore[attr-defined]
397 response.context = context # type: ignore[attr-defined]
398 return response
399
400
401class TestClient(httpx.Client):
402 __test__ = False
403 task: Future[None]
404 portal: anyio.abc.BlockingPortal | None = None
405
406 def __init__(
407 self,
408 app: ASGIApp,
409 base_url: str = "http://testserver",
410 raise_server_exceptions: bool = True,
411 root_path: str = "",
412 backend: typing.Literal["asyncio", "trio"] = "asyncio",
413 backend_options: dict[str, typing.Any] | None = None,
414 cookies: httpx._types.CookieTypes | None = None,
415 headers: dict[str, str] | None = None,
416 follow_redirects: bool = True,
417 ) -> None:
418 self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
419 if _is_asgi3(app):
420 asgi_app = app
421 else:
422 app = typing.cast(ASGI2App, app) # type: ignore[assignment]
423 asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
424 self.app = asgi_app
425 self.app_state: dict[str, typing.Any] = {}
426 transport = _TestClientTransport(
427 self.app,
428 portal_factory=self._portal_factory,
429 raise_server_exceptions=raise_server_exceptions,
430 root_path=root_path,
431 app_state=self.app_state,
432 )
433 if headers is None:
434 headers = {}
435 headers.setdefault("user-agent", "testclient")
436 super().__init__(
437 base_url=base_url,
438 headers=headers,
439 transport=transport,
440 follow_redirects=follow_redirects,
441 cookies=cookies,
442 )
443
444 @contextlib.contextmanager
445 def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
446 if self.portal is not None:
447 yield self.portal
448 else:
449 with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
450 yield portal
451
452 def _choose_redirect_arg(
453 self, follow_redirects: bool | None, allow_redirects: bool | None
454 ) -> bool | httpx._client.UseClientDefault:
455 redirect: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT
456 if allow_redirects is not None:
457 message = "The `allow_redirects` argument is deprecated. Use `follow_redirects` instead."
458 warnings.warn(message, DeprecationWarning)
459 redirect = allow_redirects
460 if follow_redirects is not None:
461 redirect = follow_redirects
462 elif allow_redirects is not None and follow_redirects is not None:
463 raise RuntimeError( # pragma: no cover
464 "Cannot use both `allow_redirects` and `follow_redirects`."
465 )
466 return redirect
467
468 def request( # type: ignore[override]
469 self,
470 method: str,
471 url: httpx._types.URLTypes,
472 *,
473 content: httpx._types.RequestContent | None = None,
474 data: _RequestData | None = None,
475 files: httpx._types.RequestFiles | None = None,
476 json: typing.Any = None,
477 params: httpx._types.QueryParamTypes | None = None,
478 headers: httpx._types.HeaderTypes | None = None,
479 cookies: httpx._types.CookieTypes | None = None,
480 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
481 follow_redirects: bool | None = None,
482 allow_redirects: bool | None = None,
483 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
484 extensions: dict[str, typing.Any] | None = None,
485 ) -> httpx.Response:
486 url = self._merge_url(url)
487 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
488 return super().request(
489 method,
490 url,
491 content=content,
492 data=data,
493 files=files,
494 json=json,
495 params=params,
496 headers=headers,
497 cookies=cookies,
498 auth=auth,
499 follow_redirects=redirect,
500 timeout=timeout,
501 extensions=extensions,
502 )
503
504 def get( # type: ignore[override]
505 self,
506 url: httpx._types.URLTypes,
507 *,
508 params: httpx._types.QueryParamTypes | None = None,
509 headers: httpx._types.HeaderTypes | None = None,
510 cookies: httpx._types.CookieTypes | None = None,
511 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
512 follow_redirects: bool | None = None,
513 allow_redirects: bool | None = None,
514 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
515 extensions: dict[str, typing.Any] | None = None,
516 ) -> httpx.Response:
517 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
518 return super().get(
519 url,
520 params=params,
521 headers=headers,
522 cookies=cookies,
523 auth=auth,
524 follow_redirects=redirect,
525 timeout=timeout,
526 extensions=extensions,
527 )
528
529 def options( # type: ignore[override]
530 self,
531 url: httpx._types.URLTypes,
532 *,
533 params: httpx._types.QueryParamTypes | None = None,
534 headers: httpx._types.HeaderTypes | None = None,
535 cookies: httpx._types.CookieTypes | None = None,
536 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
537 follow_redirects: bool | None = None,
538 allow_redirects: bool | None = None,
539 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
540 extensions: dict[str, typing.Any] | None = None,
541 ) -> httpx.Response:
542 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
543 return super().options(
544 url,
545 params=params,
546 headers=headers,
547 cookies=cookies,
548 auth=auth,
549 follow_redirects=redirect,
550 timeout=timeout,
551 extensions=extensions,
552 )
553
554 def head( # type: ignore[override]
555 self,
556 url: httpx._types.URLTypes,
557 *,
558 params: httpx._types.QueryParamTypes | None = None,
559 headers: httpx._types.HeaderTypes | None = None,
560 cookies: httpx._types.CookieTypes | None = None,
561 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
562 follow_redirects: bool | None = None,
563 allow_redirects: bool | None = None,
564 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
565 extensions: dict[str, typing.Any] | None = None,
566 ) -> httpx.Response:
567 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
568 return super().head(
569 url,
570 params=params,
571 headers=headers,
572 cookies=cookies,
573 auth=auth,
574 follow_redirects=redirect,
575 timeout=timeout,
576 extensions=extensions,
577 )
578
579 def post( # type: ignore[override]
580 self,
581 url: httpx._types.URLTypes,
582 *,
583 content: httpx._types.RequestContent | None = None,
584 data: _RequestData | None = None,
585 files: httpx._types.RequestFiles | None = None,
586 json: typing.Any = None,
587 params: httpx._types.QueryParamTypes | None = None,
588 headers: httpx._types.HeaderTypes | None = None,
589 cookies: httpx._types.CookieTypes | None = None,
590 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
591 follow_redirects: bool | None = None,
592 allow_redirects: bool | None = None,
593 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
594 extensions: dict[str, typing.Any] | None = None,
595 ) -> httpx.Response:
596 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
597 return super().post(
598 url,
599 content=content,
600 data=data,
601 files=files,
602 json=json,
603 params=params,
604 headers=headers,
605 cookies=cookies,
606 auth=auth,
607 follow_redirects=redirect,
608 timeout=timeout,
609 extensions=extensions,
610 )
611
612 def put( # type: ignore[override]
613 self,
614 url: httpx._types.URLTypes,
615 *,
616 content: httpx._types.RequestContent | None = None,
617 data: _RequestData | None = None,
618 files: httpx._types.RequestFiles | None = None,
619 json: typing.Any = None,
620 params: httpx._types.QueryParamTypes | None = None,
621 headers: httpx._types.HeaderTypes | None = None,
622 cookies: httpx._types.CookieTypes | None = None,
623 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
624 follow_redirects: bool | None = None,
625 allow_redirects: bool | None = None,
626 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
627 extensions: dict[str, typing.Any] | None = None,
628 ) -> httpx.Response:
629 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
630 return super().put(
631 url,
632 content=content,
633 data=data,
634 files=files,
635 json=json,
636 params=params,
637 headers=headers,
638 cookies=cookies,
639 auth=auth,
640 follow_redirects=redirect,
641 timeout=timeout,
642 extensions=extensions,
643 )
644
645 def patch( # type: ignore[override]
646 self,
647 url: httpx._types.URLTypes,
648 *,
649 content: httpx._types.RequestContent | None = None,
650 data: _RequestData | None = None,
651 files: httpx._types.RequestFiles | None = None,
652 json: typing.Any = None,
653 params: httpx._types.QueryParamTypes | None = None,
654 headers: httpx._types.HeaderTypes | None = None,
655 cookies: httpx._types.CookieTypes | None = None,
656 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
657 follow_redirects: bool | None = None,
658 allow_redirects: bool | None = None,
659 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
660 extensions: dict[str, typing.Any] | None = None,
661 ) -> httpx.Response:
662 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
663 return super().patch(
664 url,
665 content=content,
666 data=data,
667 files=files,
668 json=json,
669 params=params,
670 headers=headers,
671 cookies=cookies,
672 auth=auth,
673 follow_redirects=redirect,
674 timeout=timeout,
675 extensions=extensions,
676 )
677
678 def delete( # type: ignore[override]
679 self,
680 url: httpx._types.URLTypes,
681 *,
682 params: httpx._types.QueryParamTypes | None = None,
683 headers: httpx._types.HeaderTypes | None = None,
684 cookies: httpx._types.CookieTypes | None = None,
685 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
686 follow_redirects: bool | None = None,
687 allow_redirects: bool | None = None,
688 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
689 extensions: dict[str, typing.Any] | None = None,
690 ) -> httpx.Response:
691 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
692 return super().delete(
693 url,
694 params=params,
695 headers=headers,
696 cookies=cookies,
697 auth=auth,
698 follow_redirects=redirect,
699 timeout=timeout,
700 extensions=extensions,
701 )
702
703 def websocket_connect(
704 self,
705 url: str,
706 subprotocols: typing.Sequence[str] | None = None,
707 **kwargs: typing.Any,
708 ) -> WebSocketTestSession:
709 url = urljoin("ws://testserver", url)
710 headers = kwargs.get("headers", {})
711 headers.setdefault("connection", "upgrade")
712 headers.setdefault("sec-websocket-key", "testserver==")
713 headers.setdefault("sec-websocket-version", "13")
714 if subprotocols is not None:
715 headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
716 kwargs["headers"] = headers
717 try:
718 super().request("GET", url, **kwargs)
719 except _Upgrade as exc:
720 session = exc.session
721 else:
722 raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
723
724 return session
725
726 def __enter__(self) -> TestClient:
727 with contextlib.ExitStack() as stack:
728 self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
729
730 @stack.callback
731 def reset_portal() -> None:
732 self.portal = None
733
734 send1: ObjectSendStream[typing.MutableMapping[str, typing.Any] | None]
735 receive1: ObjectReceiveStream[typing.MutableMapping[str, typing.Any] | None]
736 send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
737 receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
738 send1, receive1 = anyio.create_memory_object_stream(math.inf)
739 send2, receive2 = anyio.create_memory_object_stream(math.inf)
740 self.stream_send = StapledObjectStream(send1, receive1)
741 self.stream_receive = StapledObjectStream(send2, receive2)
742 self.task = portal.start_task_soon(self.lifespan)
743 portal.call(self.wait_startup)
744
745 @stack.callback
746 def wait_shutdown() -> None:
747 portal.call(self.wait_shutdown)
748
749 self.exit_stack = stack.pop_all()
750
751 return self
752
753 def __exit__(self, *args: typing.Any) -> None:
754 self.exit_stack.close()
755
756 async def lifespan(self) -> None:
757 scope = {"type": "lifespan", "state": self.app_state}
758 try:
759 await self.app(scope, self.stream_receive.receive, self.stream_send.send)
760 finally:
761 await self.stream_send.send(None)
762
763 async def wait_startup(self) -> None:
764 await self.stream_receive.send({"type": "lifespan.startup"})
765
766 async def receive() -> typing.Any:
767 message = await self.stream_send.receive()
768 if message is None:
769 self.task.result()
770 return message
771
772 message = await receive()
773 assert message["type"] in (
774 "lifespan.startup.complete",
775 "lifespan.startup.failed",
776 )
777 if message["type"] == "lifespan.startup.failed":
778 await receive()
779
780 async def wait_shutdown(self) -> None:
781 async def receive() -> typing.Any:
782 message = await self.stream_send.receive()
783 if message is None:
784 self.task.result()
785 return message
786
787 async with self.stream_send:
788 await self.stream_receive.send({"type": "lifespan.shutdown"})
789 message = await receive()
790 assert message["type"] in (
791 "lifespan.shutdown.complete",
792 "lifespan.shutdown.failed",
793 )
794 if message["type"] == "lifespan.shutdown.failed":
795 await receive()