Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/adal/token_request.py: 22%

242 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 06:05 +0000

1#------------------------------------------------------------------------------ 

2# 

3# Copyright (c) Microsoft Corporation.  

4# All rights reserved. 

5#  

6# This code is licensed under the MIT License. 

7#  

8# Permission is hereby granted, free of charge, to any person obtaining a copy 

9# of this software and associated documentation files(the "Software"), to deal 

10# in the Software without restriction, including without limitation the rights 

11# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell 

12# copies of the Software, and to permit persons to whom the Software is 

13# furnished to do so, subject to the following conditions : 

14#  

15# The above copyright notice and this permission notice shall be included in 

16# all copies or substantial portions of the Software. 

17#  

18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

19# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

20# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE 

21# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

22# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

23# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 

24# THE SOFTWARE. 

25# 

26#------------------------------------------------------------------------------ 

27 

28from base64 import b64encode 

29 

30from . import constants 

31from . import log 

32from . import mex 

33from . import oauth2_client 

34from . import self_signed_jwt 

35from . import user_realm 

36from . import wstrust_request 

37from .adal_error import AdalError 

38from .cache_driver import CacheDriver 

39from .constants import WSTrustVersion 

40 

41OAUTH2_PARAMETERS = constants.OAuth2.Parameters 

42TOKEN_RESPONSE_FIELDS = constants.TokenResponseFields 

43OAUTH2_GRANT_TYPE = constants.OAuth2.GrantType 

44OAUTH2_SCOPE = constants.OAuth2.Scope 

45OAUTH2_DEVICE_CODE_RESPONSE_PARAMETERS = constants.OAuth2.DeviceCodeResponseParameters 

46SAML = constants.Saml 

47ACCOUNT_TYPE = constants.UserRealm.account_type 

48USER_ID = constants.TokenResponseFields.USER_ID 

49_CLIENT_ID = constants.TokenResponseFields._CLIENT_ID #pylint: disable=protected-access 

50 

51def add_parameter_if_available(parameters, key, value): 

52 if value: 

53 parameters[key] = value 

54 

55def _get_saml_grant_type(wstrust_response): 

56 token_type = wstrust_response.token_type 

57 if token_type == SAML.TokenTypeV1 or token_type == SAML.OasisWssSaml11TokenProfile11: 

58 return OAUTH2_GRANT_TYPE.SAML1 

59 

60 elif token_type == SAML.TokenTypeV2 or token_type == SAML.OasisWssSaml2TokenProfile2: 

61 return OAUTH2_GRANT_TYPE.SAML2 

62 

63 else: 

64 raise AdalError("RSTR returned unknown token type: {}".format(token_type)) 

65 

66class TokenRequest(object): 

67 

68 def __init__(self, call_context, authentication_context, client_id, 

69 resource, redirect_uri=None): 

70 

71 self._log = log.Logger("TokenRequest", call_context['log_context']) 

72 self._call_context = call_context 

73 

74 self._authentication_context = authentication_context 

75 self._resource = resource 

76 self._client_id = client_id 

77 self._redirect_uri = redirect_uri 

78 

79 self._cache_driver = None 

80 

81 # should be set at the beginning of get_token 

82 # functions that have a user_id 

83 self._user_id = None 

84 self._user_realm = None 

85 

86 # should be set when acquire token using device flow 

87 self._polling_client = None 

88 

89 def _create_user_realm_request(self, username): 

90 return user_realm.UserRealm(self._call_context, 

91 username, 

92 self._authentication_context.authority.url) 

93 

94 def _create_mex(self, mex_endpoint): 

95 return mex.Mex(self._call_context, mex_endpoint) 

96 

97 def _create_wstrust_request(self, wstrust_endpoint, applies_to, wstrust_endpoint_version): 

98 return wstrust_request.WSTrustRequest(self._call_context, wstrust_endpoint, 

99 applies_to, wstrust_endpoint_version) 

100 

101 def _create_oauth2_client(self): 

102 return oauth2_client.OAuth2Client(self._call_context, 

103 self._authentication_context.authority) 

104 

105 def _create_self_signed_jwt(self): 

106 return self_signed_jwt.SelfSignedJwt(self._call_context, 

107 self._authentication_context.authority, 

108 self._client_id) 

109 

110 def _oauth_get_token(self, oauth_parameters): 

111 client = self._create_oauth2_client() 

112 return client.get_token(oauth_parameters) 

113 

114 def _create_cache_driver(self): 

115 return CacheDriver( 

116 self._call_context, 

117 self._authentication_context.authority.url, 

118 self._resource, 

119 self._client_id, 

120 self._authentication_context.cache, 

121 self._get_token_with_token_response 

122 ) 

123 

124 def _find_token_from_cache(self): 

125 self._cache_driver = self._create_cache_driver() 

126 cache_query = self._create_cache_query() 

127 return self._cache_driver.find(cache_query) 

128 

129 def _add_token_into_cache(self, token): 

130 cache_driver = self._create_cache_driver() 

131 self._log.debug('Storing retrieved token into cache') 

132 cache_driver.add(token) 

133 

134 def _get_token_with_token_response(self, entry, resource): 

135 self._log.debug("called to refresh a token from the cache") 

136 refresh_token = entry[TOKEN_RESPONSE_FIELDS.REFRESH_TOKEN] 

137 return self._get_token_with_refresh_token(refresh_token, resource, None) 

138 

139 def _create_cache_query(self): 

140 query = {_CLIENT_ID : self._client_id} 

141 if self._user_id: 

142 query[USER_ID] = self._user_id 

143 else: 

144 self._log.debug("No user_id passed for cache query") 

145 

146 return query 

147 

148 def _create_oauth_parameters(self, grant_type): 

149 

150 oauth_parameters = {} 

151 oauth_parameters[OAUTH2_PARAMETERS.GRANT_TYPE] = grant_type 

152 

153 if (OAUTH2_GRANT_TYPE.AUTHORIZATION_CODE != grant_type and 

154 OAUTH2_GRANT_TYPE.CLIENT_CREDENTIALS != grant_type and 

155 OAUTH2_GRANT_TYPE.REFRESH_TOKEN != grant_type and 

156 OAUTH2_GRANT_TYPE.DEVICE_CODE != grant_type): 

157 

158 oauth_parameters[OAUTH2_PARAMETERS.SCOPE] = OAUTH2_SCOPE.OPENID 

159 

160 add_parameter_if_available(oauth_parameters, OAUTH2_PARAMETERS.CLIENT_ID, 

161 self._client_id) 

162 add_parameter_if_available(oauth_parameters, OAUTH2_PARAMETERS.RESOURCE, 

163 self._resource) 

164 add_parameter_if_available(oauth_parameters, OAUTH2_PARAMETERS.REDIRECT_URI, 

165 self._redirect_uri) 

166 

167 return oauth_parameters 

168 

169 def _get_token_username_password_managed(self, username, password): 

170 self._log.debug('Acquiring token with username password for managed user') 

171 

172 oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.PASSWORD) 

173 

174 oauth_parameters[OAUTH2_PARAMETERS.PASSWORD] = password 

175 oauth_parameters[OAUTH2_PARAMETERS.USERNAME] = username 

176 

177 return self._oauth_get_token(oauth_parameters) 

178 

179 def _perform_wstrust_assertion_oauth_exchange(self, wstrust_response): 

180 self._log.debug("Performing OAuth assertion grant type exchange.") 

181 

182 oauth_parameters = {} 

183 grant_type = _get_saml_grant_type(wstrust_response) 

184 

185 token_bytes = wstrust_response.token 

186 assertion = b64encode(token_bytes) 

187 

188 oauth_parameters = self._create_oauth_parameters(grant_type) 

189 oauth_parameters[OAUTH2_PARAMETERS.ASSERTION] = assertion 

190 

191 return self._oauth_get_token(oauth_parameters) 

192 

193 def _perform_wstrust_exchange(self, wstrust_endpoint, wstrust_endpoint_version, cloud_audience_urn, username, password): 

194 

195 wstrust = self._create_wstrust_request(wstrust_endpoint, cloud_audience_urn, 

196 wstrust_endpoint_version) 

197 result = wstrust.acquire_token(username, password) 

198 

199 if not result.token: 

200 err_template = "Unsuccessful RSTR.\n\terror code: {}\n\tfaultMessage: {}" 

201 error_msg = err_template.format(result.error_code, result.fault_message) 

202 self._log.info(error_msg) 

203 raise AdalError(error_msg) 

204 

205 return result 

206 

207 def _perform_username_password_for_access_token_exchange(self, wstrust_endpoint, wstrust_endpoint_version, cloud_audience_urn, 

208 username, password): 

209 wstrust_response = self._perform_wstrust_exchange(wstrust_endpoint, wstrust_endpoint_version, cloud_audience_urn, 

210 username, password) 

211 return self._perform_wstrust_assertion_oauth_exchange(wstrust_response) 

212 

213 def _get_token_username_password_federated(self, username, password): 

214 self._log.debug("Acquiring token with username password for federated user") 

215 

216 cloud_audience_urn = self._user_realm.cloud_audience_urn 

217 if not self._user_realm.federation_metadata_url: 

218 self._log.warn("Unable to retrieve federationMetadataUrl from AAD. " 

219 "Attempting fallback to AAD supplied endpoint.") 

220 

221 if not self._user_realm.federation_active_auth_url: 

222 raise AdalError('AAD did not return a WSTrust endpoint. Unable to proceed.') 

223 

224 wstrust_version = TokenRequest._parse_wstrust_version_from_federation_active_authurl( 

225 self._user_realm.federation_active_auth_url) 

226 self._log.debug( 

227 'wstrust endpoint version is: %(wstrust_version)s', 

228 {"wstrust_version": wstrust_version}) 

229 

230 return self._perform_username_password_for_access_token_exchange( 

231 self._user_realm.federation_active_auth_url, 

232 wstrust_version, cloud_audience_urn, username, password) 

233 else: 

234 mex_endpoint = self._user_realm.federation_metadata_url 

235 self._log.debug( 

236 "Attempting mex at: %(mex_endpoint)s", 

237 {"mex_endpoint": mex_endpoint}) 

238 mex_instance = self._create_mex(mex_endpoint) 

239 wstrust_version = WSTrustVersion.UNDEFINED 

240 

241 try: 

242 mex_instance.discover() 

243 wstrust_endpoint = mex_instance.username_password_policy['url'] 

244 wstrust_version = mex_instance.username_password_policy['version'] 

245 except Exception: #pylint: disable=broad-except 

246 self._log.warn( 

247 "MEX exchange failed for %(mex_endpoint)s. " 

248 "Attempting fallback to AAD supplied endpoint.", 

249 {"mex_endpoint": mex_endpoint}) 

250 wstrust_endpoint = self._user_realm.federation_active_auth_url 

251 wstrust_version = TokenRequest._parse_wstrust_version_from_federation_active_authurl( 

252 self._user_realm.federation_active_auth_url) 

253 if not wstrust_endpoint: 

254 raise AdalError('AAD did not return a WSTrust endpoint. Unable to proceed.') 

255 

256 return self._perform_username_password_for_access_token_exchange(wstrust_endpoint, wstrust_version, 

257 cloud_audience_urn, 

258 username, password) 

259 @staticmethod 

260 def _parse_wstrust_version_from_federation_active_authurl(federation_active_authurl): 

261 if '/trust/2005/usernamemixed' in federation_active_authurl: 

262 return WSTrustVersion.WSTRUST2005 

263 if '/trust/13/usernamemixed' in federation_active_authurl: 

264 return WSTrustVersion.WSTRUST13 

265 return WSTrustVersion.UNDEFINED 

266 

267 def get_token_with_username_password(self, username, password): 

268 self._log.debug("Acquiring token with username password.") 

269 self._user_id = username 

270 try: 

271 token = self._find_token_from_cache() 

272 if token: 

273 return token 

274 except AdalError: 

275 self._log.exception('Attempt to look for token in cache resulted in Error') 

276 

277 if not self._authentication_context.authority.is_adfs_authority: 

278 self._user_realm = self._create_user_realm_request(username) 

279 self._user_realm.discover() 

280 

281 try: 

282 if self._user_realm.account_type == ACCOUNT_TYPE['Managed']: 

283 token = self._get_token_username_password_managed(username, password) 

284 elif self._user_realm.account_type == ACCOUNT_TYPE['Federated']: 

285 token = self._get_token_username_password_federated(username, password) 

286 else: 

287 raise AdalError( 

288 "Server returned an unknown AccountType: {}".format(self._user_realm.account_type)) 

289 self._log.debug("Successfully retrieved token from authority.") 

290 except Exception: 

291 self._log.info("get_token_func returned with error") 

292 raise 

293 else: 

294 self._log.info('Skipping user realm discovery for ADFS authority') 

295 token = self._get_token_username_password_managed(username, password) 

296 

297 self._cache_driver.add(token) 

298 return token 

299 

300 def get_token_with_client_credentials(self, client_secret): 

301 self._log.debug("Getting token with client credentials.") 

302 try: 

303 token = self._find_token_from_cache() 

304 if token: 

305 return token 

306 except AdalError: 

307 self._log.exception('Attempt to look for token in cache resulted in Error') 

308 

309 oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.CLIENT_CREDENTIALS) 

310 oauth_parameters[OAUTH2_PARAMETERS.CLIENT_SECRET] = client_secret 

311 

312 token = self._oauth_get_token(oauth_parameters) 

313 self._cache_driver.add(token) 

314 return token 

315 

316 def get_token_with_authorization_code(self, authorization_code, client_secret, code_verifier): 

317 

318 self._log.info("Getting token with auth code.") 

319 oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.AUTHORIZATION_CODE) 

320 oauth_parameters[OAUTH2_PARAMETERS.CODE] = authorization_code 

321 if client_secret is not None: 

322 oauth_parameters[OAUTH2_PARAMETERS.CLIENT_SECRET] = client_secret 

323 if code_verifier is not None: 

324 oauth_parameters[OAUTH2_PARAMETERS.CODE_VERIFIER] = code_verifier 

325 token = self._oauth_get_token(oauth_parameters) 

326 self._add_token_into_cache(token) 

327 return token 

328 

329 def _get_token_with_refresh_token(self, refresh_token, resource, client_secret): 

330 

331 self._log.info("Getting a new token from a refresh token") 

332 

333 oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.REFRESH_TOKEN) 

334 if resource: 

335 oauth_parameters[OAUTH2_PARAMETERS.RESOURCE] = resource 

336 

337 if client_secret: 

338 oauth_parameters[OAUTH2_PARAMETERS.CLIENT_SECRET] = client_secret 

339 

340 oauth_parameters[OAUTH2_PARAMETERS.REFRESH_TOKEN] = refresh_token 

341 return self._oauth_get_token(oauth_parameters) 

342 

343 def get_token_with_refresh_token(self, refresh_token, client_secret): 

344 return self._get_token_with_refresh_token(refresh_token, None, client_secret) 

345 

346 def get_token_from_cache_with_refresh(self, user_id): 

347 self._log.debug("Getting token from cache with refresh if necessary.") 

348 self._user_id = user_id 

349 return self._find_token_from_cache() 

350 

351 def _create_jwt(self, certificate, thumbprint, public_certificate): 

352 

353 ssj = self._create_self_signed_jwt() 

354 jwt = ssj.create(certificate, thumbprint, public_certificate) 

355 

356 if not jwt: 

357 raise AdalError("Failed to create JWT.") 

358 return jwt 

359 

360 def get_token_with_certificate(self, certificate, thumbprint, public_certificate): 

361 

362 self._log.info("Getting a token via certificate.") 

363 

364 jwt = self._create_jwt(certificate, thumbprint, public_certificate) 

365 

366 oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.CLIENT_CREDENTIALS) 

367 oauth_parameters[OAUTH2_PARAMETERS.CLIENT_ASSERTION_TYPE] = OAUTH2_GRANT_TYPE.JWT_BEARER 

368 oauth_parameters[OAUTH2_PARAMETERS.CLIENT_ASSERTION] = jwt 

369 

370 try: 

371 token = self._find_token_from_cache() 

372 if token: 

373 return token 

374 except AdalError: 

375 self._log.exception('Attempt to look for token in cache resulted in Error') 

376 

377 return self._oauth_get_token(oauth_parameters) 

378 

379 def get_token_with_device_code(self, user_code_info): 

380 self._log.info("Getting a token via device code") 

381 

382 oauth_parameters = self._create_oauth_parameters(OAUTH2_GRANT_TYPE.DEVICE_CODE) 

383 oauth_parameters[OAUTH2_PARAMETERS.CODE] = user_code_info[OAUTH2_DEVICE_CODE_RESPONSE_PARAMETERS.DEVICE_CODE] 

384 

385 interval = user_code_info[OAUTH2_DEVICE_CODE_RESPONSE_PARAMETERS.INTERVAL] 

386 expires_in = user_code_info[OAUTH2_DEVICE_CODE_RESPONSE_PARAMETERS.EXPIRES_IN] 

387 

388 if interval <= 0: 

389 raise AdalError('invalid refresh interval') 

390 

391 client = self._create_oauth2_client() 

392 self._polling_client = client 

393 

394 token = client.get_token_with_polling(oauth_parameters, interval, expires_in) 

395 self._add_token_into_cache(token) 

396 

397 return token 

398 

399 def cancel_token_request_with_device_code(self): 

400 self._polling_client.cancel_polling_request()