1import time 
    2import binascii 
    3import base64 
    4import uuid 
    5import logging 
    6 
    7 
    8logger = logging.getLogger(__name__) 
    9 
    10 
    11def _str2bytes(raw): 
    12    # A conversion based on duck-typing rather than six.text_type 
    13    try:  # Assuming it is a string 
    14        return raw.encode(encoding="utf-8") 
    15    except:  # Otherwise we treat it as bytes and return it as-is 
    16        return raw 
    17 
    18def _encode_thumbprint(thumbprint): 
    19    return base64.urlsafe_b64encode(binascii.a2b_hex(thumbprint)).decode() 
    20 
    21class AssertionCreator(object): 
    22    def create_normal_assertion( 
    23            self, audience, issuer, subject, expires_at=None, expires_in=600, 
    24            issued_at=None, assertion_id=None, **kwargs): 
    25        """Create an assertion in bytes, based on the provided claims. 
    26 
    27        All parameter names are defined in https://tools.ietf.org/html/rfc7521#section-5 
    28        except the expires_in is defined here as lifetime-in-seconds, 
    29        which will be automatically translated into expires_at in UTC. 
    30        """ 
    31        raise NotImplementedError("Will be implemented by sub-class") 
    32 
    33    def create_regenerative_assertion( 
    34            self, audience, issuer, subject=None, expires_in=600, **kwargs): 
    35        """Create an assertion as a callable, 
    36        which will then compute the assertion later when necessary. 
    37 
    38        This is a useful optimization to reuse the client assertion. 
    39        """ 
    40        return AutoRefresher(  # Returns a callable 
    41            lambda a=audience, i=issuer, s=subject, e=expires_in, kwargs=kwargs: 
    42                self.create_normal_assertion(a, i, s, expires_in=e, **kwargs), 
    43            expires_in=max(expires_in-60, 0)) 
    44 
    45 
    46class AutoRefresher(object): 
    47    """Cache the output of a factory, and auto-refresh it when necessary. Usage:: 
    48 
    49        r = AutoRefresher(time.time, expires_in=5) 
    50        for i in range(15): 
    51            print(r())  # the timestamp change only after every 5 seconds 
    52            time.sleep(1) 
    53    """ 
    54    def __init__(self, factory, expires_in=540): 
    55        self._factory = factory 
    56        self._expires_in = expires_in 
    57        self._buf = {} 
    58    def __call__(self): 
    59        EXPIRES_AT, VALUE = "expires_at", "value" 
    60        now = time.time() 
    61        if self._buf.get(EXPIRES_AT, 0) <= now: 
    62            logger.debug("Regenerating new assertion") 
    63            self._buf = {VALUE: self._factory(), EXPIRES_AT: now + self._expires_in} 
    64        else: 
    65            logger.debug("Reusing still valid assertion") 
    66        return self._buf.get(VALUE) 
    67 
    68 
    69class JwtAssertionCreator(AssertionCreator): 
    70    def __init__( 
    71        self, key, algorithm, sha1_thumbprint=None, headers=None, 
    72        *, 
    73        sha256_thumbprint=None, 
    74    ): 
    75        """Construct a Jwt assertion creator. 
    76 
    77        Args: 
    78 
    79            key (str): 
    80                An unencrypted private key for signing, in a base64 encoded string. 
    81                It can also be a cryptography ``PrivateKey`` object, 
    82                which is how you can work with a previously-encrypted key. 
    83                See also https://github.com/jpadilla/pyjwt/pull/525 
    84            algorithm (str): 
    85                "RS256", etc.. See https://pyjwt.readthedocs.io/en/latest/algorithms.html 
    86                RSA and ECDSA algorithms require "pip install cryptography". 
    87            sha1_thumbprint (str): The x5t aka X.509 certificate SHA-1 thumbprint. 
    88            headers (dict): Additional headers, e.g. "kid" or "x5c" etc. 
    89            sha256_thumbprint (str): The x5t#S256 aka X.509 certificate SHA-256 thumbprint. 
    90        """ 
    91        self.key = key 
    92        self.algorithm = algorithm 
    93        self.headers = headers or {} 
    94        if sha256_thumbprint:  # https://datatracker.ietf.org/doc/html/rfc7515#section-4.1.8 
    95            self.headers["x5t#S256"] = _encode_thumbprint(sha256_thumbprint) 
    96        if sha1_thumbprint:  # https://tools.ietf.org/html/rfc7515#section-4.1.7 
    97            self.headers["x5t"] = _encode_thumbprint(sha1_thumbprint) 
    98 
    99    def create_normal_assertion( 
    100            self, audience, issuer, subject=None, expires_at=None, expires_in=600, 
    101            issued_at=None, assertion_id=None, not_before=None, 
    102            additional_claims=None, **kwargs): 
    103        """Create a JWT Assertion. 
    104 
    105        Parameters are defined in https://tools.ietf.org/html/rfc7523#section-3 
    106        Key-value pairs in additional_claims will be added into payload as-is. 
    107        """ 
    108        import jwt  # Lazy loading 
    109        now = time.time() 
    110        payload = { 
    111            'aud': audience, 
    112            'iss': issuer, 
    113            'sub': subject or issuer, 
    114            'exp': expires_at or (now + expires_in), 
    115            'iat': issued_at or now, 
    116            'jti': assertion_id or str(uuid.uuid4()), 
    117            } 
    118        if not_before: 
    119            payload['nbf'] = not_before 
    120        payload.update(additional_claims or {}) 
    121        try: 
    122            str_or_bytes = jwt.encode(  # PyJWT 1 returns bytes, PyJWT 2 returns str 
    123                payload, self.key, algorithm=self.algorithm, headers=self.headers) 
    124            return _str2bytes(str_or_bytes)  # We normalize them into bytes 
    125        except: 
    126            if self.algorithm.startswith("RS") or self.algorithm.startswith("ES"): 
    127                logger.exception( 
    128                    'Some algorithms requires "pip install cryptography". ' 
    129                    'See https://pyjwt.readthedocs.io/en/latest/installation.html#cryptographic-dependencies-optional') 
    130            raise 
    131 
    132 
    133# Obsolete. For backward compatibility. They will be removed in future versions. 
    134Signer = AssertionCreator  # For backward compatibility 
    135JwtSigner = JwtAssertionCreator  # For backward compatibility 
    136JwtSigner.sign_assertion = JwtAssertionCreator.create_normal_assertion  # For backward compatibility 
    137