1from __future__ import annotations
2
3import typing
4
5from .._models import Request, Response
6from .._types import AsyncByteStream
7from .base import AsyncBaseTransport
8
9if typing.TYPE_CHECKING: # pragma: no cover
10 import asyncio
11
12 import trio
13
14 Event = typing.Union[asyncio.Event, trio.Event]
15
16
17_Message = typing.MutableMapping[str, typing.Any]
18_Receive = typing.Callable[[], typing.Awaitable[_Message]]
19_Send = typing.Callable[
20 [typing.MutableMapping[str, typing.Any]], typing.Awaitable[None]
21]
22_ASGIApp = typing.Callable[
23 [typing.MutableMapping[str, typing.Any], _Receive, _Send], typing.Awaitable[None]
24]
25
26__all__ = ["ASGITransport"]
27
28
29def is_running_trio() -> bool:
30 try:
31 # sniffio is a dependency of trio.
32
33 # See https://github.com/python-trio/trio/issues/2802
34 import sniffio
35
36 if sniffio.current_async_library() == "trio":
37 return True
38 except ImportError: # pragma: nocover
39 pass
40
41 return False
42
43
44def create_event() -> Event:
45 if is_running_trio():
46 import trio
47
48 return trio.Event()
49
50 import asyncio
51
52 return asyncio.Event()
53
54
55class ASGIResponseStream(AsyncByteStream):
56 def __init__(self, body: list[bytes]) -> None:
57 self._body = body
58
59 async def __aiter__(self) -> typing.AsyncIterator[bytes]:
60 yield b"".join(self._body)
61
62
63class ASGITransport(AsyncBaseTransport):
64 """
65 A custom AsyncTransport that handles sending requests directly to an ASGI app.
66
67 ```python
68 transport = httpx.ASGITransport(
69 app=app,
70 root_path="/submount",
71 client=("1.2.3.4", 123)
72 )
73 client = httpx.AsyncClient(transport=transport)
74 ```
75
76 Arguments:
77
78 * `app` - The ASGI application.
79 * `raise_app_exceptions` - Boolean indicating if exceptions in the application
80 should be raised. Default to `True`. Can be set to `False` for use cases
81 such as testing the content of a client 500 response.
82 * `root_path` - The root path on which the ASGI application should be mounted.
83 * `client` - A two-tuple indicating the client IP and port of incoming requests.
84 ```
85 """
86
87 def __init__(
88 self,
89 app: _ASGIApp,
90 raise_app_exceptions: bool = True,
91 root_path: str = "",
92 client: tuple[str, int] = ("127.0.0.1", 123),
93 ) -> None:
94 self.app = app
95 self.raise_app_exceptions = raise_app_exceptions
96 self.root_path = root_path
97 self.client = client
98
99 async def handle_async_request(
100 self,
101 request: Request,
102 ) -> Response:
103 assert isinstance(request.stream, AsyncByteStream)
104
105 # ASGI scope.
106 scope = {
107 "type": "http",
108 "asgi": {"version": "3.0"},
109 "http_version": "1.1",
110 "method": request.method,
111 "headers": [(k.lower(), v) for (k, v) in request.headers.raw],
112 "scheme": request.url.scheme,
113 "path": request.url.path,
114 "raw_path": request.url.raw_path.split(b"?")[0],
115 "query_string": request.url.query,
116 "server": (request.url.host, request.url.port),
117 "client": self.client,
118 "root_path": self.root_path,
119 }
120
121 # Request.
122 request_body_chunks = request.stream.__aiter__()
123 request_complete = False
124
125 # Response.
126 status_code = None
127 response_headers = None
128 body_parts = []
129 response_started = False
130 response_complete = create_event()
131
132 # ASGI callables.
133
134 async def receive() -> dict[str, typing.Any]:
135 nonlocal request_complete
136
137 if request_complete:
138 await response_complete.wait()
139 return {"type": "http.disconnect"}
140
141 try:
142 body = await request_body_chunks.__anext__()
143 except StopAsyncIteration:
144 request_complete = True
145 return {"type": "http.request", "body": b"", "more_body": False}
146 return {"type": "http.request", "body": body, "more_body": True}
147
148 async def send(message: typing.MutableMapping[str, typing.Any]) -> None:
149 nonlocal status_code, response_headers, response_started
150
151 if message["type"] == "http.response.start":
152 assert not response_started
153
154 status_code = message["status"]
155 response_headers = message.get("headers", [])
156 response_started = True
157
158 elif message["type"] == "http.response.body":
159 assert not response_complete.is_set()
160 body = message.get("body", b"")
161 more_body = message.get("more_body", False)
162
163 if body and request.method != "HEAD":
164 body_parts.append(body)
165
166 if not more_body:
167 response_complete.set()
168
169 try:
170 await self.app(scope, receive, send)
171 except Exception: # noqa: PIE-786
172 if self.raise_app_exceptions:
173 raise
174
175 response_complete.set()
176 if status_code is None:
177 status_code = 500
178 if response_headers is None:
179 response_headers = {}
180
181 assert response_complete.is_set()
182 assert status_code is not None
183 assert response_headers is not None
184
185 stream = ASGIResponseStream(body_parts)
186
187 return Response(status_code, headers=response_headers, stream=stream)