1from __future__ import annotations
2
3import contextlib
4import inspect
5import io
6import json
7import math
8import sys
9import warnings
10from collections.abc import Awaitable, Generator, Iterable, Mapping, MutableMapping, Sequence
11from concurrent.futures import Future
12from contextlib import AbstractContextManager
13from types import GeneratorType
14from typing import (
15 Any,
16 Callable,
17 Literal,
18 TypedDict,
19 Union,
20 cast,
21)
22from urllib.parse import unquote, urljoin
23
24import anyio
25import anyio.abc
26import anyio.from_thread
27from anyio.streams.stapled import StapledObjectStream
28
29from starlette._utils import is_async_callable
30from starlette.types import ASGIApp, Message, Receive, Scope, Send
31from starlette.websockets import WebSocketDisconnect
32
33if sys.version_info >= (3, 10): # pragma: no cover
34 from typing import TypeGuard
35else: # pragma: no cover
36 from typing_extensions import TypeGuard
37
38if sys.version_info >= (3, 11): # pragma: no cover
39 from typing import Self
40else: # pragma: no cover
41 from typing_extensions import Self
42
43try:
44 import httpx
45except ModuleNotFoundError: # pragma: no cover
46 raise RuntimeError(
47 "The starlette.testclient module requires the httpx package to be installed.\n"
48 "You can install this with:\n"
49 " $ pip install httpx\n"
50 )
51_PortalFactoryType = Callable[[], AbstractContextManager[anyio.abc.BlockingPortal]]
52
53ASGIInstance = Callable[[Receive, Send], Awaitable[None]]
54ASGI2App = Callable[[Scope], ASGIInstance]
55ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]]
56
57
58_RequestData = Mapping[str, Union[str, Iterable[str], bytes]]
59
60
61def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
62 if inspect.isclass(app):
63 return hasattr(app, "__await__")
64 return is_async_callable(app)
65
66
67class _WrapASGI2:
68 """
69 Provide an ASGI3 interface onto an ASGI2 app.
70 """
71
72 def __init__(self, app: ASGI2App) -> None:
73 self.app = app
74
75 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
76 instance = self.app(scope)
77 await instance(receive, send)
78
79
80class _AsyncBackend(TypedDict):
81 backend: str
82 backend_options: dict[str, Any]
83
84
85class _Upgrade(Exception):
86 def __init__(self, session: WebSocketTestSession) -> None:
87 self.session = session
88
89
90class WebSocketDenialResponse( # type: ignore[misc]
91 httpx.Response,
92 WebSocketDisconnect,
93):
94 """
95 A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
96 `WebSocket` is closed before being accepted with a `send_denial_response()`.
97 """
98
99
100class WebSocketTestSession:
101 def __init__(
102 self,
103 app: ASGI3App,
104 scope: Scope,
105 portal_factory: _PortalFactoryType,
106 ) -> None:
107 self.app = app
108 self.scope = scope
109 self.accepted_subprotocol = None
110 self.portal_factory = portal_factory
111 self.extra_headers = None
112
113 def __enter__(self) -> WebSocketTestSession:
114 with contextlib.ExitStack() as stack:
115 self.portal = portal = stack.enter_context(self.portal_factory())
116 fut, cs = portal.start_task(self._run)
117 stack.callback(fut.result)
118 stack.callback(portal.call, cs.cancel)
119 self.send({"type": "websocket.connect"})
120 message = self.receive()
121 self._raise_on_close(message)
122 self.accepted_subprotocol = message.get("subprotocol", None)
123 self.extra_headers = message.get("headers", None)
124 stack.callback(self.close, 1000)
125 self.exit_stack = stack.pop_all()
126 return self
127
128 def __exit__(self, *args: Any) -> bool | None:
129 return self.exit_stack.__exit__(*args)
130
131 async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
132 """
133 The sub-thread in which the websocket session runs.
134 """
135 send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
136 send_tx, send_rx = send
137 receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
138 receive_tx, receive_rx = receive
139 with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
140 self._receive_tx = receive_tx
141 self._send_rx = send_rx
142 task_status.started(cs)
143 await self.app(self.scope, receive_rx.receive, send_tx.send)
144
145 # wait for cs.cancel to be called before closing streams
146 await anyio.sleep_forever()
147
148 def _raise_on_close(self, message: Message) -> None:
149 if message["type"] == "websocket.close":
150 raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
151 elif message["type"] == "websocket.http.response.start":
152 status_code: int = message["status"]
153 headers: list[tuple[bytes, bytes]] = message["headers"]
154 body: list[bytes] = []
155 while True:
156 message = self.receive()
157 assert message["type"] == "websocket.http.response.body"
158 body.append(message["body"])
159 if not message.get("more_body", False):
160 break
161 raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
162
163 def send(self, message: Message) -> None:
164 self.portal.call(self._receive_tx.send, message)
165
166 def send_text(self, data: str) -> None:
167 self.send({"type": "websocket.receive", "text": data})
168
169 def send_bytes(self, data: bytes) -> None:
170 self.send({"type": "websocket.receive", "bytes": data})
171
172 def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None:
173 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
174 if mode == "text":
175 self.send({"type": "websocket.receive", "text": text})
176 else:
177 self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
178
179 def close(self, code: int = 1000, reason: str | None = None) -> None:
180 self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
181
182 def receive(self) -> Message:
183 return self.portal.call(self._send_rx.receive)
184
185 def receive_text(self) -> str:
186 message = self.receive()
187 self._raise_on_close(message)
188 return cast(str, message["text"])
189
190 def receive_bytes(self) -> bytes:
191 message = self.receive()
192 self._raise_on_close(message)
193 return cast(bytes, message["bytes"])
194
195 def receive_json(self, mode: Literal["text", "binary"] = "text") -> Any:
196 message = self.receive()
197 self._raise_on_close(message)
198 if mode == "text":
199 text = message["text"]
200 else:
201 text = message["bytes"].decode("utf-8")
202 return json.loads(text)
203
204
205class _TestClientTransport(httpx.BaseTransport):
206 def __init__(
207 self,
208 app: ASGI3App,
209 portal_factory: _PortalFactoryType,
210 raise_server_exceptions: bool = True,
211 root_path: str = "",
212 *,
213 client: tuple[str, int],
214 app_state: dict[str, Any],
215 ) -> None:
216 self.app = app
217 self.raise_server_exceptions = raise_server_exceptions
218 self.root_path = root_path
219 self.portal_factory = portal_factory
220 self.app_state = app_state
221 self.client = client
222
223 def handle_request(self, request: httpx.Request) -> httpx.Response:
224 scheme = request.url.scheme
225 netloc = request.url.netloc.decode(encoding="ascii")
226 path = request.url.path
227 raw_path = request.url.raw_path
228 query = request.url.query.decode(encoding="ascii")
229
230 default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
231
232 if ":" in netloc:
233 host, port_string = netloc.split(":", 1)
234 port = int(port_string)
235 else:
236 host = netloc
237 port = default_port
238
239 # Include the 'host' header.
240 if "host" in request.headers:
241 headers: list[tuple[bytes, bytes]] = []
242 elif port == default_port: # pragma: no cover
243 headers = [(b"host", host.encode())]
244 else: # pragma: no cover
245 headers = [(b"host", (f"{host}:{port}").encode())]
246
247 # Include other request headers.
248 headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
249
250 scope: dict[str, Any]
251
252 if scheme in {"ws", "wss"}:
253 subprotocol = request.headers.get("sec-websocket-protocol", None)
254 if subprotocol is None:
255 subprotocols: Sequence[str] = []
256 else:
257 subprotocols = [value.strip() for value in subprotocol.split(",")]
258 scope = {
259 "type": "websocket",
260 "path": unquote(path),
261 "raw_path": raw_path.split(b"?", 1)[0],
262 "root_path": self.root_path,
263 "scheme": scheme,
264 "query_string": query.encode(),
265 "headers": headers,
266 "client": self.client,
267 "server": [host, port],
268 "subprotocols": subprotocols,
269 "state": self.app_state.copy(),
270 "extensions": {"websocket.http.response": {}},
271 }
272 session = WebSocketTestSession(self.app, scope, self.portal_factory)
273 raise _Upgrade(session)
274
275 scope = {
276 "type": "http",
277 "http_version": "1.1",
278 "method": request.method,
279 "path": unquote(path),
280 "raw_path": raw_path.split(b"?", 1)[0],
281 "root_path": self.root_path,
282 "scheme": scheme,
283 "query_string": query.encode(),
284 "headers": headers,
285 "client": self.client,
286 "server": [host, port],
287 "extensions": {"http.response.debug": {}},
288 "state": self.app_state.copy(),
289 }
290
291 request_complete = False
292 response_started = False
293 response_complete: anyio.Event
294 raw_kwargs: dict[str, Any] = {"stream": io.BytesIO()}
295 template = None
296 context = None
297
298 async def receive() -> Message:
299 nonlocal request_complete
300
301 if request_complete:
302 if not response_complete.is_set():
303 await response_complete.wait()
304 return {"type": "http.disconnect"}
305
306 body = request.read()
307 if isinstance(body, str):
308 body_bytes: bytes = body.encode("utf-8") # pragma: no cover
309 elif body is None:
310 body_bytes = b"" # pragma: no cover
311 elif isinstance(body, GeneratorType):
312 try: # pragma: no cover
313 chunk = body.send(None)
314 if isinstance(chunk, str):
315 chunk = chunk.encode("utf-8")
316 return {"type": "http.request", "body": chunk, "more_body": True}
317 except StopIteration: # pragma: no cover
318 request_complete = True
319 return {"type": "http.request", "body": b""}
320 else:
321 body_bytes = body
322
323 request_complete = True
324 return {"type": "http.request", "body": body_bytes}
325
326 async def send(message: Message) -> None:
327 nonlocal raw_kwargs, response_started, template, context
328
329 if message["type"] == "http.response.start":
330 assert not response_started, 'Received multiple "http.response.start" messages.'
331 raw_kwargs["status_code"] = message["status"]
332 raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
333 response_started = True
334 elif message["type"] == "http.response.body":
335 assert response_started, 'Received "http.response.body" without "http.response.start".'
336 assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
337 body = message.get("body", b"")
338 more_body = message.get("more_body", False)
339 if request.method != "HEAD":
340 raw_kwargs["stream"].write(body)
341 if not more_body:
342 raw_kwargs["stream"].seek(0)
343 response_complete.set()
344 elif message["type"] == "http.response.debug":
345 template = message["info"]["template"]
346 context = message["info"]["context"]
347
348 try:
349 with self.portal_factory() as portal:
350 response_complete = portal.call(anyio.Event)
351 portal.call(self.app, scope, receive, send)
352 except BaseException as exc:
353 if self.raise_server_exceptions:
354 raise exc
355
356 if self.raise_server_exceptions:
357 assert response_started, "TestClient did not receive any response."
358 elif not response_started:
359 raw_kwargs = {
360 "status_code": 500,
361 "headers": [],
362 "stream": io.BytesIO(),
363 }
364
365 raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
366
367 response = httpx.Response(**raw_kwargs, request=request)
368 if template is not None:
369 response.template = template # type: ignore[attr-defined]
370 response.context = context # type: ignore[attr-defined]
371 return response
372
373
374class TestClient(httpx.Client):
375 __test__ = False
376 task: Future[None]
377 portal: anyio.abc.BlockingPortal | None = None
378
379 def __init__(
380 self,
381 app: ASGIApp,
382 base_url: str = "http://testserver",
383 raise_server_exceptions: bool = True,
384 root_path: str = "",
385 backend: Literal["asyncio", "trio"] = "asyncio",
386 backend_options: dict[str, Any] | None = None,
387 cookies: httpx._types.CookieTypes | None = None,
388 headers: dict[str, str] | None = None,
389 follow_redirects: bool = True,
390 client: tuple[str, int] = ("testclient", 50000),
391 ) -> None:
392 self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
393 if _is_asgi3(app):
394 asgi_app = app
395 else:
396 app = cast(ASGI2App, app) # type: ignore[assignment]
397 asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
398 self.app = asgi_app
399 self.app_state: dict[str, Any] = {}
400 transport = _TestClientTransport(
401 self.app,
402 portal_factory=self._portal_factory,
403 raise_server_exceptions=raise_server_exceptions,
404 root_path=root_path,
405 app_state=self.app_state,
406 client=client,
407 )
408 if headers is None:
409 headers = {}
410 headers.setdefault("user-agent", "testclient")
411 super().__init__(
412 base_url=base_url,
413 headers=headers,
414 transport=transport,
415 follow_redirects=follow_redirects,
416 cookies=cookies,
417 )
418
419 @contextlib.contextmanager
420 def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]:
421 if self.portal is not None:
422 yield self.portal
423 else:
424 with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
425 yield portal
426
427 def request( # type: ignore[override]
428 self,
429 method: str,
430 url: httpx._types.URLTypes,
431 *,
432 content: httpx._types.RequestContent | None = None,
433 data: _RequestData | None = None,
434 files: httpx._types.RequestFiles | None = None,
435 json: Any = None,
436 params: httpx._types.QueryParamTypes | None = None,
437 headers: httpx._types.HeaderTypes | None = None,
438 cookies: httpx._types.CookieTypes | None = None,
439 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
440 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
441 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
442 extensions: dict[str, Any] | None = None,
443 ) -> httpx.Response:
444 if timeout is not httpx.USE_CLIENT_DEFAULT:
445 warnings.warn(
446 "You should not use the 'timeout' argument with the TestClient. "
447 "See https://github.com/encode/starlette/issues/1108 for more information.",
448 DeprecationWarning,
449 )
450 url = self._merge_url(url)
451 return super().request(
452 method,
453 url,
454 content=content,
455 data=data,
456 files=files,
457 json=json,
458 params=params,
459 headers=headers,
460 cookies=cookies,
461 auth=auth,
462 follow_redirects=follow_redirects,
463 timeout=timeout,
464 extensions=extensions,
465 )
466
467 def get( # type: ignore[override]
468 self,
469 url: httpx._types.URLTypes,
470 *,
471 params: httpx._types.QueryParamTypes | None = None,
472 headers: httpx._types.HeaderTypes | None = None,
473 cookies: httpx._types.CookieTypes | None = None,
474 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
475 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
476 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
477 extensions: dict[str, Any] | None = None,
478 ) -> httpx.Response:
479 return super().get(
480 url,
481 params=params,
482 headers=headers,
483 cookies=cookies,
484 auth=auth,
485 follow_redirects=follow_redirects,
486 timeout=timeout,
487 extensions=extensions,
488 )
489
490 def options( # type: ignore[override]
491 self,
492 url: httpx._types.URLTypes,
493 *,
494 params: httpx._types.QueryParamTypes | None = None,
495 headers: httpx._types.HeaderTypes | None = None,
496 cookies: httpx._types.CookieTypes | None = None,
497 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
498 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
499 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
500 extensions: dict[str, Any] | None = None,
501 ) -> httpx.Response:
502 return super().options(
503 url,
504 params=params,
505 headers=headers,
506 cookies=cookies,
507 auth=auth,
508 follow_redirects=follow_redirects,
509 timeout=timeout,
510 extensions=extensions,
511 )
512
513 def head( # type: ignore[override]
514 self,
515 url: httpx._types.URLTypes,
516 *,
517 params: httpx._types.QueryParamTypes | None = None,
518 headers: httpx._types.HeaderTypes | None = None,
519 cookies: httpx._types.CookieTypes | None = None,
520 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
521 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
522 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
523 extensions: dict[str, Any] | None = None,
524 ) -> httpx.Response:
525 return super().head(
526 url,
527 params=params,
528 headers=headers,
529 cookies=cookies,
530 auth=auth,
531 follow_redirects=follow_redirects,
532 timeout=timeout,
533 extensions=extensions,
534 )
535
536 def post( # type: ignore[override]
537 self,
538 url: httpx._types.URLTypes,
539 *,
540 content: httpx._types.RequestContent | None = None,
541 data: _RequestData | None = None,
542 files: httpx._types.RequestFiles | None = None,
543 json: Any = None,
544 params: httpx._types.QueryParamTypes | None = None,
545 headers: httpx._types.HeaderTypes | None = None,
546 cookies: httpx._types.CookieTypes | None = None,
547 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
548 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
549 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
550 extensions: dict[str, Any] | None = None,
551 ) -> httpx.Response:
552 return super().post(
553 url,
554 content=content,
555 data=data,
556 files=files,
557 json=json,
558 params=params,
559 headers=headers,
560 cookies=cookies,
561 auth=auth,
562 follow_redirects=follow_redirects,
563 timeout=timeout,
564 extensions=extensions,
565 )
566
567 def put( # type: ignore[override]
568 self,
569 url: httpx._types.URLTypes,
570 *,
571 content: httpx._types.RequestContent | None = None,
572 data: _RequestData | None = None,
573 files: httpx._types.RequestFiles | None = None,
574 json: Any = None,
575 params: httpx._types.QueryParamTypes | None = None,
576 headers: httpx._types.HeaderTypes | None = None,
577 cookies: httpx._types.CookieTypes | None = None,
578 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
579 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
580 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
581 extensions: dict[str, Any] | None = None,
582 ) -> httpx.Response:
583 return super().put(
584 url,
585 content=content,
586 data=data,
587 files=files,
588 json=json,
589 params=params,
590 headers=headers,
591 cookies=cookies,
592 auth=auth,
593 follow_redirects=follow_redirects,
594 timeout=timeout,
595 extensions=extensions,
596 )
597
598 def patch( # type: ignore[override]
599 self,
600 url: httpx._types.URLTypes,
601 *,
602 content: httpx._types.RequestContent | None = None,
603 data: _RequestData | None = None,
604 files: httpx._types.RequestFiles | None = None,
605 json: Any = None,
606 params: httpx._types.QueryParamTypes | None = None,
607 headers: httpx._types.HeaderTypes | None = None,
608 cookies: httpx._types.CookieTypes | None = None,
609 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
610 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
611 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
612 extensions: dict[str, Any] | None = None,
613 ) -> httpx.Response:
614 return super().patch(
615 url,
616 content=content,
617 data=data,
618 files=files,
619 json=json,
620 params=params,
621 headers=headers,
622 cookies=cookies,
623 auth=auth,
624 follow_redirects=follow_redirects,
625 timeout=timeout,
626 extensions=extensions,
627 )
628
629 def delete( # type: ignore[override]
630 self,
631 url: httpx._types.URLTypes,
632 *,
633 params: httpx._types.QueryParamTypes | None = None,
634 headers: httpx._types.HeaderTypes | None = None,
635 cookies: httpx._types.CookieTypes | None = None,
636 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
637 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
638 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
639 extensions: dict[str, Any] | None = None,
640 ) -> httpx.Response:
641 return super().delete(
642 url,
643 params=params,
644 headers=headers,
645 cookies=cookies,
646 auth=auth,
647 follow_redirects=follow_redirects,
648 timeout=timeout,
649 extensions=extensions,
650 )
651
652 def websocket_connect(
653 self,
654 url: str,
655 subprotocols: Sequence[str] | None = None,
656 **kwargs: Any,
657 ) -> WebSocketTestSession:
658 url = urljoin("ws://testserver", url)
659 headers = kwargs.get("headers", {})
660 headers.setdefault("connection", "upgrade")
661 headers.setdefault("sec-websocket-key", "testserver==")
662 headers.setdefault("sec-websocket-version", "13")
663 if subprotocols is not None:
664 headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
665 kwargs["headers"] = headers
666 try:
667 super().request("GET", url, **kwargs)
668 except _Upgrade as exc:
669 session = exc.session
670 else:
671 raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
672
673 return session
674
675 def __enter__(self) -> Self:
676 with contextlib.ExitStack() as stack:
677 self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
678
679 @stack.callback
680 def reset_portal() -> None:
681 self.portal = None
682
683 send: anyio.create_memory_object_stream[MutableMapping[str, Any] | None] = (
684 anyio.create_memory_object_stream(math.inf)
685 )
686 receive: anyio.create_memory_object_stream[MutableMapping[str, Any]] = anyio.create_memory_object_stream(
687 math.inf
688 )
689 for channel in (*send, *receive):
690 stack.callback(channel.close)
691 self.stream_send = StapledObjectStream(*send)
692 self.stream_receive = StapledObjectStream(*receive)
693 self.task = portal.start_task_soon(self.lifespan)
694 portal.call(self.wait_startup)
695
696 @stack.callback
697 def wait_shutdown() -> None:
698 portal.call(self.wait_shutdown)
699
700 self.exit_stack = stack.pop_all()
701
702 return self
703
704 def __exit__(self, *args: Any) -> None:
705 self.exit_stack.close()
706
707 async def lifespan(self) -> None:
708 scope = {"type": "lifespan", "state": self.app_state}
709 try:
710 await self.app(scope, self.stream_receive.receive, self.stream_send.send)
711 finally:
712 await self.stream_send.send(None)
713
714 async def wait_startup(self) -> None:
715 await self.stream_receive.send({"type": "lifespan.startup"})
716
717 async def receive() -> Any:
718 message = await self.stream_send.receive()
719 if message is None:
720 self.task.result()
721 return message
722
723 message = await receive()
724 assert message["type"] in (
725 "lifespan.startup.complete",
726 "lifespan.startup.failed",
727 )
728 if message["type"] == "lifespan.startup.failed":
729 await receive()
730
731 async def wait_shutdown(self) -> None:
732 async def receive() -> Any:
733 message = await self.stream_send.receive()
734 if message is None:
735 self.task.result()
736 return message
737
738 await self.stream_receive.send({"type": "lifespan.shutdown"})
739 message = await receive()
740 assert message["type"] in (
741 "lifespan.shutdown.complete",
742 "lifespan.shutdown.failed",
743 )
744 if message["type"] == "lifespan.shutdown.failed":
745 await receive()