Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/testclient.py: 25%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import contextlib
4import inspect
5import io
6import json
7import math
8import sys
9import warnings
10from collections.abc import Awaitable, Callable, Generator, Iterable, Mapping, MutableMapping, Sequence
11from concurrent.futures import Future
12from contextlib import AbstractContextManager
13from types import GeneratorType
14from typing import (
15 TYPE_CHECKING,
16 Any,
17 Literal,
18 TypedDict,
19 TypeGuard,
20 cast,
21)
22from urllib.parse import unquote, urljoin
24import anyio
25import anyio.abc
26import anyio.from_thread
27from anyio.streams.stapled import StapledObjectStream
29from starlette._utils import is_async_callable
30from starlette.exceptions import StarletteDeprecationWarning
31from starlette.types import ASGIApp, Message, Receive, Scope, Send
32from starlette.websockets import WebSocketDisconnect
34if sys.version_info >= (3, 11): # pragma: no cover
35 from typing import Self
36else: # pragma: no cover
37 from typing_extensions import Self
39if TYPE_CHECKING:
40 import httpx2 as httpx
41else:
42 try:
43 import httpx2 as httpx
44 except ModuleNotFoundError: # pragma: no cover
45 try:
46 import httpx
47 except ModuleNotFoundError:
48 raise RuntimeError(
49 "The starlette.testclient module requires the httpx2 package to be installed.\n"
50 "You can install this with:\n"
51 " $ pip install httpx2\n"
52 )
53 else:
54 warnings.warn(
55 "Using `httpx` with `starlette.testclient` is deprecated; install `httpx2` instead.",
56 StarletteDeprecationWarning,
57 stacklevel=2,
58 )
59_PortalFactoryType = Callable[[], AbstractContextManager[anyio.abc.BlockingPortal]]
61ASGIInstance = Callable[[Receive, Send], Awaitable[None]]
62ASGI2App = Callable[[Scope], ASGIInstance]
63ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]]
66_RequestData = Mapping[str, str | Iterable[str] | bytes]
69def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
70 if inspect.isclass(app):
71 return hasattr(app, "__await__")
72 return is_async_callable(app)
75class _WrapASGI2:
76 """
77 Provide an ASGI3 interface onto an ASGI2 app.
78 """
80 def __init__(self, app: ASGI2App) -> None:
81 self.app = app
83 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
84 instance = self.app(scope)
85 await instance(receive, send)
88class _AsyncBackend(TypedDict):
89 backend: str
90 backend_options: dict[str, Any]
93class _Upgrade(Exception):
94 def __init__(self, session: WebSocketTestSession) -> None:
95 self.session = session
98class WebSocketDenialResponse( # type: ignore[misc]
99 httpx.Response,
100 WebSocketDisconnect,
101):
102 """
103 A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
104 `WebSocket` is closed before being accepted with a `send_denial_response()`.
105 """
108class WebSocketTestSession:
109 def __init__(
110 self,
111 app: ASGI3App,
112 scope: Scope,
113 portal_factory: _PortalFactoryType,
114 ) -> None:
115 self.app = app
116 self.scope = scope
117 self.accepted_subprotocol = None
118 self.portal_factory = portal_factory
119 self.extra_headers = None
121 def __enter__(self) -> WebSocketTestSession:
122 with contextlib.ExitStack() as stack:
123 self.portal = portal = stack.enter_context(self.portal_factory())
124 fut, cs = portal.start_task(self._run)
125 stack.callback(fut.result)
126 stack.callback(portal.call, cs.cancel)
127 self.send({"type": "websocket.connect"})
128 message = self.receive()
129 self._raise_on_close(message)
130 self.accepted_subprotocol = message.get("subprotocol", None)
131 self.extra_headers = message.get("headers", None)
132 stack.callback(self.close, 1000)
133 self.exit_stack = stack.pop_all()
134 return self
136 def __exit__(self, *args: Any) -> bool | None:
137 return self.exit_stack.__exit__(*args)
139 async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
140 """
141 The sub-thread in which the websocket session runs.
142 """
143 send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
144 send_tx, send_rx = send
145 receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
146 receive_tx, receive_rx = receive
147 with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
148 self._receive_tx = receive_tx
149 self._send_rx = send_rx
150 task_status.started(cs)
151 await self.app(self.scope, receive_rx.receive, send_tx.send)
153 # wait for cs.cancel to be called before closing streams
154 await anyio.sleep_forever()
156 def _raise_on_close(self, message: Message) -> None:
157 if message["type"] == "websocket.close":
158 raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
159 elif message["type"] == "websocket.http.response.start":
160 status_code: int = message["status"]
161 headers: list[tuple[bytes, bytes]] = message["headers"]
162 body: list[bytes] = []
163 while True:
164 message = self.receive()
165 assert message["type"] == "websocket.http.response.body"
166 body.append(message["body"])
167 if not message.get("more_body", False):
168 break
169 raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
171 def send(self, message: Message) -> None:
172 self.portal.call(self._receive_tx.send, message)
174 def send_text(self, data: str) -> None:
175 self.send({"type": "websocket.receive", "text": data})
177 def send_bytes(self, data: bytes) -> None:
178 self.send({"type": "websocket.receive", "bytes": data})
180 def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None:
181 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
182 if mode == "text":
183 self.send({"type": "websocket.receive", "text": text})
184 else:
185 self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
187 def close(self, code: int = 1000, reason: str | None = None) -> None:
188 self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
190 def receive(self) -> Message:
191 return self.portal.call(self._send_rx.receive)
193 def receive_text(self) -> str:
194 message = self.receive()
195 self._raise_on_close(message)
196 return cast(str, message["text"])
198 def receive_bytes(self) -> bytes:
199 message = self.receive()
200 self._raise_on_close(message)
201 return cast(bytes, message["bytes"])
203 def receive_json(self, mode: Literal["text", "binary"] = "text") -> Any:
204 message = self.receive()
205 self._raise_on_close(message)
206 if mode == "text":
207 text = message["text"]
208 else:
209 text = message["bytes"].decode("utf-8")
210 return json.loads(text)
213class _TestClientTransport(httpx.BaseTransport):
214 def __init__(
215 self,
216 app: ASGI3App,
217 portal_factory: _PortalFactoryType,
218 raise_server_exceptions: bool = True,
219 root_path: str = "",
220 *,
221 client: tuple[str, int],
222 app_state: dict[str, Any],
223 ) -> None:
224 self.app = app
225 self.raise_server_exceptions = raise_server_exceptions
226 self.root_path = root_path
227 self.portal_factory = portal_factory
228 self.app_state = app_state
229 self.client = client
231 def handle_request(self, request: httpx.Request) -> httpx.Response:
232 scheme = request.url.scheme
233 netloc = request.url.netloc.decode(encoding="ascii")
234 path = request.url.path
235 raw_path = request.url.raw_path
236 query = request.url.query.decode(encoding="ascii")
238 default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
240 if ":" in netloc:
241 host, port_string = netloc.split(":", 1)
242 port = int(port_string)
243 else:
244 host = netloc
245 port = default_port
247 # Include the 'host' header.
248 if "host" in request.headers:
249 headers: list[tuple[bytes, bytes]] = []
250 elif port == default_port: # pragma: no cover
251 headers = [(b"host", host.encode())]
252 else: # pragma: no cover
253 headers = [(b"host", (f"{host}:{port}").encode())]
255 # Include other request headers.
256 headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
258 scope: dict[str, Any]
260 if scheme in {"ws", "wss"}:
261 subprotocol = request.headers.get("sec-websocket-protocol", None)
262 if subprotocol is None:
263 subprotocols: Sequence[str] = []
264 else:
265 subprotocols = [value.strip() for value in subprotocol.split(",")]
266 scope = {
267 "type": "websocket",
268 "path": unquote(path),
269 "raw_path": raw_path.split(b"?", 1)[0],
270 "root_path": self.root_path,
271 "scheme": scheme,
272 "query_string": query.encode(),
273 "headers": headers,
274 "client": self.client,
275 "server": [host, port],
276 "subprotocols": subprotocols,
277 "state": self.app_state.copy(),
278 "extensions": {"websocket.http.response": {}},
279 }
280 session = WebSocketTestSession(self.app, scope, self.portal_factory)
281 raise _Upgrade(session)
283 scope = {
284 "type": "http",
285 "http_version": "1.1",
286 "method": request.method,
287 "path": unquote(path),
288 "raw_path": raw_path.split(b"?", 1)[0],
289 "root_path": self.root_path,
290 "scheme": scheme,
291 "query_string": query.encode(),
292 "headers": headers,
293 "client": self.client,
294 "server": [host, port],
295 "extensions": {"http.response.debug": {}},
296 "state": self.app_state.copy(),
297 }
299 request_complete = False
300 response_started = False
301 response_complete: anyio.Event
302 raw_kwargs: dict[str, Any] = {"stream": io.BytesIO()}
303 template = None
304 context = None
306 async def receive() -> Message:
307 nonlocal request_complete
309 if request_complete:
310 if not response_complete.is_set():
311 await response_complete.wait()
312 return {"type": "http.disconnect"}
314 body = request.read()
315 if isinstance(body, str):
316 body_bytes: bytes = body.encode("utf-8") # pragma: no cover
317 elif body is None:
318 body_bytes = b"" # pragma: no cover
319 elif isinstance(body, GeneratorType):
320 try: # pragma: no cover
321 chunk = body.send(None)
322 if isinstance(chunk, str):
323 chunk = chunk.encode("utf-8")
324 return {"type": "http.request", "body": chunk, "more_body": True}
325 except StopIteration: # pragma: no cover
326 request_complete = True
327 return {"type": "http.request", "body": b""}
328 else:
329 body_bytes = body
331 request_complete = True
332 return {"type": "http.request", "body": body_bytes}
334 async def send(message: Message) -> None:
335 nonlocal raw_kwargs, response_started, template, context
337 if message["type"] == "http.response.start":
338 assert not response_started, 'Received multiple "http.response.start" messages.'
339 raw_kwargs["status_code"] = message["status"]
340 raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
341 response_started = True
342 elif message["type"] == "http.response.body":
343 assert response_started, 'Received "http.response.body" without "http.response.start".'
344 assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
345 body = message.get("body", b"")
346 more_body = message.get("more_body", False)
347 if request.method != "HEAD":
348 raw_kwargs["stream"].write(body)
349 if not more_body:
350 raw_kwargs["stream"].seek(0)
351 response_complete.set()
352 elif message["type"] == "http.response.debug":
353 template = message["info"]["template"]
354 context = message["info"]["context"]
356 try:
357 with self.portal_factory() as portal:
358 response_complete = portal.call(anyio.Event)
359 portal.call(self.app, scope, receive, send)
360 except BaseException as exc:
361 if self.raise_server_exceptions:
362 raise exc
364 if self.raise_server_exceptions:
365 assert response_started, "TestClient did not receive any response."
366 elif not response_started:
367 raw_kwargs = {
368 "status_code": 500,
369 "headers": [],
370 "stream": io.BytesIO(),
371 }
373 raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
375 response = httpx.Response(**raw_kwargs, request=request)
376 if template is not None:
377 response.template = template # type: ignore[attr-defined]
378 response.context = context # type: ignore[attr-defined]
379 return response
382class TestClient(httpx.Client):
383 __test__ = False
384 task: Future[None]
385 portal: anyio.abc.BlockingPortal | None = None
387 def __init__(
388 self,
389 app: ASGIApp,
390 base_url: str = "http://testserver",
391 raise_server_exceptions: bool = True,
392 root_path: str = "",
393 backend: Literal["asyncio", "trio"] = "asyncio",
394 backend_options: dict[str, Any] | None = None,
395 cookies: httpx._types.CookieTypes | None = None,
396 headers: dict[str, str] | None = None,
397 follow_redirects: bool = True,
398 client: tuple[str, int] = ("testclient", 50000),
399 ) -> None:
400 self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
401 if _is_asgi3(app):
402 asgi_app = app
403 else:
404 app = cast(ASGI2App, app) # type: ignore[assignment]
405 asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
406 self.app = asgi_app
407 self.app_state: dict[str, Any] = {}
408 transport = _TestClientTransport(
409 self.app,
410 portal_factory=self._portal_factory,
411 raise_server_exceptions=raise_server_exceptions,
412 root_path=root_path,
413 app_state=self.app_state,
414 client=client,
415 )
416 if headers is None:
417 headers = {}
418 headers.setdefault("user-agent", "testclient")
419 super().__init__(
420 base_url=base_url,
421 headers=headers,
422 transport=transport,
423 follow_redirects=follow_redirects,
424 cookies=cookies,
425 )
427 @contextlib.contextmanager
428 def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]:
429 if self.portal is not None:
430 yield self.portal
431 else:
432 with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
433 yield portal
435 def request( # type: ignore[override]
436 self,
437 method: str,
438 url: httpx._types.URLTypes,
439 *,
440 content: httpx._types.RequestContent | None = None,
441 data: _RequestData | None = None,
442 files: httpx._types.RequestFiles | None = None,
443 json: Any = None,
444 params: httpx._types.QueryParamTypes | None = None,
445 headers: httpx._types.HeaderTypes | None = None,
446 cookies: httpx._types.CookieTypes | None = None,
447 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
448 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
449 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
450 extensions: dict[str, Any] | None = None,
451 ) -> httpx.Response:
452 if timeout is not httpx.USE_CLIENT_DEFAULT:
453 warnings.warn(
454 "You should not use the 'timeout' argument with the TestClient. "
455 "See https://github.com/Kludex/starlette/issues/1108 for more information.",
456 DeprecationWarning,
457 )
458 url = self._merge_url(url)
459 return super().request(
460 method,
461 url,
462 content=content,
463 data=data,
464 files=files,
465 json=json,
466 params=params,
467 headers=headers,
468 cookies=cookies,
469 auth=auth,
470 follow_redirects=follow_redirects,
471 timeout=timeout,
472 extensions=extensions,
473 )
475 def get( # type: ignore[override]
476 self,
477 url: httpx._types.URLTypes,
478 *,
479 params: httpx._types.QueryParamTypes | None = None,
480 headers: httpx._types.HeaderTypes | None = None,
481 cookies: httpx._types.CookieTypes | None = None,
482 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
483 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
484 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
485 extensions: dict[str, Any] | None = None,
486 ) -> httpx.Response:
487 return super().get(
488 url,
489 params=params,
490 headers=headers,
491 cookies=cookies,
492 auth=auth,
493 follow_redirects=follow_redirects,
494 timeout=timeout,
495 extensions=extensions,
496 )
498 def options( # type: ignore[override]
499 self,
500 url: httpx._types.URLTypes,
501 *,
502 params: httpx._types.QueryParamTypes | None = None,
503 headers: httpx._types.HeaderTypes | None = None,
504 cookies: httpx._types.CookieTypes | None = None,
505 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
506 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
507 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
508 extensions: dict[str, Any] | None = None,
509 ) -> httpx.Response:
510 return super().options(
511 url,
512 params=params,
513 headers=headers,
514 cookies=cookies,
515 auth=auth,
516 follow_redirects=follow_redirects,
517 timeout=timeout,
518 extensions=extensions,
519 )
521 def head( # type: ignore[override]
522 self,
523 url: httpx._types.URLTypes,
524 *,
525 params: httpx._types.QueryParamTypes | None = None,
526 headers: httpx._types.HeaderTypes | None = None,
527 cookies: httpx._types.CookieTypes | None = None,
528 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
529 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
530 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
531 extensions: dict[str, Any] | None = None,
532 ) -> httpx.Response:
533 return super().head(
534 url,
535 params=params,
536 headers=headers,
537 cookies=cookies,
538 auth=auth,
539 follow_redirects=follow_redirects,
540 timeout=timeout,
541 extensions=extensions,
542 )
544 def post( # type: ignore[override]
545 self,
546 url: httpx._types.URLTypes,
547 *,
548 content: httpx._types.RequestContent | None = None,
549 data: _RequestData | None = None,
550 files: httpx._types.RequestFiles | None = None,
551 json: Any = None,
552 params: httpx._types.QueryParamTypes | None = None,
553 headers: httpx._types.HeaderTypes | None = None,
554 cookies: httpx._types.CookieTypes | None = None,
555 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
556 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
557 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
558 extensions: dict[str, Any] | None = None,
559 ) -> httpx.Response:
560 return super().post(
561 url,
562 content=content,
563 data=data,
564 files=files,
565 json=json,
566 params=params,
567 headers=headers,
568 cookies=cookies,
569 auth=auth,
570 follow_redirects=follow_redirects,
571 timeout=timeout,
572 extensions=extensions,
573 )
575 def put( # type: ignore[override]
576 self,
577 url: httpx._types.URLTypes,
578 *,
579 content: httpx._types.RequestContent | None = None,
580 data: _RequestData | None = None,
581 files: httpx._types.RequestFiles | None = None,
582 json: Any = None,
583 params: httpx._types.QueryParamTypes | None = None,
584 headers: httpx._types.HeaderTypes | None = None,
585 cookies: httpx._types.CookieTypes | None = None,
586 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
587 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
588 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
589 extensions: dict[str, Any] | None = None,
590 ) -> httpx.Response:
591 return super().put(
592 url,
593 content=content,
594 data=data,
595 files=files,
596 json=json,
597 params=params,
598 headers=headers,
599 cookies=cookies,
600 auth=auth,
601 follow_redirects=follow_redirects,
602 timeout=timeout,
603 extensions=extensions,
604 )
606 def patch( # type: ignore[override]
607 self,
608 url: httpx._types.URLTypes,
609 *,
610 content: httpx._types.RequestContent | None = None,
611 data: _RequestData | None = None,
612 files: httpx._types.RequestFiles | None = None,
613 json: Any = None,
614 params: httpx._types.QueryParamTypes | None = None,
615 headers: httpx._types.HeaderTypes | None = None,
616 cookies: httpx._types.CookieTypes | None = None,
617 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
618 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
619 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
620 extensions: dict[str, Any] | None = None,
621 ) -> httpx.Response:
622 return super().patch(
623 url,
624 content=content,
625 data=data,
626 files=files,
627 json=json,
628 params=params,
629 headers=headers,
630 cookies=cookies,
631 auth=auth,
632 follow_redirects=follow_redirects,
633 timeout=timeout,
634 extensions=extensions,
635 )
637 def delete( # type: ignore[override]
638 self,
639 url: httpx._types.URLTypes,
640 *,
641 params: httpx._types.QueryParamTypes | None = None,
642 headers: httpx._types.HeaderTypes | None = None,
643 cookies: httpx._types.CookieTypes | None = None,
644 auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
645 follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
646 timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
647 extensions: dict[str, Any] | None = None,
648 ) -> httpx.Response:
649 return super().delete(
650 url,
651 params=params,
652 headers=headers,
653 cookies=cookies,
654 auth=auth,
655 follow_redirects=follow_redirects,
656 timeout=timeout,
657 extensions=extensions,
658 )
660 def websocket_connect(
661 self,
662 url: str,
663 subprotocols: Sequence[str] | None = None,
664 **kwargs: Any,
665 ) -> WebSocketTestSession:
666 url = urljoin("ws://testserver", url)
667 headers = kwargs.get("headers", {})
668 headers.setdefault("connection", "upgrade")
669 headers.setdefault("sec-websocket-key", "testserver==")
670 headers.setdefault("sec-websocket-version", "13")
671 if subprotocols is not None:
672 headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
673 kwargs["headers"] = headers
674 try:
675 super().request("GET", url, **kwargs)
676 except _Upgrade as exc:
677 session = exc.session
678 else:
679 raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
681 return session
683 def __enter__(self) -> Self:
684 with contextlib.ExitStack() as stack:
685 self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
687 @stack.callback
688 def reset_portal() -> None:
689 self.portal = None
691 send: anyio.create_memory_object_stream[MutableMapping[str, Any] | None] = (
692 anyio.create_memory_object_stream(math.inf)
693 )
694 receive: anyio.create_memory_object_stream[MutableMapping[str, Any]] = anyio.create_memory_object_stream(
695 math.inf
696 )
697 for channel in (*send, *receive):
698 stack.callback(channel.close)
699 self.stream_send = StapledObjectStream(*send)
700 self.stream_receive = StapledObjectStream(*receive)
701 self.task = portal.start_task_soon(self.lifespan)
702 portal.call(self.wait_startup)
704 @stack.callback
705 def wait_shutdown() -> None:
706 portal.call(self.wait_shutdown)
708 self.exit_stack = stack.pop_all()
710 return self
712 def __exit__(self, *args: Any) -> None:
713 self.exit_stack.close()
715 async def lifespan(self) -> None:
716 scope = {"type": "lifespan", "state": self.app_state}
717 try:
718 await self.app(scope, self.stream_receive.receive, self.stream_send.send)
719 finally:
720 await self.stream_send.send(None)
722 async def wait_startup(self) -> None:
723 await self.stream_receive.send({"type": "lifespan.startup"})
725 async def receive() -> Any:
726 message = await self.stream_send.receive()
727 if message is None:
728 self.task.result()
729 return message
731 message = await receive()
732 assert message["type"] in (
733 "lifespan.startup.complete",
734 "lifespan.startup.failed",
735 )
736 if message["type"] == "lifespan.startup.failed":
737 await receive()
739 async def wait_shutdown(self) -> None:
740 async def receive() -> Any:
741 message = await self.stream_send.receive()
742 if message is None:
743 self.task.result()
744 return message
746 await self.stream_receive.send({"type": "lifespan.shutdown"})
747 message = await receive()
748 assert message["type"] in (
749 "lifespan.shutdown.complete",
750 "lifespan.shutdown.failed",
751 )
752 if message["type"] == "lifespan.shutdown.failed":
753 await receive()