1import logging
2
3from oauthlib.common import generate_token, urldecode
4from oauthlib.oauth2 import WebApplicationClient, InsecureTransportError
5from oauthlib.oauth2 import LegacyApplicationClient
6from oauthlib.oauth2 import TokenExpiredError, is_secure_transport
7import requests
8
9log = logging.getLogger(__name__)
10
11
12class TokenUpdated(Warning):
13 def __init__(self, token):
14 super(TokenUpdated, self).__init__()
15 self.token = token
16
17
18class OAuth2Session(requests.Session):
19 """Versatile OAuth 2 extension to :class:`requests.Session`.
20
21 Supports any grant type adhering to :class:`oauthlib.oauth2.Client` spec
22 including the four core OAuth 2 grants.
23
24 Can be used to create authorization urls, fetch tokens and access protected
25 resources using the :class:`requests.Session` interface you are used to.
26
27 - :class:`oauthlib.oauth2.WebApplicationClient` (default): Authorization Code Grant
28 - :class:`oauthlib.oauth2.MobileApplicationClient`: Implicit Grant
29 - :class:`oauthlib.oauth2.LegacyApplicationClient`: Password Credentials Grant
30 - :class:`oauthlib.oauth2.BackendApplicationClient`: Client Credentials Grant
31
32 Note that the only time you will be using Implicit Grant from python is if
33 you are driving a user agent able to obtain URL fragments.
34 """
35
36 def __init__(
37 self,
38 client_id=None,
39 client=None,
40 auto_refresh_url=None,
41 auto_refresh_kwargs=None,
42 scope=None,
43 redirect_uri=None,
44 token=None,
45 state=None,
46 token_updater=None,
47 pkce=None,
48 **kwargs
49 ):
50 """Construct a new OAuth 2 client session.
51
52 :param client_id: Client id obtained during registration
53 :param client: :class:`oauthlib.oauth2.Client` to be used. Default is
54 WebApplicationClient which is useful for any
55 hosted application but not mobile or desktop.
56 :param scope: List of scopes you wish to request access to
57 :param redirect_uri: Redirect URI you registered as callback
58 :param token: Token dictionary, must include access_token
59 and token_type.
60 :param state: State string used to prevent CSRF. This will be given
61 when creating the authorization url and must be supplied
62 when parsing the authorization response.
63 Can be either a string or a no argument callable.
64 :auto_refresh_url: Refresh token endpoint URL, must be HTTPS. Supply
65 this if you wish the client to automatically refresh
66 your access tokens.
67 :auto_refresh_kwargs: Extra arguments to pass to the refresh token
68 endpoint.
69 :token_updater: Method with one argument, token, to be used to update
70 your token database on automatic token refresh. If not
71 set a TokenUpdated warning will be raised when a token
72 has been refreshed. This warning will carry the token
73 in its token argument.
74 :param pkce: Set "S256" or "plain" to enable PKCE. Default is disabled.
75 :param kwargs: Arguments to pass to the Session constructor.
76 """
77 super(OAuth2Session, self).__init__(**kwargs)
78 self._client = client or WebApplicationClient(client_id, token=token)
79 self.token = token or {}
80 self._scope = scope
81 self.redirect_uri = redirect_uri
82 self.state = state or generate_token
83 self._state = state
84 self.auto_refresh_url = auto_refresh_url
85 self.auto_refresh_kwargs = auto_refresh_kwargs or {}
86 self.token_updater = token_updater
87 self._pkce = pkce
88
89 if self._pkce not in ["S256", "plain", None]:
90 raise AttributeError("Wrong value for {}(.., pkce={})".format(self.__class__, self._pkce))
91
92 # Ensure that requests doesn't do any automatic auth. See #278.
93 # The default behavior can be re-enabled by setting auth to None.
94 self.auth = lambda r: r
95
96 # Allow customizations for non compliant providers through various
97 # hooks to adjust requests and responses.
98 self.compliance_hook = {
99 "access_token_response": set(),
100 "refresh_token_response": set(),
101 "protected_request": set(),
102 "refresh_token_request": set(),
103 "access_token_request": set(),
104 }
105
106 @property
107 def scope(self):
108 """By default the scope from the client is used, except if overridden"""
109 if self._scope is not None:
110 return self._scope
111 elif self._client is not None:
112 return self._client.scope
113 else:
114 return None
115
116 @scope.setter
117 def scope(self, scope):
118 self._scope = scope
119
120 def new_state(self):
121 """Generates a state string to be used in authorizations."""
122 try:
123 self._state = self.state()
124 log.debug("Generated new state %s.", self._state)
125 except TypeError:
126 self._state = self.state
127 log.debug("Re-using previously supplied state %s.", self._state)
128 return self._state
129
130 @property
131 def client_id(self):
132 return getattr(self._client, "client_id", None)
133
134 @client_id.setter
135 def client_id(self, value):
136 self._client.client_id = value
137
138 @client_id.deleter
139 def client_id(self):
140 del self._client.client_id
141
142 @property
143 def token(self):
144 return getattr(self._client, "token", None)
145
146 @token.setter
147 def token(self, value):
148 self._client.token = value
149 self._client.populate_token_attributes(value)
150
151 @property
152 def access_token(self):
153 return getattr(self._client, "access_token", None)
154
155 @access_token.setter
156 def access_token(self, value):
157 self._client.access_token = value
158
159 @access_token.deleter
160 def access_token(self):
161 del self._client.access_token
162
163 @property
164 def authorized(self):
165 """Boolean that indicates whether this session has an OAuth token
166 or not. If `self.authorized` is True, you can reasonably expect
167 OAuth-protected requests to the resource to succeed. If
168 `self.authorized` is False, you need the user to go through the OAuth
169 authentication dance before OAuth-protected requests to the resource
170 will succeed.
171 """
172 return bool(self.access_token)
173
174 def authorization_url(self, url, state=None, **kwargs):
175 """Form an authorization URL.
176
177 :param url: Authorization endpoint url, must be HTTPS.
178 :param state: An optional state string for CSRF protection. If not
179 given it will be generated for you.
180 :param kwargs: Extra parameters to include.
181 :return: authorization_url, state
182 """
183 state = state or self.new_state()
184 if self._pkce:
185 self._code_verifier = self._client.create_code_verifier(43)
186 kwargs["code_challenge_method"] = self._pkce
187 kwargs["code_challenge"] = self._client.create_code_challenge(
188 code_verifier=self._code_verifier,
189 code_challenge_method=self._pkce
190 )
191 return (
192 self._client.prepare_request_uri(
193 url,
194 redirect_uri=self.redirect_uri,
195 scope=self.scope,
196 state=state,
197 **kwargs
198 ),
199 state,
200 )
201
202 def fetch_token(
203 self,
204 token_url,
205 code=None,
206 authorization_response=None,
207 body="",
208 auth=None,
209 username=None,
210 password=None,
211 method="POST",
212 force_querystring=False,
213 timeout=None,
214 headers=None,
215 verify=None,
216 proxies=None,
217 include_client_id=None,
218 client_secret=None,
219 cert=None,
220 **kwargs
221 ):
222 """Generic method for fetching an access token from the token endpoint.
223
224 If you are using the MobileApplicationClient you will want to use
225 `token_from_fragment` instead of `fetch_token`.
226
227 The current implementation enforces the RFC guidelines.
228
229 :param token_url: Token endpoint URL, must use HTTPS.
230 :param code: Authorization code (used by WebApplicationClients).
231 :param authorization_response: Authorization response URL, the callback
232 URL of the request back to you. Used by
233 WebApplicationClients instead of code.
234 :param body: Optional application/x-www-form-urlencoded body to add the
235 include in the token request. Prefer kwargs over body.
236 :param auth: An auth tuple or method as accepted by `requests`.
237 :param username: Username required by LegacyApplicationClients to appear
238 in the request body.
239 :param password: Password required by LegacyApplicationClients to appear
240 in the request body.
241 :param method: The HTTP method used to make the request. Defaults
242 to POST, but may also be GET. Other methods should
243 be added as needed.
244 :param force_querystring: If True, force the request body to be sent
245 in the querystring instead.
246 :param timeout: Timeout of the request in seconds.
247 :param headers: Dict to default request headers with.
248 :param verify: Verify SSL certificate.
249 :param proxies: The `proxies` argument is passed onto `requests`.
250 :param include_client_id: Should the request body include the
251 `client_id` parameter. Default is `None`,
252 which will attempt to autodetect. This can be
253 forced to always include (True) or never
254 include (False).
255 :param client_secret: The `client_secret` paired to the `client_id`.
256 This is generally required unless provided in the
257 `auth` tuple. If the value is `None`, it will be
258 omitted from the request, however if the value is
259 an empty string, an empty string will be sent.
260 :param cert: Client certificate to send for OAuth 2.0 Mutual-TLS Client
261 Authentication (draft-ietf-oauth-mtls). Can either be the
262 path of a file containing the private key and certificate or
263 a tuple of two filenames for certificate and key.
264 :param kwargs: Extra parameters to include in the token request.
265 :return: A token dict
266 """
267 if not is_secure_transport(token_url):
268 raise InsecureTransportError()
269
270 if not code and authorization_response:
271 self._client.parse_request_uri_response(
272 authorization_response, state=self._state
273 )
274 code = self._client.code
275 elif not code and isinstance(self._client, WebApplicationClient):
276 code = self._client.code
277 if not code:
278 raise ValueError(
279 "Please supply either code or " "authorization_response parameters."
280 )
281
282 if self._pkce:
283 if self._code_verifier is None:
284 raise ValueError(
285 "Code verifier is not found, authorization URL must be generated before"
286 )
287 kwargs["code_verifier"] = self._code_verifier
288
289 # Earlier versions of this library build an HTTPBasicAuth header out of
290 # `username` and `password`. The RFC states, however these attributes
291 # must be in the request body and not the header.
292 # If an upstream server is not spec compliant and requires them to
293 # appear as an Authorization header, supply an explicit `auth` header
294 # to this function.
295 # This check will allow for empty strings, but not `None`.
296 #
297 # References
298 # 4.3.2 - Resource Owner Password Credentials Grant
299 # https://tools.ietf.org/html/rfc6749#section-4.3.2
300
301 if isinstance(self._client, LegacyApplicationClient):
302 if username is None:
303 raise ValueError(
304 "`LegacyApplicationClient` requires both the "
305 "`username` and `password` parameters."
306 )
307 if password is None:
308 raise ValueError(
309 "The required parameter `username` was supplied, "
310 "but `password` was not."
311 )
312
313 # merge username and password into kwargs for `prepare_request_body`
314 if username is not None:
315 kwargs["username"] = username
316 if password is not None:
317 kwargs["password"] = password
318
319 # is an auth explicitly supplied?
320 if auth is not None:
321 # if we're dealing with the default of `include_client_id` (None):
322 # we will assume the `auth` argument is for an RFC compliant server
323 # and we should not send the `client_id` in the body.
324 # This approach allows us to still force the client_id by submitting
325 # `include_client_id=True` along with an `auth` object.
326 if include_client_id is None:
327 include_client_id = False
328
329 # otherwise we may need to create an auth header
330 else:
331 # since we don't have an auth header, we MAY need to create one
332 # it is possible that we want to send the `client_id` in the body
333 # if so, `include_client_id` should be set to True
334 # otherwise, we will generate an auth header
335 if include_client_id is not True:
336 client_id = self.client_id
337 if client_id:
338 log.debug(
339 'Encoding `client_id` "%s" with `client_secret` '
340 "as Basic auth credentials.",
341 client_id,
342 )
343 client_secret = client_secret if client_secret is not None else ""
344 auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
345
346 if include_client_id:
347 # this was pulled out of the params
348 # it needs to be passed into prepare_request_body
349 if client_secret is not None:
350 kwargs["client_secret"] = client_secret
351
352 body = self._client.prepare_request_body(
353 code=code,
354 body=body,
355 redirect_uri=self.redirect_uri,
356 include_client_id=include_client_id,
357 **kwargs
358 )
359
360 headers = headers or {
361 "Accept": "application/json",
362 "Content-Type": "application/x-www-form-urlencoded",
363 }
364 self.token = {}
365 request_kwargs = {}
366 if method.upper() == "POST":
367 request_kwargs["params" if force_querystring else "data"] = dict(
368 urldecode(body)
369 )
370 elif method.upper() == "GET":
371 request_kwargs["params"] = dict(urldecode(body))
372 else:
373 raise ValueError("The method kwarg must be POST or GET.")
374
375 for hook in self.compliance_hook["access_token_request"]:
376 log.debug("Invoking access_token_request hook %s.", hook)
377 token_url, headers, request_kwargs = hook(
378 token_url, headers, request_kwargs
379 )
380
381 r = self.request(
382 method=method,
383 url=token_url,
384 timeout=timeout,
385 headers=headers,
386 auth=auth,
387 verify=verify,
388 proxies=proxies,
389 cert=cert,
390 **request_kwargs
391 )
392
393 log.debug("Request to fetch token completed with status %s.", r.status_code)
394 log.debug("Request url was %s", r.request.url)
395 log.debug("Request headers were %s", r.request.headers)
396 log.debug("Request body was %s", r.request.body)
397 log.debug("Response headers were %s and content %s.", r.headers, r.text)
398 log.debug(
399 "Invoking %d token response hooks.",
400 len(self.compliance_hook["access_token_response"]),
401 )
402 for hook in self.compliance_hook["access_token_response"]:
403 log.debug("Invoking hook %s.", hook)
404 r = hook(r)
405
406 self._client.parse_request_body_response(r.text, scope=self.scope)
407 self.token = self._client.token
408 log.debug("Obtained token %s.", self.token)
409 return self.token
410
411 def token_from_fragment(self, authorization_response):
412 """Parse token from the URI fragment, used by MobileApplicationClients.
413
414 :param authorization_response: The full URL of the redirect back to you
415 :return: A token dict
416 """
417 self._client.parse_request_uri_response(
418 authorization_response, state=self._state
419 )
420 self.token = self._client.token
421 return self.token
422
423 def refresh_token(
424 self,
425 token_url,
426 refresh_token=None,
427 body="",
428 auth=None,
429 timeout=None,
430 headers=None,
431 verify=None,
432 proxies=None,
433 **kwargs
434 ):
435 """Fetch a new access token using a refresh token.
436
437 :param token_url: The token endpoint, must be HTTPS.
438 :param refresh_token: The refresh_token to use.
439 :param body: Optional application/x-www-form-urlencoded body to add the
440 include in the token request. Prefer kwargs over body.
441 :param auth: An auth tuple or method as accepted by `requests`.
442 :param timeout: Timeout of the request in seconds.
443 :param headers: A dict of headers to be used by `requests`.
444 :param verify: Verify SSL certificate.
445 :param proxies: The `proxies` argument will be passed to `requests`.
446 :param kwargs: Extra parameters to include in the token request.
447 :return: A token dict
448 """
449 if not token_url:
450 raise ValueError("No token endpoint set for auto_refresh.")
451
452 if not is_secure_transport(token_url):
453 raise InsecureTransportError()
454
455 refresh_token = refresh_token or self.token.get("refresh_token")
456
457 log.debug(
458 "Adding auto refresh key word arguments %s.", self.auto_refresh_kwargs
459 )
460 kwargs.update(self.auto_refresh_kwargs)
461 body = self._client.prepare_refresh_body(
462 body=body, refresh_token=refresh_token, scope=self.scope, **kwargs
463 )
464 log.debug("Prepared refresh token request body %s", body)
465
466 if headers is None:
467 headers = {
468 "Accept": "application/json",
469 "Content-Type": ("application/x-www-form-urlencoded"),
470 }
471
472 for hook in self.compliance_hook["refresh_token_request"]:
473 log.debug("Invoking refresh_token_request hook %s.", hook)
474 token_url, headers, body = hook(token_url, headers, body)
475
476 r = self.post(
477 token_url,
478 data=dict(urldecode(body)),
479 auth=auth,
480 timeout=timeout,
481 headers=headers,
482 verify=verify,
483 withhold_token=True,
484 proxies=proxies,
485 )
486 log.debug("Request to refresh token completed with status %s.", r.status_code)
487 log.debug("Response headers were %s and content %s.", r.headers, r.text)
488 log.debug(
489 "Invoking %d token response hooks.",
490 len(self.compliance_hook["refresh_token_response"]),
491 )
492 for hook in self.compliance_hook["refresh_token_response"]:
493 log.debug("Invoking hook %s.", hook)
494 r = hook(r)
495
496 self.token = self._client.parse_request_body_response(r.text, scope=self.scope)
497 if "refresh_token" not in self.token:
498 log.debug("No new refresh token given. Re-using old.")
499 self.token["refresh_token"] = refresh_token
500 return self.token
501
502 def request(
503 self,
504 method,
505 url,
506 data=None,
507 headers=None,
508 withhold_token=False,
509 client_id=None,
510 client_secret=None,
511 files=None,
512 **kwargs
513 ):
514 """Intercept all requests and add the OAuth 2 token if present."""
515 if not is_secure_transport(url):
516 raise InsecureTransportError()
517 if self.token and not withhold_token:
518 log.debug(
519 "Invoking %d protected resource request hooks.",
520 len(self.compliance_hook["protected_request"]),
521 )
522 for hook in self.compliance_hook["protected_request"]:
523 log.debug("Invoking hook %s.", hook)
524 url, headers, data = hook(url, headers, data)
525
526 log.debug("Adding token %s to request.", self.token)
527 try:
528 url, headers, data = self._client.add_token(
529 url, http_method=method, body=data, headers=headers
530 )
531 # Attempt to retrieve and save new access token if expired
532 except TokenExpiredError:
533 if self.auto_refresh_url:
534 log.debug(
535 "Auto refresh is set, attempting to refresh at %s.",
536 self.auto_refresh_url,
537 )
538
539 # We mustn't pass auth twice.
540 auth = kwargs.pop("auth", None)
541 if client_id and client_secret and (auth is None):
542 log.debug(
543 'Encoding client_id "%s" with client_secret as Basic auth credentials.',
544 client_id,
545 )
546 auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
547 token = self.refresh_token(
548 self.auto_refresh_url, auth=auth, **kwargs
549 )
550 if self.token_updater:
551 log.debug(
552 "Updating token to %s using %s.", token, self.token_updater
553 )
554 self.token_updater(token)
555 url, headers, data = self._client.add_token(
556 url, http_method=method, body=data, headers=headers
557 )
558 else:
559 raise TokenUpdated(token)
560 else:
561 raise
562
563 log.debug("Requesting url %s using method %s.", url, method)
564 log.debug("Supplying headers %s and data %s", headers, data)
565 log.debug("Passing through key word arguments %s.", kwargs)
566 return super(OAuth2Session, self).request(
567 method, url, headers=headers, data=data, files=files, **kwargs
568 )
569
570 def register_compliance_hook(self, hook_type, hook):
571 """Register a hook for request/response tweaking.
572
573 Available hooks are:
574 access_token_response invoked before token parsing.
575 refresh_token_response invoked before refresh token parsing.
576 protected_request invoked before making a request.
577 access_token_request invoked before making a token fetch request.
578 refresh_token_request invoked before making a refresh request.
579
580 If you find a new hook is needed please send a GitHub PR request
581 or open an issue.
582 """
583 if hook_type not in self.compliance_hook:
584 raise ValueError(
585 "Hook type %s is not in %s.", hook_type, self.compliance_hook
586 )
587 self.compliance_hook[hook_type].add(hook)