1# -------------------------------------------------------------------------- 
    2# 
    3# Copyright (c) Microsoft Corporation. All rights reserved. 
    4# 
    5# The MIT License (MIT) 
    6# 
    7# Permission is hereby granted, free of charge, to any person obtaining a copy 
    8# of this software and associated documentation files (the ""Software""), to 
    9# deal in the Software without restriction, including without limitation the 
    10# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 
    11# sell copies of the Software, and to permit persons to whom the Software is 
    12# furnished to do so, subject to the following conditions: 
    13# 
    14# The above copyright notice and this permission notice shall be included in 
    15# all copies or substantial portions of the Software. 
    16# 
    17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 
    18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 
    19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 
    20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 
    21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 
    22# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 
    23# IN THE SOFTWARE. 
    24# 
    25# -------------------------------------------------------------------------- 
    26from typing import Awaitable, Optional, List, Union, Any 
    27import inspect 
    28 
    29from azure.core.pipeline.policies import ( 
    30    AsyncBearerTokenCredentialPolicy, 
    31    AsyncHTTPPolicy, 
    32) 
    33from azure.core.pipeline import PipelineRequest, PipelineResponse 
    34from azure.core.pipeline.transport import ( 
    35    HttpRequest as LegacyHttpRequest, 
    36    AsyncHttpResponse as LegacyAsyncHttpResponse, 
    37) 
    38from azure.core.rest import HttpRequest, AsyncHttpResponse 
    39from azure.core.credentials import AccessToken 
    40from azure.core.credentials_async import AsyncTokenCredential 
    41 
    42 
    43from ._authentication import _AuxiliaryAuthenticationPolicyBase 
    44 
    45 
    46HTTPRequestType = Union[LegacyHttpRequest, HttpRequest] 
    47AsyncHTTPResponseType = Union[LegacyAsyncHttpResponse, AsyncHttpResponse] 
    48 
    49 
    50async def await_result(func, *args, **kwargs): 
    51    """If func returns an awaitable, await it. 
    52 
    53    :param callable func: Function to call 
    54    :param any args: Positional arguments to pass to func 
    55    :return: Result of func 
    56    :rtype: any 
    57    """ 
    58    result = func(*args, **kwargs) 
    59    if inspect.isawaitable(result): 
    60        return await result 
    61    return result 
    62 
    63 
    64class AsyncARMChallengeAuthenticationPolicy(AsyncBearerTokenCredentialPolicy): 
    65    """Adds a bearer token Authorization header to requests. 
    66 
    67    This policy internally handles Continuous Access Evaluation (CAE) challenges. When it can't complete a challenge, 
    68    it will return the 401 (unauthorized) response from ARM. 
    69    """ 
    70 
    71 
    72class AsyncAuxiliaryAuthenticationPolicy( 
    73    _AuxiliaryAuthenticationPolicyBase[AsyncTokenCredential], 
    74    AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType], 
    75): 
    76    async def _get_auxiliary_tokens(self, *scopes: str, **kwargs: Any) -> Optional[List[AccessToken]]: 
    77        if self._auxiliary_credentials: 
    78            return [await cred.get_token(*scopes, **kwargs) for cred in self._auxiliary_credentials] 
    79        return None 
    80 
    81    async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: 
    82        """Called before the policy sends a request. 
    83 
    84        The base implementation authorizes the request with an auxiliary authorization token. 
    85 
    86        :param ~azure.core.pipeline.PipelineRequest request: the request 
    87        """ 
    88        self._enforce_https(request) 
    89 
    90        if self._need_new_aux_tokens: 
    91            self._aux_tokens = await self._get_auxiliary_tokens(*self._scopes) 
    92 
    93        self._update_headers(request.http_request.headers) 
    94 
    95    def on_response( 
    96        self, 
    97        request: PipelineRequest[HTTPRequestType], 
    98        response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType], 
    99    ) -> Optional[Awaitable[None]]: 
    100        """Executed after the request comes back from the next policy. 
    101 
    102        :param request: Request to be modified after returning from the policy. 
    103        :type request: ~azure.core.pipeline.PipelineRequest 
    104        :param response: Pipeline response object 
    105        :type response: ~azure.core.pipeline.PipelineResponse 
    106        """ 
    107 
    108    def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None: 
    109        """Executed when an exception is raised while executing the next policy. 
    110 
    111        This method is executed inside the exception handler. 
    112 
    113        :param request: The Pipeline request object 
    114        :type request: ~azure.core.pipeline.PipelineRequest 
    115        """ 
    116        # pylint: disable=unused-argument 
    117        return 
    118 
    119    async def send( 
    120        self, request: PipelineRequest[HTTPRequestType] 
    121    ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: 
    122        """Authorize request with a bearer token and send it to the next policy 
    123 
    124        :param request: The pipeline request object 
    125        :type request: ~azure.core.pipeline.PipelineRequest 
    126        :return: The pipeline response object 
    127        :rtype: ~azure.core.pipeline.PipelineResponse 
    128        """ 
    129        await await_result(self.on_request, request) 
    130        try: 
    131            response = await self.next.send(request) 
    132            await await_result(self.on_response, request, response) 
    133        except Exception:  # pylint:disable=broad-except 
    134            handled = await await_result(self.on_exception, request) 
    135            if not handled: 
    136                raise 
    137        return response