1import decimal
2
3try:
4 from collections.abc import Hashable
5except ImportError:
6 from collections import Hashable
7from copy import deepcopy
8from flask import current_app, request
9
10from werkzeug.datastructures import MultiDict, FileStorage
11from werkzeug import exceptions
12
13from .errors import abort, SpecsError
14from .marshalling import marshal
15from .model import Model
16from ._http import HTTPStatus
17
18
19class ParseResult(dict):
20 """
21 The default result container as an Object dict.
22 """
23
24 def __getattr__(self, name):
25 try:
26 return self[name]
27 except KeyError:
28 raise AttributeError(name)
29
30 def __setattr__(self, name, value):
31 self[name] = value
32
33
34_friendly_location = {
35 "json": "the JSON body",
36 "form": "the post body",
37 "args": "the query string",
38 "values": "the post body or the query string",
39 "headers": "the HTTP headers",
40 "cookies": "the request's cookies",
41 "files": "an uploaded file",
42}
43
44#: Maps Flask-RESTX RequestParser locations to Swagger ones
45LOCATIONS = {
46 "args": "query",
47 "form": "formData",
48 "headers": "header",
49 "json": "body",
50 "values": "query",
51 "files": "formData",
52}
53
54#: Maps Python primitives types to Swagger ones
55PY_TYPES = {
56 int: "integer",
57 str: "string",
58 bool: "boolean",
59 float: "number",
60 None: "void",
61}
62
63SPLIT_CHAR = ","
64
65
66class Argument(object):
67 """
68 :param name: Either a name or a list of option strings, e.g. foo or -f, --foo.
69 :param default: The value produced if the argument is absent from the request.
70 :param dest: The name of the attribute to be added to the object
71 returned by :meth:`~reqparse.RequestParser.parse_args()`.
72 :param bool required: Whether or not the argument may be omitted (optionals only).
73 :param string action: The basic type of action to be taken when this argument
74 is encountered in the request. Valid options are "store" and "append".
75 :param bool ignore: Whether to ignore cases where the argument fails type conversion
76 :param type: The type to which the request argument should be converted.
77 If a type raises an exception, the message in the error will be returned in the response.
78 Defaults to :class:`str`.
79 :param location: The attributes of the :class:`flask.Request` object
80 to source the arguments from (ex: headers, args, etc.), can be an
81 iterator. The last item listed takes precedence in the result set.
82 :param choices: A container of the allowable values for the argument.
83 :param help: A brief description of the argument, returned in the
84 response when the argument is invalid. May optionally contain
85 an "{error_msg}" interpolation token, which will be replaced with
86 the text of the error raised by the type converter.
87 :param bool case_sensitive: Whether argument values in the request are
88 case sensitive or not (this will convert all values to lowercase)
89 :param bool store_missing: Whether the arguments default value should
90 be stored if the argument is missing from the request.
91 :param bool trim: If enabled, trims whitespace around the argument.
92 :param bool nullable: If enabled, allows null value in argument.
93 """
94
95 def __init__(
96 self,
97 name,
98 default=None,
99 dest=None,
100 required=False,
101 ignore=False,
102 type=str,
103 location=(
104 "json",
105 "values",
106 ),
107 choices=(),
108 action="store",
109 help=None,
110 operators=("=",),
111 case_sensitive=True,
112 store_missing=True,
113 trim=False,
114 nullable=True,
115 ):
116 self.name = name
117 self.default = default
118 self.dest = dest
119 self.required = required
120 self.ignore = ignore
121 self.location = location
122 self.type = type
123 self.choices = choices
124 self.action = action
125 self.help = help
126 self.case_sensitive = case_sensitive
127 self.operators = operators
128 self.store_missing = store_missing
129 self.trim = trim
130 self.nullable = nullable
131
132 def source(self, request):
133 """
134 Pulls values off the request in the provided location
135 :param request: The flask request object to parse arguments from
136 """
137 if isinstance(self.location, str):
138 if self.location in {"json", "get_json"}:
139 value = request.get_json(silent=True)
140 else:
141 value = getattr(request, self.location, MultiDict())
142 if callable(value):
143 value = value()
144 if value is not None:
145 return value
146 else:
147 values = MultiDict()
148 for l in self.location:
149 if l in {"json", "get_json"}:
150 value = request.get_json(silent=True)
151 else:
152 value = getattr(request, l, None)
153 if callable(value):
154 value = value()
155 if value is not None:
156 values.update(value)
157 return values
158
159 return MultiDict()
160
161 def convert(self, value, op):
162 # Don't cast None
163 if value is None:
164 if not self.nullable:
165 raise ValueError("Must not be null!")
166 return None
167
168 elif isinstance(self.type, Model) and isinstance(value, dict):
169 return marshal(value, self.type)
170
171 # and check if we're expecting a filestorage and haven't overridden `type`
172 # (required because the below instantiation isn't valid for FileStorage)
173 elif isinstance(value, FileStorage) and self.type == FileStorage:
174 return value
175
176 try:
177 return self.type(value, self.name, op)
178 except TypeError:
179 try:
180 if self.type is decimal.Decimal:
181 return self.type(str(value), self.name)
182 else:
183 return self.type(value, self.name)
184 except TypeError:
185 return self.type(value)
186
187 def handle_validation_error(self, error, bundle_errors):
188 """
189 Called when an error is raised while parsing. Aborts the request
190 with a 400 status and an error message
191
192 :param error: the error that was raised
193 :param bool bundle_errors: do not abort when first error occurs, return a
194 dict with the name of the argument and the error message to be
195 bundled
196 """
197 error_str = str(error)
198 error_msg = " ".join([str(self.help), error_str]) if self.help else error_str
199 errors = {self.name: error_msg}
200
201 if bundle_errors:
202 return ValueError(error), errors
203 abort(HTTPStatus.BAD_REQUEST, "Input payload validation failed", errors=errors)
204
205 def parse(self, request, bundle_errors=False):
206 """
207 Parses argument value(s) from the request, converting according to
208 the argument's type.
209
210 :param request: The flask request object to parse arguments from
211 :param bool bundle_errors: do not abort when first error occurs, return a
212 dict with the name of the argument and the error message to be
213 bundled
214 """
215 bundle_errors = current_app.config.get("BUNDLE_ERRORS", False) or bundle_errors
216 source = self.source(request)
217
218 results = []
219
220 # Sentinels
221 _not_found = False
222 _found = True
223
224 for operator in self.operators:
225 name = self.name + operator.replace("=", "", 1)
226 if name in source:
227 # Account for MultiDict and regular dict
228 if hasattr(source, "getlist"):
229 values = source.getlist(name)
230 else:
231 values = [source.get(name)]
232
233 for value in values:
234 if hasattr(value, "strip") and self.trim:
235 value = value.strip()
236 if hasattr(value, "lower") and not self.case_sensitive:
237 value = value.lower()
238
239 if hasattr(self.choices, "__iter__"):
240 self.choices = [choice.lower() for choice in self.choices]
241
242 try:
243 if self.action == "split":
244 value = [
245 self.convert(v, operator)
246 for v in value.split(SPLIT_CHAR)
247 ]
248 else:
249 value = self.convert(value, operator)
250 except Exception as error:
251 if self.ignore:
252 continue
253 return self.handle_validation_error(error, bundle_errors)
254
255 if self.choices and value not in self.choices:
256 msg = "The value '{0}' is not a valid choice for '{1}'.".format(
257 value, name
258 )
259 return self.handle_validation_error(msg, bundle_errors)
260
261 if name in request.unparsed_arguments:
262 request.unparsed_arguments.pop(name)
263 results.append(value)
264
265 if not results and self.required:
266 if isinstance(self.location, str):
267 location = _friendly_location.get(self.location, self.location)
268 else:
269 locations = [_friendly_location.get(loc, loc) for loc in self.location]
270 location = " or ".join(locations)
271 error_msg = "Missing required parameter in {0}".format(location)
272 return self.handle_validation_error(error_msg, bundle_errors)
273
274 if not results:
275 if callable(self.default):
276 return self.default(), _not_found
277 else:
278 return self.default, _not_found
279
280 if self.action == "append":
281 return results, _found
282
283 if self.action == "store" or len(results) == 1:
284 return results[0], _found
285 return results, _found
286
287 @property
288 def __schema__(self):
289 if self.location == "cookie":
290 return
291 param = {"name": self.name, "in": LOCATIONS.get(self.location, "query")}
292 _handle_arg_type(self, param)
293 if self.required:
294 param["required"] = True
295 if self.help:
296 param["description"] = self.help
297 if self.default is not None:
298 param["default"] = (
299 self.default() if callable(self.default) else self.default
300 )
301 if self.action == "append":
302 param["items"] = {"type": param["type"]}
303 if "pattern" in param:
304 param["items"]["pattern"] = param.pop("pattern")
305 param["type"] = "array"
306 param["collectionFormat"] = "multi"
307 if self.action == "split":
308 param["items"] = {"type": param["type"]}
309 param["type"] = "array"
310 param["collectionFormat"] = "csv"
311 if self.choices:
312 param["enum"] = self.choices
313 return param
314
315
316class RequestParser(object):
317 """
318 Enables adding and parsing of multiple arguments in the context of a single request.
319 Ex::
320
321 from flask_restx import RequestParser
322
323 parser = RequestParser()
324 parser.add_argument('foo')
325 parser.add_argument('int_bar', type=int)
326 args = parser.parse_args()
327
328 :param bool trim: If enabled, trims whitespace on all arguments in this parser
329 :param bool bundle_errors: If enabled, do not abort when first error occurs,
330 return a dict with the name of the argument and the error message to be
331 bundled and return all validation errors
332 """
333
334 def __init__(
335 self,
336 argument_class=Argument,
337 result_class=ParseResult,
338 trim=False,
339 bundle_errors=False,
340 ):
341 self.args = []
342 self.argument_class = argument_class
343 self.result_class = result_class
344 self.trim = trim
345 self.bundle_errors = bundle_errors
346
347 def add_argument(self, *args, **kwargs):
348 """
349 Adds an argument to be parsed.
350
351 Accepts either a single instance of Argument or arguments to be passed
352 into :class:`Argument`'s constructor.
353
354 See :class:`Argument`'s constructor for documentation on the available options.
355 """
356
357 if len(args) == 1 and isinstance(args[0], self.argument_class):
358 self.args.append(args[0])
359 else:
360 self.args.append(self.argument_class(*args, **kwargs))
361
362 # Do not know what other argument classes are out there
363 if self.trim and self.argument_class is Argument:
364 # enable trim for appended element
365 self.args[-1].trim = kwargs.get("trim", self.trim)
366
367 return self
368
369 def parse_args(self, req=None, strict=False):
370 """
371 Parse all arguments from the provided request and return the results as a ParseResult
372
373 :param bool strict: if req includes args not in parser, throw 400 BadRequest exception
374 :return: the parsed results as :class:`ParseResult` (or any class defined as :attr:`result_class`)
375 :rtype: ParseResult
376 """
377 if req is None:
378 req = request
379
380 result = self.result_class()
381
382 # A record of arguments not yet parsed; as each is found
383 # among self.args, it will be popped out
384 req.unparsed_arguments = (
385 dict(self.argument_class("").source(req)) if strict else {}
386 )
387 errors = {}
388 for arg in self.args:
389 value, found = arg.parse(req, self.bundle_errors)
390 if isinstance(value, ValueError):
391 errors.update(found)
392 found = None
393 if found or arg.store_missing:
394 result[arg.dest or arg.name] = value
395 if errors:
396 abort(
397 HTTPStatus.BAD_REQUEST, "Input payload validation failed", errors=errors
398 )
399
400 if strict and req.unparsed_arguments:
401 arguments = ", ".join(req.unparsed_arguments.keys())
402 msg = "Unknown arguments: {0}".format(arguments)
403 raise exceptions.BadRequest(msg)
404
405 return result
406
407 def copy(self):
408 """Creates a copy of this RequestParser with the same set of arguments"""
409 parser_copy = self.__class__(self.argument_class, self.result_class)
410 parser_copy.args = deepcopy(self.args)
411 parser_copy.trim = self.trim
412 parser_copy.bundle_errors = self.bundle_errors
413 return parser_copy
414
415 def replace_argument(self, name, *args, **kwargs):
416 """Replace the argument matching the given name with a new version."""
417 new_arg = self.argument_class(name, *args, **kwargs)
418 for index, arg in enumerate(self.args[:]):
419 if new_arg.name == arg.name:
420 del self.args[index]
421 self.args.append(new_arg)
422 break
423 return self
424
425 def remove_argument(self, name):
426 """Remove the argument matching the given name."""
427 for index, arg in enumerate(self.args[:]):
428 if name == arg.name:
429 del self.args[index]
430 break
431 return self
432
433 @property
434 def __schema__(self):
435 params = []
436 locations = set()
437 for arg in self.args:
438 param = arg.__schema__
439 if param:
440 params.append(param)
441 locations.add(param["in"])
442 if "body" in locations and "formData" in locations:
443 raise SpecsError("Can't use formData and body at the same time")
444 return params
445
446
447def _handle_arg_type(arg, param):
448 if isinstance(arg.type, Hashable) and arg.type in PY_TYPES:
449 param["type"] = PY_TYPES[arg.type]
450 elif hasattr(arg.type, "__apidoc__"):
451 param["type"] = arg.type.__apidoc__["name"]
452 param["in"] = "body"
453 elif hasattr(arg.type, "__schema__"):
454 param.update(arg.type.__schema__)
455 elif arg.location == "files":
456 param["type"] = "file"
457 else:
458 param["type"] = "string"