Line data Source code
1 : #include "source/extensions/compression/zstd/decompressor/zstd_decompressor_impl.h" 2 : 3 : #include "source/common/runtime/runtime_features.h" 4 : 5 : namespace Envoy { 6 : namespace Extensions { 7 : namespace Compression { 8 : namespace Zstd { 9 : namespace Decompressor { 10 : 11 : namespace { 12 : 13 : // How many times the output buffer is allowed to be bigger than the size of 14 : // accumulated input. This value is used to detect compression bombs. 15 : // TODO(rojkov): Re-design the Decompressor interface to handle compression 16 : // bombs gracefully instead of this quick solution. 17 : constexpr uint64_t MaxInflateRatio = 100; 18 : 19 : } // namespace 20 : 21 : ZstdDecompressorImpl::ZstdDecompressorImpl(Stats::Scope& scope, const std::string& stats_prefix, 22 : const ZstdDDictManagerPtr& ddict_manager, 23 : uint32_t chunk_size) 24 : : Common::Base(chunk_size), dctx_(ZSTD_createDCtx(), &ZSTD_freeDCtx), 25 0 : ddict_manager_(ddict_manager), stats_(generateStats(stats_prefix, scope)) {} 26 : 27 : void ZstdDecompressorImpl::decompress(const Buffer::Instance& input_buffer, 28 0 : Buffer::Instance& output_buffer) { 29 0 : uint64_t limit = MaxInflateRatio * input_buffer.length(); 30 : 31 0 : for (const Buffer::RawSlice& input_slice : input_buffer.getRawSlices()) { 32 0 : if (input_slice.len_ > 0) { 33 0 : if (ddict_manager_ && !is_dictionary_set_) { 34 0 : is_dictionary_set_ = true; 35 : // If id == 0, it means that dictionary id could not be decoded. 36 0 : dictionary_id_ = 37 0 : ZSTD_getDictID_fromFrame(static_cast<uint8_t*>(input_slice.mem_), input_slice.len_); 38 0 : if (dictionary_id_ != 0) { 39 0 : auto dictionary = ddict_manager_->getDictionaryById(dictionary_id_); 40 0 : if (!dictionary) { 41 0 : stats_.zstd_dictionary_error_.inc(); 42 0 : return; 43 0 : } 44 0 : const size_t result = ZSTD_DCtx_refDDict(dctx_.get(), dictionary); 45 0 : if (isError(result)) { 46 0 : return; 47 0 : } 48 0 : } 49 0 : } 50 : 51 0 : setInput(input_slice); 52 0 : if (!process(output_buffer)) { 53 0 : return; 54 0 : } 55 0 : if (Runtime::runtimeFeatureEnabled( 56 0 : "envoy.reloadable_features.enable_compression_bomb_protection") && 57 0 : (output_buffer.length() > limit)) { 58 0 : stats_.zstd_generic_error_.inc(); 59 0 : ENVOY_LOG(trace, 60 0 : "excessive decompression ratio detected: output " 61 0 : "size {} for input size {}", 62 0 : output_buffer.length(), input_buffer.length()); 63 0 : return; 64 0 : } 65 0 : } 66 0 : } 67 0 : } 68 : 69 0 : bool ZstdDecompressorImpl::process(Buffer::Instance& output_buffer) { 70 0 : while (input_.pos < input_.size) { 71 0 : const size_t result = ZSTD_decompressStream(dctx_.get(), &output_, &input_); 72 0 : if (isError(result)) { 73 0 : return false; 74 0 : } 75 : 76 0 : getOutput(output_buffer); 77 0 : } 78 : 79 0 : return true; 80 0 : } 81 : 82 0 : bool ZstdDecompressorImpl::isError(size_t result) { 83 0 : switch (ZSTD_getErrorCode(result)) { 84 0 : case ZSTD_error_no_error: 85 0 : return false; 86 0 : case ZSTD_error_memory_allocation: 87 0 : stats_.zstd_memory_error_.inc(); 88 0 : break; 89 0 : case ZSTD_error_dictionary_corrupted: 90 0 : case ZSTD_error_dictionary_wrong: 91 0 : stats_.zstd_dictionary_error_.inc(); 92 0 : break; 93 0 : case ZSTD_error_checksum_wrong: 94 0 : stats_.zstd_checksum_wrong_error_.inc(); 95 0 : break; 96 0 : default: 97 0 : stats_.zstd_generic_error_.inc(); 98 0 : break; 99 0 : } 100 0 : return true; 101 0 : } 102 : 103 : } // namespace Decompressor 104 : } // namespace Zstd 105 : } // namespace Compression 106 : } // namespace Extensions 107 : } // namespace Envoy