Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/botocore/tokens.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

198 statements  

1# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"). You 

4# may not use this file except in compliance with the License. A copy of 

5# the License is located at 

6# 

7# http://aws.amazon.com/apache2.0/ 

8# 

9# or in the "license" file accompanying this file. This file is 

10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 

11# ANY KIND, either express or implied. See the License for the specific 

12# language governing permissions and limitations under the License. 

13import json 

14import logging 

15import os 

16import threading 

17from datetime import datetime, timedelta 

18from typing import NamedTuple, Optional 

19 

20import dateutil.parser 

21from dateutil.tz import tzutc 

22 

23from botocore import UNSIGNED 

24from botocore.compat import total_seconds 

25from botocore.config import Config 

26from botocore.exceptions import ( 

27 ClientError, 

28 InvalidConfigError, 

29 TokenRetrievalError, 

30) 

31from botocore.utils import ( 

32 CachedProperty, 

33 JSONFileCache, 

34 SSOTokenLoader, 

35 create_nested_client, 

36 get_token_from_environment, 

37) 

38 

39logger = logging.getLogger(__name__) 

40 

41 

42def _utc_now(): 

43 return datetime.now(tzutc()) 

44 

45 

46def create_token_resolver(session): 

47 providers = [ 

48 ScopedEnvTokenProvider(session), 

49 SSOTokenProvider(session), 

50 ] 

51 return TokenProviderChain(providers=providers) 

52 

53 

54def _serialize_utc_timestamp(obj): 

55 if isinstance(obj, datetime): 

56 return obj.strftime("%Y-%m-%dT%H:%M:%SZ") 

57 return obj 

58 

59 

60def _sso_json_dumps(obj): 

61 return json.dumps(obj, default=_serialize_utc_timestamp) 

62 

63 

64class FrozenAuthToken(NamedTuple): 

65 token: str 

66 expiration: Optional[datetime] = None 

67 

68 

69class DeferredRefreshableToken: 

70 # The time at which we'll attempt to refresh, but not block if someone else 

71 # is refreshing. 

72 _advisory_refresh_timeout = 15 * 60 

73 # The time at which all threads will block waiting for a refreshed token 

74 _mandatory_refresh_timeout = 10 * 60 

75 # Refresh at most once every minute to avoid blocking every request 

76 _attempt_timeout = 60 

77 

78 def __init__(self, method, refresh_using, time_fetcher=_utc_now): 

79 self._time_fetcher = time_fetcher 

80 self._refresh_using = refresh_using 

81 self.method = method 

82 

83 # The frozen token is protected by this lock 

84 self._refresh_lock = threading.Lock() 

85 self._frozen_token = None 

86 self._next_refresh = None 

87 

88 def get_frozen_token(self): 

89 self._refresh() 

90 return self._frozen_token 

91 

92 def _refresh(self): 

93 # If we don't need to refresh just return 

94 refresh_type = self._should_refresh() 

95 if not refresh_type: 

96 return None 

97 

98 # Block for refresh if we're in the mandatory refresh window 

99 block_for_refresh = refresh_type == "mandatory" 

100 if self._refresh_lock.acquire(block_for_refresh): 

101 try: 

102 self._protected_refresh() 

103 finally: 

104 self._refresh_lock.release() 

105 

106 def _protected_refresh(self): 

107 # This should only be called after acquiring the refresh lock 

108 # Another thread may have already refreshed, double check refresh 

109 refresh_type = self._should_refresh() 

110 if not refresh_type: 

111 return None 

112 

113 try: 

114 now = self._time_fetcher() 

115 self._next_refresh = now + timedelta(seconds=self._attempt_timeout) 

116 self._frozen_token = self._refresh_using() 

117 except Exception: 

118 logger.warning( 

119 "Refreshing token failed during the %s refresh period.", 

120 refresh_type, 

121 exc_info=True, 

122 ) 

123 if refresh_type == "mandatory": 

124 # This refresh was mandatory, error must be propagated back 

125 raise 

126 

127 if self._is_expired(): 

128 # Fresh credentials should never be expired 

129 raise TokenRetrievalError( 

130 provider=self.method, 

131 error_msg="Token has expired and refresh failed", 

132 ) 

133 

134 def _is_expired(self): 

135 if self._frozen_token is None: 

136 return False 

137 

138 expiration = self._frozen_token.expiration 

139 remaining = total_seconds(expiration - self._time_fetcher()) 

140 return remaining <= 0 

141 

142 def _should_refresh(self): 

143 if self._frozen_token is None: 

144 # We don't have a token yet, mandatory refresh 

145 return "mandatory" 

146 

147 expiration = self._frozen_token.expiration 

148 if expiration is None: 

149 # No expiration, so assume we don't need to refresh. 

150 return None 

151 

152 now = self._time_fetcher() 

153 if now < self._next_refresh: 

154 return None 

155 

156 remaining = total_seconds(expiration - now) 

157 

158 if remaining < self._mandatory_refresh_timeout: 

159 return "mandatory" 

160 elif remaining < self._advisory_refresh_timeout: 

161 return "advisory" 

162 

163 return None 

164 

165 

166class TokenProviderChain: 

167 def __init__(self, providers=None): 

168 if providers is None: 

169 providers = [] 

170 self._providers = providers 

171 

172 def load_token(self, **kwargs): 

173 for provider in self._providers: 

174 token = provider.load_token(**kwargs) 

175 if token is not None: 

176 return token 

177 return None 

178 

179 

180class SSOTokenProvider: 

181 METHOD = "sso" 

182 _REFRESH_WINDOW = 15 * 60 

183 _SSO_TOKEN_CACHE_DIR = os.path.expanduser( 

184 os.path.join("~", ".aws", "sso", "cache") 

185 ) 

186 _SSO_CONFIG_VARS = [ 

187 "sso_start_url", 

188 "sso_region", 

189 ] 

190 _GRANT_TYPE = "refresh_token" 

191 DEFAULT_CACHE_CLS = JSONFileCache 

192 

193 def __init__( 

194 self, session, cache=None, time_fetcher=_utc_now, profile_name=None 

195 ): 

196 self._session = session 

197 if cache is None: 

198 cache = self.DEFAULT_CACHE_CLS( 

199 self._SSO_TOKEN_CACHE_DIR, 

200 dumps_func=_sso_json_dumps, 

201 ) 

202 self._now = time_fetcher 

203 self._cache = cache 

204 self._token_loader = SSOTokenLoader(cache=self._cache) 

205 self._profile_name = ( 

206 profile_name 

207 or self._session.get_config_variable("profile") 

208 or 'default' 

209 ) 

210 

211 def _load_sso_config(self): 

212 loaded_config = self._session.full_config 

213 profiles = loaded_config.get("profiles", {}) 

214 sso_sessions = loaded_config.get("sso_sessions", {}) 

215 profile_config = profiles.get(self._profile_name, {}) 

216 

217 if "sso_session" not in profile_config: 

218 return 

219 

220 sso_session_name = profile_config["sso_session"] 

221 sso_config = sso_sessions.get(sso_session_name, None) 

222 

223 if not sso_config: 

224 error_msg = ( 

225 f'The profile "{self._profile_name}" is configured to use the SSO ' 

226 f'token provider but the "{sso_session_name}" sso_session ' 

227 f"configuration does not exist." 

228 ) 

229 raise InvalidConfigError(error_msg=error_msg) 

230 

231 missing_configs = [] 

232 for var in self._SSO_CONFIG_VARS: 

233 if var not in sso_config: 

234 missing_configs.append(var) 

235 

236 if missing_configs: 

237 error_msg = ( 

238 f'The profile "{self._profile_name}" is configured to use the SSO ' 

239 f"token provider but is missing the following configuration: " 

240 f"{missing_configs}." 

241 ) 

242 raise InvalidConfigError(error_msg=error_msg) 

243 

244 return { 

245 "session_name": sso_session_name, 

246 "sso_region": sso_config["sso_region"], 

247 "sso_start_url": sso_config["sso_start_url"], 

248 } 

249 

250 @CachedProperty 

251 def _sso_config(self): 

252 return self._load_sso_config() 

253 

254 @CachedProperty 

255 def _client(self): 

256 config = Config( 

257 region_name=self._sso_config["sso_region"], 

258 signature_version=UNSIGNED, 

259 ) 

260 return create_nested_client(self._session, "sso-oidc", config=config) 

261 

262 def _attempt_create_token(self, token): 

263 response = self._client.create_token( 

264 grantType=self._GRANT_TYPE, 

265 clientId=token["clientId"], 

266 clientSecret=token["clientSecret"], 

267 refreshToken=token["refreshToken"], 

268 ) 

269 expires_in = timedelta(seconds=response["expiresIn"]) 

270 new_token = { 

271 "startUrl": self._sso_config["sso_start_url"], 

272 "region": self._sso_config["sso_region"], 

273 "accessToken": response["accessToken"], 

274 "expiresAt": self._now() + expires_in, 

275 # Cache the registration alongside the token 

276 "clientId": token["clientId"], 

277 "clientSecret": token["clientSecret"], 

278 "registrationExpiresAt": token["registrationExpiresAt"], 

279 } 

280 if "refreshToken" in response: 

281 new_token["refreshToken"] = response["refreshToken"] 

282 logger.info("SSO Token refresh succeeded") 

283 return new_token 

284 

285 def _refresh_access_token(self, token): 

286 keys = ( 

287 "refreshToken", 

288 "clientId", 

289 "clientSecret", 

290 "registrationExpiresAt", 

291 ) 

292 missing_keys = [k for k in keys if k not in token] 

293 if missing_keys: 

294 msg = f"Unable to refresh SSO token: missing keys: {missing_keys}" 

295 logger.info(msg) 

296 return None 

297 

298 expiry = dateutil.parser.parse(token["registrationExpiresAt"]) 

299 if total_seconds(expiry - self._now()) <= 0: 

300 logger.info("SSO token registration expired at %s", expiry) 

301 return None 

302 

303 try: 

304 return self._attempt_create_token(token) 

305 except ClientError: 

306 logger.warning("SSO token refresh attempt failed", exc_info=True) 

307 return None 

308 

309 def _refresher(self): 

310 start_url = self._sso_config["sso_start_url"] 

311 session_name = self._sso_config["session_name"] 

312 logger.info("Loading cached SSO token for %s", session_name) 

313 token_dict = self._token_loader(start_url, session_name=session_name) 

314 expiration = dateutil.parser.parse(token_dict["expiresAt"]) 

315 logger.debug("Cached SSO token expires at %s", expiration) 

316 

317 remaining = total_seconds(expiration - self._now()) 

318 if remaining < self._REFRESH_WINDOW: 

319 new_token_dict = self._refresh_access_token(token_dict) 

320 if new_token_dict is not None: 

321 token_dict = new_token_dict 

322 expiration = token_dict["expiresAt"] 

323 self._token_loader.save_token( 

324 start_url, token_dict, session_name=session_name 

325 ) 

326 

327 return FrozenAuthToken( 

328 token_dict["accessToken"], expiration=expiration 

329 ) 

330 

331 def load_token(self, **kwargs): 

332 if self._sso_config is None: 

333 return None 

334 

335 return DeferredRefreshableToken( 

336 self.METHOD, self._refresher, time_fetcher=self._now 

337 ) 

338 

339 

340class ScopedEnvTokenProvider: 

341 """ 

342 Token provider that loads tokens from environment variables scoped to 

343 a specific `signing_name`. 

344 """ 

345 

346 METHOD = 'env' 

347 

348 def __init__(self, session, environ=None): 

349 self._session = session 

350 if environ is None: 

351 environ = os.environ 

352 self.environ = environ 

353 

354 def load_token(self, **kwargs): 

355 signing_name = kwargs.get("signing_name") 

356 if signing_name is None: 

357 return None 

358 

359 token = get_token_from_environment(signing_name, self.environ) 

360 

361 if token is not None: 

362 logger.info("Found token in environment variables.") 

363 return FrozenAuthToken(token)