1# --------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation. All rights reserved.
4#
5# The MIT License (MIT)
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the ""Software""), to
9# deal in the Software without restriction, including without limitation the
10# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
11# sell copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in
15# all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
22# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
23# IN THE SOFTWARE.
24#
25# --------------------------------------------------------------------------
26from __future__ import annotations
27import logging
28from typing import Generic, TypeVar, Union, Any, List, Dict, Optional, Iterable, ContextManager
29from azure.core.pipeline import (
30 PipelineRequest,
31 PipelineResponse,
32 PipelineContext,
33)
34from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy
35from ._tools import await_result as _await_result
36from .transport import HttpTransport
37
38HTTPResponseType = TypeVar("HTTPResponseType")
39HTTPRequestType = TypeVar("HTTPRequestType")
40
41_LOGGER = logging.getLogger(__name__)
42
43
44def cleanup_kwargs_for_transport(kwargs: Dict[str, str]) -> None:
45 """Remove kwargs that are not meant for the transport layer.
46 :param kwargs: The keyword arguments.
47 :type kwargs: dict
48
49 "insecure_domain_change" is used to indicate that a redirect
50 has occurred to a different domain. This tells the SensitiveHeaderCleanupPolicy
51 to clean up sensitive headers. We need to remove it before sending the request
52 to the transport layer. This code is needed to handle the case that the
53 SensitiveHeaderCleanupPolicy is not added into the pipeline and "insecure_domain_change" is not popped.
54 "enable_cae" is added to the `get_token` method of the `TokenCredential` protocol.
55 """
56 kwargs_to_remove = ["insecure_domain_change", "enable_cae"]
57 if not kwargs:
58 return
59 for key in kwargs_to_remove:
60 kwargs.pop(key, None)
61
62
63class _SansIOHTTPPolicyRunner(HTTPPolicy[HTTPRequestType, HTTPResponseType]):
64 """Sync implementation of the SansIO policy.
65
66 Modifies the request and sends to the next policy in the chain.
67
68 :param policy: A SansIO policy.
69 :type policy: ~azure.core.pipeline.policies.SansIOHTTPPolicy
70 """
71
72 def __init__(self, policy: SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]) -> None:
73 super(_SansIOHTTPPolicyRunner, self).__init__()
74 self._policy = policy
75
76 def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
77 """Modifies the request and sends to the next policy in the chain.
78
79 :param request: The PipelineRequest object.
80 :type request: ~azure.core.pipeline.PipelineRequest
81 :return: The PipelineResponse object.
82 :rtype: ~azure.core.pipeline.PipelineResponse
83 """
84 _await_result(self._policy.on_request, request)
85 try:
86 response = self.next.send(request)
87 except Exception: # pylint: disable=broad-except
88 _await_result(self._policy.on_exception, request)
89 raise
90 _await_result(self._policy.on_response, request, response)
91 return response
92
93
94class _TransportRunner(HTTPPolicy[HTTPRequestType, HTTPResponseType]):
95 """Transport runner.
96
97 Uses specified HTTP transport type to send request and returns response.
98
99 :param sender: The Http Transport instance.
100 :type sender: ~azure.core.pipeline.transport.HttpTransport
101 """
102
103 def __init__(self, sender: HttpTransport[HTTPRequestType, HTTPResponseType]) -> None:
104 super(_TransportRunner, self).__init__()
105 self._sender = sender
106
107 def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
108 """HTTP transport send method.
109
110 :param request: The PipelineRequest object.
111 :type request: ~azure.core.pipeline.PipelineRequest
112 :return: The PipelineResponse object.
113 :rtype: ~azure.core.pipeline.PipelineResponse
114 """
115 cleanup_kwargs_for_transport(request.context.options)
116 return PipelineResponse(
117 request.http_request,
118 self._sender.send(request.http_request, **request.context.options),
119 context=request.context,
120 )
121
122
123class Pipeline(ContextManager["Pipeline"], Generic[HTTPRequestType, HTTPResponseType]):
124 """A pipeline implementation.
125
126 This is implemented as a context manager, that will activate the context
127 of the HTTP sender. The transport is the last node in the pipeline.
128
129 :param transport: The Http Transport instance
130 :type transport: ~azure.core.pipeline.transport.HttpTransport
131 :param list policies: List of configured policies.
132
133 .. admonition:: Example:
134
135 .. literalinclude:: ../samples/test_example_sync.py
136 :start-after: [START build_pipeline]
137 :end-before: [END build_pipeline]
138 :language: python
139 :dedent: 4
140 :caption: Builds the pipeline for synchronous transport.
141 """
142
143 def __init__(
144 self,
145 transport: HttpTransport[HTTPRequestType, HTTPResponseType],
146 policies: Optional[
147 Iterable[
148 Union[
149 HTTPPolicy[HTTPRequestType, HTTPResponseType], SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]
150 ]
151 ]
152 ] = None,
153 ) -> None:
154 self._impl_policies: List[HTTPPolicy[HTTPRequestType, HTTPResponseType]] = []
155 self._transport = transport
156
157 for policy in policies or []:
158 if isinstance(policy, SansIOHTTPPolicy):
159 self._impl_policies.append(_SansIOHTTPPolicyRunner(policy))
160 elif policy:
161 self._impl_policies.append(policy)
162 for index in range(len(self._impl_policies) - 1):
163 self._impl_policies[index].next = self._impl_policies[index + 1]
164 if self._impl_policies:
165 self._impl_policies[-1].next = _TransportRunner(self._transport)
166
167 def __enter__(self) -> Pipeline[HTTPRequestType, HTTPResponseType]:
168 self._transport.__enter__()
169 return self
170
171 def __exit__(self, *exc_details: Any) -> None: # pylint: disable=arguments-differ
172 self._transport.__exit__(*exc_details)
173
174 @staticmethod
175 def _prepare_multipart_mixed_request(request: HTTPRequestType) -> None:
176 """Will execute the multipart policies.
177
178 Does nothing if "set_multipart_mixed" was never called.
179
180 :param request: The request object.
181 :type request: ~azure.core.rest.HttpRequest
182 """
183 multipart_mixed_info = request.multipart_mixed_info # type: ignore
184 if not multipart_mixed_info:
185 return
186
187 requests: List[HTTPRequestType] = multipart_mixed_info[0]
188 policies: List[SansIOHTTPPolicy] = multipart_mixed_info[1]
189 pipeline_options: Dict[str, Any] = multipart_mixed_info[3]
190
191 # Apply on_requests concurrently to all requests
192 import concurrent.futures
193
194 def prepare_requests(req):
195 if req.multipart_mixed_info:
196 # Recursively update changeset "sub requests"
197 Pipeline._prepare_multipart_mixed_request(req)
198 context = PipelineContext(None, **pipeline_options)
199 pipeline_request = PipelineRequest(req, context)
200 for policy in policies:
201 _await_result(policy.on_request, pipeline_request)
202
203 with concurrent.futures.ThreadPoolExecutor() as executor:
204 # List comprehension to raise exceptions if happened
205 [ # pylint: disable=expression-not-assigned, unnecessary-comprehension
206 _ for _ in executor.map(prepare_requests, requests)
207 ]
208
209 def _prepare_multipart(self, request: HTTPRequestType) -> None:
210 # This code is fine as long as HTTPRequestType is actually
211 # azure.core.pipeline.transport.HTTPRequest, bu we don't check it in here
212 # since we didn't see (yet) pipeline usage where it's not this actual instance
213 # class used
214 self._prepare_multipart_mixed_request(request)
215 request.prepare_multipart_body() # type: ignore
216
217 def run(self, request: HTTPRequestType, **kwargs: Any) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
218 """Runs the HTTP Request through the chained policies.
219
220 :param request: The HTTP request object.
221 :type request: ~azure.core.pipeline.transport.HttpRequest
222 :return: The PipelineResponse object
223 :rtype: ~azure.core.pipeline.PipelineResponse
224 """
225 self._prepare_multipart(request)
226 context = PipelineContext(self._transport, **kwargs)
227 pipeline_request: PipelineRequest[HTTPRequestType] = PipelineRequest(request, context)
228 first_node = self._impl_policies[0] if self._impl_policies else _TransportRunner(self._transport)
229 return first_node.send(pipeline_request)