1import collections
2import copy
3import logging
4
5from jsonschema import Draft4Validator, ValidationError
6from starlette.requests import Request
7
8from connexion.exceptions import BadRequestProblem, ExtraParameterProblem
9from connexion.utils import boolean, is_null, is_nullable
10
11logger = logging.getLogger("connexion.validators.parameter")
12
13TYPE_MAP = {"integer": int, "number": float, "boolean": boolean, "object": dict}
14
15try:
16 draft4_format_checker = Draft4Validator.FORMAT_CHECKER # type: ignore
17except AttributeError: # jsonschema < 4.5.0
18 from jsonschema import draft4_format_checker
19
20
21class ParameterValidator:
22 def __init__(
23 self,
24 parameters,
25 uri_parser,
26 strict_validation=False,
27 security_query_params=None,
28 ):
29 """
30 :param parameters: List of request parameter dictionaries
31 :param uri_parser: class to use for uri parsing
32 :param strict_validation: Flag indicating if parameters not in spec are allowed
33 :param security_query_params: List of query parameter names used for security.
34 These parameters will be ignored when checking for extra parameters in case of
35 strict validation.
36 """
37 self.parameters = collections.defaultdict(list)
38 for p in parameters:
39 self.parameters[p["in"]].append(p)
40
41 self.uri_parser = uri_parser
42 self.strict_validation = strict_validation
43 self.security_query_params = set(security_query_params or [])
44
45 @staticmethod
46 def validate_parameter(parameter_type, value, param, param_name=None):
47 if is_nullable(param) and is_null(value):
48 return
49
50 elif value is not None:
51 param = copy.deepcopy(param)
52 param = param.get("schema", param)
53 try:
54 Draft4Validator(param, format_checker=draft4_format_checker).validate(
55 value
56 )
57 except ValidationError as exception:
58 return str(exception)
59
60 elif param.get("required"):
61 return "Missing {parameter_type} parameter '{param[name]}'".format(
62 **locals()
63 )
64
65 @staticmethod
66 def validate_parameter_list(request_params, spec_params):
67 request_params = set(request_params)
68 spec_params = set(spec_params)
69
70 return request_params.difference(spec_params)
71
72 def validate_query_parameter_list(self, request, security_params=None):
73 request_params = request.query_params.keys()
74 spec_params = [x["name"] for x in self.parameters.get("query", [])]
75 spec_params.extend(security_params or [])
76 return self.validate_parameter_list(request_params, spec_params)
77
78 def validate_query_parameter(self, param, request):
79 """
80 Validate a single query parameter (request.args in Flask)
81
82 :type param: dict
83 :rtype: str
84 """
85 # Convert to dict of lists
86 query_params = {
87 k: request.query_params.getlist(k) for k in request.query_params
88 }
89 query_params = self.uri_parser.resolve_query(query_params)
90 val = query_params.get(param["name"])
91 return self.validate_parameter("query", val, param)
92
93 def validate_path_parameter(self, param, request):
94 path_params = self.uri_parser.resolve_path(request.path_params)
95 val = path_params.get(param["name"].replace("-", "_"))
96 return self.validate_parameter("path", val, param)
97
98 def validate_header_parameter(self, param, request):
99 val = request.headers.get(param["name"])
100 return self.validate_parameter("header", val, param)
101
102 def validate_cookie_parameter(self, param, request):
103 val = request.cookies.get(param["name"])
104 return self.validate_parameter("cookie", val, param)
105
106 def validate(self, scope):
107 logger.debug("%s validating parameters...", scope.get("path"))
108
109 request = Request(scope)
110 self.validate_request(request)
111
112 def validate_request(self, request):
113 if self.strict_validation:
114 query_errors = self.validate_query_parameter_list(
115 request, security_params=self.security_query_params
116 )
117
118 if query_errors:
119 raise ExtraParameterProblem(
120 param_type="query", extra_params=query_errors
121 )
122
123 for param in self.parameters.get("query", []):
124 error = self.validate_query_parameter(param, request)
125 if error:
126 raise BadRequestProblem(detail=error)
127
128 for param in self.parameters.get("path", []):
129 error = self.validate_path_parameter(param, request)
130 if error:
131 raise BadRequestProblem(detail=error)
132
133 for param in self.parameters.get("header", []):
134 error = self.validate_header_parameter(param, request)
135 if error:
136 raise BadRequestProblem(detail=error)
137
138 for param in self.parameters.get("cookie", []):
139 error = self.validate_cookie_parameter(param, request)
140 if error:
141 raise BadRequestProblem(detail=error)