1from __future__ import annotations
2
3import typing
4
5import sniffio
6
7from .._models import Request, Response
8from .._types import AsyncByteStream
9from .base import AsyncBaseTransport
10
11if typing.TYPE_CHECKING: # pragma: no cover
12 import asyncio
13
14 import trio
15
16 Event = typing.Union[asyncio.Event, trio.Event]
17
18
19_Message = typing.Dict[str, typing.Any]
20_Receive = typing.Callable[[], typing.Awaitable[_Message]]
21_Send = typing.Callable[
22 [typing.Dict[str, typing.Any]], typing.Coroutine[None, None, None]
23]
24_ASGIApp = typing.Callable[
25 [typing.Dict[str, typing.Any], _Receive, _Send], typing.Coroutine[None, None, None]
26]
27
28
29def create_event() -> Event:
30 if sniffio.current_async_library() == "trio":
31 import trio
32
33 return trio.Event()
34 else:
35 import asyncio
36
37 return asyncio.Event()
38
39
40class ASGIResponseStream(AsyncByteStream):
41 def __init__(self, body: list[bytes]) -> None:
42 self._body = body
43
44 async def __aiter__(self) -> typing.AsyncIterator[bytes]:
45 yield b"".join(self._body)
46
47
48class ASGITransport(AsyncBaseTransport):
49 """
50 A custom AsyncTransport that handles sending requests directly to an ASGI app.
51 The simplest way to use this functionality is to use the `app` argument.
52
53 ```
54 client = httpx.AsyncClient(app=app)
55 ```
56
57 Alternatively, you can setup the transport instance explicitly.
58 This allows you to include any additional configuration arguments specific
59 to the ASGITransport class:
60
61 ```
62 transport = httpx.ASGITransport(
63 app=app,
64 root_path="/submount",
65 client=("1.2.3.4", 123)
66 )
67 client = httpx.AsyncClient(transport=transport)
68 ```
69
70 Arguments:
71
72 * `app` - The ASGI application.
73 * `raise_app_exceptions` - Boolean indicating if exceptions in the application
74 should be raised. Default to `True`. Can be set to `False` for use cases
75 such as testing the content of a client 500 response.
76 * `root_path` - The root path on which the ASGI application should be mounted.
77 * `client` - A two-tuple indicating the client IP and port of incoming requests.
78 ```
79 """
80
81 def __init__(
82 self,
83 app: _ASGIApp,
84 raise_app_exceptions: bool = True,
85 root_path: str = "",
86 client: tuple[str, int] = ("127.0.0.1", 123),
87 ) -> None:
88 self.app = app
89 self.raise_app_exceptions = raise_app_exceptions
90 self.root_path = root_path
91 self.client = client
92
93 async def handle_async_request(
94 self,
95 request: Request,
96 ) -> Response:
97 assert isinstance(request.stream, AsyncByteStream)
98
99 # ASGI scope.
100 scope = {
101 "type": "http",
102 "asgi": {"version": "3.0"},
103 "http_version": "1.1",
104 "method": request.method,
105 "headers": [(k.lower(), v) for (k, v) in request.headers.raw],
106 "scheme": request.url.scheme,
107 "path": request.url.path,
108 "raw_path": request.url.raw_path.split(b"?")[0],
109 "query_string": request.url.query,
110 "server": (request.url.host, request.url.port),
111 "client": self.client,
112 "root_path": self.root_path,
113 }
114
115 # Request.
116 request_body_chunks = request.stream.__aiter__()
117 request_complete = False
118
119 # Response.
120 status_code = None
121 response_headers = None
122 body_parts = []
123 response_started = False
124 response_complete = create_event()
125
126 # ASGI callables.
127
128 async def receive() -> dict[str, typing.Any]:
129 nonlocal request_complete
130
131 if request_complete:
132 await response_complete.wait()
133 return {"type": "http.disconnect"}
134
135 try:
136 body = await request_body_chunks.__anext__()
137 except StopAsyncIteration:
138 request_complete = True
139 return {"type": "http.request", "body": b"", "more_body": False}
140 return {"type": "http.request", "body": body, "more_body": True}
141
142 async def send(message: dict[str, typing.Any]) -> None:
143 nonlocal status_code, response_headers, response_started
144
145 if message["type"] == "http.response.start":
146 assert not response_started
147
148 status_code = message["status"]
149 response_headers = message.get("headers", [])
150 response_started = True
151
152 elif message["type"] == "http.response.body":
153 assert not response_complete.is_set()
154 body = message.get("body", b"")
155 more_body = message.get("more_body", False)
156
157 if body and request.method != "HEAD":
158 body_parts.append(body)
159
160 if not more_body:
161 response_complete.set()
162
163 try:
164 await self.app(scope, receive, send)
165 except Exception: # noqa: PIE-786
166 if self.raise_app_exceptions:
167 raise
168
169 response_complete.set()
170 if status_code is None:
171 status_code = 500
172 if response_headers is None:
173 response_headers = {}
174
175 assert response_complete.is_set()
176 assert status_code is not None
177 assert response_headers is not None
178
179 stream = ASGIResponseStream(body_parts)
180
181 return Response(status_code, headers=response_headers, stream=stream)