1# Copyright 2016 Google LLC 
    2# 
    3# Licensed under the Apache License, Version 2.0 (the "License"); 
    4# you may not use this file except in compliance with the License. 
    5# You may obtain a copy of the License at 
    6# 
    7#      http://www.apache.org/licenses/LICENSE-2.0 
    8# 
    9# Unless required by applicable law or agreed to in writing, software 
    10# distributed under the License is distributed on an "AS IS" BASIS, 
    11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
    12# See the License for the specific language governing permissions and 
    13# limitations under the License. 
    14 
    15"""Authorization support for gRPC.""" 
    16 
    17from __future__ import absolute_import 
    18 
    19import logging 
    20import os 
    21 
    22from google.auth import environment_vars 
    23from google.auth import exceptions 
    24from google.auth.transport import _mtls_helper 
    25from google.oauth2 import service_account 
    26 
    27try: 
    28    import grpc  # type: ignore 
    29except ImportError as caught_exc:  # pragma: NO COVER 
    30    raise ImportError( 
    31        "gRPC is not installed from please install the grpcio package to use the gRPC transport." 
    32    ) from caught_exc 
    33 
    34_LOGGER = logging.getLogger(__name__) 
    35 
    36 
    37class AuthMetadataPlugin(grpc.AuthMetadataPlugin): 
    38    """A `gRPC AuthMetadataPlugin`_ that inserts the credentials into each 
    39    request. 
    40 
    41    .. _gRPC AuthMetadataPlugin: 
    42        http://www.grpc.io/grpc/python/grpc.html#grpc.AuthMetadataPlugin 
    43 
    44    Args: 
    45        credentials (google.auth.credentials.Credentials): The credentials to 
    46            add to requests. 
    47        request (google.auth.transport.Request): A HTTP transport request 
    48            object used to refresh credentials as needed. 
    49        default_host (Optional[str]): A host like "pubsub.googleapis.com". 
    50            This is used when a self-signed JWT is created from service 
    51            account credentials. 
    52    """ 
    53 
    54    def __init__(self, credentials, request, default_host=None): 
    55        # pylint: disable=no-value-for-parameter 
    56        # pylint doesn't realize that the super method takes no arguments 
    57        # because this class is the same name as the superclass. 
    58        super(AuthMetadataPlugin, self).__init__() 
    59        self._credentials = credentials 
    60        self._request = request 
    61        self._default_host = default_host 
    62 
    63    def _get_authorization_headers(self, context): 
    64        """Gets the authorization headers for a request. 
    65 
    66        Returns: 
    67            Sequence[Tuple[str, str]]: A list of request headers (key, value) 
    68                to add to the request. 
    69        """ 
    70        headers = {} 
    71 
    72        # https://google.aip.dev/auth/4111 
    73        # Attempt to use self-signed JWTs when a service account is used. 
    74        # A default host must be explicitly provided since it cannot always 
    75        # be determined from the context.service_url. 
    76        if isinstance(self._credentials, service_account.Credentials): 
    77            self._credentials._create_self_signed_jwt( 
    78                "https://{}/".format(self._default_host) if self._default_host else None 
    79            ) 
    80 
    81        self._credentials.before_request( 
    82            self._request, context.method_name, context.service_url, headers 
    83        ) 
    84 
    85        return list(headers.items()) 
    86 
    87    def __call__(self, context, callback): 
    88        """Passes authorization metadata into the given callback. 
    89 
    90        Args: 
    91            context (grpc.AuthMetadataContext): The RPC context. 
    92            callback (grpc.AuthMetadataPluginCallback): The callback that will 
    93                be invoked to pass in the authorization metadata. 
    94        """ 
    95        callback(self._get_authorization_headers(context), None) 
    96 
    97 
    98def secure_authorized_channel( 
    99    credentials, 
    100    request, 
    101    target, 
    102    ssl_credentials=None, 
    103    client_cert_callback=None, 
    104    **kwargs 
    105): 
    106    """Creates a secure authorized gRPC channel. 
    107 
    108    This creates a channel with SSL and :class:`AuthMetadataPlugin`. This 
    109    channel can be used to create a stub that can make authorized requests. 
    110    Users can configure client certificate or rely on device certificates to 
    111    establish a mutual TLS channel, if the `GOOGLE_API_USE_CLIENT_CERTIFICATE` 
    112    variable is explicitly set to `true`. 
    113 
    114    Example:: 
    115 
    116        import google.auth 
    117        import google.auth.transport.grpc 
    118        import google.auth.transport.requests 
    119        from google.cloud.speech.v1 import cloud_speech_pb2 
    120 
    121        # Get credentials. 
    122        credentials, _ = google.auth.default() 
    123 
    124        # Get an HTTP request function to refresh credentials. 
    125        request = google.auth.transport.requests.Request() 
    126 
    127        # Create a channel. 
    128        channel = google.auth.transport.grpc.secure_authorized_channel( 
    129            credentials, regular_endpoint, request, 
    130            ssl_credentials=grpc.ssl_channel_credentials()) 
    131 
    132        # Use the channel to create a stub. 
    133        cloud_speech.create_Speech_stub(channel) 
    134 
    135    Usage: 
    136 
    137    There are actually a couple of options to create a channel, depending on if 
    138    you want to create a regular or mutual TLS channel. 
    139 
    140    First let's list the endpoints (regular vs mutual TLS) to choose from:: 
    141 
    142        regular_endpoint = 'speech.googleapis.com:443' 
    143        mtls_endpoint = 'speech.mtls.googleapis.com:443' 
    144 
    145    Option 1: create a regular (non-mutual) TLS channel by explicitly setting 
    146    the ssl_credentials:: 
    147 
    148        regular_ssl_credentials = grpc.ssl_channel_credentials() 
    149 
    150        channel = google.auth.transport.grpc.secure_authorized_channel( 
    151            credentials, regular_endpoint, request, 
    152            ssl_credentials=regular_ssl_credentials) 
    153 
    154    Option 2: create a mutual TLS channel by calling a callback which returns 
    155    the client side certificate and the key (Note that 
    156    `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable must be explicitly 
    157    set to `true`):: 
    158 
    159        def my_client_cert_callback(): 
    160            code_to_load_client_cert_and_key() 
    161            if loaded: 
    162                return (pem_cert_bytes, pem_key_bytes) 
    163            raise MyClientCertFailureException() 
    164 
    165        try: 
    166            channel = google.auth.transport.grpc.secure_authorized_channel( 
    167                credentials, mtls_endpoint, request, 
    168                client_cert_callback=my_client_cert_callback) 
    169        except MyClientCertFailureException: 
    170            # handle the exception 
    171 
    172    Option 3: use application default SSL credentials. It searches and uses 
    173    the command in a context aware metadata file, which is available on devices 
    174    with endpoint verification support (Note that 
    175    `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable must be explicitly 
    176    set to `true`). 
    177    See https://cloud.google.com/endpoint-verification/docs/overview:: 
    178 
    179        try: 
    180            default_ssl_credentials = SslCredentials() 
    181        except: 
    182            # Exception can be raised if the context aware metadata is malformed. 
    183            # See :class:`SslCredentials` for the possible exceptions. 
    184 
    185        # Choose the endpoint based on the SSL credentials type. 
    186        if default_ssl_credentials.is_mtls: 
    187            endpoint_to_use = mtls_endpoint 
    188        else: 
    189            endpoint_to_use = regular_endpoint 
    190        channel = google.auth.transport.grpc.secure_authorized_channel( 
    191            credentials, endpoint_to_use, request, 
    192            ssl_credentials=default_ssl_credentials) 
    193 
    194    Option 4: not setting ssl_credentials and client_cert_callback. For devices 
    195    without endpoint verification support or `GOOGLE_API_USE_CLIENT_CERTIFICATE` 
    196    environment variable is not `true`, a regular TLS channel is created; 
    197    otherwise, a mutual TLS channel is created, however, the call should be 
    198    wrapped in a try/except block in case of malformed context aware metadata. 
    199 
    200    The following code uses regular_endpoint, it works the same no matter the 
    201    created channle is regular or mutual TLS. Regular endpoint ignores client 
    202    certificate and key:: 
    203 
    204        channel = google.auth.transport.grpc.secure_authorized_channel( 
    205            credentials, regular_endpoint, request) 
    206 
    207    The following code uses mtls_endpoint, if the created channle is regular, 
    208    and API mtls_endpoint is confgured to require client SSL credentials, API 
    209    calls using this channel will be rejected:: 
    210 
    211        channel = google.auth.transport.grpc.secure_authorized_channel( 
    212            credentials, mtls_endpoint, request) 
    213 
    214    Args: 
    215        credentials (google.auth.credentials.Credentials): The credentials to 
    216            add to requests. 
    217        request (google.auth.transport.Request): A HTTP transport request 
    218            object used to refresh credentials as needed. Even though gRPC 
    219            is a separate transport, there's no way to refresh the credentials 
    220            without using a standard http transport. 
    221        target (str): The host and port of the service. 
    222        ssl_credentials (grpc.ChannelCredentials): Optional SSL channel 
    223            credentials. This can be used to specify different certificates. 
    224            This argument is mutually exclusive with client_cert_callback; 
    225            providing both will raise an exception. 
    226            If ssl_credentials and client_cert_callback are None, application 
    227            default SSL credentials are used if `GOOGLE_API_USE_CLIENT_CERTIFICATE` 
    228            environment variable is explicitly set to `true`, otherwise one way TLS 
    229            SSL credentials are used. 
    230        client_cert_callback (Callable[[], (bytes, bytes)]): Optional 
    231            callback function to obtain client certicate and key for mutual TLS 
    232            connection. This argument is mutually exclusive with 
    233            ssl_credentials; providing both will raise an exception. 
    234            This argument does nothing unless `GOOGLE_API_USE_CLIENT_CERTIFICATE` 
    235            environment variable is explicitly set to `true`. 
    236        kwargs: Additional arguments to pass to :func:`grpc.secure_channel`. 
    237 
    238    Returns: 
    239        grpc.Channel: The created gRPC channel. 
    240 
    241    Raises: 
    242        google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel 
    243            creation failed for any reason. 
    244    """ 
    245    # Create the metadata plugin for inserting the authorization header. 
    246    metadata_plugin = AuthMetadataPlugin(credentials, request) 
    247 
    248    # Create a set of grpc.CallCredentials using the metadata plugin. 
    249    google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin) 
    250 
    251    if ssl_credentials and client_cert_callback: 
    252        raise exceptions.MalformedError( 
    253            "Received both ssl_credentials and client_cert_callback; " 
    254            "these are mutually exclusive." 
    255        ) 
    256 
    257    # If SSL credentials are not explicitly set, try client_cert_callback and ADC. 
    258    if not ssl_credentials: 
    259        use_client_cert = os.getenv( 
    260            environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE, "false" 
    261        ) 
    262        if use_client_cert == "true" and client_cert_callback: 
    263            # Use the callback if provided. 
    264            cert, key = client_cert_callback() 
    265            ssl_credentials = grpc.ssl_channel_credentials( 
    266                certificate_chain=cert, private_key=key 
    267            ) 
    268        elif use_client_cert == "true": 
    269            # Use application default SSL credentials. 
    270            adc_ssl_credentils = SslCredentials() 
    271            ssl_credentials = adc_ssl_credentils.ssl_credentials 
    272        else: 
    273            ssl_credentials = grpc.ssl_channel_credentials() 
    274 
    275    # Combine the ssl credentials and the authorization credentials. 
    276    composite_credentials = grpc.composite_channel_credentials( 
    277        ssl_credentials, google_auth_credentials 
    278    ) 
    279 
    280    return grpc.secure_channel(target, composite_credentials, **kwargs) 
    281 
    282 
    283class SslCredentials: 
    284    """Class for application default SSL credentials. 
    285 
    286    The behavior is controlled by `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment 
    287    variable whose default value is `false`. Client certificate will not be used 
    288    unless the environment variable is explicitly set to `true`. See 
    289    https://google.aip.dev/auth/4114 
    290 
    291    If the environment variable is `true`, then for devices with endpoint verification 
    292    support, a device certificate will be automatically loaded and mutual TLS will 
    293    be established. 
    294    See https://cloud.google.com/endpoint-verification/docs/overview. 
    295    """ 
    296 
    297    def __init__(self): 
    298        use_client_cert = os.getenv( 
    299            environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE, "false" 
    300        ) 
    301        if use_client_cert != "true": 
    302            self._is_mtls = False 
    303        else: 
    304            # Load client SSL credentials. 
    305            metadata_path = _mtls_helper._check_config_path( 
    306                _mtls_helper.CONTEXT_AWARE_METADATA_PATH 
    307            ) 
    308            self._is_mtls = metadata_path is not None 
    309 
    310    @property 
    311    def ssl_credentials(self): 
    312        """Get the created SSL channel credentials. 
    313 
    314        For devices with endpoint verification support, if the device certificate 
    315        loading has any problems, corresponding exceptions will be raised. For 
    316        a device without endpoint verification support, no exceptions will be 
    317        raised. 
    318 
    319        Returns: 
    320            grpc.ChannelCredentials: The created grpc channel credentials. 
    321 
    322        Raises: 
    323            google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel 
    324                creation failed for any reason. 
    325        """ 
    326        if self._is_mtls: 
    327            try: 
    328                _, cert, key, _ = _mtls_helper.get_client_ssl_credentials() 
    329                self._ssl_credentials = grpc.ssl_channel_credentials( 
    330                    certificate_chain=cert, private_key=key 
    331                ) 
    332            except exceptions.ClientCertError as caught_exc: 
    333                new_exc = exceptions.MutualTLSChannelError(caught_exc) 
    334                raise new_exc from caught_exc 
    335        else: 
    336            self._ssl_credentials = grpc.ssl_channel_credentials() 
    337 
    338        return self._ssl_credentials 
    339 
    340    @property 
    341    def is_mtls(self): 
    342        """Indicates if the created SSL channel credentials is mutual TLS.""" 
    343        return self._is_mtls