1# Licensed under the Apache License, Version 2.0 (the "License"); you may
2# not use this file except in compliance with the License. You may obtain
3# a copy of the License at
4#
5# https://www.apache.org/licenses/LICENSE-2.0
6#
7# Unless required by applicable law or agreed to in writing, software
8# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10# License for the specific language governing permissions and limitations
11# under the License.
12
13import urllib.parse
14import weakref
15
16from requests.adapters import BaseAdapter
17from requests.utils import requote_uri
18
19from requests_mock import exceptions
20from requests_mock.request import _RequestObjectProxy
21from requests_mock.response import _MatcherResponse
22
23import logging
24
25logger = logging.getLogger(__name__)
26
27try:
28 import purl
29 purl_types = (purl.URL,)
30except ImportError:
31 purl = None
32 purl_types = ()
33
34ANY = object()
35
36
37class _RequestHistoryTracker(object):
38
39 def __init__(self):
40 self.request_history = []
41
42 def _add_to_history(self, request):
43 self.request_history.append(request)
44
45 @property
46 def last_request(self):
47 """Retrieve the latest request sent"""
48 try:
49 return self.request_history[-1]
50 except IndexError:
51 return None
52
53 @property
54 def called(self):
55 return self.call_count > 0
56
57 @property
58 def called_once(self):
59 return self.call_count == 1
60
61 @property
62 def call_count(self):
63 return len(self.request_history)
64
65 def reset(self):
66 self.request_history = []
67
68
69class _RunRealHTTP(Exception):
70 """A fake exception to jump out of mocking and allow a real request.
71
72 This exception is caught at the mocker level and allows it to execute this
73 request through the real requests mechanism rather than the mocker.
74
75 It should never be exposed to a user.
76 """
77
78
79class _Matcher(_RequestHistoryTracker):
80 """Contains all the information about a provided URL to match."""
81
82 def __init__(self, method, url, responses, complete_qs, request_headers,
83 additional_matcher, real_http, case_sensitive):
84 """
85 :param bool complete_qs: Match the entire query string. By default URLs
86 match if all the provided matcher query arguments are matched and
87 extra query arguments are ignored. Set complete_qs to true to
88 require that the entire query string needs to match.
89 """
90 super(_Matcher, self).__init__()
91
92 self._method = method
93 self._url = url
94 self._responses = responses
95 self._complete_qs = complete_qs
96 self._request_headers = request_headers
97 self._real_http = real_http
98 self._additional_matcher = additional_matcher
99
100 # url can be a regex object or ANY so don't always run urlparse
101 if isinstance(url, str):
102 url_parts = urllib.parse.urlparse(url)
103 self._scheme = url_parts.scheme.lower()
104 self._netloc = url_parts.netloc.lower()
105 self._path = requote_uri(url_parts.path or '/')
106 self._query = url_parts.query
107
108 if not case_sensitive:
109 self._path = self._path.lower()
110 self._query = self._query.lower()
111
112 elif isinstance(url, purl_types):
113 self._scheme = url.scheme()
114 self._netloc = url.netloc()
115 self._path = url.path()
116 self._query = url.query()
117
118 if not case_sensitive:
119 self._path = self._path.lower()
120 self._query = self._query.lower()
121
122 else:
123 self._scheme = None
124 self._netloc = None
125 self._path = None
126 self._query = None
127
128 def _match_method(self, request):
129 if self._method is ANY:
130 return True
131
132 if request.method.lower() == self._method.lower():
133 return True
134
135 return False
136
137 def _match_url(self, request):
138 if self._url is ANY:
139 return True
140
141 # regular expression matching
142 if hasattr(self._url, 'search'):
143 return self._url.search(request.url) is not None
144
145 # scheme is always matched case insensitive
146 if self._scheme and request.scheme.lower() != self._scheme:
147 return False
148
149 # netloc is always matched case insensitive
150 if self._netloc and request.netloc.lower() != self._netloc:
151 return False
152
153 if (request.path or '/') != self._path:
154 return False
155
156 # construct our own qs structure as we remove items from it below
157 request_qs = urllib.parse.parse_qs(request.query,
158 keep_blank_values=True)
159 matcher_qs = urllib.parse.parse_qs(self._query, keep_blank_values=True)
160
161 for k, vals in matcher_qs.items():
162 for v in vals:
163 try:
164 request_qs.get(k, []).remove(v)
165 except ValueError:
166 return False
167
168 if self._complete_qs:
169 for v in request_qs.values():
170 if v:
171 return False
172
173 return True
174
175 def _match_headers(self, request):
176 for k, vals in self._request_headers.items():
177
178 try:
179 header = request.headers[k]
180 except KeyError:
181 # NOTE(jamielennox): This seems to be a requests 1.2/2
182 # difference, in 2 they are just whatever the user inputted in
183 # 1 they are bytes. Let's optionally handle both and look at
184 # removing this when we depend on requests 2.
185 if not isinstance(k, str):
186 return False
187
188 try:
189 header = request.headers[k.encode('utf-8')]
190 except KeyError:
191 return False
192
193 if header != vals:
194 return False
195
196 return True
197
198 def _match_additional(self, request):
199 if callable(self._additional_matcher):
200 return self._additional_matcher(request)
201
202 if self._additional_matcher is not None:
203 raise TypeError("Unexpected format of additional matcher.")
204
205 return True
206
207 def _match(self, request):
208 return (self._match_method(request) and
209 self._match_url(request) and
210 self._match_headers(request) and
211 self._match_additional(request))
212
213 def __call__(self, request):
214 if not self._match(request):
215 return None
216
217 # doing this before _add_to_history means real requests are not stored
218 # in the request history. I'm not sure what is better here.
219 if self._real_http:
220 raise _RunRealHTTP()
221
222 if len(self._responses) > 1:
223 response_matcher = self._responses.pop(0)
224 else:
225 response_matcher = self._responses[0]
226
227 self._add_to_history(request)
228 return response_matcher.get_response(request)
229
230
231class Adapter(BaseAdapter, _RequestHistoryTracker):
232 """A fake adapter than can return predefined responses.
233
234 """
235 def __init__(self, case_sensitive=False):
236 super(Adapter, self).__init__()
237 self._case_sensitive = case_sensitive
238 self._matchers = []
239
240 def send(self, request, **kwargs):
241 request = _RequestObjectProxy(request,
242 case_sensitive=self._case_sensitive,
243 **kwargs)
244 self._add_to_history(request)
245
246 for matcher in reversed(self._matchers):
247 try:
248 resp = matcher(request)
249 except Exception:
250 request._matcher = weakref.ref(matcher)
251 raise
252
253 if resp is not None:
254 request._matcher = weakref.ref(matcher)
255 resp.connection = self
256 logger.debug('{} {} {}'.format(request._request.method,
257 request._request.url,
258 resp.status_code))
259 return resp
260
261 raise exceptions.NoMockAddress(request)
262
263 def close(self):
264 pass
265
266 def register_uri(self, method, url, response_list=None, **kwargs):
267 """Register a new URI match and fake response.
268
269 :param str method: The HTTP method to match.
270 :param str url: The URL to match.
271 """
272 complete_qs = kwargs.pop('complete_qs', False)
273 additional_matcher = kwargs.pop('additional_matcher', None)
274 request_headers = kwargs.pop('request_headers', {})
275 real_http = kwargs.pop('_real_http', False)
276 json_encoder = kwargs.pop('json_encoder', None)
277
278 if response_list and kwargs:
279 raise RuntimeError('You should specify either a list of '
280 'responses OR response kwargs. Not both.')
281 elif real_http and (response_list or kwargs):
282 raise RuntimeError('You should specify either response data '
283 'OR real_http. Not both.')
284 elif not response_list:
285 if json_encoder is not None:
286 kwargs['json_encoder'] = json_encoder
287 response_list = [] if real_http else [kwargs]
288
289 # NOTE(jamielennox): case_sensitive is not present as a kwarg because i
290 # think there would be an edge case where the adapter and register_uri
291 # had different values.
292 # Ideally case_sensitive would be a value passed to match() however
293 # this would change the contract of matchers so we pass ito to the
294 # proxy and the matcher separately.
295 responses = [_MatcherResponse(**k) for k in response_list]
296 matcher = _Matcher(method,
297 url,
298 responses,
299 case_sensitive=self._case_sensitive,
300 complete_qs=complete_qs,
301 additional_matcher=additional_matcher,
302 request_headers=request_headers,
303 real_http=real_http)
304 self.add_matcher(matcher)
305 return matcher
306
307 def add_matcher(self, matcher):
308 """Register a custom matcher.
309
310 A matcher is a callable that takes a `requests.Request` and returns a
311 `requests.Response` if it matches or None if not.
312
313 :param callable matcher: The matcher to execute.
314 """
315 self._matchers.append(matcher)
316
317 def reset(self):
318 super(Adapter, self).reset()
319 for matcher in self._matchers:
320 matcher.reset()
321
322
323__all__ = ['Adapter']