1# --------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation. All rights reserved.
4#
5# The MIT License (MIT)
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the ""Software""), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in
15# all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23# THE SOFTWARE.
24#
25# --------------------------------------------------------------------------
26"""Traces network calls using the implementation library from the settings."""
27import logging
28import sys
29import urllib.parse
30from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union, Any, Type, Mapping, Dict, Iterable
31from types import TracebackType
32
33from azure.core.pipeline import PipelineRequest, PipelineResponse
34from azure.core.pipeline.policies import SansIOHTTPPolicy
35from azure.core.pipeline.transport import (
36 HttpResponse as LegacyHttpResponse,
37 HttpRequest as LegacyHttpRequest,
38)
39from azure.core.rest import HttpResponse, HttpRequest
40from azure.core.settings import settings
41from azure.core.tracing import SpanKind
42from azure.core.tracing.common import change_context
43from azure.core.instrumentation import get_tracer
44from azure.core.tracing._models import TracingOptions
45from azure.core.pipeline.policies._utils import sanitize_url
46from azure.core.utils._utils import CaseInsensitiveSet
47
48if TYPE_CHECKING:
49 from opentelemetry.trace import Span
50
51HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse)
52HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
53ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
54OptExcInfo = Union[ExcInfo, Tuple[None, None, None]]
55
56_LOGGER = logging.getLogger(__name__)
57
58
59def _default_network_span_namer(http_request: HTTPRequestType) -> str:
60 """Extract the path to be used as network span name.
61
62 :param http_request: The HTTP request
63 :type http_request: ~azure.core.pipeline.transport.HttpRequest
64 :returns: The string to use as network span name
65 :rtype: str
66 """
67 return http_request.method
68
69
70class DistributedTracingPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
71 """The policy to create spans for Azure calls.
72
73 :keyword network_span_namer: A callable to customize the span name
74 :type network_span_namer: callable[[~azure.core.pipeline.transport.HttpRequest], str]
75 :keyword tracing_attributes: Attributes to set on all created spans
76 :type tracing_attributes: dict[str, str]
77 :keyword instrumentation_config: Configuration for the instrumentation providers
78 :type instrumentation_config: dict[str, Any]
79 :keyword additional_allowed_query_params: Query parameter names whose values are allowed in recorded URLs.
80 These are added to the default set which includes "api-version".
81 :type additional_allowed_query_params: Iterable[str]
82 """
83
84 TRACING_CONTEXT = "TRACING_CONTEXT"
85 _SUPPRESSION_TOKEN = "SUPPRESSION_TOKEN"
86
87 DEFAULT_QUERY_PARAMS_ALLOWLIST: set[str] = set(["api-version"])
88 _REDACTED_PLACEHOLDER = "REDACTED"
89
90 # Current stable HTTP semantic conventions
91 _HTTP_RESEND_COUNT = "http.request.resend_count"
92 _USER_AGENT_ORIGINAL = "user_agent.original"
93 _HTTP_REQUEST_METHOD = "http.request.method"
94 _URL_FULL = "url.full"
95 _HTTP_RESPONSE_STATUS_CODE = "http.response.status_code"
96 _SERVER_ADDRESS = "server.address"
97 _SERVER_PORT = "server.port"
98 _ERROR_TYPE = "error.type"
99
100 # Azure attributes
101 _REQUEST_ID = "x-ms-client-request-id"
102 _REQUEST_ID_ATTR = "az.client_request_id"
103 _RESPONSE_ID = "x-ms-request-id"
104 _RESPONSE_ID_ATTR = "az.service_request_id"
105
106 def __init__(
107 self,
108 *,
109 instrumentation_config: Optional[Mapping[str, Any]] = None,
110 additional_allowed_query_params: Optional[Iterable[str]] = None,
111 **kwargs: Any,
112 ):
113 self._network_span_namer = kwargs.get("network_span_namer", _default_network_span_namer)
114 self._tracing_attributes = kwargs.get("tracing_attributes", {})
115 self._instrumentation_config = instrumentation_config
116 self.allowed_query_params: set[str] = CaseInsensitiveSet(self.__class__.DEFAULT_QUERY_PARAMS_ALLOWLIST)
117 if additional_allowed_query_params:
118 self.allowed_query_params.update(additional_allowed_query_params)
119
120 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
121 """Starts a span for the network call.
122
123 :param request: The PipelineRequest object
124 :type request: ~azure.core.pipeline.PipelineRequest
125 """
126 ctxt = request.context.options
127 try:
128 tracing_options: TracingOptions = ctxt.pop("tracing_options", {})
129 tracing_enabled = settings.tracing_enabled()
130
131 # User can explicitly disable tracing for this request.
132 user_enabled = tracing_options.get("enabled")
133 if user_enabled is False:
134 return
135
136 # If tracing is disabled globally and user didn't explicitly enable it, don't trace.
137 if not tracing_enabled and user_enabled is None:
138 return
139
140 span_impl_type = settings.tracing_implementation()
141 namer = ctxt.pop("network_span_namer", self._network_span_namer)
142 tracing_attributes = ctxt.pop("tracing_attributes", self._tracing_attributes)
143 span_name = namer(request.http_request)
144
145 span_attributes = {**tracing_attributes, **tracing_options.get("attributes", {})}
146
147 if span_impl_type:
148 # If the plugin is enabled, prioritize it over the core tracing.
149 span = span_impl_type(name=span_name, kind=SpanKind.CLIENT)
150 for attr, value in span_attributes.items():
151 span.add_attribute(attr, value) # type: ignore
152
153 with change_context(span.span_instance):
154 headers = span.to_header()
155 request.http_request.headers.update(headers)
156 request.context[self.TRACING_CONTEXT] = span
157 else:
158 # Otherwise, use the core tracing.
159 config = self._instrumentation_config or {}
160 tracer = get_tracer(
161 library_name=config.get("library_name"),
162 library_version=config.get("library_version"),
163 attributes=config.get("attributes"),
164 )
165 if not tracer:
166 _LOGGER.warning(
167 "Tracing is enabled, but not able to get an OpenTelemetry tracer. "
168 "Please ensure that `opentelemetry-api` is installed."
169 )
170 return
171
172 otel_span = tracer.start_span(
173 name=span_name,
174 kind=SpanKind.CLIENT,
175 attributes=span_attributes,
176 )
177
178 with tracer.use_span(otel_span, end_on_exit=False):
179 trace_context_headers = tracer.get_trace_context()
180 request.http_request.headers.update(trace_context_headers)
181
182 request.context[self.TRACING_CONTEXT] = otel_span
183 token = tracer._suppress_auto_http_instrumentation() # pylint: disable=protected-access
184 request.context[self._SUPPRESSION_TOKEN] = token
185
186 except Exception: # pylint: disable=broad-except
187 _LOGGER.warning("Unable to start network span.")
188
189 def end_span(
190 self,
191 request: PipelineRequest[HTTPRequestType],
192 response: Optional[HTTPResponseType] = None,
193 exc_info: Optional[OptExcInfo] = None,
194 ) -> None:
195 """Ends the span that is tracing the network and updates its status.
196
197 :param request: The PipelineRequest object
198 :type request: ~azure.core.pipeline.PipelineRequest
199 :param response: The HttpResponse object
200 :type response: ~azure.core.rest.HTTPResponse or ~azure.core.pipeline.transport.HttpResponse
201 :param exc_info: The exception information
202 :type exc_info: tuple
203 """
204 if self.TRACING_CONTEXT not in request.context:
205 return
206
207 span = request.context[self.TRACING_CONTEXT]
208 if not span:
209 return
210
211 http_request: Union[HttpRequest, LegacyHttpRequest] = request.http_request
212
213 attributes: Dict[str, Any] = {}
214 if request.context.get("retry_count"):
215 attributes[self._HTTP_RESEND_COUNT] = request.context["retry_count"]
216 if http_request.headers.get(self._REQUEST_ID):
217 attributes[self._REQUEST_ID_ATTR] = http_request.headers[self._REQUEST_ID]
218 if response and self._RESPONSE_ID in response.headers:
219 attributes[self._RESPONSE_ID_ATTR] = response.headers[self._RESPONSE_ID]
220
221 # We'll determine if the span is from a plugin or the core tracing library based on the presence of the
222 # `set_http_attributes` method.
223 if hasattr(span, "set_http_attributes"):
224 # Plugin-based tracing
225 span.set_http_attributes(request=http_request, response=response)
226 span.add_attribute(
227 "http.url", sanitize_url(http_request.url, self.allowed_query_params, self._REDACTED_PLACEHOLDER)
228 )
229 for key, value in attributes.items():
230 span.add_attribute(key, value)
231 if exc_info:
232 span.__exit__(*exc_info)
233 else:
234 span.finish()
235 else:
236 # Native tracing
237 self._set_http_client_span_attributes(span, request=http_request, response=response)
238 span.set_attributes(attributes)
239 if exc_info:
240 # If there was an exception, set the error.type attribute.
241 exception_type = exc_info[0]
242 if exception_type:
243 module = exception_type.__module__ if exception_type.__module__ != "builtins" else ""
244 error_type = f"{module}.{exception_type.__qualname__}" if module else exception_type.__qualname__
245 span.set_attribute(self._ERROR_TYPE, error_type)
246
247 span.__exit__(*exc_info)
248 else:
249 span.end()
250
251 suppression_token = request.context.get(self._SUPPRESSION_TOKEN)
252 if suppression_token:
253 tracer = get_tracer()
254 if tracer:
255 tracer._detach_from_context(suppression_token) # pylint: disable=protected-access
256
257 def on_response(
258 self,
259 request: PipelineRequest[HTTPRequestType],
260 response: PipelineResponse[HTTPRequestType, HTTPResponseType],
261 ) -> None:
262 """Ends the span for the network call and updates its status.
263
264 :param request: The PipelineRequest object
265 :type request: ~azure.core.pipeline.PipelineRequest
266 :param response: The PipelineResponse object
267 :type response: ~azure.core.pipeline.PipelineResponse
268 """
269 self.end_span(request, response=response.http_response)
270
271 def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:
272 """Ends the span for the network call and updates its status with exception info.
273
274 :param request: The PipelineRequest object
275 :type request: ~azure.core.pipeline.PipelineRequest
276 """
277 self.end_span(request, exc_info=sys.exc_info())
278
279 def _set_http_client_span_attributes(
280 self,
281 span: "Span",
282 request: Union[HttpRequest, LegacyHttpRequest],
283 response: Optional[HTTPResponseType] = None,
284 ) -> None:
285 """Add attributes to an HTTP client span.
286
287 :param span: The span to add attributes to.
288 :type span: ~opentelemetry.trace.Span
289 :param request: The request made
290 :type request: ~azure.core.rest.HttpRequest
291 :param response: The response received from the server. Is None if no response received.
292 :type response: ~azure.core.rest.HTTPResponse or ~azure.core.pipeline.transport.HttpResponse
293 """
294 attributes: Dict[str, Any] = {
295 self._HTTP_REQUEST_METHOD: request.method,
296 self._URL_FULL: sanitize_url(request.url, self.allowed_query_params, self._REDACTED_PLACEHOLDER),
297 }
298
299 parsed_url = urllib.parse.urlparse(request.url)
300 if parsed_url.hostname:
301 attributes[self._SERVER_ADDRESS] = parsed_url.hostname
302 if parsed_url.port:
303 attributes[self._SERVER_PORT] = parsed_url.port
304
305 user_agent = request.headers.get("User-Agent")
306 if user_agent:
307 attributes[self._USER_AGENT_ORIGINAL] = user_agent
308 if response and response.status_code:
309 attributes[self._HTTP_RESPONSE_STATUS_CODE] = response.status_code
310 if response.status_code >= 400:
311 attributes[self._ERROR_TYPE] = str(response.status_code)
312
313 span.set_attributes(attributes)