1# -------------------------------------------------------------------------- 
    2# 
    3# Copyright (c) Microsoft Corporation. All rights reserved. 
    4# 
    5# The MIT License (MIT) 
    6# 
    7# Permission is hereby granted, free of charge, to any person obtaining a copy 
    8# of this software and associated documentation files (the ""Software""), to deal 
    9# in the Software without restriction, including without limitation the rights 
    10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 
    11# copies of the Software, and to permit persons to whom the Software is 
    12# furnished to do so, subject to the following conditions: 
    13# 
    14# The above copyright notice and this permission notice shall be included in 
    15# all copies or substantial portions of the Software. 
    16# 
    17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 
    18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 
    19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 
    20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 
    21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
    22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 
    23# THE SOFTWARE. 
    24# 
    25# -------------------------------------------------------------------------- 
    26"""Traces network calls using the implementation library from the settings.""" 
    27import logging 
    28import sys 
    29import urllib.parse 
    30from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union, Any, Type, Mapping, Dict 
    31from types import TracebackType 
    32 
    33from azure.core.pipeline import PipelineRequest, PipelineResponse 
    34from azure.core.pipeline.policies import SansIOHTTPPolicy 
    35from azure.core.pipeline.transport import ( 
    36    HttpResponse as LegacyHttpResponse, 
    37    HttpRequest as LegacyHttpRequest, 
    38) 
    39from azure.core.rest import HttpResponse, HttpRequest 
    40from azure.core.settings import settings 
    41from azure.core.tracing import SpanKind 
    42from azure.core.tracing.common import change_context 
    43from azure.core.instrumentation import get_tracer 
    44from azure.core.tracing._models import TracingOptions 
    45 
    46if TYPE_CHECKING: 
    47    from opentelemetry.trace import Span 
    48 
    49HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse) 
    50HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) 
    51ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] 
    52OptExcInfo = Union[ExcInfo, Tuple[None, None, None]] 
    53 
    54_LOGGER = logging.getLogger(__name__) 
    55 
    56 
    57def _default_network_span_namer(http_request: HTTPRequestType) -> str: 
    58    """Extract the path to be used as network span name. 
    59 
    60    :param http_request: The HTTP request 
    61    :type http_request: ~azure.core.pipeline.transport.HttpRequest 
    62    :returns: The string to use as network span name 
    63    :rtype: str 
    64    """ 
    65    return http_request.method 
    66 
    67 
    68class DistributedTracingPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): 
    69    """The policy to create spans for Azure calls. 
    70 
    71    :keyword network_span_namer: A callable to customize the span name 
    72    :type network_span_namer: callable[[~azure.core.pipeline.transport.HttpRequest], str] 
    73    :keyword tracing_attributes: Attributes to set on all created spans 
    74    :type tracing_attributes: dict[str, str] 
    75    :keyword instrumentation_config: Configuration for the instrumentation providers 
    76    :type instrumentation_config: dict[str, Any] 
    77    """ 
    78 
    79    TRACING_CONTEXT = "TRACING_CONTEXT" 
    80    _SUPPRESSION_TOKEN = "SUPPRESSION_TOKEN" 
    81 
    82    # Current stable HTTP semantic conventions 
    83    _HTTP_RESEND_COUNT = "http.request.resend_count" 
    84    _USER_AGENT_ORIGINAL = "user_agent.original" 
    85    _HTTP_REQUEST_METHOD = "http.request.method" 
    86    _URL_FULL = "url.full" 
    87    _HTTP_RESPONSE_STATUS_CODE = "http.response.status_code" 
    88    _SERVER_ADDRESS = "server.address" 
    89    _SERVER_PORT = "server.port" 
    90    _ERROR_TYPE = "error.type" 
    91 
    92    # Azure attributes 
    93    _REQUEST_ID = "x-ms-client-request-id" 
    94    _REQUEST_ID_ATTR = "az.client_request_id" 
    95    _RESPONSE_ID = "x-ms-request-id" 
    96    _RESPONSE_ID_ATTR = "az.service_request_id" 
    97 
    98    def __init__(self, *, instrumentation_config: Optional[Mapping[str, Any]] = None, **kwargs: Any): 
    99        self._network_span_namer = kwargs.get("network_span_namer", _default_network_span_namer) 
    100        self._tracing_attributes = kwargs.get("tracing_attributes", {}) 
    101        self._instrumentation_config = instrumentation_config 
    102 
    103    def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: 
    104        """Starts a span for the network call. 
    105 
    106        :param request: The PipelineRequest object 
    107        :type request: ~azure.core.pipeline.PipelineRequest 
    108        """ 
    109        ctxt = request.context.options 
    110        try: 
    111            tracing_options: TracingOptions = ctxt.pop("tracing_options", {}) 
    112            tracing_enabled = settings.tracing_enabled() 
    113 
    114            # User can explicitly disable tracing for this request. 
    115            user_enabled = tracing_options.get("enabled") 
    116            if user_enabled is False: 
    117                return 
    118 
    119            # If tracing is disabled globally and user didn't explicitly enable it, don't trace. 
    120            if not tracing_enabled and user_enabled is None: 
    121                return 
    122 
    123            span_impl_type = settings.tracing_implementation() 
    124            namer = ctxt.pop("network_span_namer", self._network_span_namer) 
    125            tracing_attributes = ctxt.pop("tracing_attributes", self._tracing_attributes) 
    126            span_name = namer(request.http_request) 
    127 
    128            span_attributes = {**tracing_attributes, **tracing_options.get("attributes", {})} 
    129 
    130            if span_impl_type: 
    131                # If the plugin is enabled, prioritize it over the core tracing. 
    132                span = span_impl_type(name=span_name, kind=SpanKind.CLIENT) 
    133                for attr, value in span_attributes.items(): 
    134                    span.add_attribute(attr, value)  # type: ignore 
    135 
    136                with change_context(span.span_instance): 
    137                    headers = span.to_header() 
    138                    request.http_request.headers.update(headers) 
    139                request.context[self.TRACING_CONTEXT] = span 
    140            else: 
    141                # Otherwise, use the core tracing. 
    142                config = self._instrumentation_config or {} 
    143                tracer = get_tracer( 
    144                    library_name=config.get("library_name"), 
    145                    library_version=config.get("library_version"), 
    146                    attributes=config.get("attributes"), 
    147                ) 
    148                if not tracer: 
    149                    _LOGGER.warning( 
    150                        "Tracing is enabled, but not able to get an OpenTelemetry tracer. " 
    151                        "Please ensure that `opentelemetry-api` is installed." 
    152                    ) 
    153                    return 
    154 
    155                otel_span = tracer.start_span( 
    156                    name=span_name, 
    157                    kind=SpanKind.CLIENT, 
    158                    attributes=span_attributes, 
    159                ) 
    160 
    161                with tracer.use_span(otel_span, end_on_exit=False): 
    162                    trace_context_headers = tracer.get_trace_context() 
    163                    request.http_request.headers.update(trace_context_headers) 
    164 
    165                request.context[self.TRACING_CONTEXT] = otel_span 
    166                token = tracer._suppress_auto_http_instrumentation()  # pylint: disable=protected-access 
    167                request.context[self._SUPPRESSION_TOKEN] = token 
    168 
    169        except Exception:  # pylint: disable=broad-except 
    170            _LOGGER.warning("Unable to start network span.") 
    171 
    172    def end_span( 
    173        self, 
    174        request: PipelineRequest[HTTPRequestType], 
    175        response: Optional[HTTPResponseType] = None, 
    176        exc_info: Optional[OptExcInfo] = None, 
    177    ) -> None: 
    178        """Ends the span that is tracing the network and updates its status. 
    179 
    180        :param request: The PipelineRequest object 
    181        :type request: ~azure.core.pipeline.PipelineRequest 
    182        :param response: The HttpResponse object 
    183        :type response: ~azure.core.rest.HTTPResponse or ~azure.core.pipeline.transport.HttpResponse 
    184        :param exc_info: The exception information 
    185        :type exc_info: tuple 
    186        """ 
    187        if self.TRACING_CONTEXT not in request.context: 
    188            return 
    189 
    190        span = request.context[self.TRACING_CONTEXT] 
    191        if not span: 
    192            return 
    193 
    194        http_request: Union[HttpRequest, LegacyHttpRequest] = request.http_request 
    195 
    196        attributes: Dict[str, Any] = {} 
    197        if request.context.get("retry_count"): 
    198            attributes[self._HTTP_RESEND_COUNT] = request.context["retry_count"] 
    199        if http_request.headers.get(self._REQUEST_ID): 
    200            attributes[self._REQUEST_ID_ATTR] = http_request.headers[self._REQUEST_ID] 
    201        if response and self._RESPONSE_ID in response.headers: 
    202            attributes[self._RESPONSE_ID_ATTR] = response.headers[self._RESPONSE_ID] 
    203 
    204        # We'll determine if the span is from a plugin or the core tracing library based on the presence of the 
    205        # `set_http_attributes` method. 
    206        if hasattr(span, "set_http_attributes"): 
    207            # Plugin-based tracing 
    208            span.set_http_attributes(request=http_request, response=response) 
    209            for key, value in attributes.items(): 
    210                span.add_attribute(key, value) 
    211            if exc_info: 
    212                span.__exit__(*exc_info) 
    213            else: 
    214                span.finish() 
    215        else: 
    216            # Native tracing 
    217            self._set_http_client_span_attributes(span, request=http_request, response=response) 
    218            span.set_attributes(attributes) 
    219            if exc_info: 
    220                # If there was an exception, set the error.type attribute. 
    221                exception_type = exc_info[0] 
    222                if exception_type: 
    223                    module = exception_type.__module__ if exception_type.__module__ != "builtins" else "" 
    224                    error_type = f"{module}.{exception_type.__qualname__}" if module else exception_type.__qualname__ 
    225                    span.set_attribute(self._ERROR_TYPE, error_type) 
    226 
    227                span.__exit__(*exc_info) 
    228            else: 
    229                span.end() 
    230 
    231        suppression_token = request.context.get(self._SUPPRESSION_TOKEN) 
    232        if suppression_token: 
    233            tracer = get_tracer() 
    234            if tracer: 
    235                tracer._detach_from_context(suppression_token)  # pylint: disable=protected-access 
    236 
    237    def on_response( 
    238        self, 
    239        request: PipelineRequest[HTTPRequestType], 
    240        response: PipelineResponse[HTTPRequestType, HTTPResponseType], 
    241    ) -> None: 
    242        """Ends the span for the network call and updates its status. 
    243 
    244        :param request: The PipelineRequest object 
    245        :type request: ~azure.core.pipeline.PipelineRequest 
    246        :param response: The PipelineResponse object 
    247        :type response: ~azure.core.pipeline.PipelineResponse 
    248        """ 
    249        self.end_span(request, response=response.http_response) 
    250 
    251    def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None: 
    252        """Ends the span for the network call and updates its status with exception info. 
    253 
    254        :param request: The PipelineRequest object 
    255        :type request: ~azure.core.pipeline.PipelineRequest 
    256        """ 
    257        self.end_span(request, exc_info=sys.exc_info()) 
    258 
    259    def _set_http_client_span_attributes( 
    260        self, 
    261        span: "Span", 
    262        request: Union[HttpRequest, LegacyHttpRequest], 
    263        response: Optional[HTTPResponseType] = None, 
    264    ) -> None: 
    265        """Add attributes to an HTTP client span. 
    266 
    267        :param span: The span to add attributes to. 
    268        :type span: ~opentelemetry.trace.Span 
    269        :param request: The request made 
    270        :type request: ~azure.core.rest.HttpRequest 
    271        :param response: The response received from the server. Is None if no response received. 
    272        :type response: ~azure.core.rest.HTTPResponse or ~azure.core.pipeline.transport.HttpResponse 
    273        """ 
    274        attributes: Dict[str, Any] = { 
    275            self._HTTP_REQUEST_METHOD: request.method, 
    276            self._URL_FULL: request.url, 
    277        } 
    278 
    279        parsed_url = urllib.parse.urlparse(request.url) 
    280        if parsed_url.hostname: 
    281            attributes[self._SERVER_ADDRESS] = parsed_url.hostname 
    282        if parsed_url.port: 
    283            attributes[self._SERVER_PORT] = parsed_url.port 
    284 
    285        user_agent = request.headers.get("User-Agent") 
    286        if user_agent: 
    287            attributes[self._USER_AGENT_ORIGINAL] = user_agent 
    288        if response and response.status_code: 
    289            attributes[self._HTTP_RESPONSE_STATUS_CODE] = response.status_code 
    290            if response.status_code >= 400: 
    291                attributes[self._ERROR_TYPE] = str(response.status_code) 
    292 
    293        span.set_attributes(attributes)