1from __future__ import annotations
2
3import io
4import itertools
5import sys
6import typing
7
8from .._models import Request, Response
9from .._types import SyncByteStream
10from .base import BaseTransport
11
12if typing.TYPE_CHECKING:
13 from _typeshed import OptExcInfo # pragma: no cover
14 from _typeshed.wsgi import WSGIApplication # pragma: no cover
15
16_T = typing.TypeVar("_T")
17
18
19__all__ = ["WSGITransport"]
20
21
22def _skip_leading_empty_chunks(body: typing.Iterable[_T]) -> typing.Iterable[_T]:
23 body = iter(body)
24 for chunk in body:
25 if chunk:
26 return itertools.chain([chunk], body)
27 return []
28
29
30class WSGIByteStream(SyncByteStream):
31 def __init__(self, result: typing.Iterable[bytes]) -> None:
32 self._close = getattr(result, "close", None)
33 self._result = _skip_leading_empty_chunks(result)
34
35 def __iter__(self) -> typing.Iterator[bytes]:
36 for part in self._result:
37 yield part
38
39 def close(self) -> None:
40 if self._close is not None:
41 self._close()
42
43
44class WSGITransport(BaseTransport):
45 """
46 A custom transport that handles sending requests directly to an WSGI app.
47 The simplest way to use this functionality is to use the `app` argument.
48
49 ```
50 client = httpx.Client(app=app)
51 ```
52
53 Alternatively, you can setup the transport instance explicitly.
54 This allows you to include any additional configuration arguments specific
55 to the WSGITransport class:
56
57 ```
58 transport = httpx.WSGITransport(
59 app=app,
60 script_name="/submount",
61 remote_addr="1.2.3.4"
62 )
63 client = httpx.Client(transport=transport)
64 ```
65
66 Arguments:
67
68 * `app` - The WSGI application.
69 * `raise_app_exceptions` - Boolean indicating if exceptions in the application
70 should be raised. Default to `True`. Can be set to `False` for use cases
71 such as testing the content of a client 500 response.
72 * `script_name` - The root path on which the WSGI application should be mounted.
73 * `remote_addr` - A string indicating the client IP of incoming requests.
74 ```
75 """
76
77 def __init__(
78 self,
79 app: WSGIApplication,
80 raise_app_exceptions: bool = True,
81 script_name: str = "",
82 remote_addr: str = "127.0.0.1",
83 wsgi_errors: typing.TextIO | None = None,
84 ) -> None:
85 self.app = app
86 self.raise_app_exceptions = raise_app_exceptions
87 self.script_name = script_name
88 self.remote_addr = remote_addr
89 self.wsgi_errors = wsgi_errors
90
91 def handle_request(self, request: Request) -> Response:
92 request.read()
93 wsgi_input = io.BytesIO(request.content)
94
95 port = request.url.port or {"http": 80, "https": 443}[request.url.scheme]
96 environ = {
97 "wsgi.version": (1, 0),
98 "wsgi.url_scheme": request.url.scheme,
99 "wsgi.input": wsgi_input,
100 "wsgi.errors": self.wsgi_errors or sys.stderr,
101 "wsgi.multithread": True,
102 "wsgi.multiprocess": False,
103 "wsgi.run_once": False,
104 "REQUEST_METHOD": request.method,
105 "SCRIPT_NAME": self.script_name,
106 "PATH_INFO": request.url.path,
107 "QUERY_STRING": request.url.query.decode("ascii"),
108 "SERVER_NAME": request.url.host,
109 "SERVER_PORT": str(port),
110 "SERVER_PROTOCOL": "HTTP/1.1",
111 "REMOTE_ADDR": self.remote_addr,
112 }
113 for header_key, header_value in request.headers.raw:
114 key = header_key.decode("ascii").upper().replace("-", "_")
115 if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"):
116 key = "HTTP_" + key
117 environ[key] = header_value.decode("ascii")
118
119 seen_status = None
120 seen_response_headers = None
121 seen_exc_info = None
122
123 def start_response(
124 status: str,
125 response_headers: list[tuple[str, str]],
126 exc_info: OptExcInfo | None = None,
127 ) -> typing.Callable[[bytes], typing.Any]:
128 nonlocal seen_status, seen_response_headers, seen_exc_info
129 seen_status = status
130 seen_response_headers = response_headers
131 seen_exc_info = exc_info
132 return lambda _: None
133
134 result = self.app(environ, start_response)
135
136 stream = WSGIByteStream(result)
137
138 assert seen_status is not None
139 assert seen_response_headers is not None
140 if seen_exc_info and seen_exc_info[0] and self.raise_app_exceptions:
141 raise seen_exc_info[1]
142
143 status_code = int(seen_status.split()[0])
144 headers = [
145 (key.encode("ascii"), value.encode("ascii"))
146 for key, value in seen_response_headers
147 ]
148
149 return Response(status_code, headers=headers, stream=stream)