/src/duckdb/third_party/mbedtls/mbedtls_wrapper.cpp
Line | Count | Source |
1 | | #include "mbedtls_wrapper.hpp" |
2 | | |
3 | | // otherwise we have different definitions for mbedtls_pk_context / mbedtls_sha256_context |
4 | | #define MBEDTLS_ALLOW_PRIVATE_ACCESS |
5 | | |
6 | | #include "duckdb/common/helper.hpp" |
7 | | #include "mbedtls/md.h" |
8 | | #include "mbedtls/pk.h" |
9 | | #include "mbedtls/sha1.h" |
10 | | #include "mbedtls/sha256.h" |
11 | | #include "mbedtls/cipher.h" |
12 | | |
13 | | #include "duckdb/common/random_engine.hpp" |
14 | | #include "duckdb/common/types/timestamp.hpp" |
15 | | #include "duckdb/main/config.hpp" |
16 | | #include "duckdb/common/encryption_types.hpp" |
17 | | |
18 | | #include <stdexcept> |
19 | | |
20 | | using namespace duckdb_mbedtls; |
21 | | using CipherType = duckdb::EncryptionTypes::CipherType; |
22 | | using EncryptionVersion = duckdb::EncryptionTypes::EncryptionVersion; |
23 | | using MainHeader = duckdb::MainHeader; |
24 | | |
25 | | /* |
26 | | # Command line tricks to help here |
27 | | # Create a new key |
28 | | openssl genrsa -out private.pem 2048 |
29 | | |
30 | | # Export public key |
31 | | openssl rsa -in private.pem -outform PEM -pubout -out public.pem |
32 | | |
33 | | # Calculate digest and write to 'hash' file on command line |
34 | | openssl dgst -binary -sha256 dummy > hash |
35 | | |
36 | | # Calculate signature from hash |
37 | | openssl pkeyutl -sign -in hash -inkey private.pem -pkeyopt digest:sha256 -out dummy.sign |
38 | | */ |
39 | | |
40 | 0 | void MbedTlsWrapper::ComputeSha256Hash(const char *in, size_t in_len, char *out) { |
41 | |
|
42 | 0 | mbedtls_sha256_context sha_context; |
43 | 0 | mbedtls_sha256_init(&sha_context); |
44 | 0 | if (mbedtls_sha256_starts(&sha_context, false) || |
45 | 0 | mbedtls_sha256_update(&sha_context, reinterpret_cast<const unsigned char *>(in), in_len) || |
46 | 0 | mbedtls_sha256_finish(&sha_context, reinterpret_cast<unsigned char *>(out))) { |
47 | 0 | throw std::runtime_error("SHA256 Error"); |
48 | 0 | } |
49 | 0 | mbedtls_sha256_free(&sha_context); |
50 | 0 | } |
51 | | |
52 | 0 | std::string MbedTlsWrapper::ComputeSha256Hash(const std::string &file_content) { |
53 | 0 | std::string hash; |
54 | 0 | hash.resize(MbedTlsWrapper::SHA256_HASH_LENGTH_BYTES); |
55 | 0 | ComputeSha256Hash(file_content.data(), file_content.size(), (char *)hash.data()); |
56 | 0 | return hash; |
57 | 0 | } |
58 | | |
59 | | bool MbedTlsWrapper::IsValidSha256Signature(const std::string &pubkey, const std::string &signature, |
60 | 0 | const std::string &sha256_hash) { |
61 | |
|
62 | 0 | if (signature.size() != 256 || sha256_hash.size() != 32) { |
63 | 0 | throw std::runtime_error("Invalid input lengths, expected signature length 256, got " + |
64 | 0 | std::to_string(signature.size()) + ", hash length 32, got " + |
65 | 0 | std::to_string(sha256_hash.size())); |
66 | 0 | } |
67 | | |
68 | 0 | mbedtls_pk_context pk_context; |
69 | 0 | mbedtls_pk_init(&pk_context); |
70 | |
|
71 | 0 | if (mbedtls_pk_parse_public_key(&pk_context, reinterpret_cast<const unsigned char *>(pubkey.c_str()), |
72 | 0 | pubkey.size() + 1)) { |
73 | 0 | throw std::runtime_error("RSA public key import error"); |
74 | 0 | } |
75 | | |
76 | | // actually verify |
77 | 0 | bool valid = mbedtls_pk_verify(&pk_context, MBEDTLS_MD_SHA256, |
78 | 0 | reinterpret_cast<const unsigned char *>(sha256_hash.data()), sha256_hash.size(), |
79 | 0 | reinterpret_cast<const unsigned char *>(signature.data()), signature.length()) == 0; |
80 | |
|
81 | 0 | mbedtls_pk_free(&pk_context); |
82 | 0 | return valid; |
83 | 0 | } |
84 | | |
85 | | // used in s3fs |
86 | 0 | void MbedTlsWrapper::Hmac256(const char *key, size_t key_len, const char *message, size_t message_len, char *out) { |
87 | 0 | mbedtls_md_context_t hmac_ctx; |
88 | 0 | const mbedtls_md_info_t *md_type = mbedtls_md_info_from_type(MBEDTLS_MD_SHA256); |
89 | 0 | if (!md_type) { |
90 | 0 | throw std::runtime_error("failed to init hmac"); |
91 | 0 | } |
92 | | |
93 | 0 | if (mbedtls_md_setup(&hmac_ctx, md_type, 1) || |
94 | 0 | mbedtls_md_hmac_starts(&hmac_ctx, reinterpret_cast<const unsigned char *>(key), key_len) || |
95 | 0 | mbedtls_md_hmac_update(&hmac_ctx, reinterpret_cast<const unsigned char *>(message), message_len) || |
96 | 0 | mbedtls_md_hmac_finish(&hmac_ctx, reinterpret_cast<unsigned char *>(out))) { |
97 | 0 | throw std::runtime_error("HMAC256 Error"); |
98 | 0 | } |
99 | 0 | mbedtls_md_free(&hmac_ctx); |
100 | 0 | } |
101 | | |
102 | 0 | void MbedTlsWrapper::ToBase16(char *in, char *out, size_t len) { |
103 | 0 | static char const HEX_CODES[] = "0123456789abcdef"; |
104 | 0 | size_t i, j; |
105 | |
|
106 | 0 | for (j = i = 0; i < len; i++) { |
107 | 0 | int a = in[i]; |
108 | 0 | out[j++] = HEX_CODES[(a >> 4) & 0xf]; |
109 | 0 | out[j++] = HEX_CODES[a & 0xf]; |
110 | 0 | } |
111 | 0 | } |
112 | | |
113 | 0 | MbedTlsWrapper::SHA256State::SHA256State() : sha_context(new mbedtls_sha256_context()) { |
114 | 0 | auto context = reinterpret_cast<mbedtls_sha256_context *>(sha_context); |
115 | |
|
116 | 0 | mbedtls_sha256_init(context); |
117 | |
|
118 | 0 | if (mbedtls_sha256_starts(context, false)) { |
119 | 0 | throw std::runtime_error("SHA256 Error"); |
120 | 0 | } |
121 | 0 | } |
122 | | |
123 | 0 | MbedTlsWrapper::SHA256State::~SHA256State() { |
124 | 0 | auto context = reinterpret_cast<mbedtls_sha256_context *>(sha_context); |
125 | 0 | mbedtls_sha256_free(context); |
126 | 0 | delete context; |
127 | 0 | } |
128 | | |
129 | 0 | void MbedTlsWrapper::SHA256State::AddString(const std::string &str) { |
130 | 0 | auto context = reinterpret_cast<mbedtls_sha256_context *>(sha_context); |
131 | 0 | if (mbedtls_sha256_update(context, (unsigned char *)str.data(), str.size())) { |
132 | 0 | throw std::runtime_error("SHA256 Error"); |
133 | 0 | } |
134 | 0 | } |
135 | | |
136 | 0 | void MbedTlsWrapper::SHA256State::AddBytes(duckdb::const_data_ptr_t input_bytes, duckdb::idx_t len) { |
137 | 0 | auto context = reinterpret_cast<mbedtls_sha256_context *>(sha_context); |
138 | 0 | if (mbedtls_sha256_update(context, input_bytes, len)) { |
139 | 0 | throw std::runtime_error("SHA256 Error"); |
140 | 0 | } |
141 | 0 | } |
142 | | |
143 | 0 | void MbedTlsWrapper::SHA256State::AddBytes(duckdb::data_ptr_t input_bytes, duckdb::idx_t len) { |
144 | 0 | AddBytes(duckdb::const_data_ptr_t(input_bytes), len); |
145 | 0 | } |
146 | | |
147 | 0 | void MbedTlsWrapper::SHA256State::AddSalt(unsigned char *salt, size_t salt_len) { |
148 | 0 | auto context = reinterpret_cast<mbedtls_sha256_context *>(sha_context); |
149 | 0 | if (mbedtls_sha256_update(context, salt, salt_len)) { |
150 | 0 | throw std::runtime_error("SHA256 Error"); |
151 | 0 | } |
152 | 0 | } |
153 | | |
154 | 0 | void MbedTlsWrapper::SHA256State::FinalizeDerivedKey(duckdb::data_ptr_t hash) { |
155 | 0 | auto context = reinterpret_cast<mbedtls_sha256_context *>(sha_context); |
156 | |
|
157 | 0 | if (mbedtls_sha256_finish(context, (duckdb::data_ptr_t)hash)) { |
158 | 0 | throw std::runtime_error("SHA256 Error"); |
159 | 0 | } |
160 | 0 | } |
161 | | |
162 | 0 | std::string MbedTlsWrapper::SHA256State::Finalize() { |
163 | 0 | auto context = reinterpret_cast<mbedtls_sha256_context *>(sha_context); |
164 | |
|
165 | 0 | std::string hash; |
166 | 0 | hash.resize(MbedTlsWrapper::SHA256_HASH_LENGTH_BYTES); |
167 | |
|
168 | 0 | if (mbedtls_sha256_finish(context, (unsigned char *)hash.data())) { |
169 | 0 | throw std::runtime_error("SHA256 Error"); |
170 | 0 | } |
171 | | |
172 | 0 | return hash; |
173 | 0 | } |
174 | | |
175 | 0 | void MbedTlsWrapper::SHA256State::FinishHex(char *out) { |
176 | 0 | auto context = reinterpret_cast<mbedtls_sha256_context *>(sha_context); |
177 | |
|
178 | 0 | std::string hash; |
179 | 0 | hash.resize(MbedTlsWrapper::SHA256_HASH_LENGTH_BYTES); |
180 | |
|
181 | 0 | if (mbedtls_sha256_finish(context, (unsigned char *)hash.data())) { |
182 | 0 | throw std::runtime_error("SHA256 Error"); |
183 | 0 | } |
184 | | |
185 | 0 | MbedTlsWrapper::ToBase16(const_cast<char *>(hash.c_str()), out, MbedTlsWrapper::SHA256_HASH_LENGTH_BYTES); |
186 | 0 | } |
187 | | |
188 | 0 | MbedTlsWrapper::SHA1State::SHA1State() : sha_context(new mbedtls_sha1_context()) { |
189 | 0 | auto context = reinterpret_cast<mbedtls_sha1_context *>(sha_context); |
190 | |
|
191 | 0 | mbedtls_sha1_init(context); |
192 | |
|
193 | 0 | if (mbedtls_sha1_starts(context)) { |
194 | 0 | throw std::runtime_error("SHA1 Error"); |
195 | 0 | } |
196 | 0 | } |
197 | | |
198 | 0 | MbedTlsWrapper::SHA1State::~SHA1State() { |
199 | 0 | auto context = reinterpret_cast<mbedtls_sha1_context *>(sha_context); |
200 | 0 | mbedtls_sha1_free(context); |
201 | 0 | delete context; |
202 | 0 | } |
203 | | |
204 | 0 | void MbedTlsWrapper::SHA1State::AddString(const std::string &str) { |
205 | 0 | auto context = reinterpret_cast<mbedtls_sha1_context *>(sha_context); |
206 | 0 | if (mbedtls_sha1_update(context, (unsigned char *)str.data(), str.size())) { |
207 | 0 | throw std::runtime_error("SHA1 Error"); |
208 | 0 | } |
209 | 0 | } |
210 | | |
211 | 0 | std::string MbedTlsWrapper::SHA1State::Finalize() { |
212 | 0 | auto context = reinterpret_cast<mbedtls_sha1_context *>(sha_context); |
213 | |
|
214 | 0 | std::string hash; |
215 | 0 | hash.resize(MbedTlsWrapper::SHA1_HASH_LENGTH_BYTES); |
216 | |
|
217 | 0 | if (mbedtls_sha1_finish(context, (unsigned char *)hash.data())) { |
218 | 0 | throw std::runtime_error("SHA1 Error"); |
219 | 0 | } |
220 | | |
221 | 0 | return hash; |
222 | 0 | } |
223 | | |
224 | 0 | void MbedTlsWrapper::SHA1State::FinishHex(char *out) { |
225 | 0 | auto context = reinterpret_cast<mbedtls_sha1_context *>(sha_context); |
226 | |
|
227 | 0 | std::string hash; |
228 | 0 | hash.resize(MbedTlsWrapper::SHA1_HASH_LENGTH_BYTES); |
229 | |
|
230 | 0 | if (mbedtls_sha1_finish(context, (unsigned char *)hash.data())) { |
231 | 0 | throw std::runtime_error("SHA1 Error"); |
232 | 0 | } |
233 | | |
234 | 0 | MbedTlsWrapper::ToBase16(const_cast<char *>(hash.c_str()), out, MbedTlsWrapper::SHA1_HASH_LENGTH_BYTES); |
235 | 0 | } |
236 | | |
237 | 0 | const mbedtls_cipher_info_t *MbedTlsWrapper::AESStateMBEDTLS::GetCipher(){ |
238 | 0 | switch(metadata->GetCipher()) { |
239 | 0 | case CipherType::GCM: |
240 | 0 | switch (metadata->GetKeyLen()) { |
241 | 0 | case 16: |
242 | 0 | return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_128_GCM); |
243 | 0 | case 24: |
244 | 0 | return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_192_GCM); |
245 | 0 | case 32: |
246 | 0 | return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_256_GCM); |
247 | 0 | default: |
248 | 0 | throw std::runtime_error("Invalid AES key length for GCM"); |
249 | 0 | } |
250 | 0 | case CipherType::CTR: |
251 | 0 | switch (metadata->GetKeyLen()) { |
252 | 0 | case 16: |
253 | 0 | return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_128_CTR); |
254 | 0 | case 24: |
255 | 0 | return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_192_CTR); |
256 | 0 | case 32: |
257 | 0 | return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_256_CTR); |
258 | 0 | default: |
259 | 0 | throw std::runtime_error("Invalid AES key length for CTR"); |
260 | 0 | } |
261 | 0 | case CipherType::CBC: |
262 | 0 | switch (metadata->GetKeyLen()) { |
263 | 0 | case 16: |
264 | 0 | return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_128_CBC); |
265 | 0 | case 24: |
266 | 0 | return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_192_CBC); |
267 | 0 | case 32: |
268 | 0 | return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_256_CBC); |
269 | 0 | default: |
270 | 0 | throw std::runtime_error("Invalid AES key length for CBC"); |
271 | 0 | } |
272 | 0 | default: |
273 | 0 | throw duckdb::InternalException("Invalid Encryption/Decryption Cipher: %s", duckdb::EncryptionTypes::CipherToString(metadata->GetCipher())); |
274 | 0 | } |
275 | 0 | } |
276 | | |
277 | 0 | void MbedTlsWrapper::AESStateMBEDTLS::SecureClearData(duckdb::data_ptr_t data, duckdb::idx_t len) { |
278 | 0 | mbedtls_platform_zeroize(data, len); |
279 | 0 | } |
280 | | |
281 | 0 | MbedTlsWrapper::AESStateMBEDTLS::AESStateMBEDTLS(duckdb::unique_ptr<duckdb::EncryptionStateMetadata> metadata_p) : EncryptionState(std::move(metadata_p)), context(duckdb::make_uniq<mbedtls_cipher_context_t>()) { |
282 | 0 | mbedtls_cipher_init(context.get()); |
283 | |
|
284 | 0 | auto cipher_info = GetCipher(); |
285 | |
|
286 | 0 | if (!cipher_info) { |
287 | 0 | throw std::runtime_error("Failed to get Cipher"); |
288 | 0 | } |
289 | | |
290 | 0 | if (mbedtls_cipher_setup(context.get(), cipher_info)) { |
291 | 0 | throw std::runtime_error("Failed to initialize cipher context"); |
292 | 0 | } |
293 | | |
294 | 0 | if (metadata->GetCipher() == duckdb::EncryptionTypes::CBC && mbedtls_cipher_set_padding_mode(context.get(), MBEDTLS_PADDING_PKCS7)) { |
295 | 0 | throw std::runtime_error("Failed to set CBC padding"); |
296 | |
|
297 | 0 | } |
298 | 0 | } |
299 | | |
300 | 0 | MbedTlsWrapper::AESStateMBEDTLS::~AESStateMBEDTLS() { |
301 | 0 | if (context) { |
302 | 0 | mbedtls_cipher_free(context.get()); |
303 | 0 | } |
304 | 0 | } |
305 | | |
306 | 0 | void MbedTlsWrapper::AESStateMBEDTLS::GenerateRandomDataInsecure(duckdb::data_ptr_t data, duckdb::idx_t len) { |
307 | 0 | if (!force_mbedtls) { |
308 | | // To use this insecure MbedTLS random number generator |
309 | | // we double check if force_mbedtls_unsafe is set |
310 | | // such that we do not accidentaly opt-in |
311 | 0 | throw duckdb::InternalException("Insecure random generation called without setting 'force_mbedtls_unsafe' = true"); |
312 | 0 | } |
313 | | |
314 | 0 | duckdb::RandomEngine random_engine; |
315 | |
|
316 | 0 | while (len != 0) { |
317 | 0 | const auto random_integer = random_engine.NextRandomInteger(); |
318 | 0 | const auto next = duckdb::MinValue<duckdb::idx_t>(len, sizeof(random_integer)); |
319 | 0 | memcpy(data, duckdb::const_data_ptr_cast(&random_integer), next); |
320 | 0 | data += next; |
321 | 0 | len -= next; |
322 | 0 | } |
323 | 0 | } |
324 | | |
325 | 0 | void MbedTlsWrapper::AESStateMBEDTLS::GenerateRandomData(duckdb::data_ptr_t data, duckdb::idx_t len) { |
326 | | // generate insecure random data |
327 | 0 | GenerateRandomDataInsecure(data, len); |
328 | 0 | } |
329 | | |
330 | 0 | void MbedTlsWrapper::AESStateMBEDTLS::InitializeInternal(duckdb::EncryptionNonce &nonce, duckdb::const_data_ptr_t aad, duckdb::idx_t aad_len){ |
331 | 0 | if (mbedtls_cipher_set_iv(context.get(), nonce.data(), nonce.total_size())) { |
332 | 0 | throw std::runtime_error("Failed to set IV for encryption"); |
333 | 0 | } |
334 | | |
335 | 0 | if (aad_len > 0) { |
336 | 0 | if (mbedtls_cipher_update_ad(context.get(), aad, aad_len)) { |
337 | 0 | throw std::runtime_error("Failed to set AAD"); |
338 | 0 | } |
339 | 0 | } |
340 | 0 | } |
341 | | |
342 | 0 | void MbedTlsWrapper::AESStateMBEDTLS::InitializeEncryption(duckdb::EncryptionNonce &nonce, duckdb::const_data_ptr_t key, duckdb::const_data_ptr_t aad, duckdb::idx_t aad_len) { |
343 | 0 | mode = duckdb::EncryptionTypes::ENCRYPT; |
344 | |
|
345 | 0 | if (mbedtls_cipher_setkey(context.get(), key, metadata->GetKeyLen() * 8, MBEDTLS_ENCRYPT) != 0) { |
346 | 0 | throw std::runtime_error("Failed to set AES key for encryption"); |
347 | 0 | } |
348 | | |
349 | 0 | InitializeInternal(nonce, aad, aad_len); |
350 | 0 | } |
351 | | |
352 | 0 | void MbedTlsWrapper::AESStateMBEDTLS::InitializeDecryption(duckdb::EncryptionNonce &nonce, duckdb::const_data_ptr_t key, duckdb::const_data_ptr_t aad, duckdb::idx_t aad_len) { |
353 | 0 | mode = duckdb::EncryptionTypes::DECRYPT; |
354 | |
|
355 | 0 | if (mbedtls_cipher_setkey(context.get(), key, metadata->GetKeyLen() * 8, MBEDTLS_DECRYPT)) { |
356 | 0 | throw std::runtime_error("Failed to set AES key for encryption"); |
357 | 0 | } |
358 | | |
359 | 0 | InitializeInternal(nonce, aad, aad_len); |
360 | 0 | } |
361 | | |
362 | | size_t MbedTlsWrapper::AESStateMBEDTLS::Process(duckdb::const_data_ptr_t in, duckdb::idx_t in_len, duckdb::data_ptr_t out, |
363 | 0 | duckdb::idx_t out_len) { |
364 | | |
365 | | // GCM works in-place, CTR and CBC don't |
366 | 0 | auto use_out_copy = in == out && metadata->GetCipher() != CipherType::GCM; |
367 | |
|
368 | 0 | auto out_ptr = out; |
369 | 0 | std::unique_ptr<duckdb::data_t[]> out_copy; |
370 | 0 | if (use_out_copy) { |
371 | 0 | out_copy.reset(new duckdb::data_t[out_len]); |
372 | 0 | out_ptr = out_copy.get(); |
373 | 0 | } |
374 | |
|
375 | 0 | size_t out_len_res = duckdb::NumericCast<size_t>(out_len); |
376 | 0 | if (mbedtls_cipher_update(context.get(), reinterpret_cast<const unsigned char *>(in), in_len, out_ptr, |
377 | 0 | &out_len_res)) { |
378 | 0 | throw std::runtime_error("Encryption or Decryption failed at Process"); |
379 | 0 | }; |
380 | |
|
381 | 0 | if (use_out_copy) { |
382 | 0 | memcpy(out, out_ptr, out_len_res); |
383 | 0 | } |
384 | 0 | return out_len_res; |
385 | 0 | } |
386 | | |
387 | 0 | void MbedTlsWrapper::AESStateMBEDTLS::FinalizeGCM(duckdb::data_ptr_t tag, duckdb::idx_t tag_len){ |
388 | |
|
389 | 0 | switch (mode) { |
390 | | |
391 | 0 | case duckdb::EncryptionTypes::ENCRYPT: { |
392 | 0 | if (mbedtls_cipher_write_tag(context.get(), tag, tag_len)) { |
393 | 0 | throw std::runtime_error("Writing tag failed"); |
394 | 0 | } |
395 | 0 | break; |
396 | 0 | } |
397 | | |
398 | 0 | case duckdb::EncryptionTypes::DECRYPT: { |
399 | 0 | if (mbedtls_cipher_check_tag(context.get(), tag, tag_len)) { |
400 | 0 | throw duckdb::InvalidInputException( |
401 | 0 | "Computed AES tag differs from read AES tag, are you using the right key?"); |
402 | 0 | } |
403 | 0 | break; |
404 | 0 | } |
405 | | |
406 | 0 | default: |
407 | 0 | throw duckdb::InternalException("Unhandled encryption mode %d", static_cast<int>(mode)); |
408 | 0 | } |
409 | 0 | } |
410 | | |
411 | | size_t MbedTlsWrapper::AESStateMBEDTLS::Finalize(duckdb::data_ptr_t out, duckdb::idx_t out_len, duckdb::data_ptr_t tag, |
412 | 0 | duckdb::idx_t tag_len) { |
413 | 0 | size_t result = out_len; |
414 | 0 | if (mbedtls_cipher_finish(context.get(), out, &result)) { |
415 | 0 | throw std::runtime_error("Encryption or Decryption failed at Finalize"); |
416 | 0 | } |
417 | 0 | if (metadata->GetCipher() == duckdb::EncryptionTypes::GCM) { |
418 | 0 | FinalizeGCM(tag, tag_len); |
419 | 0 | } |
420 | 0 | return result; |
421 | 0 | } |