1# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"). You
4# may not use this file except in compliance with the License. A copy of
5# the License is located at
6#
7# http://aws.amazon.com/apache2.0/
8#
9# or in the "license" file accompanying this file. This file is
10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11# ANY KIND, either express or implied. See the License for the specific
12# language governing permissions and limitations under the License.
13import logging
14import time
15import weakref
16
17from botocore import xform_name
18from botocore.exceptions import BotoCoreError, ConnectionError, HTTPClientError
19from botocore.model import OperationNotFoundError
20from botocore.utils import CachedProperty
21
22logger = logging.getLogger(__name__)
23
24
25class EndpointDiscoveryException(BotoCoreError):
26 pass
27
28
29class EndpointDiscoveryRequired(EndpointDiscoveryException):
30 """Endpoint Discovery is disabled but is required for this operation."""
31
32 fmt = 'Endpoint Discovery is not enabled but this operation requires it.'
33
34
35class EndpointDiscoveryRefreshFailed(EndpointDiscoveryException):
36 """Endpoint Discovery failed to the refresh the known endpoints."""
37
38 fmt = 'Endpoint Discovery failed to refresh the required endpoints.'
39
40
41def block_endpoint_discovery_required_operations(model, **kwargs):
42 endpoint_discovery = model.endpoint_discovery
43 if endpoint_discovery and endpoint_discovery.get('required'):
44 raise EndpointDiscoveryRequired()
45
46
47class EndpointDiscoveryModel:
48 def __init__(self, service_model):
49 self._service_model = service_model
50
51 @CachedProperty
52 def discovery_operation_name(self):
53 discovery_operation = self._service_model.endpoint_discovery_operation
54 return xform_name(discovery_operation.name)
55
56 @CachedProperty
57 def discovery_operation_keys(self):
58 discovery_operation = self._service_model.endpoint_discovery_operation
59 keys = []
60 if discovery_operation.input_shape:
61 keys = list(discovery_operation.input_shape.members.keys())
62 return keys
63
64 def discovery_required_for(self, operation_name):
65 try:
66 operation_model = self._service_model.operation_model(
67 operation_name
68 )
69 return operation_model.endpoint_discovery.get('required', False)
70 except OperationNotFoundError:
71 return False
72
73 def discovery_operation_kwargs(self, **kwargs):
74 input_keys = self.discovery_operation_keys
75 # Operation and Identifiers are only sent if there are Identifiers
76 if not kwargs.get('Identifiers'):
77 kwargs.pop('Operation', None)
78 kwargs.pop('Identifiers', None)
79 return {k: v for k, v in kwargs.items() if k in input_keys}
80
81 def gather_identifiers(self, operation, params):
82 return self._gather_ids(operation.input_shape, params)
83
84 def _gather_ids(self, shape, params, ids=None):
85 # Traverse the input shape and corresponding parameters, gathering
86 # any input fields labeled as an endpoint discovery id
87 if ids is None:
88 ids = {}
89 for member_name, member_shape in shape.members.items():
90 if member_shape.metadata.get('endpointdiscoveryid'):
91 ids[member_name] = params[member_name]
92 elif (
93 member_shape.type_name == 'structure' and member_name in params
94 ):
95 self._gather_ids(member_shape, params[member_name], ids)
96 return ids
97
98
99class EndpointDiscoveryManager:
100 def __init__(
101 self, client, cache=None, current_time=None, always_discover=True
102 ):
103 if cache is None:
104 cache = {}
105 self._cache = cache
106 self._failed_attempts = {}
107 if current_time is None:
108 current_time = time.time
109 self._time = current_time
110 self._always_discover = always_discover
111
112 # This needs to be a weak ref in order to prevent memory leaks on
113 # python 2.6
114 self._client = weakref.proxy(client)
115 self._model = EndpointDiscoveryModel(client.meta.service_model)
116
117 def _parse_endpoints(self, response):
118 endpoints = response['Endpoints']
119 current_time = self._time()
120 for endpoint in endpoints:
121 cache_time = endpoint.get('CachePeriodInMinutes')
122 endpoint['Expiration'] = current_time + cache_time * 60
123 return endpoints
124
125 def _cache_item(self, value):
126 if isinstance(value, dict):
127 return tuple(sorted(value.items()))
128 else:
129 return value
130
131 def _create_cache_key(self, **kwargs):
132 kwargs = self._model.discovery_operation_kwargs(**kwargs)
133 return tuple(self._cache_item(v) for k, v in sorted(kwargs.items()))
134
135 def gather_identifiers(self, operation, params):
136 return self._model.gather_identifiers(operation, params)
137
138 def delete_endpoints(self, **kwargs):
139 cache_key = self._create_cache_key(**kwargs)
140 if cache_key in self._cache:
141 del self._cache[cache_key]
142
143 def _describe_endpoints(self, **kwargs):
144 # This is effectively a proxy to whatever name/kwargs the service
145 # supports for endpoint discovery.
146 kwargs = self._model.discovery_operation_kwargs(**kwargs)
147 operation_name = self._model.discovery_operation_name
148 discovery_operation = getattr(self._client, operation_name)
149 logger.debug('Discovering endpoints with kwargs: %s', kwargs)
150 return discovery_operation(**kwargs)
151
152 def _get_current_endpoints(self, key):
153 if key not in self._cache:
154 return None
155 now = self._time()
156 return [e for e in self._cache[key] if now < e['Expiration']]
157
158 def _refresh_current_endpoints(self, **kwargs):
159 cache_key = self._create_cache_key(**kwargs)
160 try:
161 response = self._describe_endpoints(**kwargs)
162 endpoints = self._parse_endpoints(response)
163 self._cache[cache_key] = endpoints
164 self._failed_attempts.pop(cache_key, None)
165 return endpoints
166 except (ConnectionError, HTTPClientError):
167 self._failed_attempts[cache_key] = self._time() + 60
168 return None
169
170 def _recently_failed(self, cache_key):
171 if cache_key in self._failed_attempts:
172 now = self._time()
173 if now < self._failed_attempts[cache_key]:
174 return True
175 del self._failed_attempts[cache_key]
176 return False
177
178 def _select_endpoint(self, endpoints):
179 return endpoints[0]['Address']
180
181 def describe_endpoint(self, **kwargs):
182 operation = kwargs['Operation']
183 discovery_required = self._model.discovery_required_for(operation)
184
185 if not self._always_discover and not discovery_required:
186 # Discovery set to only run on required operations
187 logger.debug(
188 f'Optional discovery disabled. Skipping discovery for Operation: {operation}'
189 )
190 return None
191
192 # Get the endpoint for the provided operation and identifiers
193 cache_key = self._create_cache_key(**kwargs)
194 endpoints = self._get_current_endpoints(cache_key)
195 if endpoints:
196 return self._select_endpoint(endpoints)
197 # All known endpoints are stale
198 recently_failed = self._recently_failed(cache_key)
199 if not recently_failed:
200 # We haven't failed to discover recently, go ahead and refresh
201 endpoints = self._refresh_current_endpoints(**kwargs)
202 if endpoints:
203 return self._select_endpoint(endpoints)
204 # Discovery has failed recently, do our best to get an endpoint
205 logger.debug('Endpoint Discovery has failed for: %s', kwargs)
206 stale_entries = self._cache.get(cache_key, None)
207 if stale_entries:
208 # We have stale entries, use those while discovery is failing
209 return self._select_endpoint(stale_entries)
210 if discovery_required:
211 # It looks strange to be checking recently_failed again but,
212 # this informs us as to whether or not we tried to refresh earlier
213 if recently_failed:
214 # Discovery is required and we haven't already refreshed
215 endpoints = self._refresh_current_endpoints(**kwargs)
216 if endpoints:
217 return self._select_endpoint(endpoints)
218 # No endpoints even refresh, raise hard error
219 raise EndpointDiscoveryRefreshFailed()
220 # Discovery is optional, just use the default endpoint for now
221 return None
222
223
224class EndpointDiscoveryHandler:
225 def __init__(self, manager):
226 self._manager = manager
227
228 def register(self, events, service_id):
229 events.register(
230 f'before-parameter-build.{service_id}', self.gather_identifiers
231 )
232 events.register_first(
233 f'request-created.{service_id}', self.discover_endpoint
234 )
235 events.register(f'needs-retry.{service_id}', self.handle_retries)
236
237 def gather_identifiers(self, params, model, context, **kwargs):
238 endpoint_discovery = model.endpoint_discovery
239 # Only continue if the operation supports endpoint discovery
240 if endpoint_discovery is None:
241 return
242 ids = self._manager.gather_identifiers(model, params)
243 context['discovery'] = {'identifiers': ids}
244
245 def discover_endpoint(self, request, operation_name, **kwargs):
246 ids = request.context.get('discovery', {}).get('identifiers')
247 if ids is None:
248 return
249 endpoint = self._manager.describe_endpoint(
250 Operation=operation_name, Identifiers=ids
251 )
252 if endpoint is None:
253 logger.debug('Failed to discover and inject endpoint')
254 return
255 if not endpoint.startswith('http'):
256 endpoint = 'https://' + endpoint
257 logger.debug('Injecting discovered endpoint: %s', endpoint)
258 request.url = endpoint
259
260 def handle_retries(self, request_dict, response, operation, **kwargs):
261 if response is None:
262 return None
263
264 _, response = response
265 status = response.get('ResponseMetadata', {}).get('HTTPStatusCode')
266 error_code = response.get('Error', {}).get('Code')
267 if status != 421 and error_code != 'InvalidEndpointException':
268 return None
269
270 context = request_dict.get('context', {})
271 ids = context.get('discovery', {}).get('identifiers')
272 if ids is None:
273 return None
274
275 # Delete the cached endpoints, forcing a refresh on retry
276 # TODO: Improve eviction behavior to only evict the bad endpoint if
277 # there are multiple. This will almost certainly require a lock.
278 self._manager.delete_endpoints(
279 Operation=operation.name, Identifiers=ids
280 )
281 return 0