1# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"). You
4# may not use this file except in compliance with the License. A copy of
5# the License is located at
6#
7# http://aws.amazon.com/apache2.0/
8#
9# or in the "license" file accompanying this file. This file is
10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11# ANY KIND, either express or implied. See the License for the specific
12# language governing permissions and limitations under the License.
13
14"""The interfaces in this module are not intended for public use.
15
16This module defines interfaces for applying checksums to HTTP requests within
17the context of botocore. This involves both resolving the checksum to be used
18based on client configuration and environment, as well as application of the
19checksum to the request.
20"""
21
22import base64
23import io
24import logging
25from binascii import crc32
26from hashlib import sha1, sha256
27
28from botocore.compat import HAS_CRT, has_minimum_crt_version, urlparse
29from botocore.exceptions import (
30 AwsChunkedWrapperError,
31 FlexibleChecksumError,
32 MissingDependencyException,
33)
34from botocore.model import StructureShape
35from botocore.response import StreamingBody
36from botocore.useragent import register_feature_id
37from botocore.utils import (
38 conditionally_calculate_md5,
39 determine_content_length,
40 has_checksum_header,
41)
42
43if HAS_CRT:
44 from awscrt import checksums as crt_checksums
45else:
46 crt_checksums = None
47
48logger = logging.getLogger(__name__)
49
50DEFAULT_CHECKSUM_ALGORITHM = "CRC32"
51
52
53class BaseChecksum:
54 _CHUNK_SIZE = 1024 * 1024
55
56 def update(self, chunk):
57 pass
58
59 def digest(self):
60 pass
61
62 def b64digest(self):
63 bs = self.digest()
64 return base64.b64encode(bs).decode("ascii")
65
66 def _handle_fileobj(self, fileobj):
67 start_position = fileobj.tell()
68 for chunk in iter(lambda: fileobj.read(self._CHUNK_SIZE), b""):
69 self.update(chunk)
70 fileobj.seek(start_position)
71
72 def handle(self, body):
73 if isinstance(body, (bytes, bytearray)):
74 self.update(body)
75 else:
76 self._handle_fileobj(body)
77 return self.b64digest()
78
79
80class Crc32Checksum(BaseChecksum):
81 def __init__(self):
82 self._int_crc32 = 0
83
84 def update(self, chunk):
85 self._int_crc32 = crc32(chunk, self._int_crc32) & 0xFFFFFFFF
86
87 def digest(self):
88 return self._int_crc32.to_bytes(4, byteorder="big")
89
90
91class CrtCrc32Checksum(BaseChecksum):
92 # Note: This class is only used if the CRT is available
93 def __init__(self):
94 self._int_crc32 = 0
95
96 def update(self, chunk):
97 new_checksum = crt_checksums.crc32(chunk, self._int_crc32)
98 self._int_crc32 = new_checksum & 0xFFFFFFFF
99
100 def digest(self):
101 return self._int_crc32.to_bytes(4, byteorder="big")
102
103
104class CrtCrc32cChecksum(BaseChecksum):
105 # Note: This class is only used if the CRT is available
106 def __init__(self):
107 self._int_crc32c = 0
108
109 def update(self, chunk):
110 new_checksum = crt_checksums.crc32c(chunk, self._int_crc32c)
111 self._int_crc32c = new_checksum & 0xFFFFFFFF
112
113 def digest(self):
114 return self._int_crc32c.to_bytes(4, byteorder="big")
115
116
117class CrtCrc64NvmeChecksum(BaseChecksum):
118 # Note: This class is only used if the CRT is available
119 def __init__(self):
120 self._int_crc64nvme = 0
121
122 def update(self, chunk):
123 new_checksum = crt_checksums.crc64nvme(chunk, self._int_crc64nvme)
124 self._int_crc64nvme = new_checksum & 0xFFFFFFFFFFFFFFFF
125
126 def digest(self):
127 return self._int_crc64nvme.to_bytes(8, byteorder="big")
128
129
130class Sha1Checksum(BaseChecksum):
131 def __init__(self):
132 self._checksum = sha1()
133
134 def update(self, chunk):
135 self._checksum.update(chunk)
136
137 def digest(self):
138 return self._checksum.digest()
139
140
141class Sha256Checksum(BaseChecksum):
142 def __init__(self):
143 self._checksum = sha256()
144
145 def update(self, chunk):
146 self._checksum.update(chunk)
147
148 def digest(self):
149 return self._checksum.digest()
150
151
152class AwsChunkedWrapper:
153 _DEFAULT_CHUNK_SIZE = 1024 * 1024
154
155 def __init__(
156 self,
157 raw,
158 checksum_cls=None,
159 checksum_name="x-amz-checksum",
160 chunk_size=None,
161 ):
162 self._raw = raw
163 self._checksum_name = checksum_name
164 self._checksum_cls = checksum_cls
165 self._reset()
166
167 if chunk_size is None:
168 chunk_size = self._DEFAULT_CHUNK_SIZE
169 self._chunk_size = chunk_size
170
171 def _reset(self):
172 self._remaining = b""
173 self._complete = False
174 self._checksum = None
175 if self._checksum_cls:
176 self._checksum = self._checksum_cls()
177
178 def seek(self, offset, whence=0):
179 if offset != 0 or whence != 0:
180 raise AwsChunkedWrapperError(
181 error_msg="Can only seek to start of stream"
182 )
183 self._reset()
184 self._raw.seek(0)
185
186 def read(self, size=None):
187 # Normalize "read all" size values to None
188 if size is not None and size <= 0:
189 size = None
190
191 # If the underlying body is done and we have nothing left then
192 # end the stream
193 if self._complete and not self._remaining:
194 return b""
195
196 # While we're not done and want more bytes
197 want_more_bytes = size is None or size > len(self._remaining)
198 while not self._complete and want_more_bytes:
199 self._remaining += self._make_chunk()
200 want_more_bytes = size is None or size > len(self._remaining)
201
202 # If size was None, we want to return everything
203 if size is None:
204 size = len(self._remaining)
205
206 # Return a chunk up to the size asked for
207 to_return = self._remaining[:size]
208 self._remaining = self._remaining[size:]
209 return to_return
210
211 def _make_chunk(self):
212 # NOTE: Chunk size is not deterministic as read could return less. This
213 # means we cannot know the content length of the encoded aws-chunked
214 # stream ahead of time without ensuring a consistent chunk size
215 raw_chunk = self._raw.read(self._chunk_size)
216 hex_len = hex(len(raw_chunk))[2:].encode("ascii")
217 self._complete = not raw_chunk
218
219 if self._checksum:
220 self._checksum.update(raw_chunk)
221
222 if self._checksum and self._complete:
223 name = self._checksum_name.encode("ascii")
224 checksum = self._checksum.b64digest().encode("ascii")
225 return b"0\r\n%s:%s\r\n\r\n" % (name, checksum)
226
227 return b"%s\r\n%s\r\n" % (hex_len, raw_chunk)
228
229 def __iter__(self):
230 while not self._complete:
231 yield self._make_chunk()
232
233
234class StreamingChecksumBody(StreamingBody):
235 def __init__(self, raw_stream, content_length, checksum, expected):
236 super().__init__(raw_stream, content_length)
237 self._checksum = checksum
238 self._expected = expected
239
240 def read(self, amt=None):
241 chunk = super().read(amt=amt)
242 self._checksum.update(chunk)
243 if amt is None or (not chunk and amt > 0):
244 self._validate_checksum()
245 return chunk
246
247 def readinto(self, b):
248 amount_read = super().readinto(b)
249 if amount_read == len(b):
250 view = b
251 else:
252 view = memoryview(b)[:amount_read]
253 self._checksum.update(view)
254 if amount_read == 0 and len(b) > 0:
255 self._validate_checksum()
256 return amount_read
257
258 def _validate_checksum(self):
259 if self._checksum.digest() != base64.b64decode(self._expected):
260 error_msg = (
261 f"Expected checksum {self._expected} did not match calculated "
262 f"checksum: {self._checksum.b64digest()}"
263 )
264 raise FlexibleChecksumError(error_msg=error_msg)
265
266
267def resolve_checksum_context(request, operation_model, params):
268 resolve_request_checksum_algorithm(request, operation_model, params)
269 resolve_response_checksum_algorithms(request, operation_model, params)
270
271
272def resolve_request_checksum_algorithm(
273 request,
274 operation_model,
275 params,
276 supported_algorithms=None,
277):
278 # If the header is already set by the customer, skip calculation
279 if has_checksum_header(request):
280 return
281
282 checksum_context = request["context"].get("checksum", {})
283 request_checksum_calculation = request["context"][
284 "client_config"
285 ].request_checksum_calculation
286 http_checksum = operation_model.http_checksum
287 request_checksum_required = (
288 operation_model.http_checksum_required
289 or http_checksum.get("requestChecksumRequired")
290 )
291 algorithm_member = http_checksum.get("requestAlgorithmMember")
292 if algorithm_member and algorithm_member in params:
293 # If the client has opted into using flexible checksums and the
294 # request supports it, use that instead of checksum required
295 if supported_algorithms is None:
296 supported_algorithms = _SUPPORTED_CHECKSUM_ALGORITHMS
297
298 algorithm_name = params[algorithm_member].lower()
299 if algorithm_name not in supported_algorithms:
300 if not HAS_CRT and algorithm_name in _CRT_CHECKSUM_ALGORITHMS:
301 raise MissingDependencyException(
302 msg=(
303 f"Using {algorithm_name.upper()} requires an "
304 "additional dependency. You will need to pip install "
305 "botocore[crt] before proceeding."
306 )
307 )
308 raise FlexibleChecksumError(
309 error_msg=f"Unsupported checksum algorithm: {algorithm_name}"
310 )
311 elif request_checksum_required or (
312 algorithm_member and request_checksum_calculation == "when_supported"
313 ):
314 # Don't use a default checksum for presigned requests.
315 if request["context"].get("is_presign_request"):
316 return
317 algorithm_name = DEFAULT_CHECKSUM_ALGORITHM.lower()
318 algorithm_member_header = _get_request_algorithm_member_header(
319 operation_model, request, algorithm_member
320 )
321 if algorithm_member_header is not None:
322 checksum_context["request_algorithm_header"] = {
323 "name": algorithm_member_header,
324 "value": DEFAULT_CHECKSUM_ALGORITHM,
325 }
326 else:
327 return
328
329 location_type = "header"
330 if (
331 operation_model.has_streaming_input
332 and urlparse(request["url"]).scheme == "https"
333 ):
334 if request["context"]["client_config"].signature_version != 's3':
335 # Operations with streaming input must support trailers.
336 # We only support unsigned trailer checksums currently. As this
337 # disables payload signing we'll only use trailers over TLS.
338 location_type = "trailer"
339
340 algorithm = {
341 "algorithm": algorithm_name,
342 "in": location_type,
343 "name": f"x-amz-checksum-{algorithm_name}",
344 }
345
346 checksum_context["request_algorithm"] = algorithm
347 request["context"]["checksum"] = checksum_context
348
349
350def _get_request_algorithm_member_header(
351 operation_model, request, algorithm_member
352):
353 """Get the name of the header targeted by the "requestAlgorithmMember"."""
354 operation_input_shape = operation_model.input_shape
355 if not isinstance(operation_input_shape, StructureShape):
356 return
357
358 algorithm_member_shape = operation_input_shape.members.get(
359 algorithm_member
360 )
361
362 if algorithm_member_shape:
363 return algorithm_member_shape.serialization.get("name")
364
365
366def apply_request_checksum(request):
367 checksum_context = request.get("context", {}).get("checksum", {})
368 algorithm = checksum_context.get("request_algorithm")
369
370 if not algorithm:
371 return
372
373 if algorithm == "conditional-md5":
374 # Special case to handle the http checksum required trait
375 conditionally_calculate_md5(request)
376 elif algorithm["in"] == "header":
377 _apply_request_header_checksum(request)
378 elif algorithm["in"] == "trailer":
379 _apply_request_trailer_checksum(request)
380 else:
381 raise FlexibleChecksumError(
382 error_msg="Unknown checksum variant: {}".format(algorithm["in"])
383 )
384 if "request_algorithm_header" in checksum_context:
385 request_algorithm_header = checksum_context["request_algorithm_header"]
386 request["headers"][request_algorithm_header["name"]] = (
387 request_algorithm_header["value"]
388 )
389
390
391def _apply_request_header_checksum(request):
392 checksum_context = request.get("context", {}).get("checksum", {})
393 algorithm = checksum_context.get("request_algorithm")
394 location_name = algorithm["name"]
395 if location_name in request["headers"]:
396 # If the header is already set by the customer, skip calculation
397 return
398 checksum_cls = _CHECKSUM_CLS.get(algorithm["algorithm"])
399 digest = checksum_cls().handle(request["body"])
400 request["headers"][location_name] = digest
401 _register_checksum_algorithm_feature_id(algorithm)
402
403
404def _apply_request_trailer_checksum(request):
405 checksum_context = request.get("context", {}).get("checksum", {})
406 algorithm = checksum_context.get("request_algorithm")
407 location_name = algorithm["name"]
408 checksum_cls = _CHECKSUM_CLS.get(algorithm["algorithm"])
409
410 headers = request["headers"]
411 body = request["body"]
412
413 if location_name in headers:
414 # If the header is already set by the customer, skip calculation
415 return
416
417 headers["Transfer-Encoding"] = "chunked"
418 if "Content-Encoding" in headers:
419 # We need to preserve the existing content encoding and add
420 # aws-chunked as a new content encoding.
421 headers["Content-Encoding"] += ",aws-chunked"
422 else:
423 headers["Content-Encoding"] = "aws-chunked"
424 headers["X-Amz-Trailer"] = location_name
425 _register_checksum_algorithm_feature_id(algorithm)
426
427 content_length = determine_content_length(body)
428 if content_length is not None:
429 # Send the decoded content length if we can determine it. Some
430 # services such as S3 may require the decoded content length
431 headers["X-Amz-Decoded-Content-Length"] = str(content_length)
432
433 if "Content-Length" in headers:
434 del headers["Content-Length"]
435 logger.debug(
436 "Removing the Content-Length header since 'chunked' is specified for Transfer-Encoding."
437 )
438
439 if isinstance(body, (bytes, bytearray)):
440 body = io.BytesIO(body)
441
442 request["body"] = AwsChunkedWrapper(
443 body,
444 checksum_cls=checksum_cls,
445 checksum_name=location_name,
446 )
447
448
449def _register_checksum_algorithm_feature_id(algorithm):
450 checksum_algorithm_name = algorithm["algorithm"].upper()
451 if checksum_algorithm_name == "CRC64NVME":
452 checksum_algorithm_name = "CRC64"
453 checksum_algorithm_name_feature_id = (
454 f"FLEXIBLE_CHECKSUMS_REQ_{checksum_algorithm_name}"
455 )
456 register_feature_id(checksum_algorithm_name_feature_id)
457
458
459def resolve_response_checksum_algorithms(
460 request, operation_model, params, supported_algorithms=None
461):
462 http_checksum = operation_model.http_checksum
463 mode_member = http_checksum.get("requestValidationModeMember")
464 if mode_member and mode_member in params:
465 if supported_algorithms is None:
466 supported_algorithms = _SUPPORTED_CHECKSUM_ALGORITHMS
467 response_algorithms = {
468 a.lower() for a in http_checksum.get("responseAlgorithms", [])
469 }
470
471 usable_algorithms = []
472 for algorithm in _ALGORITHMS_PRIORITY_LIST:
473 if algorithm not in response_algorithms:
474 continue
475 if algorithm in supported_algorithms:
476 usable_algorithms.append(algorithm)
477
478 checksum_context = request["context"].get("checksum", {})
479 checksum_context["response_algorithms"] = usable_algorithms
480 request["context"]["checksum"] = checksum_context
481
482
483def handle_checksum_body(http_response, response, context, operation_model):
484 headers = response["headers"]
485 checksum_context = context.get("checksum", {})
486 algorithms = checksum_context.get("response_algorithms")
487
488 if not algorithms:
489 return
490
491 for algorithm in algorithms:
492 header_name = f"x-amz-checksum-{algorithm}"
493 # If the header is not found, check the next algorithm
494 if header_name not in headers:
495 continue
496
497 # If a - is in the checksum this is not valid Base64. S3 returns
498 # checksums that include a -# suffix to indicate a checksum derived
499 # from the hash of all part checksums. We cannot wrap this response
500 if "-" in headers[header_name]:
501 continue
502
503 if operation_model.has_streaming_output:
504 response["body"] = _handle_streaming_response(
505 http_response, response, algorithm
506 )
507 else:
508 response["body"] = _handle_bytes_response(
509 http_response, response, algorithm
510 )
511
512 # Expose metadata that the checksum check actually occurred
513 checksum_context = response["context"].get("checksum", {})
514 checksum_context["response_algorithm"] = algorithm
515 response["context"]["checksum"] = checksum_context
516 return
517
518 logger.debug(
519 f'Skipping checksum validation. Response did not contain one of the '
520 f'following algorithms: {algorithms}.'
521 )
522
523
524def _handle_streaming_response(http_response, response, algorithm):
525 checksum_cls = _CHECKSUM_CLS.get(algorithm)
526 header_name = f"x-amz-checksum-{algorithm}"
527 return StreamingChecksumBody(
528 http_response.raw,
529 response["headers"].get("content-length"),
530 checksum_cls(),
531 response["headers"][header_name],
532 )
533
534
535def _handle_bytes_response(http_response, response, algorithm):
536 body = http_response.content
537 header_name = f"x-amz-checksum-{algorithm}"
538 checksum_cls = _CHECKSUM_CLS.get(algorithm)
539 checksum = checksum_cls()
540 checksum.update(body)
541 expected = response["headers"][header_name]
542 if checksum.digest() != base64.b64decode(expected):
543 error_msg = (
544 f"Expected checksum {expected} did not match calculated "
545 f"checksum: {checksum.b64digest()}"
546 )
547 raise FlexibleChecksumError(error_msg=error_msg)
548 return body
549
550
551_CHECKSUM_CLS = {
552 "crc32": Crc32Checksum,
553 "sha1": Sha1Checksum,
554 "sha256": Sha256Checksum,
555}
556_CRT_CHECKSUM_ALGORITHMS = ["crc32", "crc32c", "crc64nvme"]
557if HAS_CRT:
558 # Use CRT checksum implementations if available
559 _CRT_CHECKSUM_CLS = {
560 "crc32": CrtCrc32Checksum,
561 "crc32c": CrtCrc32cChecksum,
562 }
563
564 if has_minimum_crt_version((0, 23, 4)):
565 # CRC64NVME support wasn't officially added until 0.23.4
566 _CRT_CHECKSUM_CLS["crc64nvme"] = CrtCrc64NvmeChecksum
567
568 _CHECKSUM_CLS.update(_CRT_CHECKSUM_CLS)
569 # Validate this list isn't out of sync with _CRT_CHECKSUM_CLS keys
570 assert all(
571 name in _CRT_CHECKSUM_ALGORITHMS for name in _CRT_CHECKSUM_CLS.keys()
572 )
573_SUPPORTED_CHECKSUM_ALGORITHMS = list(_CHECKSUM_CLS.keys())
574_ALGORITHMS_PRIORITY_LIST = ['crc64nvme', 'crc32c', 'crc32', 'sha1', 'sha256']