Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/adal/oauth2_client.py: 15%
225 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:05 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:05 +0000
1#------------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation.
4# All rights reserved.
5#
6# This code is licensed under the MIT License.
7#
8# Permission is hereby granted, free of charge, to any person obtaining a copy
9# of this software and associated documentation files(the "Software"), to deal
10# in the Software without restriction, including without limitation the rights
11# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
12# copies of the Software, and to permit persons to whom the Software is
13# furnished to do so, subject to the following conditions :
14#
15# The above copyright notice and this permission notice shall be included in
16# all copies or substantial portions of the Software.
17#
18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE
21# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24# THE SOFTWARE.
25#
26#------------------------------------------------------------------------------
28from datetime import datetime, timedelta
29import math
30import re
31import json
32import time
33import uuid
35try:
36 from urllib.parse import urlencode, urlparse
37except ImportError:
38 from urllib import urlencode # pylint: disable=no-name-in-module
39 from urlparse import urlparse # pylint: disable=import-error,ungrouped-imports
41import requests
43from . import log
44from . import util
45from .constants import OAuth2, TokenResponseFields, IdTokenFields
46from .adal_error import AdalError
48TOKEN_RESPONSE_MAP = {
49 OAuth2.ResponseParameters.TOKEN_TYPE : TokenResponseFields.TOKEN_TYPE,
50 OAuth2.ResponseParameters.ACCESS_TOKEN : TokenResponseFields.ACCESS_TOKEN,
51 OAuth2.ResponseParameters.REFRESH_TOKEN : TokenResponseFields.REFRESH_TOKEN,
52 OAuth2.ResponseParameters.CREATED_ON : TokenResponseFields.CREATED_ON,
53 OAuth2.ResponseParameters.EXPIRES_ON : TokenResponseFields.EXPIRES_ON,
54 OAuth2.ResponseParameters.EXPIRES_IN : TokenResponseFields.EXPIRES_IN,
55 OAuth2.ResponseParameters.RESOURCE : TokenResponseFields.RESOURCE,
56 OAuth2.ResponseParameters.ERROR : TokenResponseFields.ERROR,
57 OAuth2.ResponseParameters.ERROR_DESCRIPTION : TokenResponseFields.ERROR_DESCRIPTION,
58}
60_REQ_OPTION = {'headers' : {'content-type': 'application/x-www-form-urlencoded'}}
61_ERROR_TEMPLATE = u"{} request returned http error: {}"
64def map_fields(in_obj, map_to):
65 return dict((map_to[k], v) for k, v in in_obj.items() if k in map_to)
67def _get_user_id(id_token):
68 user_id = None
69 is_displayable = False
71 if id_token.get('upn'):
72 user_id = id_token['upn']
73 is_displayable = True
74 elif id_token.get('email'):
75 user_id = id_token['email']
76 is_displayable = True
77 elif id_token.get('sub'):
78 user_id = id_token['sub']
80 if not user_id:
81 user_id = str(uuid.uuid4())
83 user_id_vals = {}
84 user_id_vals[IdTokenFields.USER_ID] = user_id
86 if is_displayable:
87 user_id_vals[IdTokenFields.IS_USER_ID_DISPLAYABLE] = True
89 return user_id_vals
91def _extract_token_values(id_token):
92 extracted_values = {}
93 extracted_values = map_fields(id_token, OAuth2.IdTokenMap)
94 extracted_values.update(_get_user_id(id_token))
95 return extracted_values
97class OAuth2Client(object):
99 def __init__(self, call_context, authority):
100 self._token_endpoint = authority.token_endpoint
101 self._device_code_endpoint = authority.device_code_endpoint
102 self._log = log.Logger("OAuth2Client", call_context['log_context'])
103 self._call_context = call_context
104 self._cancel_polling_request = False
106 def _create_token_url(self):
107 parameters = {}
108 if self._call_context.get('api_version'):
109 parameters[OAuth2.Parameters.AAD_API_VERSION] = self._call_context[
110 'api_version']
112 return urlparse('{}?{}'.format(self._token_endpoint, urlencode(parameters)))
114 def _create_device_code_url(self):
115 parameters = {}
116 parameters[OAuth2.Parameters.AAD_API_VERSION] = '1.0'
117 return urlparse('{}?{}'.format(self._device_code_endpoint, urlencode(parameters)))
119 def _parse_optional_ints(self, obj, keys):
120 for key in keys:
121 try:
122 obj[key] = int(obj[key])
123 except ValueError:
124 self._log.exception("%(key)s could not be parsed as an int", {"key": key})
125 raise
126 except KeyError:
127 # if the key isn't present we can just continue
128 pass
130 def _parse_id_token(self, encoded_token):
132 cracked_token = self._open_jwt(encoded_token)
133 if not cracked_token:
134 return
136 try:
137 b64_id_token = cracked_token['JWSPayload']
138 b64_decoded = util.base64_urlsafe_decode(b64_id_token)
139 if not b64_decoded:
140 self._log.warn('The returned id_token could not be base64 url safe decoded.')
141 return
143 id_token = json.loads(b64_decoded.decode('utf-8'))
144 except ValueError:
145 self._log.exception(
146 "The returned id_token could not be decoded: %(id_token)s",
147 {"id_token": encoded_token})
148 raise
150 return _extract_token_values(id_token)
152 def _open_jwt(self, jwt_token):
153 id_token_parts_reg = r"^([^\.\s]*)\.([^\.\s]+)\.([^\.\s]*)$"
154 matches = re.search(id_token_parts_reg, jwt_token)
155 if not matches or len(matches.groups()) < 3:
156 self._log.warn('The token was not parsable.')
157 return {}
159 return {
160 'header': matches.group(1),
161 'JWSPayload': matches.group(2),
162 'JWSSig': matches.group(3)
163 }
165 def _validate_token_response(self, body):
167 try:
168 wire_response = json.loads(body)
169 except ValueError:
170 self._log.exception(
171 'The token response from the server is unparseable as JSON: %(token_response)s',
172 {"token_response": body})
173 raise
175 int_keys = [
176 OAuth2.ResponseParameters.EXPIRES_ON,
177 OAuth2.ResponseParameters.EXPIRES_IN,
178 OAuth2.ResponseParameters.CREATED_ON
179 ]
181 self._parse_optional_ints(wire_response, int_keys)
183 expires_in = wire_response.get(OAuth2.ResponseParameters.EXPIRES_IN)
184 if expires_in:
185 now = datetime.now()
186 soon = timedelta(seconds=expires_in)
187 wire_response[OAuth2.ResponseParameters.EXPIRES_ON] = str(now + soon)
189 created_on = wire_response.get(OAuth2.ResponseParameters.CREATED_ON)
190 if created_on:
191 temp_date = datetime.fromtimestamp(created_on)
192 wire_response[OAuth2.ResponseParameters.CREATED_ON] = str(temp_date)
194 if not wire_response.get(OAuth2.ResponseParameters.TOKEN_TYPE):
195 raise AdalError('wire_response is missing token_type', wire_response)
197 if not wire_response.get(OAuth2.ResponseParameters.ACCESS_TOKEN):
198 raise AdalError('wire_response is missing access_token', wire_response)
200 token_response = map_fields(wire_response, TOKEN_RESPONSE_MAP)
202 if wire_response.get(OAuth2.ResponseParameters.ID_TOKEN):
203 id_token = self._parse_id_token(wire_response[OAuth2.ResponseParameters.ID_TOKEN])
204 if id_token:
205 token_response.update(id_token)
207 return token_response
209 def _validate_device_code_response(self, body):
211 try:
212 wire_response = json.loads(body)
213 except ValueError:
214 self._log.info('The device code response returned from the server is unparseable as JSON:')
215 raise
217 int_keys = [
218 OAuth2.DeviceCodeResponseParameters.EXPIRES_IN,
219 OAuth2.DeviceCodeResponseParameters.INTERVAL
220 ]
222 self._parse_optional_ints(wire_response, int_keys)
224 if not wire_response.get(OAuth2.DeviceCodeResponseParameters.EXPIRES_IN):
225 raise AdalError('wire_response is missing expires_in', wire_response)
227 if not wire_response.get(OAuth2.DeviceCodeResponseParameters.DEVICE_CODE):
228 raise AdalError('wire_response is missing device_code', wire_response)
230 if not wire_response.get(OAuth2.DeviceCodeResponseParameters.USER_CODE):
231 raise AdalError('wire_response is missing user_code', wire_response)
233 #skip field naming tweak, because names from wire are python style already
234 return wire_response
236 def _handle_get_token_response(self, body):
237 try:
238 return self._validate_token_response(body)
239 except Exception:
240 self._log.exception(
241 "Error validating get token response: %(token_response)s",
242 {"token_response": body})
243 raise
245 def _handle_get_device_code_response(self, body):
247 try:
248 return self._validate_device_code_response(body)
249 except Exception:
250 self._log.exception(
251 "Error validating get user code response: %(token_response)s",
252 {"token_response": body})
253 raise
255 def get_token(self, oauth_parameters):
256 token_url = self._create_token_url()
257 url_encoded_token_request = urlencode(oauth_parameters)
258 post_options = util.create_request_options(self, _REQ_OPTION)
260 operation = "Get Token"
262 try:
263 resp = requests.post(token_url.geturl(),
264 data=url_encoded_token_request,
265 headers=post_options['headers'],
266 verify=self._call_context.get('verify_ssl', None),
267 proxies=self._call_context.get('proxies', None),
268 timeout=self._call_context.get('timeout', None))
270 util.log_return_correlation_id(self._log, operation, resp)
271 except Exception:
272 self._log.exception("%(operation)s request failed", {"operation": operation})
273 raise
275 if util.is_http_success(resp.status_code):
276 return self._handle_get_token_response(resp.text)
277 else:
278 if resp.status_code == 429:
279 resp.raise_for_status() # Will raise requests.exceptions.HTTPError
280 return_error_string = _ERROR_TEMPLATE.format(operation, resp.status_code)
281 error_response = ""
282 if resp.text:
283 return_error_string = u"{} and server response: {}".format(return_error_string,
284 resp.text)
285 try:
286 error_response = resp.json()
287 except ValueError:
288 pass
289 raise AdalError(return_error_string, error_response)
291 def get_user_code_info(self, oauth_parameters):
292 device_code_url = self._create_device_code_url()
293 url_encoded_code_request = urlencode(oauth_parameters)
295 post_options = util.create_request_options(self, _REQ_OPTION)
296 operation = "Get Device Code"
297 try:
298 resp = requests.post(device_code_url.geturl(),
299 data=url_encoded_code_request,
300 headers=post_options['headers'],
301 verify=self._call_context.get('verify_ssl', None),
302 proxies=self._call_context.get('proxies', None),
303 timeout=self._call_context.get('timeout', None))
304 util.log_return_correlation_id(self._log, operation, resp)
305 except Exception:
306 self._log.exception("%(operation)s request failed", {"operation": operation})
307 raise
309 if util.is_http_success(resp.status_code):
310 user_code_info = self._handle_get_device_code_response(resp.text)
311 user_code_info['correlation_id'] = resp.headers.get('client-request-id')
312 return user_code_info
313 else:
314 if resp.status_code == 429:
315 resp.raise_for_status() # Will raise requests.exceptions.HTTPError
316 return_error_string = _ERROR_TEMPLATE.format(operation, resp.status_code)
317 error_response = ""
318 if resp.text:
319 return_error_string = u"{} and server response: {}".format(return_error_string,
320 resp.text)
321 try:
322 error_response = resp.json()
323 except ValueError:
324 pass
326 raise AdalError(return_error_string, error_response)
328 def get_token_with_polling(self, oauth_parameters, refresh_internal, expires_in):
329 token_url = self._create_token_url()
330 url_encoded_code_request = urlencode(oauth_parameters)
332 post_options = util.create_request_options(self, _REQ_OPTION)
334 operation = "Get token with device code"
336 max_times_for_retry = math.floor(expires_in/refresh_internal)
337 for _ in range(int(max_times_for_retry)):
338 if self._cancel_polling_request:
339 raise AdalError('Polling_Request_Cancelled')
341 resp = requests.post(
342 token_url.geturl(),
343 data=url_encoded_code_request, headers=post_options['headers'],
344 proxies=self._call_context.get('proxies', None),
345 verify=self._call_context.get('verify_ssl', None))
346 if resp.status_code == 429:
347 resp.raise_for_status() # Will raise requests.exceptions.HTTPError
349 util.log_return_correlation_id(self._log, operation, resp)
351 wire_response = {}
352 if not util.is_http_success(resp.status_code):
353 # on error, the body should be json already
354 wire_response = json.loads(resp.text)
356 error = wire_response.get(OAuth2.DeviceCodeResponseParameters.ERROR)
357 if error == 'authorization_pending':
358 time.sleep(refresh_internal)
359 continue
360 elif error:
361 raise AdalError('Unexpected polling state {}'.format(error),
362 wire_response)
363 else:
364 try:
365 return self._validate_token_response(resp.text)
366 except Exception:
367 self._log.exception(
368 u"Error validating get token response %(access_token)s",
369 {"access_token": resp.text})
370 raise
372 raise AdalError('Timeout from "get_token_with_polling"')
374 def cancel_polling_request(self):
375 self._cancel_polling_request = True