1from __future__ import annotations
2
3import io
4import mimetypes
5import os
6import re
7import typing
8from pathlib import Path
9
10from ._types import (
11 AsyncByteStream,
12 FileContent,
13 FileTypes,
14 RequestData,
15 RequestFiles,
16 SyncByteStream,
17)
18from ._utils import (
19 peek_filelike_length,
20 primitive_value_to_str,
21 to_bytes,
22)
23
24_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
25_HTML5_FORM_ENCODING_REPLACEMENTS.update(
26 {chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B}
27)
28_HTML5_FORM_ENCODING_RE = re.compile(
29 r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()])
30)
31
32
33def _format_form_param(name: str, value: str) -> bytes:
34 """
35 Encode a name/value pair within a multipart form.
36 """
37
38 def replacer(match: typing.Match[str]) -> str:
39 return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)]
40
41 value = _HTML5_FORM_ENCODING_RE.sub(replacer, value)
42 return f'{name}="{value}"'.encode()
43
44
45def _guess_content_type(filename: str | None) -> str | None:
46 """
47 Guesses the mimetype based on a filename. Defaults to `application/octet-stream`.
48
49 Returns `None` if `filename` is `None` or empty.
50 """
51 if filename:
52 return mimetypes.guess_type(filename)[0] or "application/octet-stream"
53 return None
54
55
56def get_multipart_boundary_from_content_type(
57 content_type: bytes | None,
58) -> bytes | None:
59 if not content_type or not content_type.startswith(b"multipart/form-data"):
60 return None
61 # parse boundary according to
62 # https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1
63 if b";" in content_type:
64 for section in content_type.split(b";"):
65 if section.strip().lower().startswith(b"boundary="):
66 return section.strip()[len(b"boundary=") :].strip(b'"')
67 return None
68
69
70class DataField:
71 """
72 A single form field item, within a multipart form field.
73 """
74
75 def __init__(self, name: str, value: str | bytes | int | float | None) -> None:
76 if not isinstance(name, str):
77 raise TypeError(
78 f"Invalid type for name. Expected str, got {type(name)}: {name!r}"
79 )
80 if value is not None and not isinstance(value, (str, bytes, int, float)):
81 raise TypeError(
82 "Invalid type for value. Expected primitive type,"
83 f" got {type(value)}: {value!r}"
84 )
85 self.name = name
86 self.value: str | bytes = (
87 value if isinstance(value, bytes) else primitive_value_to_str(value)
88 )
89
90 def render_headers(self) -> bytes:
91 if not hasattr(self, "_headers"):
92 name = _format_form_param("name", self.name)
93 self._headers = b"".join(
94 [b"Content-Disposition: form-data; ", name, b"\r\n\r\n"]
95 )
96
97 return self._headers
98
99 def render_data(self) -> bytes:
100 if not hasattr(self, "_data"):
101 self._data = to_bytes(self.value)
102
103 return self._data
104
105 def get_length(self) -> int:
106 headers = self.render_headers()
107 data = self.render_data()
108 return len(headers) + len(data)
109
110 def render(self) -> typing.Iterator[bytes]:
111 yield self.render_headers()
112 yield self.render_data()
113
114
115class FileField:
116 """
117 A single file field item, within a multipart form field.
118 """
119
120 CHUNK_SIZE = 64 * 1024
121
122 def __init__(self, name: str, value: FileTypes) -> None:
123 self.name = name
124
125 fileobj: FileContent
126
127 headers: dict[str, str] = {}
128 content_type: str | None = None
129
130 # This large tuple based API largely mirror's requests' API
131 # It would be good to think of better APIs for this that we could
132 # include in httpx 2.0 since variable length tuples(especially of 4 elements)
133 # are quite unwieldly
134 if isinstance(value, tuple):
135 if len(value) == 2:
136 # neither the 3rd parameter (content_type) nor the 4th (headers)
137 # was included
138 filename, fileobj = value
139 elif len(value) == 3:
140 filename, fileobj, content_type = value
141 else:
142 # all 4 parameters included
143 filename, fileobj, content_type, headers = value # type: ignore
144 else:
145 filename = Path(str(getattr(value, "name", "upload"))).name
146 fileobj = value
147
148 if content_type is None:
149 content_type = _guess_content_type(filename)
150
151 has_content_type_header = any("content-type" in key.lower() for key in headers)
152 if content_type is not None and not has_content_type_header:
153 # note that unlike requests, we ignore the content_type provided in the 3rd
154 # tuple element if it is also included in the headers requests does
155 # the opposite (it overwrites the headerwith the 3rd tuple element)
156 headers["Content-Type"] = content_type
157
158 if isinstance(fileobj, io.StringIO):
159 raise TypeError(
160 "Multipart file uploads require 'io.BytesIO', not 'io.StringIO'."
161 )
162 if isinstance(fileobj, io.TextIOBase):
163 raise TypeError(
164 "Multipart file uploads must be opened in binary mode, not text mode."
165 )
166
167 self.filename = filename
168 self.file = fileobj
169 self.headers = headers
170
171 def get_length(self) -> int | None:
172 headers = self.render_headers()
173
174 if isinstance(self.file, (str, bytes)):
175 return len(headers) + len(to_bytes(self.file))
176
177 file_length = peek_filelike_length(self.file)
178
179 # If we can't determine the filesize without reading it into memory,
180 # then return `None` here, to indicate an unknown file length.
181 if file_length is None:
182 return None
183
184 return len(headers) + file_length
185
186 def render_headers(self) -> bytes:
187 if not hasattr(self, "_headers"):
188 parts = [
189 b"Content-Disposition: form-data; ",
190 _format_form_param("name", self.name),
191 ]
192 if self.filename:
193 filename = _format_form_param("filename", self.filename)
194 parts.extend([b"; ", filename])
195 for header_name, header_value in self.headers.items():
196 key, val = f"\r\n{header_name}: ".encode(), header_value.encode()
197 parts.extend([key, val])
198 parts.append(b"\r\n\r\n")
199 self._headers = b"".join(parts)
200
201 return self._headers
202
203 def render_data(self) -> typing.Iterator[bytes]:
204 if isinstance(self.file, (str, bytes)):
205 yield to_bytes(self.file)
206 return
207
208 if hasattr(self.file, "seek"):
209 try:
210 self.file.seek(0)
211 except io.UnsupportedOperation:
212 pass
213
214 chunk = self.file.read(self.CHUNK_SIZE)
215 while chunk:
216 yield to_bytes(chunk)
217 chunk = self.file.read(self.CHUNK_SIZE)
218
219 def render(self) -> typing.Iterator[bytes]:
220 yield self.render_headers()
221 yield from self.render_data()
222
223
224class MultipartStream(SyncByteStream, AsyncByteStream):
225 """
226 Request content as streaming multipart encoded form data.
227 """
228
229 def __init__(
230 self,
231 data: RequestData,
232 files: RequestFiles,
233 boundary: bytes | None = None,
234 ) -> None:
235 if boundary is None:
236 boundary = os.urandom(16).hex().encode("ascii")
237
238 self.boundary = boundary
239 self.content_type = "multipart/form-data; boundary=%s" % boundary.decode(
240 "ascii"
241 )
242 self.fields = list(self._iter_fields(data, files))
243
244 def _iter_fields(
245 self, data: RequestData, files: RequestFiles
246 ) -> typing.Iterator[FileField | DataField]:
247 for name, value in data.items():
248 if isinstance(value, (tuple, list)):
249 for item in value:
250 yield DataField(name=name, value=item)
251 else:
252 yield DataField(name=name, value=value)
253
254 file_items = files.items() if isinstance(files, typing.Mapping) else files
255 for name, value in file_items:
256 yield FileField(name=name, value=value)
257
258 def iter_chunks(self) -> typing.Iterator[bytes]:
259 for field in self.fields:
260 yield b"--%s\r\n" % self.boundary
261 yield from field.render()
262 yield b"\r\n"
263 yield b"--%s--\r\n" % self.boundary
264
265 def get_content_length(self) -> int | None:
266 """
267 Return the length of the multipart encoded content, or `None` if
268 any of the files have a length that cannot be determined upfront.
269 """
270 boundary_length = len(self.boundary)
271 length = 0
272
273 for field in self.fields:
274 field_length = field.get_length()
275 if field_length is None:
276 return None
277
278 length += 2 + boundary_length + 2 # b"--{boundary}\r\n"
279 length += field_length
280 length += 2 # b"\r\n"
281
282 length += 2 + boundary_length + 4 # b"--{boundary}--\r\n"
283 return length
284
285 # Content stream interface.
286
287 def get_headers(self) -> dict[str, str]:
288 content_length = self.get_content_length()
289 content_type = self.content_type
290 if content_length is None:
291 return {"Transfer-Encoding": "chunked", "Content-Type": content_type}
292 return {"Content-Length": str(content_length), "Content-Type": content_type}
293
294 def __iter__(self) -> typing.Iterator[bytes]:
295 for chunk in self.iter_chunks():
296 yield chunk
297
298 async def __aiter__(self) -> typing.AsyncIterator[bytes]:
299 for chunk in self.iter_chunks():
300 yield chunk