Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/botocore/tokens.py: 28%
182 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:51 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:51 +0000
1# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"). You
4# may not use this file except in compliance with the License. A copy of
5# the License is located at
6#
7# http://aws.amazon.com/apache2.0/
8#
9# or in the "license" file accompanying this file. This file is
10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11# ANY KIND, either express or implied. See the License for the specific
12# language governing permissions and limitations under the License.
13import json
14import logging
15import os
16import threading
17from datetime import datetime, timedelta
18from typing import NamedTuple, Optional
20import dateutil.parser
21from dateutil.tz import tzutc
23from botocore import UNSIGNED
24from botocore.compat import total_seconds
25from botocore.config import Config
26from botocore.exceptions import (
27 ClientError,
28 InvalidConfigError,
29 TokenRetrievalError,
30)
31from botocore.utils import CachedProperty, JSONFileCache, SSOTokenLoader
33logger = logging.getLogger(__name__)
36def _utc_now():
37 return datetime.now(tzutc())
40def create_token_resolver(session):
41 providers = [
42 SSOTokenProvider(session),
43 ]
44 return TokenProviderChain(providers=providers)
47def _serialize_utc_timestamp(obj):
48 if isinstance(obj, datetime):
49 return obj.strftime("%Y-%m-%dT%H:%M:%SZ")
50 return obj
53def _sso_json_dumps(obj):
54 return json.dumps(obj, default=_serialize_utc_timestamp)
57class FrozenAuthToken(NamedTuple):
58 token: str
59 expiration: Optional[datetime] = None
62class DeferredRefreshableToken:
63 # The time at which we'll attempt to refresh, but not block if someone else
64 # is refreshing.
65 _advisory_refresh_timeout = 15 * 60
66 # The time at which all threads will block waiting for a refreshed token
67 _mandatory_refresh_timeout = 10 * 60
68 # Refresh at most once every minute to avoid blocking every request
69 _attempt_timeout = 60
71 def __init__(self, method, refresh_using, time_fetcher=_utc_now):
72 self._time_fetcher = time_fetcher
73 self._refresh_using = refresh_using
74 self.method = method
76 # The frozen token is protected by this lock
77 self._refresh_lock = threading.Lock()
78 self._frozen_token = None
79 self._next_refresh = None
81 def get_frozen_token(self):
82 self._refresh()
83 return self._frozen_token
85 def _refresh(self):
86 # If we don't need to refresh just return
87 refresh_type = self._should_refresh()
88 if not refresh_type:
89 return None
91 # Block for refresh if we're in the mandatory refresh window
92 block_for_refresh = refresh_type == "mandatory"
93 if self._refresh_lock.acquire(block_for_refresh):
94 try:
95 self._protected_refresh()
96 finally:
97 self._refresh_lock.release()
99 def _protected_refresh(self):
100 # This should only be called after acquiring the refresh lock
101 # Another thread may have already refreshed, double check refresh
102 refresh_type = self._should_refresh()
103 if not refresh_type:
104 return None
106 try:
107 now = self._time_fetcher()
108 self._next_refresh = now + timedelta(seconds=self._attempt_timeout)
109 self._frozen_token = self._refresh_using()
110 except Exception:
111 logger.warning(
112 "Refreshing token failed during the %s refresh period.",
113 refresh_type,
114 exc_info=True,
115 )
116 if refresh_type == "mandatory":
117 # This refresh was mandatory, error must be propagated back
118 raise
120 if self._is_expired():
121 # Fresh credentials should never be expired
122 raise TokenRetrievalError(
123 provider=self.method,
124 error_msg="Token has expired and refresh failed",
125 )
127 def _is_expired(self):
128 if self._frozen_token is None:
129 return False
131 expiration = self._frozen_token.expiration
132 remaining = total_seconds(expiration - self._time_fetcher())
133 return remaining <= 0
135 def _should_refresh(self):
136 if self._frozen_token is None:
137 # We don't have a token yet, mandatory refresh
138 return "mandatory"
140 expiration = self._frozen_token.expiration
141 if expiration is None:
142 # No expiration, so assume we don't need to refresh.
143 return None
145 now = self._time_fetcher()
146 if now < self._next_refresh:
147 return None
149 remaining = total_seconds(expiration - now)
151 if remaining < self._mandatory_refresh_timeout:
152 return "mandatory"
153 elif remaining < self._advisory_refresh_timeout:
154 return "advisory"
156 return None
159class TokenProviderChain:
160 def __init__(self, providers=None):
161 if providers is None:
162 providers = []
163 self._providers = providers
165 def load_token(self):
166 for provider in self._providers:
167 token = provider.load_token()
168 if token is not None:
169 return token
170 return None
173class SSOTokenProvider:
174 METHOD = "sso"
175 _REFRESH_WINDOW = 15 * 60
176 _SSO_TOKEN_CACHE_DIR = os.path.expanduser(
177 os.path.join("~", ".aws", "sso", "cache")
178 )
179 _SSO_CONFIG_VARS = [
180 "sso_start_url",
181 "sso_region",
182 ]
183 _GRANT_TYPE = "refresh_token"
184 DEFAULT_CACHE_CLS = JSONFileCache
186 def __init__(
187 self, session, cache=None, time_fetcher=_utc_now, profile_name=None
188 ):
189 self._session = session
190 if cache is None:
191 cache = self.DEFAULT_CACHE_CLS(
192 self._SSO_TOKEN_CACHE_DIR,
193 dumps_func=_sso_json_dumps,
194 )
195 self._now = time_fetcher
196 self._cache = cache
197 self._token_loader = SSOTokenLoader(cache=self._cache)
198 self._profile_name = (
199 profile_name
200 or self._session.get_config_variable("profile")
201 or 'default'
202 )
204 def _load_sso_config(self):
205 loaded_config = self._session.full_config
206 profiles = loaded_config.get("profiles", {})
207 sso_sessions = loaded_config.get("sso_sessions", {})
208 profile_config = profiles.get(self._profile_name, {})
210 if "sso_session" not in profile_config:
211 return
213 sso_session_name = profile_config["sso_session"]
214 sso_config = sso_sessions.get(sso_session_name, None)
216 if not sso_config:
217 error_msg = (
218 f'The profile "{self._profile_name}" is configured to use the SSO '
219 f'token provider but the "{sso_session_name}" sso_session '
220 f"configuration does not exist."
221 )
222 raise InvalidConfigError(error_msg=error_msg)
224 missing_configs = []
225 for var in self._SSO_CONFIG_VARS:
226 if var not in sso_config:
227 missing_configs.append(var)
229 if missing_configs:
230 error_msg = (
231 f'The profile "{self._profile_name}" is configured to use the SSO '
232 f"token provider but is missing the following configuration: "
233 f"{missing_configs}."
234 )
235 raise InvalidConfigError(error_msg=error_msg)
237 return {
238 "session_name": sso_session_name,
239 "sso_region": sso_config["sso_region"],
240 "sso_start_url": sso_config["sso_start_url"],
241 }
243 @CachedProperty
244 def _sso_config(self):
245 return self._load_sso_config()
247 @CachedProperty
248 def _client(self):
249 config = Config(
250 region_name=self._sso_config["sso_region"],
251 signature_version=UNSIGNED,
252 )
253 return self._session.create_client("sso-oidc", config=config)
255 def _attempt_create_token(self, token):
256 response = self._client.create_token(
257 grantType=self._GRANT_TYPE,
258 clientId=token["clientId"],
259 clientSecret=token["clientSecret"],
260 refreshToken=token["refreshToken"],
261 )
262 expires_in = timedelta(seconds=response["expiresIn"])
263 new_token = {
264 "startUrl": self._sso_config["sso_start_url"],
265 "region": self._sso_config["sso_region"],
266 "accessToken": response["accessToken"],
267 "expiresAt": self._now() + expires_in,
268 # Cache the registration alongside the token
269 "clientId": token["clientId"],
270 "clientSecret": token["clientSecret"],
271 "registrationExpiresAt": token["registrationExpiresAt"],
272 }
273 if "refreshToken" in response:
274 new_token["refreshToken"] = response["refreshToken"]
275 logger.info("SSO Token refresh succeeded")
276 return new_token
278 def _refresh_access_token(self, token):
279 keys = (
280 "refreshToken",
281 "clientId",
282 "clientSecret",
283 "registrationExpiresAt",
284 )
285 missing_keys = [k for k in keys if k not in token]
286 if missing_keys:
287 msg = f"Unable to refresh SSO token: missing keys: {missing_keys}"
288 logger.info(msg)
289 return None
291 expiry = dateutil.parser.parse(token["registrationExpiresAt"])
292 if total_seconds(expiry - self._now()) <= 0:
293 logger.info(f"SSO token registration expired at {expiry}")
294 return None
296 try:
297 return self._attempt_create_token(token)
298 except ClientError:
299 logger.warning("SSO token refresh attempt failed", exc_info=True)
300 return None
302 def _refresher(self):
303 start_url = self._sso_config["sso_start_url"]
304 session_name = self._sso_config["session_name"]
305 logger.info(f"Loading cached SSO token for {session_name}")
306 token_dict = self._token_loader(start_url, session_name=session_name)
307 expiration = dateutil.parser.parse(token_dict["expiresAt"])
308 logger.debug(f"Cached SSO token expires at {expiration}")
310 remaining = total_seconds(expiration - self._now())
311 if remaining < self._REFRESH_WINDOW:
312 new_token_dict = self._refresh_access_token(token_dict)
313 if new_token_dict is not None:
314 token_dict = new_token_dict
315 expiration = token_dict["expiresAt"]
316 self._token_loader.save_token(
317 start_url, token_dict, session_name=session_name
318 )
320 return FrozenAuthToken(
321 token_dict["accessToken"], expiration=expiration
322 )
324 def load_token(self):
325 if self._sso_config is None:
326 return None
328 return DeferredRefreshableToken(
329 self.METHOD, self._refresher, time_fetcher=self._now
330 )