/src/duckdb/extension/parquet/parquet_crypto.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | #include "parquet_crypto.hpp" |
2 | | |
3 | | #include "mbedtls_wrapper.hpp" |
4 | | #include "thrift_tools.hpp" |
5 | | |
6 | | #include "duckdb/common/exception/conversion_exception.hpp" |
7 | | #include "duckdb/common/helper.hpp" |
8 | | #include "duckdb/common/types/blob.hpp" |
9 | | #include "duckdb/storage/arena_allocator.hpp" |
10 | | |
11 | | namespace duckdb { |
12 | | |
13 | 0 | ParquetKeys &ParquetKeys::Get(ClientContext &context) { |
14 | 0 | auto &cache = ObjectCache::GetObjectCache(context); |
15 | 0 | if (!cache.Get<ParquetKeys>(ParquetKeys::ObjectType())) { |
16 | 0 | cache.Put(ParquetKeys::ObjectType(), make_shared_ptr<ParquetKeys>()); |
17 | 0 | } |
18 | 0 | return *cache.Get<ParquetKeys>(ParquetKeys::ObjectType()); |
19 | 0 | } |
20 | | |
21 | 0 | void ParquetKeys::AddKey(const string &key_name, const string &key) { |
22 | 0 | keys[key_name] = key; |
23 | 0 | } |
24 | | |
25 | 0 | bool ParquetKeys::HasKey(const string &key_name) const { |
26 | 0 | return keys.find(key_name) != keys.end(); |
27 | 0 | } |
28 | | |
29 | 0 | const string &ParquetKeys::GetKey(const string &key_name) const { |
30 | 0 | D_ASSERT(HasKey(key_name)); |
31 | 0 | return keys.at(key_name); |
32 | 0 | } |
33 | | |
34 | 0 | string ParquetKeys::ObjectType() { |
35 | 0 | return "parquet_keys"; |
36 | 0 | } |
37 | | |
38 | 0 | string ParquetKeys::GetObjectType() { |
39 | 0 | return ObjectType(); |
40 | 0 | } |
41 | | |
42 | 0 | ParquetEncryptionConfig::ParquetEncryptionConfig() { |
43 | 0 | } |
44 | | |
45 | 0 | ParquetEncryptionConfig::ParquetEncryptionConfig(string footer_key_p) : footer_key(std::move(footer_key_p)) { |
46 | 0 | } |
47 | | |
48 | 0 | ParquetEncryptionConfig::ParquetEncryptionConfig(ClientContext &context, const Value &arg) { |
49 | 0 | if (arg.type().id() != LogicalTypeId::STRUCT) { |
50 | 0 | throw BinderException("Parquet encryption_config must be of type STRUCT"); |
51 | 0 | } |
52 | 0 | const auto &child_types = StructType::GetChildTypes(arg.type()); |
53 | 0 | auto &children = StructValue::GetChildren(arg); |
54 | 0 | const auto &keys = ParquetKeys::Get(context); |
55 | 0 | for (idx_t i = 0; i < StructType::GetChildCount(arg.type()); i++) { |
56 | 0 | auto &struct_key = child_types[i].first; |
57 | 0 | if (StringUtil::Lower(struct_key) == "footer_key") { |
58 | 0 | const auto footer_key_name = StringValue::Get(children[i].DefaultCastAs(LogicalType::VARCHAR)); |
59 | 0 | if (!keys.HasKey(footer_key_name)) { |
60 | 0 | throw BinderException( |
61 | 0 | "No key with name \"%s\" exists. Add it with PRAGMA add_parquet_key('<key_name>','<key>');", |
62 | 0 | footer_key_name); |
63 | 0 | } |
64 | | // footer key name provided - read the key from the config |
65 | 0 | const auto &keys = ParquetKeys::Get(context); |
66 | 0 | footer_key = keys.GetKey(footer_key_name); |
67 | 0 | } else if (StringUtil::Lower(struct_key) == "footer_key_value") { |
68 | 0 | footer_key = StringValue::Get(children[i].DefaultCastAs(LogicalType::BLOB)); |
69 | 0 | } else if (StringUtil::Lower(struct_key) == "column_keys") { |
70 | 0 | throw NotImplementedException("Parquet encryption_config column_keys not yet implemented"); |
71 | 0 | } else { |
72 | 0 | throw BinderException("Unknown key in encryption_config \"%s\"", struct_key); |
73 | 0 | } |
74 | 0 | } |
75 | 0 | } |
76 | | |
77 | 0 | shared_ptr<ParquetEncryptionConfig> ParquetEncryptionConfig::Create(ClientContext &context, const Value &arg) { |
78 | 0 | return shared_ptr<ParquetEncryptionConfig>(new ParquetEncryptionConfig(context, arg)); |
79 | 0 | } |
80 | | |
81 | 0 | const string &ParquetEncryptionConfig::GetFooterKey() const { |
82 | 0 | return footer_key; |
83 | 0 | } |
84 | | |
85 | | using duckdb_apache::thrift::protocol::TCompactProtocolFactoryT; |
86 | | using duckdb_apache::thrift::transport::TTransport; |
87 | | |
88 | | //! Encryption wrapper for a transport protocol |
89 | | class EncryptionTransport : public TTransport { |
90 | | public: |
91 | | EncryptionTransport(TProtocol &prot_p, const string &key, const EncryptionUtil &encryption_util_p) |
92 | 0 | : prot(prot_p), trans(*prot.getTransport()), |
93 | 0 | aes(encryption_util_p.CreateEncryptionState(reinterpret_cast<const_data_ptr_t>(key.data()), key.size())), |
94 | 0 | allocator(Allocator::DefaultAllocator(), ParquetCrypto::CRYPTO_BLOCK_SIZE) { |
95 | 0 | Initialize(key); |
96 | 0 | } |
97 | | |
98 | 0 | bool isOpen() const override { |
99 | 0 | return trans.isOpen(); |
100 | 0 | } |
101 | | |
102 | 0 | void open() override { |
103 | 0 | trans.open(); |
104 | 0 | } |
105 | | |
106 | 0 | void close() override { |
107 | 0 | trans.close(); |
108 | 0 | } |
109 | | |
110 | 0 | void write_virt(const uint8_t *buf, uint32_t len) override { |
111 | 0 | memcpy(allocator.Allocate(len), buf, len); |
112 | 0 | } |
113 | | |
114 | 0 | uint32_t Finalize() { |
115 | | // Write length |
116 | 0 | const auto ciphertext_length = allocator.SizeInBytes(); |
117 | 0 | const uint32_t total_length = ParquetCrypto::NONCE_BYTES + ciphertext_length + ParquetCrypto::TAG_BYTES; |
118 | |
|
119 | 0 | trans.write(const_data_ptr_cast(&total_length), ParquetCrypto::LENGTH_BYTES); |
120 | | // Write nonce at beginning of encrypted chunk |
121 | 0 | trans.write(nonce, ParquetCrypto::NONCE_BYTES); |
122 | |
|
123 | 0 | data_t aes_buffer[ParquetCrypto::CRYPTO_BLOCK_SIZE]; |
124 | 0 | auto current = allocator.GetTail(); |
125 | | |
126 | | // Loop through the whole chunk |
127 | 0 | while (current != nullptr) { |
128 | 0 | for (idx_t pos = 0; pos < current->current_position; pos += ParquetCrypto::CRYPTO_BLOCK_SIZE) { |
129 | 0 | auto next = MinValue<idx_t>(current->current_position - pos, ParquetCrypto::CRYPTO_BLOCK_SIZE); |
130 | 0 | auto write_size = |
131 | 0 | aes->Process(current->data.get() + pos, next, aes_buffer, ParquetCrypto::CRYPTO_BLOCK_SIZE); |
132 | 0 | trans.write(aes_buffer, write_size); |
133 | 0 | } |
134 | 0 | current = current->prev; |
135 | 0 | } |
136 | | |
137 | | // Finalize the last encrypted data |
138 | 0 | data_t tag[ParquetCrypto::TAG_BYTES]; |
139 | 0 | auto write_size = aes->Finalize(aes_buffer, 0, tag, ParquetCrypto::TAG_BYTES); |
140 | 0 | trans.write(aes_buffer, write_size); |
141 | | // Write tag for verification |
142 | 0 | trans.write(tag, ParquetCrypto::TAG_BYTES); |
143 | |
|
144 | 0 | return ParquetCrypto::LENGTH_BYTES + total_length; |
145 | 0 | } |
146 | | |
147 | | private: |
148 | 0 | void Initialize(const string &key) { |
149 | | // Generate Nonce |
150 | 0 | aes->GenerateRandomData(nonce, ParquetCrypto::NONCE_BYTES); |
151 | | // Initialize Encryption |
152 | 0 | aes->InitializeEncryption(nonce, ParquetCrypto::NONCE_BYTES, reinterpret_cast<const_data_ptr_t>(key.data()), |
153 | 0 | key.size()); |
154 | 0 | } |
155 | | |
156 | | private: |
157 | | //! Protocol and corresponding transport that we're wrapping |
158 | | TProtocol &prot; |
159 | | TTransport &trans; |
160 | | |
161 | | //! AES context and buffers |
162 | | shared_ptr<EncryptionState> aes; |
163 | | |
164 | | //! Nonce created by Initialize() |
165 | | data_t nonce[ParquetCrypto::NONCE_BYTES]; |
166 | | |
167 | | //! Arena Allocator to fully materialize in memory before encrypting |
168 | | ArenaAllocator allocator; |
169 | | }; |
170 | | |
171 | | //! Decryption wrapper for a transport protocol |
172 | | class DecryptionTransport : public TTransport { |
173 | | public: |
174 | | DecryptionTransport(TProtocol &prot_p, const string &key, const EncryptionUtil &encryption_util_p) |
175 | 0 | : prot(prot_p), trans(*prot.getTransport()), |
176 | 0 | aes(encryption_util_p.CreateEncryptionState(reinterpret_cast<const_data_ptr_t>(key.data()), key.size())), |
177 | 0 | read_buffer_size(0), read_buffer_offset(0) { |
178 | 0 | Initialize(key); |
179 | 0 | } |
180 | 0 | uint32_t read_virt(uint8_t *buf, uint32_t len) override { |
181 | 0 | const uint32_t result = len; |
182 | |
|
183 | 0 | if (len > transport_remaining - ParquetCrypto::TAG_BYTES + read_buffer_size - read_buffer_offset) { |
184 | 0 | throw InvalidInputException("Too many bytes requested from crypto buffer"); |
185 | 0 | } |
186 | | |
187 | 0 | while (len != 0) { |
188 | 0 | if (read_buffer_offset == read_buffer_size) { |
189 | 0 | ReadBlock(buf); |
190 | 0 | } |
191 | 0 | const auto next = MinValue(read_buffer_size - read_buffer_offset, len); |
192 | 0 | read_buffer_offset += next; |
193 | 0 | buf += next; |
194 | 0 | len -= next; |
195 | 0 | } |
196 | |
|
197 | 0 | return result; |
198 | 0 | } |
199 | | |
200 | 0 | uint32_t Finalize() { |
201 | |
|
202 | 0 | if (read_buffer_offset != read_buffer_size) { |
203 | 0 | throw InternalException("DecryptionTransport::Finalize was called with bytes remaining in read buffer: \n" |
204 | 0 | "read buffer offset: %d, read buffer size: %d", |
205 | 0 | read_buffer_offset, read_buffer_size); |
206 | 0 | } |
207 | | |
208 | 0 | data_t computed_tag[ParquetCrypto::TAG_BYTES]; |
209 | 0 | transport_remaining -= trans.read(computed_tag, ParquetCrypto::TAG_BYTES); |
210 | 0 | if (aes->Finalize(read_buffer, 0, computed_tag, ParquetCrypto::TAG_BYTES) != 0) { |
211 | 0 | throw InternalException("DecryptionTransport::Finalize was called with bytes remaining in AES context out"); |
212 | 0 | } |
213 | | |
214 | 0 | if (transport_remaining != 0) { |
215 | 0 | throw InvalidInputException("Encoded ciphertext length differs from actual ciphertext length"); |
216 | 0 | } |
217 | | |
218 | 0 | return ParquetCrypto::LENGTH_BYTES + total_bytes; |
219 | 0 | } |
220 | | |
221 | 0 | AllocatedData ReadAll() { |
222 | 0 | D_ASSERT(transport_remaining == total_bytes - ParquetCrypto::NONCE_BYTES); |
223 | 0 | auto result = Allocator::DefaultAllocator().Allocate(transport_remaining - ParquetCrypto::TAG_BYTES); |
224 | 0 | read_virt(result.get(), transport_remaining - ParquetCrypto::TAG_BYTES); |
225 | 0 | Finalize(); |
226 | 0 | return result; |
227 | 0 | } |
228 | | |
229 | | private: |
230 | 0 | void Initialize(const string &key) { |
231 | | // Read encoded length (don't add to read_bytes) |
232 | 0 | data_t length_buf[ParquetCrypto::LENGTH_BYTES]; |
233 | 0 | trans.read(length_buf, ParquetCrypto::LENGTH_BYTES); |
234 | 0 | total_bytes = Load<uint32_t>(length_buf); |
235 | 0 | transport_remaining = total_bytes; |
236 | | // Read nonce and initialize AES |
237 | 0 | transport_remaining -= trans.read(nonce, ParquetCrypto::NONCE_BYTES); |
238 | | // check whether context is initialized |
239 | 0 | aes->InitializeDecryption(nonce, ParquetCrypto::NONCE_BYTES, reinterpret_cast<const_data_ptr_t>(key.data()), |
240 | 0 | key.size()); |
241 | 0 | } |
242 | | |
243 | 0 | void ReadBlock(uint8_t *buf) { |
244 | | // Read from transport into read_buffer at one AES block size offset (up to the tag) |
245 | 0 | read_buffer_size = MinValue(ParquetCrypto::CRYPTO_BLOCK_SIZE, transport_remaining - ParquetCrypto::TAG_BYTES); |
246 | 0 | transport_remaining -= trans.read(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size); |
247 | | |
248 | | // Decrypt from read_buffer + block size into read_buffer start (decryption can trail behind in same buffer) |
249 | | #ifdef DEBUG |
250 | | auto size = aes->Process(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size, buf, |
251 | | ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE); |
252 | | D_ASSERT(size == read_buffer_size); |
253 | | #else |
254 | 0 | aes->Process(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size, buf, |
255 | 0 | ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE); |
256 | 0 | #endif |
257 | 0 | read_buffer_offset = 0; |
258 | 0 | } |
259 | | |
260 | | private: |
261 | | //! Protocol and corresponding transport that we're wrapping |
262 | | TProtocol &prot; |
263 | | TTransport &trans; |
264 | | |
265 | | //! AES context and buffers |
266 | | shared_ptr<EncryptionState> aes; |
267 | | |
268 | | //! We read/decrypt big blocks at a time |
269 | | data_t read_buffer[ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE]; |
270 | | uint32_t read_buffer_size; |
271 | | uint32_t read_buffer_offset; |
272 | | |
273 | | //! Remaining bytes to read, set by Initialize(), decremented by ReadBlock() |
274 | | uint32_t total_bytes; |
275 | | uint32_t transport_remaining; |
276 | | //! Nonce read by Initialize() |
277 | | data_t nonce[ParquetCrypto::NONCE_BYTES]; |
278 | | }; |
279 | | |
280 | | class SimpleReadTransport : public TTransport { |
281 | | public: |
282 | | explicit SimpleReadTransport(data_ptr_t read_buffer_p, uint32_t read_buffer_size_p) |
283 | 0 | : read_buffer(read_buffer_p), read_buffer_size(read_buffer_size_p), read_buffer_offset(0) { |
284 | 0 | } |
285 | | |
286 | 0 | uint32_t read_virt(uint8_t *buf, uint32_t len) override { |
287 | 0 | const auto remaining = read_buffer_size - read_buffer_offset; |
288 | 0 | if (len > remaining) { |
289 | 0 | return remaining; |
290 | 0 | } |
291 | 0 | memcpy(buf, read_buffer + read_buffer_offset, len); |
292 | 0 | read_buffer_offset += len; |
293 | 0 | return len; |
294 | 0 | } |
295 | | |
296 | | private: |
297 | | const data_ptr_t read_buffer; |
298 | | const uint32_t read_buffer_size; |
299 | | uint32_t read_buffer_offset; |
300 | | }; |
301 | | |
302 | | uint32_t ParquetCrypto::Read(TBase &object, TProtocol &iprot, const string &key, |
303 | 0 | const EncryptionUtil &encryption_util_p) { |
304 | 0 | TCompactProtocolFactoryT<DecryptionTransport> tproto_factory; |
305 | 0 | auto dprot = |
306 | 0 | tproto_factory.getProtocol(duckdb_base_std::make_shared<DecryptionTransport>(iprot, key, encryption_util_p)); |
307 | 0 | auto &dtrans = reinterpret_cast<DecryptionTransport &>(*dprot->getTransport()); |
308 | | |
309 | | // We have to read the whole thing otherwise thrift throws an error before we realize we're decryption is wrong |
310 | 0 | auto all = dtrans.ReadAll(); |
311 | 0 | TCompactProtocolFactoryT<SimpleReadTransport> tsimple_proto_factory; |
312 | 0 | auto simple_prot = |
313 | 0 | tsimple_proto_factory.getProtocol(duckdb_base_std::make_shared<SimpleReadTransport>(all.get(), all.GetSize())); |
314 | | |
315 | | // Read the object |
316 | 0 | object.read(simple_prot.get()); |
317 | |
|
318 | 0 | return ParquetCrypto::LENGTH_BYTES + ParquetCrypto::NONCE_BYTES + all.GetSize() + ParquetCrypto::TAG_BYTES; |
319 | 0 | } |
320 | | |
321 | | uint32_t ParquetCrypto::Write(const TBase &object, TProtocol &oprot, const string &key, |
322 | 0 | const EncryptionUtil &encryption_util_p) { |
323 | | // Create encryption protocol |
324 | 0 | TCompactProtocolFactoryT<EncryptionTransport> tproto_factory; |
325 | 0 | auto eprot = |
326 | 0 | tproto_factory.getProtocol(duckdb_base_std::make_shared<EncryptionTransport>(oprot, key, encryption_util_p)); |
327 | 0 | auto &etrans = reinterpret_cast<EncryptionTransport &>(*eprot->getTransport()); |
328 | | |
329 | | // Write the object in memory |
330 | 0 | object.write(eprot.get()); |
331 | | |
332 | | // Encrypt and write to oprot |
333 | 0 | return etrans.Finalize(); |
334 | 0 | } |
335 | | |
336 | | uint32_t ParquetCrypto::ReadData(TProtocol &iprot, const data_ptr_t buffer, const uint32_t buffer_size, |
337 | 0 | const string &key, const EncryptionUtil &encryption_util_p) { |
338 | | // Create decryption protocol |
339 | 0 | TCompactProtocolFactoryT<DecryptionTransport> tproto_factory; |
340 | 0 | auto dprot = |
341 | 0 | tproto_factory.getProtocol(duckdb_base_std::make_shared<DecryptionTransport>(iprot, key, encryption_util_p)); |
342 | 0 | auto &dtrans = reinterpret_cast<DecryptionTransport &>(*dprot->getTransport()); |
343 | | |
344 | | // Read buffer |
345 | 0 | dtrans.read(buffer, buffer_size); |
346 | | |
347 | | // Verify AES tag and read length |
348 | 0 | return dtrans.Finalize(); |
349 | 0 | } |
350 | | |
351 | | uint32_t ParquetCrypto::WriteData(TProtocol &oprot, const const_data_ptr_t buffer, const uint32_t buffer_size, |
352 | 0 | const string &key, const EncryptionUtil &encryption_util_p) { |
353 | | // FIXME: we know the size upfront so we could do a streaming write instead of this |
354 | | // Create encryption protocol |
355 | 0 | TCompactProtocolFactoryT<EncryptionTransport> tproto_factory; |
356 | 0 | auto eprot = |
357 | 0 | tproto_factory.getProtocol(duckdb_base_std::make_shared<EncryptionTransport>(oprot, key, encryption_util_p)); |
358 | 0 | auto &etrans = reinterpret_cast<EncryptionTransport &>(*eprot->getTransport()); |
359 | | |
360 | | // Write the data in memory |
361 | 0 | etrans.write(buffer, buffer_size); |
362 | | |
363 | | // Encrypt and write to oprot |
364 | 0 | return etrans.Finalize(); |
365 | 0 | } |
366 | | |
367 | 0 | bool ParquetCrypto::ValidKey(const std::string &key) { |
368 | 0 | switch (key.size()) { |
369 | 0 | case 16: |
370 | 0 | case 24: |
371 | 0 | case 32: |
372 | 0 | return true; |
373 | 0 | default: |
374 | 0 | return false; |
375 | 0 | } |
376 | 0 | } |
377 | | |
378 | 0 | static string Base64Decode(const string &key) { |
379 | 0 | auto result_size = Blob::FromBase64Size(key); |
380 | 0 | auto output = duckdb::unique_ptr<unsigned char[]>(new unsigned char[result_size]); |
381 | 0 | Blob::FromBase64(key, output.get(), result_size); |
382 | 0 | string decoded_key(reinterpret_cast<const char *>(output.get()), result_size); |
383 | 0 | return decoded_key; |
384 | 0 | } |
385 | | |
386 | 0 | void ParquetCrypto::AddKey(ClientContext &context, const FunctionParameters ¶meters) { |
387 | 0 | const auto &key_name = StringValue::Get(parameters.values[0]); |
388 | 0 | const auto &key = StringValue::Get(parameters.values[1]); |
389 | |
|
390 | 0 | auto &keys = ParquetKeys::Get(context); |
391 | 0 | if (ValidKey(key)) { |
392 | 0 | keys.AddKey(key_name, key); |
393 | 0 | } else { |
394 | 0 | string decoded_key; |
395 | 0 | try { |
396 | 0 | decoded_key = Base64Decode(key); |
397 | 0 | } catch (const ConversionException &e) { |
398 | 0 | throw InvalidInputException("Invalid AES key. Not a plain AES key NOR a base64 encoded string"); |
399 | 0 | } |
400 | 0 | if (!ValidKey(decoded_key)) { |
401 | 0 | throw InvalidInputException( |
402 | 0 | "Invalid AES key. Must have a length of 128, 192, or 256 bits (16, 24, or 32 bytes)"); |
403 | 0 | } |
404 | 0 | keys.AddKey(key_name, decoded_key); |
405 | 0 | } |
406 | 0 | } |
407 | | |
408 | | } // namespace duckdb |