1# Copyright 2016 Google LLC 
    2# 
    3# Licensed under the Apache License, Version 2.0 (the "License"); 
    4# you may not use this file except in compliance with the License. 
    5# You may obtain a copy of the License at 
    6# 
    7#      http://www.apache.org/licenses/LICENSE-2.0 
    8# 
    9# Unless required by applicable law or agreed to in writing, software 
    10# distributed under the License is distributed on an "AS IS" BASIS, 
    11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
    12# See the License for the specific language governing permissions and 
    13# limitations under the License. 
    14 
    15"""Transport adapter for Requests.""" 
    16 
    17from __future__ import absolute_import 
    18 
    19import functools 
    20import logging 
    21import numbers 
    22import os 
    23import time 
    24 
    25try: 
    26    import requests 
    27except ImportError as caught_exc:  # pragma: NO COVER 
    28    raise ImportError( 
    29        "The requests library is not installed from please install the requests package to use the requests transport." 
    30    ) from caught_exc 
    31import requests.adapters  # pylint: disable=ungrouped-imports 
    32import requests.exceptions  # pylint: disable=ungrouped-imports 
    33from requests.packages.urllib3.util.ssl_ import (  # type: ignore 
    34    create_urllib3_context, 
    35)  # pylint: disable=ungrouped-imports 
    36 
    37from google.auth import _helpers 
    38from google.auth import environment_vars 
    39from google.auth import exceptions 
    40from google.auth import transport 
    41import google.auth.transport._mtls_helper 
    42from google.oauth2 import service_account 
    43 
    44_LOGGER = logging.getLogger(__name__) 
    45 
    46_DEFAULT_TIMEOUT = 120  # in seconds 
    47 
    48 
    49class _Response(transport.Response): 
    50    """Requests transport response adapter. 
    51 
    52    Args: 
    53        response (requests.Response): The raw Requests response. 
    54    """ 
    55 
    56    def __init__(self, response): 
    57        self._response = response 
    58 
    59    @property 
    60    def status(self): 
    61        return self._response.status_code 
    62 
    63    @property 
    64    def headers(self): 
    65        return self._response.headers 
    66 
    67    @property 
    68    def data(self): 
    69        return self._response.content 
    70 
    71 
    72class TimeoutGuard(object): 
    73    """A context manager raising an error if the suite execution took too long. 
    74 
    75    Args: 
    76        timeout (Union[None, Union[float, Tuple[float, float]]]): 
    77            The maximum number of seconds a suite can run without the context 
    78            manager raising a timeout exception on exit. If passed as a tuple, 
    79            the smaller of the values is taken as a timeout. If ``None``, a 
    80            timeout error is never raised. 
    81        timeout_error_type (Optional[Exception]): 
    82            The type of the error to raise on timeout. Defaults to 
    83            :class:`requests.exceptions.Timeout`. 
    84    """ 
    85 
    86    def __init__(self, timeout, timeout_error_type=requests.exceptions.Timeout): 
    87        self._timeout = timeout 
    88        self.remaining_timeout = timeout 
    89        self._timeout_error_type = timeout_error_type 
    90 
    91    def __enter__(self): 
    92        self._start = time.time() 
    93        return self 
    94 
    95    def __exit__(self, exc_type, exc_value, traceback): 
    96        if exc_value: 
    97            return  # let the error bubble up automatically 
    98 
    99        if self._timeout is None: 
    100            return  # nothing to do, the timeout was not specified 
    101 
    102        elapsed = time.time() - self._start 
    103        deadline_hit = False 
    104 
    105        if isinstance(self._timeout, numbers.Number): 
    106            self.remaining_timeout = self._timeout - elapsed 
    107            deadline_hit = self.remaining_timeout <= 0 
    108        else: 
    109            self.remaining_timeout = tuple(x - elapsed for x in self._timeout) 
    110            deadline_hit = min(self.remaining_timeout) <= 0 
    111 
    112        if deadline_hit: 
    113            raise self._timeout_error_type() 
    114 
    115 
    116class Request(transport.Request): 
    117    """Requests request adapter. 
    118 
    119    This class is used internally for making requests using various transports 
    120    in a consistent way. If you use :class:`AuthorizedSession` you do not need 
    121    to construct or use this class directly. 
    122 
    123    This class can be useful if you want to manually refresh a 
    124    :class:`~google.auth.credentials.Credentials` instance:: 
    125 
    126        import google.auth.transport.requests 
    127        import requests 
    128 
    129        request = google.auth.transport.requests.Request() 
    130 
    131        credentials.refresh(request) 
    132 
    133    Args: 
    134        session (requests.Session): An instance :class:`requests.Session` used 
    135            to make HTTP requests. If not specified, a session will be created. 
    136 
    137    .. automethod:: __call__ 
    138    """ 
    139 
    140    def __init__(self, session=None): 
    141        if not session: 
    142            session = requests.Session() 
    143 
    144        self.session = session 
    145 
    146    def __del__(self): 
    147        try: 
    148            if hasattr(self, "session") and self.session is not None: 
    149                self.session.close() 
    150        except TypeError: 
    151            # NOTE: For certain Python binary built, the queue.Empty exception 
    152            # might not be considered a normal Python exception causing 
    153            # TypeError. 
    154            pass 
    155 
    156    def __call__( 
    157        self, 
    158        url, 
    159        method="GET", 
    160        body=None, 
    161        headers=None, 
    162        timeout=_DEFAULT_TIMEOUT, 
    163        **kwargs 
    164    ): 
    165        """Make an HTTP request using requests. 
    166 
    167        Args: 
    168            url (str): The URI to be requested. 
    169            method (str): The HTTP method to use for the request. Defaults 
    170                to 'GET'. 
    171            body (bytes): The payload or body in HTTP request. 
    172            headers (Mapping[str, str]): Request headers. 
    173            timeout (Optional[int]): The number of seconds to wait for a 
    174                response from the server. If not specified or if None, the 
    175                requests default timeout will be used. 
    176            kwargs: Additional arguments passed through to the underlying 
    177                requests :meth:`~requests.Session.request` method. 
    178 
    179        Returns: 
    180            google.auth.transport.Response: The HTTP response. 
    181 
    182        Raises: 
    183            google.auth.exceptions.TransportError: If any exception occurred. 
    184        """ 
    185        try: 
    186            _helpers.request_log(_LOGGER, method, url, body, headers) 
    187            response = self.session.request( 
    188                method, url, data=body, headers=headers, timeout=timeout, **kwargs 
    189            ) 
    190            _helpers.response_log(_LOGGER, response) 
    191            return _Response(response) 
    192        except requests.exceptions.RequestException as caught_exc: 
    193            new_exc = exceptions.TransportError(caught_exc) 
    194            raise new_exc from caught_exc 
    195 
    196 
    197class _MutualTlsAdapter(requests.adapters.HTTPAdapter): 
    198    """ 
    199    A TransportAdapter that enables mutual TLS. 
    200 
    201    Args: 
    202        cert (bytes): client certificate in PEM format 
    203        key (bytes): client private key in PEM format 
    204 
    205    Raises: 
    206        ImportError: if certifi or pyOpenSSL is not installed 
    207        OpenSSL.crypto.Error: if client cert or key is invalid 
    208    """ 
    209 
    210    def __init__(self, cert, key): 
    211        import certifi 
    212        from OpenSSL import crypto 
    213        import urllib3.contrib.pyopenssl  # type: ignore 
    214 
    215        urllib3.contrib.pyopenssl.inject_into_urllib3() 
    216 
    217        pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key) 
    218        x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert) 
    219 
    220        ctx_poolmanager = create_urllib3_context() 
    221        ctx_poolmanager.load_verify_locations(cafile=certifi.where()) 
    222        ctx_poolmanager._ctx.use_certificate(x509) 
    223        ctx_poolmanager._ctx.use_privatekey(pkey) 
    224        self._ctx_poolmanager = ctx_poolmanager 
    225 
    226        ctx_proxymanager = create_urllib3_context() 
    227        ctx_proxymanager.load_verify_locations(cafile=certifi.where()) 
    228        ctx_proxymanager._ctx.use_certificate(x509) 
    229        ctx_proxymanager._ctx.use_privatekey(pkey) 
    230        self._ctx_proxymanager = ctx_proxymanager 
    231 
    232        super(_MutualTlsAdapter, self).__init__() 
    233 
    234    def init_poolmanager(self, *args, **kwargs): 
    235        kwargs["ssl_context"] = self._ctx_poolmanager 
    236        super(_MutualTlsAdapter, self).init_poolmanager(*args, **kwargs) 
    237 
    238    def proxy_manager_for(self, *args, **kwargs): 
    239        kwargs["ssl_context"] = self._ctx_proxymanager 
    240        return super(_MutualTlsAdapter, self).proxy_manager_for(*args, **kwargs) 
    241 
    242 
    243class _MutualTlsOffloadAdapter(requests.adapters.HTTPAdapter): 
    244    """ 
    245    A TransportAdapter that enables mutual TLS and offloads the client side 
    246    signing operation to the signing library. 
    247 
    248    Args: 
    249        enterprise_cert_file_path (str): the path to a enterprise cert JSON 
    250            file. The file should contain the following field: 
    251 
    252                { 
    253                    "libs": { 
    254                        "signer_library": "...", 
    255                        "offload_library": "..." 
    256                    } 
    257                } 
    258 
    259    Raises: 
    260        ImportError: if certifi or pyOpenSSL is not installed 
    261        google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel 
    262            creation failed for any reason. 
    263    """ 
    264 
    265    def __init__(self, enterprise_cert_file_path): 
    266        import certifi 
    267        from google.auth.transport import _custom_tls_signer 
    268 
    269        self.signer = _custom_tls_signer.CustomTlsSigner(enterprise_cert_file_path) 
    270        self.signer.load_libraries() 
    271 
    272        import urllib3.contrib.pyopenssl 
    273 
    274        urllib3.contrib.pyopenssl.inject_into_urllib3() 
    275 
    276        poolmanager = create_urllib3_context() 
    277        poolmanager.load_verify_locations(cafile=certifi.where()) 
    278        self.signer.attach_to_ssl_context(poolmanager) 
    279        self._ctx_poolmanager = poolmanager 
    280 
    281        proxymanager = create_urllib3_context() 
    282        proxymanager.load_verify_locations(cafile=certifi.where()) 
    283        self.signer.attach_to_ssl_context(proxymanager) 
    284        self._ctx_proxymanager = proxymanager 
    285 
    286        super(_MutualTlsOffloadAdapter, self).__init__() 
    287 
    288    def init_poolmanager(self, *args, **kwargs): 
    289        kwargs["ssl_context"] = self._ctx_poolmanager 
    290        super(_MutualTlsOffloadAdapter, self).init_poolmanager(*args, **kwargs) 
    291 
    292    def proxy_manager_for(self, *args, **kwargs): 
    293        kwargs["ssl_context"] = self._ctx_proxymanager 
    294        return super(_MutualTlsOffloadAdapter, self).proxy_manager_for(*args, **kwargs) 
    295 
    296 
    297class AuthorizedSession(requests.Session): 
    298    """A Requests Session class with credentials. 
    299 
    300    This class is used to perform requests to API endpoints that require 
    301    authorization:: 
    302 
    303        from google.auth.transport.requests import AuthorizedSession 
    304 
    305        authed_session = AuthorizedSession(credentials) 
    306 
    307        response = authed_session.request( 
    308            'GET', 'https://www.googleapis.com/storage/v1/b') 
    309 
    310 
    311    The underlying :meth:`request` implementation handles adding the 
    312    credentials' headers to the request and refreshing credentials as needed. 
    313 
    314    This class also supports mutual TLS via :meth:`configure_mtls_channel` 
    315    method. In order to use this method, the `GOOGLE_API_USE_CLIENT_CERTIFICATE` 
    316    environment variable must be explicitly set to ``true``, otherwise it does 
    317    nothing. Assume the environment is set to ``true``, the method behaves in the 
    318    following manner: 
    319 
    320    If client_cert_callback is provided, client certificate and private 
    321    key are loaded using the callback; if client_cert_callback is None, 
    322    application default SSL credentials will be used. Exceptions are raised if 
    323    there are problems with the certificate, private key, or the loading process, 
    324    so it should be called within a try/except block. 
    325 
    326    First we set the environment variable to ``true``, then create an :class:`AuthorizedSession` 
    327    instance and specify the endpoints:: 
    328 
    329        regular_endpoint = 'https://pubsub.googleapis.com/v1/projects/{my_project_id}/topics' 
    330        mtls_endpoint = 'https://pubsub.mtls.googleapis.com/v1/projects/{my_project_id}/topics' 
    331 
    332        authed_session = AuthorizedSession(credentials) 
    333 
    334    Now we can pass a callback to :meth:`configure_mtls_channel`:: 
    335 
    336        def my_cert_callback(): 
    337            # some code to load client cert bytes and private key bytes, both in 
    338            # PEM format. 
    339            some_code_to_load_client_cert_and_key() 
    340            if loaded: 
    341                return cert, key 
    342            raise MyClientCertFailureException() 
    343 
    344        # Always call configure_mtls_channel within a try/except block. 
    345        try: 
    346            authed_session.configure_mtls_channel(my_cert_callback) 
    347        except: 
    348            # handle exceptions. 
    349 
    350        if authed_session.is_mtls: 
    351            response = authed_session.request('GET', mtls_endpoint) 
    352        else: 
    353            response = authed_session.request('GET', regular_endpoint) 
    354 
    355 
    356    You can alternatively use application default SSL credentials like this:: 
    357 
    358        try: 
    359            authed_session.configure_mtls_channel() 
    360        except: 
    361            # handle exceptions. 
    362 
    363    Args: 
    364        credentials (google.auth.credentials.Credentials): The credentials to 
    365            add to the request. 
    366        refresh_status_codes (Sequence[int]): Which HTTP status codes indicate 
    367            that credentials should be refreshed and the request should be 
    368            retried. 
    369        max_refresh_attempts (int): The maximum number of times to attempt to 
    370            refresh the credentials and retry the request. 
    371        refresh_timeout (Optional[int]): The timeout value in seconds for 
    372            credential refresh HTTP requests. 
    373        auth_request (google.auth.transport.requests.Request): 
    374            (Optional) An instance of 
    375            :class:`~google.auth.transport.requests.Request` used when 
    376            refreshing credentials. If not passed, 
    377            an instance of :class:`~google.auth.transport.requests.Request` 
    378            is created. 
    379        default_host (Optional[str]): A host like "pubsub.googleapis.com". 
    380            This is used when a self-signed JWT is created from service 
    381            account credentials. 
    382    """ 
    383 
    384    def __init__( 
    385        self, 
    386        credentials, 
    387        refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES, 
    388        max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS, 
    389        refresh_timeout=None, 
    390        auth_request=None, 
    391        default_host=None, 
    392    ): 
    393        super(AuthorizedSession, self).__init__() 
    394        self.credentials = credentials 
    395        self._refresh_status_codes = refresh_status_codes 
    396        self._max_refresh_attempts = max_refresh_attempts 
    397        self._refresh_timeout = refresh_timeout 
    398        self._is_mtls = False 
    399        self._default_host = default_host 
    400 
    401        if auth_request is None: 
    402            self._auth_request_session = requests.Session() 
    403 
    404            # Using an adapter to make HTTP requests robust to network errors. 
    405            # This adapter retrys HTTP requests when network errors occur 
    406            # and the requests seems safely retryable. 
    407            retry_adapter = requests.adapters.HTTPAdapter(max_retries=3) 
    408            self._auth_request_session.mount("https://", retry_adapter) 
    409 
    410            # Do not pass `self` as the session here, as it can lead to 
    411            # infinite recursion. 
    412            auth_request = Request(self._auth_request_session) 
    413        else: 
    414            self._auth_request_session = None 
    415 
    416        # Request instance used by internal methods (for example, 
    417        # credentials.refresh). 
    418        self._auth_request = auth_request 
    419 
    420        # https://google.aip.dev/auth/4111 
    421        # Attempt to use self-signed JWTs when a service account is used. 
    422        if isinstance(self.credentials, service_account.Credentials): 
    423            self.credentials._create_self_signed_jwt( 
    424                "https://{}/".format(self._default_host) if self._default_host else None 
    425            ) 
    426 
    427    def configure_mtls_channel(self, client_cert_callback=None): 
    428        """Configure the client certificate and key for SSL connection. 
    429 
    430        The function does nothing unless `GOOGLE_API_USE_CLIENT_CERTIFICATE` is 
    431        explicitly set to `true`. In this case if client certificate and key are 
    432        successfully obtained (from the given client_cert_callback or from application 
    433        default SSL credentials), a :class:`_MutualTlsAdapter` instance will be mounted 
    434        to "https://" prefix. 
    435 
    436        Args: 
    437            client_cert_callback (Optional[Callable[[], (bytes, bytes)]]): 
    438                The optional callback returns the client certificate and private 
    439                key bytes both in PEM format. 
    440                If the callback is None, application default SSL credentials 
    441                will be used. 
    442 
    443        Raises: 
    444            google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel 
    445                creation failed for any reason. 
    446        """ 
    447        use_client_cert = os.getenv( 
    448            environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE, "false" 
    449        ) 
    450        if use_client_cert != "true": 
    451            self._is_mtls = False 
    452            return 
    453 
    454        try: 
    455            import OpenSSL 
    456        except ImportError as caught_exc: 
    457            new_exc = exceptions.MutualTLSChannelError(caught_exc) 
    458            raise new_exc from caught_exc 
    459 
    460        try: 
    461            ( 
    462                self._is_mtls, 
    463                cert, 
    464                key, 
    465            ) = google.auth.transport._mtls_helper.get_client_cert_and_key( 
    466                client_cert_callback 
    467            ) 
    468 
    469            if self._is_mtls: 
    470                mtls_adapter = _MutualTlsAdapter(cert, key) 
    471                self.mount("https://", mtls_adapter) 
    472        except ( 
    473            exceptions.ClientCertError, 
    474            ImportError, 
    475            OpenSSL.crypto.Error, 
    476        ) as caught_exc: 
    477            new_exc = exceptions.MutualTLSChannelError(caught_exc) 
    478            raise new_exc from caught_exc 
    479 
    480    def request( 
    481        self, 
    482        method, 
    483        url, 
    484        data=None, 
    485        headers=None, 
    486        max_allowed_time=None, 
    487        timeout=_DEFAULT_TIMEOUT, 
    488        **kwargs 
    489    ): 
    490        """Implementation of Requests' request. 
    491 
    492        Args: 
    493            timeout (Optional[Union[float, Tuple[float, float]]]): 
    494                The amount of time in seconds to wait for the server response 
    495                with each individual request. Can also be passed as a tuple 
    496                ``(connect_timeout, read_timeout)``. See :meth:`requests.Session.request` 
    497                documentation for details. 
    498            max_allowed_time (Optional[float]): 
    499                If the method runs longer than this, a ``Timeout`` exception is 
    500                automatically raised. Unlike the ``timeout`` parameter, this 
    501                value applies to the total method execution time, even if 
    502                multiple requests are made under the hood. 
    503 
    504                Mind that it is not guaranteed that the timeout error is raised 
    505                at ``max_allowed_time``. It might take longer, for example, if 
    506                an underlying request takes a lot of time, but the request 
    507                itself does not timeout, e.g. if a large file is being 
    508                transmitted. The timout error will be raised after such 
    509                request completes. 
    510        """ 
    511        # pylint: disable=arguments-differ 
    512        # Requests has a ton of arguments to request, but only two 
    513        # (method, url) are required. We pass through all of the other 
    514        # arguments to super, so no need to exhaustively list them here. 
    515 
    516        # Use a kwarg for this instead of an attribute to maintain 
    517        # thread-safety. 
    518        _credential_refresh_attempt = kwargs.pop("_credential_refresh_attempt", 0) 
    519 
    520        # Make a copy of the headers. They will be modified by the credentials 
    521        # and we want to pass the original headers if we recurse. 
    522        request_headers = headers.copy() if headers is not None else {} 
    523 
    524        # Do not apply the timeout unconditionally in order to not override the 
    525        # _auth_request's default timeout. 
    526        auth_request = ( 
    527            self._auth_request 
    528            if timeout is None 
    529            else functools.partial(self._auth_request, timeout=timeout) 
    530        ) 
    531 
    532        remaining_time = max_allowed_time 
    533 
    534        with TimeoutGuard(remaining_time) as guard: 
    535            self.credentials.before_request(auth_request, method, url, request_headers) 
    536        remaining_time = guard.remaining_timeout 
    537 
    538        with TimeoutGuard(remaining_time) as guard: 
    539            _helpers.request_log(_LOGGER, method, url, data, headers) 
    540            response = super(AuthorizedSession, self).request( 
    541                method, 
    542                url, 
    543                data=data, 
    544                headers=request_headers, 
    545                timeout=timeout, 
    546                **kwargs 
    547            ) 
    548        remaining_time = guard.remaining_timeout 
    549 
    550        # If the response indicated that the credentials needed to be 
    551        # refreshed, then refresh the credentials and re-attempt the 
    552        # request. 
    553        # A stored token may expire between the time it is retrieved and 
    554        # the time the request is made, so we may need to try twice. 
    555        if ( 
    556            response.status_code in self._refresh_status_codes 
    557            and _credential_refresh_attempt < self._max_refresh_attempts 
    558        ): 
    559 
    560            _LOGGER.info( 
    561                "Refreshing credentials due to a %s response. Attempt %s/%s.", 
    562                response.status_code, 
    563                _credential_refresh_attempt + 1, 
    564                self._max_refresh_attempts, 
    565            ) 
    566 
    567            # Do not apply the timeout unconditionally in order to not override the 
    568            # _auth_request's default timeout. 
    569            auth_request = ( 
    570                self._auth_request 
    571                if timeout is None 
    572                else functools.partial(self._auth_request, timeout=timeout) 
    573            ) 
    574 
    575            with TimeoutGuard(remaining_time) as guard: 
    576                self.credentials.refresh(auth_request) 
    577            remaining_time = guard.remaining_timeout 
    578 
    579            # Recurse. Pass in the original headers, not our modified set, but 
    580            # do pass the adjusted max allowed time (i.e. the remaining total time). 
    581            return self.request( 
    582                method, 
    583                url, 
    584                data=data, 
    585                headers=headers, 
    586                max_allowed_time=remaining_time, 
    587                timeout=timeout, 
    588                _credential_refresh_attempt=_credential_refresh_attempt + 1, 
    589                **kwargs 
    590            ) 
    591 
    592        return response 
    593 
    594    @property 
    595    def is_mtls(self): 
    596        """Indicates if the created SSL channel is mutual TLS.""" 
    597        return self._is_mtls 
    598 
    599    def close(self): 
    600        if self._auth_request_session is not None: 
    601            self._auth_request_session.close() 
    602        super(AuthorizedSession, self).close()