1# --------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation. All rights reserved.
4#
5# The MIT License (MIT)
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the ""Software""), to
9# deal in the Software without restriction, including without limitation the
10# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
11# sell copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in
15# all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
22# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
23# IN THE SOFTWARE.
24#
25# --------------------------------------------------------------------------
26from __future__ import annotations
27import logging
28from typing import (
29 Generic,
30 TypeVar,
31 Union,
32 Any,
33 List,
34 Dict,
35 Optional,
36 Iterable,
37 ContextManager,
38)
39from azure.core.pipeline import (
40 PipelineRequest,
41 PipelineResponse,
42 PipelineContext,
43)
44from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy
45from ._tools import await_result as _await_result
46from .transport import HttpTransport
47
48HTTPResponseType = TypeVar("HTTPResponseType")
49HTTPRequestType = TypeVar("HTTPRequestType")
50
51_LOGGER = logging.getLogger(__name__)
52
53
54def cleanup_kwargs_for_transport(kwargs: Dict[str, str]) -> None:
55 """Remove kwargs that are not meant for the transport layer.
56
57 :param kwargs: The keyword arguments.
58 :type kwargs: dict
59
60 "insecure_domain_change" is used to indicate that a redirect
61 has occurred to a different domain. This tells the SensitiveHeaderCleanupPolicy
62 to clean up sensitive headers. We need to remove it before sending the request
63 to the transport layer. This code is needed to handle the case that the
64 SensitiveHeaderCleanupPolicy is not added into the pipeline and "insecure_domain_change" is not popped.
65 "enable_cae" is added to the `get_token` method of the `TokenCredential` protocol.
66 "tracing_options" is used in the DistributedTracingPolicy and tracing decorators.
67 """
68 kwargs_to_remove = ["insecure_domain_change", "enable_cae", "tracing_options"]
69 if not kwargs:
70 return
71 for key in kwargs_to_remove:
72 kwargs.pop(key, None)
73
74
75class _SansIOHTTPPolicyRunner(HTTPPolicy[HTTPRequestType, HTTPResponseType]):
76 """Sync implementation of the SansIO policy.
77
78 Modifies the request and sends to the next policy in the chain.
79
80 :param policy: A SansIO policy.
81 :type policy: ~azure.core.pipeline.policies.SansIOHTTPPolicy
82 """
83
84 def __init__(self, policy: SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]) -> None:
85 super(_SansIOHTTPPolicyRunner, self).__init__()
86 self._policy = policy
87
88 def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
89 """Modifies the request and sends to the next policy in the chain.
90
91 :param request: The PipelineRequest object.
92 :type request: ~azure.core.pipeline.PipelineRequest
93 :return: The PipelineResponse object.
94 :rtype: ~azure.core.pipeline.PipelineResponse
95 """
96 _await_result(self._policy.on_request, request)
97 try:
98 response = self.next.send(request)
99 except Exception:
100 _await_result(self._policy.on_exception, request)
101 raise
102 _await_result(self._policy.on_response, request, response)
103 return response
104
105
106class _TransportRunner(HTTPPolicy[HTTPRequestType, HTTPResponseType]):
107 """Transport runner.
108
109 Uses specified HTTP transport type to send request and returns response.
110
111 :param sender: The Http Transport instance.
112 :type sender: ~azure.core.pipeline.transport.HttpTransport
113 """
114
115 def __init__(self, sender: HttpTransport[HTTPRequestType, HTTPResponseType]) -> None:
116 super(_TransportRunner, self).__init__()
117 self._sender = sender
118
119 def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
120 """HTTP transport send method.
121
122 :param request: The PipelineRequest object.
123 :type request: ~azure.core.pipeline.PipelineRequest
124 :return: The PipelineResponse object.
125 :rtype: ~azure.core.pipeline.PipelineResponse
126 """
127 cleanup_kwargs_for_transport(request.context.options)
128 return PipelineResponse(
129 request.http_request,
130 self._sender.send(request.http_request, **request.context.options),
131 context=request.context,
132 )
133
134
135class Pipeline(ContextManager["Pipeline"], Generic[HTTPRequestType, HTTPResponseType]):
136 """A pipeline implementation.
137
138 This is implemented as a context manager, that will activate the context
139 of the HTTP sender. The transport is the last node in the pipeline.
140
141 :param transport: The Http Transport instance
142 :type transport: ~azure.core.pipeline.transport.HttpTransport
143 :param list policies: List of configured policies.
144
145 .. admonition:: Example:
146
147 .. literalinclude:: ../samples/test_example_sync.py
148 :start-after: [START build_pipeline]
149 :end-before: [END build_pipeline]
150 :language: python
151 :dedent: 4
152 :caption: Builds the pipeline for synchronous transport.
153 """
154
155 def __init__(
156 self,
157 transport: HttpTransport[HTTPRequestType, HTTPResponseType],
158 policies: Optional[
159 Iterable[
160 Union[
161 HTTPPolicy[HTTPRequestType, HTTPResponseType],
162 SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType],
163 ]
164 ]
165 ] = None,
166 ) -> None:
167 self._impl_policies: List[HTTPPolicy[HTTPRequestType, HTTPResponseType]] = []
168 self._transport = transport
169
170 for policy in policies or []:
171 if isinstance(policy, SansIOHTTPPolicy):
172 self._impl_policies.append(_SansIOHTTPPolicyRunner(policy))
173 elif policy:
174 self._impl_policies.append(policy)
175 for index in range(len(self._impl_policies) - 1):
176 self._impl_policies[index].next = self._impl_policies[index + 1]
177 if self._impl_policies:
178 self._impl_policies[-1].next = _TransportRunner(self._transport)
179
180 def __enter__(self) -> Pipeline[HTTPRequestType, HTTPResponseType]:
181 self._transport.__enter__()
182 return self
183
184 def __exit__(self, *exc_details: Any) -> None:
185 self._transport.__exit__(*exc_details)
186
187 @staticmethod
188 def _prepare_multipart_mixed_request(request: HTTPRequestType) -> None:
189 """Will execute the multipart policies.
190
191 Does nothing if "set_multipart_mixed" was never called.
192
193 :param request: The request object.
194 :type request: ~azure.core.rest.HttpRequest
195 """
196 multipart_mixed_info = request.multipart_mixed_info # type: ignore
197 if not multipart_mixed_info:
198 return
199
200 requests: List[HTTPRequestType] = multipart_mixed_info[0]
201 policies: List[SansIOHTTPPolicy] = multipart_mixed_info[1]
202 pipeline_options: Dict[str, Any] = multipart_mixed_info[3]
203
204 # Apply on_requests concurrently to all requests
205 import concurrent.futures
206
207 def prepare_requests(req):
208 if req.multipart_mixed_info:
209 # Recursively update changeset "sub requests"
210 Pipeline._prepare_multipart_mixed_request(req)
211 context = PipelineContext(None, **pipeline_options)
212 pipeline_request = PipelineRequest(req, context)
213 for policy in policies:
214 _await_result(policy.on_request, pipeline_request)
215
216 with concurrent.futures.ThreadPoolExecutor() as executor:
217 # List comprehension to raise exceptions if happened
218 [ # pylint: disable=expression-not-assigned, unnecessary-comprehension
219 _ for _ in executor.map(prepare_requests, requests)
220 ]
221
222 def _prepare_multipart(self, request: HTTPRequestType) -> None:
223 # This code is fine as long as HTTPRequestType is actually
224 # azure.core.pipeline.transport.HTTPRequest, bu we don't check it in here
225 # since we didn't see (yet) pipeline usage where it's not this actual instance
226 # class used
227 self._prepare_multipart_mixed_request(request)
228 request.prepare_multipart_body() # type: ignore
229
230 def run(self, request: HTTPRequestType, **kwargs: Any) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
231 """Runs the HTTP Request through the chained policies.
232
233 :param request: The HTTP request object.
234 :type request: ~azure.core.pipeline.transport.HttpRequest
235 :return: The PipelineResponse object
236 :rtype: ~azure.core.pipeline.PipelineResponse
237 """
238 self._prepare_multipart(request)
239 context = PipelineContext(self._transport, **kwargs)
240 pipeline_request: PipelineRequest[HTTPRequestType] = PipelineRequest(request, context)
241 first_node = self._impl_policies[0] if self._impl_policies else _TransportRunner(self._transport)
242 return first_node.send(pipeline_request)