Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/testclient.py: 25%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import contextlib
4import inspect
5import io
6import json
7import math
8import sys
9import warnings
10from collections.abc import Awaitable, Callable, Generator, Iterable, Mapping, MutableMapping, Sequence
11from concurrent.futures import Future
12from contextlib import AbstractContextManager
13from types import GeneratorType
14from typing import (
15 Any,
16 Literal,
17 TypedDict,
18 TypeGuard,
19 cast,
20)
21from urllib.parse import unquote, urljoin
23import anyio
24import anyio.abc
25import anyio.from_thread
26from anyio.streams.stapled import StapledObjectStream
28from starlette._utils import is_async_callable
29from starlette.types import ASGIApp, Message, Receive, Scope, Send
30from starlette.websockets import WebSocketDisconnect
32if sys.version_info >= (3, 11): # pragma: no cover
33 from typing import Self
34else: # pragma: no cover
35 from typing_extensions import Self
37try:
38 import httpx
39except ModuleNotFoundError: # pragma: no cover
40 raise RuntimeError(
41 "The starlette.testclient module requires the httpx package to be installed.\n"
42 "You can install this with:\n"
43 " $ pip install httpx\n"
44 )
45_PortalFactoryType = Callable[[], AbstractContextManager[anyio.abc.BlockingPortal]]
47ASGIInstance = Callable[[Receive, Send], Awaitable[None]]
48ASGI2App = Callable[[Scope], ASGIInstance]
49ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]]
52_RequestData = Mapping[str, str | Iterable[str] | bytes]
55def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
56 if inspect.isclass(app):
57 return hasattr(app, "__await__")
58 return is_async_callable(app)
61class _WrapASGI2:
62 """
63 Provide an ASGI3 interface onto an ASGI2 app.
64 """
66 def __init__(self, app: ASGI2App) -> None:
67 self.app = app
69 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
70 instance = self.app(scope)
71 await instance(receive, send)
74class _AsyncBackend(TypedDict):
75 backend: str
76 backend_options: dict[str, Any]
79class _Upgrade(Exception):
80 def __init__(self, session: WebSocketTestSession) -> None:
81 self.session = session
84class WebSocketDenialResponse( # type: ignore[misc]
85 httpx.Response,
86 WebSocketDisconnect,
87):
88 """
89 A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
90 `WebSocket` is closed before being accepted with a `send_denial_response()`.
91 """
94class WebSocketTestSession:
95 def __init__(
96 self,
97 app: ASGI3App,
98 scope: Scope,
99 portal_factory: _PortalFactoryType,
100 ) -> None:
101 self.app = app
102 self.scope = scope
103 self.accepted_subprotocol = None
104 self.portal_factory = portal_factory
105 self.extra_headers = None
107 def __enter__(self) -> WebSocketTestSession:
108 with contextlib.ExitStack() as stack:
109 self.portal = portal = stack.enter_context(self.portal_factory())
110 fut, cs = portal.start_task(self._run)
111 stack.callback(fut.result)
112 stack.callback(portal.call, cs.cancel)
113 self.send({"type": "websocket.connect"})
114 message = self.receive()
115 self._raise_on_close(message)
116 self.accepted_subprotocol = message.get("subprotocol", None)
117 self.extra_headers = message.get("headers", None)
118 stack.callback(self.close, 1000)
119 self.exit_stack = stack.pop_all()
120 return self
122 def __exit__(self, *args: Any) -> bool | None:
123 return self.exit_stack.__exit__(*args)
125 async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
126 """
127 The sub-thread in which the websocket session runs.
128 """
129 send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
130 send_tx, send_rx = send
131 receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
132 receive_tx, receive_rx = receive
133 with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
134 self._receive_tx = receive_tx
135 self._send_rx = send_rx
136 task_status.started(cs)
137 await self.app(self.scope, receive_rx.receive, send_tx.send)
139 # wait for cs.cancel to be called before closing streams
140 await anyio.sleep_forever()
142 def _raise_on_close(self, message: Message) -> None:
143 if message["type"] == "websocket.close":
144 raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
145 elif message["type"] == "websocket.http.response.start":
146 status_code: int = message["status"]
147 headers: list[tuple[bytes, bytes]] = message["headers"]
148 body: list[bytes] = []
149 while True:
150 message = self.receive()
151 assert message["type"] == "websocket.http.response.body"
152 body.append(message["body"])
153 if not message.get("more_body", False):
154 break
155 raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
157 def send(self, message: Message) -> None:
158 self.portal.call(self._receive_tx.send, message)
160 def send_text(self, data: str) -> None:
161 self.send({"type": "websocket.receive", "text": data})
163 def send_bytes(self, data: bytes) -> None:
164 self.send({"type": "websocket.receive", "bytes": data})
166 def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None:
167 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
168 if mode == "text":
169 self.send({"type": "websocket.receive", "text": text})
170 else:
171 self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
173 def close(self, code: int = 1000, reason: str | None = None) -> None:
174 self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
176 def receive(self) -> Message:
177 return self.portal.call(self._send_rx.receive)
179 def receive_text(self) -> str:
180 message = self.receive()
181 self._raise_on_close(message)
182 return cast(str, message["text"])
184 def receive_bytes(self) -> bytes:
185 message = self.receive()
186 self._raise_on_close(message)
187 return cast(bytes, message["bytes"])
189 def receive_json(self, mode: Literal["text", "binary"] = "text") -> Any:
190 message = self.receive()
191 self._raise_on_close(message)
192 if mode == "text":
193 text = message["text"]
194 else:
195 text = message["bytes"].decode("utf-8")
196 return json.loads(text)
199class _TestClientTransport(httpx.BaseTransport):
200 def __init__(
201 self,
202 app: ASGI3App,
203 portal_factory: _PortalFactoryType,
204 raise_server_exceptions: bool = True,
205 root_path: str = "",
206 *,
207 client: tuple[str, int],
208 app_state: dict[str, Any],
209 ) -> None:
210 self.app = app
211 self.raise_server_exceptions = raise_server_exceptions
212 self.root_path = root_path
213 self.portal_factory = portal_factory
214 self.app_state = app_state
215 self.client = client
217 def handle_request(self, request: httpx.Request) -> httpx.Response:
218 scheme = request.url.scheme
219 netloc = request.url.netloc.decode(encoding="ascii")
220 path = request.url.path
221 raw_path = request.url.raw_path
222 query = request.url.query.decode(encoding="ascii")
224 default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
226 if ":" in netloc:
227 host, port_string = netloc.split(":", 1)
228 port = int(port_string)
229 else:
230 host = netloc
231 port = default_port
233 # Include the 'host' header.
234 if "host" in request.headers:
235 headers: list[tuple[bytes, bytes]] = []
236 elif port == default_port: # pragma: no cover
237 headers = [(b"host", host.encode())]
238 else: # pragma: no cover
239 headers = [(b"host", (f"{host}:{port}").encode())]
241 # Include other request headers.
242 headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
244 scope: dict[str, Any]
246 if scheme in {"ws", "wss"}:
247 subprotocol = request.headers.get("sec-websocket-protocol", None)
248 if subprotocol is None:
249 subprotocols: Sequence[str] = []
250 else:
251 subprotocols = [value.strip() for value in subprotocol.split(",")]
252 scope = {
253 "type": "websocket",
254 "path": unquote(path),
255 "raw_path": raw_path.split(b"?", 1)[0],
256 "root_path": self.root_path,
257 "scheme": scheme,
258 "query_string": query.encode(),
259 "headers": headers,
260 "client": self.client,
261 "server": [host, port],
262 "subprotocols": subprotocols,
263 "state": self.app_state.copy(),
264 "extensions": {"websocket.http.response": {}},
265 }
266 session = WebSocketTestSession(self.app, scope, self.portal_factory)
267 raise _Upgrade(session)
269 scope = {
270 "type": "http",
271 "http_version": "1.1",
272 "method": request.method,
273 "path": unquote(path),
274 "raw_path": raw_path.split(b"?", 1)[0],
275 "root_path": self.root_path,
276 "scheme": scheme,
277 "query_string": query.encode(),
278 "headers": headers,
279 "client": self.client,
280 "server": [host, port],
281 "extensions": {"http.response.debug": {}},
282 "state": self.app_state.copy(),
283 }
285 request_complete = False
286 response_started = False
287 response_complete: anyio.Event
288 raw_kwargs: dict[str, Any] = {"stream": io.BytesIO()}
289 template = None
290 context = None
292 async def receive() -> Message:
293 nonlocal request_complete
295 if request_complete:
296 if not response_complete.is_set():
297 await response_complete.wait()
298 return {"type": "http.disconnect"}
300 body = request.read()
301 if isinstance(body, str):
302 body_bytes: bytes = body.encode("utf-8") # pragma: no cover
303 elif body is None:
304 body_bytes = b"" # pragma: no cover
305 elif isinstance(body, GeneratorType):
306 try: # pragma: no cover
307 chunk = body.send(None)
308 if isinstance(chunk, str):
309 chunk = chunk.encode("utf-8")
310 return {"type": "http.request", "body": chunk, "more_body": True}
311 except StopIteration: # pragma: no cover
312 request_complete = True
313 return {"type": "http.request", "body": b""}
314 else:
315 body_bytes = body
317 request_complete = True
318 return {"type": "http.request", "body": body_bytes}
320 async def send(message: Message) -> None:
321 nonlocal raw_kwargs, response_started, template, context
323 if message["type"] == "http.response.start":
324 assert not response_started, 'Received multiple "http.response.start" messages.'
325 raw_kwargs["status_code"] = message["status"]
326 raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
327 response_started = True
328 elif message["type"] == "http.response.body":
329 assert response_started, 'Received "http.response.body" without "http.response.start".'
330 assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
331 body = message.get("body", b"")
332 more_body = message.get("more_body", False)
333 if request.method != "HEAD":
334 raw_kwargs["stream"].write(body)
335 if not more_body:
336 raw_kwargs["stream"].seek(0)
337 response_complete.set()
338 elif message["type"] == "http.response.debug":
339 template = message["info"]["template"]
340 context = message["info"]["context"]
342 try:
343 with self.portal_factory() as portal:
344 response_complete = portal.call(anyio.Event)
345 portal.call(self.app, scope, receive, send)
346 except BaseException as exc:
347 if self.raise_server_exceptions:
348 raise exc
350 if self.raise_server_exceptions:
351 assert response_started, "TestClient did not receive any response."
352 elif not response_started:
353 raw_kwargs = {
354 "status_code": 500,
355 "headers": [],
356 "stream": io.BytesIO(),
357 }
359 raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
361 response = httpx.Response(**raw_kwargs, request=request)
362 if template is not None:
363 response.template = template # type: ignore[attr-defined]
364 response.context = context # type: ignore[attr-defined]
365 return response
368class TestClient(httpx.Client):
369 __test__ = False
370 task: Future[None]
371 portal: anyio.abc.BlockingPortal | None = None
373 def __init__(
374 self,
375 app: ASGIApp,
376 base_url: str = "http://testserver",
377 raise_server_exceptions: bool = True,
378 root_path: str = "",
379 backend: Literal["asyncio", "trio"] = "asyncio",
380 backend_options: dict[str, Any] | None = None,
381 cookies: httpx._types.CookieTypes | None = None,
382 headers: dict[str, str] | None = None,
383 follow_redirects: bool = True,
384 client: tuple[str, int] = ("testclient", 50000),
385 ) -> None:
386 self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
387 if _is_asgi3(app):
388 asgi_app = app
389 else:
390 app = cast(ASGI2App, app) # type: ignore[assignment]
391 asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
392 self.app = asgi_app
393 self.app_state: dict[str, Any] = {}
394 transport = _TestClientTransport(
395 self.app,
396 portal_factory=self._portal_factory,
397 raise_server_exceptions=raise_server_exceptions,
398 root_path=root_path,
399 app_state=self.app_state,
400 client=client,
401 )
402 if headers is None:
403 headers = {}
404 headers.setdefault("user-agent", "testclient")
405 super().__init__(
406 base_url=base_url,
407 headers=headers,
408 transport=transport,
409 follow_redirects=follow_redirects,
410 cookies=cookies,
411 )
413 @contextlib.contextmanager
414 def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]:
415 if self.portal is not None:
416 yield self.portal
417 else:
418 with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
419 yield portal
421 def request( # type: ignore[override]
422 self,
423 method: str,
424 url: httpx._types.URLTypes,
425 *,
426 content: httpx._types.RequestContent | None = None,
427 data: _RequestData | None = None,
428 files: httpx._types.RequestFiles | None = None,
429 json: Any = None,
430 params: httpx._types.QueryParamTypes | None = None,
431 headers: httpx._types.HeaderTypes | None = None,
432 cookies: httpx._types.CookieTypes | None = None,
433 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
434 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
435 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
436 extensions: dict[str, Any] | None = None,
437 ) -> httpx.Response:
438 if timeout is not httpx.USE_CLIENT_DEFAULT:
439 warnings.warn(
440 "You should not use the 'timeout' argument with the TestClient. "
441 "See https://github.com/Kludex/starlette/issues/1108 for more information.",
442 DeprecationWarning,
443 )
444 url = self._merge_url(url)
445 return super().request(
446 method,
447 url,
448 content=content,
449 data=data,
450 files=files,
451 json=json,
452 params=params,
453 headers=headers,
454 cookies=cookies,
455 auth=auth,
456 follow_redirects=follow_redirects,
457 timeout=timeout,
458 extensions=extensions,
459 )
461 def get( # type: ignore[override]
462 self,
463 url: httpx._types.URLTypes,
464 *,
465 params: httpx._types.QueryParamTypes | None = None,
466 headers: httpx._types.HeaderTypes | None = None,
467 cookies: httpx._types.CookieTypes | None = None,
468 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
469 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
470 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
471 extensions: dict[str, Any] | None = None,
472 ) -> httpx.Response:
473 return super().get(
474 url,
475 params=params,
476 headers=headers,
477 cookies=cookies,
478 auth=auth,
479 follow_redirects=follow_redirects,
480 timeout=timeout,
481 extensions=extensions,
482 )
484 def options( # type: ignore[override]
485 self,
486 url: httpx._types.URLTypes,
487 *,
488 params: httpx._types.QueryParamTypes | None = None,
489 headers: httpx._types.HeaderTypes | None = None,
490 cookies: httpx._types.CookieTypes | None = None,
491 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
492 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
493 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
494 extensions: dict[str, Any] | None = None,
495 ) -> httpx.Response:
496 return super().options(
497 url,
498 params=params,
499 headers=headers,
500 cookies=cookies,
501 auth=auth,
502 follow_redirects=follow_redirects,
503 timeout=timeout,
504 extensions=extensions,
505 )
507 def head( # type: ignore[override]
508 self,
509 url: httpx._types.URLTypes,
510 *,
511 params: httpx._types.QueryParamTypes | None = None,
512 headers: httpx._types.HeaderTypes | None = None,
513 cookies: httpx._types.CookieTypes | None = None,
514 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
515 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
516 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
517 extensions: dict[str, Any] | None = None,
518 ) -> httpx.Response:
519 return super().head(
520 url,
521 params=params,
522 headers=headers,
523 cookies=cookies,
524 auth=auth,
525 follow_redirects=follow_redirects,
526 timeout=timeout,
527 extensions=extensions,
528 )
530 def post( # type: ignore[override]
531 self,
532 url: httpx._types.URLTypes,
533 *,
534 content: httpx._types.RequestContent | None = None,
535 data: _RequestData | None = None,
536 files: httpx._types.RequestFiles | None = None,
537 json: Any = None,
538 params: httpx._types.QueryParamTypes | None = None,
539 headers: httpx._types.HeaderTypes | None = None,
540 cookies: httpx._types.CookieTypes | None = None,
541 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
542 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
543 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
544 extensions: dict[str, Any] | None = None,
545 ) -> httpx.Response:
546 return super().post(
547 url,
548 content=content,
549 data=data,
550 files=files,
551 json=json,
552 params=params,
553 headers=headers,
554 cookies=cookies,
555 auth=auth,
556 follow_redirects=follow_redirects,
557 timeout=timeout,
558 extensions=extensions,
559 )
561 def put( # type: ignore[override]
562 self,
563 url: httpx._types.URLTypes,
564 *,
565 content: httpx._types.RequestContent | None = None,
566 data: _RequestData | None = None,
567 files: httpx._types.RequestFiles | None = None,
568 json: Any = None,
569 params: httpx._types.QueryParamTypes | None = None,
570 headers: httpx._types.HeaderTypes | None = None,
571 cookies: httpx._types.CookieTypes | None = None,
572 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
573 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
574 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
575 extensions: dict[str, Any] | None = None,
576 ) -> httpx.Response:
577 return super().put(
578 url,
579 content=content,
580 data=data,
581 files=files,
582 json=json,
583 params=params,
584 headers=headers,
585 cookies=cookies,
586 auth=auth,
587 follow_redirects=follow_redirects,
588 timeout=timeout,
589 extensions=extensions,
590 )
592 def patch( # type: ignore[override]
593 self,
594 url: httpx._types.URLTypes,
595 *,
596 content: httpx._types.RequestContent | None = None,
597 data: _RequestData | None = None,
598 files: httpx._types.RequestFiles | None = None,
599 json: Any = None,
600 params: httpx._types.QueryParamTypes | None = None,
601 headers: httpx._types.HeaderTypes | None = None,
602 cookies: httpx._types.CookieTypes | None = None,
603 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
604 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
605 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
606 extensions: dict[str, Any] | None = None,
607 ) -> httpx.Response:
608 return super().patch(
609 url,
610 content=content,
611 data=data,
612 files=files,
613 json=json,
614 params=params,
615 headers=headers,
616 cookies=cookies,
617 auth=auth,
618 follow_redirects=follow_redirects,
619 timeout=timeout,
620 extensions=extensions,
621 )
623 def delete( # type: ignore[override]
624 self,
625 url: httpx._types.URLTypes,
626 *,
627 params: httpx._types.QueryParamTypes | None = None,
628 headers: httpx._types.HeaderTypes | None = None,
629 cookies: httpx._types.CookieTypes | None = None,
630 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
631 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
632 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
633 extensions: dict[str, Any] | None = None,
634 ) -> httpx.Response:
635 return super().delete(
636 url,
637 params=params,
638 headers=headers,
639 cookies=cookies,
640 auth=auth,
641 follow_redirects=follow_redirects,
642 timeout=timeout,
643 extensions=extensions,
644 )
646 def websocket_connect(
647 self,
648 url: str,
649 subprotocols: Sequence[str] | None = None,
650 **kwargs: Any,
651 ) -> WebSocketTestSession:
652 url = urljoin("ws://testserver", url)
653 headers = kwargs.get("headers", {})
654 headers.setdefault("connection", "upgrade")
655 headers.setdefault("sec-websocket-key", "testserver==")
656 headers.setdefault("sec-websocket-version", "13")
657 if subprotocols is not None:
658 headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
659 kwargs["headers"] = headers
660 try:
661 super().request("GET", url, **kwargs)
662 except _Upgrade as exc:
663 session = exc.session
664 else:
665 raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
667 return session
669 def __enter__(self) -> Self:
670 with contextlib.ExitStack() as stack:
671 self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
673 @stack.callback
674 def reset_portal() -> None:
675 self.portal = None
677 send: anyio.create_memory_object_stream[MutableMapping[str, Any] | None] = (
678 anyio.create_memory_object_stream(math.inf)
679 )
680 receive: anyio.create_memory_object_stream[MutableMapping[str, Any]] = anyio.create_memory_object_stream(
681 math.inf
682 )
683 for channel in (*send, *receive):
684 stack.callback(channel.close)
685 self.stream_send = StapledObjectStream(*send)
686 self.stream_receive = StapledObjectStream(*receive)
687 self.task = portal.start_task_soon(self.lifespan)
688 portal.call(self.wait_startup)
690 @stack.callback
691 def wait_shutdown() -> None:
692 portal.call(self.wait_shutdown)
694 self.exit_stack = stack.pop_all()
696 return self
698 def __exit__(self, *args: Any) -> None:
699 self.exit_stack.close()
701 async def lifespan(self) -> None:
702 scope = {"type": "lifespan", "state": self.app_state}
703 try:
704 await self.app(scope, self.stream_receive.receive, self.stream_send.send)
705 finally:
706 await self.stream_send.send(None)
708 async def wait_startup(self) -> None:
709 await self.stream_receive.send({"type": "lifespan.startup"})
711 async def receive() -> Any:
712 message = await self.stream_send.receive()
713 if message is None:
714 self.task.result()
715 return message
717 message = await receive()
718 assert message["type"] in (
719 "lifespan.startup.complete",
720 "lifespan.startup.failed",
721 )
722 if message["type"] == "lifespan.startup.failed":
723 await receive()
725 async def wait_shutdown(self) -> None:
726 async def receive() -> Any:
727 message = await self.stream_send.receive()
728 if message is None:
729 self.task.result()
730 return message
732 await self.stream_receive.send({"type": "lifespan.shutdown"})
733 message = await receive()
734 assert message["type"] in (
735 "lifespan.shutdown.complete",
736 "lifespan.shutdown.failed",
737 )
738 if message["type"] == "lifespan.shutdown.failed":
739 await receive()