Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/botocore/tokens.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

198 statements  

1# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"). You 

4# may not use this file except in compliance with the License. A copy of 

5# the License is located at 

6# 

7# http://aws.amazon.com/apache2.0/ 

8# 

9# or in the "license" file accompanying this file. This file is 

10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 

11# ANY KIND, either express or implied. See the License for the specific 

12# language governing permissions and limitations under the License. 

13import json 

14import logging 

15import os 

16import threading 

17from datetime import datetime, timedelta 

18from typing import NamedTuple, Optional 

19 

20import dateutil.parser 

21from dateutil.tz import tzutc 

22 

23from botocore import UNSIGNED 

24from botocore.compat import total_seconds 

25from botocore.config import Config 

26from botocore.exceptions import ( 

27 ClientError, 

28 InvalidConfigError, 

29 TokenRetrievalError, 

30) 

31from botocore.utils import ( 

32 CachedProperty, 

33 JSONFileCache, 

34 SSOTokenLoader, 

35 get_token_from_environment, 

36) 

37 

38logger = logging.getLogger(__name__) 

39 

40 

41def _utc_now(): 

42 return datetime.now(tzutc()) 

43 

44 

45def create_token_resolver(session): 

46 providers = [ 

47 ScopedEnvTokenProvider(session), 

48 SSOTokenProvider(session), 

49 ] 

50 return TokenProviderChain(providers=providers) 

51 

52 

53def _serialize_utc_timestamp(obj): 

54 if isinstance(obj, datetime): 

55 return obj.strftime("%Y-%m-%dT%H:%M:%SZ") 

56 return obj 

57 

58 

59def _sso_json_dumps(obj): 

60 return json.dumps(obj, default=_serialize_utc_timestamp) 

61 

62 

63class FrozenAuthToken(NamedTuple): 

64 token: str 

65 expiration: Optional[datetime] = None 

66 

67 

68class DeferredRefreshableToken: 

69 # The time at which we'll attempt to refresh, but not block if someone else 

70 # is refreshing. 

71 _advisory_refresh_timeout = 15 * 60 

72 # The time at which all threads will block waiting for a refreshed token 

73 _mandatory_refresh_timeout = 10 * 60 

74 # Refresh at most once every minute to avoid blocking every request 

75 _attempt_timeout = 60 

76 

77 def __init__(self, method, refresh_using, time_fetcher=_utc_now): 

78 self._time_fetcher = time_fetcher 

79 self._refresh_using = refresh_using 

80 self.method = method 

81 

82 # The frozen token is protected by this lock 

83 self._refresh_lock = threading.Lock() 

84 self._frozen_token = None 

85 self._next_refresh = None 

86 

87 def get_frozen_token(self): 

88 self._refresh() 

89 return self._frozen_token 

90 

91 def _refresh(self): 

92 # If we don't need to refresh just return 

93 refresh_type = self._should_refresh() 

94 if not refresh_type: 

95 return None 

96 

97 # Block for refresh if we're in the mandatory refresh window 

98 block_for_refresh = refresh_type == "mandatory" 

99 if self._refresh_lock.acquire(block_for_refresh): 

100 try: 

101 self._protected_refresh() 

102 finally: 

103 self._refresh_lock.release() 

104 

105 def _protected_refresh(self): 

106 # This should only be called after acquiring the refresh lock 

107 # Another thread may have already refreshed, double check refresh 

108 refresh_type = self._should_refresh() 

109 if not refresh_type: 

110 return None 

111 

112 try: 

113 now = self._time_fetcher() 

114 self._next_refresh = now + timedelta(seconds=self._attempt_timeout) 

115 self._frozen_token = self._refresh_using() 

116 except Exception: 

117 logger.warning( 

118 "Refreshing token failed during the %s refresh period.", 

119 refresh_type, 

120 exc_info=True, 

121 ) 

122 if refresh_type == "mandatory": 

123 # This refresh was mandatory, error must be propagated back 

124 raise 

125 

126 if self._is_expired(): 

127 # Fresh credentials should never be expired 

128 raise TokenRetrievalError( 

129 provider=self.method, 

130 error_msg="Token has expired and refresh failed", 

131 ) 

132 

133 def _is_expired(self): 

134 if self._frozen_token is None: 

135 return False 

136 

137 expiration = self._frozen_token.expiration 

138 remaining = total_seconds(expiration - self._time_fetcher()) 

139 return remaining <= 0 

140 

141 def _should_refresh(self): 

142 if self._frozen_token is None: 

143 # We don't have a token yet, mandatory refresh 

144 return "mandatory" 

145 

146 expiration = self._frozen_token.expiration 

147 if expiration is None: 

148 # No expiration, so assume we don't need to refresh. 

149 return None 

150 

151 now = self._time_fetcher() 

152 if now < self._next_refresh: 

153 return None 

154 

155 remaining = total_seconds(expiration - now) 

156 

157 if remaining < self._mandatory_refresh_timeout: 

158 return "mandatory" 

159 elif remaining < self._advisory_refresh_timeout: 

160 return "advisory" 

161 

162 return None 

163 

164 

165class TokenProviderChain: 

166 def __init__(self, providers=None): 

167 if providers is None: 

168 providers = [] 

169 self._providers = providers 

170 

171 def load_token(self, **kwargs): 

172 for provider in self._providers: 

173 token = provider.load_token(**kwargs) 

174 if token is not None: 

175 return token 

176 return None 

177 

178 

179class SSOTokenProvider: 

180 METHOD = "sso" 

181 _REFRESH_WINDOW = 15 * 60 

182 _SSO_TOKEN_CACHE_DIR = os.path.expanduser( 

183 os.path.join("~", ".aws", "sso", "cache") 

184 ) 

185 _SSO_CONFIG_VARS = [ 

186 "sso_start_url", 

187 "sso_region", 

188 ] 

189 _GRANT_TYPE = "refresh_token" 

190 DEFAULT_CACHE_CLS = JSONFileCache 

191 

192 def __init__( 

193 self, session, cache=None, time_fetcher=_utc_now, profile_name=None 

194 ): 

195 self._session = session 

196 if cache is None: 

197 cache = self.DEFAULT_CACHE_CLS( 

198 self._SSO_TOKEN_CACHE_DIR, 

199 dumps_func=_sso_json_dumps, 

200 ) 

201 self._now = time_fetcher 

202 self._cache = cache 

203 self._token_loader = SSOTokenLoader(cache=self._cache) 

204 self._profile_name = ( 

205 profile_name 

206 or self._session.get_config_variable("profile") 

207 or 'default' 

208 ) 

209 

210 def _load_sso_config(self): 

211 loaded_config = self._session.full_config 

212 profiles = loaded_config.get("profiles", {}) 

213 sso_sessions = loaded_config.get("sso_sessions", {}) 

214 profile_config = profiles.get(self._profile_name, {}) 

215 

216 if "sso_session" not in profile_config: 

217 return 

218 

219 sso_session_name = profile_config["sso_session"] 

220 sso_config = sso_sessions.get(sso_session_name, None) 

221 

222 if not sso_config: 

223 error_msg = ( 

224 f'The profile "{self._profile_name}" is configured to use the SSO ' 

225 f'token provider but the "{sso_session_name}" sso_session ' 

226 f"configuration does not exist." 

227 ) 

228 raise InvalidConfigError(error_msg=error_msg) 

229 

230 missing_configs = [] 

231 for var in self._SSO_CONFIG_VARS: 

232 if var not in sso_config: 

233 missing_configs.append(var) 

234 

235 if missing_configs: 

236 error_msg = ( 

237 f'The profile "{self._profile_name}" is configured to use the SSO ' 

238 f"token provider but is missing the following configuration: " 

239 f"{missing_configs}." 

240 ) 

241 raise InvalidConfigError(error_msg=error_msg) 

242 

243 return { 

244 "session_name": sso_session_name, 

245 "sso_region": sso_config["sso_region"], 

246 "sso_start_url": sso_config["sso_start_url"], 

247 } 

248 

249 @CachedProperty 

250 def _sso_config(self): 

251 return self._load_sso_config() 

252 

253 @CachedProperty 

254 def _client(self): 

255 config = Config( 

256 region_name=self._sso_config["sso_region"], 

257 signature_version=UNSIGNED, 

258 ) 

259 return self._session.create_client("sso-oidc", config=config) 

260 

261 def _attempt_create_token(self, token): 

262 response = self._client.create_token( 

263 grantType=self._GRANT_TYPE, 

264 clientId=token["clientId"], 

265 clientSecret=token["clientSecret"], 

266 refreshToken=token["refreshToken"], 

267 ) 

268 expires_in = timedelta(seconds=response["expiresIn"]) 

269 new_token = { 

270 "startUrl": self._sso_config["sso_start_url"], 

271 "region": self._sso_config["sso_region"], 

272 "accessToken": response["accessToken"], 

273 "expiresAt": self._now() + expires_in, 

274 # Cache the registration alongside the token 

275 "clientId": token["clientId"], 

276 "clientSecret": token["clientSecret"], 

277 "registrationExpiresAt": token["registrationExpiresAt"], 

278 } 

279 if "refreshToken" in response: 

280 new_token["refreshToken"] = response["refreshToken"] 

281 logger.info("SSO Token refresh succeeded") 

282 return new_token 

283 

284 def _refresh_access_token(self, token): 

285 keys = ( 

286 "refreshToken", 

287 "clientId", 

288 "clientSecret", 

289 "registrationExpiresAt", 

290 ) 

291 missing_keys = [k for k in keys if k not in token] 

292 if missing_keys: 

293 msg = f"Unable to refresh SSO token: missing keys: {missing_keys}" 

294 logger.info(msg) 

295 return None 

296 

297 expiry = dateutil.parser.parse(token["registrationExpiresAt"]) 

298 if total_seconds(expiry - self._now()) <= 0: 

299 logger.info(f"SSO token registration expired at {expiry}") 

300 return None 

301 

302 try: 

303 return self._attempt_create_token(token) 

304 except ClientError: 

305 logger.warning("SSO token refresh attempt failed", exc_info=True) 

306 return None 

307 

308 def _refresher(self): 

309 start_url = self._sso_config["sso_start_url"] 

310 session_name = self._sso_config["session_name"] 

311 logger.info(f"Loading cached SSO token for {session_name}") 

312 token_dict = self._token_loader(start_url, session_name=session_name) 

313 expiration = dateutil.parser.parse(token_dict["expiresAt"]) 

314 logger.debug(f"Cached SSO token expires at {expiration}") 

315 

316 remaining = total_seconds(expiration - self._now()) 

317 if remaining < self._REFRESH_WINDOW: 

318 new_token_dict = self._refresh_access_token(token_dict) 

319 if new_token_dict is not None: 

320 token_dict = new_token_dict 

321 expiration = token_dict["expiresAt"] 

322 self._token_loader.save_token( 

323 start_url, token_dict, session_name=session_name 

324 ) 

325 

326 return FrozenAuthToken( 

327 token_dict["accessToken"], expiration=expiration 

328 ) 

329 

330 def load_token(self, **kwargs): 

331 if self._sso_config is None: 

332 return None 

333 

334 return DeferredRefreshableToken( 

335 self.METHOD, self._refresher, time_fetcher=self._now 

336 ) 

337 

338 

339class ScopedEnvTokenProvider: 

340 """ 

341 Token provider that loads tokens from environment variables scoped to 

342 a specific `signing_name`. 

343 """ 

344 

345 METHOD = 'env' 

346 

347 def __init__(self, session, environ=None): 

348 self._session = session 

349 if environ is None: 

350 environ = os.environ 

351 self.environ = environ 

352 

353 def load_token(self, **kwargs): 

354 signing_name = kwargs.get("signing_name") 

355 if signing_name is None: 

356 return None 

357 

358 token = get_token_from_environment(signing_name, self.environ) 

359 

360 if token is not None: 

361 logger.info("Found token in environment variables.") 

362 return FrozenAuthToken(token)