/src/duckdb/extension/parquet/zstd_file_system.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | #include "zstd_file_system.hpp" |
2 | | |
3 | | #include "zstd.h" |
4 | | |
5 | | namespace duckdb { |
6 | | |
7 | | struct ZstdStreamWrapper : public StreamWrapper { |
8 | | ~ZstdStreamWrapper() override; |
9 | | |
10 | | CompressedFile *file = nullptr; |
11 | | duckdb_zstd::ZSTD_DStream *zstd_stream_ptr = nullptr; |
12 | | duckdb_zstd::ZSTD_CStream *zstd_compress_ptr = nullptr; |
13 | | bool writing = false; |
14 | | |
15 | | public: |
16 | | void Initialize(CompressedFile &file, bool write) override; |
17 | | bool Read(StreamData &stream_data) override; |
18 | | void Write(CompressedFile &file, StreamData &stream_data, data_ptr_t buffer, int64_t nr_bytes) override; |
19 | | |
20 | | void Close() override; |
21 | | |
22 | | void FlushStream(); |
23 | | }; |
24 | | |
25 | 0 | ZstdStreamWrapper::~ZstdStreamWrapper() { |
26 | 0 | if (Exception::UncaughtException()) { |
27 | 0 | return; |
28 | 0 | } |
29 | 0 | try { |
30 | 0 | Close(); |
31 | 0 | } catch (...) { // NOLINT: swallow exceptions in destructor |
32 | 0 | } |
33 | 0 | } |
34 | | |
35 | 0 | void ZstdStreamWrapper::Initialize(CompressedFile &file, bool write) { |
36 | 0 | Close(); |
37 | 0 | this->file = &file; |
38 | 0 | this->writing = write; |
39 | 0 | if (write) { |
40 | 0 | zstd_compress_ptr = duckdb_zstd::ZSTD_createCStream(); |
41 | 0 | } else { |
42 | 0 | zstd_stream_ptr = duckdb_zstd::ZSTD_createDStream(); |
43 | 0 | } |
44 | 0 | } |
45 | | |
46 | 0 | bool ZstdStreamWrapper::Read(StreamData &sd) { |
47 | 0 | D_ASSERT(!writing); |
48 | |
|
49 | 0 | duckdb_zstd::ZSTD_inBuffer in_buffer; |
50 | 0 | duckdb_zstd::ZSTD_outBuffer out_buffer; |
51 | |
|
52 | 0 | in_buffer.src = sd.in_buff_start; |
53 | 0 | in_buffer.size = sd.in_buff_end - sd.in_buff_start; |
54 | 0 | in_buffer.pos = 0; |
55 | |
|
56 | 0 | out_buffer.dst = sd.out_buff_start; |
57 | 0 | out_buffer.size = sd.out_buf_size; |
58 | 0 | out_buffer.pos = 0; |
59 | |
|
60 | 0 | auto res = duckdb_zstd::ZSTD_decompressStream(zstd_stream_ptr, &out_buffer, &in_buffer); |
61 | 0 | if (duckdb_zstd::ZSTD_isError(res)) { |
62 | 0 | throw IOException(duckdb_zstd::ZSTD_getErrorName(res)); |
63 | 0 | } |
64 | | |
65 | 0 | sd.in_buff_start = (data_ptr_t)in_buffer.src + in_buffer.pos; // NOLINT |
66 | 0 | sd.in_buff_end = (data_ptr_t)in_buffer.src + in_buffer.size; // NOLINT |
67 | 0 | sd.out_buff_end = (data_ptr_t)out_buffer.dst + out_buffer.pos; // NOLINT |
68 | 0 | return false; |
69 | 0 | } |
70 | | |
71 | | void ZstdStreamWrapper::Write(CompressedFile &file, StreamData &sd, data_ptr_t uncompressed_data, |
72 | 0 | int64_t uncompressed_size) { |
73 | 0 | D_ASSERT(writing); |
74 | |
|
75 | 0 | auto remaining = uncompressed_size; |
76 | 0 | while (remaining > 0) { |
77 | 0 | D_ASSERT(sd.out_buff.get() + sd.out_buf_size > sd.out_buff_start); |
78 | 0 | idx_t output_remaining = (sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_start; |
79 | |
|
80 | 0 | duckdb_zstd::ZSTD_inBuffer in_buffer; |
81 | 0 | duckdb_zstd::ZSTD_outBuffer out_buffer; |
82 | |
|
83 | 0 | in_buffer.src = uncompressed_data; |
84 | 0 | in_buffer.size = remaining; |
85 | 0 | in_buffer.pos = 0; |
86 | |
|
87 | 0 | out_buffer.dst = sd.out_buff_start; |
88 | 0 | out_buffer.size = output_remaining; |
89 | 0 | out_buffer.pos = 0; |
90 | 0 | auto res = |
91 | 0 | duckdb_zstd::ZSTD_compressStream2(zstd_compress_ptr, &out_buffer, &in_buffer, duckdb_zstd::ZSTD_e_continue); |
92 | 0 | if (duckdb_zstd::ZSTD_isError(res)) { |
93 | 0 | throw IOException(duckdb_zstd::ZSTD_getErrorName(res)); |
94 | 0 | } |
95 | 0 | idx_t input_consumed = in_buffer.pos; |
96 | 0 | idx_t written_to_output = out_buffer.pos; |
97 | 0 | sd.out_buff_start += written_to_output; |
98 | 0 | if (sd.out_buff_start == sd.out_buff.get() + sd.out_buf_size) { |
99 | | // no more output buffer available: flush |
100 | 0 | file.child_handle->Write(sd.out_buff.get(), sd.out_buff_start - sd.out_buff.get()); |
101 | 0 | sd.out_buff_start = sd.out_buff.get(); |
102 | 0 | } |
103 | 0 | uncompressed_data += input_consumed; |
104 | 0 | remaining -= UnsafeNumericCast<int64_t>(input_consumed); |
105 | 0 | } |
106 | 0 | } |
107 | | |
108 | 0 | void ZstdStreamWrapper::FlushStream() { |
109 | 0 | auto &sd = file->stream_data; |
110 | 0 | duckdb_zstd::ZSTD_inBuffer in_buffer; |
111 | 0 | duckdb_zstd::ZSTD_outBuffer out_buffer; |
112 | |
|
113 | 0 | in_buffer.src = nullptr; |
114 | 0 | in_buffer.size = 0; |
115 | 0 | in_buffer.pos = 0; |
116 | 0 | while (true) { |
117 | 0 | idx_t output_remaining = (sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_start; |
118 | |
|
119 | 0 | out_buffer.dst = sd.out_buff_start; |
120 | 0 | out_buffer.size = output_remaining; |
121 | 0 | out_buffer.pos = 0; |
122 | |
|
123 | 0 | auto res = |
124 | 0 | duckdb_zstd::ZSTD_compressStream2(zstd_compress_ptr, &out_buffer, &in_buffer, duckdb_zstd::ZSTD_e_end); |
125 | 0 | if (duckdb_zstd::ZSTD_isError(res)) { |
126 | 0 | throw IOException(duckdb_zstd::ZSTD_getErrorName(res)); |
127 | 0 | } |
128 | 0 | idx_t written_to_output = out_buffer.pos; |
129 | 0 | sd.out_buff_start += written_to_output; |
130 | 0 | if (sd.out_buff_start > sd.out_buff.get()) { |
131 | 0 | file->child_handle->Write(sd.out_buff.get(), sd.out_buff_start - sd.out_buff.get()); |
132 | 0 | sd.out_buff_start = sd.out_buff.get(); |
133 | 0 | } |
134 | 0 | if (res == 0) { |
135 | 0 | break; |
136 | 0 | } |
137 | 0 | } |
138 | 0 | } |
139 | | |
140 | 0 | void ZstdStreamWrapper::Close() { |
141 | 0 | if (!zstd_stream_ptr && !zstd_compress_ptr) { |
142 | 0 | return; |
143 | 0 | } |
144 | 0 | if (writing) { |
145 | 0 | FlushStream(); |
146 | 0 | } |
147 | 0 | if (zstd_stream_ptr) { |
148 | 0 | duckdb_zstd::ZSTD_freeDStream(zstd_stream_ptr); |
149 | 0 | } |
150 | 0 | if (zstd_compress_ptr) { |
151 | 0 | duckdb_zstd::ZSTD_freeCStream(zstd_compress_ptr); |
152 | 0 | } |
153 | 0 | zstd_stream_ptr = nullptr; |
154 | 0 | zstd_compress_ptr = nullptr; |
155 | 0 | } |
156 | | |
157 | | class ZStdFile : public CompressedFile { |
158 | | public: |
159 | | ZStdFile(unique_ptr<FileHandle> child_handle_p, const string &path, bool write) |
160 | 0 | : CompressedFile(zstd_fs, std::move(child_handle_p), path) { |
161 | 0 | Initialize(write); |
162 | 0 | } |
163 | | |
164 | 0 | FileCompressionType GetFileCompressionType() override { |
165 | 0 | return FileCompressionType::ZSTD; |
166 | 0 | } |
167 | | |
168 | | ZStdFileSystem zstd_fs; |
169 | | }; |
170 | | |
171 | 0 | unique_ptr<FileHandle> ZStdFileSystem::OpenCompressedFile(unique_ptr<FileHandle> handle, bool write) { |
172 | 0 | auto path = handle->path; |
173 | 0 | return make_uniq<ZStdFile>(std::move(handle), path, write); |
174 | 0 | } |
175 | | |
176 | 0 | unique_ptr<StreamWrapper> ZStdFileSystem::CreateStream() { |
177 | 0 | return make_uniq<ZstdStreamWrapper>(); |
178 | 0 | } |
179 | | |
180 | 0 | idx_t ZStdFileSystem::InBufferSize() { |
181 | 0 | return duckdb_zstd::ZSTD_DStreamInSize(); |
182 | 0 | } |
183 | | |
184 | 0 | idx_t ZStdFileSystem::OutBufferSize() { |
185 | 0 | return duckdb_zstd::ZSTD_DStreamOutSize(); |
186 | 0 | } |
187 | | |
188 | 0 | int64_t ZStdFileSystem::DefaultCompressionLevel() { |
189 | 0 | return duckdb_zstd::ZSTD_defaultCLevel(); |
190 | 0 | } |
191 | | |
192 | 0 | int64_t ZStdFileSystem::MinimumCompressionLevel() { |
193 | 0 | return duckdb_zstd::ZSTD_minCLevel(); |
194 | 0 | } |
195 | | |
196 | 0 | int64_t ZStdFileSystem::MaximumCompressionLevel() { |
197 | 0 | return duckdb_zstd::ZSTD_maxCLevel(); |
198 | 0 | } |
199 | | |
200 | | } // namespace duckdb |