Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/msal/oauth2cli/assertion.py: 36%
53 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:20 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:20 +0000
1import time
2import binascii
3import base64
4import uuid
5import logging
8logger = logging.getLogger(__name__)
11def _str2bytes(raw):
12 # A conversion based on duck-typing rather than six.text_type
13 try: # Assuming it is a string
14 return raw.encode(encoding="utf-8")
15 except: # Otherwise we treat it as bytes and return it as-is
16 return raw
19class AssertionCreator(object):
20 def create_normal_assertion(
21 self, audience, issuer, subject, expires_at=None, expires_in=600,
22 issued_at=None, assertion_id=None, **kwargs):
23 """Create an assertion in bytes, based on the provided claims.
25 All parameter names are defined in https://tools.ietf.org/html/rfc7521#section-5
26 except the expires_in is defined here as lifetime-in-seconds,
27 which will be automatically translated into expires_at in UTC.
28 """
29 raise NotImplementedError("Will be implemented by sub-class")
31 def create_regenerative_assertion(
32 self, audience, issuer, subject=None, expires_in=600, **kwargs):
33 """Create an assertion as a callable,
34 which will then compute the assertion later when necessary.
36 This is a useful optimization to reuse the client assertion.
37 """
38 return AutoRefresher( # Returns a callable
39 lambda a=audience, i=issuer, s=subject, e=expires_in, kwargs=kwargs:
40 self.create_normal_assertion(a, i, s, expires_in=e, **kwargs),
41 expires_in=max(expires_in-60, 0))
44class AutoRefresher(object):
45 """Cache the output of a factory, and auto-refresh it when necessary. Usage::
47 r = AutoRefresher(time.time, expires_in=5)
48 for i in range(15):
49 print(r()) # the timestamp change only after every 5 seconds
50 time.sleep(1)
51 """
52 def __init__(self, factory, expires_in=540):
53 self._factory = factory
54 self._expires_in = expires_in
55 self._buf = {}
56 def __call__(self):
57 EXPIRES_AT, VALUE = "expires_at", "value"
58 now = time.time()
59 if self._buf.get(EXPIRES_AT, 0) <= now:
60 logger.debug("Regenerating new assertion")
61 self._buf = {VALUE: self._factory(), EXPIRES_AT: now + self._expires_in}
62 else:
63 logger.debug("Reusing still valid assertion")
64 return self._buf.get(VALUE)
67class JwtAssertionCreator(AssertionCreator):
68 def __init__(self, key, algorithm, sha1_thumbprint=None, headers=None):
69 """Construct a Jwt assertion creator.
71 Args:
73 key (str):
74 An unencrypted private key for signing, in a base64 encoded string.
75 It can also be a cryptography ``PrivateKey`` object,
76 which is how you can work with a previously-encrypted key.
77 See also https://github.com/jpadilla/pyjwt/pull/525
78 algorithm (str):
79 "RS256", etc.. See https://pyjwt.readthedocs.io/en/latest/algorithms.html
80 RSA and ECDSA algorithms require "pip install cryptography".
81 sha1_thumbprint (str): The x5t aka X.509 certificate SHA-1 thumbprint.
82 headers (dict): Additional headers, e.g. "kid" or "x5c" etc.
83 """
84 self.key = key
85 self.algorithm = algorithm
86 self.headers = headers or {}
87 if sha1_thumbprint: # https://tools.ietf.org/html/rfc7515#section-4.1.7
88 self.headers["x5t"] = base64.urlsafe_b64encode(
89 binascii.a2b_hex(sha1_thumbprint)).decode()
91 def create_normal_assertion(
92 self, audience, issuer, subject=None, expires_at=None, expires_in=600,
93 issued_at=None, assertion_id=None, not_before=None,
94 additional_claims=None, **kwargs):
95 """Create a JWT Assertion.
97 Parameters are defined in https://tools.ietf.org/html/rfc7523#section-3
98 Key-value pairs in additional_claims will be added into payload as-is.
99 """
100 import jwt # Lazy loading
101 now = time.time()
102 payload = {
103 'aud': audience,
104 'iss': issuer,
105 'sub': subject or issuer,
106 'exp': expires_at or (now + expires_in),
107 'iat': issued_at or now,
108 'jti': assertion_id or str(uuid.uuid4()),
109 }
110 if not_before:
111 payload['nbf'] = not_before
112 payload.update(additional_claims or {})
113 try:
114 str_or_bytes = jwt.encode( # PyJWT 1 returns bytes, PyJWT 2 returns str
115 payload, self.key, algorithm=self.algorithm, headers=self.headers)
116 return _str2bytes(str_or_bytes) # We normalize them into bytes
117 except:
118 if self.algorithm.startswith("RS") or self.algorithm.startswith("ES"):
119 logger.exception(
120 'Some algorithms requires "pip install cryptography". '
121 'See https://pyjwt.readthedocs.io/en/latest/installation.html#cryptographic-dependencies-optional')
122 raise
125# Obsolete. For backward compatibility. They will be removed in future versions.
126Signer = AssertionCreator # For backward compatibility
127JwtSigner = JwtAssertionCreator # For backward compatibility
128JwtSigner.sign_assertion = JwtAssertionCreator.create_normal_assertion # For backward compatibility