1# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"). You
4# may not use this file except in compliance with the License. A copy of
5# the License is located at
6#
7# http://aws.amazon.com/apache2.0/
8#
9# or in the "license" file accompanying this file. This file is
10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11# ANY KIND, either express or implied. See the License for the specific
12# language governing permissions and limitations under the License.
13
14"""The interfaces in this module are not intended for public use.
15
16This module defines interfaces for applying checksums to HTTP requests within
17the context of botocore. This involves both resolving the checksum to be used
18based on client configuration and environment, as well as application of the
19checksum to the request.
20"""
21
22import base64
23import io
24import logging
25from binascii import crc32
26from hashlib import sha1, sha256
27
28from botocore.compat import HAS_CRT
29from botocore.exceptions import (
30 AwsChunkedWrapperError,
31 FlexibleChecksumError,
32 MissingDependencyException,
33)
34from botocore.response import StreamingBody
35from botocore.utils import (
36 conditionally_calculate_md5,
37 determine_content_length,
38)
39
40if HAS_CRT:
41 from awscrt import checksums as crt_checksums
42else:
43 crt_checksums = None
44
45logger = logging.getLogger(__name__)
46
47
48class BaseChecksum:
49 _CHUNK_SIZE = 1024 * 1024
50
51 def update(self, chunk):
52 pass
53
54 def digest(self):
55 pass
56
57 def b64digest(self):
58 bs = self.digest()
59 return base64.b64encode(bs).decode("ascii")
60
61 def _handle_fileobj(self, fileobj):
62 start_position = fileobj.tell()
63 for chunk in iter(lambda: fileobj.read(self._CHUNK_SIZE), b""):
64 self.update(chunk)
65 fileobj.seek(start_position)
66
67 def handle(self, body):
68 if isinstance(body, (bytes, bytearray)):
69 self.update(body)
70 else:
71 self._handle_fileobj(body)
72 return self.b64digest()
73
74
75class Crc32Checksum(BaseChecksum):
76 def __init__(self):
77 self._int_crc32 = 0
78
79 def update(self, chunk):
80 self._int_crc32 = crc32(chunk, self._int_crc32) & 0xFFFFFFFF
81
82 def digest(self):
83 return self._int_crc32.to_bytes(4, byteorder="big")
84
85
86class CrtCrc32Checksum(BaseChecksum):
87 # Note: This class is only used if the CRT is available
88 def __init__(self):
89 self._int_crc32 = 0
90
91 def update(self, chunk):
92 new_checksum = crt_checksums.crc32(chunk, self._int_crc32)
93 self._int_crc32 = new_checksum & 0xFFFFFFFF
94
95 def digest(self):
96 return self._int_crc32.to_bytes(4, byteorder="big")
97
98
99class CrtCrc32cChecksum(BaseChecksum):
100 # Note: This class is only used if the CRT is available
101 def __init__(self):
102 self._int_crc32c = 0
103
104 def update(self, chunk):
105 new_checksum = crt_checksums.crc32c(chunk, self._int_crc32c)
106 self._int_crc32c = new_checksum & 0xFFFFFFFF
107
108 def digest(self):
109 return self._int_crc32c.to_bytes(4, byteorder="big")
110
111
112class Sha1Checksum(BaseChecksum):
113 def __init__(self):
114 self._checksum = sha1()
115
116 def update(self, chunk):
117 self._checksum.update(chunk)
118
119 def digest(self):
120 return self._checksum.digest()
121
122
123class Sha256Checksum(BaseChecksum):
124 def __init__(self):
125 self._checksum = sha256()
126
127 def update(self, chunk):
128 self._checksum.update(chunk)
129
130 def digest(self):
131 return self._checksum.digest()
132
133
134class AwsChunkedWrapper:
135 _DEFAULT_CHUNK_SIZE = 1024 * 1024
136
137 def __init__(
138 self,
139 raw,
140 checksum_cls=None,
141 checksum_name="x-amz-checksum",
142 chunk_size=None,
143 ):
144 self._raw = raw
145 self._checksum_name = checksum_name
146 self._checksum_cls = checksum_cls
147 self._reset()
148
149 if chunk_size is None:
150 chunk_size = self._DEFAULT_CHUNK_SIZE
151 self._chunk_size = chunk_size
152
153 def _reset(self):
154 self._remaining = b""
155 self._complete = False
156 self._checksum = None
157 if self._checksum_cls:
158 self._checksum = self._checksum_cls()
159
160 def seek(self, offset, whence=0):
161 if offset != 0 or whence != 0:
162 raise AwsChunkedWrapperError(
163 error_msg="Can only seek to start of stream"
164 )
165 self._reset()
166 self._raw.seek(0)
167
168 def read(self, size=None):
169 # Normalize "read all" size values to None
170 if size is not None and size <= 0:
171 size = None
172
173 # If the underlying body is done and we have nothing left then
174 # end the stream
175 if self._complete and not self._remaining:
176 return b""
177
178 # While we're not done and want more bytes
179 want_more_bytes = size is None or size > len(self._remaining)
180 while not self._complete and want_more_bytes:
181 self._remaining += self._make_chunk()
182 want_more_bytes = size is None or size > len(self._remaining)
183
184 # If size was None, we want to return everything
185 if size is None:
186 size = len(self._remaining)
187
188 # Return a chunk up to the size asked for
189 to_return = self._remaining[:size]
190 self._remaining = self._remaining[size:]
191 return to_return
192
193 def _make_chunk(self):
194 # NOTE: Chunk size is not deterministic as read could return less. This
195 # means we cannot know the content length of the encoded aws-chunked
196 # stream ahead of time without ensuring a consistent chunk size
197 raw_chunk = self._raw.read(self._chunk_size)
198 hex_len = hex(len(raw_chunk))[2:].encode("ascii")
199 self._complete = not raw_chunk
200
201 if self._checksum:
202 self._checksum.update(raw_chunk)
203
204 if self._checksum and self._complete:
205 name = self._checksum_name.encode("ascii")
206 checksum = self._checksum.b64digest().encode("ascii")
207 return b"0\r\n%s:%s\r\n\r\n" % (name, checksum)
208
209 return b"%s\r\n%s\r\n" % (hex_len, raw_chunk)
210
211 def __iter__(self):
212 while not self._complete:
213 yield self._make_chunk()
214
215
216class StreamingChecksumBody(StreamingBody):
217 def __init__(self, raw_stream, content_length, checksum, expected):
218 super().__init__(raw_stream, content_length)
219 self._checksum = checksum
220 self._expected = expected
221
222 def read(self, amt=None):
223 chunk = super().read(amt=amt)
224 self._checksum.update(chunk)
225 if amt is None or (not chunk and amt > 0):
226 self._validate_checksum()
227 return chunk
228
229 def _validate_checksum(self):
230 if self._checksum.digest() != base64.b64decode(self._expected):
231 error_msg = (
232 f"Expected checksum {self._expected} did not match calculated "
233 f"checksum: {self._checksum.b64digest()}"
234 )
235 raise FlexibleChecksumError(error_msg=error_msg)
236
237
238def resolve_checksum_context(request, operation_model, params):
239 resolve_request_checksum_algorithm(request, operation_model, params)
240 resolve_response_checksum_algorithms(request, operation_model, params)
241
242
243def resolve_request_checksum_algorithm(
244 request,
245 operation_model,
246 params,
247 supported_algorithms=None,
248):
249 http_checksum = operation_model.http_checksum
250 algorithm_member = http_checksum.get("requestAlgorithmMember")
251 if algorithm_member and algorithm_member in params:
252 # If the client has opted into using flexible checksums and the
253 # request supports it, use that instead of checksum required
254 if supported_algorithms is None:
255 supported_algorithms = _SUPPORTED_CHECKSUM_ALGORITHMS
256
257 algorithm_name = params[algorithm_member].lower()
258 if algorithm_name not in supported_algorithms:
259 if not HAS_CRT and algorithm_name in _CRT_CHECKSUM_ALGORITHMS:
260 raise MissingDependencyException(
261 msg=(
262 f"Using {algorithm_name.upper()} requires an "
263 "additional dependency. You will need to pip install "
264 "botocore[crt] before proceeding."
265 )
266 )
267 raise FlexibleChecksumError(
268 error_msg=f"Unsupported checksum algorithm: {algorithm_name}"
269 )
270
271 location_type = "header"
272 if operation_model.has_streaming_input:
273 # Operations with streaming input must support trailers.
274 if request["url"].startswith("https:"):
275 # We only support unsigned trailer checksums currently. As this
276 # disables payload signing we'll only use trailers over TLS.
277 location_type = "trailer"
278
279 algorithm = {
280 "algorithm": algorithm_name,
281 "in": location_type,
282 "name": f"x-amz-checksum-{algorithm_name}",
283 }
284
285 if algorithm["name"] in request["headers"]:
286 # If the header is already set by the customer, skip calculation
287 return
288
289 checksum_context = request["context"].get("checksum", {})
290 checksum_context["request_algorithm"] = algorithm
291 request["context"]["checksum"] = checksum_context
292 elif operation_model.http_checksum_required or http_checksum.get(
293 "requestChecksumRequired"
294 ):
295 # Otherwise apply the old http checksum behavior via Content-MD5
296 checksum_context = request["context"].get("checksum", {})
297 checksum_context["request_algorithm"] = "conditional-md5"
298 request["context"]["checksum"] = checksum_context
299
300
301def apply_request_checksum(request):
302 checksum_context = request.get("context", {}).get("checksum", {})
303 algorithm = checksum_context.get("request_algorithm")
304
305 if not algorithm:
306 return
307
308 if algorithm == "conditional-md5":
309 # Special case to handle the http checksum required trait
310 conditionally_calculate_md5(request)
311 elif algorithm["in"] == "header":
312 _apply_request_header_checksum(request)
313 elif algorithm["in"] == "trailer":
314 _apply_request_trailer_checksum(request)
315 else:
316 raise FlexibleChecksumError(
317 error_msg="Unknown checksum variant: {}".format(algorithm["in"])
318 )
319
320
321def _apply_request_header_checksum(request):
322 checksum_context = request.get("context", {}).get("checksum", {})
323 algorithm = checksum_context.get("request_algorithm")
324 location_name = algorithm["name"]
325 if location_name in request["headers"]:
326 # If the header is already set by the customer, skip calculation
327 return
328 checksum_cls = _CHECKSUM_CLS.get(algorithm["algorithm"])
329 digest = checksum_cls().handle(request["body"])
330 request["headers"][location_name] = digest
331
332
333def _apply_request_trailer_checksum(request):
334 checksum_context = request.get("context", {}).get("checksum", {})
335 algorithm = checksum_context.get("request_algorithm")
336 location_name = algorithm["name"]
337 checksum_cls = _CHECKSUM_CLS.get(algorithm["algorithm"])
338
339 headers = request["headers"]
340 body = request["body"]
341
342 if location_name in headers:
343 # If the header is already set by the customer, skip calculation
344 return
345
346 headers["Transfer-Encoding"] = "chunked"
347 if "Content-Encoding" in headers:
348 # We need to preserve the existing content encoding and add
349 # aws-chunked as a new content encoding.
350 headers["Content-Encoding"] += ",aws-chunked"
351 else:
352 headers["Content-Encoding"] = "aws-chunked"
353 headers["X-Amz-Trailer"] = location_name
354
355 content_length = determine_content_length(body)
356 if content_length is not None:
357 # Send the decoded content length if we can determine it. Some
358 # services such as S3 may require the decoded content length
359 headers["X-Amz-Decoded-Content-Length"] = str(content_length)
360
361 if isinstance(body, (bytes, bytearray)):
362 body = io.BytesIO(body)
363
364 request["body"] = AwsChunkedWrapper(
365 body,
366 checksum_cls=checksum_cls,
367 checksum_name=location_name,
368 )
369
370
371def resolve_response_checksum_algorithms(
372 request, operation_model, params, supported_algorithms=None
373):
374 http_checksum = operation_model.http_checksum
375 mode_member = http_checksum.get("requestValidationModeMember")
376 if mode_member and mode_member in params:
377 if supported_algorithms is None:
378 supported_algorithms = _SUPPORTED_CHECKSUM_ALGORITHMS
379 response_algorithms = {
380 a.lower() for a in http_checksum.get("responseAlgorithms", [])
381 }
382
383 usable_algorithms = []
384 for algorithm in _ALGORITHMS_PRIORITY_LIST:
385 if algorithm not in response_algorithms:
386 continue
387 if algorithm in supported_algorithms:
388 usable_algorithms.append(algorithm)
389
390 checksum_context = request["context"].get("checksum", {})
391 checksum_context["response_algorithms"] = usable_algorithms
392 request["context"]["checksum"] = checksum_context
393
394
395def handle_checksum_body(http_response, response, context, operation_model):
396 headers = response["headers"]
397 checksum_context = context.get("checksum", {})
398 algorithms = checksum_context.get("response_algorithms")
399
400 if not algorithms:
401 return
402
403 for algorithm in algorithms:
404 header_name = f"x-amz-checksum-{algorithm}"
405 # If the header is not found, check the next algorithm
406 if header_name not in headers:
407 continue
408
409 # If a - is in the checksum this is not valid Base64. S3 returns
410 # checksums that include a -# suffix to indicate a checksum derived
411 # from the hash of all part checksums. We cannot wrap this response
412 if "-" in headers[header_name]:
413 continue
414
415 if operation_model.has_streaming_output:
416 response["body"] = _handle_streaming_response(
417 http_response, response, algorithm
418 )
419 else:
420 response["body"] = _handle_bytes_response(
421 http_response, response, algorithm
422 )
423
424 # Expose metadata that the checksum check actually occurred
425 checksum_context = response["context"].get("checksum", {})
426 checksum_context["response_algorithm"] = algorithm
427 response["context"]["checksum"] = checksum_context
428 return
429
430 logger.info(
431 f'Skipping checksum validation. Response did not contain one of the '
432 f'following algorithms: {algorithms}.'
433 )
434
435
436def _handle_streaming_response(http_response, response, algorithm):
437 checksum_cls = _CHECKSUM_CLS.get(algorithm)
438 header_name = f"x-amz-checksum-{algorithm}"
439 return StreamingChecksumBody(
440 http_response.raw,
441 response["headers"].get("content-length"),
442 checksum_cls(),
443 response["headers"][header_name],
444 )
445
446
447def _handle_bytes_response(http_response, response, algorithm):
448 body = http_response.content
449 header_name = f"x-amz-checksum-{algorithm}"
450 checksum_cls = _CHECKSUM_CLS.get(algorithm)
451 checksum = checksum_cls()
452 checksum.update(body)
453 expected = response["headers"][header_name]
454 if checksum.digest() != base64.b64decode(expected):
455 error_msg = (
456 f"Expected checksum {expected} did not match calculated "
457 f"checksum: {checksum.b64digest()}"
458 )
459 raise FlexibleChecksumError(error_msg=error_msg)
460 return body
461
462
463_CHECKSUM_CLS = {
464 "crc32": Crc32Checksum,
465 "sha1": Sha1Checksum,
466 "sha256": Sha256Checksum,
467}
468_CRT_CHECKSUM_ALGORITHMS = ["crc32", "crc32c"]
469if HAS_CRT:
470 # Use CRT checksum implementations if available
471 _CRT_CHECKSUM_CLS = {
472 "crc32": CrtCrc32Checksum,
473 "crc32c": CrtCrc32cChecksum,
474 }
475 _CHECKSUM_CLS.update(_CRT_CHECKSUM_CLS)
476 # Validate this list isn't out of sync with _CRT_CHECKSUM_CLS keys
477 assert all(
478 name in _CRT_CHECKSUM_ALGORITHMS for name in _CRT_CHECKSUM_CLS.keys()
479 )
480_SUPPORTED_CHECKSUM_ALGORITHMS = list(_CHECKSUM_CLS.keys())
481_ALGORITHMS_PRIORITY_LIST = ['crc32c', 'crc32', 'sha1', 'sha256']