1import logging 
    2 
    3from oauthlib.common import generate_token, urldecode 
    4from oauthlib.oauth2 import WebApplicationClient, InsecureTransportError 
    5from oauthlib.oauth2 import LegacyApplicationClient 
    6from oauthlib.oauth2 import TokenExpiredError, is_secure_transport 
    7import requests 
    8 
    9log = logging.getLogger(__name__) 
    10 
    11 
    12class TokenUpdated(Warning): 
    13    def __init__(self, token): 
    14        super(TokenUpdated, self).__init__() 
    15        self.token = token 
    16 
    17 
    18class OAuth2Session(requests.Session): 
    19    """Versatile OAuth 2 extension to :class:`requests.Session`. 
    20 
    21    Supports any grant type adhering to :class:`oauthlib.oauth2.Client` spec 
    22    including the four core OAuth 2 grants. 
    23 
    24    Can be used to create authorization urls, fetch tokens and access protected 
    25    resources using the :class:`requests.Session` interface you are used to. 
    26 
    27    - :class:`oauthlib.oauth2.WebApplicationClient` (default): Authorization Code Grant 
    28    - :class:`oauthlib.oauth2.MobileApplicationClient`: Implicit Grant 
    29    - :class:`oauthlib.oauth2.LegacyApplicationClient`: Password Credentials Grant 
    30    - :class:`oauthlib.oauth2.BackendApplicationClient`: Client Credentials Grant 
    31 
    32    Note that the only time you will be using Implicit Grant from python is if 
    33    you are driving a user agent able to obtain URL fragments. 
    34    """ 
    35 
    36    def __init__( 
    37        self, 
    38        client_id=None, 
    39        client=None, 
    40        auto_refresh_url=None, 
    41        auto_refresh_kwargs=None, 
    42        scope=None, 
    43        redirect_uri=None, 
    44        token=None, 
    45        state=None, 
    46        token_updater=None, 
    47        pkce=None, 
    48        **kwargs 
    49    ): 
    50        """Construct a new OAuth 2 client session. 
    51 
    52        :param client_id: Client id obtained during registration 
    53        :param client: :class:`oauthlib.oauth2.Client` to be used. Default is 
    54                       WebApplicationClient which is useful for any 
    55                       hosted application but not mobile or desktop. 
    56        :param scope: List of scopes you wish to request access to 
    57        :param redirect_uri: Redirect URI you registered as callback 
    58        :param token: Token dictionary, must include access_token 
    59                      and token_type. 
    60        :param state: State string used to prevent CSRF. This will be given 
    61                      when creating the authorization url and must be supplied 
    62                      when parsing the authorization response. 
    63                      Can be either a string or a no argument callable. 
    64        :auto_refresh_url: Refresh token endpoint URL, must be HTTPS. Supply 
    65                           this if you wish the client to automatically refresh 
    66                           your access tokens. 
    67        :auto_refresh_kwargs: Extra arguments to pass to the refresh token 
    68                              endpoint. 
    69        :token_updater: Method with one argument, token, to be used to update 
    70                        your token database on automatic token refresh. If not 
    71                        set a TokenUpdated warning will be raised when a token 
    72                        has been refreshed. This warning will carry the token 
    73                        in its token argument. 
    74        :param pkce: Set "S256" or "plain" to enable PKCE. Default is disabled. 
    75        :param kwargs: Arguments to pass to the Session constructor. 
    76        """ 
    77        super(OAuth2Session, self).__init__(**kwargs) 
    78        self._client = client or WebApplicationClient(client_id, token=token) 
    79        self.token = token or {} 
    80        self._scope = scope 
    81        self.redirect_uri = redirect_uri 
    82        self.state = state or generate_token 
    83        self._state = state 
    84        self.auto_refresh_url = auto_refresh_url 
    85        self.auto_refresh_kwargs = auto_refresh_kwargs or {} 
    86        self.token_updater = token_updater 
    87        self._pkce = pkce 
    88 
    89        if self._pkce not in ["S256", "plain", None]: 
    90            raise AttributeError("Wrong value for {}(.., pkce={})".format(self.__class__, self._pkce)) 
    91 
    92        # Ensure that requests doesn't do any automatic auth. See #278. 
    93        # The default behavior can be re-enabled by setting auth to None. 
    94        self.auth = lambda r: r 
    95 
    96        # Allow customizations for non compliant providers through various 
    97        # hooks to adjust requests and responses. 
    98        self.compliance_hook = { 
    99            "access_token_response": set(), 
    100            "refresh_token_response": set(), 
    101            "protected_request": set(), 
    102            "refresh_token_request": set(), 
    103            "access_token_request": set(), 
    104        } 
    105 
    106    @property 
    107    def scope(self): 
    108        """By default the scope from the client is used, except if overridden""" 
    109        if self._scope is not None: 
    110            return self._scope 
    111        elif self._client is not None: 
    112            return self._client.scope 
    113        else: 
    114            return None 
    115 
    116    @scope.setter 
    117    def scope(self, scope): 
    118        self._scope = scope 
    119 
    120    def new_state(self): 
    121        """Generates a state string to be used in authorizations.""" 
    122        try: 
    123            self._state = self.state() 
    124            log.debug("Generated new state %s.", self._state) 
    125        except TypeError: 
    126            self._state = self.state 
    127            log.debug("Re-using previously supplied state %s.", self._state) 
    128        return self._state 
    129 
    130    @property 
    131    def client_id(self): 
    132        return getattr(self._client, "client_id", None) 
    133 
    134    @client_id.setter 
    135    def client_id(self, value): 
    136        self._client.client_id = value 
    137 
    138    @client_id.deleter 
    139    def client_id(self): 
    140        del self._client.client_id 
    141 
    142    @property 
    143    def token(self): 
    144        return getattr(self._client, "token", None) 
    145 
    146    @token.setter 
    147    def token(self, value): 
    148        self._client.token = value 
    149        self._client.populate_token_attributes(value) 
    150 
    151    @property 
    152    def access_token(self): 
    153        return getattr(self._client, "access_token", None) 
    154 
    155    @access_token.setter 
    156    def access_token(self, value): 
    157        self._client.access_token = value 
    158 
    159    @access_token.deleter 
    160    def access_token(self): 
    161        del self._client.access_token 
    162 
    163    @property 
    164    def authorized(self): 
    165        """Boolean that indicates whether this session has an OAuth token 
    166        or not. If `self.authorized` is True, you can reasonably expect 
    167        OAuth-protected requests to the resource to succeed. If 
    168        `self.authorized` is False, you need the user to go through the OAuth 
    169        authentication dance before OAuth-protected requests to the resource 
    170        will succeed. 
    171        """ 
    172        return bool(self.access_token) 
    173 
    174    def authorization_url(self, url, state=None, **kwargs): 
    175        """Form an authorization URL. 
    176 
    177        :param url: Authorization endpoint url, must be HTTPS. 
    178        :param state: An optional state string for CSRF protection. If not 
    179                      given it will be generated for you. 
    180        :param kwargs: Extra parameters to include. 
    181        :return: authorization_url, state 
    182        """ 
    183        state = state or self.new_state() 
    184        if self._pkce: 
    185            self._code_verifier = self._client.create_code_verifier(43) 
    186            kwargs["code_challenge_method"] = self._pkce 
    187            kwargs["code_challenge"] = self._client.create_code_challenge( 
    188                code_verifier=self._code_verifier, 
    189                code_challenge_method=self._pkce 
    190            ) 
    191        return ( 
    192            self._client.prepare_request_uri( 
    193                url, 
    194                redirect_uri=self.redirect_uri, 
    195                scope=self.scope, 
    196                state=state, 
    197                **kwargs 
    198            ), 
    199            state, 
    200        ) 
    201 
    202    def fetch_token( 
    203        self, 
    204        token_url, 
    205        code=None, 
    206        authorization_response=None, 
    207        body="", 
    208        auth=None, 
    209        username=None, 
    210        password=None, 
    211        method="POST", 
    212        force_querystring=False, 
    213        timeout=None, 
    214        headers=None, 
    215        verify=None, 
    216        proxies=None, 
    217        include_client_id=None, 
    218        client_secret=None, 
    219        cert=None, 
    220        **kwargs 
    221    ): 
    222        """Generic method for fetching an access token from the token endpoint. 
    223 
    224        If you are using the MobileApplicationClient you will want to use 
    225        `token_from_fragment` instead of `fetch_token`. 
    226 
    227        The current implementation enforces the RFC guidelines. 
    228 
    229        :param token_url: Token endpoint URL, must use HTTPS. 
    230        :param code: Authorization code (used by WebApplicationClients). 
    231        :param authorization_response: Authorization response URL, the callback 
    232                                       URL of the request back to you. Used by 
    233                                       WebApplicationClients instead of code. 
    234        :param body: Optional application/x-www-form-urlencoded body to add the 
    235                     include in the token request. Prefer kwargs over body. 
    236        :param auth: An auth tuple or method as accepted by `requests`. 
    237        :param username: Username required by LegacyApplicationClients to appear 
    238                         in the request body. 
    239        :param password: Password required by LegacyApplicationClients to appear 
    240                         in the request body. 
    241        :param method: The HTTP method used to make the request. Defaults 
    242                       to POST, but may also be GET. Other methods should 
    243                       be added as needed. 
    244        :param force_querystring: If True, force the request body to be sent 
    245            in the querystring instead. 
    246        :param timeout: Timeout of the request in seconds. 
    247        :param headers: Dict to default request headers with. 
    248        :param verify: Verify SSL certificate. 
    249        :param proxies: The `proxies` argument is passed onto `requests`. 
    250        :param include_client_id: Should the request body include the 
    251                                  `client_id` parameter. Default is `None`, 
    252                                  which will attempt to autodetect. This can be 
    253                                  forced to always include (True) or never 
    254                                  include (False). 
    255        :param client_secret: The `client_secret` paired to the `client_id`. 
    256                              This is generally required unless provided in the 
    257                              `auth` tuple. If the value is `None`, it will be 
    258                              omitted from the request, however if the value is 
    259                              an empty string, an empty string will be sent. 
    260        :param cert: Client certificate to send for OAuth 2.0 Mutual-TLS Client 
    261                     Authentication (draft-ietf-oauth-mtls). Can either be the 
    262                     path of a file containing the private key and certificate or 
    263                     a tuple of two filenames for certificate and key. 
    264        :param kwargs: Extra parameters to include in the token request. 
    265        :return: A token dict 
    266        """ 
    267        if not is_secure_transport(token_url): 
    268            raise InsecureTransportError() 
    269 
    270        if not code and authorization_response: 
    271            self._client.parse_request_uri_response( 
    272                authorization_response, state=self._state 
    273            ) 
    274            code = self._client.code 
    275        elif not code and isinstance(self._client, WebApplicationClient): 
    276            code = self._client.code 
    277            if not code: 
    278                raise ValueError( 
    279                    "Please supply either code or " "authorization_response parameters." 
    280                ) 
    281 
    282        if self._pkce: 
    283            if self._code_verifier is None: 
    284                raise ValueError( 
    285                    "Code verifier is not found, authorization URL must be generated before" 
    286                ) 
    287            kwargs["code_verifier"] = self._code_verifier 
    288 
    289        # Earlier versions of this library build an HTTPBasicAuth header out of 
    290        # `username` and `password`. The RFC states, however these attributes 
    291        # must be in the request body and not the header. 
    292        # If an upstream server is not spec compliant and requires them to 
    293        # appear as an Authorization header, supply an explicit `auth` header 
    294        # to this function. 
    295        # This check will allow for empty strings, but not `None`. 
    296        # 
    297        # References 
    298        # 4.3.2 - Resource Owner Password Credentials Grant 
    299        #         https://tools.ietf.org/html/rfc6749#section-4.3.2 
    300 
    301        if isinstance(self._client, LegacyApplicationClient): 
    302            if username is None: 
    303                raise ValueError( 
    304                    "`LegacyApplicationClient` requires both the " 
    305                    "`username` and `password` parameters." 
    306                ) 
    307            if password is None: 
    308                raise ValueError( 
    309                    "The required parameter `username` was supplied, " 
    310                    "but `password` was not." 
    311                ) 
    312 
    313        # merge username and password into kwargs for `prepare_request_body` 
    314        if username is not None: 
    315            kwargs["username"] = username 
    316        if password is not None: 
    317            kwargs["password"] = password 
    318 
    319        # is an auth explicitly supplied? 
    320        if auth is not None: 
    321            # if we're dealing with the default of `include_client_id` (None): 
    322            # we will assume the `auth` argument is for an RFC compliant server 
    323            # and we should not send the `client_id` in the body. 
    324            # This approach allows us to still force the client_id by submitting 
    325            # `include_client_id=True` along with an `auth` object. 
    326            if include_client_id is None: 
    327                include_client_id = False 
    328 
    329        # otherwise we may need to create an auth header 
    330        else: 
    331            # since we don't have an auth header, we MAY need to create one 
    332            # it is possible that we want to send the `client_id` in the body 
    333            # if so, `include_client_id` should be set to True 
    334            # otherwise, we will generate an auth header 
    335            if include_client_id is not True: 
    336                client_id = self.client_id 
    337                if client_id: 
    338                    log.debug( 
    339                        'Encoding `client_id` "%s" with `client_secret` ' 
    340                        "as Basic auth credentials.", 
    341                        client_id, 
    342                    ) 
    343                    client_secret = client_secret if client_secret is not None else "" 
    344                    auth = requests.auth.HTTPBasicAuth(client_id, client_secret) 
    345 
    346        if include_client_id: 
    347            # this was pulled out of the params 
    348            # it needs to be passed into prepare_request_body 
    349            if client_secret is not None: 
    350                kwargs["client_secret"] = client_secret 
    351 
    352        body = self._client.prepare_request_body( 
    353            code=code, 
    354            body=body, 
    355            redirect_uri=self.redirect_uri, 
    356            include_client_id=include_client_id, 
    357            **kwargs 
    358        ) 
    359 
    360        headers = headers or { 
    361            "Accept": "application/json", 
    362            "Content-Type": "application/x-www-form-urlencoded", 
    363        } 
    364        self.token = {} 
    365        request_kwargs = {} 
    366        if method.upper() == "POST": 
    367            request_kwargs["params" if force_querystring else "data"] = dict( 
    368                urldecode(body) 
    369            ) 
    370        elif method.upper() == "GET": 
    371            request_kwargs["params"] = dict(urldecode(body)) 
    372        else: 
    373            raise ValueError("The method kwarg must be POST or GET.") 
    374 
    375        for hook in self.compliance_hook["access_token_request"]: 
    376            log.debug("Invoking access_token_request hook %s.", hook) 
    377            token_url, headers, request_kwargs = hook( 
    378                token_url, headers, request_kwargs 
    379            ) 
    380 
    381        r = self.request( 
    382            method=method, 
    383            url=token_url, 
    384            timeout=timeout, 
    385            headers=headers, 
    386            auth=auth, 
    387            verify=verify, 
    388            proxies=proxies, 
    389            cert=cert, 
    390            **request_kwargs 
    391        ) 
    392 
    393        log.debug("Request to fetch token completed with status %s.", r.status_code) 
    394        log.debug("Request url was %s", r.request.url) 
    395        log.debug("Request headers were %s", r.request.headers) 
    396        log.debug("Request body was %s", r.request.body) 
    397        log.debug("Response headers were %s and content %s.", r.headers, r.text) 
    398        log.debug( 
    399            "Invoking %d token response hooks.", 
    400            len(self.compliance_hook["access_token_response"]), 
    401        ) 
    402        for hook in self.compliance_hook["access_token_response"]: 
    403            log.debug("Invoking hook %s.", hook) 
    404            r = hook(r) 
    405 
    406        self._client.parse_request_body_response(r.text, scope=self.scope) 
    407        self.token = self._client.token 
    408        log.debug("Obtained token %s.", self.token) 
    409        return self.token 
    410 
    411    def token_from_fragment(self, authorization_response): 
    412        """Parse token from the URI fragment, used by MobileApplicationClients. 
    413 
    414        :param authorization_response: The full URL of the redirect back to you 
    415        :return: A token dict 
    416        """ 
    417        self._client.parse_request_uri_response( 
    418            authorization_response, state=self._state 
    419        ) 
    420        self.token = self._client.token 
    421        return self.token 
    422 
    423    def refresh_token( 
    424        self, 
    425        token_url, 
    426        refresh_token=None, 
    427        body="", 
    428        auth=None, 
    429        timeout=None, 
    430        headers=None, 
    431        verify=None, 
    432        proxies=None, 
    433        **kwargs 
    434    ): 
    435        """Fetch a new access token using a refresh token. 
    436 
    437        :param token_url: The token endpoint, must be HTTPS. 
    438        :param refresh_token: The refresh_token to use. 
    439        :param body: Optional application/x-www-form-urlencoded body to add the 
    440                     include in the token request. Prefer kwargs over body. 
    441        :param auth: An auth tuple or method as accepted by `requests`. 
    442        :param timeout: Timeout of the request in seconds. 
    443        :param headers: A dict of headers to be used by `requests`. 
    444        :param verify: Verify SSL certificate. 
    445        :param proxies: The `proxies` argument will be passed to `requests`. 
    446        :param kwargs: Extra parameters to include in the token request. 
    447        :return: A token dict 
    448        """ 
    449        if not token_url: 
    450            raise ValueError("No token endpoint set for auto_refresh.") 
    451 
    452        if not is_secure_transport(token_url): 
    453            raise InsecureTransportError() 
    454 
    455        refresh_token = refresh_token or self.token.get("refresh_token") 
    456 
    457        log.debug( 
    458            "Adding auto refresh key word arguments %s.", self.auto_refresh_kwargs 
    459        ) 
    460        kwargs.update(self.auto_refresh_kwargs) 
    461        body = self._client.prepare_refresh_body( 
    462            body=body, refresh_token=refresh_token, scope=self.scope, **kwargs 
    463        ) 
    464        log.debug("Prepared refresh token request body %s", body) 
    465 
    466        if headers is None: 
    467            headers = { 
    468                "Accept": "application/json", 
    469                "Content-Type": ("application/x-www-form-urlencoded"), 
    470            } 
    471 
    472        for hook in self.compliance_hook["refresh_token_request"]: 
    473            log.debug("Invoking refresh_token_request hook %s.", hook) 
    474            token_url, headers, body = hook(token_url, headers, body) 
    475 
    476        r = self.post( 
    477            token_url, 
    478            data=dict(urldecode(body)), 
    479            auth=auth, 
    480            timeout=timeout, 
    481            headers=headers, 
    482            verify=verify, 
    483            withhold_token=True, 
    484            proxies=proxies, 
    485        ) 
    486        log.debug("Request to refresh token completed with status %s.", r.status_code) 
    487        log.debug("Response headers were %s and content %s.", r.headers, r.text) 
    488        log.debug( 
    489            "Invoking %d token response hooks.", 
    490            len(self.compliance_hook["refresh_token_response"]), 
    491        ) 
    492        for hook in self.compliance_hook["refresh_token_response"]: 
    493            log.debug("Invoking hook %s.", hook) 
    494            r = hook(r) 
    495 
    496        self.token = self._client.parse_request_body_response(r.text, scope=self.scope) 
    497        if "refresh_token" not in self.token: 
    498            log.debug("No new refresh token given. Re-using old.") 
    499            self.token["refresh_token"] = refresh_token 
    500        return self.token 
    501 
    502    def request( 
    503        self, 
    504        method, 
    505        url, 
    506        data=None, 
    507        headers=None, 
    508        withhold_token=False, 
    509        client_id=None, 
    510        client_secret=None, 
    511        files=None, 
    512        **kwargs 
    513    ): 
    514        """Intercept all requests and add the OAuth 2 token if present.""" 
    515        if not is_secure_transport(url): 
    516            raise InsecureTransportError() 
    517        if self.token and not withhold_token: 
    518            log.debug( 
    519                "Invoking %d protected resource request hooks.", 
    520                len(self.compliance_hook["protected_request"]), 
    521            ) 
    522            for hook in self.compliance_hook["protected_request"]: 
    523                log.debug("Invoking hook %s.", hook) 
    524                url, headers, data = hook(url, headers, data) 
    525 
    526            log.debug("Adding token %s to request.", self.token) 
    527            try: 
    528                url, headers, data = self._client.add_token( 
    529                    url, http_method=method, body=data, headers=headers 
    530                ) 
    531            # Attempt to retrieve and save new access token if expired 
    532            except TokenExpiredError: 
    533                if self.auto_refresh_url: 
    534                    log.debug( 
    535                        "Auto refresh is set, attempting to refresh at %s.", 
    536                        self.auto_refresh_url, 
    537                    ) 
    538 
    539                    # We mustn't pass auth twice. 
    540                    auth = kwargs.pop("auth", None) 
    541                    if client_id and client_secret and (auth is None): 
    542                        log.debug( 
    543                            'Encoding client_id "%s" with client_secret as Basic auth credentials.', 
    544                            client_id, 
    545                        ) 
    546                        auth = requests.auth.HTTPBasicAuth(client_id, client_secret) 
    547                    token = self.refresh_token( 
    548                        self.auto_refresh_url, auth=auth, **kwargs 
    549                    ) 
    550                    if self.token_updater: 
    551                        log.debug( 
    552                            "Updating token to %s using %s.", token, self.token_updater 
    553                        ) 
    554                        self.token_updater(token) 
    555                        url, headers, data = self._client.add_token( 
    556                            url, http_method=method, body=data, headers=headers 
    557                        ) 
    558                    else: 
    559                        raise TokenUpdated(token) 
    560                else: 
    561                    raise 
    562 
    563        log.debug("Requesting url %s using method %s.", url, method) 
    564        log.debug("Supplying headers %s and data %s", headers, data) 
    565        log.debug("Passing through key word arguments %s.", kwargs) 
    566        return super(OAuth2Session, self).request( 
    567            method, url, headers=headers, data=data, files=files, **kwargs 
    568        ) 
    569 
    570    def register_compliance_hook(self, hook_type, hook): 
    571        """Register a hook for request/response tweaking. 
    572 
    573        Available hooks are: 
    574            access_token_response invoked before token parsing. 
    575            refresh_token_response invoked before refresh token parsing. 
    576            protected_request invoked before making a request. 
    577            access_token_request invoked before making a token fetch request. 
    578            refresh_token_request invoked before making a refresh request. 
    579 
    580        If you find a new hook is needed please send a GitHub PR request 
    581        or open an issue. 
    582        """ 
    583        if hook_type not in self.compliance_hook: 
    584            raise ValueError( 
    585                "Hook type %s is not in %s.", hook_type, self.compliance_hook 
    586            ) 
    587        self.compliance_hook[hook_type].add(hook)