Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/botocore/tokens.py: 28%

182 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 06:51 +0000

1# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"). You 

4# may not use this file except in compliance with the License. A copy of 

5# the License is located at 

6# 

7# http://aws.amazon.com/apache2.0/ 

8# 

9# or in the "license" file accompanying this file. This file is 

10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 

11# ANY KIND, either express or implied. See the License for the specific 

12# language governing permissions and limitations under the License. 

13import json 

14import logging 

15import os 

16import threading 

17from datetime import datetime, timedelta 

18from typing import NamedTuple, Optional 

19 

20import dateutil.parser 

21from dateutil.tz import tzutc 

22 

23from botocore import UNSIGNED 

24from botocore.compat import total_seconds 

25from botocore.config import Config 

26from botocore.exceptions import ( 

27 ClientError, 

28 InvalidConfigError, 

29 TokenRetrievalError, 

30) 

31from botocore.utils import CachedProperty, JSONFileCache, SSOTokenLoader 

32 

33logger = logging.getLogger(__name__) 

34 

35 

36def _utc_now(): 

37 return datetime.now(tzutc()) 

38 

39 

40def create_token_resolver(session): 

41 providers = [ 

42 SSOTokenProvider(session), 

43 ] 

44 return TokenProviderChain(providers=providers) 

45 

46 

47def _serialize_utc_timestamp(obj): 

48 if isinstance(obj, datetime): 

49 return obj.strftime("%Y-%m-%dT%H:%M:%SZ") 

50 return obj 

51 

52 

53def _sso_json_dumps(obj): 

54 return json.dumps(obj, default=_serialize_utc_timestamp) 

55 

56 

57class FrozenAuthToken(NamedTuple): 

58 token: str 

59 expiration: Optional[datetime] = None 

60 

61 

62class DeferredRefreshableToken: 

63 # The time at which we'll attempt to refresh, but not block if someone else 

64 # is refreshing. 

65 _advisory_refresh_timeout = 15 * 60 

66 # The time at which all threads will block waiting for a refreshed token 

67 _mandatory_refresh_timeout = 10 * 60 

68 # Refresh at most once every minute to avoid blocking every request 

69 _attempt_timeout = 60 

70 

71 def __init__(self, method, refresh_using, time_fetcher=_utc_now): 

72 self._time_fetcher = time_fetcher 

73 self._refresh_using = refresh_using 

74 self.method = method 

75 

76 # The frozen token is protected by this lock 

77 self._refresh_lock = threading.Lock() 

78 self._frozen_token = None 

79 self._next_refresh = None 

80 

81 def get_frozen_token(self): 

82 self._refresh() 

83 return self._frozen_token 

84 

85 def _refresh(self): 

86 # If we don't need to refresh just return 

87 refresh_type = self._should_refresh() 

88 if not refresh_type: 

89 return None 

90 

91 # Block for refresh if we're in the mandatory refresh window 

92 block_for_refresh = refresh_type == "mandatory" 

93 if self._refresh_lock.acquire(block_for_refresh): 

94 try: 

95 self._protected_refresh() 

96 finally: 

97 self._refresh_lock.release() 

98 

99 def _protected_refresh(self): 

100 # This should only be called after acquiring the refresh lock 

101 # Another thread may have already refreshed, double check refresh 

102 refresh_type = self._should_refresh() 

103 if not refresh_type: 

104 return None 

105 

106 try: 

107 now = self._time_fetcher() 

108 self._next_refresh = now + timedelta(seconds=self._attempt_timeout) 

109 self._frozen_token = self._refresh_using() 

110 except Exception: 

111 logger.warning( 

112 "Refreshing token failed during the %s refresh period.", 

113 refresh_type, 

114 exc_info=True, 

115 ) 

116 if refresh_type == "mandatory": 

117 # This refresh was mandatory, error must be propagated back 

118 raise 

119 

120 if self._is_expired(): 

121 # Fresh credentials should never be expired 

122 raise TokenRetrievalError( 

123 provider=self.method, 

124 error_msg="Token has expired and refresh failed", 

125 ) 

126 

127 def _is_expired(self): 

128 if self._frozen_token is None: 

129 return False 

130 

131 expiration = self._frozen_token.expiration 

132 remaining = total_seconds(expiration - self._time_fetcher()) 

133 return remaining <= 0 

134 

135 def _should_refresh(self): 

136 if self._frozen_token is None: 

137 # We don't have a token yet, mandatory refresh 

138 return "mandatory" 

139 

140 expiration = self._frozen_token.expiration 

141 if expiration is None: 

142 # No expiration, so assume we don't need to refresh. 

143 return None 

144 

145 now = self._time_fetcher() 

146 if now < self._next_refresh: 

147 return None 

148 

149 remaining = total_seconds(expiration - now) 

150 

151 if remaining < self._mandatory_refresh_timeout: 

152 return "mandatory" 

153 elif remaining < self._advisory_refresh_timeout: 

154 return "advisory" 

155 

156 return None 

157 

158 

159class TokenProviderChain: 

160 def __init__(self, providers=None): 

161 if providers is None: 

162 providers = [] 

163 self._providers = providers 

164 

165 def load_token(self): 

166 for provider in self._providers: 

167 token = provider.load_token() 

168 if token is not None: 

169 return token 

170 return None 

171 

172 

173class SSOTokenProvider: 

174 METHOD = "sso" 

175 _REFRESH_WINDOW = 15 * 60 

176 _SSO_TOKEN_CACHE_DIR = os.path.expanduser( 

177 os.path.join("~", ".aws", "sso", "cache") 

178 ) 

179 _SSO_CONFIG_VARS = [ 

180 "sso_start_url", 

181 "sso_region", 

182 ] 

183 _GRANT_TYPE = "refresh_token" 

184 DEFAULT_CACHE_CLS = JSONFileCache 

185 

186 def __init__( 

187 self, session, cache=None, time_fetcher=_utc_now, profile_name=None 

188 ): 

189 self._session = session 

190 if cache is None: 

191 cache = self.DEFAULT_CACHE_CLS( 

192 self._SSO_TOKEN_CACHE_DIR, 

193 dumps_func=_sso_json_dumps, 

194 ) 

195 self._now = time_fetcher 

196 self._cache = cache 

197 self._token_loader = SSOTokenLoader(cache=self._cache) 

198 self._profile_name = ( 

199 profile_name 

200 or self._session.get_config_variable("profile") 

201 or 'default' 

202 ) 

203 

204 def _load_sso_config(self): 

205 loaded_config = self._session.full_config 

206 profiles = loaded_config.get("profiles", {}) 

207 sso_sessions = loaded_config.get("sso_sessions", {}) 

208 profile_config = profiles.get(self._profile_name, {}) 

209 

210 if "sso_session" not in profile_config: 

211 return 

212 

213 sso_session_name = profile_config["sso_session"] 

214 sso_config = sso_sessions.get(sso_session_name, None) 

215 

216 if not sso_config: 

217 error_msg = ( 

218 f'The profile "{self._profile_name}" is configured to use the SSO ' 

219 f'token provider but the "{sso_session_name}" sso_session ' 

220 f"configuration does not exist." 

221 ) 

222 raise InvalidConfigError(error_msg=error_msg) 

223 

224 missing_configs = [] 

225 for var in self._SSO_CONFIG_VARS: 

226 if var not in sso_config: 

227 missing_configs.append(var) 

228 

229 if missing_configs: 

230 error_msg = ( 

231 f'The profile "{self._profile_name}" is configured to use the SSO ' 

232 f"token provider but is missing the following configuration: " 

233 f"{missing_configs}." 

234 ) 

235 raise InvalidConfigError(error_msg=error_msg) 

236 

237 return { 

238 "session_name": sso_session_name, 

239 "sso_region": sso_config["sso_region"], 

240 "sso_start_url": sso_config["sso_start_url"], 

241 } 

242 

243 @CachedProperty 

244 def _sso_config(self): 

245 return self._load_sso_config() 

246 

247 @CachedProperty 

248 def _client(self): 

249 config = Config( 

250 region_name=self._sso_config["sso_region"], 

251 signature_version=UNSIGNED, 

252 ) 

253 return self._session.create_client("sso-oidc", config=config) 

254 

255 def _attempt_create_token(self, token): 

256 response = self._client.create_token( 

257 grantType=self._GRANT_TYPE, 

258 clientId=token["clientId"], 

259 clientSecret=token["clientSecret"], 

260 refreshToken=token["refreshToken"], 

261 ) 

262 expires_in = timedelta(seconds=response["expiresIn"]) 

263 new_token = { 

264 "startUrl": self._sso_config["sso_start_url"], 

265 "region": self._sso_config["sso_region"], 

266 "accessToken": response["accessToken"], 

267 "expiresAt": self._now() + expires_in, 

268 # Cache the registration alongside the token 

269 "clientId": token["clientId"], 

270 "clientSecret": token["clientSecret"], 

271 "registrationExpiresAt": token["registrationExpiresAt"], 

272 } 

273 if "refreshToken" in response: 

274 new_token["refreshToken"] = response["refreshToken"] 

275 logger.info("SSO Token refresh succeeded") 

276 return new_token 

277 

278 def _refresh_access_token(self, token): 

279 keys = ( 

280 "refreshToken", 

281 "clientId", 

282 "clientSecret", 

283 "registrationExpiresAt", 

284 ) 

285 missing_keys = [k for k in keys if k not in token] 

286 if missing_keys: 

287 msg = f"Unable to refresh SSO token: missing keys: {missing_keys}" 

288 logger.info(msg) 

289 return None 

290 

291 expiry = dateutil.parser.parse(token["registrationExpiresAt"]) 

292 if total_seconds(expiry - self._now()) <= 0: 

293 logger.info(f"SSO token registration expired at {expiry}") 

294 return None 

295 

296 try: 

297 return self._attempt_create_token(token) 

298 except ClientError: 

299 logger.warning("SSO token refresh attempt failed", exc_info=True) 

300 return None 

301 

302 def _refresher(self): 

303 start_url = self._sso_config["sso_start_url"] 

304 session_name = self._sso_config["session_name"] 

305 logger.info(f"Loading cached SSO token for {session_name}") 

306 token_dict = self._token_loader(start_url, session_name=session_name) 

307 expiration = dateutil.parser.parse(token_dict["expiresAt"]) 

308 logger.debug(f"Cached SSO token expires at {expiration}") 

309 

310 remaining = total_seconds(expiration - self._now()) 

311 if remaining < self._REFRESH_WINDOW: 

312 new_token_dict = self._refresh_access_token(token_dict) 

313 if new_token_dict is not None: 

314 token_dict = new_token_dict 

315 expiration = token_dict["expiresAt"] 

316 self._token_loader.save_token( 

317 start_url, token_dict, session_name=session_name 

318 ) 

319 

320 return FrozenAuthToken( 

321 token_dict["accessToken"], expiration=expiration 

322 ) 

323 

324 def load_token(self): 

325 if self._sso_config is None: 

326 return None 

327 

328 return DeferredRefreshableToken( 

329 self.METHOD, self._refresher, time_fetcher=self._now 

330 )