Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/requests_mock/adapter.py: 59%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

180 statements  

1# Licensed under the Apache License, Version 2.0 (the "License"); you may 

2# not use this file except in compliance with the License. You may obtain 

3# a copy of the License at 

4# 

5# https://www.apache.org/licenses/LICENSE-2.0 

6# 

7# Unless required by applicable law or agreed to in writing, software 

8# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 

9# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 

10# License for the specific language governing permissions and limitations 

11# under the License. 

12 

13import urllib.parse 

14import weakref 

15 

16from requests.adapters import BaseAdapter 

17from requests.utils import requote_uri 

18 

19from requests_mock import exceptions 

20from requests_mock.request import _RequestObjectProxy 

21from requests_mock.response import _MatcherResponse 

22 

23import logging 

24 

25logger = logging.getLogger(__name__) 

26 

27try: 

28 import purl 

29 purl_types = (purl.URL,) 

30except ImportError: 

31 purl = None 

32 purl_types = () 

33 

34ANY = object() 

35 

36 

37class _RequestHistoryTracker(object): 

38 

39 def __init__(self): 

40 self.request_history = [] 

41 

42 def _add_to_history(self, request): 

43 self.request_history.append(request) 

44 

45 @property 

46 def last_request(self): 

47 """Retrieve the latest request sent""" 

48 try: 

49 return self.request_history[-1] 

50 except IndexError: 

51 return None 

52 

53 @property 

54 def called(self): 

55 return self.call_count > 0 

56 

57 @property 

58 def called_once(self): 

59 return self.call_count == 1 

60 

61 @property 

62 def call_count(self): 

63 return len(self.request_history) 

64 

65 def reset(self): 

66 self.request_history = [] 

67 

68 

69class _RunRealHTTP(Exception): 

70 """A fake exception to jump out of mocking and allow a real request. 

71 

72 This exception is caught at the mocker level and allows it to execute this 

73 request through the real requests mechanism rather than the mocker. 

74 

75 It should never be exposed to a user. 

76 """ 

77 

78 

79class _Matcher(_RequestHistoryTracker): 

80 """Contains all the information about a provided URL to match.""" 

81 

82 def __init__(self, method, url, responses, complete_qs, request_headers, 

83 additional_matcher, real_http, case_sensitive): 

84 """ 

85 :param bool complete_qs: Match the entire query string. By default URLs 

86 match if all the provided matcher query arguments are matched and 

87 extra query arguments are ignored. Set complete_qs to true to 

88 require that the entire query string needs to match. 

89 """ 

90 super(_Matcher, self).__init__() 

91 

92 self._method = method 

93 self._url = url 

94 self._responses = responses 

95 self._complete_qs = complete_qs 

96 self._request_headers = request_headers 

97 self._real_http = real_http 

98 self._additional_matcher = additional_matcher 

99 

100 # url can be a regex object or ANY so don't always run urlparse 

101 if isinstance(url, str): 

102 url_parts = urllib.parse.urlparse(url) 

103 self._scheme = url_parts.scheme.lower() 

104 self._netloc = url_parts.netloc.lower() 

105 self._path = requote_uri(url_parts.path or '/') 

106 self._query = url_parts.query 

107 

108 if not case_sensitive: 

109 self._path = self._path.lower() 

110 self._query = self._query.lower() 

111 

112 elif isinstance(url, purl_types): 

113 self._scheme = url.scheme() 

114 self._netloc = url.netloc() 

115 self._path = url.path() 

116 self._query = url.query() 

117 

118 if not case_sensitive: 

119 self._path = self._path.lower() 

120 self._query = self._query.lower() 

121 

122 else: 

123 self._scheme = None 

124 self._netloc = None 

125 self._path = None 

126 self._query = None 

127 

128 def _match_method(self, request): 

129 if self._method is ANY: 

130 return True 

131 

132 if request.method.lower() == self._method.lower(): 

133 return True 

134 

135 return False 

136 

137 def _match_url(self, request): 

138 if self._url is ANY: 

139 return True 

140 

141 # regular expression matching 

142 if hasattr(self._url, 'search'): 

143 return self._url.search(request.url) is not None 

144 

145 # scheme is always matched case insensitive 

146 if self._scheme and request.scheme.lower() != self._scheme: 

147 return False 

148 

149 # netloc is always matched case insensitive 

150 if self._netloc and request.netloc.lower() != self._netloc: 

151 return False 

152 

153 if (request.path or '/') != self._path: 

154 return False 

155 

156 # construct our own qs structure as we remove items from it below 

157 request_qs = urllib.parse.parse_qs(request.query, 

158 keep_blank_values=True) 

159 matcher_qs = urllib.parse.parse_qs(self._query, keep_blank_values=True) 

160 

161 for k, vals in matcher_qs.items(): 

162 for v in vals: 

163 try: 

164 request_qs.get(k, []).remove(v) 

165 except ValueError: 

166 return False 

167 

168 if self._complete_qs: 

169 for v in request_qs.values(): 

170 if v: 

171 return False 

172 

173 return True 

174 

175 def _match_headers(self, request): 

176 for k, vals in self._request_headers.items(): 

177 

178 try: 

179 header = request.headers[k] 

180 except KeyError: 

181 # NOTE(jamielennox): This seems to be a requests 1.2/2 

182 # difference, in 2 they are just whatever the user inputted in 

183 # 1 they are bytes. Let's optionally handle both and look at 

184 # removing this when we depend on requests 2. 

185 if not isinstance(k, str): 

186 return False 

187 

188 try: 

189 header = request.headers[k.encode('utf-8')] 

190 except KeyError: 

191 return False 

192 

193 if header != vals: 

194 return False 

195 

196 return True 

197 

198 def _match_additional(self, request): 

199 if callable(self._additional_matcher): 

200 return self._additional_matcher(request) 

201 

202 if self._additional_matcher is not None: 

203 raise TypeError("Unexpected format of additional matcher.") 

204 

205 return True 

206 

207 def _match(self, request): 

208 return (self._match_method(request) and 

209 self._match_url(request) and 

210 self._match_headers(request) and 

211 self._match_additional(request)) 

212 

213 def __call__(self, request): 

214 if not self._match(request): 

215 return None 

216 

217 # doing this before _add_to_history means real requests are not stored 

218 # in the request history. I'm not sure what is better here. 

219 if self._real_http: 

220 raise _RunRealHTTP() 

221 

222 if len(self._responses) > 1: 

223 response_matcher = self._responses.pop(0) 

224 else: 

225 response_matcher = self._responses[0] 

226 

227 self._add_to_history(request) 

228 return response_matcher.get_response(request) 

229 

230 

231class Adapter(BaseAdapter, _RequestHistoryTracker): 

232 """A fake adapter than can return predefined responses. 

233 

234 """ 

235 def __init__(self, case_sensitive=False): 

236 super(Adapter, self).__init__() 

237 self._case_sensitive = case_sensitive 

238 self._matchers = [] 

239 

240 def send(self, request, **kwargs): 

241 request = _RequestObjectProxy(request, 

242 case_sensitive=self._case_sensitive, 

243 **kwargs) 

244 self._add_to_history(request) 

245 

246 for matcher in reversed(self._matchers): 

247 try: 

248 resp = matcher(request) 

249 except Exception: 

250 request._matcher = weakref.ref(matcher) 

251 raise 

252 

253 if resp is not None: 

254 request._matcher = weakref.ref(matcher) 

255 resp.connection = self 

256 logger.debug('{} {} {}'.format(request._request.method, 

257 request._request.url, 

258 resp.status_code)) 

259 return resp 

260 

261 raise exceptions.NoMockAddress(request) 

262 

263 def close(self): 

264 pass 

265 

266 def register_uri(self, method, url, response_list=None, **kwargs): 

267 """Register a new URI match and fake response. 

268 

269 :param str method: The HTTP method to match. 

270 :param str url: The URL to match. 

271 """ 

272 complete_qs = kwargs.pop('complete_qs', False) 

273 additional_matcher = kwargs.pop('additional_matcher', None) 

274 request_headers = kwargs.pop('request_headers', {}) 

275 real_http = kwargs.pop('_real_http', False) 

276 json_encoder = kwargs.pop('json_encoder', None) 

277 

278 if response_list and kwargs: 

279 raise RuntimeError('You should specify either a list of ' 

280 'responses OR response kwargs. Not both.') 

281 elif real_http and (response_list or kwargs): 

282 raise RuntimeError('You should specify either response data ' 

283 'OR real_http. Not both.') 

284 elif not response_list: 

285 if json_encoder is not None: 

286 kwargs['json_encoder'] = json_encoder 

287 response_list = [] if real_http else [kwargs] 

288 

289 # NOTE(jamielennox): case_sensitive is not present as a kwarg because i 

290 # think there would be an edge case where the adapter and register_uri 

291 # had different values. 

292 # Ideally case_sensitive would be a value passed to match() however 

293 # this would change the contract of matchers so we pass ito to the 

294 # proxy and the matcher separately. 

295 responses = [_MatcherResponse(**k) for k in response_list] 

296 matcher = _Matcher(method, 

297 url, 

298 responses, 

299 case_sensitive=self._case_sensitive, 

300 complete_qs=complete_qs, 

301 additional_matcher=additional_matcher, 

302 request_headers=request_headers, 

303 real_http=real_http) 

304 self.add_matcher(matcher) 

305 return matcher 

306 

307 def add_matcher(self, matcher): 

308 """Register a custom matcher. 

309 

310 A matcher is a callable that takes a `requests.Request` and returns a 

311 `requests.Response` if it matches or None if not. 

312 

313 :param callable matcher: The matcher to execute. 

314 """ 

315 self._matchers.append(matcher) 

316 

317 def reset(self): 

318 super(Adapter, self).reset() 

319 for matcher in self._matchers: 

320 matcher.reset() 

321 

322 

323__all__ = ['Adapter']