1from __future__ import annotations
2
3import io
4import itertools
5import sys
6import typing
7
8from .._models import Request, Response
9from .._types import SyncByteStream
10from .base import BaseTransport
11
12if typing.TYPE_CHECKING:
13 from _typeshed import OptExcInfo # pragma: no cover
14 from _typeshed.wsgi import WSGIApplication # pragma: no cover
15
16_T = typing.TypeVar("_T")
17
18
19def _skip_leading_empty_chunks(body: typing.Iterable[_T]) -> typing.Iterable[_T]:
20 body = iter(body)
21 for chunk in body:
22 if chunk:
23 return itertools.chain([chunk], body)
24 return []
25
26
27class WSGIByteStream(SyncByteStream):
28 def __init__(self, result: typing.Iterable[bytes]) -> None:
29 self._close = getattr(result, "close", None)
30 self._result = _skip_leading_empty_chunks(result)
31
32 def __iter__(self) -> typing.Iterator[bytes]:
33 for part in self._result:
34 yield part
35
36 def close(self) -> None:
37 if self._close is not None:
38 self._close()
39
40
41class WSGITransport(BaseTransport):
42 """
43 A custom transport that handles sending requests directly to an WSGI app.
44 The simplest way to use this functionality is to use the `app` argument.
45
46 ```
47 client = httpx.Client(app=app)
48 ```
49
50 Alternatively, you can setup the transport instance explicitly.
51 This allows you to include any additional configuration arguments specific
52 to the WSGITransport class:
53
54 ```
55 transport = httpx.WSGITransport(
56 app=app,
57 script_name="/submount",
58 remote_addr="1.2.3.4"
59 )
60 client = httpx.Client(transport=transport)
61 ```
62
63 Arguments:
64
65 * `app` - The WSGI application.
66 * `raise_app_exceptions` - Boolean indicating if exceptions in the application
67 should be raised. Default to `True`. Can be set to `False` for use cases
68 such as testing the content of a client 500 response.
69 * `script_name` - The root path on which the WSGI application should be mounted.
70 * `remote_addr` - A string indicating the client IP of incoming requests.
71 ```
72 """
73
74 def __init__(
75 self,
76 app: WSGIApplication,
77 raise_app_exceptions: bool = True,
78 script_name: str = "",
79 remote_addr: str = "127.0.0.1",
80 wsgi_errors: typing.TextIO | None = None,
81 ) -> None:
82 self.app = app
83 self.raise_app_exceptions = raise_app_exceptions
84 self.script_name = script_name
85 self.remote_addr = remote_addr
86 self.wsgi_errors = wsgi_errors
87
88 def handle_request(self, request: Request) -> Response:
89 request.read()
90 wsgi_input = io.BytesIO(request.content)
91
92 port = request.url.port or {"http": 80, "https": 443}[request.url.scheme]
93 environ = {
94 "wsgi.version": (1, 0),
95 "wsgi.url_scheme": request.url.scheme,
96 "wsgi.input": wsgi_input,
97 "wsgi.errors": self.wsgi_errors or sys.stderr,
98 "wsgi.multithread": True,
99 "wsgi.multiprocess": False,
100 "wsgi.run_once": False,
101 "REQUEST_METHOD": request.method,
102 "SCRIPT_NAME": self.script_name,
103 "PATH_INFO": request.url.path,
104 "QUERY_STRING": request.url.query.decode("ascii"),
105 "SERVER_NAME": request.url.host,
106 "SERVER_PORT": str(port),
107 "SERVER_PROTOCOL": "HTTP/1.1",
108 "REMOTE_ADDR": self.remote_addr,
109 }
110 for header_key, header_value in request.headers.raw:
111 key = header_key.decode("ascii").upper().replace("-", "_")
112 if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"):
113 key = "HTTP_" + key
114 environ[key] = header_value.decode("ascii")
115
116 seen_status = None
117 seen_response_headers = None
118 seen_exc_info = None
119
120 def start_response(
121 status: str,
122 response_headers: list[tuple[str, str]],
123 exc_info: OptExcInfo | None = None,
124 ) -> typing.Callable[[bytes], typing.Any]:
125 nonlocal seen_status, seen_response_headers, seen_exc_info
126 seen_status = status
127 seen_response_headers = response_headers
128 seen_exc_info = exc_info
129 return lambda _: None
130
131 result = self.app(environ, start_response)
132
133 stream = WSGIByteStream(result)
134
135 assert seen_status is not None
136 assert seen_response_headers is not None
137 if seen_exc_info and seen_exc_info[0] and self.raise_app_exceptions:
138 raise seen_exc_info[1]
139
140 status_code = int(seen_status.split()[0])
141 headers = [
142 (key.encode("ascii"), value.encode("ascii"))
143 for key, value in seen_response_headers
144 ]
145
146 return Response(status_code, headers=headers, stream=stream)