1# -------------------------------------------------------------------------- 
    2# 
    3# Copyright (c) Microsoft Corporation. All rights reserved. 
    4# 
    5# The MIT License (MIT) 
    6# 
    7# Permission is hereby granted, free of charge, to any person obtaining a copy 
    8# of this software and associated documentation files (the ""Software""), to 
    9# deal in the Software without restriction, including without limitation the 
    10# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 
    11# sell copies of the Software, and to permit persons to whom the Software is 
    12# furnished to do so, subject to the following conditions: 
    13# 
    14# The above copyright notice and this permission notice shall be included in 
    15# all copies or substantial portions of the Software. 
    16# 
    17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 
    18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 
    19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 
    20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 
    21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 
    22# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 
    23# IN THE SOFTWARE. 
    24# 
    25# -------------------------------------------------------------------------- 
    26import time 
    27from typing import Optional, Union, MutableMapping, List, Any, Sequence, TypeVar, Generic 
    28 
    29from azure.core.credentials import AccessToken, TokenCredential 
    30from azure.core.credentials_async import AsyncTokenCredential 
    31from azure.core.pipeline.policies import BearerTokenCredentialPolicy, SansIOHTTPPolicy 
    32from azure.core.pipeline import PipelineRequest 
    33from azure.core.exceptions import ServiceRequestError 
    34from azure.core.pipeline.transport import ( 
    35    HttpRequest as LegacyHttpRequest, 
    36    HttpResponse as LegacyHttpResponse, 
    37) 
    38from azure.core.rest import HttpRequest, HttpResponse 
    39 
    40 
    41HTTPRequestType = Union[LegacyHttpRequest, HttpRequest] 
    42HTTPResponseType = Union[LegacyHttpResponse, HttpResponse] 
    43TokenCredentialType = TypeVar("TokenCredentialType", bound=Union[TokenCredential, AsyncTokenCredential]) 
    44 
    45 
    46class ARMChallengeAuthenticationPolicy(BearerTokenCredentialPolicy): 
    47    """Adds a bearer token Authorization header to requests. 
    48 
    49    This policy internally handles Continuous Access Evaluation (CAE) challenges. When it can't complete a challenge, 
    50    it will return the 401 (unauthorized) response from ARM. 
    51    """ 
    52 
    53 
    54# pylint:disable=too-few-public-methods 
    55class _AuxiliaryAuthenticationPolicyBase(Generic[TokenCredentialType]): 
    56    """Adds auxiliary authorization token header to requests. 
    57 
    58    :param ~azure.core.credentials.TokenCredential auxiliary_credentials: auxiliary credential for authorizing requests 
    59    :param str scopes: required authentication scopes 
    60    """ 
    61 
    62    def __init__(  # pylint: disable=unused-argument 
    63        self, auxiliary_credentials: Sequence[TokenCredentialType], *scopes: str, **kwargs: Any 
    64    ) -> None: 
    65        self._auxiliary_credentials = auxiliary_credentials 
    66        self._scopes = scopes 
    67        self._aux_tokens: Optional[List[AccessToken]] = None 
    68 
    69    @staticmethod 
    70    def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None: 
    71        # move 'enforce_https' from options to context, so it persists 
    72        # across retries but isn't passed to transport implementation 
    73        option = request.context.options.pop("enforce_https", None) 
    74 
    75        # True is the default setting; we needn't preserve an explicit opt in to the default behavior 
    76        if option is False: 
    77            request.context["enforce_https"] = option 
    78 
    79        enforce_https = request.context.get("enforce_https", True) 
    80        if enforce_https and not request.http_request.url.lower().startswith("https"): 
    81            raise ServiceRequestError( 
    82                "Bearer token authentication is not permitted for non-TLS protected (non-https) URLs." 
    83            ) 
    84 
    85    def _update_headers(self, headers: MutableMapping[str, str]) -> None: 
    86        """Updates the x-ms-authorization-auxiliary header with the auxiliary token. 
    87 
    88        :param dict headers: The HTTP Request headers 
    89        """ 
    90        if self._aux_tokens: 
    91            headers["x-ms-authorization-auxiliary"] = ", ".join( 
    92                "Bearer {}".format(token.token) for token in self._aux_tokens 
    93            ) 
    94 
    95    @property 
    96    def _need_new_aux_tokens(self) -> bool: 
    97        if not self._aux_tokens: 
    98            return True 
    99        for token in self._aux_tokens: 
    100            if token.expires_on - time.time() < 300: 
    101                return True 
    102        return False 
    103 
    104 
    105class AuxiliaryAuthenticationPolicy( 
    106    _AuxiliaryAuthenticationPolicyBase[TokenCredential], 
    107    SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType], 
    108): 
    109    def _get_auxiliary_tokens(self, *scopes: str, **kwargs: Any) -> Optional[List[AccessToken]]: 
    110        if self._auxiliary_credentials: 
    111            return [cred.get_token(*scopes, **kwargs) for cred in self._auxiliary_credentials] 
    112        return None 
    113 
    114    def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: 
    115        """Called before the policy sends a request. 
    116 
    117        The base implementation authorizes the request with an auxiliary authorization token. 
    118 
    119        :param ~azure.core.pipeline.PipelineRequest request: the request 
    120        """ 
    121        self._enforce_https(request) 
    122 
    123        if self._need_new_aux_tokens: 
    124            self._aux_tokens = self._get_auxiliary_tokens(*self._scopes) 
    125 
    126        self._update_headers(request.http_request.headers)