Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/adal/token_request.py: 22%
242 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:05 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:05 +0000
1#------------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation.
4# All rights reserved.
5#
6# This code is licensed under the MIT License.
7#
8# Permission is hereby granted, free of charge, to any person obtaining a copy
9# of this software and associated documentation files(the "Software"), to deal
10# in the Software without restriction, including without limitation the rights
11# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
12# copies of the Software, and to permit persons to whom the Software is
13# furnished to do so, subject to the following conditions :
14#
15# The above copyright notice and this permission notice shall be included in
16# all copies or substantial portions of the Software.
17#
18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE
21# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24# THE SOFTWARE.
25#
26#------------------------------------------------------------------------------
28from base64 import b64encode
30from . import constants
31from . import log
32from . import mex
33from . import oauth2_client
34from . import self_signed_jwt
35from . import user_realm
36from . import wstrust_request
37from .adal_error import AdalError
38from .cache_driver import CacheDriver
39from .constants import WSTrustVersion
41OAUTH2_PARAMETERS = constants.OAuth2.Parameters
42TOKEN_RESPONSE_FIELDS = constants.TokenResponseFields
43OAUTH2_GRANT_TYPE = constants.OAuth2.GrantType
44OAUTH2_SCOPE = constants.OAuth2.Scope
45OAUTH2_DEVICE_CODE_RESPONSE_PARAMETERS = constants.OAuth2.DeviceCodeResponseParameters
46SAML = constants.Saml
47ACCOUNT_TYPE = constants.UserRealm.account_type
48USER_ID = constants.TokenResponseFields.USER_ID
49_CLIENT_ID = constants.TokenResponseFields._CLIENT_ID #pylint: disable=protected-access
51def add_parameter_if_available(parameters, key, value):
52 if value:
53 parameters[key] = value
55def _get_saml_grant_type(wstrust_response):
56 token_type = wstrust_response.token_type
57 if token_type == SAML.TokenTypeV1 or token_type == SAML.OasisWssSaml11TokenProfile11:
58 return OAUTH2_GRANT_TYPE.SAML1
60 elif token_type == SAML.TokenTypeV2 or token_type == SAML.OasisWssSaml2TokenProfile2:
61 return OAUTH2_GRANT_TYPE.SAML2
63 else:
64 raise AdalError("RSTR returned unknown token type: {}".format(token_type))
66class TokenRequest(object):
68 def __init__(self, call_context, authentication_context, client_id,
69 resource, redirect_uri=None):
71 self._log = log.Logger("TokenRequest", call_context['log_context'])
72 self._call_context = call_context
74 self._authentication_context = authentication_context
75 self._resource = resource
76 self._client_id = client_id
77 self._redirect_uri = redirect_uri
79 self._cache_driver = None
81 # should be set at the beginning of get_token
82 # functions that have a user_id
83 self._user_id = None
84 self._user_realm = None
86 # should be set when acquire token using device flow
87 self._polling_client = None
89 def _create_user_realm_request(self, username):
90 return user_realm.UserRealm(self._call_context,
91 username,
92 self._authentication_context.authority.url)
94 def _create_mex(self, mex_endpoint):
95 return mex.Mex(self._call_context, mex_endpoint)
97 def _create_wstrust_request(self, wstrust_endpoint, applies_to, wstrust_endpoint_version):
98 return wstrust_request.WSTrustRequest(self._call_context, wstrust_endpoint,
99 applies_to, wstrust_endpoint_version)
101 def _create_oauth2_client(self):
102 return oauth2_client.OAuth2Client(self._call_context,
103 self._authentication_context.authority)
105 def _create_self_signed_jwt(self):
106 return self_signed_jwt.SelfSignedJwt(self._call_context,
107 self._authentication_context.authority,
108 self._client_id)
110 def _oauth_get_token(self, oauth_parameters):
111 client = self._create_oauth2_client()
112 return client.get_token(oauth_parameters)
114 def _create_cache_driver(self):
115 return CacheDriver(
116 self._call_context,
117 self._authentication_context.authority.url,
118 self._resource,
119 self._client_id,
120 self._authentication_context.cache,
121 self._get_token_with_token_response
122 )
124 def _find_token_from_cache(self):
125 self._cache_driver = self._create_cache_driver()
126 cache_query = self._create_cache_query()
127 return self._cache_driver.find(cache_query)
129 def _add_token_into_cache(self, token):
130 cache_driver = self._create_cache_driver()
131 self._log.debug('Storing retrieved token into cache')
132 cache_driver.add(token)
134 def _get_token_with_token_response(self, entry, resource):
135 self._log.debug("called to refresh a token from the cache")
136 refresh_token = entry[TOKEN_RESPONSE_FIELDS.REFRESH_TOKEN]
137 return self._get_token_with_refresh_token(refresh_token, resource, None)
139 def _create_cache_query(self):
140 query = {_CLIENT_ID : self._client_id}
141 if self._user_id:
142 query[USER_ID] = self._user_id
143 else:
144 self._log.debug("No user_id passed for cache query")
146 return query
148 def _create_oauth_parameters(self, grant_type):
150 oauth_parameters = {}
151 oauth_parameters[OAUTH2_PARAMETERS.GRANT_TYPE] = grant_type
153 if (OAUTH2_GRANT_TYPE.AUTHORIZATION_CODE != grant_type and
154 OAUTH2_GRANT_TYPE.CLIENT_CREDENTIALS != grant_type and
155 OAUTH2_GRANT_TYPE.REFRESH_TOKEN != grant_type and
156 OAUTH2_GRANT_TYPE.DEVICE_CODE != grant_type):
158 oauth_parameters[OAUTH2_PARAMETERS.SCOPE] = OAUTH2_SCOPE.OPENID
160 add_parameter_if_available(oauth_parameters, OAUTH2_PARAMETERS.CLIENT_ID,
161 self._client_id)
162 add_parameter_if_available(oauth_parameters, OAUTH2_PARAMETERS.RESOURCE,
163 self._resource)
164 add_parameter_if_available(oauth_parameters, OAUTH2_PARAMETERS.REDIRECT_URI,
165 self._redirect_uri)
167 return oauth_parameters
169 def _get_token_username_password_managed(self, username, password):
170 self._log.debug('Acquiring token with username password for managed user')
172 oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.PASSWORD)
174 oauth_parameters[OAUTH2_PARAMETERS.PASSWORD] = password
175 oauth_parameters[OAUTH2_PARAMETERS.USERNAME] = username
177 return self._oauth_get_token(oauth_parameters)
179 def _perform_wstrust_assertion_oauth_exchange(self, wstrust_response):
180 self._log.debug("Performing OAuth assertion grant type exchange.")
182 oauth_parameters = {}
183 grant_type = _get_saml_grant_type(wstrust_response)
185 token_bytes = wstrust_response.token
186 assertion = b64encode(token_bytes)
188 oauth_parameters = self._create_oauth_parameters(grant_type)
189 oauth_parameters[OAUTH2_PARAMETERS.ASSERTION] = assertion
191 return self._oauth_get_token(oauth_parameters)
193 def _perform_wstrust_exchange(self, wstrust_endpoint, wstrust_endpoint_version, cloud_audience_urn, username, password):
195 wstrust = self._create_wstrust_request(wstrust_endpoint, cloud_audience_urn,
196 wstrust_endpoint_version)
197 result = wstrust.acquire_token(username, password)
199 if not result.token:
200 err_template = "Unsuccessful RSTR.\n\terror code: {}\n\tfaultMessage: {}"
201 error_msg = err_template.format(result.error_code, result.fault_message)
202 self._log.info(error_msg)
203 raise AdalError(error_msg)
205 return result
207 def _perform_username_password_for_access_token_exchange(self, wstrust_endpoint, wstrust_endpoint_version, cloud_audience_urn,
208 username, password):
209 wstrust_response = self._perform_wstrust_exchange(wstrust_endpoint, wstrust_endpoint_version, cloud_audience_urn,
210 username, password)
211 return self._perform_wstrust_assertion_oauth_exchange(wstrust_response)
213 def _get_token_username_password_federated(self, username, password):
214 self._log.debug("Acquiring token with username password for federated user")
216 cloud_audience_urn = self._user_realm.cloud_audience_urn
217 if not self._user_realm.federation_metadata_url:
218 self._log.warn("Unable to retrieve federationMetadataUrl from AAD. "
219 "Attempting fallback to AAD supplied endpoint.")
221 if not self._user_realm.federation_active_auth_url:
222 raise AdalError('AAD did not return a WSTrust endpoint. Unable to proceed.')
224 wstrust_version = TokenRequest._parse_wstrust_version_from_federation_active_authurl(
225 self._user_realm.federation_active_auth_url)
226 self._log.debug(
227 'wstrust endpoint version is: %(wstrust_version)s',
228 {"wstrust_version": wstrust_version})
230 return self._perform_username_password_for_access_token_exchange(
231 self._user_realm.federation_active_auth_url,
232 wstrust_version, cloud_audience_urn, username, password)
233 else:
234 mex_endpoint = self._user_realm.federation_metadata_url
235 self._log.debug(
236 "Attempting mex at: %(mex_endpoint)s",
237 {"mex_endpoint": mex_endpoint})
238 mex_instance = self._create_mex(mex_endpoint)
239 wstrust_version = WSTrustVersion.UNDEFINED
241 try:
242 mex_instance.discover()
243 wstrust_endpoint = mex_instance.username_password_policy['url']
244 wstrust_version = mex_instance.username_password_policy['version']
245 except Exception: #pylint: disable=broad-except
246 self._log.warn(
247 "MEX exchange failed for %(mex_endpoint)s. "
248 "Attempting fallback to AAD supplied endpoint.",
249 {"mex_endpoint": mex_endpoint})
250 wstrust_endpoint = self._user_realm.federation_active_auth_url
251 wstrust_version = TokenRequest._parse_wstrust_version_from_federation_active_authurl(
252 self._user_realm.federation_active_auth_url)
253 if not wstrust_endpoint:
254 raise AdalError('AAD did not return a WSTrust endpoint. Unable to proceed.')
256 return self._perform_username_password_for_access_token_exchange(wstrust_endpoint, wstrust_version,
257 cloud_audience_urn,
258 username, password)
259 @staticmethod
260 def _parse_wstrust_version_from_federation_active_authurl(federation_active_authurl):
261 if '/trust/2005/usernamemixed' in federation_active_authurl:
262 return WSTrustVersion.WSTRUST2005
263 if '/trust/13/usernamemixed' in federation_active_authurl:
264 return WSTrustVersion.WSTRUST13
265 return WSTrustVersion.UNDEFINED
267 def get_token_with_username_password(self, username, password):
268 self._log.debug("Acquiring token with username password.")
269 self._user_id = username
270 try:
271 token = self._find_token_from_cache()
272 if token:
273 return token
274 except AdalError:
275 self._log.exception('Attempt to look for token in cache resulted in Error')
277 if not self._authentication_context.authority.is_adfs_authority:
278 self._user_realm = self._create_user_realm_request(username)
279 self._user_realm.discover()
281 try:
282 if self._user_realm.account_type == ACCOUNT_TYPE['Managed']:
283 token = self._get_token_username_password_managed(username, password)
284 elif self._user_realm.account_type == ACCOUNT_TYPE['Federated']:
285 token = self._get_token_username_password_federated(username, password)
286 else:
287 raise AdalError(
288 "Server returned an unknown AccountType: {}".format(self._user_realm.account_type))
289 self._log.debug("Successfully retrieved token from authority.")
290 except Exception:
291 self._log.info("get_token_func returned with error")
292 raise
293 else:
294 self._log.info('Skipping user realm discovery for ADFS authority')
295 token = self._get_token_username_password_managed(username, password)
297 self._cache_driver.add(token)
298 return token
300 def get_token_with_client_credentials(self, client_secret):
301 self._log.debug("Getting token with client credentials.")
302 try:
303 token = self._find_token_from_cache()
304 if token:
305 return token
306 except AdalError:
307 self._log.exception('Attempt to look for token in cache resulted in Error')
309 oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.CLIENT_CREDENTIALS)
310 oauth_parameters[OAUTH2_PARAMETERS.CLIENT_SECRET] = client_secret
312 token = self._oauth_get_token(oauth_parameters)
313 self._cache_driver.add(token)
314 return token
316 def get_token_with_authorization_code(self, authorization_code, client_secret, code_verifier):
318 self._log.info("Getting token with auth code.")
319 oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.AUTHORIZATION_CODE)
320 oauth_parameters[OAUTH2_PARAMETERS.CODE] = authorization_code
321 if client_secret is not None:
322 oauth_parameters[OAUTH2_PARAMETERS.CLIENT_SECRET] = client_secret
323 if code_verifier is not None:
324 oauth_parameters[OAUTH2_PARAMETERS.CODE_VERIFIER] = code_verifier
325 token = self._oauth_get_token(oauth_parameters)
326 self._add_token_into_cache(token)
327 return token
329 def _get_token_with_refresh_token(self, refresh_token, resource, client_secret):
331 self._log.info("Getting a new token from a refresh token")
333 oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.REFRESH_TOKEN)
334 if resource:
335 oauth_parameters[OAUTH2_PARAMETERS.RESOURCE] = resource
337 if client_secret:
338 oauth_parameters[OAUTH2_PARAMETERS.CLIENT_SECRET] = client_secret
340 oauth_parameters[OAUTH2_PARAMETERS.REFRESH_TOKEN] = refresh_token
341 return self._oauth_get_token(oauth_parameters)
343 def get_token_with_refresh_token(self, refresh_token, client_secret):
344 return self._get_token_with_refresh_token(refresh_token, None, client_secret)
346 def get_token_from_cache_with_refresh(self, user_id):
347 self._log.debug("Getting token from cache with refresh if necessary.")
348 self._user_id = user_id
349 return self._find_token_from_cache()
351 def _create_jwt(self, certificate, thumbprint, public_certificate):
353 ssj = self._create_self_signed_jwt()
354 jwt = ssj.create(certificate, thumbprint, public_certificate)
356 if not jwt:
357 raise AdalError("Failed to create JWT.")
358 return jwt
360 def get_token_with_certificate(self, certificate, thumbprint, public_certificate):
362 self._log.info("Getting a token via certificate.")
364 jwt = self._create_jwt(certificate, thumbprint, public_certificate)
366 oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.CLIENT_CREDENTIALS)
367 oauth_parameters[OAUTH2_PARAMETERS.CLIENT_ASSERTION_TYPE] = OAUTH2_GRANT_TYPE.JWT_BEARER
368 oauth_parameters[OAUTH2_PARAMETERS.CLIENT_ASSERTION] = jwt
370 try:
371 token = self._find_token_from_cache()
372 if token:
373 return token
374 except AdalError:
375 self._log.exception('Attempt to look for token in cache resulted in Error')
377 return self._oauth_get_token(oauth_parameters)
379 def get_token_with_device_code(self, user_code_info):
380 self._log.info("Getting a token via device code")
382 oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.DEVICE_CODE)
383 oauth_parameters[OAUTH2_PARAMETERS.CODE] = user_code_info[OAUTH2_DEVICE_CODE_RESPONSE_PARAMETERS.DEVICE_CODE]
385 interval = user_code_info[OAUTH2_DEVICE_CODE_RESPONSE_PARAMETERS.INTERVAL]
386 expires_in = user_code_info[OAUTH2_DEVICE_CODE_RESPONSE_PARAMETERS.EXPIRES_IN]
388 if interval <= 0:
389 raise AdalError('invalid refresh interval')
391 client = self._create_oauth2_client()
392 self._polling_client = client
394 token = client.get_token_with_polling(oauth_parameters, interval, expires_in)
395 self._add_token_into_cache(token)
397 return token
399 def cancel_token_request_with_device_code(self):
400 self._polling_client.cancel_polling_request()