Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/adal/oauth2_client.py: 15%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

225 statements  

1#------------------------------------------------------------------------------ 

2# 

3# Copyright (c) Microsoft Corporation.  

4# All rights reserved. 

5#  

6# This code is licensed under the MIT License. 

7#  

8# Permission is hereby granted, free of charge, to any person obtaining a copy 

9# of this software and associated documentation files(the "Software"), to deal 

10# in the Software without restriction, including without limitation the rights 

11# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell 

12# copies of the Software, and to permit persons to whom the Software is 

13# furnished to do so, subject to the following conditions : 

14#  

15# The above copyright notice and this permission notice shall be included in 

16# all copies or substantial portions of the Software. 

17#  

18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

19# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

20# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE 

21# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

22# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

23# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 

24# THE SOFTWARE. 

25# 

26#------------------------------------------------------------------------------ 

27 

28from datetime import datetime, timedelta 

29import math 

30import re 

31import json 

32import time 

33import uuid 

34 

35try: 

36 from urllib.parse import urlencode, urlparse 

37except ImportError: 

38 from urllib import urlencode # pylint: disable=no-name-in-module 

39 from urlparse import urlparse # pylint: disable=import-error,ungrouped-imports 

40 

41import requests 

42 

43from . import log 

44from . import util 

45from .constants import OAuth2, TokenResponseFields, IdTokenFields 

46from .adal_error import AdalError 

47 

48TOKEN_RESPONSE_MAP = { 

49 OAuth2.ResponseParameters.TOKEN_TYPE : TokenResponseFields.TOKEN_TYPE, 

50 OAuth2.ResponseParameters.ACCESS_TOKEN : TokenResponseFields.ACCESS_TOKEN, 

51 OAuth2.ResponseParameters.REFRESH_TOKEN : TokenResponseFields.REFRESH_TOKEN, 

52 OAuth2.ResponseParameters.CREATED_ON : TokenResponseFields.CREATED_ON, 

53 OAuth2.ResponseParameters.EXPIRES_ON : TokenResponseFields.EXPIRES_ON, 

54 OAuth2.ResponseParameters.EXPIRES_IN : TokenResponseFields.EXPIRES_IN, 

55 OAuth2.ResponseParameters.RESOURCE : TokenResponseFields.RESOURCE, 

56 OAuth2.ResponseParameters.ERROR : TokenResponseFields.ERROR, 

57 OAuth2.ResponseParameters.ERROR_DESCRIPTION : TokenResponseFields.ERROR_DESCRIPTION, 

58} 

59 

60_REQ_OPTION = {'headers' : {'content-type': 'application/x-www-form-urlencoded'}} 

61_ERROR_TEMPLATE = u"{} request returned http error: {}" 

62 

63 

64def map_fields(in_obj, map_to): 

65 return dict((map_to[k], v) for k, v in in_obj.items() if k in map_to) 

66 

67def _get_user_id(id_token): 

68 user_id = None 

69 is_displayable = False 

70 

71 if id_token.get('upn'): 

72 user_id = id_token['upn'] 

73 is_displayable = True 

74 elif id_token.get('email'): 

75 user_id = id_token['email'] 

76 is_displayable = True 

77 elif id_token.get('sub'): 

78 user_id = id_token['sub'] 

79 

80 if not user_id: 

81 user_id = str(uuid.uuid4()) 

82 

83 user_id_vals = {} 

84 user_id_vals[IdTokenFields.USER_ID] = user_id 

85 

86 if is_displayable: 

87 user_id_vals[IdTokenFields.IS_USER_ID_DISPLAYABLE] = True 

88 

89 return user_id_vals 

90 

91def _extract_token_values(id_token): 

92 extracted_values = {} 

93 extracted_values = map_fields(id_token, OAuth2.IdTokenMap) 

94 extracted_values.update(_get_user_id(id_token)) 

95 return extracted_values 

96 

97class OAuth2Client(object): 

98 

99 def __init__(self, call_context, authority): 

100 self._token_endpoint = authority.token_endpoint 

101 self._device_code_endpoint = authority.device_code_endpoint 

102 self._log = log.Logger("OAuth2Client", call_context['log_context']) 

103 self._call_context = call_context 

104 self._cancel_polling_request = False 

105 

106 def _create_token_url(self): 

107 parameters = {} 

108 if self._call_context.get('api_version'): 

109 parameters[OAuth2.Parameters.AAD_API_VERSION] = self._call_context[ 

110 'api_version'] 

111 

112 return urlparse('{}?{}'.format(self._token_endpoint, urlencode(parameters))) 

113 

114 def _create_device_code_url(self): 

115 parameters = {} 

116 parameters[OAuth2.Parameters.AAD_API_VERSION] = '1.0' 

117 return urlparse('{}?{}'.format(self._device_code_endpoint, urlencode(parameters))) 

118 

119 def _parse_optional_ints(self, obj, keys): 

120 for key in keys: 

121 try: 

122 obj[key] = int(obj[key]) 

123 except ValueError: 

124 self._log.exception("%(key)s could not be parsed as an int", {"key": key}) 

125 raise 

126 except KeyError: 

127 # if the key isn't present we can just continue 

128 pass 

129 

130 def _parse_id_token(self, encoded_token): 

131 

132 cracked_token = self._open_jwt(encoded_token) 

133 if not cracked_token: 

134 return 

135 

136 try: 

137 b64_id_token = cracked_token['JWSPayload'] 

138 b64_decoded = util.base64_urlsafe_decode(b64_id_token) 

139 if not b64_decoded: 

140 self._log.warn('The returned id_token could not be base64 url safe decoded.') 

141 return 

142 

143 id_token = json.loads(b64_decoded.decode('utf-8')) 

144 except ValueError: 

145 self._log.exception( 

146 "The returned id_token could not be decoded: %(id_token)s", 

147 {"id_token": encoded_token}) 

148 raise 

149 

150 return _extract_token_values(id_token) 

151 

152 def _open_jwt(self, jwt_token): 

153 id_token_parts_reg = r"^([^\.\s]*)\.([^\.\s]+)\.([^\.\s]*)$" 

154 matches = re.search(id_token_parts_reg, jwt_token) 

155 if not matches or len(matches.groups()) < 3: 

156 self._log.warn('The token was not parsable.') 

157 return {} 

158 

159 return { 

160 'header': matches.group(1), 

161 'JWSPayload': matches.group(2), 

162 'JWSSig': matches.group(3) 

163 } 

164 

165 def _validate_token_response(self, body): 

166 

167 try: 

168 wire_response = json.loads(body) 

169 except ValueError: 

170 self._log.exception( 

171 'The token response from the server is unparseable as JSON: %(token_response)s', 

172 {"token_response": body}) 

173 raise 

174 

175 int_keys = [ 

176 OAuth2.ResponseParameters.EXPIRES_ON, 

177 OAuth2.ResponseParameters.EXPIRES_IN, 

178 OAuth2.ResponseParameters.CREATED_ON 

179 ] 

180 

181 self._parse_optional_ints(wire_response, int_keys) 

182 

183 expires_in = wire_response.get(OAuth2.ResponseParameters.EXPIRES_IN) 

184 if expires_in: 

185 now = datetime.now() 

186 soon = timedelta(seconds=expires_in) 

187 wire_response[OAuth2.ResponseParameters.EXPIRES_ON] = str(now + soon) 

188 

189 created_on = wire_response.get(OAuth2.ResponseParameters.CREATED_ON) 

190 if created_on: 

191 temp_date = datetime.fromtimestamp(created_on) 

192 wire_response[OAuth2.ResponseParameters.CREATED_ON] = str(temp_date) 

193 

194 if not wire_response.get(OAuth2.ResponseParameters.TOKEN_TYPE): 

195 raise AdalError('wire_response is missing token_type', wire_response) 

196 

197 if not wire_response.get(OAuth2.ResponseParameters.ACCESS_TOKEN): 

198 raise AdalError('wire_response is missing access_token', wire_response) 

199 

200 token_response = map_fields(wire_response, TOKEN_RESPONSE_MAP) 

201 

202 if wire_response.get(OAuth2.ResponseParameters.ID_TOKEN): 

203 id_token = self._parse_id_token(wire_response[OAuth2.ResponseParameters.ID_TOKEN]) 

204 if id_token: 

205 token_response.update(id_token) 

206 

207 return token_response 

208 

209 def _validate_device_code_response(self, body): 

210 

211 try: 

212 wire_response = json.loads(body) 

213 except ValueError: 

214 self._log.info('The device code response returned from the server is unparseable as JSON:') 

215 raise 

216 

217 int_keys = [ 

218 OAuth2.DeviceCodeResponseParameters.EXPIRES_IN, 

219 OAuth2.DeviceCodeResponseParameters.INTERVAL 

220 ] 

221 

222 self._parse_optional_ints(wire_response, int_keys) 

223 

224 if not wire_response.get(OAuth2.DeviceCodeResponseParameters.EXPIRES_IN): 

225 raise AdalError('wire_response is missing expires_in', wire_response) 

226 

227 if not wire_response.get(OAuth2.DeviceCodeResponseParameters.DEVICE_CODE): 

228 raise AdalError('wire_response is missing device_code', wire_response) 

229 

230 if not wire_response.get(OAuth2.DeviceCodeResponseParameters.USER_CODE): 

231 raise AdalError('wire_response is missing user_code', wire_response) 

232 

233 #skip field naming tweak, because names from wire are python style already 

234 return wire_response 

235 

236 def _handle_get_token_response(self, body): 

237 try: 

238 return self._validate_token_response(body) 

239 except Exception: 

240 self._log.exception( 

241 "Error validating get token response: %(token_response)s", 

242 {"token_response": body}) 

243 raise 

244 

245 def _handle_get_device_code_response(self, body): 

246 

247 try: 

248 return self._validate_device_code_response(body) 

249 except Exception: 

250 self._log.exception( 

251 "Error validating get user code response: %(token_response)s", 

252 {"token_response": body}) 

253 raise 

254 

255 def get_token(self, oauth_parameters): 

256 token_url = self._create_token_url() 

257 url_encoded_token_request = urlencode(oauth_parameters) 

258 post_options = util.create_request_options(self, _REQ_OPTION) 

259 

260 operation = "Get Token" 

261 

262 try: 

263 resp = requests.post(token_url.geturl(), 

264 data=url_encoded_token_request, 

265 headers=post_options['headers'], 

266 verify=self._call_context.get('verify_ssl', None), 

267 proxies=self._call_context.get('proxies', None), 

268 timeout=self._call_context.get('timeout', None)) 

269 

270 util.log_return_correlation_id(self._log, operation, resp) 

271 except Exception: 

272 self._log.exception("%(operation)s request failed", {"operation": operation}) 

273 raise 

274 

275 if util.is_http_success(resp.status_code): 

276 return self._handle_get_token_response(resp.text) 

277 else: 

278 if resp.status_code == 429: 

279 resp.raise_for_status() # Will raise requests.exceptions.HTTPError 

280 return_error_string = _ERROR_TEMPLATE.format(operation, resp.status_code) 

281 error_response = "" 

282 if resp.text: 

283 return_error_string = u"{} and server response: {}".format(return_error_string, 

284 resp.text) 

285 try: 

286 error_response = resp.json() 

287 except ValueError: 

288 pass 

289 raise AdalError(return_error_string, error_response) 

290 

291 def get_user_code_info(self, oauth_parameters): 

292 device_code_url = self._create_device_code_url() 

293 url_encoded_code_request = urlencode(oauth_parameters) 

294 

295 post_options = util.create_request_options(self, _REQ_OPTION) 

296 operation = "Get Device Code" 

297 try: 

298 resp = requests.post(device_code_url.geturl(), 

299 data=url_encoded_code_request, 

300 headers=post_options['headers'], 

301 verify=self._call_context.get('verify_ssl', None), 

302 proxies=self._call_context.get('proxies', None), 

303 timeout=self._call_context.get('timeout', None)) 

304 util.log_return_correlation_id(self._log, operation, resp) 

305 except Exception: 

306 self._log.exception("%(operation)s request failed", {"operation": operation}) 

307 raise 

308 

309 if util.is_http_success(resp.status_code): 

310 user_code_info = self._handle_get_device_code_response(resp.text) 

311 user_code_info['correlation_id'] = resp.headers.get('client-request-id') 

312 return user_code_info 

313 else: 

314 if resp.status_code == 429: 

315 resp.raise_for_status() # Will raise requests.exceptions.HTTPError 

316 return_error_string = _ERROR_TEMPLATE.format(operation, resp.status_code) 

317 error_response = "" 

318 if resp.text: 

319 return_error_string = u"{} and server response: {}".format(return_error_string, 

320 resp.text) 

321 try: 

322 error_response = resp.json() 

323 except ValueError: 

324 pass 

325 

326 raise AdalError(return_error_string, error_response) 

327 

328 def get_token_with_polling(self, oauth_parameters, refresh_internal, expires_in): 

329 token_url = self._create_token_url() 

330 url_encoded_code_request = urlencode(oauth_parameters) 

331 

332 post_options = util.create_request_options(self, _REQ_OPTION) 

333 

334 operation = "Get token with device code" 

335 

336 max_times_for_retry = math.floor(expires_in/refresh_internal) 

337 for _ in range(int(max_times_for_retry)): 

338 if self._cancel_polling_request: 

339 raise AdalError('Polling_Request_Cancelled') 

340 

341 resp = requests.post( 

342 token_url.geturl(), 

343 data=url_encoded_code_request, headers=post_options['headers'], 

344 proxies=self._call_context.get('proxies', None), 

345 verify=self._call_context.get('verify_ssl', None)) 

346 if resp.status_code == 429: 

347 resp.raise_for_status() # Will raise requests.exceptions.HTTPError 

348 

349 util.log_return_correlation_id(self._log, operation, resp) 

350 

351 wire_response = {} 

352 if not util.is_http_success(resp.status_code): 

353 # on error, the body should be json already  

354 wire_response = json.loads(resp.text) 

355 

356 error = wire_response.get(OAuth2.DeviceCodeResponseParameters.ERROR) 

357 if error == 'authorization_pending': 

358 time.sleep(refresh_internal) 

359 continue 

360 elif error: 

361 raise AdalError('Unexpected polling state {}'.format(error), 

362 wire_response) 

363 else: 

364 try: 

365 return self._validate_token_response(resp.text) 

366 except Exception: 

367 self._log.exception( 

368 u"Error validating get token response %(access_token)s", 

369 {"access_token": resp.text}) 

370 raise 

371 

372 raise AdalError('Timeout from "get_token_with_polling"') 

373 

374 def cancel_polling_request(self): 

375 self._cancel_polling_request = True 

376