Coverage Report

Created: 2025-08-28 07:58

/src/duckdb/extension/parquet/parquet_crypto.cpp
Line
Count
Source (jump to first uncovered line)
1
#include "parquet_crypto.hpp"
2
3
#include "mbedtls_wrapper.hpp"
4
#include "thrift_tools.hpp"
5
6
#include "duckdb/common/exception/conversion_exception.hpp"
7
#include "duckdb/common/helper.hpp"
8
#include "duckdb/common/types/blob.hpp"
9
#include "duckdb/storage/arena_allocator.hpp"
10
11
namespace duckdb {
12
13
0
ParquetKeys &ParquetKeys::Get(ClientContext &context) {
14
0
  auto &cache = ObjectCache::GetObjectCache(context);
15
0
  if (!cache.Get<ParquetKeys>(ParquetKeys::ObjectType())) {
16
0
    cache.Put(ParquetKeys::ObjectType(), make_shared_ptr<ParquetKeys>());
17
0
  }
18
0
  return *cache.Get<ParquetKeys>(ParquetKeys::ObjectType());
19
0
}
20
21
0
void ParquetKeys::AddKey(const string &key_name, const string &key) {
22
0
  keys[key_name] = key;
23
0
}
24
25
0
bool ParquetKeys::HasKey(const string &key_name) const {
26
0
  return keys.find(key_name) != keys.end();
27
0
}
28
29
0
const string &ParquetKeys::GetKey(const string &key_name) const {
30
0
  D_ASSERT(HasKey(key_name));
31
0
  return keys.at(key_name);
32
0
}
33
34
0
string ParquetKeys::ObjectType() {
35
0
  return "parquet_keys";
36
0
}
37
38
0
string ParquetKeys::GetObjectType() {
39
0
  return ObjectType();
40
0
}
41
42
0
ParquetEncryptionConfig::ParquetEncryptionConfig() {
43
0
}
44
45
0
ParquetEncryptionConfig::ParquetEncryptionConfig(string footer_key_p) : footer_key(std::move(footer_key_p)) {
46
0
}
47
48
0
ParquetEncryptionConfig::ParquetEncryptionConfig(ClientContext &context, const Value &arg) {
49
0
  if (arg.type().id() != LogicalTypeId::STRUCT) {
50
0
    throw BinderException("Parquet encryption_config must be of type STRUCT");
51
0
  }
52
0
  const auto &child_types = StructType::GetChildTypes(arg.type());
53
0
  auto &children = StructValue::GetChildren(arg);
54
0
  const auto &keys = ParquetKeys::Get(context);
55
0
  for (idx_t i = 0; i < StructType::GetChildCount(arg.type()); i++) {
56
0
    auto &struct_key = child_types[i].first;
57
0
    if (StringUtil::Lower(struct_key) == "footer_key") {
58
0
      const auto footer_key_name = StringValue::Get(children[i].DefaultCastAs(LogicalType::VARCHAR));
59
0
      if (!keys.HasKey(footer_key_name)) {
60
0
        throw BinderException(
61
0
            "No key with name \"%s\" exists. Add it with PRAGMA add_parquet_key('<key_name>','<key>');",
62
0
            footer_key_name);
63
0
      }
64
      // footer key name provided - read the key from the config
65
0
      const auto &keys = ParquetKeys::Get(context);
66
0
      footer_key = keys.GetKey(footer_key_name);
67
0
    } else if (StringUtil::Lower(struct_key) == "footer_key_value") {
68
0
      footer_key = StringValue::Get(children[i].DefaultCastAs(LogicalType::BLOB));
69
0
    } else if (StringUtil::Lower(struct_key) == "column_keys") {
70
0
      throw NotImplementedException("Parquet encryption_config column_keys not yet implemented");
71
0
    } else {
72
0
      throw BinderException("Unknown key in encryption_config \"%s\"", struct_key);
73
0
    }
74
0
  }
75
0
}
76
77
0
shared_ptr<ParquetEncryptionConfig> ParquetEncryptionConfig::Create(ClientContext &context, const Value &arg) {
78
0
  return shared_ptr<ParquetEncryptionConfig>(new ParquetEncryptionConfig(context, arg));
79
0
}
80
81
0
const string &ParquetEncryptionConfig::GetFooterKey() const {
82
0
  return footer_key;
83
0
}
84
85
using duckdb_apache::thrift::protocol::TCompactProtocolFactoryT;
86
using duckdb_apache::thrift::transport::TTransport;
87
88
//! Encryption wrapper for a transport protocol
89
class EncryptionTransport : public TTransport {
90
public:
91
  EncryptionTransport(TProtocol &prot_p, const string &key, const EncryptionUtil &encryption_util_p)
92
0
      : prot(prot_p), trans(*prot.getTransport()),
93
0
        aes(encryption_util_p.CreateEncryptionState(reinterpret_cast<const_data_ptr_t>(key.data()), key.size())),
94
0
        allocator(Allocator::DefaultAllocator(), ParquetCrypto::CRYPTO_BLOCK_SIZE) {
95
0
    Initialize(key);
96
0
  }
97
98
0
  bool isOpen() const override {
99
0
    return trans.isOpen();
100
0
  }
101
102
0
  void open() override {
103
0
    trans.open();
104
0
  }
105
106
0
  void close() override {
107
0
    trans.close();
108
0
  }
109
110
0
  void write_virt(const uint8_t *buf, uint32_t len) override {
111
0
    memcpy(allocator.Allocate(len), buf, len);
112
0
  }
113
114
0
  uint32_t Finalize() {
115
    // Write length
116
0
    const auto ciphertext_length = allocator.SizeInBytes();
117
0
    const uint32_t total_length = ParquetCrypto::NONCE_BYTES + ciphertext_length + ParquetCrypto::TAG_BYTES;
118
119
0
    trans.write(const_data_ptr_cast(&total_length), ParquetCrypto::LENGTH_BYTES);
120
    // Write nonce at beginning of encrypted chunk
121
0
    trans.write(nonce, ParquetCrypto::NONCE_BYTES);
122
123
0
    data_t aes_buffer[ParquetCrypto::CRYPTO_BLOCK_SIZE];
124
0
    auto current = allocator.GetTail();
125
126
    // Loop through the whole chunk
127
0
    while (current != nullptr) {
128
0
      for (idx_t pos = 0; pos < current->current_position; pos += ParquetCrypto::CRYPTO_BLOCK_SIZE) {
129
0
        auto next = MinValue<idx_t>(current->current_position - pos, ParquetCrypto::CRYPTO_BLOCK_SIZE);
130
0
        auto write_size =
131
0
            aes->Process(current->data.get() + pos, next, aes_buffer, ParquetCrypto::CRYPTO_BLOCK_SIZE);
132
0
        trans.write(aes_buffer, write_size);
133
0
      }
134
0
      current = current->prev;
135
0
    }
136
137
    // Finalize the last encrypted data
138
0
    data_t tag[ParquetCrypto::TAG_BYTES];
139
0
    auto write_size = aes->Finalize(aes_buffer, 0, tag, ParquetCrypto::TAG_BYTES);
140
0
    trans.write(aes_buffer, write_size);
141
    // Write tag for verification
142
0
    trans.write(tag, ParquetCrypto::TAG_BYTES);
143
144
0
    return ParquetCrypto::LENGTH_BYTES + total_length;
145
0
  }
146
147
private:
148
0
  void Initialize(const string &key) {
149
    // Generate Nonce
150
0
    aes->GenerateRandomData(nonce, ParquetCrypto::NONCE_BYTES);
151
    // Initialize Encryption
152
0
    aes->InitializeEncryption(nonce, ParquetCrypto::NONCE_BYTES, reinterpret_cast<const_data_ptr_t>(key.data()),
153
0
                              key.size());
154
0
  }
155
156
private:
157
  //! Protocol and corresponding transport that we're wrapping
158
  TProtocol &prot;
159
  TTransport &trans;
160
161
  //! AES context and buffers
162
  shared_ptr<EncryptionState> aes;
163
164
  //! Nonce created by Initialize()
165
  data_t nonce[ParquetCrypto::NONCE_BYTES];
166
167
  //! Arena Allocator to fully materialize in memory before encrypting
168
  ArenaAllocator allocator;
169
};
170
171
//! Decryption wrapper for a transport protocol
172
class DecryptionTransport : public TTransport {
173
public:
174
  DecryptionTransport(TProtocol &prot_p, const string &key, const EncryptionUtil &encryption_util_p)
175
0
      : prot(prot_p), trans(*prot.getTransport()),
176
0
        aes(encryption_util_p.CreateEncryptionState(reinterpret_cast<const_data_ptr_t>(key.data()), key.size())),
177
0
        read_buffer_size(0), read_buffer_offset(0) {
178
0
    Initialize(key);
179
0
  }
180
0
  uint32_t read_virt(uint8_t *buf, uint32_t len) override {
181
0
    const uint32_t result = len;
182
183
0
    if (len > transport_remaining - ParquetCrypto::TAG_BYTES + read_buffer_size - read_buffer_offset) {
184
0
      throw InvalidInputException("Too many bytes requested from crypto buffer");
185
0
    }
186
187
0
    while (len != 0) {
188
0
      if (read_buffer_offset == read_buffer_size) {
189
0
        ReadBlock(buf);
190
0
      }
191
0
      const auto next = MinValue(read_buffer_size - read_buffer_offset, len);
192
0
      read_buffer_offset += next;
193
0
      buf += next;
194
0
      len -= next;
195
0
    }
196
197
0
    return result;
198
0
  }
199
200
0
  uint32_t Finalize() {
201
202
0
    if (read_buffer_offset != read_buffer_size) {
203
0
      throw InternalException("DecryptionTransport::Finalize was called with bytes remaining in read buffer: \n"
204
0
                              "read buffer offset: %d, read buffer size: %d",
205
0
                              read_buffer_offset, read_buffer_size);
206
0
    }
207
208
0
    data_t computed_tag[ParquetCrypto::TAG_BYTES];
209
0
    transport_remaining -= trans.read(computed_tag, ParquetCrypto::TAG_BYTES);
210
0
    if (aes->Finalize(read_buffer, 0, computed_tag, ParquetCrypto::TAG_BYTES) != 0) {
211
0
      throw InternalException("DecryptionTransport::Finalize was called with bytes remaining in AES context out");
212
0
    }
213
214
0
    if (transport_remaining != 0) {
215
0
      throw InvalidInputException("Encoded ciphertext length differs from actual ciphertext length");
216
0
    }
217
218
0
    return ParquetCrypto::LENGTH_BYTES + total_bytes;
219
0
  }
220
221
0
  AllocatedData ReadAll() {
222
0
    D_ASSERT(transport_remaining == total_bytes - ParquetCrypto::NONCE_BYTES);
223
0
    auto result = Allocator::DefaultAllocator().Allocate(transport_remaining - ParquetCrypto::TAG_BYTES);
224
0
    read_virt(result.get(), transport_remaining - ParquetCrypto::TAG_BYTES);
225
0
    Finalize();
226
0
    return result;
227
0
  }
228
229
private:
230
0
  void Initialize(const string &key) {
231
    // Read encoded length (don't add to read_bytes)
232
0
    data_t length_buf[ParquetCrypto::LENGTH_BYTES];
233
0
    trans.read(length_buf, ParquetCrypto::LENGTH_BYTES);
234
0
    total_bytes = Load<uint32_t>(length_buf);
235
0
    transport_remaining = total_bytes;
236
    // Read nonce and initialize AES
237
0
    transport_remaining -= trans.read(nonce, ParquetCrypto::NONCE_BYTES);
238
    // check whether context is initialized
239
0
    aes->InitializeDecryption(nonce, ParquetCrypto::NONCE_BYTES, reinterpret_cast<const_data_ptr_t>(key.data()),
240
0
                              key.size());
241
0
  }
242
243
0
  void ReadBlock(uint8_t *buf) {
244
    // Read from transport into read_buffer at one AES block size offset (up to the tag)
245
0
    read_buffer_size = MinValue(ParquetCrypto::CRYPTO_BLOCK_SIZE, transport_remaining - ParquetCrypto::TAG_BYTES);
246
0
    transport_remaining -= trans.read(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size);
247
248
    // Decrypt from read_buffer + block size into read_buffer start (decryption can trail behind in same buffer)
249
#ifdef DEBUG
250
    auto size = aes->Process(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size, buf,
251
                             ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE);
252
    D_ASSERT(size == read_buffer_size);
253
#else
254
0
    aes->Process(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size, buf,
255
0
                 ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE);
256
0
#endif
257
0
    read_buffer_offset = 0;
258
0
  }
259
260
private:
261
  //! Protocol and corresponding transport that we're wrapping
262
  TProtocol &prot;
263
  TTransport &trans;
264
265
  //! AES context and buffers
266
  shared_ptr<EncryptionState> aes;
267
268
  //! We read/decrypt big blocks at a time
269
  data_t read_buffer[ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE];
270
  uint32_t read_buffer_size;
271
  uint32_t read_buffer_offset;
272
273
  //! Remaining bytes to read, set by Initialize(), decremented by ReadBlock()
274
  uint32_t total_bytes;
275
  uint32_t transport_remaining;
276
  //! Nonce read by Initialize()
277
  data_t nonce[ParquetCrypto::NONCE_BYTES];
278
};
279
280
class SimpleReadTransport : public TTransport {
281
public:
282
  explicit SimpleReadTransport(data_ptr_t read_buffer_p, uint32_t read_buffer_size_p)
283
0
      : read_buffer(read_buffer_p), read_buffer_size(read_buffer_size_p), read_buffer_offset(0) {
284
0
  }
285
286
0
  uint32_t read_virt(uint8_t *buf, uint32_t len) override {
287
0
    const auto remaining = read_buffer_size - read_buffer_offset;
288
0
    if (len > remaining) {
289
0
      return remaining;
290
0
    }
291
0
    memcpy(buf, read_buffer + read_buffer_offset, len);
292
0
    read_buffer_offset += len;
293
0
    return len;
294
0
  }
295
296
private:
297
  const data_ptr_t read_buffer;
298
  const uint32_t read_buffer_size;
299
  uint32_t read_buffer_offset;
300
};
301
302
uint32_t ParquetCrypto::Read(TBase &object, TProtocol &iprot, const string &key,
303
0
                             const EncryptionUtil &encryption_util_p) {
304
0
  TCompactProtocolFactoryT<DecryptionTransport> tproto_factory;
305
0
  auto dprot =
306
0
      tproto_factory.getProtocol(duckdb_base_std::make_shared<DecryptionTransport>(iprot, key, encryption_util_p));
307
0
  auto &dtrans = reinterpret_cast<DecryptionTransport &>(*dprot->getTransport());
308
309
  // We have to read the whole thing otherwise thrift throws an error before we realize we're decryption is wrong
310
0
  auto all = dtrans.ReadAll();
311
0
  TCompactProtocolFactoryT<SimpleReadTransport> tsimple_proto_factory;
312
0
  auto simple_prot =
313
0
      tsimple_proto_factory.getProtocol(duckdb_base_std::make_shared<SimpleReadTransport>(all.get(), all.GetSize()));
314
315
  // Read the object
316
0
  object.read(simple_prot.get());
317
318
0
  return ParquetCrypto::LENGTH_BYTES + ParquetCrypto::NONCE_BYTES + all.GetSize() + ParquetCrypto::TAG_BYTES;
319
0
}
320
321
uint32_t ParquetCrypto::Write(const TBase &object, TProtocol &oprot, const string &key,
322
0
                              const EncryptionUtil &encryption_util_p) {
323
  // Create encryption protocol
324
0
  TCompactProtocolFactoryT<EncryptionTransport> tproto_factory;
325
0
  auto eprot =
326
0
      tproto_factory.getProtocol(duckdb_base_std::make_shared<EncryptionTransport>(oprot, key, encryption_util_p));
327
0
  auto &etrans = reinterpret_cast<EncryptionTransport &>(*eprot->getTransport());
328
329
  // Write the object in memory
330
0
  object.write(eprot.get());
331
332
  // Encrypt and write to oprot
333
0
  return etrans.Finalize();
334
0
}
335
336
uint32_t ParquetCrypto::ReadData(TProtocol &iprot, const data_ptr_t buffer, const uint32_t buffer_size,
337
0
                                 const string &key, const EncryptionUtil &encryption_util_p) {
338
  // Create decryption protocol
339
0
  TCompactProtocolFactoryT<DecryptionTransport> tproto_factory;
340
0
  auto dprot =
341
0
      tproto_factory.getProtocol(duckdb_base_std::make_shared<DecryptionTransport>(iprot, key, encryption_util_p));
342
0
  auto &dtrans = reinterpret_cast<DecryptionTransport &>(*dprot->getTransport());
343
344
  // Read buffer
345
0
  dtrans.read(buffer, buffer_size);
346
347
  // Verify AES tag and read length
348
0
  return dtrans.Finalize();
349
0
}
350
351
uint32_t ParquetCrypto::WriteData(TProtocol &oprot, const const_data_ptr_t buffer, const uint32_t buffer_size,
352
0
                                  const string &key, const EncryptionUtil &encryption_util_p) {
353
  // FIXME: we know the size upfront so we could do a streaming write instead of this
354
  // Create encryption protocol
355
0
  TCompactProtocolFactoryT<EncryptionTransport> tproto_factory;
356
0
  auto eprot =
357
0
      tproto_factory.getProtocol(duckdb_base_std::make_shared<EncryptionTransport>(oprot, key, encryption_util_p));
358
0
  auto &etrans = reinterpret_cast<EncryptionTransport &>(*eprot->getTransport());
359
360
  // Write the data in memory
361
0
  etrans.write(buffer, buffer_size);
362
363
  // Encrypt and write to oprot
364
0
  return etrans.Finalize();
365
0
}
366
367
0
bool ParquetCrypto::ValidKey(const std::string &key) {
368
0
  switch (key.size()) {
369
0
  case 16:
370
0
  case 24:
371
0
  case 32:
372
0
    return true;
373
0
  default:
374
0
    return false;
375
0
  }
376
0
}
377
378
0
static string Base64Decode(const string &key) {
379
0
  auto result_size = Blob::FromBase64Size(key);
380
0
  auto output = duckdb::unique_ptr<unsigned char[]>(new unsigned char[result_size]);
381
0
  Blob::FromBase64(key, output.get(), result_size);
382
0
  string decoded_key(reinterpret_cast<const char *>(output.get()), result_size);
383
0
  return decoded_key;
384
0
}
385
386
0
void ParquetCrypto::AddKey(ClientContext &context, const FunctionParameters &parameters) {
387
0
  const auto &key_name = StringValue::Get(parameters.values[0]);
388
0
  const auto &key = StringValue::Get(parameters.values[1]);
389
390
0
  auto &keys = ParquetKeys::Get(context);
391
0
  if (ValidKey(key)) {
392
0
    keys.AddKey(key_name, key);
393
0
  } else {
394
0
    string decoded_key;
395
0
    try {
396
0
      decoded_key = Base64Decode(key);
397
0
    } catch (const ConversionException &e) {
398
0
      throw InvalidInputException("Invalid AES key. Not a plain AES key NOR a base64 encoded string");
399
0
    }
400
0
    if (!ValidKey(decoded_key)) {
401
0
      throw InvalidInputException(
402
0
          "Invalid AES key. Must have a length of 128, 192, or 256 bits (16, 24, or 32 bytes)");
403
0
    }
404
0
    keys.AddKey(key_name, decoded_key);
405
0
  }
406
0
}
407
408
} // namespace duckdb