1"""
2This module defines a decorator to convert request parameters to arguments for the view function.
3"""
4import abc
5import asyncio
6import builtins
7import functools
8import keyword
9import logging
10import re
11import typing as t
12from copy import copy, deepcopy
13
14import inflection
15
16from connexion.context import context, operation
17from connexion.frameworks.abstract import Framework
18from connexion.http_facts import FORM_CONTENT_TYPES
19from connexion.lifecycle import ConnexionRequest, WSGIRequest
20from connexion.operations import AbstractOperation, Swagger2Operation
21from connexion.utils import (
22 deep_merge,
23 inspect_function_arguments,
24 is_null,
25 is_nullable,
26 make_type,
27)
28
29logger = logging.getLogger(__name__)
30
31CONTEXT_NAME = "context_"
32
33
34class BaseParameterDecorator:
35 def __init__(
36 self,
37 *,
38 framework: t.Type[Framework],
39 pythonic_params: bool = False,
40 ) -> None:
41 self.framework = framework
42 self.sanitize_fn = pythonic if pythonic_params else sanitized
43
44 def _maybe_get_body(
45 self,
46 request: t.Union[WSGIRequest, ConnexionRequest],
47 *,
48 arguments: t.List[str],
49 has_kwargs: bool,
50 ) -> t.Any:
51 body_name = self.sanitize_fn(operation.body_name(request.content_type))
52 # Pass form contents separately for Swagger2 for backward compatibility with
53 # Connexion 2 Checking for body_name is not enough
54 if (body_name in arguments or has_kwargs) or (
55 request.mimetype in FORM_CONTENT_TYPES
56 and isinstance(operation, Swagger2Operation)
57 ):
58 return request.get_body()
59 else:
60 return None
61
62 @abc.abstractmethod
63 def __call__(self, function: t.Callable) -> t.Callable:
64 raise NotImplementedError
65
66
67class SyncParameterDecorator(BaseParameterDecorator):
68 def __call__(self, function: t.Callable) -> t.Callable:
69 unwrapped_function = unwrap_decorators(function)
70 arguments, has_kwargs = inspect_function_arguments(unwrapped_function)
71
72 @functools.wraps(function)
73 def wrapper(request: WSGIRequest) -> t.Any:
74 request_body = self._maybe_get_body(
75 request, arguments=arguments, has_kwargs=has_kwargs
76 )
77
78 kwargs = prep_kwargs(
79 request,
80 request_body=request_body,
81 files=request.files(),
82 arguments=arguments,
83 has_kwargs=has_kwargs,
84 sanitize=self.sanitize_fn,
85 )
86
87 return function(**kwargs)
88
89 return wrapper
90
91
92class AsyncParameterDecorator(BaseParameterDecorator):
93 def __call__(self, function: t.Callable) -> t.Callable:
94 unwrapped_function = unwrap_decorators(function)
95 arguments, has_kwargs = inspect_function_arguments(unwrapped_function)
96
97 @functools.wraps(function)
98 async def wrapper(request: ConnexionRequest) -> t.Any:
99 request_body = self._maybe_get_body(
100 request, arguments=arguments, has_kwargs=has_kwargs
101 )
102
103 while asyncio.iscoroutine(request_body):
104 request_body = await request_body
105
106 kwargs = prep_kwargs(
107 request,
108 request_body=request_body,
109 files=await request.files(),
110 arguments=arguments,
111 has_kwargs=has_kwargs,
112 sanitize=self.sanitize_fn,
113 )
114
115 return await function(**kwargs)
116
117 return wrapper
118
119
120def prep_kwargs(
121 request: t.Union[WSGIRequest, ConnexionRequest],
122 *,
123 request_body: t.Any,
124 files: t.Dict[str, t.Any],
125 arguments: t.List[str],
126 has_kwargs: bool,
127 sanitize: t.Callable,
128) -> dict:
129 kwargs = get_arguments(
130 operation,
131 path_params=request.path_params,
132 query_params=request.query_params,
133 body=request_body,
134 files=files,
135 arguments=arguments,
136 has_kwargs=has_kwargs,
137 sanitize=sanitize,
138 content_type=request.mimetype,
139 )
140
141 # optionally convert parameter variable names to un-shadowed, snake_case form
142 kwargs = {sanitize(k): v for k, v in kwargs.items()}
143
144 # add context info (e.g. from security decorator)
145 for key, value in context.items():
146 if has_kwargs or key in arguments:
147 kwargs[key] = value
148 else:
149 logger.debug("Context parameter '%s' not in function arguments", key)
150 # attempt to provide the request context to the function
151 if CONTEXT_NAME in arguments:
152 kwargs[CONTEXT_NAME] = context
153
154 return kwargs
155
156
157def unwrap_decorators(function: t.Callable) -> t.Callable:
158 """Unwrap decorators to return the original function."""
159 while hasattr(function, "__wrapped__"):
160 function = function.__wrapped__ # type: ignore
161 return function
162
163
164def snake_and_shadow(name: str) -> str:
165 """
166 Converts the given name into Pythonic form. Firstly it converts CamelCase names to snake_case. Secondly it looks to
167 see if the name matches a known built-in and if it does it appends an underscore to the name.
168 :param name: The parameter name
169 """
170 snake = inflection.underscore(name)
171 if snake in builtins.__dict__ or keyword.iskeyword(snake):
172 return f"{snake}_"
173 return snake
174
175
176def sanitized(name: str) -> str:
177 return name and re.sub(
178 "^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", re.sub(r"\[(?!])", "_", name))
179 )
180
181
182def pythonic(name: str) -> str:
183 name = name and snake_and_shadow(name)
184 return sanitized(name)
185
186
187def get_arguments(
188 operation: AbstractOperation,
189 *,
190 path_params: dict,
191 query_params: dict,
192 body: t.Any,
193 files: dict,
194 arguments: t.List[str],
195 has_kwargs: bool,
196 sanitize: t.Callable,
197 content_type: str,
198) -> t.Dict[str, t.Any]:
199 """
200 get arguments for handler function
201 """
202 ret = {}
203 ret.update(_get_path_arguments(path_params, operation=operation, sanitize=sanitize))
204 ret.update(
205 _get_query_arguments(
206 query_params,
207 operation=operation,
208 arguments=arguments,
209 has_kwargs=has_kwargs,
210 sanitize=sanitize,
211 )
212 )
213
214 if operation.method.upper() == "TRACE":
215 # TRACE requests MUST NOT include a body (RFC7231 section 4.3.8)
216 return ret
217
218 ret.update(
219 _get_body_argument(
220 body,
221 operation=operation,
222 arguments=arguments,
223 has_kwargs=has_kwargs,
224 sanitize=sanitize,
225 content_type=content_type,
226 )
227 )
228 body_schema = operation.body_schema(content_type)
229 ret.update(_get_file_arguments(files, arguments, body_schema, has_kwargs))
230 return ret
231
232
233def _get_path_arguments(
234 path_params: dict, *, operation: AbstractOperation, sanitize: t.Callable
235) -> dict:
236 """
237 Extract handler function arguments from path parameters
238 """
239 kwargs = {}
240
241 path_definitions = {
242 parameter["name"]: parameter
243 for parameter in operation.parameters
244 if parameter["in"] == "path"
245 }
246
247 for name, value in path_params.items():
248 sanitized_key = sanitize(name)
249 if name in path_definitions:
250 kwargs[sanitized_key] = _get_val_from_param(value, path_definitions[name])
251 else: # Assume path params mechanism used for injection
252 kwargs[sanitized_key] = value
253 return kwargs
254
255
256def _get_val_from_param(value: t.Any, param_definitions: t.Dict[str, dict]) -> t.Any:
257 """Cast a value according to its definition in the specification."""
258 param_schema = param_definitions.get("schema", param_definitions)
259
260 if is_nullable(param_schema) and is_null(value):
261 return None
262
263 if param_schema["type"] == "array":
264 type_ = param_schema["items"]["type"]
265 format_ = param_schema["items"].get("format")
266 return [make_type(part, type_, format_) for part in value]
267 else:
268 type_ = param_schema["type"]
269 format_ = param_schema.get("format")
270 return make_type(value, type_, format_)
271
272
273def _get_query_arguments(
274 query_params: dict,
275 *,
276 operation: AbstractOperation,
277 arguments: t.List[str],
278 has_kwargs: bool,
279 sanitize: t.Callable,
280) -> dict:
281 """
282 extract handler function arguments from the query parameters
283 """
284 query_definitions = {
285 parameter["name"]: parameter
286 for parameter in operation.parameters
287 if parameter["in"] == "query"
288 }
289
290 default_query_params = _get_query_defaults(query_definitions)
291
292 query_arguments = deepcopy(default_query_params)
293 query_arguments = deep_merge(query_arguments, query_params)
294 return _query_args_helper(
295 query_definitions, query_arguments, arguments, has_kwargs, sanitize
296 )
297
298
299def _get_query_defaults(query_definitions: t.Dict[str, dict]) -> t.Dict[str, t.Any]:
300 """Get the default values for the query parameter from the parameter definition."""
301 defaults = {}
302 for k, v in query_definitions.items():
303 try:
304 if "default" in v:
305 defaults[k] = v["default"]
306 elif v["schema"]["type"] == "object":
307 defaults[k] = _get_default_obj(v["schema"])
308 else:
309 defaults[k] = v["schema"]["default"]
310 except KeyError:
311 pass
312 return defaults
313
314
315def _get_default_obj(schema: dict) -> dict:
316 try:
317 return deepcopy(schema["default"])
318 except KeyError:
319 properties = schema.get("properties", {})
320 return _build_default_obj_recursive(properties, {})
321
322
323def _build_default_obj_recursive(properties: dict, default_object: dict) -> dict:
324 """takes disparate and nested default keys, and builds up a default object"""
325 for name, property_ in properties.items():
326 if "default" in property_ and name not in default_object:
327 default_object[name] = copy(property_["default"])
328 elif property_.get("type") == "object" and "properties" in property_:
329 default_object.setdefault(name, {})
330 default_object[name] = _build_default_obj_recursive(
331 property_["properties"], default_object[name]
332 )
333 return default_object
334
335
336def _query_args_helper(
337 query_definitions: dict,
338 query_arguments: dict,
339 function_arguments: t.List[str],
340 has_kwargs: bool,
341 sanitize: t.Callable,
342) -> dict:
343 result = {}
344 for key, value in query_arguments.items():
345 sanitized_key = sanitize(key)
346 if not has_kwargs and sanitized_key not in function_arguments:
347 logger.debug(
348 "Query Parameter '%s' (sanitized: '%s') not in function arguments",
349 key,
350 sanitized_key,
351 )
352 else:
353 logger.debug(
354 "Query Parameter '%s' (sanitized: '%s') in function arguments",
355 key,
356 sanitized_key,
357 )
358 try:
359 query_defn = query_definitions[key]
360 except KeyError: # pragma: no cover
361 logger.error(
362 "Function argument '%s' (non-sanitized: %s) not defined in specification",
363 sanitized_key,
364 key,
365 )
366 else:
367 logger.debug("%s is a %s", key, query_defn)
368 result.update({sanitized_key: _get_val_from_param(value, query_defn)})
369 return result
370
371
372def _get_body_argument(
373 body: t.Any,
374 *,
375 operation: AbstractOperation,
376 arguments: t.List[str],
377 has_kwargs: bool,
378 sanitize: t.Callable,
379 content_type: str,
380) -> dict:
381 if len(arguments) <= 0 and not has_kwargs:
382 return {}
383
384 if not operation.is_request_body_defined:
385 return {}
386
387 body_name = sanitize(operation.body_name(content_type))
388
389 if content_type in FORM_CONTENT_TYPES:
390 result = _get_body_argument_form(
391 body, operation=operation, content_type=content_type
392 )
393
394 # Unpack form values for Swagger for compatibility with Connexion 2 behavior
395 if content_type in FORM_CONTENT_TYPES and isinstance(
396 operation, Swagger2Operation
397 ):
398 if has_kwargs:
399 return result
400 else:
401 return {
402 sanitize(name): value
403 for name, value in result.items()
404 if sanitize(name) in arguments
405 }
406 else:
407 result = _get_body_argument_json(
408 body, operation=operation, content_type=content_type
409 )
410
411 if body_name in arguments or has_kwargs:
412 return {body_name: result}
413
414 return {}
415
416
417def _get_body_argument_json(
418 body: t.Any, *, operation: AbstractOperation, content_type: str
419) -> t.Any:
420 # if the body came in null, and the schema says it can be null, we decide
421 # to include no value for the body argument, rather than the default body
422 if is_nullable(operation.body_schema(content_type)) and is_null(body):
423 return None
424
425 if body is None:
426 default_body = operation.body_schema(content_type).get("default", {})
427 return deepcopy(default_body)
428
429 return body
430
431
432def _get_body_argument_form(
433 body: dict, *, operation: AbstractOperation, content_type: str
434) -> dict:
435 # now determine the actual value for the body (whether it came in or is default)
436 default_body = operation.body_schema(content_type).get("default", {})
437 body_props = {
438 k: {"schema": v}
439 for k, v in operation.body_schema(content_type).get("properties", {}).items()
440 }
441
442 # by OpenAPI specification `additionalProperties` defaults to `true`
443 # see: https://github.com/OAI/OpenAPI-Specification/blame/3.0.2/versions/3.0.2.md#L2305
444 additional_props = operation.body_schema().get("additionalProperties", True)
445
446 body_arg = deepcopy(default_body)
447 body_arg.update(body or {})
448
449 if body_props or additional_props:
450 return _get_typed_body_values(body_arg, body_props, additional_props)
451
452 return {}
453
454
455def _get_typed_body_values(body_arg, body_props, additional_props):
456 """
457 Return a copy of the provided body_arg dictionary
458 whose values will have the appropriate types
459 as defined in the provided schemas.
460
461 :type body_arg: type dict
462 :type body_props: dict
463 :type additional_props: dict|bool
464 :rtype: dict
465 """
466 additional_props_defn = (
467 {"schema": additional_props} if isinstance(additional_props, dict) else None
468 )
469 res = {}
470
471 for key, value in body_arg.items():
472 try:
473 prop_defn = body_props[key]
474 res[key] = _get_val_from_param(value, prop_defn)
475 except KeyError: # pragma: no cover
476 if not additional_props:
477 logger.error(f"Body property '{key}' not defined in body schema")
478 continue
479 if additional_props_defn is not None:
480 value = _get_val_from_param(value, additional_props_defn)
481 res[key] = value
482
483 return res
484
485
486def _get_file_arguments(files, arguments, body_schema: dict, has_kwargs=False):
487 results = {}
488 for k, v in files.items():
489 if not (k in arguments or has_kwargs):
490 continue
491 if body_schema.get("properties", {}).get(k, {}).get("type") != "array":
492 v = v[0]
493 results[k] = v
494
495 return results