1from __future__ import annotations
2
3import contextlib
4import inspect
5import io
6import json
7import math
8import sys
9import warnings
10from collections.abc import Awaitable, Generator, Iterable, Mapping, MutableMapping, Sequence
11from concurrent.futures import Future
12from contextlib import AbstractContextManager
13from types import GeneratorType
14from typing import (
15 Any,
16 Callable,
17 Literal,
18 TypedDict,
19 Union,
20 cast,
21)
22from urllib.parse import unquote, urljoin
23
24import anyio
25import anyio.abc
26import anyio.from_thread
27from anyio.streams.stapled import StapledObjectStream
28
29from starlette._utils import is_async_callable
30from starlette.types import ASGIApp, Message, Receive, Scope, Send
31from starlette.websockets import WebSocketDisconnect
32
33if sys.version_info >= (3, 10): # pragma: no cover
34 from typing import TypeGuard
35else: # pragma: no cover
36 from typing_extensions import TypeGuard
37
38try:
39 import httpx
40except ModuleNotFoundError: # pragma: no cover
41 raise RuntimeError(
42 "The starlette.testclient module requires the httpx package to be installed.\n"
43 "You can install this with:\n"
44 " $ pip install httpx\n"
45 )
46_PortalFactoryType = Callable[[], AbstractContextManager[anyio.abc.BlockingPortal]]
47
48ASGIInstance = Callable[[Receive, Send], Awaitable[None]]
49ASGI2App = Callable[[Scope], ASGIInstance]
50ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]]
51
52
53_RequestData = Mapping[str, Union[str, Iterable[str], bytes]]
54
55
56def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
57 if inspect.isclass(app):
58 return hasattr(app, "__await__")
59 return is_async_callable(app)
60
61
62class _WrapASGI2:
63 """
64 Provide an ASGI3 interface onto an ASGI2 app.
65 """
66
67 def __init__(self, app: ASGI2App) -> None:
68 self.app = app
69
70 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
71 instance = self.app(scope)
72 await instance(receive, send)
73
74
75class _AsyncBackend(TypedDict):
76 backend: str
77 backend_options: dict[str, Any]
78
79
80class _Upgrade(Exception):
81 def __init__(self, session: WebSocketTestSession) -> None:
82 self.session = session
83
84
85class WebSocketDenialResponse( # type: ignore[misc]
86 httpx.Response,
87 WebSocketDisconnect,
88):
89 """
90 A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
91 `WebSocket` is closed before being accepted with a `send_denial_response()`.
92 """
93
94
95class WebSocketTestSession:
96 def __init__(
97 self,
98 app: ASGI3App,
99 scope: Scope,
100 portal_factory: _PortalFactoryType,
101 ) -> None:
102 self.app = app
103 self.scope = scope
104 self.accepted_subprotocol = None
105 self.portal_factory = portal_factory
106 self.extra_headers = None
107
108 def __enter__(self) -> WebSocketTestSession:
109 with contextlib.ExitStack() as stack:
110 self.portal = portal = stack.enter_context(self.portal_factory())
111 fut, cs = portal.start_task(self._run)
112 stack.callback(fut.result)
113 stack.callback(portal.call, cs.cancel)
114 self.send({"type": "websocket.connect"})
115 message = self.receive()
116 self._raise_on_close(message)
117 self.accepted_subprotocol = message.get("subprotocol", None)
118 self.extra_headers = message.get("headers", None)
119 stack.callback(self.close, 1000)
120 self.exit_stack = stack.pop_all()
121 return self
122
123 def __exit__(self, *args: Any) -> bool | None:
124 return self.exit_stack.__exit__(*args)
125
126 async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
127 """
128 The sub-thread in which the websocket session runs.
129 """
130 send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
131 send_tx, send_rx = send
132 receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
133 receive_tx, receive_rx = receive
134 with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
135 self._receive_tx = receive_tx
136 self._send_rx = send_rx
137 task_status.started(cs)
138 await self.app(self.scope, receive_rx.receive, send_tx.send)
139
140 # wait for cs.cancel to be called before closing streams
141 await anyio.sleep_forever()
142
143 def _raise_on_close(self, message: Message) -> None:
144 if message["type"] == "websocket.close":
145 raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
146 elif message["type"] == "websocket.http.response.start":
147 status_code: int = message["status"]
148 headers: list[tuple[bytes, bytes]] = message["headers"]
149 body: list[bytes] = []
150 while True:
151 message = self.receive()
152 assert message["type"] == "websocket.http.response.body"
153 body.append(message["body"])
154 if not message.get("more_body", False):
155 break
156 raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
157
158 def send(self, message: Message) -> None:
159 self.portal.call(self._receive_tx.send, message)
160
161 def send_text(self, data: str) -> None:
162 self.send({"type": "websocket.receive", "text": data})
163
164 def send_bytes(self, data: bytes) -> None:
165 self.send({"type": "websocket.receive", "bytes": data})
166
167 def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None:
168 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
169 if mode == "text":
170 self.send({"type": "websocket.receive", "text": text})
171 else:
172 self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
173
174 def close(self, code: int = 1000, reason: str | None = None) -> None:
175 self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
176
177 def receive(self) -> Message:
178 return self.portal.call(self._send_rx.receive)
179
180 def receive_text(self) -> str:
181 message = self.receive()
182 self._raise_on_close(message)
183 return cast(str, message["text"])
184
185 def receive_bytes(self) -> bytes:
186 message = self.receive()
187 self._raise_on_close(message)
188 return cast(bytes, message["bytes"])
189
190 def receive_json(self, mode: Literal["text", "binary"] = "text") -> Any:
191 message = self.receive()
192 self._raise_on_close(message)
193 if mode == "text":
194 text = message["text"]
195 else:
196 text = message["bytes"].decode("utf-8")
197 return json.loads(text)
198
199
200class _TestClientTransport(httpx.BaseTransport):
201 def __init__(
202 self,
203 app: ASGI3App,
204 portal_factory: _PortalFactoryType,
205 raise_server_exceptions: bool = True,
206 root_path: str = "",
207 *,
208 client: tuple[str, int],
209 app_state: dict[str, Any],
210 ) -> None:
211 self.app = app
212 self.raise_server_exceptions = raise_server_exceptions
213 self.root_path = root_path
214 self.portal_factory = portal_factory
215 self.app_state = app_state
216 self.client = client
217
218 def handle_request(self, request: httpx.Request) -> httpx.Response:
219 scheme = request.url.scheme
220 netloc = request.url.netloc.decode(encoding="ascii")
221 path = request.url.path
222 raw_path = request.url.raw_path
223 query = request.url.query.decode(encoding="ascii")
224
225 default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
226
227 if ":" in netloc:
228 host, port_string = netloc.split(":", 1)
229 port = int(port_string)
230 else:
231 host = netloc
232 port = default_port
233
234 # Include the 'host' header.
235 if "host" in request.headers:
236 headers: list[tuple[bytes, bytes]] = []
237 elif port == default_port: # pragma: no cover
238 headers = [(b"host", host.encode())]
239 else: # pragma: no cover
240 headers = [(b"host", (f"{host}:{port}").encode())]
241
242 # Include other request headers.
243 headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
244
245 scope: dict[str, Any]
246
247 if scheme in {"ws", "wss"}:
248 subprotocol = request.headers.get("sec-websocket-protocol", None)
249 if subprotocol is None:
250 subprotocols: Sequence[str] = []
251 else:
252 subprotocols = [value.strip() for value in subprotocol.split(",")]
253 scope = {
254 "type": "websocket",
255 "path": unquote(path),
256 "raw_path": raw_path.split(b"?", 1)[0],
257 "root_path": self.root_path,
258 "scheme": scheme,
259 "query_string": query.encode(),
260 "headers": headers,
261 "client": self.client,
262 "server": [host, port],
263 "subprotocols": subprotocols,
264 "state": self.app_state.copy(),
265 "extensions": {"websocket.http.response": {}},
266 }
267 session = WebSocketTestSession(self.app, scope, self.portal_factory)
268 raise _Upgrade(session)
269
270 scope = {
271 "type": "http",
272 "http_version": "1.1",
273 "method": request.method,
274 "path": unquote(path),
275 "raw_path": raw_path.split(b"?", 1)[0],
276 "root_path": self.root_path,
277 "scheme": scheme,
278 "query_string": query.encode(),
279 "headers": headers,
280 "client": self.client,
281 "server": [host, port],
282 "extensions": {"http.response.debug": {}},
283 "state": self.app_state.copy(),
284 }
285
286 request_complete = False
287 response_started = False
288 response_complete: anyio.Event
289 raw_kwargs: dict[str, Any] = {"stream": io.BytesIO()}
290 template = None
291 context = None
292
293 async def receive() -> Message:
294 nonlocal request_complete
295
296 if request_complete:
297 if not response_complete.is_set():
298 await response_complete.wait()
299 return {"type": "http.disconnect"}
300
301 body = request.read()
302 if isinstance(body, str):
303 body_bytes: bytes = body.encode("utf-8") # pragma: no cover
304 elif body is None:
305 body_bytes = b"" # pragma: no cover
306 elif isinstance(body, GeneratorType):
307 try: # pragma: no cover
308 chunk = body.send(None)
309 if isinstance(chunk, str):
310 chunk = chunk.encode("utf-8")
311 return {"type": "http.request", "body": chunk, "more_body": True}
312 except StopIteration: # pragma: no cover
313 request_complete = True
314 return {"type": "http.request", "body": b""}
315 else:
316 body_bytes = body
317
318 request_complete = True
319 return {"type": "http.request", "body": body_bytes}
320
321 async def send(message: Message) -> None:
322 nonlocal raw_kwargs, response_started, template, context
323
324 if message["type"] == "http.response.start":
325 assert not response_started, 'Received multiple "http.response.start" messages.'
326 raw_kwargs["status_code"] = message["status"]
327 raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
328 response_started = True
329 elif message["type"] == "http.response.body":
330 assert response_started, 'Received "http.response.body" without "http.response.start".'
331 assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
332 body = message.get("body", b"")
333 more_body = message.get("more_body", False)
334 if request.method != "HEAD":
335 raw_kwargs["stream"].write(body)
336 if not more_body:
337 raw_kwargs["stream"].seek(0)
338 response_complete.set()
339 elif message["type"] == "http.response.debug":
340 template = message["info"]["template"]
341 context = message["info"]["context"]
342
343 try:
344 with self.portal_factory() as portal:
345 response_complete = portal.call(anyio.Event)
346 portal.call(self.app, scope, receive, send)
347 except BaseException as exc:
348 if self.raise_server_exceptions:
349 raise exc
350
351 if self.raise_server_exceptions:
352 assert response_started, "TestClient did not receive any response."
353 elif not response_started:
354 raw_kwargs = {
355 "status_code": 500,
356 "headers": [],
357 "stream": io.BytesIO(),
358 }
359
360 raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
361
362 response = httpx.Response(**raw_kwargs, request=request)
363 if template is not None:
364 response.template = template # type: ignore[attr-defined]
365 response.context = context # type: ignore[attr-defined]
366 return response
367
368
369class TestClient(httpx.Client):
370 __test__ = False
371 task: Future[None]
372 portal: anyio.abc.BlockingPortal | None = None
373
374 def __init__(
375 self,
376 app: ASGIApp,
377 base_url: str = "http://testserver",
378 raise_server_exceptions: bool = True,
379 root_path: str = "",
380 backend: Literal["asyncio", "trio"] = "asyncio",
381 backend_options: dict[str, Any] | None = None,
382 cookies: httpx._types.CookieTypes | None = None,
383 headers: dict[str, str] | None = None,
384 follow_redirects: bool = True,
385 client: tuple[str, int] = ("testclient", 50000),
386 ) -> None:
387 self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
388 if _is_asgi3(app):
389 asgi_app = app
390 else:
391 app = cast(ASGI2App, app) # type: ignore[assignment]
392 asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
393 self.app = asgi_app
394 self.app_state: dict[str, Any] = {}
395 transport = _TestClientTransport(
396 self.app,
397 portal_factory=self._portal_factory,
398 raise_server_exceptions=raise_server_exceptions,
399 root_path=root_path,
400 app_state=self.app_state,
401 client=client,
402 )
403 if headers is None:
404 headers = {}
405 headers.setdefault("user-agent", "testclient")
406 super().__init__(
407 base_url=base_url,
408 headers=headers,
409 transport=transport,
410 follow_redirects=follow_redirects,
411 cookies=cookies,
412 )
413
414 @contextlib.contextmanager
415 def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]:
416 if self.portal is not None:
417 yield self.portal
418 else:
419 with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
420 yield portal
421
422 def request( # type: ignore[override]
423 self,
424 method: str,
425 url: httpx._types.URLTypes,
426 *,
427 content: httpx._types.RequestContent | None = None,
428 data: _RequestData | None = None,
429 files: httpx._types.RequestFiles | None = None,
430 json: Any = None,
431 params: httpx._types.QueryParamTypes | None = None,
432 headers: httpx._types.HeaderTypes | None = None,
433 cookies: httpx._types.CookieTypes | None = None,
434 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
435 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
436 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
437 extensions: dict[str, Any] | None = None,
438 ) -> httpx.Response:
439 if timeout is not httpx.USE_CLIENT_DEFAULT:
440 warnings.warn(
441 "You should not use the 'timeout' argument with the TestClient. "
442 "See https://github.com/encode/starlette/issues/1108 for more information.",
443 DeprecationWarning,
444 )
445 url = self._merge_url(url)
446 return super().request(
447 method,
448 url,
449 content=content,
450 data=data,
451 files=files,
452 json=json,
453 params=params,
454 headers=headers,
455 cookies=cookies,
456 auth=auth,
457 follow_redirects=follow_redirects,
458 timeout=timeout,
459 extensions=extensions,
460 )
461
462 def get( # type: ignore[override]
463 self,
464 url: httpx._types.URLTypes,
465 *,
466 params: httpx._types.QueryParamTypes | None = None,
467 headers: httpx._types.HeaderTypes | None = None,
468 cookies: httpx._types.CookieTypes | None = None,
469 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
470 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
471 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
472 extensions: dict[str, Any] | None = None,
473 ) -> httpx.Response:
474 return super().get(
475 url,
476 params=params,
477 headers=headers,
478 cookies=cookies,
479 auth=auth,
480 follow_redirects=follow_redirects,
481 timeout=timeout,
482 extensions=extensions,
483 )
484
485 def options( # type: ignore[override]
486 self,
487 url: httpx._types.URLTypes,
488 *,
489 params: httpx._types.QueryParamTypes | None = None,
490 headers: httpx._types.HeaderTypes | None = None,
491 cookies: httpx._types.CookieTypes | None = None,
492 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
493 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
494 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
495 extensions: dict[str, Any] | None = None,
496 ) -> httpx.Response:
497 return super().options(
498 url,
499 params=params,
500 headers=headers,
501 cookies=cookies,
502 auth=auth,
503 follow_redirects=follow_redirects,
504 timeout=timeout,
505 extensions=extensions,
506 )
507
508 def head( # type: ignore[override]
509 self,
510 url: httpx._types.URLTypes,
511 *,
512 params: httpx._types.QueryParamTypes | None = None,
513 headers: httpx._types.HeaderTypes | None = None,
514 cookies: httpx._types.CookieTypes | None = None,
515 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
516 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
517 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
518 extensions: dict[str, Any] | None = None,
519 ) -> httpx.Response:
520 return super().head(
521 url,
522 params=params,
523 headers=headers,
524 cookies=cookies,
525 auth=auth,
526 follow_redirects=follow_redirects,
527 timeout=timeout,
528 extensions=extensions,
529 )
530
531 def post( # type: ignore[override]
532 self,
533 url: httpx._types.URLTypes,
534 *,
535 content: httpx._types.RequestContent | None = None,
536 data: _RequestData | None = None,
537 files: httpx._types.RequestFiles | None = None,
538 json: Any = None,
539 params: httpx._types.QueryParamTypes | None = None,
540 headers: httpx._types.HeaderTypes | None = None,
541 cookies: httpx._types.CookieTypes | None = None,
542 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
543 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
544 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
545 extensions: dict[str, Any] | None = None,
546 ) -> httpx.Response:
547 return super().post(
548 url,
549 content=content,
550 data=data,
551 files=files,
552 json=json,
553 params=params,
554 headers=headers,
555 cookies=cookies,
556 auth=auth,
557 follow_redirects=follow_redirects,
558 timeout=timeout,
559 extensions=extensions,
560 )
561
562 def put( # type: ignore[override]
563 self,
564 url: httpx._types.URLTypes,
565 *,
566 content: httpx._types.RequestContent | None = None,
567 data: _RequestData | None = None,
568 files: httpx._types.RequestFiles | None = None,
569 json: Any = None,
570 params: httpx._types.QueryParamTypes | None = None,
571 headers: httpx._types.HeaderTypes | None = None,
572 cookies: httpx._types.CookieTypes | None = None,
573 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
574 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
575 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
576 extensions: dict[str, Any] | None = None,
577 ) -> httpx.Response:
578 return super().put(
579 url,
580 content=content,
581 data=data,
582 files=files,
583 json=json,
584 params=params,
585 headers=headers,
586 cookies=cookies,
587 auth=auth,
588 follow_redirects=follow_redirects,
589 timeout=timeout,
590 extensions=extensions,
591 )
592
593 def patch( # type: ignore[override]
594 self,
595 url: httpx._types.URLTypes,
596 *,
597 content: httpx._types.RequestContent | None = None,
598 data: _RequestData | None = None,
599 files: httpx._types.RequestFiles | None = None,
600 json: Any = None,
601 params: httpx._types.QueryParamTypes | None = None,
602 headers: httpx._types.HeaderTypes | None = None,
603 cookies: httpx._types.CookieTypes | None = None,
604 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
605 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
606 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
607 extensions: dict[str, Any] | None = None,
608 ) -> httpx.Response:
609 return super().patch(
610 url,
611 content=content,
612 data=data,
613 files=files,
614 json=json,
615 params=params,
616 headers=headers,
617 cookies=cookies,
618 auth=auth,
619 follow_redirects=follow_redirects,
620 timeout=timeout,
621 extensions=extensions,
622 )
623
624 def delete( # type: ignore[override]
625 self,
626 url: httpx._types.URLTypes,
627 *,
628 params: httpx._types.QueryParamTypes | None = None,
629 headers: httpx._types.HeaderTypes | None = None,
630 cookies: httpx._types.CookieTypes | None = None,
631 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
632 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
633 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
634 extensions: dict[str, Any] | None = None,
635 ) -> httpx.Response:
636 return super().delete(
637 url,
638 params=params,
639 headers=headers,
640 cookies=cookies,
641 auth=auth,
642 follow_redirects=follow_redirects,
643 timeout=timeout,
644 extensions=extensions,
645 )
646
647 def websocket_connect(
648 self,
649 url: str,
650 subprotocols: Sequence[str] | None = None,
651 **kwargs: Any,
652 ) -> WebSocketTestSession:
653 url = urljoin("ws://testserver", url)
654 headers = kwargs.get("headers", {})
655 headers.setdefault("connection", "upgrade")
656 headers.setdefault("sec-websocket-key", "testserver==")
657 headers.setdefault("sec-websocket-version", "13")
658 if subprotocols is not None:
659 headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
660 kwargs["headers"] = headers
661 try:
662 super().request("GET", url, **kwargs)
663 except _Upgrade as exc:
664 session = exc.session
665 else:
666 raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
667
668 return session
669
670 def __enter__(self) -> TestClient:
671 with contextlib.ExitStack() as stack:
672 self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
673
674 @stack.callback
675 def reset_portal() -> None:
676 self.portal = None
677
678 send: anyio.create_memory_object_stream[MutableMapping[str, Any] | None] = (
679 anyio.create_memory_object_stream(math.inf)
680 )
681 receive: anyio.create_memory_object_stream[MutableMapping[str, Any]] = anyio.create_memory_object_stream(
682 math.inf
683 )
684 for channel in (*send, *receive):
685 stack.callback(channel.close)
686 self.stream_send = StapledObjectStream(*send)
687 self.stream_receive = StapledObjectStream(*receive)
688 self.task = portal.start_task_soon(self.lifespan)
689 portal.call(self.wait_startup)
690
691 @stack.callback
692 def wait_shutdown() -> None:
693 portal.call(self.wait_shutdown)
694
695 self.exit_stack = stack.pop_all()
696
697 return self
698
699 def __exit__(self, *args: Any) -> None:
700 self.exit_stack.close()
701
702 async def lifespan(self) -> None:
703 scope = {"type": "lifespan", "state": self.app_state}
704 try:
705 await self.app(scope, self.stream_receive.receive, self.stream_send.send)
706 finally:
707 await self.stream_send.send(None)
708
709 async def wait_startup(self) -> None:
710 await self.stream_receive.send({"type": "lifespan.startup"})
711
712 async def receive() -> Any:
713 message = await self.stream_send.receive()
714 if message is None:
715 self.task.result()
716 return message
717
718 message = await receive()
719 assert message["type"] in (
720 "lifespan.startup.complete",
721 "lifespan.startup.failed",
722 )
723 if message["type"] == "lifespan.startup.failed":
724 await receive()
725
726 async def wait_shutdown(self) -> None:
727 async def receive() -> Any:
728 message = await self.stream_send.receive()
729 if message is None:
730 self.task.result()
731 return message
732
733 await self.stream_receive.send({"type": "lifespan.shutdown"})
734 message = await receive()
735 assert message["type"] in (
736 "lifespan.shutdown.complete",
737 "lifespan.shutdown.failed",
738 )
739 if message["type"] == "lifespan.shutdown.failed":
740 await receive()