Coverage Report

Created: 2026-03-23 07:13

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/rust/registry/src/index.crates.io-1949cf8c6b5b557f/aws-lc-rs-1.16.2/src/kem.rs
Line
Count
Source
1
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
// SPDX-License-Identifier: Apache-2.0 OR ISC
3
4
//! Key-Encapsulation Mechanisms (KEMs), including support for Kyber Round 3 Submission.
5
//!
6
//! # Example
7
//!
8
//! Note that this example uses the Kyber-512 Round 3 algorithm, but other algorithms can be used
9
//! in the exact same way by substituting
10
//! `kem::<desired_algorithm_here>` for `kem::KYBER512_R3`.
11
//!
12
//! ```rust
13
//! use aws_lc_rs::{
14
//!     kem::{Ciphertext, DecapsulationKey, EncapsulationKey},
15
//!     kem::{ML_KEM_512}
16
//! };
17
//!
18
//! // Alice generates their (private) decapsulation key.
19
//! let decapsulation_key = DecapsulationKey::generate(&ML_KEM_512)?;
20
//!
21
//! // Alices computes the (public) encapsulation key.
22
//! let encapsulation_key = decapsulation_key.encapsulation_key()?;
23
//!
24
//! let encapsulation_key_bytes = encapsulation_key.key_bytes()?;
25
//!
26
//! // Alice sends the encapsulation key bytes to bob through some
27
//! // protocol message.
28
//! let encapsulation_key_bytes = encapsulation_key_bytes.as_ref();
29
//!
30
//! // Bob constructs the (public) encapsulation key from the key bytes provided by Alice.
31
//! let retrieved_encapsulation_key = EncapsulationKey::new(&ML_KEM_512, encapsulation_key_bytes)?;
32
//!
33
//! // Bob executes the encapsulation algorithm to to produce their copy of the secret, and associated ciphertext.
34
//! let (ciphertext, bob_secret) = retrieved_encapsulation_key.encapsulate()?;
35
//!
36
//! // Alice receives ciphertext bytes from bob
37
//! let ciphertext_bytes = ciphertext.as_ref();
38
//!
39
//! // Bob sends Alice the ciphertext computed from the encapsulation algorithm, Alice runs decapsulation to derive their
40
//! // copy of the secret.
41
//! let alice_secret = decapsulation_key.decapsulate(Ciphertext::from(ciphertext_bytes))?;
42
//!
43
//! // Alice and Bob have now arrived to the same secret
44
//! assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
45
//!
46
//! # Ok::<(), aws_lc_rs::error::Unspecified>(())
47
//! ```
48
use crate::aws_lc::{
49
    EVP_PKEY_CTX_kem_set_params, EVP_PKEY_decapsulate, EVP_PKEY_encapsulate,
50
    EVP_PKEY_kem_new_raw_public_key, EVP_PKEY_kem_new_raw_secret_key, EVP_PKEY, EVP_PKEY_KEM,
51
};
52
use crate::buffer::Buffer;
53
use crate::encoding::generated_encodings;
54
use crate::error::{KeyRejected, Unspecified};
55
use crate::ptr::LcPtr;
56
use alloc::borrow::Cow;
57
use core::cmp::Ordering;
58
use zeroize::Zeroize;
59
60
const ML_KEM_512_SHARED_SECRET_LENGTH: usize = 32;
61
const ML_KEM_512_PUBLIC_KEY_LENGTH: usize = 800;
62
const ML_KEM_512_SECRET_KEY_LENGTH: usize = 1632;
63
const ML_KEM_512_CIPHERTEXT_LENGTH: usize = 768;
64
65
const ML_KEM_768_SHARED_SECRET_LENGTH: usize = 32;
66
const ML_KEM_768_PUBLIC_KEY_LENGTH: usize = 1184;
67
const ML_KEM_768_SECRET_KEY_LENGTH: usize = 2400;
68
const ML_KEM_768_CIPHERTEXT_LENGTH: usize = 1088;
69
70
const ML_KEM_1024_SHARED_SECRET_LENGTH: usize = 32;
71
const ML_KEM_1024_PUBLIC_KEY_LENGTH: usize = 1568;
72
const ML_KEM_1024_SECRET_KEY_LENGTH: usize = 3168;
73
const ML_KEM_1024_CIPHERTEXT_LENGTH: usize = 1568;
74
75
/// NIST FIPS 203 ML-KEM-512 algorithm.
76
pub const ML_KEM_512: Algorithm<AlgorithmId> = Algorithm {
77
    id: AlgorithmId::MlKem512,
78
    decapsulate_key_size: ML_KEM_512_SECRET_KEY_LENGTH,
79
    encapsulate_key_size: ML_KEM_512_PUBLIC_KEY_LENGTH,
80
    ciphertext_size: ML_KEM_512_CIPHERTEXT_LENGTH,
81
    shared_secret_size: ML_KEM_512_SHARED_SECRET_LENGTH,
82
};
83
84
/// NIST FIPS 203 ML-KEM-768 algorithm.
85
pub const ML_KEM_768: Algorithm<AlgorithmId> = Algorithm {
86
    id: AlgorithmId::MlKem768,
87
    decapsulate_key_size: ML_KEM_768_SECRET_KEY_LENGTH,
88
    encapsulate_key_size: ML_KEM_768_PUBLIC_KEY_LENGTH,
89
    ciphertext_size: ML_KEM_768_CIPHERTEXT_LENGTH,
90
    shared_secret_size: ML_KEM_768_SHARED_SECRET_LENGTH,
91
};
92
93
/// NIST FIPS 203 ML-KEM-1024 algorithm.
94
pub const ML_KEM_1024: Algorithm<AlgorithmId> = Algorithm {
95
    id: AlgorithmId::MlKem1024,
96
    decapsulate_key_size: ML_KEM_1024_SECRET_KEY_LENGTH,
97
    encapsulate_key_size: ML_KEM_1024_PUBLIC_KEY_LENGTH,
98
    ciphertext_size: ML_KEM_1024_CIPHERTEXT_LENGTH,
99
    shared_secret_size: ML_KEM_1024_SHARED_SECRET_LENGTH,
100
};
101
102
use crate::aws_lc::{NID_MLKEM1024, NID_MLKEM512, NID_MLKEM768};
103
104
/// An identifier for a KEM algorithm.
105
pub trait AlgorithmIdentifier:
106
    Copy + Clone + Debug + PartialEq + crate::sealed::Sealed + 'static
107
{
108
    /// Returns the algorithm's associated AWS-LC nid.
109
    fn nid(self) -> i32;
110
}
111
112
/// A KEM algorithm
113
#[derive(PartialEq)]
114
pub struct Algorithm<Id = AlgorithmId>
115
where
116
    Id: AlgorithmIdentifier,
117
{
118
    pub(crate) id: Id,
119
    pub(crate) decapsulate_key_size: usize,
120
    pub(crate) encapsulate_key_size: usize,
121
    pub(crate) ciphertext_size: usize,
122
    pub(crate) shared_secret_size: usize,
123
}
124
125
impl<Id> Algorithm<Id>
126
where
127
    Id: AlgorithmIdentifier,
128
{
129
    /// Returns the identifier for this algorithm.
130
    #[must_use]
131
0
    pub fn id(&self) -> Id {
132
0
        self.id
133
0
    }
134
135
    #[inline]
136
    #[allow(dead_code)]
137
0
    pub(crate) fn decapsulate_key_size(&self) -> usize {
138
0
        self.decapsulate_key_size
139
0
    }
140
141
    #[inline]
142
0
    pub(crate) fn encapsulate_key_size(&self) -> usize {
143
0
        self.encapsulate_key_size
144
0
    }
145
146
    #[inline]
147
0
    pub(crate) fn ciphertext_size(&self) -> usize {
148
0
        self.ciphertext_size
149
0
    }
150
151
    #[inline]
152
0
    pub(crate) fn shared_secret_size(&self) -> usize {
153
0
        self.shared_secret_size
154
0
    }
155
}
156
157
impl<Id> Debug for Algorithm<Id>
158
where
159
    Id: AlgorithmIdentifier,
160
{
161
0
    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
162
0
        Debug::fmt(&self.id, f)
163
0
    }
164
}
165
166
/// A serializable decapulsation key usable with KEMs. This can be randomly generated with `DecapsulationKey::generate`.
167
pub struct DecapsulationKey<Id = AlgorithmId>
168
where
169
    Id: AlgorithmIdentifier,
170
{
171
    algorithm: &'static Algorithm<Id>,
172
    evp_pkey: LcPtr<EVP_PKEY>,
173
}
174
175
/// Identifier for a KEM algorithm.
176
#[non_exhaustive]
177
#[derive(Clone, Copy, Debug, PartialEq)]
178
pub enum AlgorithmId {
179
    /// NIST FIPS 203 ML-KEM-512 algorithm.
180
    MlKem512,
181
182
    /// NIST FIPS 203 ML-KEM-768 algorithm.
183
    MlKem768,
184
185
    /// NIST FIPS 203 ML-KEM-1024 algorithm.
186
    MlKem1024,
187
}
188
189
impl AlgorithmIdentifier for AlgorithmId {
190
0
    fn nid(self) -> i32 {
191
0
        match self {
192
0
            AlgorithmId::MlKem512 => NID_MLKEM512,
193
0
            AlgorithmId::MlKem768 => NID_MLKEM768,
194
0
            AlgorithmId::MlKem1024 => NID_MLKEM1024,
195
        }
196
0
    }
197
}
198
199
impl crate::sealed::Sealed for AlgorithmId {}
200
201
impl<Id> DecapsulationKey<Id>
202
where
203
    Id: AlgorithmIdentifier,
204
{
205
    /// Creates a new KEM decapsulation key from raw bytes. This method MUST NOT be used to generate
206
    /// a new decapsulation key, rather it MUST be used to construct `DecapsulationKey` previously serialized
207
    /// to raw bytes.
208
    ///
209
    /// `alg` is the [`Algorithm`] to be associated with the generated `DecapsulationKey`.
210
    ///
211
    /// `bytes` is a slice of raw bytes representing a `DecapsulationKey`.
212
    ///
213
    /// # Security Considerations
214
    ///
215
    /// This function performs size validation but does not fully validate key material integrity.
216
    /// Invalid key bytes (e.g., corrupted or tampered data) may be accepted by this function but
217
    /// will cause [`Self::decapsulate`] to fail. Only use bytes that were previously obtained from
218
    /// [`Self::key_bytes`] on a validly generated key.
219
    ///
220
    /// # Limitations
221
    ///
222
    /// The `DecapsulationKey` returned by this function will NOT provide the associated
223
    /// `EncapsulationKey` via [`Self::encapsulation_key`]. The `EncapsulationKey` must be
224
    /// serialized and restored separately using [`EncapsulationKey::key_bytes`] and
225
    /// [`EncapsulationKey::new`].
226
    ///
227
    /// # Errors
228
    ///
229
    /// Returns `KeyRejected::too_small()` if `bytes.len() < alg.decapsulate_key_size()`.
230
    ///
231
    /// Returns `KeyRejected::too_large()` if `bytes.len() > alg.decapsulate_key_size()`.
232
    ///
233
    /// Returns `KeyRejected::unexpected_error()` if the underlying cryptographic operation fails.
234
0
    pub fn new(alg: &'static Algorithm<Id>, bytes: &[u8]) -> Result<Self, KeyRejected> {
235
0
        match bytes.len().cmp(&alg.decapsulate_key_size()) {
236
0
            Ordering::Less => Err(KeyRejected::too_small()),
237
0
            Ordering::Greater => Err(KeyRejected::too_large()),
238
0
            Ordering::Equal => Ok(()),
239
0
        }?;
240
0
        let evp_pkey = LcPtr::new(unsafe {
241
0
            EVP_PKEY_kem_new_raw_secret_key(alg.id.nid(), bytes.as_ptr(), bytes.len())
242
0
        })?;
243
0
        Ok(DecapsulationKey {
244
0
            algorithm: alg,
245
0
            evp_pkey,
246
0
        })
247
0
    }
248
249
    /// Generate a new KEM decapsulation key for the given algorithm.
250
    ///
251
    /// # Errors
252
    /// `error::Unspecified` when operation fails due to internal error.
253
0
    pub fn generate(alg: &'static Algorithm<Id>) -> Result<Self, Unspecified> {
254
0
        let kyber_key = kem_key_generate(alg.id.nid())?;
255
0
        Ok(DecapsulationKey {
256
0
            algorithm: alg,
257
0
            evp_pkey: kyber_key,
258
0
        })
259
0
    }
260
261
    /// Return the algorithm associated with the given KEM decapsulation key.
262
    #[must_use]
263
0
    pub fn algorithm(&self) -> &'static Algorithm<Id> {
264
0
        self.algorithm
265
0
    }
266
267
    /// Returns the raw bytes of the `DecapsulationKey`.
268
    ///
269
    /// The returned bytes can be used with [`Self::new`] to reconstruct the `DecapsulationKey`.
270
    ///
271
    /// # Errors
272
    ///
273
    /// Returns [`Unspecified`] if the key bytes cannot be retrieved from the underlying
274
    /// cryptographic implementation.
275
0
    pub fn key_bytes(&self) -> Result<DecapsulationKeyBytes<'static>, Unspecified> {
276
0
        let decapsulation_key_bytes = self.evp_pkey.as_const().marshal_raw_private_key()?;
277
0
        debug_assert_eq!(
278
0
            decapsulation_key_bytes.len(),
279
0
            self.algorithm.decapsulate_key_size()
280
        );
281
0
        Ok(DecapsulationKeyBytes::new(decapsulation_key_bytes))
282
0
    }
283
284
    /// Returns the `EncapsulationKey` associated with this `DecapsulationKey`.
285
    ///
286
    /// # Errors
287
    ///
288
    /// Returns [`Unspecified`] in the following cases:
289
    /// * The `DecapsulationKey` was constructed from raw bytes using [`Self::new`],
290
    ///   as the underlying key representation does not include the public key component.
291
    ///   In this case, the `EncapsulationKey` must be serialized and restored separately.
292
    /// * An internal error occurs while extracting the public key.
293
    #[allow(clippy::missing_panics_doc)]
294
0
    pub fn encapsulation_key(&self) -> Result<EncapsulationKey<Id>, Unspecified> {
295
0
        let evp_pkey = self.evp_pkey.clone();
296
297
0
        let encapsulation_key = EncapsulationKey {
298
0
            algorithm: self.algorithm,
299
0
            evp_pkey,
300
0
        };
301
302
        // Verify the encapsulation key is valid by attempting to get its bytes.
303
        // Keys constructed from raw secret bytes may not have a valid public key.
304
0
        if encapsulation_key.key_bytes().is_err() {
305
0
            return Err(Unspecified);
306
0
        }
307
308
0
        Ok(encapsulation_key)
309
0
    }
310
311
    /// Performs the decapsulate operation using this `DecapsulationKey` on the given ciphertext.
312
    ///
313
    /// `ciphertext` is the ciphertext generated by the encapsulate operation using the `EncapsulationKey`
314
    /// associated with this `DecapsulationKey`.
315
    ///
316
    /// # Errors
317
    ///
318
    /// Returns [`Unspecified`] in the following cases:
319
    /// * The `ciphertext` is malformed or was not generated for this key's algorithm.
320
    /// * The `DecapsulationKey` was constructed from invalid bytes (e.g., corrupted or tampered
321
    ///   key material passed to [`Self::new`]). Note that [`Self::new`] only validates the size
322
    ///   of the key bytes, not their cryptographic validity.
323
    /// * An internal cryptographic error occurs.
324
    #[allow(clippy::needless_pass_by_value)]
325
0
    pub fn decapsulate(&self, ciphertext: Ciphertext<'_>) -> Result<SharedSecret, Unspecified> {
326
0
        let mut shared_secret_len = self.algorithm.shared_secret_size();
327
0
        let mut shared_secret: Vec<u8> = vec![0u8; shared_secret_len];
328
329
0
        let mut ctx = self.evp_pkey.create_EVP_PKEY_CTX()?;
330
331
0
        let ciphertext = ciphertext.as_ref();
332
333
0
        if 1 != unsafe {
334
0
            EVP_PKEY_decapsulate(
335
0
                ctx.as_mut_ptr(),
336
0
                shared_secret.as_mut_ptr(),
337
0
                &mut shared_secret_len,
338
0
                // AWS-LC incorrectly has this as an unqualified `uint8_t *`, it should be qualified with const
339
0
                ciphertext.as_ptr().cast_mut(),
340
0
                ciphertext.len(),
341
0
            )
342
0
        } {
343
0
            return Err(Unspecified);
344
0
        }
345
346
        // This is currently pedantic but done for safety in-case the shared_secret buffer
347
        // size changes in the future. `EVP_PKEY_decapsulate` updates `shared_secret_len` with
348
        // the length of the shared secret in the event the buffer provided was larger then the secret.
349
        // This truncates the buffer to the proper length to match the shared secret written.
350
0
        debug_assert_eq!(shared_secret_len, shared_secret.len());
351
0
        shared_secret.truncate(shared_secret_len);
352
353
0
        Ok(SharedSecret(shared_secret.into_boxed_slice()))
354
0
    }
355
}
356
357
unsafe impl<Id> Send for DecapsulationKey<Id> where Id: AlgorithmIdentifier {}
358
359
unsafe impl<Id> Sync for DecapsulationKey<Id> where Id: AlgorithmIdentifier {}
360
361
impl<Id> Debug for DecapsulationKey<Id>
362
where
363
    Id: AlgorithmIdentifier,
364
{
365
0
    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
366
0
        f.debug_struct("DecapsulationKey")
367
0
            .field("algorithm", &self.algorithm)
368
0
            .finish_non_exhaustive()
369
0
    }
370
}
371
372
generated_encodings!(
373
    (EncapsulationKeyBytes, EncapsulationKeyBytesType),
374
    (DecapsulationKeyBytes, DecapsulationKeyBytesType)
375
);
376
377
/// A serializable encapsulation key usable with KEM algorithms. Constructed
378
/// from either a `DecapsulationKey` or raw bytes.
379
pub struct EncapsulationKey<Id = AlgorithmId>
380
where
381
    Id: AlgorithmIdentifier,
382
{
383
    algorithm: &'static Algorithm<Id>,
384
    evp_pkey: LcPtr<EVP_PKEY>,
385
}
386
387
impl<Id> EncapsulationKey<Id>
388
where
389
    Id: AlgorithmIdentifier,
390
{
391
    /// Return the algorithm associated with the given KEM encapsulation key.
392
    #[must_use]
393
0
    pub fn algorithm(&self) -> &'static Algorithm<Id> {
394
0
        self.algorithm
395
0
    }
396
397
    /// Performs the encapsulate operation using this KEM encapsulation key, generating a ciphertext
398
    /// and associated shared secret.
399
    ///
400
    /// # Errors
401
    /// `error::Unspecified` when operation fails due to internal error.
402
0
    pub fn encapsulate(&self) -> Result<(Ciphertext<'static>, SharedSecret), Unspecified> {
403
0
        let mut ciphertext_len = self.algorithm.ciphertext_size();
404
0
        let mut shared_secret_len = self.algorithm.shared_secret_size();
405
0
        let mut ciphertext: Vec<u8> = vec![0u8; ciphertext_len];
406
0
        let mut shared_secret: Vec<u8> = vec![0u8; shared_secret_len];
407
408
0
        let mut ctx = self.evp_pkey.create_EVP_PKEY_CTX()?;
409
410
0
        if 1 != unsafe {
411
0
            EVP_PKEY_encapsulate(
412
0
                ctx.as_mut_ptr(),
413
0
                ciphertext.as_mut_ptr(),
414
0
                &mut ciphertext_len,
415
0
                shared_secret.as_mut_ptr(),
416
0
                &mut shared_secret_len,
417
0
            )
418
0
        } {
419
0
            return Err(Unspecified);
420
0
        }
421
422
        // The following two steps are currently pedantic but done for safety in-case the buffer allocation
423
        // sizes change in the future. `EVP_PKEY_encapsulate` updates `ciphertext_len` and `shared_secret_len` with
424
        // the length of the ciphertext and shared secret respectivly in the event the buffer provided for each was
425
        // larger then the actual values. Thus these two steps truncate the buffers to the proper length to match the
426
        // value lengths written.
427
0
        debug_assert_eq!(ciphertext_len, ciphertext.len());
428
0
        ciphertext.truncate(ciphertext_len);
429
0
        debug_assert_eq!(shared_secret_len, shared_secret.len());
430
0
        shared_secret.truncate(shared_secret_len);
431
432
0
        Ok((
433
0
            Ciphertext::new(ciphertext),
434
0
            SharedSecret::new(shared_secret.into_boxed_slice()),
435
0
        ))
436
0
    }
437
438
    /// Returns the `EnscapsulationKey` bytes.
439
    ///
440
    /// # Errors
441
    /// * `Unspecified`: Any failure to retrieve the `EnscapsulationKey` bytes.
442
0
    pub fn key_bytes(&self) -> Result<EncapsulationKeyBytes<'static>, Unspecified> {
443
0
        let mut encapsulate_bytes = vec![0u8; self.algorithm.encapsulate_key_size()];
444
0
        let encapsulate_key_size = self
445
0
            .evp_pkey
446
0
            .as_const()
447
0
            .marshal_raw_public_to_buffer(&mut encapsulate_bytes)?;
448
449
0
        debug_assert_eq!(encapsulate_key_size, encapsulate_bytes.len());
450
0
        encapsulate_bytes.truncate(encapsulate_key_size);
451
452
0
        Ok(EncapsulationKeyBytes::new(encapsulate_bytes))
453
0
    }
454
455
    /// Creates a new KEM encapsulation key from raw bytes. This method MUST NOT be used to generate
456
    /// a new encapsulation key, rather it MUST be used to construct `EncapsulationKey` previously serialized
457
    /// to raw bytes.
458
    ///
459
    /// `alg` is the [`Algorithm`] to be associated with the generated `EncapsulationKey`.
460
    ///
461
    /// `bytes` is a slice of raw bytes representing a `EncapsulationKey`.
462
    ///
463
    /// # Errors
464
    /// `error::KeyRejected` when operation fails during key creation.
465
0
    pub fn new(alg: &'static Algorithm<Id>, bytes: &[u8]) -> Result<Self, KeyRejected> {
466
0
        match bytes.len().cmp(&alg.encapsulate_key_size()) {
467
0
            Ordering::Less => Err(KeyRejected::too_small()),
468
0
            Ordering::Greater => Err(KeyRejected::too_large()),
469
0
            Ordering::Equal => Ok(()),
470
0
        }?;
471
0
        let pubkey = LcPtr::new(unsafe {
472
0
            EVP_PKEY_kem_new_raw_public_key(alg.id.nid(), bytes.as_ptr(), bytes.len())
473
0
        })?;
474
0
        Ok(EncapsulationKey {
475
0
            algorithm: alg,
476
0
            evp_pkey: pubkey,
477
0
        })
478
0
    }
479
}
480
481
unsafe impl<Id> Send for EncapsulationKey<Id> where Id: AlgorithmIdentifier {}
482
483
unsafe impl<Id> Sync for EncapsulationKey<Id> where Id: AlgorithmIdentifier {}
484
485
impl<Id> Debug for EncapsulationKey<Id>
486
where
487
    Id: AlgorithmIdentifier,
488
{
489
0
    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
490
0
        f.debug_struct("EncapsulationKey")
491
0
            .field("algorithm", &self.algorithm)
492
0
            .finish_non_exhaustive()
493
0
    }
494
}
495
496
/// A set of encrypted bytes produced by [`EncapsulationKey::encapsulate`],
497
/// and used as an input to [`DecapsulationKey::decapsulate`].
498
pub struct Ciphertext<'a>(Cow<'a, [u8]>);
499
500
impl<'a> Ciphertext<'a> {
501
0
    fn new(value: Vec<u8>) -> Ciphertext<'a> {
502
0
        Self(Cow::Owned(value))
503
0
    }
504
}
505
506
impl Drop for Ciphertext<'_> {
507
0
    fn drop(&mut self) {
508
0
        if let Cow::Owned(ref mut v) = self.0 {
509
0
            v.zeroize();
510
0
        }
511
0
    }
512
}
513
514
impl AsRef<[u8]> for Ciphertext<'_> {
515
0
    fn as_ref(&self) -> &[u8] {
516
0
        match self.0 {
517
0
            Cow::Borrowed(v) => v,
518
0
            Cow::Owned(ref v) => v.as_ref(),
519
        }
520
0
    }
521
}
522
523
impl<'a> From<&'a [u8]> for Ciphertext<'a> {
524
0
    fn from(value: &'a [u8]) -> Self {
525
0
        Self(Cow::Borrowed(value))
526
0
    }
527
}
528
529
/// The cryptographic shared secret output from the KEM encapsulate / decapsulate process.
530
pub struct SharedSecret(Box<[u8]>);
531
532
impl SharedSecret {
533
0
    fn new(value: Box<[u8]>) -> Self {
534
0
        Self(value)
535
0
    }
536
}
537
538
impl Drop for SharedSecret {
539
0
    fn drop(&mut self) {
540
0
        self.0.zeroize();
541
0
    }
542
}
543
544
impl AsRef<[u8]> for SharedSecret {
545
0
    fn as_ref(&self) -> &[u8] {
546
0
        self.0.as_ref()
547
0
    }
548
}
549
550
// Returns an LcPtr to an EVP_PKEY
551
#[inline]
552
0
fn kem_key_generate(nid: i32) -> Result<LcPtr<EVP_PKEY>, Unspecified> {
553
0
    let params_fn = |ctx| {
554
0
        if 1 == unsafe { EVP_PKEY_CTX_kem_set_params(ctx, nid) } {
555
0
            Ok(())
556
        } else {
557
0
            Err(())
558
        }
559
0
    };
560
561
0
    LcPtr::<EVP_PKEY>::generate(EVP_PKEY_KEM, Some(params_fn))
562
0
}
563
564
#[cfg(test)]
565
mod tests {
566
    use super::{Ciphertext, DecapsulationKey, EncapsulationKey, SharedSecret};
567
    use crate::error::KeyRejected;
568
569
    use crate::kem::{ML_KEM_1024, ML_KEM_512, ML_KEM_768};
570
571
    #[test]
572
    fn ciphertext() {
573
        let ciphertext_bytes = vec![42u8; 4];
574
        let ciphertext = Ciphertext::from(ciphertext_bytes.as_ref());
575
        assert_eq!(ciphertext.as_ref(), &[42, 42, 42, 42]);
576
        drop(ciphertext);
577
578
        let ciphertext_bytes = vec![42u8; 4];
579
        let ciphertext = Ciphertext::<'static>::new(ciphertext_bytes);
580
        assert_eq!(ciphertext.as_ref(), &[42, 42, 42, 42]);
581
    }
582
583
    #[test]
584
    fn shared_secret() {
585
        let secret_bytes = vec![42u8; 4];
586
        let shared_secret = SharedSecret::new(secret_bytes.into_boxed_slice());
587
        assert_eq!(shared_secret.as_ref(), &[42, 42, 42, 42]);
588
    }
589
590
    #[test]
591
    fn test_kem_serialize() {
592
        for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
593
            let priv_key = DecapsulationKey::generate(algorithm).unwrap();
594
            assert_eq!(priv_key.algorithm(), algorithm);
595
596
            // Test DecapsulationKey serialization
597
            let priv_key_raw_bytes = priv_key.key_bytes().unwrap();
598
            assert_eq!(
599
                priv_key_raw_bytes.as_ref().len(),
600
                algorithm.decapsulate_key_size()
601
            );
602
            let priv_key_from_bytes =
603
                DecapsulationKey::new(algorithm, priv_key_raw_bytes.as_ref()).unwrap();
604
605
            assert_eq!(
606
                priv_key.key_bytes().unwrap().as_ref(),
607
                priv_key_from_bytes.key_bytes().unwrap().as_ref()
608
            );
609
            assert_eq!(priv_key.algorithm(), priv_key_from_bytes.algorithm());
610
611
            // Test EncapsulationKey serialization
612
            let pub_key = priv_key.encapsulation_key().unwrap();
613
            let pubkey_raw_bytes = pub_key.key_bytes().unwrap();
614
            let pub_key_from_bytes =
615
                EncapsulationKey::new(algorithm, pubkey_raw_bytes.as_ref()).unwrap();
616
617
            assert_eq!(
618
                pub_key.key_bytes().unwrap().as_ref(),
619
                pub_key_from_bytes.key_bytes().unwrap().as_ref()
620
            );
621
            assert_eq!(pub_key.algorithm(), pub_key_from_bytes.algorithm());
622
        }
623
    }
624
625
    #[test]
626
    fn test_kem_wrong_sizes() {
627
        for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
628
            // Test EncapsulationKey size validation
629
            let too_long_bytes = vec![0u8; algorithm.encapsulate_key_size() + 1];
630
            let long_pub_key_from_bytes = EncapsulationKey::new(algorithm, &too_long_bytes);
631
            assert_eq!(
632
                long_pub_key_from_bytes.err(),
633
                Some(KeyRejected::too_large())
634
            );
635
636
            let too_short_bytes = vec![0u8; algorithm.encapsulate_key_size() - 1];
637
            let short_pub_key_from_bytes = EncapsulationKey::new(algorithm, &too_short_bytes);
638
            assert_eq!(
639
                short_pub_key_from_bytes.err(),
640
                Some(KeyRejected::too_small())
641
            );
642
643
            // Test DecapsulationKey size validation
644
            let too_long_bytes = vec![0u8; algorithm.decapsulate_key_size() + 1];
645
            let long_priv_key_from_bytes = DecapsulationKey::new(algorithm, &too_long_bytes);
646
            assert_eq!(
647
                long_priv_key_from_bytes.err(),
648
                Some(KeyRejected::too_large())
649
            );
650
651
            let too_short_bytes = vec![0u8; algorithm.decapsulate_key_size() - 1];
652
            let short_priv_key_from_bytes = DecapsulationKey::new(algorithm, &too_short_bytes);
653
            assert_eq!(
654
                short_priv_key_from_bytes.err(),
655
                Some(KeyRejected::too_small())
656
            );
657
        }
658
    }
659
660
    #[test]
661
    fn test_kem_e2e() {
662
        for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
663
            let priv_key = DecapsulationKey::generate(algorithm).unwrap();
664
            assert_eq!(priv_key.algorithm(), algorithm);
665
666
            // Serialize and reconstruct the decapsulation key
667
            let priv_key_bytes = priv_key.key_bytes().unwrap();
668
            let priv_key_from_bytes =
669
                DecapsulationKey::new(algorithm, priv_key_bytes.as_ref()).unwrap();
670
671
            // Keys reconstructed from bytes cannot provide encapsulation_key()
672
            assert!(priv_key_from_bytes.encapsulation_key().is_err());
673
674
            let pub_key = priv_key.encapsulation_key().unwrap();
675
676
            let (alice_ciphertext, alice_secret) =
677
                pub_key.encapsulate().expect("encapsulate successful");
678
679
            // Decapsulate using the reconstructed key
680
            let bob_secret = priv_key_from_bytes
681
                .decapsulate(alice_ciphertext)
682
                .expect("decapsulate successful");
683
684
            assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
685
        }
686
    }
687
688
    #[test]
689
    fn test_serialized_kem_e2e() {
690
        for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
691
            let priv_key = DecapsulationKey::generate(algorithm).unwrap();
692
            assert_eq!(priv_key.algorithm(), algorithm);
693
694
            let pub_key = priv_key.encapsulation_key().unwrap();
695
696
            // Generate public key bytes to send to bob
697
            let pub_key_bytes = pub_key.key_bytes().unwrap();
698
699
            // Generate private key bytes for alice to store securely
700
            let priv_key_bytes = priv_key.key_bytes().unwrap();
701
702
            // Test that priv_key's EVP_PKEY isn't entirely freed since we remove this pub_key's reference.
703
            drop(pub_key);
704
            drop(priv_key);
705
706
            let retrieved_pub_key =
707
                EncapsulationKey::new(algorithm, pub_key_bytes.as_ref()).unwrap();
708
            let (ciphertext, bob_secret) = retrieved_pub_key
709
                .encapsulate()
710
                .expect("encapsulate successful");
711
712
            // Alice reconstructs her private key from stored bytes
713
            let retrieved_priv_key =
714
                DecapsulationKey::new(algorithm, priv_key_bytes.as_ref()).unwrap();
715
            let alice_secret = retrieved_priv_key
716
                .decapsulate(ciphertext)
717
                .expect("decapsulate successful");
718
719
            assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
720
        }
721
    }
722
723
    #[test]
724
    fn test_decapsulation_key_serialization_roundtrip() {
725
        for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
726
            // Generate original key
727
            let original_key = DecapsulationKey::generate(algorithm).unwrap();
728
729
            // Test key_bytes() returns correct size
730
            let key_bytes = original_key.key_bytes().unwrap();
731
            assert_eq!(key_bytes.as_ref().len(), algorithm.decapsulate_key_size());
732
733
            // Test round-trip serialization/deserialization
734
            let reconstructed_key = DecapsulationKey::new(algorithm, key_bytes.as_ref()).unwrap();
735
736
            // Verify algorithm consistency
737
            assert_eq!(original_key.algorithm(), reconstructed_key.algorithm());
738
            assert_eq!(original_key.algorithm(), algorithm);
739
740
            // Test serialization produces identical bytes (stability check)
741
            let key_bytes_2 = reconstructed_key.key_bytes().unwrap();
742
            assert_eq!(key_bytes.as_ref(), key_bytes_2.as_ref());
743
744
            // Test functional equivalence: both keys decrypt the same ciphertext identically
745
            let pub_key = original_key.encapsulation_key().unwrap();
746
            let (ciphertext, expected_secret) =
747
                pub_key.encapsulate().expect("encapsulate successful");
748
749
            let secret_from_original = original_key
750
                .decapsulate(Ciphertext::from(ciphertext.as_ref()))
751
                .expect("decapsulate with original key");
752
            let secret_from_reconstructed = reconstructed_key
753
                .decapsulate(Ciphertext::from(ciphertext.as_ref()))
754
                .expect("decapsulate with reconstructed key");
755
756
            // Verify both keys produce identical secrets
757
            assert_eq!(expected_secret.as_ref(), secret_from_original.as_ref());
758
            assert_eq!(expected_secret.as_ref(), secret_from_reconstructed.as_ref());
759
760
            // Verify secret length matches algorithm specification
761
            assert_eq!(expected_secret.as_ref().len(), algorithm.shared_secret_size);
762
        }
763
    }
764
765
    #[test]
766
    fn test_decapsulation_key_zeroed_bytes() {
767
        // Test behavior when constructing DecapsulationKey from zeroed bytes of correct size.
768
        // ML-KEM accepts any bytes of the correct size as a valid secret key (seed-based).
769
        // This test documents the expected behavior.
770
        for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
771
            let zeroed_bytes = vec![0u8; algorithm.decapsulate_key_size()];
772
773
            // Constructing a key from zeroed bytes should succeed (ML-KEM treats any
774
            // correctly-sized byte sequence as a valid seed)
775
            let key_from_zeroed = DecapsulationKey::new(algorithm, &zeroed_bytes);
776
            assert!(
777
                key_from_zeroed.is_ok(),
778
                "DecapsulationKey::new should accept zeroed bytes of correct size for {:?}",
779
                algorithm.id()
780
            );
781
782
            let key = key_from_zeroed.unwrap();
783
784
            // The key should be able to serialize back to bytes
785
            let key_bytes = key.key_bytes();
786
            assert!(
787
                key_bytes.is_ok(),
788
                "key_bytes() should succeed for key constructed from zeroed bytes"
789
            );
790
            assert_eq!(key_bytes.unwrap().as_ref(), zeroed_bytes.as_slice());
791
792
            // encapsulation_key() should fail since key was constructed from raw bytes
793
            assert!(
794
                key.encapsulation_key().is_err(),
795
                "encapsulation_key() should fail for key constructed from raw bytes"
796
            );
797
798
            // Test decapsulation behavior with zeroed-seed key.
799
            // Generate a valid ciphertext from a properly generated key pair
800
            let valid_key = DecapsulationKey::generate(algorithm).unwrap();
801
            let valid_pub_key = valid_key.encapsulation_key().unwrap();
802
            let (ciphertext, _) = valid_pub_key.encapsulate().unwrap();
803
804
            // Decapsulating with a zeroed-seed key fails because the key material
805
            // doesn't represent a valid ML-KEM private key structure.
806
            // This documents that ML-KEM validates key integrity during decapsulation.
807
            let decapsulate_result = key.decapsulate(Ciphertext::from(ciphertext.as_ref()));
808
            assert!(
809
                decapsulate_result.is_err(),
810
                "decapsulate should fail with invalid (zeroed) key material for {:?}",
811
                algorithm.id()
812
            );
813
        }
814
    }
815
816
    #[test]
817
    fn test_cross_algorithm_key_rejection() {
818
        // Test that keys from one algorithm are rejected when used with a different algorithm
819
        // due to size mismatches.
820
        let algorithms = [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024];
821
822
        for source_alg in &algorithms {
823
            let key = DecapsulationKey::generate(source_alg).unwrap();
824
            let key_bytes = key.key_bytes().unwrap();
825
826
            for target_alg in &algorithms {
827
                if source_alg.id() == target_alg.id() {
828
                    // Same algorithm should succeed
829
                    let result = DecapsulationKey::new(target_alg, key_bytes.as_ref());
830
                    assert!(
831
                        result.is_ok(),
832
                        "Same algorithm should accept its own key bytes"
833
                    );
834
                } else {
835
                    // Different algorithm should fail due to size mismatch
836
                    let result = DecapsulationKey::new(target_alg, key_bytes.as_ref());
837
                    assert!(
838
                        result.is_err(),
839
                        "Algorithm {:?} should reject key bytes from {:?}",
840
                        target_alg.id(),
841
                        source_alg.id()
842
                    );
843
844
                    // Verify the error is size-related
845
                    let err = result.err().unwrap();
846
                    let source_size = source_alg.decapsulate_key_size();
847
                    let target_size = target_alg.decapsulate_key_size();
848
                    if source_size < target_size {
849
                        assert_eq!(
850
                            err,
851
                            KeyRejected::too_small(),
852
                            "Smaller key should be rejected as too_small"
853
                        );
854
                    } else {
855
                        assert_eq!(
856
                            err,
857
                            KeyRejected::too_large(),
858
                            "Larger key should be rejected as too_large"
859
                        );
860
                    }
861
                }
862
            }
863
        }
864
865
        // Also test EncapsulationKey cross-algorithm rejection for completeness
866
        for source_alg in &algorithms {
867
            let decap_key = DecapsulationKey::generate(source_alg).unwrap();
868
            let encap_key = decap_key.encapsulation_key().unwrap();
869
            let key_bytes = encap_key.key_bytes().unwrap();
870
871
            for target_alg in &algorithms {
872
                if source_alg.id() == target_alg.id() {
873
                    let result = EncapsulationKey::new(target_alg, key_bytes.as_ref());
874
                    assert!(
875
                        result.is_ok(),
876
                        "Same algorithm should accept its own encapsulation key bytes"
877
                    );
878
                } else {
879
                    let result = EncapsulationKey::new(target_alg, key_bytes.as_ref());
880
                    assert!(
881
                        result.is_err(),
882
                        "Algorithm {:?} should reject encapsulation key bytes from {:?}",
883
                        target_alg.id(),
884
                        source_alg.id()
885
                    );
886
                }
887
            }
888
        }
889
    }
890
891
    #[test]
892
    fn test_debug_fmt() {
893
        let private = DecapsulationKey::generate(&ML_KEM_512).expect("successful generation");
894
        assert_eq!(
895
            format!("{private:?}"),
896
            "DecapsulationKey { algorithm: MlKem512, .. }"
897
        );
898
        assert_eq!(
899
            format!(
900
                "{:?}",
901
                private.encapsulation_key().expect("public key retrievable")
902
            ),
903
            "EncapsulationKey { algorithm: MlKem512, .. }"
904
        );
905
    }
906
}