1from __future__ import annotations
2
3import typing
4
5import sniffio
6
7from .._models import Request, Response
8from .._types import AsyncByteStream
9from .base import AsyncBaseTransport
10
11if typing.TYPE_CHECKING: # pragma: no cover
12 import asyncio
13
14 import trio
15
16 Event = typing.Union[asyncio.Event, trio.Event]
17
18
19_Message = typing.MutableMapping[str, typing.Any]
20_Receive = typing.Callable[[], typing.Awaitable[_Message]]
21_Send = typing.Callable[
22 [typing.MutableMapping[str, typing.Any]], typing.Awaitable[None]
23]
24_ASGIApp = typing.Callable[
25 [typing.MutableMapping[str, typing.Any], _Receive, _Send], typing.Awaitable[None]
26]
27
28__all__ = ["ASGITransport"]
29
30
31def create_event() -> Event:
32 if sniffio.current_async_library() == "trio":
33 import trio
34
35 return trio.Event()
36 else:
37 import asyncio
38
39 return asyncio.Event()
40
41
42class ASGIResponseStream(AsyncByteStream):
43 def __init__(self, body: list[bytes]) -> None:
44 self._body = body
45
46 async def __aiter__(self) -> typing.AsyncIterator[bytes]:
47 yield b"".join(self._body)
48
49
50class ASGITransport(AsyncBaseTransport):
51 """
52 A custom AsyncTransport that handles sending requests directly to an ASGI app.
53
54 ```python
55 transport = httpx.ASGITransport(
56 app=app,
57 root_path="/submount",
58 client=("1.2.3.4", 123)
59 )
60 client = httpx.AsyncClient(transport=transport)
61 ```
62
63 Arguments:
64
65 * `app` - The ASGI application.
66 * `raise_app_exceptions` - Boolean indicating if exceptions in the application
67 should be raised. Default to `True`. Can be set to `False` for use cases
68 such as testing the content of a client 500 response.
69 * `root_path` - The root path on which the ASGI application should be mounted.
70 * `client` - A two-tuple indicating the client IP and port of incoming requests.
71 ```
72 """
73
74 def __init__(
75 self,
76 app: _ASGIApp,
77 raise_app_exceptions: bool = True,
78 root_path: str = "",
79 client: tuple[str, int] = ("127.0.0.1", 123),
80 ) -> None:
81 self.app = app
82 self.raise_app_exceptions = raise_app_exceptions
83 self.root_path = root_path
84 self.client = client
85
86 async def handle_async_request(
87 self,
88 request: Request,
89 ) -> Response:
90 assert isinstance(request.stream, AsyncByteStream)
91
92 # ASGI scope.
93 scope = {
94 "type": "http",
95 "asgi": {"version": "3.0"},
96 "http_version": "1.1",
97 "method": request.method,
98 "headers": [(k.lower(), v) for (k, v) in request.headers.raw],
99 "scheme": request.url.scheme,
100 "path": request.url.path,
101 "raw_path": request.url.raw_path.split(b"?")[0],
102 "query_string": request.url.query,
103 "server": (request.url.host, request.url.port),
104 "client": self.client,
105 "root_path": self.root_path,
106 }
107
108 # Request.
109 request_body_chunks = request.stream.__aiter__()
110 request_complete = False
111
112 # Response.
113 status_code = None
114 response_headers = None
115 body_parts = []
116 response_started = False
117 response_complete = create_event()
118
119 # ASGI callables.
120
121 async def receive() -> dict[str, typing.Any]:
122 nonlocal request_complete
123
124 if request_complete:
125 await response_complete.wait()
126 return {"type": "http.disconnect"}
127
128 try:
129 body = await request_body_chunks.__anext__()
130 except StopAsyncIteration:
131 request_complete = True
132 return {"type": "http.request", "body": b"", "more_body": False}
133 return {"type": "http.request", "body": body, "more_body": True}
134
135 async def send(message: typing.MutableMapping[str, typing.Any]) -> None:
136 nonlocal status_code, response_headers, response_started
137
138 if message["type"] == "http.response.start":
139 assert not response_started
140
141 status_code = message["status"]
142 response_headers = message.get("headers", [])
143 response_started = True
144
145 elif message["type"] == "http.response.body":
146 assert not response_complete.is_set()
147 body = message.get("body", b"")
148 more_body = message.get("more_body", False)
149
150 if body and request.method != "HEAD":
151 body_parts.append(body)
152
153 if not more_body:
154 response_complete.set()
155
156 try:
157 await self.app(scope, receive, send)
158 except Exception: # noqa: PIE-786
159 if self.raise_app_exceptions:
160 raise
161
162 response_complete.set()
163 if status_code is None:
164 status_code = 500
165 if response_headers is None:
166 response_headers = {}
167
168 assert response_complete.is_set()
169 assert status_code is not None
170 assert response_headers is not None
171
172 stream = ASGIResponseStream(body_parts)
173
174 return Response(status_code, headers=response_headers, stream=stream)