Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/adal/self_signed_jwt.py: 32%

85 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 06:05 +0000

1#------------------------------------------------------------------------------ 

2# 

3# Copyright (c) Microsoft Corporation.  

4# All rights reserved. 

5#  

6# This code is licensed under the MIT License. 

7#  

8# Permission is hereby granted, free of charge, to any person obtaining a copy 

9# of this software and associated documentation files(the "Software"), to deal 

10# in the Software without restriction, including without limitation the rights 

11# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell 

12# copies of the Software, and to permit persons to whom the Software is 

13# furnished to do so, subject to the following conditions : 

14#  

15# The above copyright notice and this permission notice shall be included in 

16# all copies or substantial portions of the Software. 

17#  

18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

19# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

20# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE 

21# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

22# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

23# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 

24# THE SOFTWARE. 

25# 

26#------------------------------------------------------------------------------ 

27 

28import time 

29import datetime 

30import uuid 

31import base64 

32import binascii 

33import re 

34 

35import jwt 

36 

37from .constants import Jwt 

38from .log import Logger 

39from .adal_error import AdalError 

40 

41def _get_date_now(): 

42 return datetime.datetime.now() 

43 

44def _get_new_jwt_id(): 

45 return str(uuid.uuid4()) 

46 

47def _create_x5t_value(thumbprint): 

48 hex_val = binascii.a2b_hex(thumbprint) 

49 return base64.urlsafe_b64encode(hex_val).decode() 

50 

51def _sign_jwt(header, payload, certificate): 

52 try: 

53 encoded_jwt = _encode_jwt(payload, certificate, header) 

54 except Exception as exp: 

55 raise AdalError("Error:Invalid Certificate: Expected Start of Certificate to be '-----BEGIN RSA PRIVATE KEY-----'", exp) 

56 _raise_on_invalid_jwt_signature(encoded_jwt) 

57 return encoded_jwt 

58 

59def _encode_jwt(payload, certificate, header): 

60 encoded = jwt.encode(payload, certificate, algorithm='RS256', headers=header) 

61 try: 

62 return encoded.decode() # PyJWT 1.x returns bytes; historically we convert it to string 

63 except AttributeError: 

64 return encoded # PyJWT 2 will return string 

65 

66def _raise_on_invalid_jwt_signature(encoded_jwt): 

67 segments = encoded_jwt.split('.') 

68 if len(segments) < 3 or not segments[2]: 

69 raise AdalError('Failed to sign JWT. This is most likely due to an invalid certificate.') 

70 

71def _extract_certs(public_cert_content): 

72 # Parses raw public certificate file contents and returns a list of strings 

73 # Usage: headers = {"x5c": extract_certs(open("my_cert.pem").read())} 

74 public_certificates = re.findall( 

75 r'-----BEGIN CERTIFICATE-----(?P<cert_value>[^-]+)-----END CERTIFICATE-----', 

76 public_cert_content, re.I) 

77 if public_certificates: 

78 return [cert.strip() for cert in public_certificates] 

79 # The public cert tags are not found in the input, 

80 # let's make best effort to exclude a private key pem file. 

81 if "PRIVATE KEY" in public_cert_content: 

82 raise ValueError( 

83 "We expect your public key but detect a private key instead") 

84 return [public_cert_content.strip()] 

85 

86class SelfSignedJwt(object): 

87 

88 NumCharIn128BitHexString = 128/8*2 

89 numCharIn160BitHexString = 160/8*2 

90 ThumbprintRegEx = r"^[a-f\d]*$" 

91 

92 def __init__(self, call_context, authority, client_id): 

93 self._log = Logger('SelfSignedJwt', call_context['log_context']) 

94 self._call_context = call_context 

95 

96 self._authortiy = authority 

97 self._token_endpoint = authority.token_endpoint 

98 self._client_id = client_id 

99 

100 def _create_header(self, thumbprint, public_certificate): 

101 x5t = _create_x5t_value(thumbprint) 

102 header = {'typ':'JWT', 'alg':'RS256', 'x5t':x5t} 

103 if public_certificate: 

104 header['x5c'] = _extract_certs(public_certificate) 

105 self._log.debug("Creating self signed JWT header. x5t: %(x5t)s, x5c: %(x5c)s", 

106 {"x5t": x5t, "x5c": public_certificate}) 

107 

108 return header 

109 

110 def _create_payload(self): 

111 now = _get_date_now() 

112 minutes = datetime.timedelta(0, 0, 0, 0, Jwt.SELF_SIGNED_JWT_LIFETIME) 

113 expires = now + minutes 

114 

115 self._log.debug( 

116 'Creating self signed JWT payload. Expires: %(expires)s NotBefore: %(nbf)s', 

117 {"expires": expires, "nbf": now}) 

118 

119 jwt_payload = {} 

120 jwt_payload[Jwt.AUDIENCE] = self._token_endpoint 

121 jwt_payload[Jwt.ISSUER] = self._client_id 

122 jwt_payload[Jwt.SUBJECT] = self._client_id 

123 jwt_payload[Jwt.NOT_BEFORE] = int(time.mktime(now.timetuple())) 

124 jwt_payload[Jwt.EXPIRES_ON] = int(time.mktime(expires.timetuple())) 

125 jwt_payload[Jwt.JWT_ID] = _get_new_jwt_id() 

126 

127 return jwt_payload 

128 

129 def _raise_on_invalid_thumbprint(self, thumbprint): 

130 thumbprint_sizes = [self.NumCharIn128BitHexString, self.numCharIn160BitHexString] 

131 size_ok = len(thumbprint) in thumbprint_sizes 

132 if not size_ok or not re.search(self.ThumbprintRegEx, thumbprint): 

133 raise AdalError("The thumbprint does not match a known format") 

134 

135 def _reduce_thumbprint(self, thumbprint): 

136 canonical = thumbprint.lower().replace(' ', '').replace(':', '') 

137 self._raise_on_invalid_thumbprint(canonical) 

138 return canonical 

139 

140 def create(self, certificate, thumbprint, public_certificate): 

141 thumbprint = self._reduce_thumbprint(thumbprint) 

142 

143 header = self._create_header(thumbprint, public_certificate) 

144 payload = self._create_payload() 

145 return _sign_jwt(header, payload, certificate)