1"""
2This module defines a SecurityHandlerFactory which supports the creation of
3SecurityHandler instances for different security schemes.
4
5It also exposes a `SECURITY_HANDLERS` dictionary which maps security scheme
6types to SecurityHandler classes. This dictionary can be used to register
7custom SecurityHandler classes for custom security schemes, or to overwrite
8existing SecurityHandler classes.
9This can be done by supplying a value for `security_map` argument of the
10SecurityHandlerFactory.
11
12Swagger 2.0 lets you define the following authentication types for an API:
13
14- Basic authentication
15- API key (as a header or a query string parameter)
16- OAuth 2 common flows (authorization code, implicit, resource owner password credentials, client credentials)
17
18
19Changes from OpenAPI 2.0 to OpenAPI 3.0
20If you used OpenAPI 2.0 before, here is a summary of changes to help you get started with OpenAPI 3.0:
21- securityDefinitions were renamed to securitySchemes and moved inside components.
22- type: basic was replaced with type: http and scheme: basic.
23- The new type: http is an umbrella type for all HTTP security schemes, including Basic, Bearer and other,
24and the scheme keyword indicates the scheme type.
25- API keys can now be sent in: cookie.
26- Added support for OpenID Connect Discovery (type: openIdConnect).
27- OAuth 2 security schemes can now define multiple flows.
28- OAuth 2 flows were renamed to match the OAuth 2 Specification: accessCode is now authorizationCode,
29and application is now clientCredentials.
30
31
32OpenAPI uses the term security scheme for authentication and authorization schemes.
33OpenAPI 3.0 lets you describe APIs protected using the following security schemes:
34
35- HTTP authentication schemes (they use the Authorization header):
36 - Basic
37 - Bearer
38 - other HTTP schemes as defined by RFC 7235 and HTTP Authentication Scheme Registry
39- API keys in headers, query string or cookies
40 - Cookie authentication
41- OAuth 2
42- OpenID Connect Discovery
43
44"""
45
46import asyncio
47import base64
48import http.cookies
49import logging
50import os
51import typing as t
52
53import httpx
54
55from connexion.decorators.parameter import inspect_function_arguments
56from connexion.exceptions import OAuthProblem, OAuthResponseProblem, OAuthScopeProblem
57from connexion.lifecycle import ConnexionRequest
58from connexion.utils import get_function_from_name
59
60logger = logging.getLogger(__name__)
61
62
63NO_VALUE = object()
64"""Sentinel value to indicate that no security credentials were found."""
65
66
67class AbstractSecurityHandler:
68
69 required_scopes_kw = "required_scopes"
70 request_kw = "request"
71 client = None
72 security_definition_key: str
73 """The key which contains the value for the function name to resolve."""
74 environ_key: str
75 """The name of the environment variable that can be used alternatively for the function name."""
76
77 def get_fn(self, security_scheme, required_scopes):
78 """Returns the handler function"""
79 security_func = self._resolve_func(security_scheme)
80 if not security_func:
81 logger.warning("... %s missing", self.security_definition_key)
82 return None
83
84 return self._get_verify_func(security_func)
85
86 @classmethod
87 def _get_function(
88 cls,
89 security_definition: dict,
90 security_definition_key: str,
91 environ_key: str,
92 default: t.Optional[t.Callable] = None,
93 ):
94 """
95 Return function by getting its name from security_definition or environment variable
96
97 :param security_definition: Security Definition (scheme) from the spec.
98 :param security_definition_key: The key which contains the value for the function name to resolve.
99 :param environ_key: The name of the environment variable that can be used alternatively for the function name.
100 :param default: The default to use in case the function cannot be found based on the security_definition_key or the environ_key
101 """
102 func_name = security_definition.get(security_definition_key) or os.environ.get(
103 environ_key
104 )
105 if func_name:
106 return get_function_from_name(func_name)
107 return default
108
109 def _generic_check(self, func, exception_msg):
110 async def wrapper(request, *args, required_scopes=None):
111 kwargs = {}
112 if self._accepts_kwarg(func, self.required_scopes_kw):
113 kwargs[self.required_scopes_kw] = required_scopes
114 if self._accepts_kwarg(func, self.request_kw):
115 kwargs[self.request_kw] = request
116 token_info = func(*args, **kwargs)
117 while asyncio.iscoroutine(token_info):
118 token_info = await token_info
119 if token_info is NO_VALUE:
120 return NO_VALUE
121 if token_info is None:
122 raise OAuthResponseProblem(detail=exception_msg)
123 return token_info
124
125 return wrapper
126
127 @staticmethod
128 def get_auth_header_value(request):
129 """
130 Return Authorization type and value if any.
131 If not Authorization, return (None, None)
132 Raise OAuthProblem for invalid Authorization header
133 """
134 authorization = request.headers.get("Authorization")
135 if not authorization:
136 return None, None
137
138 try:
139 auth_type, value = authorization.split(maxsplit=1)
140 except ValueError:
141 raise OAuthProblem(detail="Invalid authorization header")
142 return auth_type.lower(), value
143
144 @staticmethod
145 def _accepts_kwarg(func: t.Callable, keyword: str) -> bool:
146 """Check if the function accepts the provided keyword argument."""
147 arguments, has_kwargs = inspect_function_arguments(func)
148 return has_kwargs or keyword in arguments
149
150 def _resolve_func(self, security_scheme):
151 """
152 Get the user function object based on the security scheme or the environment variable.
153
154 :param security_scheme: Security Definition (scheme) from the spec.
155 """
156 return self._get_function(
157 security_scheme, self.security_definition_key, self.environ_key
158 )
159
160 def _get_verify_func(self, function):
161 """
162 Wraps the user security function in a function that checks the request for the correct
163 security credentials and calls the user function with the correct arguments.
164 """
165 return self._generic_check(function, "Provided authorization is not valid")
166
167
168class BasicSecurityHandler(AbstractSecurityHandler):
169 """
170 Security Handler for
171 - `type: basic` (Swagger 2), and
172 - `type: http` and `scheme: basic` (OpenAPI 3)
173 """
174
175 security_definition_key = "x-basicInfoFunc"
176 environ_key = "BASICINFO_FUNC"
177
178 def _get_verify_func(self, basic_info_func):
179 check_basic_info_func = self.check_basic_auth(basic_info_func)
180
181 def wrapper(request):
182 auth_type, user_pass = self.get_auth_header_value(request)
183 if auth_type != "basic":
184 return NO_VALUE
185
186 try:
187 username, password = (
188 base64.b64decode(user_pass).decode("latin1").split(":", 1)
189 )
190 except Exception:
191 raise OAuthProblem(detail="Invalid authorization header")
192
193 return check_basic_info_func(request, username, password)
194
195 return wrapper
196
197 def check_basic_auth(self, basic_info_func):
198 return self._generic_check(
199 basic_info_func, "Provided authorization is not valid"
200 )
201
202
203class BearerSecurityHandler(AbstractSecurityHandler):
204 """
205 Security Handler for HTTP Bearer authentication.
206 """
207
208 security_definition_key = "x-bearerInfoFunc"
209 environ_key = "BEARERINFO_FUNC"
210
211 def check_bearer_token(self, token_info_func):
212 return self._generic_check(token_info_func, "Provided token is not valid")
213
214 def _get_verify_func(self, token_info_func):
215 """
216 :param token_info_func: types.FunctionType
217 :rtype: types.FunctionType
218 """
219 check_bearer_func = self.check_bearer_token(token_info_func)
220
221 def wrapper(request):
222 auth_type, token = self.get_auth_header_value(request)
223 if auth_type != "bearer":
224 return NO_VALUE
225 return check_bearer_func(request, token)
226
227 return wrapper
228
229
230class ApiKeySecurityHandler(AbstractSecurityHandler):
231 """
232 Security Handler for API Keys.
233 """
234
235 security_definition_key = "x-apikeyInfoFunc"
236 environ_key = "APIKEYINFO_FUNC"
237
238 def get_fn(self, security_scheme, required_scopes):
239 apikey_info_func = self._resolve_func(security_scheme)
240 if not apikey_info_func:
241 logger.warning("... %s missing", self.security_definition_key)
242 return None
243
244 return self._get_verify_func(
245 apikey_info_func,
246 security_scheme["in"],
247 security_scheme["name"],
248 )
249
250 def _get_verify_func(self, api_key_info_func, loc, name):
251 check_api_key_func = self.check_api_key(api_key_info_func)
252
253 def wrapper(request: ConnexionRequest):
254 if loc == "query":
255 api_key = request.query_params.get(name)
256 elif loc == "header":
257 api_key = request.headers.get(name)
258 elif loc == "cookie":
259 cookie_list = request.headers.get("Cookie")
260 api_key = self.get_cookie_value(cookie_list, name)
261 else:
262 return NO_VALUE
263
264 if api_key is None:
265 return NO_VALUE
266
267 return check_api_key_func(request, api_key)
268
269 return wrapper
270
271 def check_api_key(self, api_key_info_func):
272 return self._generic_check(api_key_info_func, "Provided apikey is not valid")
273
274 @staticmethod
275 def get_cookie_value(cookies, name):
276 """
277 Returns cookie value by its name. `None` if no such value.
278
279 :param cookies: str: cookies raw data
280 :param name: str: cookies key
281 """
282 cookie_parser = http.cookies.SimpleCookie()
283 cookie_parser.load(str(cookies))
284 try:
285 return cookie_parser[name].value
286 except KeyError:
287 return None
288
289
290class OAuthSecurityHandler(AbstractSecurityHandler):
291 """
292 Security Handler for the OAuth security scheme.
293 """
294
295 def get_fn(self, security_scheme, required_scopes):
296 token_info_func = self.get_tokeninfo_func(security_scheme)
297 scope_validate_func = self.get_scope_validate_func(security_scheme)
298 if not token_info_func:
299 logger.warning("... x-tokenInfoFunc missing")
300 return None
301
302 return self._get_verify_func(
303 token_info_func, scope_validate_func, required_scopes
304 )
305
306 def get_tokeninfo_func(self, security_definition: dict) -> t.Optional[t.Callable]:
307 """
308 Gets the function for retrieving the token info.
309 It is possible to specify a function or a URL. The function variant is
310 preferred. If it is not found, the URL variant is used with the
311 `get_token_info_remote` function.
312
313 >>> get_tokeninfo_func({'x-tokenInfoFunc': 'foo.bar'})
314 '<function foo.bar>'
315 """
316 token_info_func = self._get_function(
317 security_definition, "x-tokenInfoFunc", "TOKENINFO_FUNC"
318 )
319 if token_info_func:
320 return token_info_func
321
322 token_info_url = security_definition.get("x-tokenInfoUrl") or os.environ.get(
323 "TOKENINFO_URL"
324 )
325 if token_info_url:
326 return self.get_token_info_remote(token_info_url)
327
328 return None
329
330 @classmethod
331 def get_scope_validate_func(cls, security_definition):
332 """
333 Gets the function for validating the token scopes.
334 If it is not found, the default `validate_scope` function is used.
335
336 >>> get_scope_validate_func({'x-scopeValidateFunc': 'foo.bar'})
337 '<function foo.bar>'
338 """
339 return cls._get_function(
340 security_definition,
341 "x-scopeValidateFunc",
342 "SCOPEVALIDATE_FUNC",
343 cls.validate_scope,
344 )
345
346 @staticmethod
347 def validate_scope(required_scopes, token_scopes):
348 """
349 :param required_scopes: Scopes required to access operation
350 :param token_scopes: Scopes granted by authorization server
351 :rtype: bool
352 """
353 required_scopes = set(required_scopes)
354 if isinstance(token_scopes, list):
355 token_scopes = set(token_scopes)
356 else:
357 token_scopes = set(token_scopes.split())
358 logger.debug("... Scopes required: %s", required_scopes)
359 logger.debug("... Token scopes: %s", token_scopes)
360 if not required_scopes <= token_scopes:
361 logger.info(
362 "... Token scopes (%s) do not match the scopes necessary to call endpoint (%s)."
363 " Aborting with 403.",
364 token_scopes,
365 required_scopes,
366 )
367 return False
368 return True
369
370 def get_token_info_remote(self, token_info_url: str) -> t.Callable:
371 """
372 Return a function which will call `token_info_url` to retrieve token info.
373
374 Returned function must accept oauth token in parameter.
375 It must return a token_info dict in case of success, None otherwise.
376
377 :param token_info_url: URL to get information about the token
378 """
379
380 async def wrapper(token):
381 if self.client is None:
382 self.client = httpx.AsyncClient()
383 headers = {"Authorization": f"Bearer {token}"}
384 token_request = await self.client.get(
385 token_info_url, headers=headers, timeout=5
386 )
387 if token_request.status_code != 200:
388 return
389 return token_request.json()
390
391 return wrapper
392
393 def _get_verify_func(self, token_info_func, scope_validate_func, required_scopes):
394 check_oauth_func = self.check_oauth_func(token_info_func, scope_validate_func)
395
396 def wrapper(request):
397 auth_type, token = self.get_auth_header_value(request)
398 if auth_type != "bearer":
399 return NO_VALUE
400
401 return check_oauth_func(request, token, required_scopes=required_scopes)
402
403 return wrapper
404
405 def check_oauth_func(self, token_info_func, scope_validate_func):
406 get_token_info = self._generic_check(
407 token_info_func, "Provided token is not valid"
408 )
409
410 async def wrapper(request, token, required_scopes):
411 token_info = await get_token_info(
412 request, token, required_scopes=required_scopes
413 )
414
415 # Fallback to 'scopes' for backward compatibility
416 token_scopes = token_info.get("scope", token_info.get("scopes", ""))
417
418 validation = scope_validate_func(required_scopes, token_scopes)
419 while asyncio.iscoroutine(validation):
420 validation = await validation
421 if not validation:
422 raise OAuthScopeProblem(
423 required_scopes=required_scopes,
424 token_scopes=token_scopes,
425 )
426
427 return token_info
428
429 return wrapper
430
431
432SECURITY_HANDLERS = {
433 # Swagger 2: `type: basic`
434 # OpenAPI 3: `type: http` and `scheme: basic`
435 "basic": BasicSecurityHandler,
436 # Swagger 2 and OpenAPI 3
437 "apiKey": ApiKeySecurityHandler,
438 "oauth2": OAuthSecurityHandler,
439 # OpenAPI 3: http schemes
440 "bearer": BearerSecurityHandler,
441}
442
443
444class SecurityHandlerFactory:
445 """
446 A factory class for parsing security schemes and returning the appropriate
447 security handler.
448
449 By default, it will use the built-in security handlers specified in the
450 SECURITY_HANDLERS dict, but you can also pass in your own security handlers
451 to override the built-in ones.
452 """
453
454 def __init__(
455 self,
456 security_handlers: t.Optional[dict] = None,
457 ) -> None:
458 self.security_handlers = SECURITY_HANDLERS.copy()
459 if security_handlers is not None:
460 self.security_handlers.update(security_handlers)
461
462 def parse_security_scheme(
463 self,
464 security_scheme: dict,
465 required_scopes: t.List[str],
466 ) -> t.Optional[t.Callable]:
467 """Parses the security scheme and returns the function for verifying it.
468
469 :param security_scheme: The security scheme from the spec.
470 :param required_scopes: List of scopes for this security scheme.
471 """
472 security_type = security_scheme["type"]
473 if security_type in ("basic", "oauth2"):
474 security_handler = self.security_handlers[security_type]
475 return security_handler().get_fn(security_scheme, required_scopes)
476
477 # OpenAPI 3.0.0
478 elif security_type == "http":
479 scheme = security_scheme["scheme"].lower()
480 if scheme in self.security_handlers:
481 security_handler = self.security_handlers[scheme]
482 return security_handler().get_fn(security_scheme, required_scopes)
483 else:
484 logger.warning("... Unsupported http authorization scheme %s", scheme)
485 return None
486
487 elif security_type == "apiKey":
488 scheme = security_scheme.get("x-authentication-scheme", "").lower()
489 if scheme == "bearer":
490 return BearerSecurityHandler().get_fn(security_scheme, required_scopes)
491 else:
492 security_handler = self.security_handlers["apiKey"]
493 return security_handler().get_fn(security_scheme, required_scopes)
494
495 # Custom security handler
496 elif (scheme := security_scheme["scheme"].lower()) in self.security_handlers:
497 security_handler = self.security_handlers[scheme]
498 return security_handler().get_fn(security_scheme, required_scopes)
499
500 else:
501 logger.warning(
502 "... Unsupported security scheme type %s",
503 security_type,
504 )
505 return None
506
507 @staticmethod
508 async def security_passthrough(request):
509 """Used when no security is required for the operation.
510
511 Equivalent OpenAPI snippet:
512
513 .. code-block:: yaml
514
515 /helloworld
516 get:
517 security: [] # No security
518 ...
519 """
520 return request
521
522 @staticmethod
523 def verify_none(request):
524 """Used for optional security.
525
526 Equivalent OpenAPI snippet:
527
528 .. code-block:: yaml
529
530 security:
531 - {} # <--
532 - myapikey: []
533 """
534 return {}
535
536 def verify_multiple_schemes(self, schemes):
537 """
538 Verifies multiple authentication schemes in AND fashion.
539 If any scheme fails, the entire authentication fails.
540
541 :param schemes: mapping scheme_name to auth function
542 :type schemes: dict
543 :rtype: types.FunctionType
544 """
545
546 async def wrapper(request):
547 token_info = {}
548 for scheme_name, func in schemes.items():
549 result = func(request)
550 while asyncio.iscoroutine(result):
551 result = await result
552 if result is NO_VALUE:
553 return NO_VALUE
554 token_info[scheme_name] = result
555
556 return token_info
557
558 return wrapper
559
560 @classmethod
561 def verify_security(cls, auth_funcs):
562 async def verify_fn(request):
563 token_info = NO_VALUE
564 errors = []
565 for func in auth_funcs:
566 try:
567 token_info = func(request)
568 while asyncio.iscoroutine(token_info):
569 token_info = await token_info
570 if token_info is not NO_VALUE:
571 break
572 except Exception as err:
573 errors.append(err)
574
575 else:
576 if errors != []:
577 cls._raise_most_specific(errors)
578 else:
579 logger.info("... No auth provided. Aborting with 401.")
580 raise OAuthProblem(detail="No authorization token provided")
581
582 request.context.update(
583 {
584 # Fallback to 'uid' for backward compatibility
585 "user": token_info.get("sub", token_info.get("uid")),
586 "token_info": token_info,
587 }
588 )
589
590 return verify_fn
591
592 @staticmethod
593 def _raise_most_specific(exceptions: t.List[Exception]) -> None:
594 """Raises the most specific error from a list of exceptions by status code.
595
596 The status codes are expected to be either in the `code`
597 or in the `status` attribute of the exceptions.
598
599 The order is as follows:
600 - 403: valid credentials but not enough privileges
601 - 401: no or invalid credentials
602 - for other status codes, the smallest one is selected
603
604 :param errors: List of exceptions.
605 :type errors: t.List[Exception]
606 """
607 if not exceptions:
608 return
609 # We only use status code attributes from exceptions
610 # We use 600 as default because 599 is highest valid status code
611 status_to_exc = {
612 getattr(exc, "status_code", getattr(exc, "status", 600)): exc
613 for exc in exceptions
614 }
615 if 403 in status_to_exc:
616 raise status_to_exc[403]
617 elif 401 in status_to_exc:
618 raise status_to_exc[401]
619 else:
620 lowest_status_code = min(status_to_exc)
621 raise status_to_exc[lowest_status_code]