Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.10/site-packages/msal/oauth2cli/oidc.py: 37%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

108 statements  

1import json 

2import base64 

3import time 

4import secrets 

5import warnings 

6import hashlib 

7import logging 

8 

9from . import oauth2 

10 

11 

12logger = logging.getLogger(__name__) 

13 

14def decode_part(raw, encoding="utf-8"): 

15 """Decode a part of the JWT. 

16 

17 JWT is encoded by padding-less base64url, 

18 based on `JWS specs <https://tools.ietf.org/html/rfc7515#appendix-C>`_. 

19 

20 :param encoding: 

21 If you are going to decode the first 2 parts of a JWT, i.e. the header 

22 or the payload, the default value "utf-8" would work fine. 

23 If you are going to decode the last part i.e. the signature part, 

24 it is a binary string so you should use `None` as encoding here. 

25 """ 

26 raw += '=' * (-len(raw) % 4) # https://stackoverflow.com/a/32517907/728675 

27 raw = str( 

28 # On Python 2.7, argument of urlsafe_b64decode must be str, not unicode. 

29 # This is not required on Python 3. 

30 raw) 

31 output = base64.urlsafe_b64decode(raw) 

32 if encoding: 

33 output = output.decode(encoding) 

34 return output 

35 

36base64decode = decode_part # Obsolete. For backward compatibility only. 

37 

38def _epoch_to_local(epoch): 

39 return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(epoch)) 

40 

41class IdTokenError(RuntimeError): # We waised RuntimeError before, so keep it 

42 """In unlikely event of an ID token is malformed, this exception will be raised.""" 

43 def __init__(self, reason, now, claims): 

44 super(IdTokenError, self).__init__( 

45 "%s Current epoch = %s. The id_token was approximately: %s" % ( 

46 reason, _epoch_to_local(now), json.dumps(dict( 

47 claims, 

48 iat=_epoch_to_local(claims["iat"]) if claims.get("iat") else None, 

49 exp=_epoch_to_local(claims["exp"]) if claims.get("exp") else None, 

50 ), indent=2))) 

51 

52class _IdTokenTimeError(IdTokenError): # This is not intended to be raised and caught 

53 _SUGGESTION = "Make sure your computer's time and time zone are both correct." 

54 def __init__(self, reason, now, claims): 

55 super(_IdTokenTimeError, self).__init__(reason+ " " + self._SUGGESTION, now, claims) 

56 def log(self): 

57 # Influenced by JWT specs https://tools.ietf.org/html/rfc7519#section-4.1.5 

58 # and OIDC specs https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation 

59 # We used to raise this error, but now we just log it as warning, because: 

60 # 1. If it is caused by incorrect local machine time, 

61 # then the token(s) are still correct and probably functioning, 

62 # so, there is no point to error out. 

63 # 2. If it is caused by incorrect IdP time, then it is IdP's fault, 

64 # There is not much a client can do, so, we might as well return the token(s) 

65 # and let downstream components to decide what to do. 

66 logger.warning(str(self)) 

67 

68class IdTokenIssuerError(IdTokenError): 

69 pass 

70 

71class IdTokenAudienceError(IdTokenError): 

72 pass 

73 

74class IdTokenNonceError(IdTokenError): 

75 pass 

76 

77def decode_id_token(id_token, client_id=None, issuer=None, nonce=None, now=None): 

78 """Decodes and validates an id_token and returns its claims as a dictionary. 

79 

80 ID token claims would at least contain: "iss", "sub", "aud", "exp", "iat", 

81 per `specs <https://openid.net/specs/openid-connect-core-1_0.html#IDToken>`_ 

82 and it may contain other optional content such as "preferred_username", 

83 `maybe more <https://openid.net/specs/openid-connect-core-1_0.html#Claims>`_ 

84 """ 

85 decoded = json.loads(decode_part(id_token.split('.')[1])) 

86 # Based on https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation 

87 _now = int(now or time.time()) 

88 skew = 120 # 2 minutes 

89 

90 if _now + skew < decoded.get("nbf", _now - 1): # nbf is optional per JWT specs 

91 # This is not an ID token validation, but a JWT validation 

92 # https://tools.ietf.org/html/rfc7519#section-4.1.5 

93 _IdTokenTimeError("0. The ID token is not yet valid.", _now, decoded).log() 

94 

95 if issuer and issuer != decoded["iss"]: 

96 # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationResponse 

97 raise IdTokenIssuerError( 

98 '2. The Issuer Identifier for the OpenID Provider, "%s", ' 

99 "(which is typically obtained during Discovery), " 

100 "MUST exactly match the value of the iss (issuer) Claim." % issuer, 

101 _now, 

102 decoded) 

103 

104 if client_id: 

105 valid_aud = client_id in decoded["aud"] if isinstance( 

106 decoded["aud"], list) else client_id == decoded["aud"] 

107 if not valid_aud: 

108 raise IdTokenAudienceError( 

109 "3. The aud (audience) claim must contain this client's client_id " 

110 '"%s", case-sensitively. Was your client_id in wrong casing?' 

111 # Some IdP accepts wrong casing request but issues right casing IDT 

112 % client_id, 

113 _now, 

114 decoded) 

115 

116 # Per specs: 

117 # 6. If the ID Token is received via direct communication between 

118 # the Client and the Token Endpoint (which it is during _obtain_token()), 

119 # the TLS server validation MAY be used to validate the issuer 

120 # in place of checking the token signature. 

121 

122 if _now - skew > decoded["exp"]: 

123 _IdTokenTimeError("9. The ID token already expires.", _now, decoded).log() 

124 

125 if nonce and nonce != decoded.get("nonce"): 

126 raise IdTokenNonceError( 

127 "11. Nonce must be the same value " 

128 "as the one that was sent in the Authentication Request.", 

129 _now, 

130 decoded) 

131 

132 return decoded 

133 

134 

135def _nonce_hash(nonce): 

136 # https://openid.net/specs/openid-connect-core-1_0.html#NonceNotes 

137 return hashlib.sha256(nonce.encode("ascii")).hexdigest() 

138 

139 

140class Prompt(object): 

141 """This class defines the constant strings for prompt parameter. 

142 

143 The values are based on 

144 https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest 

145 """ 

146 NONE = "none" 

147 LOGIN = "login" 

148 CONSENT = "consent" 

149 SELECT_ACCOUNT = "select_account" 

150 CREATE = "create" # Defined in https://openid.net/specs/openid-connect-prompt-create-1_0.html#PromptParameter 

151 

152 

153class Client(oauth2.Client): 

154 """OpenID Connect is a layer on top of the OAuth2. 

155 

156 See its specs at https://openid.net/connect/ 

157 """ 

158 

159 def decode_id_token(self, id_token, nonce=None): 

160 """See :func:`~decode_id_token`.""" 

161 return decode_id_token( 

162 id_token, nonce=nonce, 

163 client_id=self.client_id, issuer=self.configuration.get("issuer")) 

164 

165 def _obtain_token(self, grant_type, *args, **kwargs): 

166 """The result will also contain one more key "id_token_claims", 

167 whose value will be a dictionary returned by :func:`~decode_id_token`. 

168 """ 

169 ret = super(Client, self)._obtain_token(grant_type, *args, **kwargs) 

170 if "id_token" in ret: 

171 ret["id_token_claims"] = self.decode_id_token(ret["id_token"]) 

172 return ret 

173 

174 def build_auth_request_uri(self, response_type, nonce=None, **kwargs): 

175 """Generate an authorization uri to be visited by resource owner. 

176 

177 Return value and all other parameters are the same as 

178 :func:`oauth2.Client.build_auth_request_uri`, plus new parameter(s): 

179 

180 :param nonce: 

181 A hard-to-guess string used to mitigate replay attacks. See also 

182 `OIDC specs <https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest>`_. 

183 """ 

184 warnings.warn("Use initiate_auth_code_flow() instead", DeprecationWarning) 

185 return super(Client, self).build_auth_request_uri( 

186 response_type, nonce=nonce, **kwargs) 

187 

188 def obtain_token_by_authorization_code(self, code, nonce=None, **kwargs): 

189 """Get a token via authorization code. a.k.a. Authorization Code Grant. 

190 

191 Return value and all other parameters are the same as 

192 :func:`oauth2.Client.obtain_token_by_authorization_code`, 

193 plus new parameter(s): 

194 

195 :param nonce: 

196 If you provided a nonce when calling :func:`build_auth_request_uri`, 

197 same nonce should also be provided here, so that we'll validate it. 

198 An exception will be raised if the nonce in id token mismatches. 

199 """ 

200 warnings.warn( 

201 "Use obtain_token_by_auth_code_flow() instead", DeprecationWarning) 

202 result = super(Client, self).obtain_token_by_authorization_code( 

203 code, **kwargs) 

204 nonce_in_id_token = result.get("id_token_claims", {}).get("nonce") 

205 if "id_token_claims" in result and nonce and nonce != nonce_in_id_token: 

206 raise ValueError( 

207 'The nonce in id token ("%s") should match your nonce ("%s")' % 

208 (nonce_in_id_token, nonce)) 

209 return result 

210 

211 def initiate_auth_code_flow( 

212 self, 

213 scope=None, 

214 **kwargs): 

215 """Initiate an auth code flow. 

216 

217 It provides nonce protection automatically. 

218 

219 :param list scope: 

220 A list of strings, e.g. ["profile", "email", ...]. 

221 This method will automatically send ["openid"] to the wire, 

222 although it won't modify your input list. 

223 

224 See :func:`oauth2.Client.initiate_auth_code_flow` in parent class 

225 for descriptions on other parameters and return value. 

226 """ 

227 if "id_token" in kwargs.get("response_type", ""): 

228 # Implicit grant would cause auth response coming back in #fragment, 

229 # but fragment won't reach a web service. 

230 raise ValueError('response_type="id_token ..." is not allowed') 

231 _scope = list(scope) if scope else [] # We won't modify input parameter 

232 if "openid" not in _scope: 

233 # "If no openid scope value is present, 

234 # the request may still be a valid OAuth 2.0 request, 

235 # but is not an OpenID Connect request." -- OIDC Core Specs, 3.1.2.2 

236 # https://openid.net/specs/openid-connect-core-1_0.html#AuthRequestValidation 

237 # Here we just automatically add it. If the caller do not want id_token, 

238 # they should simply go with oauth2.Client. 

239 _scope.append("openid") 

240 nonce = secrets.token_urlsafe(16) 

241 flow = super(Client, self).initiate_auth_code_flow( 

242 scope=_scope, nonce=_nonce_hash(nonce), **kwargs) 

243 flow["nonce"] = nonce 

244 if kwargs.get("max_age") is not None: 

245 flow["max_age"] = kwargs["max_age"] 

246 return flow 

247 

248 def obtain_token_by_auth_code_flow(self, auth_code_flow, auth_response, **kwargs): 

249 """Validate the auth_response being redirected back, and then obtain tokens, 

250 including ID token which can be used for user sign in. 

251 

252 Internally, it implements nonce to mitigate replay attack. 

253 It also implements PKCE to mitigate the auth code interception attack. 

254 

255 See :func:`oauth2.Client.obtain_token_by_auth_code_flow` in parent class 

256 for descriptions on other parameters and return value. 

257 """ 

258 result = super(Client, self).obtain_token_by_auth_code_flow( 

259 auth_code_flow, auth_response, **kwargs) 

260 if "id_token_claims" in result: 

261 nonce_in_id_token = result.get("id_token_claims", {}).get("nonce") 

262 expected_hash = _nonce_hash(auth_code_flow["nonce"]) 

263 if nonce_in_id_token != expected_hash: 

264 raise RuntimeError( 

265 'The nonce in id token ("%s") should match our nonce ("%s")' % 

266 (nonce_in_id_token, expected_hash)) 

267 

268 if auth_code_flow.get("max_age") is not None: 

269 auth_time = result.get("id_token_claims", {}).get("auth_time") 

270 if not auth_time: 

271 raise RuntimeError( 

272 "13. max_age was requested, ID token should contain auth_time") 

273 now = int(time.time()) 

274 skew = 120 # 2 minutes. Hardcoded, for now 

275 if now - skew > auth_time + auth_code_flow["max_age"]: 

276 raise RuntimeError( 

277 "13. auth_time ({auth_time}) was requested, " 

278 "by using max_age ({max_age}) parameter, " 

279 "and now ({now}) too much time has elasped " 

280 "since last end-user authentication. " 

281 "The ID token was: {id_token}".format( 

282 auth_time=auth_time, 

283 max_age=auth_code_flow["max_age"], 

284 now=now, 

285 id_token=json.dumps(result["id_token_claims"], indent=2), 

286 )) 

287 return result 

288 

289 def obtain_token_by_browser( 

290 self, 

291 display=None, 

292 prompt=None, 

293 max_age=None, 

294 ui_locales=None, 

295 id_token_hint=None, # It is relevant, 

296 # because this library exposes raw ID token 

297 login_hint=None, 

298 acr_values=None, 

299 **kwargs): 

300 """A native app can use this method to obtain token via a local browser. 

301 

302 Internally, it implements nonce to mitigate replay attack. 

303 It also implements PKCE to mitigate the auth code interception attack. 

304 

305 :param string display: Defined in 

306 `OIDC <https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest>`_. 

307 :param string prompt: Defined in 

308 `OIDC <https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest>`_. 

309 You can find the valid string values defined in :class:`oidc.Prompt`. 

310 

311 :param int max_age: Defined in 

312 `OIDC <https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest>`_. 

313 :param string ui_locales: Defined in 

314 `OIDC <https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest>`_. 

315 :param string id_token_hint: Defined in 

316 `OIDC <https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest>`_. 

317 :param string login_hint: Defined in 

318 `OIDC <https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest>`_. 

319 :param string acr_values: Defined in 

320 `OIDC <https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest>`_. 

321 

322 See :func:`oauth2.Client.obtain_token_by_browser` in parent class 

323 for descriptions on other parameters and return value. 

324 """ 

325 filtered_params = {k:v for k, v in dict( 

326 prompt=" ".join(prompt) if isinstance(prompt, (list, tuple)) else prompt, 

327 display=display, 

328 max_age=max_age, 

329 ui_locales=ui_locales, 

330 id_token_hint=id_token_hint, 

331 login_hint=login_hint, 

332 acr_values=acr_values, 

333 ).items() if v is not None} # Filter out None values 

334 return super(Client, self).obtain_token_by_browser( 

335 auth_params=dict(kwargs.pop("auth_params", {}), **filtered_params), 

336 **kwargs) 

337