Coverage Report

Created: 2025-11-15 07:36

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/duckdb/extension/parquet/parquet_crypto.cpp
Line
Count
Source
1
#include "parquet_crypto.hpp"
2
3
#include "mbedtls_wrapper.hpp"
4
#include "thrift_tools.hpp"
5
6
#include "duckdb/common/exception/conversion_exception.hpp"
7
#include "duckdb/common/helper.hpp"
8
#include "duckdb/common/types/blob.hpp"
9
#include "duckdb/storage/arena_allocator.hpp"
10
11
namespace duckdb {
12
13
0
ParquetKeys &ParquetKeys::Get(ClientContext &context) {
14
0
  auto &cache = ObjectCache::GetObjectCache(context);
15
0
  if (!cache.Get<ParquetKeys>(ParquetKeys::ObjectType())) {
16
0
    cache.Put(ParquetKeys::ObjectType(), make_shared_ptr<ParquetKeys>());
17
0
  }
18
0
  return *cache.Get<ParquetKeys>(ParquetKeys::ObjectType());
19
0
}
20
21
0
void ParquetKeys::AddKey(const string &key_name, const string &key) {
22
0
  keys[key_name] = key;
23
0
}
24
25
0
bool ParquetKeys::HasKey(const string &key_name) const {
26
0
  return keys.find(key_name) != keys.end();
27
0
}
28
29
0
const string &ParquetKeys::GetKey(const string &key_name) const {
30
0
  D_ASSERT(HasKey(key_name));
31
0
  return keys.at(key_name);
32
0
}
33
34
0
string ParquetKeys::ObjectType() {
35
0
  return "parquet_keys";
36
0
}
37
38
0
string ParquetKeys::GetObjectType() {
39
0
  return ObjectType();
40
0
}
41
42
0
ParquetEncryptionConfig::ParquetEncryptionConfig() {
43
0
}
44
45
0
ParquetEncryptionConfig::ParquetEncryptionConfig(string footer_key_p) : footer_key(std::move(footer_key_p)) {
46
0
}
47
48
0
ParquetEncryptionConfig::ParquetEncryptionConfig(ClientContext &context, const Value &arg) {
49
0
  if (arg.type().id() != LogicalTypeId::STRUCT) {
50
0
    throw BinderException("Parquet encryption_config must be of type STRUCT");
51
0
  }
52
0
  const auto &child_types = StructType::GetChildTypes(arg.type());
53
0
  auto &children = StructValue::GetChildren(arg);
54
0
  const auto &keys = ParquetKeys::Get(context);
55
0
  for (idx_t i = 0; i < StructType::GetChildCount(arg.type()); i++) {
56
0
    auto &struct_key = child_types[i].first;
57
0
    if (StringUtil::Lower(struct_key) == "footer_key") {
58
0
      const auto footer_key_name = StringValue::Get(children[i].DefaultCastAs(LogicalType::VARCHAR));
59
0
      if (!keys.HasKey(footer_key_name)) {
60
0
        throw BinderException(
61
0
            "No key with name \"%s\" exists. Add it with PRAGMA add_parquet_key('<key_name>','<key>');",
62
0
            footer_key_name);
63
0
      }
64
      // footer key name provided - read the key from the config
65
0
      const auto &keys = ParquetKeys::Get(context);
66
0
      footer_key = keys.GetKey(footer_key_name);
67
0
    } else if (StringUtil::Lower(struct_key) == "footer_key_value") {
68
0
      footer_key = StringValue::Get(children[i].DefaultCastAs(LogicalType::BLOB));
69
0
    } else if (StringUtil::Lower(struct_key) == "column_keys") {
70
0
      throw NotImplementedException("Parquet encryption_config column_keys not yet implemented");
71
0
    } else {
72
0
      throw BinderException("Unknown key in encryption_config \"%s\"", struct_key);
73
0
    }
74
0
  }
75
0
}
76
77
0
shared_ptr<ParquetEncryptionConfig> ParquetEncryptionConfig::Create(ClientContext &context, const Value &arg) {
78
0
  return shared_ptr<ParquetEncryptionConfig>(new ParquetEncryptionConfig(context, arg));
79
0
}
80
81
0
const string &ParquetEncryptionConfig::GetFooterKey() const {
82
0
  return footer_key;
83
0
}
84
85
using duckdb_apache::thrift::protocol::TCompactProtocolFactoryT;
86
using duckdb_apache::thrift::transport::TTransport;
87
88
//! Encryption wrapper for a transport protocol
89
class EncryptionTransport : public TTransport {
90
public:
91
  EncryptionTransport(TProtocol &prot_p, const string &key, const EncryptionUtil &encryption_util_p)
92
0
      : prot(prot_p), trans(*prot.getTransport()),
93
0
        aes(encryption_util_p.CreateEncryptionState(EncryptionTypes::GCM, key.size())),
94
0
        allocator(Allocator::DefaultAllocator(), ParquetCrypto::CRYPTO_BLOCK_SIZE) {
95
0
    Initialize(key);
96
0
  }
97
98
0
  bool isOpen() const override {
99
0
    return trans.isOpen();
100
0
  }
101
102
0
  void open() override {
103
0
    trans.open();
104
0
  }
105
106
0
  void close() override {
107
0
    trans.close();
108
0
  }
109
110
0
  void write_virt(const uint8_t *buf, uint32_t len) override {
111
0
    memcpy(allocator.Allocate(len), buf, len);
112
0
  }
113
114
0
  uint32_t Finalize() {
115
    // Write length
116
0
    const auto ciphertext_length = allocator.SizeInBytes();
117
0
    const uint32_t total_length = ParquetCrypto::NONCE_BYTES + ciphertext_length + ParquetCrypto::TAG_BYTES;
118
119
0
    trans.write(const_data_ptr_cast(&total_length), ParquetCrypto::LENGTH_BYTES);
120
    // Write nonce at beginning of encrypted chunk
121
0
    trans.write(nonce, ParquetCrypto::NONCE_BYTES);
122
123
0
    data_t aes_buffer[ParquetCrypto::CRYPTO_BLOCK_SIZE];
124
0
    auto current = allocator.GetTail();
125
126
    // Loop through the whole chunk
127
0
    while (current != nullptr) {
128
0
      for (idx_t pos = 0; pos < current->current_position; pos += ParquetCrypto::CRYPTO_BLOCK_SIZE) {
129
0
        auto next = MinValue<idx_t>(current->current_position - pos, ParquetCrypto::CRYPTO_BLOCK_SIZE);
130
0
        auto write_size =
131
0
            aes->Process(current->data.get() + pos, next, aes_buffer, ParquetCrypto::CRYPTO_BLOCK_SIZE);
132
0
        trans.write(aes_buffer, write_size);
133
0
      }
134
0
      current = current->prev;
135
0
    }
136
137
    // Finalize the last encrypted data
138
0
    data_t tag[ParquetCrypto::TAG_BYTES];
139
0
    auto write_size = aes->Finalize(aes_buffer, 0, tag, ParquetCrypto::TAG_BYTES);
140
0
    trans.write(aes_buffer, write_size);
141
    // Write tag for verification
142
0
    trans.write(tag, ParquetCrypto::TAG_BYTES);
143
144
0
    return ParquetCrypto::LENGTH_BYTES + total_length;
145
0
  }
146
147
private:
148
0
  void Initialize(const string &key) {
149
    // Generate Nonce
150
0
    aes->GenerateRandomData(nonce, ParquetCrypto::NONCE_BYTES);
151
    // Initialize Encryption
152
0
    aes->InitializeEncryption(nonce, ParquetCrypto::NONCE_BYTES, reinterpret_cast<const_data_ptr_t>(key.data()),
153
0
                              key.size());
154
0
  }
155
156
private:
157
  //! Protocol and corresponding transport that we're wrapping
158
  TProtocol &prot;
159
  TTransport &trans;
160
161
  //! AES context and buffers
162
  shared_ptr<EncryptionState> aes;
163
164
  //! Nonce created by Initialize()
165
  data_t nonce[ParquetCrypto::NONCE_BYTES];
166
167
  //! Arena Allocator to fully materialize in memory before encrypting
168
  ArenaAllocator allocator;
169
};
170
171
//! Decryption wrapper for a transport protocol
172
class DecryptionTransport : public TTransport {
173
public:
174
  DecryptionTransport(TProtocol &prot_p, const string &key, const EncryptionUtil &encryption_util_p)
175
0
      : prot(prot_p), trans(*prot.getTransport()),
176
0
        aes(encryption_util_p.CreateEncryptionState(EncryptionTypes::GCM, key.size())), read_buffer_size(0),
177
0
        read_buffer_offset(0) {
178
0
    Initialize(key);
179
0
  }
180
0
  uint32_t read_virt(uint8_t *buf, uint32_t len) override {
181
0
    const uint32_t result = len;
182
183
0
    if (len > transport_remaining - ParquetCrypto::TAG_BYTES + read_buffer_size - read_buffer_offset) {
184
0
      throw InvalidInputException("Too many bytes requested from crypto buffer");
185
0
    }
186
187
0
    while (len != 0) {
188
0
      if (read_buffer_offset == read_buffer_size) {
189
0
        ReadBlock(buf);
190
0
      }
191
0
      const auto next = MinValue(read_buffer_size - read_buffer_offset, len);
192
0
      read_buffer_offset += next;
193
0
      buf += next;
194
0
      len -= next;
195
0
    }
196
197
0
    return result;
198
0
  }
199
200
0
  uint32_t Finalize() {
201
0
    if (read_buffer_offset != read_buffer_size) {
202
0
      throw InternalException("DecryptionTransport::Finalize was called with bytes remaining in read buffer: \n"
203
0
                              "read buffer offset: %d, read buffer size: %d",
204
0
                              read_buffer_offset, read_buffer_size);
205
0
    }
206
207
0
    data_t computed_tag[ParquetCrypto::TAG_BYTES];
208
0
    transport_remaining -= trans.read(computed_tag, ParquetCrypto::TAG_BYTES);
209
0
    aes->Finalize(read_buffer, 0, computed_tag, ParquetCrypto::TAG_BYTES);
210
211
0
    if (transport_remaining != 0) {
212
0
      throw InvalidInputException("Encoded ciphertext length differs from actual ciphertext length");
213
0
    }
214
215
0
    return ParquetCrypto::LENGTH_BYTES + total_bytes;
216
0
  }
217
218
0
  AllocatedData ReadAll() {
219
0
    D_ASSERT(transport_remaining == total_bytes - ParquetCrypto::NONCE_BYTES);
220
0
    auto result = Allocator::DefaultAllocator().Allocate(transport_remaining - ParquetCrypto::TAG_BYTES);
221
0
    read_virt(result.get(), transport_remaining - ParquetCrypto::TAG_BYTES);
222
0
    Finalize();
223
0
    return result;
224
0
  }
225
226
private:
227
0
  void Initialize(const string &key) {
228
    // Read encoded length (don't add to read_bytes)
229
0
    data_t length_buf[ParquetCrypto::LENGTH_BYTES];
230
0
    trans.read(length_buf, ParquetCrypto::LENGTH_BYTES);
231
0
    total_bytes = Load<uint32_t>(length_buf);
232
0
    transport_remaining = total_bytes;
233
    // Read nonce and initialize AES
234
0
    transport_remaining -= trans.read(nonce, ParquetCrypto::NONCE_BYTES);
235
    // check whether context is initialized
236
0
    aes->InitializeDecryption(nonce, ParquetCrypto::NONCE_BYTES, reinterpret_cast<const_data_ptr_t>(key.data()),
237
0
                              key.size());
238
0
  }
239
240
0
  void ReadBlock(uint8_t *buf) {
241
    // Read from transport into read_buffer at one AES block size offset (up to the tag)
242
0
    read_buffer_size = MinValue(ParquetCrypto::CRYPTO_BLOCK_SIZE, transport_remaining - ParquetCrypto::TAG_BYTES);
243
0
    transport_remaining -= trans.read(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size);
244
245
    // Decrypt from read_buffer + block size into read_buffer start (decryption can trail behind in same buffer)
246
#ifdef DEBUG
247
    auto size = aes->Process(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size, buf,
248
                             ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE);
249
    D_ASSERT(size == read_buffer_size);
250
#else
251
0
    aes->Process(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size, buf,
252
0
                 ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE);
253
0
#endif
254
0
    read_buffer_offset = 0;
255
0
  }
256
257
private:
258
  //! Protocol and corresponding transport that we're wrapping
259
  TProtocol &prot;
260
  TTransport &trans;
261
262
  //! AES context and buffers
263
  shared_ptr<EncryptionState> aes;
264
265
  //! We read/decrypt big blocks at a time
266
  data_t read_buffer[ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE];
267
  uint32_t read_buffer_size;
268
  uint32_t read_buffer_offset;
269
270
  //! Remaining bytes to read, set by Initialize(), decremented by ReadBlock()
271
  uint32_t total_bytes;
272
  uint32_t transport_remaining;
273
  //! Nonce read by Initialize()
274
  data_t nonce[ParquetCrypto::NONCE_BYTES];
275
};
276
277
class SimpleReadTransport : public TTransport {
278
public:
279
  explicit SimpleReadTransport(data_ptr_t read_buffer_p, uint32_t read_buffer_size_p)
280
0
      : read_buffer(read_buffer_p), read_buffer_size(read_buffer_size_p), read_buffer_offset(0) {
281
0
  }
282
283
0
  uint32_t read_virt(uint8_t *buf, uint32_t len) override {
284
0
    const auto remaining = read_buffer_size - read_buffer_offset;
285
0
    if (len > remaining) {
286
0
      return remaining;
287
0
    }
288
0
    memcpy(buf, read_buffer + read_buffer_offset, len);
289
0
    read_buffer_offset += len;
290
0
    return len;
291
0
  }
292
293
private:
294
  const data_ptr_t read_buffer;
295
  const uint32_t read_buffer_size;
296
  uint32_t read_buffer_offset;
297
};
298
299
uint32_t ParquetCrypto::Read(TBase &object, TProtocol &iprot, const string &key,
300
0
                             const EncryptionUtil &encryption_util_p) {
301
0
  TCompactProtocolFactoryT<DecryptionTransport> tproto_factory;
302
0
  auto dprot =
303
0
      tproto_factory.getProtocol(duckdb_base_std::make_shared<DecryptionTransport>(iprot, key, encryption_util_p));
304
0
  auto &dtrans = reinterpret_cast<DecryptionTransport &>(*dprot->getTransport());
305
306
  // We have to read the whole thing otherwise thrift throws an error before we realize we're decryption is wrong
307
0
  auto all = dtrans.ReadAll();
308
0
  TCompactProtocolFactoryT<SimpleReadTransport> tsimple_proto_factory;
309
0
  auto simple_prot =
310
0
      tsimple_proto_factory.getProtocol(duckdb_base_std::make_shared<SimpleReadTransport>(all.get(), all.GetSize()));
311
312
  // Read the object
313
0
  object.read(simple_prot.get());
314
315
0
  return ParquetCrypto::LENGTH_BYTES + ParquetCrypto::NONCE_BYTES + all.GetSize() + ParquetCrypto::TAG_BYTES;
316
0
}
317
318
uint32_t ParquetCrypto::Write(const TBase &object, TProtocol &oprot, const string &key,
319
0
                              const EncryptionUtil &encryption_util_p) {
320
  // Create encryption protocol
321
0
  TCompactProtocolFactoryT<EncryptionTransport> tproto_factory;
322
0
  auto eprot =
323
0
      tproto_factory.getProtocol(duckdb_base_std::make_shared<EncryptionTransport>(oprot, key, encryption_util_p));
324
0
  auto &etrans = reinterpret_cast<EncryptionTransport &>(*eprot->getTransport());
325
326
  // Write the object in memory
327
0
  object.write(eprot.get());
328
329
  // Encrypt and write to oprot
330
0
  return etrans.Finalize();
331
0
}
332
333
uint32_t ParquetCrypto::ReadData(TProtocol &iprot, const data_ptr_t buffer, const uint32_t buffer_size,
334
0
                                 const string &key, const EncryptionUtil &encryption_util_p) {
335
  // Create decryption protocol
336
0
  TCompactProtocolFactoryT<DecryptionTransport> tproto_factory;
337
0
  auto dprot =
338
0
      tproto_factory.getProtocol(duckdb_base_std::make_shared<DecryptionTransport>(iprot, key, encryption_util_p));
339
0
  auto &dtrans = reinterpret_cast<DecryptionTransport &>(*dprot->getTransport());
340
341
  // Read buffer
342
0
  dtrans.read(buffer, buffer_size);
343
344
  // Verify AES tag and read length
345
0
  return dtrans.Finalize();
346
0
}
347
348
uint32_t ParquetCrypto::WriteData(TProtocol &oprot, const const_data_ptr_t buffer, const uint32_t buffer_size,
349
0
                                  const string &key, const EncryptionUtil &encryption_util_p) {
350
  // FIXME: we know the size upfront so we could do a streaming write instead of this
351
  // Create encryption protocol
352
0
  TCompactProtocolFactoryT<EncryptionTransport> tproto_factory;
353
0
  auto eprot =
354
0
      tproto_factory.getProtocol(duckdb_base_std::make_shared<EncryptionTransport>(oprot, key, encryption_util_p));
355
0
  auto &etrans = reinterpret_cast<EncryptionTransport &>(*eprot->getTransport());
356
357
  // Write the data in memory
358
0
  etrans.write(buffer, buffer_size);
359
360
  // Encrypt and write to oprot
361
0
  return etrans.Finalize();
362
0
}
363
364
0
bool ParquetCrypto::ValidKey(const std::string &key) {
365
0
  switch (key.size()) {
366
0
  case 16:
367
0
  case 24:
368
0
  case 32:
369
0
    return true;
370
0
  default:
371
0
    return false;
372
0
  }
373
0
}
374
375
0
static string Base64Decode(const string &key) {
376
0
  auto result_size = Blob::FromBase64Size(key);
377
0
  auto output = duckdb::unique_ptr<unsigned char[]>(new unsigned char[result_size]);
378
0
  Blob::FromBase64(key, output.get(), result_size);
379
0
  string decoded_key(reinterpret_cast<const char *>(output.get()), result_size);
380
0
  return decoded_key;
381
0
}
382
383
0
void ParquetCrypto::AddKey(ClientContext &context, const FunctionParameters &parameters) {
384
0
  const auto &key_name = StringValue::Get(parameters.values[0]);
385
0
  const auto &key = StringValue::Get(parameters.values[1]);
386
387
0
  auto &keys = ParquetKeys::Get(context);
388
0
  if (ValidKey(key)) {
389
0
    keys.AddKey(key_name, key);
390
0
  } else {
391
0
    string decoded_key;
392
0
    try {
393
0
      decoded_key = Base64Decode(key);
394
0
    } catch (const ConversionException &e) {
395
0
      throw InvalidInputException("Invalid AES key. Not a plain AES key NOR a base64 encoded string");
396
0
    }
397
0
    if (!ValidKey(decoded_key)) {
398
0
      throw InvalidInputException(
399
0
          "Invalid AES key. Must have a length of 128, 192, or 256 bits (16, 24, or 32 bytes)");
400
0
    }
401
0
    keys.AddKey(key_name, decoded_key);
402
0
  }
403
0
}
404
405
} // namespace duckdb